设计模式之【策略模式】,去掉繁琐的if-else,实现算法的动态替换

文章目录

  • 一、什么是策略模式
    • 1、策略模式应用场景
    • 2、状态模式与策略模式的区别
    • 3、策略模式优缺点
    • 4、策略模式的三大角色
  • 二、实例
    • 1、策略模式的一般写法
    • 2、促销活动案例
    • 3、网购订单支付案例
    • 4、DispatcherServlet的优化
    • 5、文件排序案例
    • 6、Spring中策略模式简单应用
  • 三、源码中的策略模式
    • 1、Comparator接口
    • 2、Spring的InstantiationStrategy
    • 3、Spring的Resource

全网最全最细的【设计模式】总目录,收藏起来慢慢啃,看完不懂砍我

一、什么是策略模式

策略模式(Strategy Pattern)又叫政策模式(Policy Pattern),它是将定义的算法家族分别封装起来,让它们之间可以互相替换,从而让算法的变化不会影响到使用算法的用户。属于行为型模式。

策略模式使用的就是面向对象的继承和多态机制,从而实现同一行为在不同场景下具备不同实现。

1、策略模式应用场景

策略模式在生活中应用也非常多。比如一个人的交税比率与他的工资有关,不同工资对应不同的税率。再比如互联网移动支付,每次下单后付款都需要选择支付方式。

策略模式可以解决在有多种算法相似的情况下,使用if-else或者switch-case所带来的复杂性和臃肿性,策略模式通常适用于以下场景:

  • 针对同一类型问题,有多种处理方式,每一种都能独立解决问题;
  • 一个类定义了多种行为,并且这些行为在这个类的操作中以多个条件语句的形式出现,可将每个条件分支移入它们各自的策略类中以代替这些条件语句;
  • 多个类只区别在表现行为不同,可以使用策略模式,在运行时动态选择具体要执行的行为;
  • 算法需要自由切换的场景;
  • 需要屏蔽算法规则的场景。

2、状态模式与策略模式的区别

状态模式和策略模式的UML类图架构几乎完全一样,但他们的应用场景是不一样的。策略模式多种算法行为择其一都能满足,彼此之间是独立的,用户可自行更换策略算法;而状态模式各个状态间是存在相互关系的,彼此之间在一定条件下存在自动切换状态效果,且用户无法指定状态,只能设置初始状态。

3、策略模式优缺点

优点:

  • 策略类之间可以自由切换:由于策略类都实现同一个接口,所以使它们之间可以自由切换。
  • 易于扩展:增加一个新的策略只需要添加一个具体的策略类即可,基本不需要改变原有的代码,符合“开闭原则“
  • 避免使用多重条件选择语句(if else),充分体现面向对象设计思想。

缺点:

  • 客户端必须知道所有的策略类,并自行决定使用哪一个策略类。
  • 策略模式将造成产生很多策略类,可以通过使用享元模式在一定程度上减少对象的数量。

4、策略模式的三大角色


策略模式的主要角色如下:

  • 抽象策略(Strategy)类:这是一个抽象角色,通常由一个接口或抽象类实现。此角色给出所有的具体策略类所需的接口。
  • 具体策略(Concrete Strategy)类:实现了抽象策略定义的接口,提供具体的算法实现或行为。
  • 环境(Context)类:用来操作策略的上下文环境,屏蔽高层模块(客户端)对策略、算法的直接访问,封装可能存在的变化。

注:策略模式中的上下文环境(Context),其职责本来是隔离客户端与策略类的耦合,让客户端完全与上下文环境沟通,无需关心具体策略。

二、实例

1、策略模式的一般写法

//抽象策略类 Strategy
public interface IStrategy {
    void algorithm();
}
//具体策略类 ConcreteStrategy
public class ConcreteStrategyA implements IStrategy {
    public void algorithm() {
        System.out.println("Strategy A");
    }
}
//具体策略类 ConcreteStrategy
public class ConcreteStrategyB implements IStrategy {
    public void algorithm() {
        System.out.println("Strategy B");
    }
}
//上下文环境
public class Context {
    private IStrategy mStrategy;

    public Context(IStrategy strategy) {
        this.mStrategy = strategy;
    }

    public void algorithm() {
        this.mStrategy.algorithm();
    }
}
public class Test {
    public static void main(String[] args) {
        //选择一个具体策略
        IStrategy strategy = new ConcreteStrategyA();
        //来一个上下文环境
        Context context = new Context(strategy);
        //客户端直接让上下文环境执行算法
        context.algorithm();
    }
}

2、促销活动案例

一家百货公司在定年度的促销活动。针对不同的节日(春节、中秋节、圣诞节)推出不同的促销活动,由促销员将促销活动展示给客户。类图如下:

// 定义百货公司所有促销活动的共同接口
public interface Strategy {
	void show();
}
// 定义具体策略角色(Concrete Strategy):每个节日具体的促销活动
//为春节准备的促销活动A
public class StrategyA implements Strategy {
	public void show() {
		System.out.println("买一送一");
	}
}
//为中秋准备的促销活动B
public class StrategyB implements Strategy {
	public void show() {
		System.out.println("满200元减50元");
	}
}
//为圣诞准备的促销活动C
public class StrategyC implements Strategy {
	public void show() {
		System.out.println("满1000元加一元换购任意200元以下商品");
	}
}
// 定义环境角色(Context):用于连接上下文,即把促销活动推销给客户,这里可以理解为销售员
public class SalesMan {
	//持有抽象策略角色的引用
	private Strategy strategy;
	public SalesMan(Strategy strategy) {
		this.strategy = strategy;
	}
	//向客户展示促销活动
	public void salesManShow(){
		strategy.show();
	}
}
// 测试类
public class Client {
    public static void main(String[] args) {
        //春节来了,使用春节促销活动
        SalesMan salesMan = new SalesMan(new StrategyA());
        //展示促销活动
        salesMan.salesManShow();

        System.out.println("==============");
        //中秋节到了,使用中秋节的促销活动
        salesMan.setStrategy(new StrategyB());
        //展示促销活动
        salesMan.salesManShow();

        System.out.println("==============");
        //圣诞节到了,使用圣诞节的促销活动
        salesMan.setStrategy(new StrategyC());
        //展示促销活动
        salesMan.salesManShow();
    }
}

此时,我们发现,上面的测试代码放到实际业务场景其实并不实用,因为我们做活动时往往是要根据不同的需求对促销策略进行动态选择的,并不会一次性执行多种优惠,所以我们代码通常会这样写:

public class Client {
    public static void main(String[] args) {
        SalesMan salesMan = null;
        String saleKey = "A";
        
        if(saleKey.equals("A")){
            //春节来了,使用春节促销活动
            salesMan = new SalesMan(new StrategyA());
        } else if (saleKey.equals("B")) {
            //中秋节到了,使用中秋节的促销活动
            salesMan = new SalesMan(new StrategyB());
        } // ...

        //展示促销活动
        salesMan.salesManShow();
    }
}

这样改造之后,满足了业务需求,客户可以根据自己的需求选择不同的优惠策略了。但是这里的if-else随着促销活动的增多会越来越复杂,我们可以使用单例模式和工厂模式进行优化:

public class SalesMan {

    public static final String SaleKeyA = "A";
    public static final String SaleKeyB = "B";
    public static final String SaleKeyC = "C";

    private static Map<String, Strategy> sales = new HashMap<String, Strategy>();

    static {
        sales.put(SaleKeyA, new StrategyA());
        sales.put(SaleKeyB, new StrategyB());
        sales.put(SaleKeyC, new StrategyC());
    }

    public Strategy getStrategy(String key) {
        Strategy strategy = sales.get(key);
        if(strategy == null){
            throw new RuntimeException("策略有误");
        }
        return strategy;
    }
}

public class Client {
    public static void main(String[] args) {
        SalesMan salesMan = new SalesMan();
        String saleKey = "A";

        Strategy strategy = salesMan.getStrategy(saleKey);

        //展示促销活动
        strategy.show();
    }
}

3、网购订单支付案例

我们在网购下单时,会提示选择支付方式,通常会有支付宝、微信、银联等等支付方式,如果没选择,系统也会使用默认的支付方式,我们使用策略模式来模拟此场景:

// 支付状态包装类
public class MsgResult {
    private int code;
    private Object data;
    private String msg;

    public MsgResult(int code, String msg, Object data) {
        this.code = code;
        this.data = data;
        this.msg = msg;
    }

    @Override
    public String toString() {
        return "MsgResult{" +
                "code=" + code +
                ", data=" + data +
                ", msg='" + msg + '\'' +
                '}';
    }
}

// 定义支付逻辑,具体支付交由子类实现
public abstract class Payment {

    public abstract String getName();

    //通用逻辑放到抽象类里面实现
    public MsgResult pay(String uid, double amount){
        //余额是否足够
        if(queryBalance(uid) < amount){
            return new MsgResult(500,"支付失败","余额不足");
        }
        return new MsgResult(200,"支付成功","支付金额" + amount);
    }

    protected abstract double queryBalance(String uid);
}

// 定义具体支付方式
public class AliPay extends Payment {
    public String getName() {
        return "支付宝";
    }

    protected double queryBalance(String uid) {
        return 900;
    }
}
public class JDPay extends Payment {
    public String getName() {
        return "京东白条";
    }

    protected double queryBalance(String uid) {
        return 500;
    }
}
public class UnionPay extends Payment {
    public String getName() {
        return "银联支付";
    }

    protected double queryBalance(String uid) {
        return 120;
    }
}
public class WechatPay extends Payment {
    public String getName() {
        return "微信支付";
    }

    protected double queryBalance(String uid) {
        return 263;
    }
}

// 策略管理类
public class PayStrategy {
    public static  final String ALI_PAY = "AliPay";
    public static  final String JD_PAY = "JdPay";
    public static  final String WECHAT_PAY = "WechatPay";
    public static  final String UNION_PAY = "UnionPay";
    public static  final String DEFAULT_PAY = ALI_PAY;

    private static Map<String,Payment> strategy = new HashMap<String,Payment>();

    static {
        strategy.put(ALI_PAY,new AliPay());
        strategy.put(JD_PAY,new JDPay());
        strategy.put(WECHAT_PAY,new WechatPay());
        strategy.put(UNION_PAY,new UnionPay());
    }

    public static Payment get(String payKey){
        if(!strategy.containsKey(payKey)){
            return strategy.get(DEFAULT_PAY);
        }
        return strategy.get(payKey);
    }
}

// 订单类
public class Order {
    private String uid;
    private String orderId;
    private double amount;

    public Order(String uid, String orderId, double amount) {
        this.uid = uid;
        this.orderId = orderId;
        this.amount = amount;
    }

    public MsgResult pay(){
        return pay(PayStrategy.DEFAULT_PAY);
    }

    public MsgResult pay(String payKey){
        Payment payment = PayStrategy.get(payKey);
        System.out.println("欢迎使用" + payment.getName());
        System.out.println("本次交易金额为" + amount + ",开始扣款");
        return payment.pay(uid,amount);
    }
}

// 测试类
public class Test {
    public static void main(String[] args) {
        Order order = new Order("1","orderid",324.5);
        System.out.println(order.pay(PayStrategy.UNION_PAY));
    }
}

4、DispatcherServlet的优化

我们都知道SpringMVC的请求都是通过DispatcherServlet的doDispatch方法进行分发的,如果让我们设计,可能会这样实现:

public class DispatcherServlet extends HttpServlet {

	private void doDispatch(HttpServletRequest request, HttpServletResponse response) throws Exception {

		String uri = request.getRequestURI();
		String mid = request.getParameter("mid");
		if("getMemberById".equals(uri)){
			new MemberController().getMemberById(mid);
		} else if("getOrder".equals(uri)) {
			new OrderController().getOrder();
		}// ...
	}

}

上面的代码扩展性确实不太优雅,我们可以使用策略模式进行优化:

public class DispatcherServlet extends HttpServlet {

    private List<Handler> handlerMapping = new ArrayList<Handler>();

    @Override
    protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        doDispatch(req,resp);
    }

    private void doDispatch(HttpServletRequest req, HttpServletResponse resp) {
        String uri = req.getRequestURI();
        // 使用uri匹配handler
        Handler handler = null;
        for (Handler h : handlerMapping) {
            if(uri.equals(h.getUrl())){
                handler = h;
                break;
            }
        }
        // 将具体任务分发给Method
        Object result = null;
        try {
            result = handler.getMethod().invoke(handler.getController());
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        }
        
        // 写回结果
        try {
            resp.getWriter().write((String) result);
        } catch (IOException e) {
            e.printStackTrace();
        }

    }

    @Override
    public void init() throws ServletException {
        try {
            
            handlerMapping.add(new Handler().setController(MemberController.class.newInstance())
                            .setMethod(MemberController.class.getMethod("getMemberById", new Class[]{String.class}))
                            .setUrl("/web/getMemberById.json"));
            handlerMapping.add(new Handler().setController(OrderController.class.newInstance())
                            .setMethod(OrderController.class.getMethod("getOrderById", new Class[]{String.class}))
                            .setUrl("/web/getOrderById.json"));
            // ...其他handler
        }catch (Exception e){
            e.printStackTrace();
        }
    }
}

class Handler {
    private Object controller;
    private Method method;
    private String url;

    public Object getController() {
        return controller;
    }

    public Handler setController(Object controller) {
        this.controller = controller;
        return this;
    }

    public Method getMethod() {
        return method;
    }

    public Handler setMethod(Method method) {
        this.method = method;
        return this;
    }

    public String getUrl() {
        return url;
    }

    public Handler setUrl(String url) {
        this.url = url;
        return this;
    }
}

5、文件排序案例

我们再使用一个文件排序的案例加深策略模式的印象。

有这样一个需求,希望写一个小程序,实现对一个文件进行排序的功能。不同大小的文件排序的算法是不同的,我们初步可能会这样实现:

public class Sorter {
  private static final long GB = 1000 * 1000 * 1000;
  public void sortFile(String filePath) {
    // 省略校验逻辑
    File file = new File(filePath);
    long fileSize = file.length();
    if (fileSize < 6 * GB) { // [0, 6GB)
      quickSort(filePath);
    } else if (fileSize < 10 * GB) { // [6GB, 10GB)
      externalSort(filePath);
    } else if (fileSize < 100 * GB) { // [10GB, 100GB)
      concurrentExternalSort(filePath);
    } else { // [100GB, ~)
      mapreduceSort(filePath);
    }
  }
  private void quickSort(String filePath) {
    // 快速排序
  }
  private void externalSort(String filePath) {
    // 外部排序
  }
  private void concurrentExternalSort(String filePath) {
    // 多线程外部排序
  }
  private void mapreduceSort(String filePath) {
    // 利用MapReduce多机排序
  }
}
public class SortingTool {
  public static void main(String[] args) {
    Sorter sorter = new Sorter();
    sorter.sortFile(args[0]);
  }
}

以上代码并不能体现面向对象的魅力,完全是面向过程的,而我们日常开发中这样的代码也是占大多数。

我们使用策略模式进行重构:

// 策略接口
public interface ISortAlg {
  void sort(String filePath);
}
// 具体策略类
public class QuickSort implements ISortAlg {
  @Override
  public void sort(String filePath) {
    //...
  }
}
public class ExternalSort implements ISortAlg {
  @Override
  public void sort(String filePath) {
    //...
  }
}
public class ConcurrentExternalSort implements ISortAlg {
  @Override
  public void sort(String filePath) {
    //...
  }
}
public class MapReduceSort implements ISortAlg {
  @Override
  public void sort(String filePath) {
    //...
  }
}
// 排序
public class SortAlgFactory {
  private static final Map<String, ISortAlg> algs = new HashMap<>();
  static {
    algs.put("QuickSort", new QuickSort());
    algs.put("ExternalSort", new ExternalSort());
    algs.put("ConcurrentExternalSort", new ConcurrentExternalSort());
    algs.put("MapReduceSort", new MapReduceSort());
  }
  public static ISortAlg getSortAlg(String type) {
    if (type == null || type.isEmpty()) {
      throw new IllegalArgumentException("type should not be empty.");
    }
    return algs.get(type);
  }
}
public class Sorter {
  private static final long GB = 1000 * 1000 * 1000;
  public void sortFile(String filePath) {
    // 省略校验逻辑
    File file = new File(filePath);
    long fileSize = file.length();
    ISortAlg sortAlg;
    if (fileSize < 6 * GB) { // [0, 6GB)
      sortAlg = SortAlgFactory.getSortAlg("QuickSort");
    } else if (fileSize < 10 * GB) { // [6GB, 10GB)
      sortAlg = SortAlgFactory.getSortAlg("ExternalSort");
    } else if (fileSize < 100 * GB) { // [10GB, 100GB)
      sortAlg = SortAlgFactory.getSortAlg("ConcurrentExternalSort");
    } else { // [100GB, ~)
      sortAlg = SortAlgFactory.getSortAlg("MapReduceSort");
    }
    sortAlg.sort(filePath);
  }
}

Sorter 类中的 sortFile() 函数还是有一堆 if-else 逻辑。这里的 if-else 逻辑分支不多、也不复杂,这样写完全没问题。但如果你特别想将 if-else 分支判断移除掉,那也是有办法的。我直接给出代码,你一看就能明白。实际上,这也是基于查表法来解决的,其中的“algs”就是“表”。

public class Sorter {
  private static final long GB = 1000 * 1000 * 1000;
  private static final List<AlgRange> algs = new ArrayList<>();
  static {
    algs.add(new AlgRange(0, 6*GB, SortAlgFactory.getSortAlg("QuickSort")));
    algs.add(new AlgRange(6*GB, 10*GB, SortAlgFactory.getSortAlg("ExternalSort")));
    algs.add(new AlgRange(10*GB, 100*GB, SortAlgFactory.getSortAlg("ConcurrentExternalSort")));
    algs.add(new AlgRange(100*GB, Long.MAX_VALUE, SortAlgFactory.getSortAlg("MapReduceSort")));
  }
  public void sortFile(String filePath) {
    // 省略校验逻辑
    File file = new File(filePath);
    long fileSize = file.length();
    ISortAlg sortAlg = null;
    for (AlgRange algRange : algs) {
      if (algRange.inRange(fileSize)) {
        sortAlg = algRange.getAlg();
        break;
      }
    }
    sortAlg.sort(filePath);
  }
  private static class AlgRange {
    private long start;
    private long end;
    private ISortAlg alg;
    public AlgRange(long start, long end, ISortAlg alg) {
      this.start = start;
      this.end = end;
      this.alg = alg;
    }
    public ISortAlg getAlg() {
      return alg;
    }
    public boolean inRange(long size) {
      return size >= start && size < end;
    }
  }
}

6、Spring中策略模式简单应用

策略模式在实际开发中的应用

@Autowired
private Map<String,MyService> ServiceMap;

@Slf4j
@Service("MyServiceImpl")
public class MyServiceImpl implements MyService {}

Spring在自动注入时,使用Map的话,key就是接口实现类Bean的名称,value就是对应的Bean实例。

三、源码中的策略模式

1、Comparator接口

JDK中一个常用的比较器Comparator接口,有一个常用的方法compare():

public interface Comparator<T> {
	int compare(T o1, T o2);
}

Comparator抽象下面有非常多的实现类,我们经常会把Comparator作为参数传入作为排序策略,例如Arrays类的parallelSort方法等:

public class Arrays{
	public static <T> void sort(T[] a, Comparator<? super T> c) {
		if (c == null) {
		sort(a);
		} else {
		if (LegacyMergeSort.userRequested)
		legacyMergeSort(a, c);
		else
		TimSort.sort(a, 0, a.length, c, null, 0, 0);
		}
	}
}

Arrays就是一个环境角色类,这个sort方法可以传一个新策略让Arrays根据这个策略来进行排序。就比如下面的测试类。

public class demo {
	public static void main(String[] args) {
		Integer[] data = {12, 2, 3, 2, 4, 5, 1};
		// 实现降序排序
		Arrays.sort(data, new Comparator<Integer>() {
		public int compare(Integer o1, Integer o2) {
		return o2 - o1;
		}
		});
		System.out.println(Arrays.toString(data)); //[12, 5, 4, 3, 2, 2, 1]
	}
}

这里我们在调用Arrays的sort方法时,第二个参数传递的是Comparator接口的子实现类对象。所以Comparator充当的是抽象策略角色,而具体的子实现类充当的是具体策略角色。环境角色类(Arrays)应该持有抽象策略的引用来调用。那么,Arrays类的sort方法到底有没有使用
Comparator子实现类中的 compare() 方法吗?让我们继续查看TimSort类的 sort() 方法,代码如下:

static <T> void sort(T[] a, int lo, int hi, Comparator<? super T> c,
                     T[] work, int workBase, int workLen) {
    assert c != null && a != null && lo >= 0 && lo <= hi && hi <= a.length;

    int nRemaining  = hi - lo;
    if (nRemaining < 2)
        return;  // Arrays of size 0 and 1 are always sorted

    // If array is small, do a "mini-TimSort" with no merges
    if (nRemaining < MIN_MERGE) {
        int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
        binarySort(a, lo, hi, lo + initRunLen, c);
        return;
    }

    /**
     * March over the array once, left to right, finding natural runs,
     * extending short natural runs to minRun elements, and merging runs
     * to maintain stack invariant.
     */
    TimSort<T> ts = new TimSort<>(a, c, work, workBase, workLen);
    int minRun = minRunLength(nRemaining);
    do {
        // Identify next run
        int runLen = countRunAndMakeAscending(a, lo, hi, c);

        // If run is short, extend to min(minRun, nRemaining)
        if (runLen < minRun) {
            int force = nRemaining <= minRun ? nRemaining : minRun;
            binarySort(a, lo, lo + force, lo + runLen, c);
            runLen = force;
        }

        // Push run onto pending-run stack, and maybe merge
        ts.pushRun(lo, runLen);
        ts.mergeCollapse();

        // Advance to find next run
        lo += runLen;
        nRemaining -= runLen;
    } while (nRemaining != 0);

    // Merge all remaining runs to complete sort
    assert lo == hi;
    ts.mergeForceCollapse();
    assert ts.stackSize == 1;
}

上面的代码中最终会跑到 countRunAndMakeAscending() 这个方法中。我们可以看见,只用了compare方法,所以在调用Arrays.sort方法只传具体compare重写方法的类对象就行,这也是Comparator接口中必须要子类实现的一个方法。

2、Spring的InstantiationStrategy

Spring初始化用到的InstantiationStrategy接口:

public interface InstantiationStrategy {

	Object instantiate(RootBeanDefinition bd, @Nullable String beanName, BeanFactory owner)
			throws BeansException;

	Object instantiate(RootBeanDefinition bd, @Nullable String beanName, BeanFactory owner,
			Constructor<?> ctor, Object... args) throws BeansException;

	Object instantiate(RootBeanDefinition bd, @Nullable String beanName, BeanFactory owner,
			@Nullable Object factoryBean, Method factoryMethod, Object... args)
			throws BeansException;

}

它有两种策略:CglibSubclassingInstantiationStrategy和SimpleInstantiationStrategy,我们发现了CglibSubclassingInstantiationStrategy继承了SimpleInstantiationStrategy,说明在实际应用中多种策略之间还可以继承使用。我们可以作为一个参考,在实际业务场景中,可以根据需要进行设计。

3、Spring的Resource

package org.springframework.core.io;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URL;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import org.springframework.lang.Nullable;

public interface Resource extends InputStreamSource {

	boolean exists();

	default boolean isReadable() {
		return exists();
	}

	default boolean isOpen() {
		return false;
	}

	default boolean isFile() {
		return false;
	}

	URL getURL() throws IOException;

	URI getURI() throws IOException;

	File getFile() throws IOException;

	default ReadableByteChannel readableChannel() throws IOException {
		return Channels.newChannel(getInputStream());
	}

	long contentLength() throws IOException;

	long lastModified() throws IOException;

	Resource createRelative(String relativePath) throws IOException;

	@Nullable
	String getFilename();

	String getDescription();

}

Resource有很多子类:

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
社会演员多的头像社会演员多普通用户
上一篇 2023年12月11日
下一篇 2023年12月11日

相关推荐