netty实现同一个端口接收并解析多种解析

发布时间 2023-08-21 16:26:47作者: wang_longan

1、背景

项目需求,一个端口既能接收tcp协议数据又能接收http协议数据并解析,如果简单使用java socket也能做到,但是当客户端使用post请求发送的是二进制文件时,socket将无法解析,因为无法判断二进制文件的开始和结束。

由于netty有现成的解析http协议的工具包,所以使用netty可极大方便实现该需求

2、使用netty实现思路

3、贴上源代码

查看代码
 package com.example.demo.netty;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.Epoll;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollServerSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import lombok.extern.slf4j.Slf4j;

/**
* netty启动主类
*/
@Slf4j
public class NettyServer {

    public void start(int port) {
        ServerBootstrap serverBootstrap = new ServerBootstrap();
        // 做是否支持epoll轮询判断以获取更高性能
        EventLoopGroup boss = Epoll.isAvailable() ? new EpollEventLoopGroup() : new NioEventLoopGroup();
        EventLoopGroup worker = Epoll.isAvailable() ? new EpollEventLoopGroup() : new NioEventLoopGroup();
        try {
            serverBootstrap.group(boss, worker)
                    .channel(Epoll.isAvailable() ? EpollServerSocketChannel.class : NioServerSocketChannel.class)
                    .childHandler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel ch) {
                            ch.pipeline().addLast(new MyPortUnificationServerHandler());
                        }
                    })
                    .childOption(ChannelOption.SO_KEEPALIVE, true)
                    .childOption(ChannelOption.TCP_NODELAY, true);

            ChannelFuture future = serverBootstrap.bind(port).sync();
            future.channel().closeFuture().sync();
        } catch (Exception e) {
            log.error("e: ", e);
        } finally {
            boss.shutdownGracefully();
            worker.shutdownGracefully();
        }
    }
}
查看代码
 package com.example.demo.netty;

import com.example.demo.netty.http.HttpMsgHandler;
import com.example.demo.netty.tcp.SocketMsgHandler;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.codec.string.StringDecoder;
import io.netty.handler.codec.string.StringEncoder;
import io.netty.util.CharsetUtil;
import lombok.extern.slf4j.Slf4j;

import java.util.List;

/**
 * 统一端口的协议处理器
 * <p>
 * 使用同一个端口去处理TCP/HTTP协议的请求,因为HTTP的底层协议也是TCP,因此可以在此处理器内部可以通过解析部分数据
 * 来判断请求是TCP请求还是HTTP请求,然后使用动态的pipeline切换
 *
 */
@Slf4j
public class MyPortUnificationServerHandler extends ByteToMessageDecoder {
    @Override
    protected void decode(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf, List<Object> list) throws Exception {
        // Will use the first five bytes to detect a protocol.
        if (byteBuf.readableBytes() < 5) {
            return;
        }
        final int magic1 = byteBuf.getUnsignedByte(byteBuf.readerIndex());
        final int magic2 = byteBuf.getUnsignedByte(byteBuf.readerIndex() + 1);

        // 判断是不是HTTP请求
        if (isHttp(magic1, magic2)) {
            log.info("this is a http msg");

            switchToHttp(channelHandlerContext);
        } else {
            log.info("this is a socket msg");
            switchToTcp(channelHandlerContext);
        }
        channelHandlerContext.pipeline().remove(this);
    }

    /**
     * 跳到TCP处理
     * @param ctx
     */
    private void switchToTcp(ChannelHandlerContext ctx) {
        ChannelPipeline pipeline = ctx.pipeline();
        // Decoder
        pipeline.addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8));
        // Encoder
        pipeline.addLast("stringEncoder", new StringEncoder(CharsetUtil.UTF_8));
        // 添加自定义的TCP解析处理器
        pipeline.addLast(new SocketMsgHandler());
    }

    /**
     * 跳转到http处理
     *
     * @param ctx
     */
    private void switchToHttp(ChannelHandlerContext ctx) {
        ChannelPipeline p = ctx.pipeline();

        p.addLast(new HttpRequestDecoder())
                .addLast(new HttpResponseEncoder())
                .addLast(new HttpObjectAggregator(1024 * 1024 * 8)) //8M
                .addLast(new HttpMsgHandler());

    }


    /**
     * 判断请求是否是HTTP请求
     *
     * @param magic1 报文第一个字节
     * @param magic2 报文第二个字节
     * @return
     */
    private boolean isHttp(int magic1, int magic2) {
        return magic1 == 'G' && magic2 == 'E' || // GET
                magic1 == 'P' && magic2 == 'O' || // POST
                magic1 == 'P' && magic2 == 'U' || // PUT
                magic1 == 'H' && magic2 == 'E' || // HEAD
                magic1 == 'O' && magic2 == 'P' || // OPTIONS
                magic1 == 'P' && magic2 == 'A' || // PATCH
                magic1 == 'D' && magic2 == 'E' || // DELETE
                magic1 == 'T' && magic2 == 'R' || // TRACE
                magic1 == 'C' && magic2 == 'O';   // CONNECT
    }
}

 

查看代码
 package com.example.demo.netty.tcp;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import java.io.IOException;
import java.net.InetSocketAddress;


/**
 * Socket数据处理器
 *
 */
@Slf4j
@Component
public class SocketMsgHandler extends SimpleChannelInboundHandler<String> {

    public static SocketMsgHandler socketMsgHandler;

    @PostConstruct
    public void init() {
        socketMsgHandler = this;
    }


    /**
     * 处理TCP协议数据
     * @param ctx           the {@link ChannelHandlerContext} which this {@link SimpleChannelInboundHandler}
     *                      belongs to
     * @param msg           the message to handle
     * @throws Exception
     */
    @Override
    public void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception {
        log.info("tcp socket msg: {}", msg);
        String resp = "hello world";


        ctx.writeAndFlush(resp);
    }




    /**
     * 从客户端收到新的数据、读取完成时调用
     *
     * @param ctx
     */
    @Override
    public void channelReadComplete(ChannelHandlerContext ctx) throws IOException {
        log.info("channelReadComplete");
        ctx.flush();
    }

    /**
     * 当出现 Throwable 对象才会被调用,即当 Netty 由于 IO 错误或者处理器在处理事件时抛出的异常时
     *
     * @param ctx
     * @param cause
     */
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws IOException {
        log.error("exceptionCaught: {}", cause.toString());
        ctx.close();//抛出异常,断开与客户端的连接
    }

    /**
     * 客户端与服务端第一次建立连接时 执行
     *
     * @param ctx
     * @throws Exception
     */
    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception, IOException {
        super.channelActive(ctx);
        ctx.channel().read();
        InetSocketAddress insocket = (InetSocketAddress) ctx.channel().remoteAddress();
        String clientIp = insocket.getAddress().getHostAddress();
        //此处不能使用ctx.close(),否则客户端始终无法与服务端建立连接
        log.info("channelActive:" + clientIp + ctx.name());
    }

    /**
     * 客户端与服务端 断连时 执行
     *
     * @param ctx
     * @throws Exception
     */
    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception, IOException {
        super.channelInactive(ctx);
        InetSocketAddress insocket = (InetSocketAddress) ctx.channel().remoteAddress();
        String clientIp = insocket.getAddress().getHostAddress();
        ctx.close(); //断开连接时,必须关闭,否则造成资源浪费,并发量很大情况下可能造成宕机
        log.info("channelInactive:" + clientIp);
    }

    /**
     * 服务端当read超时, 会调用这个方法
     *
     * @param ctx
     * @param evt
     * @throws Exception
     */
    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception, IOException {
        super.userEventTriggered(ctx, evt);
        InetSocketAddress insocket = (InetSocketAddress) ctx.channel().remoteAddress();
        String clientIp = insocket.getAddress().getHostAddress();
        ctx.close();//超时时断开连接
        log.info("userEventTriggered:" + clientIp);
    }
}
查看代码
 package com.example.demo.netty.http;

import java.io.File;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import com.alibaba.fastjson.JSONObject;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.QueryStringDecoder;
import io.netty.handler.codec.http.multipart.Attribute;
import io.netty.handler.codec.http.multipart.DiskFileUpload;
import io.netty.handler.codec.http.multipart.FileUpload;
import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder;
import io.netty.handler.codec.http.multipart.InterfaceHttpData;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;


/**
 * http协议格式数据的处理器
 */
@Slf4j
public class HttpMsgHandler extends SimpleChannelInboundHandler<FullHttpRequest> {


    static {
        DiskFileUpload.baseDirectory = "E:\\temporary\\upload";
    }


    @Override
    protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest httpRequest) throws Exception {

        String resBody = "hello world";

        //1.根据请求类型做出处理
        HttpMethod type = httpRequest.method();

        if (type.equals(HttpMethod.GET)) {
            //Get请求
            String getRespBody = parseGet(httpRequest);
            if (StringUtils.isNotBlank(getRespBody)) {
                resBody = getRespBody;
            }
        }
        else if (type.equals(HttpMethod.POST)) {
            //post请求
            parsePost(httpRequest);
        }
        else {
            log.error("不支持的请求方式,{}", type);
        }
        log.info("resp: {}", resBody);


        //给客户端写数据
        writeResponse(ctx, httpRequest, HttpResponseStatus.OK, resBody);

    }

    /**
     * 给客户端响应
     * @param ctx
     * @param fullHttpRequest
     * @param status
     * @param msg
     */
    private void writeResponse(ChannelHandlerContext ctx, FullHttpRequest fullHttpRequest, HttpResponseStatus status, String msg) {
        //创建一个默认的响应对象
        FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status);
        //写入数据
        response.content().writeBytes(msg.getBytes(StandardCharsets.UTF_8));
        //设置响应头--content-type
        response.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/json");
        //设置内容长度--content-length
        HttpUtil.setContentLength(response, response.content().readableBytes());
        boolean keepAlive = HttpUtil.isKeepAlive(fullHttpRequest);
        if (keepAlive) {
            response.headers().set(HttpHeaderNames.CONNECTION, "keep-alive");
        }
        ctx.writeAndFlush(response);
    }

    /**
     * 处理get请求
     * @param httpRequest
     */
    private String parseGet(FullHttpRequest httpRequest) {
        log.info("request uri: {}", httpRequest.uri());
        //通过请求url获取参数信息
        Map<String, String> paramMap = parseKvStr(httpRequest.uri(), true);

        String responseBody = null;
        if (StringUtils.contains(httpRequest.uri(), "/ping?flag=")) {
            String flag = paramMap.get("flag");
            log.info("flag: {}", flag);
            responseBody = flag;

            return responseBody;
        }

        return responseBody;
    }

    /**
     * 从url中获取参数信息
     * @param uri 请求的url
     * @param hasPath
     */
    private Map<String, String> parseKvStr(String uri, boolean hasPath) {
        QueryStringDecoder queryStringDecoder = new QueryStringDecoder(uri, StandardCharsets.UTF_8, hasPath);
        Map<String, List<String>> parameters = queryStringDecoder.parameters();
        Map<String, String> queryParams = new HashMap<>();
        for (Map.Entry<String, List<String>> attr : parameters.entrySet()) {
            for (String attrVal : attr.getValue()) {
                queryParams.put(attr.getKey(), attrVal);
            }
        }
        return queryParams;
    }

    /**
     * 处理post请求
     * application/json
     * application/x-www-form-urlencoded
     * multipart/form-data
     * @param httpRequest
     */
    private void parsePost(FullHttpRequest httpRequest) {
        String contentType = getContentType(httpRequest);
        switch (contentType) {
            case "application/json":
                parseJson(httpRequest);
                break;
            case "application/x-www-form-urlencoded":
                parseFormData(httpRequest);
                break;
            case "multipart/form-data":
                parseMultipart(httpRequest);
                break;
            default:
                log.error("不支持的数据类型:{}", contentType);
                break;
        }

    }


    /**
     * 处理文件上传
     * 在该方法中的解析方式,同样也适用于解析普通的表单提交请求
     * 通用(普通post,文件上传)
     * @param httpRequest
     */
    private void parseMultipart(FullHttpRequest httpRequest) {
        HttpPostRequestDecoder httpPostRequestDecoder = new HttpPostRequestDecoder(httpRequest);
        //判断是否是multipart
        if (httpPostRequestDecoder.isMultipart()) {
            //获取body中的数据
            List<InterfaceHttpData> bodyHttpDatas = httpPostRequestDecoder.getBodyHttpDatas();
            for (InterfaceHttpData dataItem : bodyHttpDatas) {
                //判断表单项的类型
                InterfaceHttpData.HttpDataType httpDataType = dataItem.getHttpDataType();
                if (httpDataType.equals(InterfaceHttpData.HttpDataType.Attribute)){
                    //普通表单项,直接获取数据
                    Attribute attribute = (Attribute) dataItem;
                    try {
                        log.info("表单项名称:{},表单项值:{}",attribute.getName(),attribute.getValue());
                    } catch (IOException e) {
                        log.error("获取表单项数据错误,msg={}",e.getMessage());
                    }
                } else if (httpDataType.equals(InterfaceHttpData.HttpDataType.FileUpload)){
                    //文件上传项,将文件保存到磁盘
                    FileUpload fileUpload = (FileUpload) dataItem;
                    //获取原始文件名称
                    String filename = fileUpload.getFilename();
                    //获取表单项名称
                    String name = fileUpload.getName();
                    log.info("文件名称:{},表单项名称:{}",filename,name);
                    //将文件保存到磁盘
                    if (fileUpload.isCompleted()) {
                        try {
                            String path = DiskFileUpload.baseDirectory + File.separator + filename;
                            System.out.println("path: " + path);
                            fileUpload.renameTo(new File(path));
                        } catch (IOException e) {
                            log.error("文件转存失败,msg={}",e.getMessage());
                        }
                    }
                }
            }
        }
    }

    /**
     * 处理表单数据
     * @param httpRequest
     */
    private void parseFormData(FullHttpRequest httpRequest) {
        //两个部分有数据  uri,body
        parseKvStr(httpRequest.uri(), true);
        parseKvStr(httpRequest.content().toString(StandardCharsets.UTF_8), false);
    }

    /**
     * 处理json数据
     * @param httpRequest
     */
    private void parseJson(FullHttpRequest httpRequest) {
        String jsonStr = httpRequest.content().toString(StandardCharsets.UTF_8);
        JSONObject jsonObject = JSONObject.parseObject(jsonStr);
        jsonObject.forEach((key, value) -> log.info("item:{}={}", key, value));
    }

    /**
     * 获取请求数据类型
     * @param httpRequest
     * @return
     */
    private String getContentType(FullHttpRequest httpRequest) {
        HttpHeaders headers = httpRequest.headers();
        String contentType = headers.get(HttpHeaderNames.CONTENT_TYPE);
        return contentType.split(";")[0];
    }



    /**
     * 从客户端收到新的数据、读取完成时调用
     *
     * @param ctx
     */
    @Override
    public void channelReadComplete(ChannelHandlerContext ctx) throws IOException {
        log.info("channelReadComplete");
        ctx.flush();
    }

    /**
     * 当出现 Throwable 对象才会被调用,即当 Netty 由于 IO 错误或者处理器在处理事件时抛出的异常时
     *
     * @param ctx
     * @param cause
     */
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws IOException {
        log.error("exceptionCaught: {}", cause.toString());
        ctx.close();//抛出异常,断开与客户端的连接
    }

    /**
     * 客户端与服务端第一次建立连接时 执行
     *
     * @param ctx
     * @throws Exception
     */
    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception, IOException {
        super.channelActive(ctx);
        ctx.channel().read();
        InetSocketAddress insocket = (InetSocketAddress) ctx.channel().remoteAddress();
        String clientIp = insocket.getAddress().getHostAddress();
        //此处不能使用ctx.close(),否则客户端始终无法与服务端建立连接
        log.info("channelActive:" + clientIp + ctx.name());
    }

    /**
     * 客户端与服务端 断连时 执行
     *
     * @param ctx
     * @throws Exception
     */
    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception, IOException {
        super.channelInactive(ctx);
        InetSocketAddress insocket = (InetSocketAddress) ctx.channel().remoteAddress();
        String clientIp = insocket.getAddress().getHostAddress();
        ctx.close(); //断开连接时,必须关闭,否则造成资源浪费,并发量很大情况下可能造成宕机
        log.info("channelInactive:" + clientIp);
    }

    /**
     * 服务端当read超时, 会调用这个方法
     *
     * @param ctx
     * @param evt
     * @throws Exception
     */
    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception, IOException {
        super.userEventTriggered(ctx, evt);
        InetSocketAddress insocket = (InetSocketAddress) ctx.channel().remoteAddress();
        String clientIp = insocket.getAddress().getHostAddress();
        ctx.close();//超时时断开连接
        log.info("userEventTriggered:" + clientIp);
    }
}