Spring Cloud Gateway实践(一):获取参数

发布时间 2023-12-10 15:22:30作者: _且歌且行

SCG(Spring Cloud Gateway)就我个人理解,是想让开发者把它作为一个较为简单的网关框架,只需简单在yml文件中写几个配置项就可以运行。所以它不大推荐在网关这一层获取body数据或者做一下复杂的业务处理。故而在实际编写代码中,获取queryParam很容易,但body数据就比较麻烦了,如果要修改就更麻烦。在本篇文章主要讨论如何获取请求方式中的参数。
SCG获取参数一般有两种方式:

  1. 通过Filter过滤器
  2. 通过Predicate断言

配置Filter获取

import lombok.NonNull;  
import lombok.extern.slf4j.Slf4j;  
import org.springframework.cloud.gateway.filter.GatewayFilterChain;  
import org.springframework.cloud.gateway.filter.GlobalFilter;  
import org.springframework.core.Ordered;  
import org.springframework.core.io.buffer.DataBuffer;  
import org.springframework.core.io.buffer.DataBufferUtils;  
import org.springframework.http.server.reactive.ServerHttpRequest;  
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;  
import org.springframework.stereotype.Component;  
import org.springframework.web.server.ServerWebExchange;  
import reactor.core.publisher.Flux;  
import reactor.core.publisher.Mono;  
   
@Component  
public class ReadParamFilter implements GlobalFilter, Ordered {  
  
    @Override  
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {  
        if (exchange.getRequest().getHeaders().getContentType() == null) {  
            return chain.filter(exchange);  
        } else {  
            return DataBufferUtils.join(exchange.getRequest().getBody())  
                    .flatMap(dataBuffer -> {  
                        DataBufferUtils.retain(dataBuffer);  
                        Flux<DataBuffer> cachedFlux = Flux  
                                .defer(() -> Flux.just(dataBuffer.slice(0, dataBuffer.readableByteCount())));  
                        ServerHttpRequest mutatedRequest = new ServerHttpRequestDecorator(exchange.getRequest()) {  
                            @Override  
                            public @NonNull Flux<DataBuffer> getBody() {  
                                return cachedFlux;  
                            }  
  
                        };  
                        exchange.getAttributes().put("cachedRequestBodyObject", cachedFlux);  
                        log.debug("global filter header: {}", mutatedRequest.getHeaders());  
                        log.debug("query param:{}", exchange.getRequest().getQueryParams());  
                        log.debug("mutatedRequest:{}", mutatedRequest.getQueryParams());  
                        return chain.filter(exchange.mutate().request(mutatedRequest).build());  
                    });  
        }  
    }  
  
    @Override  
    public int getOrder() {  
        return Ordered.HIGHEST_PRECEDENCE;  
    }  
}

缺点:在断言阶段不能获取参数

配置Predicate获取

配置文件

server:  
  port: 8081  
  
logging:  
  level:  
    root: DEBUG  
  
management:  
  #所有暴露端点  
  endpoints:  
    web:  
      base-path: /metrics  
      exposure:  
        include: prometheus,health,loggers  
  #开启端点  
  endpoint:  
    health:  
      show-details: always  
    shutdown:  
      enabled: false  
    prometheus:  
      enabled: true  
    metrics:  
      enabled: true  
  # 度量指标  
  metrics:  
    export:  
      prometheus:  
        enabled: true  
  # 访问端口  
  server:  
    port: 8093  
  
spring:  
  cloud:  
    gateway:  
      predicate:  
        read-body:  
          enabled: true  
      routes:  
        - id: post-json  
          uri: 'https://localhost:8081'  
          predicates:  
            - Path=/post-json  
            - name: ReadBody  
              args:  
                inClass: "#{T(String)}" #json  
                predicate: "#{@bodyPredicate}" #自定义断言处理器  
            - name: CheckParamLegal  
              args:  
                mediaType: "JSON"  
                params:  
                  - name: "hello"  
                    position: "$.data.[*].hello"  
                    legalRange:  
                      - 1  
                      - 2  
                      - 3  
          filters:  
            - name: Redirect  
              args:  
                aggregate: true 
        - id: post-form  
          uri: 'https://localhost:8081'  
          predicates:  
            - Path=/post-form  
            - name: ReadBody  
              args:  
                inClass: "#{T(String)}" #json  
                predicate: "#{@bodyPredicate}" #自定义断言处理器  
            - name: CheckParamLegal  
              args:  
                mediaType: "FORM"  
                params:  
                  - name: "hello"  
                    position: "$.data.hello"  
                    legalRange:  
                      - 1  
                      - 2  
                      - 3  
          filters:  
        - id: post-xwform  
          uri: 'https://localhost:8081'  
          predicates:  
            - Path=/post-xwform  
            - name: ReadBody  
              args:  
                inClass: "#{T(String)}" #json  
                predicate: "#{@bodyPredicate}" #自定义断言处理器  
            - name: CheckParamLegal  
              args:  
                mediaType: "XW_FORM"  
                params:  
                  - name: "hello"  
                    position: "$.data.hello"  
                    legalRange:  
                      - 1  
                      - 2  
                      - 3  
          filters:  
        - id: get  
          uri: 'https://localhost:8081'  
          predicates:  
            - Method=GET  
            - Path=/get  
            - name: CheckParamLegal  
              args:  
                params:  
                  - name: "hello"  
                    legalRange:  
                      - 1  
                      - 2  
                      - 3  
          filters:  
            - name: Redirect  
              args:  
                aggregate: true  
  
  
  
  datasource:  
    driver-class-name: com.mysql.cj.jdbc.Driver  
    username: "xxxx"  
    password: "xxx"  
    url: jdbc:mysql://localhost:3306/test?serverTimezone=UTC&characterEncoding=utf-8&useSSL=false&allowPublicKeyRetrieval=true  
  codec:  
    max-in-memory-size: 100MB  
  main:  
    allow-circular-references: true  
    allow-bean-definition-overriding: true  
    web-application-type: reactive  
  application:  
    name: xxxx 
  project:  
    name: xxx

断言类

@Component  
public class BodyPredicate implements Predicate {  
   @Override  
   public boolean test(Object o) {  
      return true;  
   }  
}

获取参数

@Service  
public class ParamFactory {  
   @Autowired  
   Map<String, ParamStrategy> getParamFactoryMap;  
  
   public ParamStrategy getParamStrategy(HttpMethod requestMethod){  
      return getParamFactoryMap.get(requestMethod.name());  
   }  
}

获取参数策略

public abstract class ParamStrategy {  
  
    public RequestParamBO analyzeRequestParam(ServerWebExchange exchange) {  
        return doAnalyzeRequestParam(exchange);  
    }  
  
    /**  
     * 解析请求数据  
     *  
     * @param exchange  
    * @return  
     */  
    protected abstract RequestParamBO doAnalyzeRequestParam(ServerWebExchange exchange);  
  
    /**  
     * 获取某个请求参数  
     *  
     * @param requestMessage  
     * @param paramKey  
     * @param position  
     * @return  
     */  
    public abstract String getParamValue(RequestMessageBO requestMessage, String paramKey,  
                                         String position);  
  
    /**  
     * 修改请求参数名  
     * GET方法修改 queryParam  
     * Post方法修改 queryParam和form,jsonBody不修改  
     *  
     * @param requestParam  
     * @param names  
     */  
    public abstract void reWriteParamName(RequestMessageBO requestParam, Map<String, String> names);  
  
    public void addQueryParam(RequestParamBO requestParam, String paramKey, String paramValue) {  
        requestParam.getQueryParams().put(paramKey, paramValue);  
    }  
  
}

get

@Component("GET")  
public class GetParamStrategy extends ParamStrategy {  
  
    /**  
     * 解析请求数据  
     *  
     * @param exchange@return  
     */    @Override  
    protected RequestParamBO doAnalyzeRequestParam(ServerWebExchange exchange) {  
        Map<String, String> paramMap = new HashMap<>();  
        MultiValueMap<String, String> queryParams = exchange.getRequest().getQueryParams();  
        if (!queryParams.isEmpty()) {  
            paramMap =  
                    queryParams.entrySet().stream()  
                            .collect(  
                                    Collectors.toMap(  
                                            Map.Entry::getKey,  
                                            entry -> {  
                                                List<String> list =  
                                                        new ArrayList<>(entry.getValue());  
                                                // list包含空数据  
                                                list.removeIf(Objects::isNull);  
                                                if (list.size() != 0) {  
                                                    return entry.getValue().get(0);  
                                                } else {  
                                                    return "";  
                                                }  
                                            }));  
        }  
        return RequestParamBO.builder()  
                .queryParams(paramMap)  
                .build();  
    }  
  
    @Override  
    public String getParamValue(RequestMessageBO requestMessage, String paramKey, String position) {  
        Map<String,String> queryParam = requestMessage.getParam().getQueryParams();  
        if (CollectionUtils.isEmpty(queryParam)){  
            return null;  
        }  
        return queryParam.get(paramKey);  
    }  
  
    @Override  
    public void reWriteParamName(RequestMessageBO requestParam, Map<String, String> names) {  
        Map<String,String> queryParam = requestParam.getParam().getQueryParams();  
        if (CollectionUtils.isEmpty(queryParam)){  
            return;  
        }  
        for (var entry : names.entrySet()) {  
            String value = queryParam.get(entry.getKey());  
            queryParam.put(entry.getValue(),value);  
            queryParam.remove(entry.getKey());  
        }  
  
    }  
}

post

@Component("POST")  
public class PostParamStrategy extends ParamStrategy {  
  
    private static final String XW_FORM_PARAM_REGEX = "&";  
    private static final String XW_KEY_VALUE_REGEX = "=";  
  
    /**  
     * 解析请求数据  
     *  
     * @param exchange  
     * @return  
     */  
    @Override  
    protected RequestParamBO doAnalyzeRequestParam(ServerWebExchange exchange) {  
        Map<String, String> paramMap = new HashMap<>();  
        Map<String, Object> attributes = exchange.getAttributes();  
        MultiValueMap<String, String> queryParams = exchange.getRequest().getQueryParams();  
        RequestParamBO requestParams = new RequestParamBO();  
        if (CollectionUtil.isNotEmpty(queryParams)) {  
            paramMap =  
                    queryParams.entrySet().stream()  
                            .collect(  
                                    Collectors.toMap(  
                                            Map.Entry::getKey,  
                                            entry -> {  
                                                List<String> list =  
                                                        new ArrayList<>(entry.getValue());  
                                                // list包含空数据  
                                                list.removeIf(Objects::isNull);  
                                                if (list.size() != 0) {  
                                                    return entry.getValue().get(0);  
                                                } else {  
                                                    return "";  
                                                }  
                                            }));  
        }  
        requestParams.setQueryParams(paramMap);  
        MediaType contentType = exchange.getRequest().getHeaders().getContentType();  
        String body = (String) attributes.get(CACHE_REQUEST_BODY_OBJECT);  
  
        if (MULTIPART_FORM_DATA.isCompatibleWith(contentType)) {  
            assert contentType != null;  
            requestParams.setFormParams(getFormParam(contentType.toString(), body));  
        } else if (APPLICATION_FORM_URLENCODED.isCompatibleWith(contentType)) {  
            requestParams.setFormParams(getXwFormParam(body));  
        } else if (APPLICATION_JSON.isCompatibleWith(contentType)) {  
            requestParams.setJsonBody(body);  
        }  
  
        return requestParams;  
    }  
  
    @Override  
    public String getParamValue(RequestMessageBO requestMessage, String paramKey, String position) {  
        MediaType mediaType = requestMessage.getMediaType();  
        if (APPLICATION_JSON.isCompatibleWith(mediaType)) {  
            Object document = Configuration.defaultConfiguration()  
                    .jsonProvider().parse(requestMessage.getParam().getJsonBody());  
            JSONArray paramValues = Objects.requireNonNull(JsonPath.read(document, position));  
            return String.valueOf(paramValues.get(0));  
        }else {  
            return requestMessage.getParam().getFormParams().get(paramKey);  
        }  
    }  
  
    @Override  
    public void reWriteParamName(RequestMessageBO requestParam, Map<String, String> names) {  
        Map<String,String> formParams = requestParam.getParam().getFormParams();  
        Map<String,String> queryParams = requestParam.getParam().getQueryParams();  
        if (CollectionUtil.isNotEmpty(formParams)){  
            for (var entry : names.entrySet()) {  
                String value = formParams.get(entry.getKey());  
                formParams.put(entry.getValue(),value);  
                formParams.remove(entry.getKey());  
            }  
        }  
  
        if (CollectionUtil.isNotEmpty(queryParams)){  
            for (var entry : names.entrySet()) {  
                String value = queryParams.get(entry.getKey());  
                queryParams.put(entry.getValue(),value);  
                queryParams.remove(entry.getKey());  
            }  
        }  
  
    }  
  
    @SneakyThrows  
    private Map<String, String> getFormParam(String contentType, String bodyString) {  
  
        String boundary = contentType.substring(contentType.lastIndexOf("boundary=") + 9);  
        Map<String, String> formMap = Maps.newHashMap();  
        String part =  
                "^\r\nContent-Disposition: form-data; name=\"([^/?]+)\"\r\n\r\n([^/?]+)\r\n--?$";  
        Pattern r = Pattern.compile(part);  
        String[] split = bodyString.split(boundary);  
        for (int x = 1; x < split.length - 1; x++) {  
            Matcher m = r.matcher(split[x]);  
            if (m.find()) {  
                String name = m.group(1);  
                String value = m.group(2);  
                formMap.put(name, value);  
            }  
        }  
        return formMap;  
    }  
  
    private Map<String, String> getXwFormParam(String bodyStr) {  
        Map<String, String> paramMap = new HashMap<>();  
        try {  
            bodyStr = URLDecoder.decode(bodyStr, "utf-8");  
        } catch (UnsupportedEncodingException e) {  
            throw new RuntimeException(e);  
        }  
        String[] params = bodyStr.split(XW_FORM_PARAM_REGEX);  
        for (String paramKeyValue : params) {  
            String[] keyValue = paramKeyValue.split(XW_KEY_VALUE_REGEX);  
            if (keyValue.length == 2) {  
                paramMap.put(keyValue[0], keyValue[1]);  
            }  
        }  
        return paramMap;  
    }
}