Java Spring Boot 拦截器的使用小结

发布时间 2023-12-12 18:14:12作者: 进击的davis

很多时候,我们在开发项目中,总是希望在接口中,尽量进行业务处理,其余的事项交给其他组件来处理,比如:

  • 登录验证
  • 日志记录
  • 接口性能

在 Spring Boot 中,正如大多数框架一样,可以用到拦截件进行处理,不管叫中间件还是拦截件,总之都是为了让我们更好的专注于业务,解耦功能。

我们看看 Spring Boot 中应该怎样应用拦截件。

环境:

  • Spring Boot:3.1.16
  • JDK:17

拦截件接口声明

如果我们自己写个拦截件,通常需要实现这个接口:

package org.springframework.web.servlet;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.lang.Nullable;

public interface HandlerInterceptor {
    default boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        return true;
    }

    default void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, @Nullable ModelAndView modelAndView) throws Exception {
    }

    default void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, @Nullable Exception ex) throws Exception {
    }
}

正如接口声明一样[1]:

  • boolean preHandle(),方法在请求处理之前被调用。该方法在 Interceptor 类中最先执行,用来进行一些前置初始化操作或是对当前请求做预处理,也可以进行一些判断来决定请求是否要继续进行下去。该方法的返回至是 Boolean 类型,当它返回 false 时,表示请求结束,后续的 Interceptor 和 Controller 都不会再执行;当它返回为 true 时会继续调用下一个 Interceptor 的 preHandle 方法,如果已经是最后一个 Interceptor 的时候就会调用当前请求的 Controller 方法。

  • void postHandle(),方法在当前请求处理完成之后,也就是 Controller 方法调用之后执行,但是它会在 DispatcherServlet 进行视图返回渲染之前被调用,所以我们可以在这个方法中对 Controller 处理之后的 ModelAndView 对象进行操作。

  • void afterCompletion(),方法需要在当前对应的 Interceptor 类的 postHandler 方法返回值为 true 时才会执行。顾名思义,该方法将在整个请求结束之后,也就是在 DispatcherServlet 渲染了对应的视图之后执行。此方法主要用来进行资源清理。

拦截件的执行顺序

这里不考虑过滤器,根据注册拦截件的顺序和限定的 pathPattern,执行顺序如下:

image.png

接下来我们在某一需求中,实现我们的拦截件。

需求:

  • 需要统计某些接口的性能
  • 需要记录请求响应的日志
  • 需要统计接口的访问次数
  • 需要对登陆过的请求做拦截

基于上面的需求,我们逐步实现。

拦截件实现流程

说说大致流程:

1.明确拦截需求

2.编写拦截件

3.注册拦截件

4.编写业务接口

接下来是代码实现,在实现中我们暂不依赖其他三方库。

代码目录:

image.png

拦截件的实现

性能拦截件

package com.example.springbootinterceptorsdemo.interceptors;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.core.NamedThreadLocal;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;

/**
 * function: record time cost of every api handler
 */
public class ApiBenchMarkInterceptor implements HandlerInterceptor {

    private NamedThreadLocal<Long> startTimeThreadLocal = new NamedThreadLocal<>("StopWatch-StartTime");

    private final long THRESHOLD = 500;

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        System.out.println("------ApiBenchMarkInterceptor.preHandle------");
        long beginTime = System.currentTimeMillis();
        startTimeThreadLocal.set(beginTime);

        return true;
    }

    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
        System.out.println("------ApiBenchMarkInterceptor.postHandle------");
    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
        System.out.println("------ApiBenchMarkInterceptor.afterCompletion------");
        long endTime = System.currentTimeMillis();
        long beginTime = startTimeThreadLocal.get();
        long cost = endTime - beginTime;

        // record slow request
        if (cost > THRESHOLD) {
            System.out.println(String.format("%s cost %d millis", request.getRequestURI(), cost));
        }
    }
}

接口访问拦截件

package com.example.springbootinterceptorsdemo.interceptors;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;

import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * function: record api access count
 */
public class ApiAccessRecordInterceptors implements HandlerInterceptor {
    private static Map<String, Integer> pv = new ConcurrentHashMap<>();

    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        System.out.println("------ApiAccessRecordInterceptors.preHandle------");
        HandlerMethod handlerMethod = (HandlerMethod) handler;
        Method method = handlerMethod.getMethod();
        if (pv.get(method.getName()) == null) {
            pv.put(method.getName(), 1);
        } else {
            Integer count = pv.get(method.getName());
            pv.put(method.getName(), count+1);
        }
        return true;
    }

    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
        System.out.println("------ApiAccessRecordInterceptors.postHandle------");
    }

    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
        System.out.println("------ApiAccessRecordInterceptors.afterCompletion------");
        System.out.println(pv.toString());
    }
}

日志记录拦截件

package com.example.springbootinterceptorsdemo.interceptors;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;

/**
 * function: record request and response
 */
public class LogRecordInterceptors implements HandlerInterceptor {

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        String path = request.getRequestURI();
        String method = request.getMethod();
        String clientIp = request.getHeader("X-Real-IP");
        if (clientIp == null)
        clientIp = request.getHeader("X-Forwarded-For");
        System.out.println("------LogRecordInterceptors.preHandle------");
        System.out.println(String.format("Log -> clientIP: %s, path: %s, method: %s", clientIp, path, method));

        return true;
    }

    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
        System.out.println("------LogRecordInterceptors.postHandle------");
    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
        System.out.println("------LogRecordInterceptors.afterCompletion------");
    }
}

登陆验证拦截件

package com.example.springbootinterceptorsdemo.interceptors;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;

import java.lang.reflect.Method;

/**
 * function: check auth
 */
public class AuthCheckInterceptors implements HandlerInterceptor {
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        System.out.println("------AuthCheckInterceptors.postHandle------");
        String token = request.getHeader("token");
        // TODO: verify token
        return true;
    }

    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
        System.out.println("------AuthCheckInterceptors.postHandle------");
    }

    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
        System.out.println("------AuthCheckInterceptors.afterCompletion------");
    }
}

注册拦截件

需要明确哪些接口需要用到,什么时候用到。

package com.example.springbootinterceptorsdemo.config;

import com.example.springbootinterceptorsdemo.interceptors.ApiAccessRecordInterceptors;
import com.example.springbootinterceptorsdemo.interceptors.ApiBenchMarkInterceptor;
import com.example.springbootinterceptorsdemo.interceptors.AuthCheckInterceptors;
import com.example.springbootinterceptorsdemo.interceptors.LogRecordInterceptors;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

/**
 * register interceptors class
 * attention:
 * 1.if record api performance, add in the first
 * 2.if no path patterns, stands for intercept all request
 * 3.if add exclude path patterns, app will intercept other request exclude the patterns in exclude ones
 */
@Configuration
public class InterceptorsConfig implements WebMvcConfigurer {

    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(new ApiBenchMarkInterceptor()).addPathPatterns("/performance"); // record performance, set first
        registry.addInterceptor(new ApiAccessRecordInterceptors()); // record method and count
        registry.addInterceptor(new LogRecordInterceptors()); // record request info
        registry.addInterceptor(new AuthCheckInterceptors()).addPathPatterns("/user/*", "/article/*");
    }
}

用户类

package com.example.springbootinterceptorsdemo.model;

public class User {

    private String name;

    private String password;

    private int age = 20;

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public String getPassword() {
        return password;
    }

    public void setPassword(String password) {
        this.password = password;
    }

    public int getAge() {
        return age;
    }

    public void setAge(int age) {
        this.age = age;
    }
}

控制器类

index控制器类

package com.example.springbootinterceptorsdemo.controller;

import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

@RestController
public class IndexController {

    @GetMapping("/index")
    public String index() {
        System.out.println("this is from index handler");
        return "Hello, this is Spring Boot web application!";
    }
}

性能测试控制器类

package com.example.springbootinterceptorsdemo.controller;

import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

import java.util.Random;

@RestController
public class CostTimeController {

    @GetMapping("/performance")
    public Object performance() throws InterruptedException {
        long start = System.currentTimeMillis();
        Random random = new Random();
        Thread.sleep(random.nextLong(1000));
        return "performance is ok, cost: " + (System.currentTimeMillis() - start);
    }
}

登录控制器类

package com.example.springbootinterceptorsdemo.controller;

import com.example.springbootinterceptorsdemo.model.User;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;

import java.util.UUID;

@RestController
public class AuthController {
    private static String name = "admin";

    private static String password = "12345678";

    @PostMapping("/login")
    public Object login(HttpServletRequest request, HttpServletResponse response, @RequestBody User user) {
        if (user.getName().equals(name) && user.getPassword().equals(password)) {
            String token = UUID.randomUUID().toString();
            response.addHeader("token", token);
        }

        return "login success!";
    }
}

用户增删改查控制器类

package com.example.springbootinterceptorsdemo.controller;

import com.example.springbootinterceptorsdemo.model.User;

import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

import java.util.Base64;

@RestController
@RequestMapping("/user")
public class UserController {

    @GetMapping("/get")
    public Object getUserInfo() {
        User user = new User();
        user.setName("admin");
        user.setAge(88);
        user.setPassword(Base64.getEncoder().encode("12345678".getBytes()).toString());

        return user;
    }
}

参考: