Middleware Extensions
Introduction
Kitex, as a lightweight RPC framework, offers powerful extensibility and primarily provides two methods of extension: one is a relatively low-level approach that involves adding middleware directly, and the other is a higher-level approach that involves adding suites. The following mainly introduces the usage of middleware.
Middleware
Middleware is a relatively low level of extension. Most of the Kitex-based extension and secondary development functions are based on middleware to achieve. Kitex’s Middleware is defined in pkg/endpoint/endpoint.go
, the most important of which are two types:
-
Endpoint
is a function that accepts ctx, req, resp and returns err. Please refer to the following “Example” code. -
Middleware (hereinafter referred to as MW) is also a function that receives and returns an Endpoint.
type Middleware func(Endpoint) Endpoint
In fact, a middleware is essentially a function that takes an Endpoint as input and returns an Endpoint as output. This ensures transparency to the application, as the application itself is unaware of whether it is being decorated by middleware. Due to this feature, middlewares can be nested and used in combination.
Middlewares are used in a chained manner. By invoking the provided next function, you can obtain the response (if any) and error returned by the subsequent middleware. Based on this, you can perform the necessary processing and return an error to the previous middleware (be sure to check for errors returned by next and avoid swallowing errors) or set the response accordingly.
Client Middleware
How To Use
-
client.WithMiddleware
adds Middleware to the current client, which is executed after service circuit breaker and timeout Middleware ; -
client.WithInstanceMW
adds middleware to the current client, which is executed after service discovery and load balance. If there is an instance circuit breaker, it will be executed after the instance circuit breaker (if Proxy is used, it will not be called, such as in Mesh mode). -
client.WithContextMiddlewares
set context middlewares, see bellow for the detail. -
client.WithErrorHandler
set the error handler middleware.Note: the above functions should all be passed as options when creating the client.
Call Sequence
-
XDS routing, service level circuit breaker, timeout
-
Middlewares set by
client.WithContextMiddlewares
-
Middlewares set by
client.WithMiddleware
-
ACLMiddleware, see Customized Access Control
-
Service Discovery , Instance circuit breaker , Instance-Level Middleware/Service Discovery, Proxy Middleware
-
Error handler set by
client.WithErrorHandler
The above can be seen in client.go
Context Middleware
Context Middleware is essentially Client Middleware, but the difference is that it is controlled by ctx whether and which to inject middleware.
The introduction of Context Middleware is to provide a method that can inject Client Middleware globally or dynamically. Typical usage scenarios include statistics on which downstream interfaces are called.
Middleware can be injected into ctx using ctx = client.WithContextMiddlewares(ctx, mw)
.
Note: Context Middleware executes before middleware set by client.WithMiddleware()
.
Server Middleware
How To Use
-
server.WithMiddleware
adds Middleware to the current server. -
server.WithErrorHandler
set the error handler middleware.Note: the above functions should all be passed as options when creating the server.
Call Sequence
-
Middleware set by
server.WithMiddleware
-
ACLMiddleware, see Customized Access Control
-
Error handler set by
client.WithErrorHandler
The above can be seen in server.go
Example
We can use the following example to see how to use Middleware.
Request/Reponse
If we need to print out the request content before the request, and then print out the response content after the request, we can write the following middleware (For the service which include streaming call, see the gRPC middleware below):
/*
type Request struct {
Message string `thrift:"Message,1,required" frugal:"1,required,string" json:"Message"`
Base *base.Base `thrift:"Base,255,optional" frugal:"255,optional,base.Base" json:"Base,omitempty"`
}
type Response struct {
Message string `thrift:"Message,1,required" frugal:"1,required,string" json:"Message"`
BaseResp *base.BaseResp `thrift:"BaseResp,255,optional" frugal:"255,optional,base.BaseResp" json:"BaseResp,omitempty"`
}
*/
import "github.com/cloudwego/kitex/pkg/utils"
func ExampleMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request, response interface{}) error {
if arg, ok := request.(utils.KitexArgs); ok {
if req := arg.GetFirstArgument().(*echo.Request; req != nil {
klog.Debugf("Request Message: %v", req.Message)
}
}
err := next(ctx, request, response)
if result, ok := response.(utils.KitexResult); ok {
if resp, ok := result.GetResult().(*echo.Response); ok {
klog.Debugf("Response Message: %v", resp.Message)
// resp.SetSuccess(...) could be used to replace customized response
// But notice: the type should be the same with the response of this method
}
}
return err
}
}
The provided example is for illustrative purposes, and it is indeed important to exercise caution when implementing such logging practices in a production environment. Logging every request and response indiscriminately can indeed have performance implications, especially when dealing with large response bodies.
Precautions
- If RPCInfo is used in custom middleware, be aware that RPCInfo will be recycled after the rpc ends, so if you use goroutine operation RPCInfo in middleware, there will be issues . Please avoid such operations .
- Middleware is a chained call, if you use
result. SetSuccess()
or some other way to modify the response in any middleware, the upstream middlewares will receive the modified response.
gRPC Middleware
As we all know, in addition to Thrift, Kitex also supports the protobuf and gRPC encoding/decoding protocols. In the case of protobuf, it refers to using protobuf exclusively to define the payload format, and the service definition only includes unary methods. However, if streaming methods are introduced, Kitex will use the gRPC protocol for encoding/decoding and communication.
For services using protobuf (unary only), the development of middleware remains consistent with the previous context, as the design of both is identical.
However, if streaming methods are used, the development of middleware is completely different. Therefore, the usage of gRPC streaming middleware is explained separately as a standalone unit.
For streaming methods, such as client stream, server stream, bidirectional stream, etc., and considering that the sending and receiving of messages (Recv & Send) have their own business logic control, middleware can not cover the messages themselves. Therefore, if you want to implement request/response logging at the message level during Send/Recv operations, you need to wrap Kitex’s streaming.Stream as follows:
type wrappedStream struct {
streaming.Stream
}
func (w *wrappedStream) RecvMsg(m interface{}) error {
err := w.Stream.RecvMsg(m)
log.Printf("Receive a message: %T(%v)", m, m)
return err
}
func (w *wrappedStream) SendMsg(m interface{}) error {
log.Printf("Send a message: %T(%v)", m, m)
return w.Stream.SendMsg(m)
}
func newWrappedStream(s streaming.Stream) streaming.Stream {
return &wrappedStream{s}
}
Then, within the middleware, insert the wrapped streaming.Stream object at specific invocation points.
import "github.com/cloudwego/kitex/pkg/streaming"
// A middleware that can be used for both client-side and server-side in Kitex with gRPC/Thrift/TTheader-protobuf
func DemoGRPCMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, res interface{}) error {
var Nil interface{} // can not switch nil directly in go
switch Nil {
case req: // The current middleware is used for the client-side and specifically designed for streaming methods
err := next(ctx, req, res)
// The stream object can only be obtained after the final endpoint returns
if tmp, ok := res.(*streaming.Result); err == nil && ok {
tmp.Stream = newWrappedStream(tmp.Stream) // wrap stream object
}
return err
case res: // The current middleware is used for the server-side and specifically designed for streaming methods
if tmp, ok := req.(*streaming.Args); ok {
tmp.Stream = newWrappedStream(tmp.Stream) // wrap stream object
}
default: // pure unary method, or thrift method
// do something else
}
return next(ctx, req, res)
}
}
Explanation of the request/response parameter types obtained within the Kitex middleware in different scenarios of gRPC:
Scenario | Request Type | Response Type |
---|---|---|
Kitex-gRPC Server Unary/Streaming | *streaming.Args | nil |
Kitex-gRPC Client Unary | *xxxservice.XXXMethodArgs | *xxxservice.XXXMethodResult |
Kitex-gRPC Client Streaming | nil | *streaming.Result |
Summary
Middleware is indeed a lower-level implementation of extensions, typically used to inject simple code containing specific functionalities. However, in complex scenarios, single middleware may not be sufficient to meet the business requirements. In such cases, a more comprehensive approach is needed, which involves assembling multiple middlewares or options into a complete middleware layer. Users can develop this requirement based on suites, refer to Suite Extend
FAQ
How to recover handler panic in middleware
Question: A handler who wanted to recover their own business in middleware threw a panic and found that the panic had already been recovered by the framework.
Description:
The framework will recover and report the panic in Handler. If you want to capture panic in custom middleware, you can determine the type of error returned in middleware (whether it is kerrors.ErrPanic
).
func TestServerMiddleware(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, req, resp interface{}) (err error) {
err = next(ctx, req, resp)
if errors.Is(err, kerrors.ErrPanic) {
fmt.Println("capture panic")
}
return err
}
}
How to get the real Request/Response in Middleware?
Due to implementation needs, the req and resp passed in middlewares are not the req and resp passed by the real user, but an object wrapped by Kitex, specifically a structure similar to the following.
Thrift
// req
type ${XService}${XMethod}Args struct {
Req *${XRequest} `thrift:"req,1" json:"req"`
}
func (p *${XService}${XMethod}Args) GetFirstArgument() interface{} {
return p.Req
}
// resp
type ${XService}${XMethod}Result struct {
Success *${XResponse} `thrift:"success,0" json:"success,omitempty"`
}
func (p *${XService}${XMethod}Result) GetResult() interface{} {
return p.Success
}
Protobuf
// req
type ${XMethod}Args struct {
Req *${XRequest}
}
func (p *${XMethod}Args) GetReq() *${XRequest} {
if !p.IsSetReq() {
return ${XMethod}Args_Req_DEFAULT
}
return p.Req
}
// resp
type ${XMethod}Result struct {
Success *${XResponse}
}
func (p *${XMethod}Result) GetSuccess() *${XResponse} {
if !p.IsSetSuccess() {
return ${XMethod}Result_Success_DEFAULT
}
return p.Success
}
The above generated code can be seen in kitex_gen directory. Therefore, there are three solutions for the business side to obtain the real req and resp:
- If you can determine which method is being called and the type of req used, you can directly obtain the specific Args type through type assertion, and then obtain the real req through the GetReq method.
- For thrift generated code, by asserting
GetFirstArgument
orGetResult
, obtaininterface{}
, and then do type assertion to the real req or resp (Note: Since the returnedinterface{}
contains a type, judginginterface{}
nil cannot intercept the case where req/resp itself is a null pointer, so we need to judge whether the asserted req/resp is a null pointer again); - Obtain the real request/response body through reflection method, refer to the code:
var ExampleMW endpoint.Middleware = func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request, response interface{}) error {
reqV := reflect.ValueOf(request).MethodByName("GetReq").Call(nil)[0]
log.Infof(ctx, "request: %T", reqV.Interface())
err := next(ctx, request, response)
respV := reflect.ValueOf(response).MethodByName("GetSuccess").Call(nil)[0]
log.Infof(ctx, "response: %T", respV.Interface())
return err
}
}