上篇文章(限流算法与Guava RateLimiter解析)对常用的限流算法及Google Guava基于令牌桶算法的实现RateLimiter进行了介绍。RateLimiter通过线程锁控制同步,只适用于单机应用,在分布式环境下,虽然有像阿里Sentinel的限流开源框架,但对于一些小型应用来说未免过重,但限流的需求在小型项目中也是存在的,比如获取手机验证码的控制,对资源消耗较大操作的访问频率控制等。本文介绍最近写的一个基于RateLimiter,适用于分布式环境下的限流实现,并使用spring-boot-starter的形式发布,比较轻量级且“开箱即用”。

本文限流实现包括两种形式:

  1. 基于RateLimiter令牌桶算法的限速控制(严格限制访问速度)
  2. 基于Lua脚本的限量控制(限制一个时间窗口内的访问量,对访问速度没有严格限制)

限速控制

1. 令牌桶模型

首先定义令牌桶模型,与RateLimiter中类似,包括几个关键属性与关键方法。其中关键属性定义如下,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
@Data
public class RedisPermits {

/**
* 最大存储令牌数
*/
private double maxPermits;
/**
* 当前存储令牌数
*/
private double storedPermits;
/**
* 添加令牌的时间间隔/毫秒
*/
private double intervalMillis;
/**
* 下次请求可以获取令牌的时间,可以是过去(令牌积累)也可以是将来的时间(令牌预消费)
*/
private long nextFreeTicketMillis;

//...

关键方法定义与RateLimiter也大同小异,方法注释基本已描述各方法用途,不再赘述。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
/**
* 构建Redis令牌数据模型
*
* @param permitsPerSecond 每秒放入的令牌数
* @param maxBurstSeconds maxPermits由此字段计算,最大存储maxBurstSeconds秒生成的令牌
* @param nextFreeTicketMillis 下次请求可以获取令牌的起始时间,默认当前系统时间
*/
public RedisPermits(double permitsPerSecond, double maxBurstSeconds, Long nextFreeTicketMillis) {
this.maxPermits = permitsPerSecond * maxBurstSeconds;
this.storedPermits = maxPermits;
this.intervalMillis = TimeUnit.SECONDS.toMillis(1) / permitsPerSecond;
this.nextFreeTicketMillis = nextFreeTicketMillis;
}

/**
* 基于当前时间,若当前时间晚于nextFreeTicketMicros,则计算该段时间内可以生成多少令牌,将生成的令牌加入令牌桶中并更新数据
*/
public void resync(long nowMillis) {
if (nowMillis > nextFreeTicketMillis) {
double newPermits = (nowMillis - nextFreeTicketMillis) / intervalMillis;
storedPermits = Math.min(maxPermits, storedPermits + newPermits);
nextFreeTicketMillis = nowMillis;
}
}

/**
* 保留指定数量令牌,并返回需要等待的时间
*/
public long reserveAndGetWaitLength(long nowMillis, int permits) {
resync(nowMillis);
double storedPermitsToSpend = Math.min(permits, storedPermits); // 可以消耗的令牌数
double freshPermits = permits - storedPermitsToSpend; // 需要等待的令牌数
long waitMillis = (long) (freshPermits * intervalMillis); // 需要等待的时间

nextFreeTicketMillis = LongMath.saturatedAdd(nextFreeTicketMillis, waitMillis);
storedPermits -= storedPermitsToSpend;
return waitMillis;
}

/**
* 在超时时间内,是否有指定数量的令牌可用
*/
public boolean canAcquire(long nowMillis, int permits, long timeoutMillis) {
return queryEarliestAvailable(nowMillis, permits) <= timeoutMillis;
}

/**
* 指定数量令牌数可用需等待的时间
*
* @param permits 需保留的令牌数
* @return 指定数量令牌可用的等待时间,如果为0或负数,表示当前可用
*/
private long queryEarliestAvailable(long nowMillis, int permits) {
resync(nowMillis);
double storedPermitsToSpend = Math.min(permits, storedPermits); // 可以消耗的令牌数
double freshPermits = permits - storedPermitsToSpend; // 需要等待的令牌数
long waitMillis = (long) (freshPermits * intervalMillis); // 需要等待的时间

return LongMath.saturatedAdd(nextFreeTicketMillis - nowMillis, waitMillis);
}

2. 令牌桶控制类

Guava RateLimiter中的控制都在RateLimiter及其子类中(如SmoothBursty),本处涉及到分布式环境下的同步,因此将其解耦,令牌桶模型存储于Redis中,对其同步操作的控制放置在如下控制类,其中同步控制使用到了前面介绍的分布式锁(参考基于Redis分布式锁的正确打开方式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@Slf4j
public class RedisRateLimiter {

/**
* 获取一个令牌,阻塞一直到获取令牌,返回阻塞等待时间
*
* @return time 阻塞等待时间/毫秒
*/
public long acquire(String key) throws IllegalArgumentException {
return acquire(key, 1);
}

/**
* 获取指定数量的令牌,如果令牌数不够,则一直阻塞,返回阻塞等待的时间
*
* @param permits 需要获取的令牌数
* @return time 等待的时间/毫秒
* @throws IllegalArgumentException tokens值不能为负数或零
*/
public long acquire(String key, int permits) throws IllegalArgumentException {
long millisToWait = reserve(key, permits);
log.info("acquire {} permits for key[{}], waiting for {}ms", permits, key, millisToWait);
try {
Thread.sleep(millisToWait);
} catch (InterruptedException e) {
log.error("Interrupted when trying to acquire {} permits for key[{}]", permits, key, e);
}
return millisToWait;
}

/**
* 在指定时间内获取一个令牌,如果获取不到则一直阻塞,直到超时
*
* @param timeout 最大等待时间(超时时间),为0则不等待立即返回
* @param unit 时间单元
* @return 获取到令牌则true,否则false
* @throws IllegalArgumentException
*/
public boolean tryAcquire(String key, long timeout, TimeUnit unit) throws IllegalArgumentException {
return tryAcquire(key, 1, timeout, unit);
}

/**
* 在指定时间内获取指定数量的令牌,如果在指定时间内获取不到指定数量的令牌,则直接返回false,
* 否则阻塞直到能获取到指定数量的令牌
*
* @param permits 需要获取的令牌数
* @param timeout 最大等待时间(超时时间)
* @param unit 时间单元
* @return 如果在指定时间内能获取到指定令牌数,则true,否则false
* @throws IllegalArgumentException tokens为负数或零,抛出异常
*/
public boolean tryAcquire(String key, int permits, long timeout, TimeUnit unit) throws IllegalArgumentException {
long timeoutMillis = Math.max(unit.toMillis(timeout), 0);
checkPermits(permits);

long millisToWait;
boolean locked = false;
try {
locked = lock.lock(key + LOCK_KEY_SUFFIX, WebUtil.getRequestId(), 60, 2, TimeUnit.SECONDS);
if (locked) {
long nowMillis = getNowMillis();
RedisPermits permit = getPermits(key, nowMillis);
if (!permit.canAcquire(nowMillis, permits, timeoutMillis)) {
return false;
} else {
millisToWait = permit.reserveAndGetWaitLength(nowMillis, permits);
permitsRedisTemplate.opsForValue().set(key, permit, expire, TimeUnit.SECONDS);
}
} else {
return false; //超时获取不到锁,也返回false
}
} finally {
if (locked) {
lock.unLock(key + LOCK_KEY_SUFFIX, WebUtil.getRequestId());
}
}
if (millisToWait > 0) {
try {
Thread.sleep(millisToWait);
} catch (InterruptedException e) {

}
}
return true;
}

/**
* 保留指定的令牌数待用
*
* @param permits 需保留的令牌数
* @return time 令牌可用的等待时间
* @throws IllegalArgumentException tokens不能为负数或零
*/
private long reserve(String key, int permits) throws IllegalArgumentException {
checkPermits(permits);
try {
lock.lock(key + LOCK_KEY_SUFFIX, WebUtil.getRequestId(), 60, 2, TimeUnit.SECONDS);
long nowMillis = getNowMillis();
RedisPermits permit = getPermits(key, nowMillis);
long waitMillis = permit.reserveAndGetWaitLength(nowMillis, permits);
permitsRedisTemplate.opsForValue().set(key, permit, expire, TimeUnit.SECONDS);
return waitMillis;
} finally {
lock.unLock(key + LOCK_KEY_SUFFIX, WebUtil.getRequestId());
}
}

/**
* 获取令牌桶
*
* @return
*/
private RedisPermits getPermits(String key, long nowMillis) {
RedisPermits permit = permitsRedisTemplate.opsForValue().get(key);
if (permit == null) {
permit = new RedisPermits(permitsPerSecond, maxBurstSeconds, nowMillis);
}
return permit;
}

/**
* 获取redis服务器时间
*/
private long getNowMillis() {
String luaScript = "return redis.call('time')";
DefaultRedisScript<List> redisScript = new DefaultRedisScript<>(luaScript, List.class);
List<String> now = (List<String>)stringRedisTemplate.execute(redisScript, null);
return now == null ? System.currentTimeMillis() : Long.valueOf(now.get(0))*1000+Long.valueOf(now.get(1))/1000;
}

//...
}

其中:

  1. acquire 是阻塞方法,如果没有可用的令牌,则一直阻塞直到获取到令牌。
  2. tryAcquire 则是非阻塞方法,如果在指定超时时间内获取不到指定数量的令牌,则直接返回false,不阻塞等待。
  3. getNowMillis 获取Redis服务器时间,避免业务服务器时间不一致导致的问题,如果业务服务器能保障时间同步,则可从本地获取提高效率。

3. 令牌桶控制工厂类

工厂类负责管理令牌桶控制类,将其缓存在本地,这里使用了Guava中的Cache,一方面避免每次都新建控制类提高效率,另一方面通过控制缓存的最大容量来避免像用户粒度的限流占用过多的内存。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
public class RedisRateLimiterFactory {

private PermitsRedisTemplate permitsRedisTemplate;
private StringRedisTemplate stringRedisTemplate;
private DistributedLock distributedLock;

private Cache<String, RedisRateLimiter> cache = CacheBuilder.newBuilder()
.initialCapacity(100) //初始大小
.maximumSize(10000) // 缓存的最大容量
.expireAfterAccess(5, TimeUnit.MINUTES) // 缓存在最后一次访问多久之后失效
.concurrencyLevel(Runtime.getRuntime().availableProcessors()) // 设置并发级别
.build();

public RedisRateLimiterFactory(PermitsRedisTemplate permitsRedisTemplate, StringRedisTemplate stringRedisTemplate, DistributedLock distributedLock) {
this.permitsRedisTemplate = permitsRedisTemplate;
this.stringRedisTemplate = stringRedisTemplate;
this.distributedLock = distributedLock;
}

/**
* 创建RateLimiter
*
* @param key RedisRateLimiter本地缓存key
* @param permitsPerSecond 每秒放入的令牌数
* @param maxBurstSeconds 最大存储maxBurstSeconds秒生成的令牌
* @param expire 该令牌桶的redis tty/秒
* @return RateLimiter
*/
public RedisRateLimiter build(String key, double permitsPerSecond, double maxBurstSeconds, int expire) {
if (cache.getIfPresent(key) == null) {
synchronized (this) {
if (cache.getIfPresent(key) == null) {
cache.put(key, new RedisRateLimiter(permitsRedisTemplate, stringRedisTemplate, distributedLock, permitsPerSecond,
maxBurstSeconds, expire));
}
}
}
return cache.getIfPresent(key);
}
}

4. 注解支持

定义注解 @RateLimit 如下,表示以每秒rate的速率放置令牌,最多保留burst秒的令牌,取令牌的超时时间为timeout,limitType用于控制key类型,目前支持:

  1. IP, 根据客户端IP限流
  2. USER, 根据用户限流,对于Spring Security可从SecurityContextHolder中获取当前用户信息,如userId
  3. METHOD, 根据方法名全局限流,className.methodName,注意避免同时对同一个类中的同名方法做限流控制,否则需要修改获取key的逻辑
  4. CUSTOM,自定义,支持表达式解析,如#{id}, #{user.id}
1
2
3
4
5
6
7
8
9
10
11
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface RateLimit {
String key() default "";
String prefix() default "rateLimit:"; //key前缀
int expire() default 60; // 表示令牌桶模型RedisPermits redis key的过期时间/秒
double rate() default 1.0; // permitsPerSecond值
double burst() default 1.0; // maxBurstSeconds值
int timeout() default 0; // 超时时间/秒
LimitType limitType() default LimitType.METHOD;
}

通过切面的前置增强来为添加了 @RateLimit 注解的方法提供限流控制,如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
@Aspect
@Slf4j
public class RedisLimitAspect {
//...

@Before(value = "@annotation(rateLimit)")
public void rateLimit(JoinPoint point, RateLimit rateLimit) throws Throwable {
String key = getKey(point, rateLimit.limitType(), rateLimit.key(), rateLimit.prefix());
RedisRateLimiter redisRateLimiter = redisRateLimiterFactory.build(key, rateLimit.rate(), rateLimit.burst(), rateLimit.expire());
if(!redisRateLimiter.tryAcquire(key, rateLimit.timeout(), TimeUnit.SECONDS)){
ExceptionUtil.rethrowClientSideException(LIMIT_MESSAGE);
}
}

//...

限量控制

1. 限量控制类

限制一个时间窗口内的访问量,可使用计数器算法,借助Lua脚本执行的原子性来实现。

Lua脚本逻辑:

  1. 以需要控制的对象为key(如方法,用户ID,或IP等),当前访问次数为Value,时间窗口值为缓存的过期时间
  2. 如果key存在则将其增1,判断当前值是否大于访问量限制值,如果大于则返回0,表示该时间窗口内已达访问量上限,如果小于则返回1表示允许访问
  3. 如果key不存在,则将其初始化为1,并设置过期时间,返回1表示允许访问
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
public class RedisCountLimiter {

private StringRedisTemplate stringRedisTemplate;

private static final String LUA_SCRIPT = "local c \nc = redis.call('get',KEYS[1]) \nif c and redis.call('incr',KEYS[1]) > tonumber(ARGV[1]) then return 0 end"
+ " \nif c then return 1 else \nredis.call('set', KEYS[1], 1) \nredis.call('expire', KEYS[1], tonumber(ARGV[2])) \nreturn 1 end";

private static final int SUCCESS_RESULT = 1;
private static final int FAIL_RESULT = 0;

public RedisCountLimiter(StringRedisTemplate stringRedisTemplate) {
this.stringRedisTemplate = stringRedisTemplate;
}

/**
* 是否允许访问
*
* @param key redis key
* @param limit 限制次数
* @param expire 时间段/秒
* @return 获取成功true,否则false
* @throws IllegalArgumentException
*/
public boolean tryAcquire(String key, int limit, int expire) throws IllegalArgumentException {
RedisScript<Number> redisScript = new DefaultRedisScript<>(LUA_SCRIPT, Number.class);
Number result = stringRedisTemplate.execute(redisScript, Collections.singletonList(key), String.valueOf(limit), String.valueOf(expire));
if(result != null && result.intValue() == SUCCESS_RESULT) {
return true;
}
return false;
}

}

2. 注解支持

定义注解 @CountLimit 如下,表示在period时间窗口内,最多允许访问limit次,limitType用于控制key类型,取值与 @RateLimit 同。

1
2
3
4
5
6
7
8
9
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface CountLimit {
String key() default "";
String prefix() default "countLimit:"; //key前缀
int limit() default 1; // expire时间段内限制访问次数
int period() default 1; // 表示时间段/秒
LimitType limitType() default LimitType.METHOD;
}

同样采用前值增强来为添加了 @CountLimit 注解的方法提供限流控制,如下

1
2
3
4
5
6
7
@Before(value = "@annotation(countLimit)")
public void countLimit(JoinPoint point, CountLimit countLimit) throws Throwable {
String key = getKey(point, countLimit.limitType(), countLimit.key(), countLimit.prefix());
if (!redisCountLimiter.tryAcquire(key, countLimit.limit(), countLimit.period())) {
ExceptionUtil.rethrowClientSideException(LIMIT_MESSAGE);
}
}

使用示例

1.添加依赖

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
<dependencies>
<dependency>
<groupId>cn.jboost.springboot</groupId>
<artifactId>limiter-spring-boot-starter</artifactId>
<version>1.3-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
</dependencies>

2.配置redis相关参数

1
2
3
4
5
6
7
8
9
10
11
spring:
application:
name: limiter-demo
redis:
#数据库索引
database: 0
host: 192.168.40.92
port: 6379
password: password
#连接超时时间
timeout: 2000

3.测试类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@RestController
@RequestMapping("limiter")
public class LimiterController {

/**
* 注解形式
* @param key
* @return
*/
@GetMapping("/count")
@CountLimit(key = "#{key}", limit = 2, period = 10, limitType = LimitType.CUSTOM)
public String testCountLimit(@RequestParam("key") String key){
return "test count limiter...";
}

/**
* 注解形式
* @param key
* @return
*/
@GetMapping("/rate")
@RateLimit(rate = 1.0/5, burst = 5.0, expire = 120, timeout = 0)
public String testRateLimit(@RequestParam("key") String key){
return "test rate limiter...";
}

@Autowired
private RedisRateLimiterFactory redisRateLimiterFactory;
/**
* 代码段形式
* @param
* @return
*/
@GetMapping("/rate2")
public String testRateLimit(){
RedisRateLimiter limiter = redisRateLimiterFactory.build("LimiterController.testRateLimit", 1.0/30, 30, 120);
if(!limiter.tryAcquire("app.limiter", 0, TimeUnit.SECONDS)) {
System.out.println(LocalDateTime.now());
ExceptionUtil.rethrowClientSideException("您的访问过于频繁,请稍后重试");
}
return "test rate limiter 2...";
}
}

4.验证

启动测试项目,浏览器中访问 http://localhost:8080/limiter/rate?key=test ,第一次访问成功,如图

ratelimiter1

持续刷新,将返回如下错误,直到5s之后再返回成功,限制5秒1次的访问速度

ratelimiter2

注解的使用

  1. 限流类型LimitType支持IP(客户端IP)、用户(userId)、方法(className.methodName)、自定义(CUSTOM)几种形式,默认为METHOD
  2. LimitType为CUSTOM时,需要手动指定key(其它key自动为ip,userid,或methodname),key支持表达式形式,如#{id}, #{user.id}
  3. 针对某个时间窗口内限制访问一次的场景,既可以使用 @CountLimit, 也可以使用 @RateLimit,比如验证码一分钟内只允许获取一次,以下两种形式都能达到目的
1
2
3
4
//同一个手机号码60s内最多访问一次
@CountLimit(key = "#{params.phone}", limit = 1, period = 60, limitType = LimitType.CUSTOM)
//以1/60的速度放置令牌,最多保存60s的令牌(也就是最多保存一个),控制访问速度为1/60个每秒(1个每分钟)
@RateLimit(key = "#{params.phone}", rate = 1.0/60, burst = 60, expire = 120, limitType = LimitType.CUSTOM)

总结

本文介绍了适用于分布式环境的基于RateLimiter令牌桶算法的限速控制与基于计数器算法的限量控制,可应用于中小型项目中有相关需求的场景(注:本实现未做压力测试,如果用户并发量较大需验证效果)。

如果觉得有帮助,别忘了给个star ^_^。作者公众号:半路雨歌,欢迎关注查看更多干货文章。

评论