diff --git a/api/adapter.go b/api/adapter.go deleted file mode 100644 index 646f881..0000000 --- a/api/adapter.go +++ /dev/null @@ -1,13 +0,0 @@ -package api - -import "net/http" - -// Adapter to encapsulate incoming requests -type Adapter func(http.HandlerFunc) http.HandlerFunc - -// AuthHandlerFunc is http.HandlerFunc with additional Authorization information -type AuthHandlerFunc func(Auth, http.ResponseWriter, *http.Request) - -// AuthAdapter to encapsulate incoming request with access to api.Auth -// to manage permissions -type AuthAdapter func(AuthHandlerFunc) AuthHandlerFunc diff --git a/api/auth.go b/api/auth.go index ff2699a..cc162b7 100644 --- a/api/auth.go +++ b/api/auth.go @@ -21,7 +21,7 @@ type Auth struct { // Granted returns whether the authorization is granted // i.e. Auth.Active fulfills Auth.Required -func (a Auth) Granted() bool { +func (a *Auth) Granted() bool { var nothingRequired = true // first dimension: OR ; at least one is valid @@ -43,7 +43,7 @@ func (a Auth) Granted() bool { } // returns whether Auth.Active fulfills (contains) all @required roles -func (a Auth) fulfills(required []string) bool { +func (a *Auth) fulfills(required []string) bool { for _, requiredRole := range required { var found = false for _, activeRole := range a.Active { diff --git a/api/context.go b/api/context.go index c5f9a9f..a384282 100644 --- a/api/context.go +++ b/api/context.go @@ -7,12 +7,8 @@ import ( "git.xdrm.io/go/aicra/internal/ctx" ) -// Context is a simple wrapper around context.Context that adds helper methods -// to access additional information -type Context struct{ context.Context } - -// Request current request -func (c Context) Request() *http.Request { +// GetRequest extracts the current request from a context.Context +func GetRequest(c context.Context) *http.Request { var ( raw = c.Value(ctx.Request) cast, ok = raw.(*http.Request) @@ -23,8 +19,8 @@ func (c Context) Request() *http.Request { return cast } -// ResponseWriter for this request -func (c Context) ResponseWriter() http.ResponseWriter { +// GetResponseWriter extracts the response writer from a context.Context +func GetResponseWriter(c context.Context) http.ResponseWriter { var ( raw = c.Value(ctx.Response) cast, ok = raw.(http.ResponseWriter) @@ -35,8 +31,8 @@ func (c Context) ResponseWriter() http.ResponseWriter { return cast } -// Auth associated with this request -func (c Context) Auth() *Auth { +// GetAuth returns the api.Auth associated with this request from a context.Context +func GetAuth(c context.Context) *Auth { var ( raw = c.Value(ctx.Auth) cast, ok = raw.(*Auth) diff --git a/builder.go b/builder.go index 15d94e4..d3c9eb9 100644 --- a/builder.go +++ b/builder.go @@ -5,7 +5,6 @@ import ( "io" "net/http" - "git.xdrm.io/go/aicra/api" "git.xdrm.io/go/aicra/datatype" "git.xdrm.io/go/aicra/internal/config" "git.xdrm.io/go/aicra/internal/dynfunc" @@ -13,10 +12,16 @@ import ( // Builder for an aicra server type Builder struct { - conf *config.Server - handlers []*apiHandler - adapters []api.Adapter - authAdapters []api.AuthAdapter + // the server configuration defining available services + conf *config.Server + // user-defined handlers bound to services from the configuration + handlers []*apiHandler + // http middlewares wrapping the entire http connection (e.g. logger) + middlewares []func(http.Handler) http.Handler + // custom middlewares only wrapping the service handler of a request + // they will benefit from the request's context that contains service-specific + // information (e.g. required permisisons from the configuration) + ctxMiddlewares []func(http.Handler) http.Handler } // represents an api handler (method-pattern combination) @@ -41,26 +46,36 @@ func (b *Builder) AddType(t datatype.T) error { return nil } -// With adds an http adapter (middleware) -func (b *Builder) With(adapter api.Adapter) { +// With adds an http middleware on top of the http connection +// +// Authentication management can only be done with the WithContext() methods as +// the service associated with the request has not been found at this stage. +// This stage is perfect for logging or generic request management. +func (b *Builder) With(mw func(http.Handler) http.Handler) { if b.conf == nil { b.conf = &config.Server{} } - if b.adapters == nil { - b.adapters = make([]api.Adapter, 0) + if b.middlewares == nil { + b.middlewares = make([]func(http.Handler) http.Handler, 0) } - b.adapters = append(b.adapters, adapter) + b.middlewares = append(b.middlewares, mw) } -// WithAuth adds an http adapter with auth capabilities (middleware) -func (b *Builder) WithAuth(adapter api.AuthAdapter) { +// WithContext adds an http middleware with the fully loaded context +// +// Logging or generic request management should be done with the With() method as +// it wraps the full http connection. Middlewares added through this method only +// wrap the user-defined service handler. The context.Context is filled with useful +// data that can be access with api.GetRequest(), api.GetResponseWriter(), +// api.GetAuth(), etc methods. +func (b *Builder) WithContext(mw func(http.Handler) http.Handler) { if b.conf == nil { b.conf = &config.Server{} } - if b.authAdapters == nil { - b.authAdapters = make([]api.AuthAdapter, 0) + if b.ctxMiddlewares == nil { + b.ctxMiddlewares = make([]func(http.Handler) http.Handler, 0) } - b.authAdapters = append(b.authAdapters, adapter) + b.ctxMiddlewares = append(b.ctxMiddlewares, mw) } // Setup the builder with its api definition file diff --git a/builder_test.go b/builder_test.go index c123f3a..0207c1e 100644 --- a/builder_test.go +++ b/builder_test.go @@ -1,6 +1,7 @@ package aicra import ( + "context" "errors" "net/http" "strings" @@ -72,7 +73,7 @@ func TestBind(t *testing.T) { Config: "[]", HandlerMethod: "", HandlerPath: "", - HandlerFn: func(*api.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + HandlerFn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, BindErr: errUnknownService, BuildErr: nil, }, @@ -108,7 +109,7 @@ func TestBind(t *testing.T) { ]`, HandlerMethod: http.MethodPost, HandlerPath: "/path", - HandlerFn: func(*api.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + HandlerFn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, BindErr: errUnknownService, BuildErr: errMissingHandler, }, @@ -126,7 +127,7 @@ func TestBind(t *testing.T) { ]`, HandlerMethod: http.MethodGet, HandlerPath: "/paths", - HandlerFn: func(*api.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + HandlerFn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, BindErr: errUnknownService, BuildErr: errMissingHandler, }, @@ -144,7 +145,7 @@ func TestBind(t *testing.T) { ]`, HandlerMethod: http.MethodGet, HandlerPath: "/path", - HandlerFn: func(*api.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + HandlerFn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, BindErr: nil, BuildErr: nil, }, @@ -164,7 +165,7 @@ func TestBind(t *testing.T) { ]`, HandlerMethod: http.MethodGet, HandlerPath: "/path", - HandlerFn: func(*api.Context, struct{ Name int }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + HandlerFn: func(context.Context, struct{ Name int }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, BindErr: nil, BuildErr: nil, }, @@ -184,7 +185,7 @@ func TestBind(t *testing.T) { ]`, HandlerMethod: http.MethodGet, HandlerPath: "/path", - HandlerFn: func(*api.Context, struct{ Name uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + HandlerFn: func(context.Context, struct{ Name uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, BindErr: nil, BuildErr: nil, }, @@ -204,7 +205,7 @@ func TestBind(t *testing.T) { ]`, HandlerMethod: http.MethodGet, HandlerPath: "/path", - HandlerFn: func(*api.Context, struct{ Name string }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + HandlerFn: func(context.Context, struct{ Name string }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, BindErr: nil, BuildErr: nil, }, @@ -224,7 +225,7 @@ func TestBind(t *testing.T) { ]`, HandlerMethod: http.MethodGet, HandlerPath: "/path", - HandlerFn: func(*api.Context, struct{ Name bool }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + HandlerFn: func(context.Context, struct{ Name bool }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, BindErr: nil, BuildErr: nil, }, diff --git a/handler.go b/handler.go index d91f5e7..6a269bd 100644 --- a/handler.go +++ b/handler.go @@ -17,30 +17,31 @@ type Handler Builder // ServeHTTP implements http.Handler and wraps it in middlewares (adapters) func (s Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - var h = http.HandlerFunc(s.resolve) + var h http.Handler = http.HandlerFunc(s.resolve) - for _, adapter := range s.adapters { - h = adapter(h) + for _, mw := range s.middlewares { + h = mw(h) } - h(w, r) + h.ServeHTTP(w, r) } +// ServeHTTP implements http.Handler and wraps it in middlewares (adapters) func (s Handler) resolve(w http.ResponseWriter, r *http.Request) { - // 1. find a matching service from config + // 1ind a matching service from config var service = s.conf.Find(r) if service == nil { handleError(api.ErrUnknownService, w, r) return } - // 2. extract request data + // extract request data var input, err = extractInput(service, *r) if err != nil { handleError(api.ErrMissingParam, w, r) return } - // 3. find a matching handler + // find a matching handler var handler *apiHandler for _, h := range s.handlers { if h.Method == service.Method && h.Path == service.Pattern { @@ -48,63 +49,36 @@ func (s Handler) resolve(w http.ResponseWriter, r *http.Request) { } } - // 4. fail on no matching handler + // fail on no matching handler if handler == nil { handleError(api.ErrUncallableService, w, r) return } - // replace format '[a]' in scope where 'a' is an existing input's name - scope := make([][]string, len(service.Scope)) - for a, list := range service.Scope { - scope[a] = make([]string, len(list)) - for b, perm := range list { - scope[a][b] = perm - for name, value := range input.Data { - var ( - token = fmt.Sprintf("[%s]", name) - replacement = "" - ) - if value != nil { - replacement = fmt.Sprintf("[%v]", value) - } - scope[a][b] = strings.ReplaceAll(scope[a][b], token, replacement) - } - } - } - - var auth = api.Auth{ - Required: scope, - Active: []string{}, - } - - // 5. run auth-aware middlewares - var h = api.AuthHandlerFunc(func(a api.Auth, w http.ResponseWriter, r *http.Request) { - if !a.Granted() { - handleError(api.ErrPermission, w, r) - return - } - - s.handle(input, handler, service, w, r) - }) - - for _, adapter := range s.authAdapters { - h = adapter(h) - } - h(auth, w, r) - -} - -func (s *Handler) handle(input *reqdata.T, handler *apiHandler, service *config.Service, w http.ResponseWriter, r *http.Request) { // build context with builtin data c := r.Context() c = context.WithValue(c, ctx.Request, r) c = context.WithValue(c, ctx.Response, w) - c = context.WithValue(c, ctx.Auth, w) - apictx := &api.Context{Context: c} + c = context.WithValue(c, ctx.Auth, buildAuth(service.Scope, input.Data)) + // create http handler + var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // use context defined in the request + s.handle(r.Context(), input, handler, service, w, r) + }) + + // run middlewares the handler + for _, mw := range s.ctxMiddlewares { + h = mw(h) + } + + // serve using the context with values + h.ServeHTTP(w, r.WithContext(c)) +} + +func (s *Handler) handle(c context.Context, input *reqdata.T, handler *apiHandler, service *config.Service, w http.ResponseWriter, r *http.Request) { // pass execution to the handler - var outData, outErr = handler.dyn.Handle(apictx, input.Data) + var outData, outErr = handler.dyn.Handle(c, input.Data) // build response from returned arguments var res = api.EmptyResponse().WithError(outErr) @@ -157,3 +131,35 @@ func extractInput(service *config.Service, req http.Request) (*reqdata.T, error) return dataset, nil } + +// buildAuth builds the api.Auth struct from the service scope configuration +// +// it replaces format '[a]' in scope where 'a' is an existing input argument's +// name with its value +func buildAuth(scope [][]string, in map[string]interface{}) *api.Auth { + updated := make([][]string, len(scope)) + + // replace '[arg_name]' with the 'arg_name' value if it is a known variable + // name + for a, list := range scope { + updated[a] = make([]string, len(list)) + for b, perm := range list { + updated[a][b] = perm + for name, value := range in { + var ( + token = fmt.Sprintf("[%s]", name) + replacement = "" + ) + if value != nil { + replacement = fmt.Sprintf("[%v]", value) + } + updated[a][b] = strings.ReplaceAll(updated[a][b], token, replacement) + } + } + } + + return &api.Auth{ + Required: updated, + Active: []string{}, + } +} diff --git a/handler_test.go b/handler_test.go index be2fdb7..eba1f0d 100644 --- a/handler_test.go +++ b/handler_test.go @@ -48,15 +48,15 @@ func TestWith(t *testing.T) { type ckey int const key ckey = 0 - middleware := func(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { + middleware := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { newr := r // first time -> store 1 value := r.Context().Value(key) if value == nil { newr = r.WithContext(context.WithValue(r.Context(), key, int(1))) - next(w, newr) + next.ServeHTTP(w, newr) return } @@ -67,8 +67,8 @@ func TestWith(t *testing.T) { } cast++ newr = r.WithContext(context.WithValue(r.Context(), key, cast)) - next(w, newr) - } + next.ServeHTTP(w, newr) + }) } // add middleware @n times @@ -82,7 +82,7 @@ func TestWith(t *testing.T) { t.Fatalf("setup: unexpected error <%v>", err) } - pathHandler := func(ctx *api.Context) (*struct{}, api.Err) { + pathHandler := func(ctx context.Context) (*struct{}, api.Err) { // write value from middlewares into response value := ctx.Value(key) if value == nil { @@ -93,7 +93,7 @@ func TestWith(t *testing.T) { t.Fatalf("cannot cast context data to int") } // write to response - ctx.ResponseWriter().Write([]byte(fmt.Sprintf("#%d#", cast))) + api.GetResponseWriter(ctx).Write([]byte(fmt.Sprintf("#%d#", cast))) return nil, api.ErrSuccess } @@ -212,8 +212,13 @@ func TestWithAuth(t *testing.T) { } // tester middleware (last executed) - builder.WithAuth(func(next api.AuthHandlerFunc) api.AuthHandlerFunc { - return func(a api.Auth, w http.ResponseWriter, r *http.Request) { + builder.WithContext(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + a := api.GetAuth(r.Context()) + if a == nil { + t.Fatalf("cannot access api.Auth form request context") + } + if a.Granted() == tc.granted { return } @@ -222,14 +227,20 @@ func TestWithAuth(t *testing.T) { } else { t.Fatalf("expected granted auth") } - } + next.ServeHTTP(w, r) + }) }) - builder.WithAuth(func(next api.AuthHandlerFunc) api.AuthHandlerFunc { - return func(a api.Auth, w http.ResponseWriter, r *http.Request) { + builder.WithContext(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + a := api.GetAuth(r.Context()) + if a == nil { + t.Fatalf("cannot access api.Auth form request context") + } + a.Active = tc.permissions - next(a, w, r) - } + next.ServeHTTP(w, r) + }) }) err := builder.Setup(strings.NewReader(tc.manifest)) @@ -237,7 +248,7 @@ func TestWithAuth(t *testing.T) { t.Fatalf("setup: unexpected error <%v>", err) } - pathHandler := func(ctx *api.Context) (*struct{}, api.Err) { + pathHandler := func(ctx context.Context) (*struct{}, api.Err) { return nil, api.ErrNotImplemented } @@ -290,7 +301,7 @@ func TestDynamicScope(t *testing.T) { } ]`, path: "/path/{id}", - handler: func(*api.Context, struct{ Input1 uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + handler: func(context.Context, struct{ Input1 uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, url: "/path/123", body: ``, permissions: []string{"user[123]"}, @@ -311,7 +322,7 @@ func TestDynamicScope(t *testing.T) { } ]`, path: "/path/{id}", - handler: func(*api.Context, struct{ Input1 uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + handler: func(context.Context, struct{ Input1 uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, url: "/path/666", body: ``, permissions: []string{"user[123]"}, @@ -332,7 +343,7 @@ func TestDynamicScope(t *testing.T) { } ]`, path: "/path/{id}", - handler: func(*api.Context, struct{ User uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + handler: func(context.Context, struct{ User uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, url: "/path/123", body: ``, permissions: []string{"prefix.user[123].suffix"}, @@ -354,7 +365,7 @@ func TestDynamicScope(t *testing.T) { } ]`, path: "/prefix/{pid}/user/{uid}", - handler: func(*api.Context, struct { + handler: func(context.Context, struct { Prefix uint User uint }) (*struct{}, api.Err) { @@ -381,7 +392,7 @@ func TestDynamicScope(t *testing.T) { } ]`, path: "/prefix/{pid}/user/{uid}", - handler: func(*api.Context, struct { + handler: func(context.Context, struct { Prefix uint User uint }) (*struct{}, api.Err) { @@ -409,7 +420,7 @@ func TestDynamicScope(t *testing.T) { } ]`, path: "/prefix/{pid}/user/{uid}/suffix/{sid}", - handler: func(*api.Context, struct { + handler: func(context.Context, struct { Prefix uint User uint Suffix uint @@ -438,7 +449,7 @@ func TestDynamicScope(t *testing.T) { } ]`, path: "/prefix/{pid}/user/{uid}/suffix/{sid}", - handler: func(*api.Context, struct { + handler: func(context.Context, struct { Prefix uint User uint Suffix uint @@ -460,8 +471,12 @@ func TestDynamicScope(t *testing.T) { } // tester middleware (last executed) - builder.WithAuth(func(next api.AuthHandlerFunc) api.AuthHandlerFunc { - return func(a api.Auth, w http.ResponseWriter, r *http.Request) { + builder.WithContext(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + a := api.GetAuth(r.Context()) + if a == nil { + t.Fatalf("cannot access api.Auth form request context") + } if a.Granted() == tc.granted { return } @@ -470,15 +485,20 @@ func TestDynamicScope(t *testing.T) { } else { t.Fatalf("expected granted auth") } - } + next.ServeHTTP(w, r) + }) }) // update permissions - builder.WithAuth(func(next api.AuthHandlerFunc) api.AuthHandlerFunc { - return func(a api.Auth, w http.ResponseWriter, r *http.Request) { + builder.WithContext(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + a := api.GetAuth(r.Context()) + if a == nil { + t.Fatalf("cannot access api.Auth form request context") + } a.Active = tc.permissions - next(a, w, r) - } + next.ServeHTTP(w, r) + }) }) err := builder.Setup(strings.NewReader(tc.manifest)) diff --git a/internal/dynfunc/errors.go b/internal/dynfunc/errors.go index 7d45b64..903e5f0 100644 --- a/internal/dynfunc/errors.go +++ b/internal/dynfunc/errors.go @@ -14,7 +14,7 @@ const errHandlerNotFunc = cerr("handler must be a func") const errNoServiceForHandler = cerr("no service found for this handler") // errMissingHandlerArgumentParam - missing params arguments for handler -const errMissingHandlerContextArgument = cerr("missing handler first argument of type *api.Context") +const errMissingHandlerContextArgument = cerr("missing handler first argument of type context.Context") // errMissingHandlerInputArgument - missing params arguments for handler const errMissingHandlerInputArgument = cerr("missing handler argument: input struct") diff --git a/internal/dynfunc/handler.go b/internal/dynfunc/handler.go index dd3bcdc..783612b 100644 --- a/internal/dynfunc/handler.go +++ b/internal/dynfunc/handler.go @@ -1,6 +1,7 @@ package dynfunc import ( + "context" "fmt" "log" "reflect" @@ -50,7 +51,7 @@ func Build(fn interface{}, service config.Service) (*Handler, error) { } // Handle binds input `data` into the dynamic function and returns an output map -func (h *Handler) Handle(ctx *api.Context, data map[string]interface{}) (map[string]interface{}, api.Err) { +func (h *Handler) Handle(ctx context.Context, data map[string]interface{}) (map[string]interface{}, api.Err) { var ( ert = reflect.TypeOf(api.Err{}) fnv = reflect.ValueOf(h.fn) diff --git a/internal/dynfunc/handler_test.go b/internal/dynfunc/handler_test.go index 851e2f2..053cc25 100644 --- a/internal/dynfunc/handler_test.go +++ b/internal/dynfunc/handler_test.go @@ -1,6 +1,7 @@ package dynfunc import ( + "context" "fmt" "reflect" "testing" @@ -52,7 +53,7 @@ func TestInput(t *testing.T) { { Name: "none required none provided", Spec: (&testsignature{}).withArgs(), - Fn: func(*api.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + Fn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, HasContext: false, Input: []interface{}{}, ExpectedOutput: []interface{}{}, @@ -61,7 +62,7 @@ func TestInput(t *testing.T) { { Name: "int proxy (0)", Spec: (&testsignature{}).withArgs(reflect.TypeOf(int(0))), - Fn: func(ctx *api.Context, in intstruct) (*intstruct, api.Err) { + Fn: func(ctx context.Context, in intstruct) (*intstruct, api.Err) { return &intstruct{P1: in.P1}, api.ErrSuccess }, HasContext: false, @@ -72,7 +73,7 @@ func TestInput(t *testing.T) { { Name: "int proxy (11)", Spec: (&testsignature{}).withArgs(reflect.TypeOf(int(0))), - Fn: func(ctx *api.Context, in intstruct) (*intstruct, api.Err) { + Fn: func(ctx context.Context, in intstruct) (*intstruct, api.Err) { return &intstruct{P1: in.P1}, api.ErrSuccess }, HasContext: false, @@ -83,7 +84,7 @@ func TestInput(t *testing.T) { { Name: "*int proxy (nil)", Spec: (&testsignature{}).withArgs(reflect.TypeOf(new(int))), - Fn: func(ctx *api.Context, in intptrstruct) (*intptrstruct, api.Err) { + Fn: func(ctx context.Context, in intptrstruct) (*intptrstruct, api.Err) { return &intptrstruct{P1: in.P1}, api.ErrSuccess }, HasContext: false, @@ -94,7 +95,7 @@ func TestInput(t *testing.T) { { Name: "*int proxy (28)", Spec: (&testsignature{}).withArgs(reflect.TypeOf(new(int))), - Fn: func(ctx *api.Context, in intptrstruct) (*intstruct, api.Err) { + Fn: func(ctx context.Context, in intptrstruct) (*intstruct, api.Err) { return &intstruct{P1: *in.P1}, api.ErrSuccess }, HasContext: false, @@ -105,7 +106,7 @@ func TestInput(t *testing.T) { { Name: "*int proxy (13)", Spec: (&testsignature{}).withArgs(reflect.TypeOf(new(int))), - Fn: func(ctx *api.Context, in intptrstruct) (*intstruct, api.Err) { + Fn: func(ctx context.Context, in intptrstruct) (*intstruct, api.Err) { return &intstruct{P1: *in.P1}, api.ErrSuccess }, HasContext: false, @@ -131,7 +132,7 @@ func TestInput(t *testing.T) { input[key] = val } - var output, err = handler.Handle(&api.Context{}, input) + var output, err = handler.Handle(context.Background(), input) if err != tcase.ExpectedErr { t.Fatalf("expected api error <%v> got <%v>", tcase.ExpectedErr, err) } diff --git a/internal/dynfunc/signature.go b/internal/dynfunc/signature.go index 6d33a33..e0f9a19 100644 --- a/internal/dynfunc/signature.go +++ b/internal/dynfunc/signature.go @@ -1,6 +1,7 @@ package dynfunc import ( + "context" "fmt" "reflect" "strings" @@ -48,12 +49,17 @@ func BuildSignature(service config.Service) *Signature { // ValidateInput validates a handler's input arguments against the service signature func (s *Signature) ValidateInput(handlerType reflect.Type) error { - ctxType := reflect.TypeOf(api.Context{}) + ctxType := reflect.TypeOf((*context.Context)(nil)).Elem() - // missing or invalid first arg: api.Context - if handlerType.NumIn() < 1 || ctxType.AssignableTo(handlerType.In(0)) { + // missing or invalid first arg: context.Context + if handlerType.NumIn() < 1 { return errMissingHandlerContextArgument } + firstArgType := handlerType.In(0) + + if !firstArgType.Implements(ctxType) { + return fmt.Errorf("fock") + } // no input required if len(s.Input) == 0 { diff --git a/internal/dynfunc/signature_test.go b/internal/dynfunc/signature_test.go index e9e97b2..874834c 100644 --- a/internal/dynfunc/signature_test.go +++ b/internal/dynfunc/signature_test.go @@ -1,6 +1,7 @@ package dynfunc import ( + "context" "errors" "fmt" "reflect" @@ -20,22 +21,22 @@ func TestInputCheck(t *testing.T) { { Name: "no input 0 given", Input: map[string]reflect.Type{}, - Fn: func(*api.Context) {}, - FnCtx: func(*api.Context) {}, + Fn: func(context.Context) {}, + FnCtx: func(context.Context) {}, Err: nil, }, { Name: "no input 1 given", Input: map[string]reflect.Type{}, - Fn: func(*api.Context, int) {}, - FnCtx: func(*api.Context, int) {}, + Fn: func(context.Context, int) {}, + FnCtx: func(context.Context, int) {}, Err: errUnexpectedInput, }, { Name: "no input 2 given", Input: map[string]reflect.Type{}, - Fn: func(*api.Context, int, string) {}, - FnCtx: func(*api.Context, int, string) {}, + Fn: func(context.Context, int, string) {}, + FnCtx: func(context.Context, int, string) {}, Err: errUnexpectedInput, }, { @@ -43,8 +44,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func(*api.Context) {}, - FnCtx: func(*api.Context) {}, + Fn: func(context.Context) {}, + FnCtx: func(context.Context) {}, Err: errMissingHandlerInputArgument, }, { @@ -52,8 +53,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func(*api.Context, int) {}, - FnCtx: func(*api.Context, int) {}, + Fn: func(context.Context, int) {}, + FnCtx: func(context.Context, int) {}, Err: errMissingParamArgument, }, { @@ -61,8 +62,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "test1": reflect.TypeOf(int(0)), }, - Fn: func(*api.Context, struct{}) {}, - FnCtx: func(*api.Context, struct{}) {}, + Fn: func(context.Context, struct{}) {}, + FnCtx: func(context.Context, struct{}) {}, Err: errUnexportedName, }, { @@ -70,8 +71,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func(*api.Context, struct{}) {}, - FnCtx: func(*api.Context, struct{}) {}, + Fn: func(context.Context, struct{}) {}, + FnCtx: func(context.Context, struct{}) {}, Err: errMissingConfigArgument, }, { @@ -79,8 +80,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func(*api.Context, struct{ Test1 string }) {}, - FnCtx: func(*api.Context, struct{ Test1 string }) {}, + Fn: func(context.Context, struct{ Test1 string }) {}, + FnCtx: func(context.Context, struct{ Test1 string }) {}, Err: errWrongParamTypeFromConfig, }, { @@ -88,8 +89,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func(*api.Context, struct{ Test1 int }) {}, - FnCtx: func(*api.Context, struct{ Test1 int }) {}, + Fn: func(context.Context, struct{ Test1 int }) {}, + FnCtx: func(context.Context, struct{ Test1 int }) {}, Err: nil, }, { @@ -97,8 +98,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(int)), }, - Fn: func(*api.Context, struct{}) {}, - FnCtx: func(*api.Context, struct{}) {}, + Fn: func(context.Context, struct{}) {}, + FnCtx: func(context.Context, struct{}) {}, Err: errMissingConfigArgument, }, { @@ -106,8 +107,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(int)), }, - Fn: func(*api.Context, struct{ Test1 string }) {}, - FnCtx: func(*api.Context, struct{ Test1 string }) {}, + Fn: func(context.Context, struct{ Test1 string }) {}, + FnCtx: func(context.Context, struct{ Test1 string }) {}, Err: errWrongParamTypeFromConfig, }, { @@ -115,8 +116,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(int)), }, - Fn: func(*api.Context, struct{ Test1 *string }) {}, - FnCtx: func(*api.Context, struct{ Test1 *string }) {}, + Fn: func(context.Context, struct{ Test1 *string }) {}, + FnCtx: func(context.Context, struct{ Test1 *string }) {}, Err: errWrongParamTypeFromConfig, }, { @@ -124,8 +125,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(int)), }, - Fn: func(*api.Context, struct{ Test1 *int }) {}, - FnCtx: func(*api.Context, struct{ Test1 *int }) {}, + Fn: func(context.Context, struct{ Test1 *int }) {}, + FnCtx: func(context.Context, struct{ Test1 *int }) {}, Err: nil, }, { @@ -133,8 +134,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(string("")), }, - Fn: func(*api.Context, struct{ Test1 string }) {}, - FnCtx: func(*api.Context, struct{ Test1 string }) {}, + Fn: func(context.Context, struct{ Test1 string }) {}, + FnCtx: func(context.Context, struct{ Test1 string }) {}, Err: nil, }, { @@ -142,8 +143,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(uint(0)), }, - Fn: func(*api.Context, struct{ Test1 uint }) {}, - FnCtx: func(*api.Context, struct{ Test1 uint }) {}, + Fn: func(context.Context, struct{ Test1 uint }) {}, + FnCtx: func(context.Context, struct{ Test1 uint }) {}, Err: nil, }, { @@ -151,8 +152,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(float64(0)), }, - Fn: func(*api.Context, struct{ Test1 float64 }) {}, - FnCtx: func(*api.Context, struct{ Test1 float64 }) {}, + Fn: func(context.Context, struct{ Test1 float64 }) {}, + FnCtx: func(context.Context, struct{ Test1 float64 }) {}, Err: nil, }, { @@ -160,8 +161,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf([]byte("")), }, - Fn: func(*api.Context, struct{ Test1 []byte }) {}, - FnCtx: func(*api.Context, struct{ Test1 []byte }) {}, + Fn: func(context.Context, struct{ Test1 []byte }) {}, + FnCtx: func(context.Context, struct{ Test1 []byte }) {}, Err: nil, }, { @@ -169,8 +170,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf([]rune("")), }, - Fn: func(*api.Context, struct{ Test1 []rune }) {}, - FnCtx: func(*api.Context, struct{ Test1 []rune }) {}, + Fn: func(context.Context, struct{ Test1 []rune }) {}, + FnCtx: func(context.Context, struct{ Test1 []rune }) {}, Err: nil, }, { @@ -178,8 +179,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(string)), }, - Fn: func(*api.Context, struct{ Test1 *string }) {}, - FnCtx: func(*api.Context, struct{ Test1 *string }) {}, + Fn: func(context.Context, struct{ Test1 *string }) {}, + FnCtx: func(context.Context, struct{ Test1 *string }) {}, Err: nil, }, { @@ -187,8 +188,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(uint)), }, - Fn: func(*api.Context, struct{ Test1 *uint }) {}, - FnCtx: func(*api.Context, struct{ Test1 *uint }) {}, + Fn: func(context.Context, struct{ Test1 *uint }) {}, + FnCtx: func(context.Context, struct{ Test1 *uint }) {}, Err: nil, }, { @@ -196,8 +197,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(float64)), }, - Fn: func(*api.Context, struct{ Test1 *float64 }) {}, - FnCtx: func(*api.Context, struct{ Test1 *float64 }) {}, + Fn: func(context.Context, struct{ Test1 *float64 }) {}, + FnCtx: func(context.Context, struct{ Test1 *float64 }) {}, Err: nil, }, { @@ -205,8 +206,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new([]byte)), }, - Fn: func(*api.Context, struct{ Test1 *[]byte }) {}, - FnCtx: func(*api.Context, struct{ Test1 *[]byte }) {}, + Fn: func(context.Context, struct{ Test1 *[]byte }) {}, + FnCtx: func(context.Context, struct{ Test1 *[]byte }) {}, Err: nil, }, { @@ -214,8 +215,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new([]rune)), }, - Fn: func(*api.Context, struct{ Test1 *[]rune }) {}, - FnCtx: func(*api.Context, struct{ Test1 *[]rune }) {}, + Fn: func(context.Context, struct{ Test1 *[]rune }) {}, + FnCtx: func(context.Context, struct{ Test1 *[]rune }) {}, Err: nil, }, } @@ -259,37 +260,37 @@ func TestOutputCheck(t *testing.T) { // no input -> missing api.Err { Output: map[string]reflect.Type{}, - Fn: func(*api.Context) {}, + Fn: func(context.Context) {}, Err: errMissingHandlerOutputArgument, }, // no input -> with last type not api.Err { Output: map[string]reflect.Type{}, - Fn: func(*api.Context) bool { return true }, + Fn: func(context.Context) bool { return true }, Err: errMissingHandlerErrorArgument, }, // no input -> with api.Err { Output: map[string]reflect.Type{}, - Fn: func(*api.Context) api.Err { return api.ErrSuccess }, + Fn: func(context.Context) api.Err { return api.ErrSuccess }, Err: nil, }, - // no input -> missing *api.Context + // no input -> missing context.Context { Output: map[string]reflect.Type{}, - Fn: func(*api.Context) api.Err { return api.ErrSuccess }, + Fn: func(context.Context) api.Err { return api.ErrSuccess }, Err: errMissingHandlerContextArgument, }, - // no input -> invlaid *api.Context type + // no input -> invlaid context.Context type { Output: map[string]reflect.Type{}, - Fn: func(*api.Context, int) api.Err { return api.ErrSuccess }, + Fn: func(context.Context, int) api.Err { return api.ErrSuccess }, Err: errMissingHandlerContextArgument, }, // func can have output if not specified { Output: map[string]reflect.Type{}, - Fn: func(*api.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + Fn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, Err: nil, }, // missing output struct in func