go-zero开发入门-网关往rpc服务传递额外数据

发布时间 2023-12-13 21:52:50作者: -见

go-zero 的网关服务实际是个 go-zero 的 API 服务,也就是一个 http 服务,或者说 rest 服务。http 转 grpc 使用了开源的 grpcurl 库,当网关需要往 rpc 服务传递额外的数据,比如鉴权数据的时候,通过 http 的 header 进行:

func AuthMiddleware(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) {
    authResp, err := authClient.Authenticate(r.Context(), &authReq) // 调用鉴权服务
    r.Header.Set("Grpc-Metadata-myuid", authResp.UserId) // 往 rpc 服务传递额外数据
    next.ServeHTTP(w, r)
}

rpc 服务端从 metadata 取出数据:

func (l *QueryUserLogic) QueryUser(in *user.UserReq) (*user.UserResp, error) {
	vals := metadata.ValueFromIncomingContext(l.ctx, "gateway-myuid")
    uid = vals[0]
}

这里有两个需要注意的地方,在网关侧的名必须以“Grpc-Metadata-”打头,而 rpc 服务端必须以“gateway-”打头,这是 go-zero 的 gateway/internal/headerprocessor.go 写死的规则:

const (
	metadataHeaderPrefix = "Grpc-Metadata-"
	metadataPrefix       = "gateway-"
)

// ProcessHeaders builds the headers for the gateway from HTTP headers.
func ProcessHeaders(header http.Header) []string {
	var headers []string

	for k, v := range header {
		if !strings.HasPrefix(k, metadataHeaderPrefix) { // 判断是否以“Grpc-Metadata-”打头(网关侧传递的)
			continue // 非以“Grpc-Metadata-”打头的都会被丢弃掉
		}

		key := fmt.Sprintf("%s%s", metadataPrefix, strings.TrimPrefix(k, metadataHeaderPrefix)) // 替换为新的前缀“gateway-”(rpc 服务端看到的)
		for _, vv := range v {
			headers = append(headers, key+":"+vv)
		}
	}

	return headers
}

在文件 zrpc/internal/clientinterceptors/tracinginterceptor.go 中调用了 metadata.NewOutgoingContext:

func startSpan(ctx context.Context, method, target string) (context.Context, trace.Span) {
	md, ok := metadata.FromOutgoingContext(ctx)
	if !ok {
		md = metadata.MD{}
	}
	tr := otel.Tracer(ztrace.TraceName)
	name, attr := ztrace.SpanInfo(method, target)
	ctx, span := tr.Start(ctx, name, trace.WithSpanKind(trace.SpanKindClient),
		trace.WithAttributes(attr...))
	ztrace.Inject(ctx, otel.GetTextMapPropagator(), &md)
	ctx = metadata.NewOutgoingContext(ctx, md)

	return ctx, span
}

// UnaryTracingInterceptor returns a grpc.UnaryClientInterceptor for opentelemetry.
func UnaryTracingInterceptor(ctx context.Context, method string, req, reply any,
	cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
	ctx, span := startSpan(ctx, method, cc.Target())
	defer span.End()

	ztrace.MessageSent.Event(ctx, 1, req)
	err := invoker(ctx, method, req, reply, cc, opts...)
	ztrace.MessageReceived.Event(ctx, 1, reply)
	if err != nil {
		s, ok := status.FromError(err)
		if ok {
			span.SetStatus(codes.Error, s.Message())
			span.SetAttributes(ztrace.StatusCodeAttr(s.Code()))
		} else {
			span.SetStatus(codes.Error, err.Error())
		}
		return err
	}

	span.SetAttributes(ztrace.StatusCodeAttr(gcodes.OK))
	return nil
}

拦截器:

./zrpc/internal/rpcserver.go:		interceptors = append(interceptors, serverinterceptors.UnaryTracingInterceptor)
./zrpc/internal/client.go:		interceptors = append(interceptors, clientinterceptors.UnaryTracingInterceptor)

服务端代码:

//zrpc/internal/rpcserver.go
func (s *rpcServer) buildUnaryInterceptors() []grpc.UnaryServerInterceptor {
	var interceptors []grpc.UnaryServerInterceptor

	if s.middlewares.Trace {
		interceptors = append(interceptors, serverinterceptors.UnaryTracingInterceptor)
	}
	if s.middlewares.Recover {
		interceptors = append(interceptors, serverinterceptors.UnaryRecoverInterceptor)
	}
	if s.middlewares.Stat {
		interceptors = append(interceptors,
			serverinterceptors.UnaryStatInterceptor(s.metrics, s.middlewares.StatConf))
	}
	if s.middlewares.Prometheus {
		interceptors = append(interceptors, serverinterceptors.UnaryPrometheusInterceptor)
	}
	if s.middlewares.Breaker {
		interceptors = append(interceptors, serverinterceptors.UnaryBreakerInterceptor)
	}

	return append(interceptors, s.unaryInterceptors...)
}

func (s *rpcServer) Start(register RegisterFn) error {
	lis, err := net.Listen("tcp", s.address)
	if err != nil {
		return err
	}

	unaryInterceptorOption := grpc.ChainUnaryInterceptor(s.buildUnaryInterceptors()...)
	streamInterceptorOption := grpc.ChainStreamInterceptor(s.buildStreamInterceptors()...)

	options := append(s.options, unaryInterceptorOption, streamInterceptorOption)
	server := grpc.NewServer(options...)
	register(server)

	// register the health check service
	if s.health != nil {
		grpc_health_v1.RegisterHealthServer(server, s.health)
		s.health.Resume()
	}
	s.healthManager.MarkReady()
	health.AddProbe(s.healthManager)

	// we need to make sure all others are wrapped up,
	// so we do graceful stop at shutdown phase instead of wrap up phase
	waitForCalled := proc.AddShutdownListener(func() {
		if s.health != nil {
			s.health.Shutdown()
		}
		server.GracefulStop()
	})
	defer waitForCalled()

	return server.Serve(lis)
}

客户端代码:

//zrpc/internal/client.go
func (c *client) buildUnaryInterceptors(timeout time.Duration) []grpc.UnaryClientInterceptor {
	var interceptors []grpc.UnaryClientInterceptor

	if c.middlewares.Trace {
		interceptors = append(interceptors, clientinterceptors.UnaryTracingInterceptor)
	}
	if c.middlewares.Duration {
		interceptors = append(interceptors, clientinterceptors.DurationInterceptor)
	}
	if c.middlewares.Prometheus {
		interceptors = append(interceptors, clientinterceptors.PrometheusInterceptor)
	}
	if c.middlewares.Breaker {
		interceptors = append(interceptors, clientinterceptors.BreakerInterceptor)
	}
	if c.middlewares.Timeout {
		interceptors = append(interceptors, clientinterceptors.TimeoutInterceptor(timeout))
	}

	return interceptors
}

func (c *client) buildDialOptions(opts ...ClientOption) []grpc.DialOption {
	var cliOpts ClientOptions
	for _, opt := range opts {
		opt(&cliOpts)
	}

	var options []grpc.DialOption
	if !cliOpts.Secure {
		options = append([]grpc.DialOption(nil),
			grpc.WithTransportCredentials(insecure.NewCredentials()))
	}

	if !cliOpts.NonBlock {
		options = append(options, grpc.WithBlock())
	}

	options = append(options,
		grpc.WithChainUnaryInterceptor(c.buildUnaryInterceptors(cliOpts.Timeout)...),
		grpc.WithChainStreamInterceptor(c.buildStreamInterceptors()...),
	)

	return append(options, cliOpts.DialOptions...)
}

func (c *client) dial(server string, opts ...ClientOption) error {
	options := c.buildDialOptions(opts...)
	timeCtx, cancel := context.WithTimeout(context.Background(), dialTimeout)
	defer cancel()
	conn, err := grpc.DialContext(timeCtx, server, options...)
	if err != nil {
		service := server
		if errors.Is(err, context.DeadlineExceeded) {
			pos := strings.LastIndexByte(server, separator)
			// len(server) - 1 is the index of last char
			if 0 < pos && pos < len(server)-1 {
				service = server[pos+1:]
			}
		}
		return fmt.Errorf("rpc dial: %s, error: %s, make sure rpc service %q is already started",
			server, err.Error(), service)
	}

	c.conn = conn
	return nil
}

// NewClient returns a Client.
func NewClient(target string, middlewares ClientMiddlewaresConf, opts ...ClientOption) (Client, error) {
	cli := client{
		middlewares: middlewares,
	}

	svcCfg := fmt.Sprintf(`{"loadBalancingPolicy":"%s"}`, p2c.Name)
	balancerOpt := WithDialOption(grpc.WithDefaultServiceConfig(svcCfg))
	opts = append([]ClientOption{balancerOpt}, opts...)
	if err := cli.dial(target, opts...); err != nil {
		return nil, err
	}

	return &cli, nil
}