java限流-基于redis+lua

发布时间 2023-07-04 15:21:05作者: 五官一体即忢

redis是线程安全的,天然具有线程安全的特性,支持原子性操作,限流服务不仅需要承接超高QPS,还要保证限流逻辑的执行层面具备线程安全的特性,利用Redis这些特性做限流,既能保证线程安全,也能保证性能。

结合上面的流程图,这里梳理出一个整体的实现思路:

编写lua脚本,指定入参的限流规则,比如对特定的接口限流时,可以根据某个或几个参数进行判定,调用该接口的请求,在一定的时间窗口内监控请求次数;

既然是限流,最好能够通用,可将限流规则应用到任何接口上,那么最合适的方式就是通过自定义注解形式切入;

引入redis依赖
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
自定义注解
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface RedisLimitAnnotation {
 
    /**
     * key
     */
    String key() default "";
    /**
     * Key的前缀
     */
    String prefix() default "";
    /**
     * 一定时间内最多访问次数
     */
    int count();
    /**
     * 给定的时间范围 单位(秒)
     */
    int period();
    /**
     * 限流的类型(用户自定义key或者请求ip)
     */
    LimitType limitType() default LimitType.CUSTOMER;
 
}
 自定义redis配置类
import org.springframework.context.annotation.Bean;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.scripting.support.ResourceScriptSource;
import org.springframework.stereotype.Component;
 
import java.io.Serializable;
 
@Component
public class RedisConfiguration {
 
    @Bean
    public DefaultRedisScript<Number> redisluaScript() {
        DefaultRedisScript<Number> redisScript = new DefaultRedisScript<>();
        redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("limit.lua")));
        redisScript.setResultType(Number.class);
        return redisScript;
    }
 
    @Bean("redisTemplate")
    public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
        RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
        redisTemplate.setConnectionFactory(redisConnectionFactory);
        Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer(Object.class);
 
        //设置value的序列化方式为JSOn
        redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);
        //设置key的序列化方式为String
        redisTemplate.setKeySerializer(new StringRedisSerializer());
 
        redisTemplate.setHashKeySerializer(new StringRedisSerializer());
        redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
        redisTemplate.afterPropertiesSet();
 
        return redisTemplate;
    }
 
}
自定义限流AOP类
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
 
import javax.servlet.http.HttpServletRequest;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.List;
 
@Aspect
@Configuration
public class LimitRestAspect {
 
    private static final Logger logger = LoggerFactory.getLogger(LimitRestAspect.class);
 
    @Autowired
    private RedisTemplate<String, Object> redisTemplate;
 
    @Autowired
    private DefaultRedisScript<Number> redisluaScript;
 
 
    @Pointcut(value = "@annotation(com.congge.config.limit.RedisLimitAnnotation)")
    public void rateLimit() {
 
    }
 
    @Around("rateLimit()")
    public Object interceptor(ProceedingJoinPoint joinPoint) throws Throwable {
        MethodSignature signature = (MethodSignature) joinPoint.getSignature();
        Method method = signature.getMethod();
        Class<?> targetClass = method.getDeclaringClass();
        RedisLimitAnnotation rateLimit = method.getAnnotation(RedisLimitAnnotation.class);
        if (rateLimit != null) {
            HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
            String ipAddress = getIpAddr(request);
            StringBuffer stringBuffer = new StringBuffer();
            stringBuffer.append(ipAddress).append("-")
                    .append(targetClass.getName()).append("- ")
                    .append(method.getName()).append("-")
                    .append(rateLimit.key());
            List<String> keys = Collections.singletonList(stringBuffer.toString());
            //调用lua脚本,获取返回结果,这里即为请求的次数
            Number number = redisTemplate.execute(
                    redisluaScript,
                    keys,
                    rateLimit.count(),
                    rateLimit.period()
            );
            if (number != null && number.intValue() != 0 && number.intValue() <= rateLimit.count()) {
                logger.info("限流时间段内访问了第:{} 次", number.toString());
                return joinPoint.proceed();
            }
        } else {
            return joinPoint.proceed();
        }
        throw new RuntimeException("访问频率过快,被限流了");
    }
 
    /**
     * 获取请求的IP方法
     * @param request
     * @return
     */
    private static String getIpAddr(HttpServletRequest request) {
        String ipAddress = null;
        try {
            ipAddress = request.getHeader("x-forwarded-for");
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
                ipAddress = request.getHeader("Proxy-Client-IP");
            }
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
                ipAddress = request.getHeader("WL-Proxy-Client-IP");
            }
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {
                ipAddress = request.getRemoteAddr();
            }
            // 对于通过多个代理的情况,第一个IP为客户端真实IP,多个IP按照','分割
            if (ipAddress != null && ipAddress.length() > 15) {
                if (ipAddress.indexOf(",") > 0) {
                    ipAddress = ipAddress.substring(0, ipAddress.indexOf(","));
                }
            }
        } catch (Exception e) {
            ipAddress = "";
        }
        return ipAddress;
    }
 
}

该类要做的事情和上面的两种限流措施类似,不过在这里核心的限流是通过读取lua脚步,通过参数传递给lua脚步实现的。

自定义lua脚本

在工程的resources目录下,添加如下的lua脚本

local key = "rate.limit:" .. KEYS[1]
 
local limit = tonumber(ARGV[1])
 
local current = tonumber(redis.call('get', key) or "0")
 
if current + 1 > limit then
  return 0
else
   -- 没有超阈值,将当前访问数量+1,并设置2秒过期(可根据自己的业务情况调整)
   redis.call("INCRBY", key,"1")
   redis.call("expire", key,"2")
   return current + 1
end
添加测试接口
@RestController
public class RedisController {
 
    //localhost:8081/redis/limit
    @GetMapping("/redis/limit")
    @RedisLimitAnnotation(key = "queryFromRedis",period = 1, count = 1)
    public String queryFromRedis(){
        return "success";
    }
 
}