diff --git a/README.md b/README.md index e8110b9..c2182e6 100644 --- a/README.md +++ b/README.md @@ -107,12 +107,17 @@ func main() { log.Fatalf("invalid config: %s", err) } + // add http middlewares (logger) + builder.With(func(next http.Handler) http.Handler{ /* ... */ }) + + // add contextual middlewares (authentication) + builder.WithContext(func(next http.Handler) http.Handler{ /* ... */ }) + // bind handlers err = builder.Bind(http.MethodGet, "/user/{id}", getUserById) if err != nil { log.Fatalf("cannog bind GET /user/{id}: %s", err) } - // ... // build your services handler, err := builder.Build() @@ -261,7 +266,7 @@ type res struct{ Output2 bool } -func myHandler(r req) (*res, api.Err) { +func myHandler(ctx context.Context, r req) (*res, api.Err) { err := doSomething() if err != nil { return nil, api.ErrFailure diff --git a/api/adapter.go b/api/adapter.go deleted file mode 100644 index 646f881..0000000 --- a/api/adapter.go +++ /dev/null @@ -1,13 +0,0 @@ -package api - -import "net/http" - -// Adapter to encapsulate incoming requests -type Adapter func(http.HandlerFunc) http.HandlerFunc - -// AuthHandlerFunc is http.HandlerFunc with additional Authorization information -type AuthHandlerFunc func(Auth, http.ResponseWriter, *http.Request) - -// AuthAdapter to encapsulate incoming request with access to api.Auth -// to manage permissions -type AuthAdapter func(AuthHandlerFunc) AuthHandlerFunc diff --git a/api/auth.go b/api/auth.go index ff2699a..cc162b7 100644 --- a/api/auth.go +++ b/api/auth.go @@ -21,7 +21,7 @@ type Auth struct { // Granted returns whether the authorization is granted // i.e. Auth.Active fulfills Auth.Required -func (a Auth) Granted() bool { +func (a *Auth) Granted() bool { var nothingRequired = true // first dimension: OR ; at least one is valid @@ -43,7 +43,7 @@ func (a Auth) Granted() bool { } // returns whether Auth.Active fulfills (contains) all @required roles -func (a Auth) fulfills(required []string) bool { +func (a *Auth) fulfills(required []string) bool { for _, requiredRole := range required { var found = false for _, activeRole := range a.Active { diff --git a/api/context.go b/api/context.go index d8e390e..a384282 100644 --- a/api/context.go +++ b/api/context.go @@ -1,17 +1,44 @@ package api import ( + "context" "net/http" + + "git.xdrm.io/go/aicra/internal/ctx" ) -// Ctx contains additional information for handlers -// -// usually input/output arguments built by aicra are sufficient -// but the Ctx lets you manage your request from scratch if required -// -// If required, set api.Ctx as the first argument of your handler; if you -// don't need it, only use standard input arguments and it will be ignored -type Ctx struct { - Res http.ResponseWriter - Req *http.Request +// GetRequest extracts the current request from a context.Context +func GetRequest(c context.Context) *http.Request { + var ( + raw = c.Value(ctx.Request) + cast, ok = raw.(*http.Request) + ) + if !ok { + return nil + } + return cast +} + +// GetResponseWriter extracts the response writer from a context.Context +func GetResponseWriter(c context.Context) http.ResponseWriter { + var ( + raw = c.Value(ctx.Response) + cast, ok = raw.(http.ResponseWriter) + ) + if !ok { + return nil + } + return cast +} + +// GetAuth returns the api.Auth associated with this request from a context.Context +func GetAuth(c context.Context) *Auth { + var ( + raw = c.Value(ctx.Auth) + cast, ok = raw.(*Auth) + ) + if !ok { + return nil + } + return cast } diff --git a/builder.go b/builder.go index 15d94e4..d3c9eb9 100644 --- a/builder.go +++ b/builder.go @@ -5,7 +5,6 @@ import ( "io" "net/http" - "git.xdrm.io/go/aicra/api" "git.xdrm.io/go/aicra/datatype" "git.xdrm.io/go/aicra/internal/config" "git.xdrm.io/go/aicra/internal/dynfunc" @@ -13,10 +12,16 @@ import ( // Builder for an aicra server type Builder struct { - conf *config.Server - handlers []*apiHandler - adapters []api.Adapter - authAdapters []api.AuthAdapter + // the server configuration defining available services + conf *config.Server + // user-defined handlers bound to services from the configuration + handlers []*apiHandler + // http middlewares wrapping the entire http connection (e.g. logger) + middlewares []func(http.Handler) http.Handler + // custom middlewares only wrapping the service handler of a request + // they will benefit from the request's context that contains service-specific + // information (e.g. required permisisons from the configuration) + ctxMiddlewares []func(http.Handler) http.Handler } // represents an api handler (method-pattern combination) @@ -41,26 +46,36 @@ func (b *Builder) AddType(t datatype.T) error { return nil } -// With adds an http adapter (middleware) -func (b *Builder) With(adapter api.Adapter) { +// With adds an http middleware on top of the http connection +// +// Authentication management can only be done with the WithContext() methods as +// the service associated with the request has not been found at this stage. +// This stage is perfect for logging or generic request management. +func (b *Builder) With(mw func(http.Handler) http.Handler) { if b.conf == nil { b.conf = &config.Server{} } - if b.adapters == nil { - b.adapters = make([]api.Adapter, 0) + if b.middlewares == nil { + b.middlewares = make([]func(http.Handler) http.Handler, 0) } - b.adapters = append(b.adapters, adapter) + b.middlewares = append(b.middlewares, mw) } -// WithAuth adds an http adapter with auth capabilities (middleware) -func (b *Builder) WithAuth(adapter api.AuthAdapter) { +// WithContext adds an http middleware with the fully loaded context +// +// Logging or generic request management should be done with the With() method as +// it wraps the full http connection. Middlewares added through this method only +// wrap the user-defined service handler. The context.Context is filled with useful +// data that can be access with api.GetRequest(), api.GetResponseWriter(), +// api.GetAuth(), etc methods. +func (b *Builder) WithContext(mw func(http.Handler) http.Handler) { if b.conf == nil { b.conf = &config.Server{} } - if b.authAdapters == nil { - b.authAdapters = make([]api.AuthAdapter, 0) + if b.ctxMiddlewares == nil { + b.ctxMiddlewares = make([]func(http.Handler) http.Handler, 0) } - b.authAdapters = append(b.authAdapters, adapter) + b.ctxMiddlewares = append(b.ctxMiddlewares, mw) } // Setup the builder with its api definition file diff --git a/builder_test.go b/builder_test.go index ab98dca..0207c1e 100644 --- a/builder_test.go +++ b/builder_test.go @@ -1,6 +1,7 @@ package aicra import ( + "context" "errors" "net/http" "strings" @@ -72,7 +73,7 @@ func TestBind(t *testing.T) { Config: "[]", HandlerMethod: "", HandlerPath: "", - HandlerFn: func() (*struct{}, api.Err) { return nil, api.ErrSuccess }, + HandlerFn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, BindErr: errUnknownService, BuildErr: nil, }, @@ -108,7 +109,7 @@ func TestBind(t *testing.T) { ]`, HandlerMethod: http.MethodPost, HandlerPath: "/path", - HandlerFn: func() (*struct{}, api.Err) { return nil, api.ErrSuccess }, + HandlerFn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, BindErr: errUnknownService, BuildErr: errMissingHandler, }, @@ -126,7 +127,7 @@ func TestBind(t *testing.T) { ]`, HandlerMethod: http.MethodGet, HandlerPath: "/paths", - HandlerFn: func() (*struct{}, api.Err) { return nil, api.ErrSuccess }, + HandlerFn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, BindErr: errUnknownService, BuildErr: errMissingHandler, }, @@ -144,7 +145,7 @@ func TestBind(t *testing.T) { ]`, HandlerMethod: http.MethodGet, HandlerPath: "/path", - HandlerFn: func() (*struct{}, api.Err) { return nil, api.ErrSuccess }, + HandlerFn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, BindErr: nil, BuildErr: nil, }, @@ -164,7 +165,7 @@ func TestBind(t *testing.T) { ]`, HandlerMethod: http.MethodGet, HandlerPath: "/path", - HandlerFn: func(struct{ Name int }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + HandlerFn: func(context.Context, struct{ Name int }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, BindErr: nil, BuildErr: nil, }, @@ -184,7 +185,7 @@ func TestBind(t *testing.T) { ]`, HandlerMethod: http.MethodGet, HandlerPath: "/path", - HandlerFn: func(struct{ Name uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + HandlerFn: func(context.Context, struct{ Name uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, BindErr: nil, BuildErr: nil, }, @@ -204,7 +205,7 @@ func TestBind(t *testing.T) { ]`, HandlerMethod: http.MethodGet, HandlerPath: "/path", - HandlerFn: func(struct{ Name string }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + HandlerFn: func(context.Context, struct{ Name string }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, BindErr: nil, BuildErr: nil, }, @@ -224,7 +225,7 @@ func TestBind(t *testing.T) { ]`, HandlerMethod: http.MethodGet, HandlerPath: "/path", - HandlerFn: func(struct{ Name bool }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + HandlerFn: func(context.Context, struct{ Name bool }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, BindErr: nil, BuildErr: nil, }, diff --git a/handler.go b/handler.go index 87ce076..4dea77a 100644 --- a/handler.go +++ b/handler.go @@ -1,12 +1,14 @@ package aicra import ( + "context" "fmt" "net/http" "strings" "git.xdrm.io/go/aicra/api" "git.xdrm.io/go/aicra/internal/config" + "git.xdrm.io/go/aicra/internal/ctx" "git.xdrm.io/go/aicra/internal/reqdata" ) @@ -15,30 +17,31 @@ type Handler Builder // ServeHTTP implements http.Handler and wraps it in middlewares (adapters) func (s Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - var h = http.HandlerFunc(s.resolve) + var h http.Handler = http.HandlerFunc(s.resolve) - for _, adapter := range s.adapters { - h = adapter(h) + for _, mw := range s.middlewares { + h = mw(h) } - h(w, r) + h.ServeHTTP(w, r) } +// ServeHTTP implements http.Handler and wraps it in middlewares (adapters) func (s Handler) resolve(w http.ResponseWriter, r *http.Request) { - // 1. find a matching service from config + // 1ind a matching service from config var service = s.conf.Find(r) if service == nil { handleError(api.ErrUnknownService, w, r) return } - // 2. extract request data + // extract request data var input, err = extractInput(service, *r) if err != nil { handleError(api.ErrMissingParam, w, r) return } - // 3. find a matching handler + // find a matching handler var handler *apiHandler for _, h := range s.handlers { if h.Method == service.Method && h.Path == service.Pattern { @@ -46,59 +49,50 @@ func (s Handler) resolve(w http.ResponseWriter, r *http.Request) { } } - // 4. fail on no matching handler + // fail on no matching handler if handler == nil { handleError(api.ErrUncallableService, w, r) return } - // replace format '[a]' in scope where 'a' is an existing input's name - scope := make([][]string, len(service.Scope)) - for a, list := range service.Scope { - scope[a] = make([]string, len(list)) - for b, perm := range list { - scope[a][b] = perm - for name, value := range input.Data { - var ( - token = fmt.Sprintf("[%s]", name) - replacement = "" - ) - if value != nil { - replacement = fmt.Sprintf("[%v]", value) - } - scope[a][b] = strings.ReplaceAll(scope[a][b], token, replacement) - } - } - } + // build context with builtin data + c := r.Context() + c = context.WithValue(c, ctx.Request, r) + c = context.WithValue(c, ctx.Response, w) + c = context.WithValue(c, ctx.Auth, buildAuth(service.Scope, input.Data)) - var auth = api.Auth{ - Required: scope, - Active: []string{}, - } - - // 5. run auth-aware middlewares - var h = api.AuthHandlerFunc(func(a api.Auth, w http.ResponseWriter, r *http.Request) { - if !a.Granted() { + // create http handler + var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := api.GetAuth(r.Context()) + if auth == nil { handleError(api.ErrPermission, w, r) return } - s.handle(input, handler, service, w, r) + // reject non granted requests + if !auth.Granted() { + handleError(api.ErrPermission, w, r) + return + } + + // use context defined in the request + s.handle(r.Context(), input, handler, service, w, r) }) - for _, adapter := range s.authAdapters { - h = adapter(h) + // run middlewares the handler + for _, mw := range s.ctxMiddlewares { + h = mw(h) } - h(auth, w, r) + // serve using the context with values + h.ServeHTTP(w, r.WithContext(c)) } -func (s *Handler) handle(input *reqdata.T, handler *apiHandler, service *config.Service, w http.ResponseWriter, r *http.Request) { - // 5. pass execution to the handler - ctx := api.Ctx{Res: w, Req: r} - var outData, outErr = handler.dyn.Handle(ctx, input.Data) +func (s *Handler) handle(c context.Context, input *reqdata.T, handler *apiHandler, service *config.Service, w http.ResponseWriter, r *http.Request) { + // pass execution to the handler + var outData, outErr = handler.dyn.Handle(c, input.Data) - // 6. build res from returned data + // build response from returned arguments var res = api.EmptyResponse().WithError(outErr) for key, value := range outData { @@ -149,3 +143,35 @@ func extractInput(service *config.Service, req http.Request) (*reqdata.T, error) return dataset, nil } + +// buildAuth builds the api.Auth struct from the service scope configuration +// +// it replaces format '[a]' in scope where 'a' is an existing input argument's +// name with its value +func buildAuth(scope [][]string, in map[string]interface{}) *api.Auth { + updated := make([][]string, len(scope)) + + // replace '[arg_name]' with the 'arg_name' value if it is a known variable + // name + for a, list := range scope { + updated[a] = make([]string, len(list)) + for b, perm := range list { + updated[a][b] = perm + for name, value := range in { + var ( + token = fmt.Sprintf("[%s]", name) + replacement = "" + ) + if value != nil { + replacement = fmt.Sprintf("[%v]", value) + } + updated[a][b] = strings.ReplaceAll(updated[a][b], token, replacement) + } + } + } + + return &api.Auth{ + Required: updated, + Active: []string{}, + } +} diff --git a/handler_test.go b/handler_test.go index 6d0cfd8..c5883c5 100644 --- a/handler_test.go +++ b/handler_test.go @@ -3,6 +3,7 @@ package aicra_test import ( "bytes" "context" + "encoding/json" "fmt" "net/http" "net/http/httptest" @@ -48,15 +49,15 @@ func TestWith(t *testing.T) { type ckey int const key ckey = 0 - middleware := func(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { + middleware := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { newr := r // first time -> store 1 value := r.Context().Value(key) if value == nil { newr = r.WithContext(context.WithValue(r.Context(), key, int(1))) - next(w, newr) + next.ServeHTTP(w, newr) return } @@ -67,8 +68,8 @@ func TestWith(t *testing.T) { } cast++ newr = r.WithContext(context.WithValue(r.Context(), key, cast)) - next(w, newr) - } + next.ServeHTTP(w, newr) + }) } // add middleware @n times @@ -82,9 +83,9 @@ func TestWith(t *testing.T) { t.Fatalf("setup: unexpected error <%v>", err) } - pathHandler := func(ctx api.Ctx) (*struct{}, api.Err) { + pathHandler := func(ctx context.Context) (*struct{}, api.Err) { // write value from middlewares into response - value := ctx.Req.Context().Value(key) + value := ctx.Value(key) if value == nil { t.Fatalf("nothing found in context") } @@ -93,7 +94,7 @@ func TestWith(t *testing.T) { t.Fatalf("cannot cast context data to int") } // write to response - ctx.Res.Write([]byte(fmt.Sprintf("#%d#", cast))) + api.GetResponseWriter(ctx).Write([]byte(fmt.Sprintf("#%d#", cast))) return nil, api.ErrSuccess } @@ -212,8 +213,13 @@ func TestWithAuth(t *testing.T) { } // tester middleware (last executed) - builder.WithAuth(func(next api.AuthHandlerFunc) api.AuthHandlerFunc { - return func(a api.Auth, w http.ResponseWriter, r *http.Request) { + builder.WithContext(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + a := api.GetAuth(r.Context()) + if a == nil { + t.Fatalf("cannot access api.Auth form request context") + } + if a.Granted() == tc.granted { return } @@ -222,14 +228,20 @@ func TestWithAuth(t *testing.T) { } else { t.Fatalf("expected granted auth") } - } + next.ServeHTTP(w, r) + }) }) - builder.WithAuth(func(next api.AuthHandlerFunc) api.AuthHandlerFunc { - return func(a api.Auth, w http.ResponseWriter, r *http.Request) { + builder.WithContext(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + a := api.GetAuth(r.Context()) + if a == nil { + t.Fatalf("cannot access api.Auth form request context") + } + a.Active = tc.permissions - next(a, w, r) - } + next.ServeHTTP(w, r) + }) }) err := builder.Setup(strings.NewReader(tc.manifest)) @@ -237,7 +249,7 @@ func TestWithAuth(t *testing.T) { t.Fatalf("setup: unexpected error <%v>", err) } - pathHandler := func(ctx api.Ctx) (*struct{}, api.Err) { + pathHandler := func(ctx context.Context) (*struct{}, api.Err) { return nil, api.ErrNotImplemented } @@ -264,6 +276,97 @@ func TestWithAuth(t *testing.T) { } +func TestPermissionError(t *testing.T) { + + tt := []struct { + name string + manifest string + permissions []string + granted bool + }{ + { + name: "permission fulfilled", + manifest: `[ { "method": "GET", "path": "/path", "scope": [["A"]], "info": "info", "in": {}, "out": {} } ]`, + permissions: []string{"A"}, + granted: true, + }, + { + name: "missing permission", + manifest: `[ { "method": "GET", "path": "/path", "scope": [["A"]], "info": "info", "in": {}, "out": {} } ]`, + permissions: []string{}, + granted: false, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + builder := &aicra.Builder{} + if err := addBuiltinTypes(builder); err != nil { + t.Fatalf("unexpected error <%v>", err) + } + + // add active permissions + builder.WithContext(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + a := api.GetAuth(r.Context()) + if a == nil { + t.Fatalf("cannot access api.Auth form request context") + } + + a.Active = tc.permissions + next.ServeHTTP(w, r) + }) + }) + + err := builder.Setup(strings.NewReader(tc.manifest)) + if err != nil { + t.Fatalf("setup: unexpected error <%v>", err) + } + + pathHandler := func(ctx context.Context) (*struct{}, api.Err) { + return nil, api.ErrNotImplemented + } + + if err := builder.Bind(http.MethodGet, "/path", pathHandler); err != nil { + t.Fatalf("bind: unexpected error <%v>", err) + } + + handler, err := builder.Build() + if err != nil { + t.Fatalf("build: unexpected error <%v>", err) + } + + response := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/path", &bytes.Buffer{}) + + // test request + handler.ServeHTTP(response, request) + if response.Body == nil { + t.Fatalf("response has no body") + } + type jsonResponse struct { + Err api.Err `json:"error"` + } + var res jsonResponse + err = json.Unmarshal(response.Body.Bytes(), &res) + if err != nil { + t.Fatalf("cannot unmarshal response: %s", err) + } + + expectedError := api.ErrNotImplemented + if !tc.granted { + expectedError = api.ErrPermission + } + + if res.Err.Code != expectedError.Code { + t.Fatalf("expected error code %d got %d", expectedError.Code, res.Err.Code) + } + + }) + } + +} + func TestDynamicScope(t *testing.T) { tt := []struct { name string @@ -290,7 +393,7 @@ func TestDynamicScope(t *testing.T) { } ]`, path: "/path/{id}", - handler: func(struct{ Input1 uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + handler: func(context.Context, struct{ Input1 uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, url: "/path/123", body: ``, permissions: []string{"user[123]"}, @@ -311,7 +414,7 @@ func TestDynamicScope(t *testing.T) { } ]`, path: "/path/{id}", - handler: func(struct{ Input1 uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + handler: func(context.Context, struct{ Input1 uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, url: "/path/666", body: ``, permissions: []string{"user[123]"}, @@ -332,7 +435,7 @@ func TestDynamicScope(t *testing.T) { } ]`, path: "/path/{id}", - handler: func(struct{ User uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + handler: func(context.Context, struct{ User uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, url: "/path/123", body: ``, permissions: []string{"prefix.user[123].suffix"}, @@ -354,7 +457,7 @@ func TestDynamicScope(t *testing.T) { } ]`, path: "/prefix/{pid}/user/{uid}", - handler: func(struct { + handler: func(context.Context, struct { Prefix uint User uint }) (*struct{}, api.Err) { @@ -381,7 +484,7 @@ func TestDynamicScope(t *testing.T) { } ]`, path: "/prefix/{pid}/user/{uid}", - handler: func(struct { + handler: func(context.Context, struct { Prefix uint User uint }) (*struct{}, api.Err) { @@ -409,7 +512,7 @@ func TestDynamicScope(t *testing.T) { } ]`, path: "/prefix/{pid}/user/{uid}/suffix/{sid}", - handler: func(struct { + handler: func(context.Context, struct { Prefix uint User uint Suffix uint @@ -438,7 +541,7 @@ func TestDynamicScope(t *testing.T) { } ]`, path: "/prefix/{pid}/user/{uid}/suffix/{sid}", - handler: func(struct { + handler: func(context.Context, struct { Prefix uint User uint Suffix uint @@ -460,8 +563,12 @@ func TestDynamicScope(t *testing.T) { } // tester middleware (last executed) - builder.WithAuth(func(next api.AuthHandlerFunc) api.AuthHandlerFunc { - return func(a api.Auth, w http.ResponseWriter, r *http.Request) { + builder.WithContext(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + a := api.GetAuth(r.Context()) + if a == nil { + t.Fatalf("cannot access api.Auth form request context") + } if a.Granted() == tc.granted { return } @@ -470,15 +577,20 @@ func TestDynamicScope(t *testing.T) { } else { t.Fatalf("expected granted auth") } - } + next.ServeHTTP(w, r) + }) }) // update permissions - builder.WithAuth(func(next api.AuthHandlerFunc) api.AuthHandlerFunc { - return func(a api.Auth, w http.ResponseWriter, r *http.Request) { + builder.WithContext(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + a := api.GetAuth(r.Context()) + if a == nil { + t.Fatalf("cannot access api.Auth form request context") + } a.Active = tc.permissions - next(a, w, r) - } + next.ServeHTTP(w, r) + }) }) err := builder.Setup(strings.NewReader(tc.manifest)) diff --git a/internal/ctx/ctx.go b/internal/ctx/ctx.go new file mode 100644 index 0000000..dd59756 --- /dev/null +++ b/internal/ctx/ctx.go @@ -0,0 +1,13 @@ +package ctx + +// Key defines a custom context key type +type Key int + +const ( + // Request is the key for the current *http.Request + Request Key = iota + // Response is the key for the associated http.ResponseWriter + Response + // Auth is the key for the request's authentication information + Auth +) diff --git a/internal/dynfunc/errors.go b/internal/dynfunc/errors.go index ca87692..903e5f0 100644 --- a/internal/dynfunc/errors.go +++ b/internal/dynfunc/errors.go @@ -14,16 +14,19 @@ const errHandlerNotFunc = cerr("handler must be a func") const errNoServiceForHandler = cerr("no service found for this handler") // errMissingHandlerArgumentParam - missing params arguments for handler -const errMissingHandlerArgumentParam = cerr("missing handler argument : parameter struct") +const errMissingHandlerContextArgument = cerr("missing handler first argument of type context.Context") + +// errMissingHandlerInputArgument - missing params arguments for handler +const errMissingHandlerInputArgument = cerr("missing handler argument: input struct") // errUnexpectedInput - input argument is not expected const errUnexpectedInput = cerr("unexpected input struct") -// errMissingHandlerOutput - missing output for handler -const errMissingHandlerOutput = cerr("handler must have at least 1 output") +// errMissingHandlerOutputArgument - missing output for handler +const errMissingHandlerOutputArgument = cerr("missing handler first output argument: output struct") // errMissingHandlerOutputError - missing error output for handler -const errMissingHandlerOutputError = cerr("handler must have its last output of type api.Err") +const errMissingHandlerOutputError = cerr("missing handler last output argument of type api.Err") // errMissingRequestArgument - missing request argument for handler const errMissingRequestArgument = cerr("handler first argument must be of type api.Request") @@ -34,17 +37,14 @@ const errMissingParamArgument = cerr("handler second argument must be a struct") // errUnexportedName - argument is unexported in struct const errUnexportedName = cerr("unexported name") -// errMissingParamOutput - missing output argument for handler -const errMissingParamOutput = cerr("handler first output must be a *struct") +// errWrongOutputArgumentType - wrong type for output first argument +const errWrongOutputArgumentType = cerr("handler first output argument must be a *struct") -// errMissingParamFromConfig - missing a parameter in handler struct -const errMissingParamFromConfig = cerr("missing a parameter from configuration") - -// errMissingOutputFromConfig - missing a parameter in handler struct -const errMissingOutputFromConfig = cerr("missing a parameter from configuration") +// errMissingConfigArgument - missing an input/output argument in handler struct +const errMissingConfigArgument = cerr("missing an argument from the configuration") // errWrongParamTypeFromConfig - a configuration parameter type is invalid in the handler param struct const errWrongParamTypeFromConfig = cerr("invalid struct field type") -// errMissingHandlerErrorOutput - missing handler output error -const errMissingHandlerErrorOutput = cerr("last output must be of type api.Err") +// errMissingHandlerErrorArgument - missing handler output error +const errMissingHandlerErrorArgument = cerr("last output must be of type api.Err") diff --git a/internal/dynfunc/handler.go b/internal/dynfunc/handler.go index 634cfc6..783612b 100644 --- a/internal/dynfunc/handler.go +++ b/internal/dynfunc/handler.go @@ -1,6 +1,7 @@ package dynfunc import ( + "context" "fmt" "log" "reflect" @@ -9,73 +10,69 @@ import ( "git.xdrm.io/go/aicra/internal/config" ) -// Handler represents a dynamic api handler +// Handler represents a dynamic aicra service handler type Handler struct { - spec *signature - fn interface{} - // whether fn uses api.Ctx as 1st argument - hasContext bool - // index in input arguments where the data struct must be - dataIndex int + // signature defined from the service configuration + signature *Signature + // fn provided function that will be the service's handler + fn interface{} } -// Build a handler from a service configuration and a dynamic function -// -// @fn must have as a signature : `func(inputStruct) (*outputStruct, api.Err)` -// - `inputStruct` is a struct{} containing a field for each service input (with valid reflect.Type) -// - `outputStruct` is a struct{} containing a field for each service output (with valid reflect.Type) +// Build a handler from a dynamic function and checks its signature against a +// service configuration +//e +// `fn` must have as a signature : `func(*api.Context, in) (*out, api.Err)` +// - `in` is a struct{} containing a field for each service input (with valid reflect.Type) +// - `out` is a struct{} containing a field for each service output (with valid reflect.Type) // // Special cases: -// - a first optional input parameter of type `api.Ctx` can be added -// - it there is no input, `inputStruct` must be omitted -// - it there is no output, `outputStruct` must be omitted +// - it there is no input, `in` MUST be omitted +// - it there is no output, `out` MUST be omitted func Build(fn interface{}, service config.Service) (*Handler, error) { - h := &Handler{ - spec: signatureFromService(service), - fn: fn, - } + var ( + h = &Handler{ + signature: BuildSignature(service), + fn: fn, + } + fnType = reflect.TypeOf(fn) + ) - impl := reflect.TypeOf(fn) - - if impl.Kind() != reflect.Func { + if fnType.Kind() != reflect.Func { return nil, errHandlerNotFunc } - - h.hasContext = impl.NumIn() >= 1 && reflect.TypeOf(api.Ctx{}).AssignableTo(impl.In(0)) - if h.hasContext { - h.dataIndex = 1 - } - - if err := h.spec.checkInput(impl, h.dataIndex); err != nil { + if err := h.signature.ValidateInput(fnType); err != nil { return nil, fmt.Errorf("input: %w", err) } - if err := h.spec.checkOutput(impl); err != nil { + if err := h.signature.ValidateOutput(fnType); err != nil { return nil, fmt.Errorf("output: %w", err) } return h, nil } -// Handle binds input @data into the dynamic function and returns map output -func (h *Handler) Handle(ctx api.Ctx, data map[string]interface{}) (map[string]interface{}, api.Err) { - var ert = reflect.TypeOf(api.Err{}) - var fnv = reflect.ValueOf(h.fn) +// Handle binds input `data` into the dynamic function and returns an output map +func (h *Handler) Handle(ctx context.Context, data map[string]interface{}) (map[string]interface{}, api.Err) { + var ( + ert = reflect.TypeOf(api.Err{}) + fnv = reflect.ValueOf(h.fn) + callArgs = make([]reflect.Value, 0) + ) - callArgs := []reflect.Value{} + // bind context + callArgs = append(callArgs, reflect.ValueOf(ctx)) - // bind context if used in handler - if h.hasContext { - callArgs = append(callArgs, reflect.ValueOf(ctx)) - } + inputStructRequired := fnv.Type().NumIn() > 1 - // bind input data - if fnv.Type().NumIn() > h.dataIndex { + // bind input arguments + if inputStructRequired { // create zero value struct - callStructPtr := reflect.New(fnv.Type().In(0)) - callStruct := callStructPtr.Elem() + var ( + callStructPtr = reflect.New(fnv.Type().In(1)) + callStruct = callStructPtr.Elem() + ) // set each field - for name := range h.spec.Input { + for name := range h.signature.Input { field := callStruct.FieldByName(name) if !field.CanSet() { continue @@ -115,12 +112,12 @@ func (h *Handler) Handle(ctx api.Ctx, data map[string]interface{}) (map[string]i callArgs = append(callArgs, callStruct) } - // call the HandlerFn + // call the handler output := fnv.Call(callArgs) // no output OR pointer to output struct is nil outdata := make(map[string]interface{}) - if len(h.spec.Output) < 1 || output[0].IsNil() { + if len(h.signature.Output) < 1 || output[0].IsNil() { var structerr = output[len(output)-1].Convert(ert) return outdata, api.Err{ Code: int(structerr.FieldByName("Code").Int()), @@ -132,7 +129,7 @@ func (h *Handler) Handle(ctx api.Ctx, data map[string]interface{}) (map[string]i // extract struct from pointer returnStruct := output[0].Elem() - for name := range h.spec.Output { + for name := range h.signature.Output { field := returnStruct.FieldByName(name) outdata[name] = field.Interface() } diff --git a/internal/dynfunc/handler_test.go b/internal/dynfunc/handler_test.go index a457f1e..053cc25 100644 --- a/internal/dynfunc/handler_test.go +++ b/internal/dynfunc/handler_test.go @@ -1,6 +1,7 @@ package dynfunc import ( + "context" "fmt" "reflect" "testing" @@ -8,7 +9,7 @@ import ( "git.xdrm.io/go/aicra/api" ) -type testsignature signature +type testsignature Signature // builds a mock service with provided arguments as Input and matched as Output func (s *testsignature) withArgs(dtypes ...reflect.Type) *testsignature { @@ -52,7 +53,7 @@ func TestInput(t *testing.T) { { Name: "none required none provided", Spec: (&testsignature{}).withArgs(), - Fn: func() (*struct{}, api.Err) { return nil, api.ErrSuccess }, + Fn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, HasContext: false, Input: []interface{}{}, ExpectedOutput: []interface{}{}, @@ -61,7 +62,7 @@ func TestInput(t *testing.T) { { Name: "int proxy (0)", Spec: (&testsignature{}).withArgs(reflect.TypeOf(int(0))), - Fn: func(in intstruct) (*intstruct, api.Err) { + Fn: func(ctx context.Context, in intstruct) (*intstruct, api.Err) { return &intstruct{P1: in.P1}, api.ErrSuccess }, HasContext: false, @@ -72,7 +73,7 @@ func TestInput(t *testing.T) { { Name: "int proxy (11)", Spec: (&testsignature{}).withArgs(reflect.TypeOf(int(0))), - Fn: func(in intstruct) (*intstruct, api.Err) { + Fn: func(ctx context.Context, in intstruct) (*intstruct, api.Err) { return &intstruct{P1: in.P1}, api.ErrSuccess }, HasContext: false, @@ -83,7 +84,7 @@ func TestInput(t *testing.T) { { Name: "*int proxy (nil)", Spec: (&testsignature{}).withArgs(reflect.TypeOf(new(int))), - Fn: func(in intptrstruct) (*intptrstruct, api.Err) { + Fn: func(ctx context.Context, in intptrstruct) (*intptrstruct, api.Err) { return &intptrstruct{P1: in.P1}, api.ErrSuccess }, HasContext: false, @@ -94,7 +95,7 @@ func TestInput(t *testing.T) { { Name: "*int proxy (28)", Spec: (&testsignature{}).withArgs(reflect.TypeOf(new(int))), - Fn: func(in intptrstruct) (*intstruct, api.Err) { + Fn: func(ctx context.Context, in intptrstruct) (*intstruct, api.Err) { return &intstruct{P1: *in.P1}, api.ErrSuccess }, HasContext: false, @@ -105,7 +106,7 @@ func TestInput(t *testing.T) { { Name: "*int proxy (13)", Spec: (&testsignature{}).withArgs(reflect.TypeOf(new(int))), - Fn: func(in intptrstruct) (*intstruct, api.Err) { + Fn: func(ctx context.Context, in intptrstruct) (*intstruct, api.Err) { return &intstruct{P1: *in.P1}, api.ErrSuccess }, HasContext: false, @@ -119,16 +120,9 @@ func TestInput(t *testing.T) { t.Run(tcase.Name, func(t *testing.T) { t.Parallel() - var dataIndex = 0 - if tcase.HasContext { - dataIndex = 1 - } - var handler = &Handler{ - spec: &signature{Input: tcase.Spec.Input, Output: tcase.Spec.Output}, - fn: tcase.Fn, - dataIndex: dataIndex, - hasContext: tcase.HasContext, + signature: &Signature{Input: tcase.Spec.Input, Output: tcase.Spec.Output}, + fn: tcase.Fn, } // build input @@ -138,7 +132,7 @@ func TestInput(t *testing.T) { input[key] = val } - var output, err = handler.Handle(api.Ctx{}, input) + var output, err = handler.Handle(context.Background(), input) if err != tcase.ExpectedErr { t.Fatalf("expected api error <%v> got <%v>", tcase.ExpectedErr, err) } diff --git a/internal/dynfunc/signature.go b/internal/dynfunc/signature.go index 2ee32ae..e0f9a19 100644 --- a/internal/dynfunc/signature.go +++ b/internal/dynfunc/signature.go @@ -1,6 +1,7 @@ package dynfunc import ( + "context" "fmt" "reflect" "strings" @@ -9,15 +10,17 @@ import ( "git.xdrm.io/go/aicra/internal/config" ) -// signature represents input/output arguments for a dynamic function -type signature struct { - Input map[string]reflect.Type +// Signature represents input/output arguments for service from the aicra configuration +type Signature struct { + // Input arguments of the service + Input map[string]reflect.Type + // Output arguments of the service Output map[string]reflect.Type } -// builds a spec from the configuration service -func signatureFromService(service config.Service) *signature { - s := &signature{ +// BuildSignature builds a signature for a service configuration +func BuildSignature(service config.Service) *Signature { + s := &Signature{ Input: make(map[string]reflect.Type), Output: make(map[string]reflect.Type), } @@ -44,31 +47,37 @@ func signatureFromService(service config.Service) *signature { return s } -// checks for HandlerFn input arguments -func (s *signature) checkInput(impl reflect.Type, index int) error { - var requiredInput, structIndex = index, index - if len(s.Input) > 0 { // arguments struct - requiredInput++ +// ValidateInput validates a handler's input arguments against the service signature +func (s *Signature) ValidateInput(handlerType reflect.Type) error { + ctxType := reflect.TypeOf((*context.Context)(nil)).Elem() + + // missing or invalid first arg: context.Context + if handlerType.NumIn() < 1 { + return errMissingHandlerContextArgument + } + firstArgType := handlerType.In(0) + + if !firstArgType.Implements(ctxType) { + return fmt.Errorf("fock") } - // missing arguments - if impl.NumIn() > requiredInput { - return errUnexpectedInput - } - - // none required + // no input required if len(s.Input) == 0 { + // input struct provided + if handlerType.NumIn() > 1 { + return errUnexpectedInput + } return nil } // too much arguments - if impl.NumIn() != requiredInput { - return errMissingHandlerArgumentParam + if handlerType.NumIn() > 2 { + return errMissingHandlerInputArgument } // arg must be a struct - structArg := impl.In(structIndex) - if structArg.Kind() != reflect.Struct { + inStruct := handlerType.In(1) + if inStruct.Kind() != reflect.Struct { return errMissingParamArgument } @@ -78,9 +87,9 @@ func (s *signature) checkInput(impl reflect.Type, index int) error { return fmt.Errorf("%s: %w", name, errUnexportedName) } - field, exists := structArg.FieldByName(name) + field, exists := inStruct.FieldByName(name) if !exists { - return fmt.Errorf("%s: %w", name, errMissingParamFromConfig) + return fmt.Errorf("%s: %w", name, errMissingConfigArgument) } if !ptype.AssignableTo(field.Type) { @@ -91,16 +100,18 @@ func (s *signature) checkInput(impl reflect.Type, index int) error { return nil } -// checks for HandlerFn output arguments -func (s signature) checkOutput(impl reflect.Type) error { - if impl.NumOut() < 1 { - return errMissingHandlerOutput +// ValidateOutput validates a handler's output arguments against the service signature +func (s Signature) ValidateOutput(handlerType reflect.Type) error { + errType := reflect.TypeOf(api.ErrUnknown) + + if handlerType.NumOut() < 1 { + return errMissingHandlerErrorArgument } // last output must be api.Err - errOutput := impl.Out(impl.NumOut() - 1) - if !errOutput.AssignableTo(reflect.TypeOf(api.ErrUnknown)) { - return errMissingHandlerErrorOutput + lastArgType := handlerType.Out(handlerType.NumOut() - 1) + if !lastArgType.AssignableTo(errType) { + return errMissingHandlerErrorArgument } // no output -> ok @@ -108,19 +119,19 @@ func (s signature) checkOutput(impl reflect.Type) error { return nil } - if impl.NumOut() != 2 { - return errMissingParamOutput + if handlerType.NumOut() < 2 { + return errMissingHandlerOutputArgument } // fail if first output is not a pointer to struct - structOutputPtr := impl.Out(0) - if structOutputPtr.Kind() != reflect.Ptr { - return errMissingParamOutput + outStructPtr := handlerType.Out(0) + if outStructPtr.Kind() != reflect.Ptr { + return errWrongOutputArgumentType } - structOutput := structOutputPtr.Elem() - if structOutput.Kind() != reflect.Struct { - return errMissingParamOutput + outStruct := outStructPtr.Elem() + if outStruct.Kind() != reflect.Struct { + return errWrongOutputArgumentType } // fail on invalid output @@ -129,9 +140,9 @@ func (s signature) checkOutput(impl reflect.Type) error { return fmt.Errorf("%s: %w", name, errUnexportedName) } - field, exists := structOutput.FieldByName(name) + field, exists := outStruct.FieldByName(name) if !exists { - return fmt.Errorf("%s: %w", name, errMissingOutputFromConfig) + return fmt.Errorf("%s: %w", name, errMissingConfigArgument) } // ignore types evalutating to nil diff --git a/internal/dynfunc/signature_test.go b/internal/dynfunc/signature_test.go index 8860a92..874834c 100644 --- a/internal/dynfunc/signature_test.go +++ b/internal/dynfunc/signature_test.go @@ -1,6 +1,7 @@ package dynfunc import ( + "context" "errors" "fmt" "reflect" @@ -20,22 +21,22 @@ func TestInputCheck(t *testing.T) { { Name: "no input 0 given", Input: map[string]reflect.Type{}, - Fn: func() {}, - FnCtx: func(api.Ctx) {}, + Fn: func(context.Context) {}, + FnCtx: func(context.Context) {}, Err: nil, }, { Name: "no input 1 given", Input: map[string]reflect.Type{}, - Fn: func(int) {}, - FnCtx: func(api.Ctx, int) {}, + Fn: func(context.Context, int) {}, + FnCtx: func(context.Context, int) {}, Err: errUnexpectedInput, }, { Name: "no input 2 given", Input: map[string]reflect.Type{}, - Fn: func(int, string) {}, - FnCtx: func(api.Ctx, int, string) {}, + Fn: func(context.Context, int, string) {}, + FnCtx: func(context.Context, int, string) {}, Err: errUnexpectedInput, }, { @@ -43,17 +44,17 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func() {}, - FnCtx: func(api.Ctx) {}, - Err: errMissingHandlerArgumentParam, + Fn: func(context.Context) {}, + FnCtx: func(context.Context) {}, + Err: errMissingHandlerInputArgument, }, { Name: "1 input non-struct given", Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func(int) {}, - FnCtx: func(api.Ctx, int) {}, + Fn: func(context.Context, int) {}, + FnCtx: func(context.Context, int) {}, Err: errMissingParamArgument, }, { @@ -61,8 +62,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "test1": reflect.TypeOf(int(0)), }, - Fn: func(struct{}) {}, - FnCtx: func(api.Ctx, struct{}) {}, + Fn: func(context.Context, struct{}) {}, + FnCtx: func(context.Context, struct{}) {}, Err: errUnexportedName, }, { @@ -70,17 +71,17 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func(struct{}) {}, - FnCtx: func(api.Ctx, struct{}) {}, - Err: errMissingParamFromConfig, + Fn: func(context.Context, struct{}) {}, + FnCtx: func(context.Context, struct{}) {}, + Err: errMissingConfigArgument, }, { Name: "1 input invalid given", Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func(struct{ Test1 string }) {}, - FnCtx: func(api.Ctx, struct{ Test1 string }) {}, + Fn: func(context.Context, struct{ Test1 string }) {}, + FnCtx: func(context.Context, struct{ Test1 string }) {}, Err: errWrongParamTypeFromConfig, }, { @@ -88,8 +89,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func(struct{ Test1 int }) {}, - FnCtx: func(api.Ctx, struct{ Test1 int }) {}, + Fn: func(context.Context, struct{ Test1 int }) {}, + FnCtx: func(context.Context, struct{ Test1 int }) {}, Err: nil, }, { @@ -97,17 +98,17 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(int)), }, - Fn: func(struct{}) {}, - FnCtx: func(api.Ctx, struct{}) {}, - Err: errMissingParamFromConfig, + Fn: func(context.Context, struct{}) {}, + FnCtx: func(context.Context, struct{}) {}, + Err: errMissingConfigArgument, }, { Name: "1 input ptr invalid given", Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(int)), }, - Fn: func(struct{ Test1 string }) {}, - FnCtx: func(api.Ctx, struct{ Test1 string }) {}, + Fn: func(context.Context, struct{ Test1 string }) {}, + FnCtx: func(context.Context, struct{ Test1 string }) {}, Err: errWrongParamTypeFromConfig, }, { @@ -115,8 +116,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(int)), }, - Fn: func(struct{ Test1 *string }) {}, - FnCtx: func(api.Ctx, struct{ Test1 *string }) {}, + Fn: func(context.Context, struct{ Test1 *string }) {}, + FnCtx: func(context.Context, struct{ Test1 *string }) {}, Err: errWrongParamTypeFromConfig, }, { @@ -124,8 +125,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(int)), }, - Fn: func(struct{ Test1 *int }) {}, - FnCtx: func(api.Ctx, struct{ Test1 *int }) {}, + Fn: func(context.Context, struct{ Test1 *int }) {}, + FnCtx: func(context.Context, struct{ Test1 *int }) {}, Err: nil, }, { @@ -133,8 +134,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(string("")), }, - Fn: func(struct{ Test1 string }) {}, - FnCtx: func(api.Ctx, struct{ Test1 string }) {}, + Fn: func(context.Context, struct{ Test1 string }) {}, + FnCtx: func(context.Context, struct{ Test1 string }) {}, Err: nil, }, { @@ -142,8 +143,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(uint(0)), }, - Fn: func(struct{ Test1 uint }) {}, - FnCtx: func(api.Ctx, struct{ Test1 uint }) {}, + Fn: func(context.Context, struct{ Test1 uint }) {}, + FnCtx: func(context.Context, struct{ Test1 uint }) {}, Err: nil, }, { @@ -151,8 +152,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(float64(0)), }, - Fn: func(struct{ Test1 float64 }) {}, - FnCtx: func(api.Ctx, struct{ Test1 float64 }) {}, + Fn: func(context.Context, struct{ Test1 float64 }) {}, + FnCtx: func(context.Context, struct{ Test1 float64 }) {}, Err: nil, }, { @@ -160,8 +161,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf([]byte("")), }, - Fn: func(struct{ Test1 []byte }) {}, - FnCtx: func(api.Ctx, struct{ Test1 []byte }) {}, + Fn: func(context.Context, struct{ Test1 []byte }) {}, + FnCtx: func(context.Context, struct{ Test1 []byte }) {}, Err: nil, }, { @@ -169,8 +170,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf([]rune("")), }, - Fn: func(struct{ Test1 []rune }) {}, - FnCtx: func(api.Ctx, struct{ Test1 []rune }) {}, + Fn: func(context.Context, struct{ Test1 []rune }) {}, + FnCtx: func(context.Context, struct{ Test1 []rune }) {}, Err: nil, }, { @@ -178,8 +179,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(string)), }, - Fn: func(struct{ Test1 *string }) {}, - FnCtx: func(api.Ctx, struct{ Test1 *string }) {}, + Fn: func(context.Context, struct{ Test1 *string }) {}, + FnCtx: func(context.Context, struct{ Test1 *string }) {}, Err: nil, }, { @@ -187,8 +188,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(uint)), }, - Fn: func(struct{ Test1 *uint }) {}, - FnCtx: func(api.Ctx, struct{ Test1 *uint }) {}, + Fn: func(context.Context, struct{ Test1 *uint }) {}, + FnCtx: func(context.Context, struct{ Test1 *uint }) {}, Err: nil, }, { @@ -196,8 +197,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(float64)), }, - Fn: func(struct{ Test1 *float64 }) {}, - FnCtx: func(api.Ctx, struct{ Test1 *float64 }) {}, + Fn: func(context.Context, struct{ Test1 *float64 }) {}, + FnCtx: func(context.Context, struct{ Test1 *float64 }) {}, Err: nil, }, { @@ -205,8 +206,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new([]byte)), }, - Fn: func(struct{ Test1 *[]byte }) {}, - FnCtx: func(api.Ctx, struct{ Test1 *[]byte }) {}, + Fn: func(context.Context, struct{ Test1 *[]byte }) {}, + FnCtx: func(context.Context, struct{ Test1 *[]byte }) {}, Err: nil, }, { @@ -214,8 +215,8 @@ func TestInputCheck(t *testing.T) { Input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new([]rune)), }, - Fn: func(struct{ Test1 *[]rune }) {}, - FnCtx: func(api.Ctx, struct{ Test1 *[]rune }) {}, + Fn: func(context.Context, struct{ Test1 *[]rune }) {}, + FnCtx: func(context.Context, struct{ Test1 *[]rune }) {}, Err: nil, }, } @@ -225,47 +226,27 @@ func TestInputCheck(t *testing.T) { t.Parallel() // mock spec - s := signature{ + s := Signature{ Input: tcase.Input, Output: nil, } - t.Run("with-context", func(t *testing.T) { - err := s.checkInput(reflect.TypeOf(tcase.FnCtx), 1) - if err == nil && tcase.Err != nil { - t.Errorf("expected an error: '%s'", tcase.Err.Error()) - t.FailNow() - } - if err != nil && tcase.Err == nil { - t.Errorf("unexpected error: '%s'", err.Error()) - t.FailNow() - } + err := s.ValidateInput(reflect.TypeOf(tcase.FnCtx)) + if err == nil && tcase.Err != nil { + t.Errorf("expected an error: '%s'", tcase.Err.Error()) + t.FailNow() + } + if err != nil && tcase.Err == nil { + t.Errorf("unexpected error: '%s'", err.Error()) + t.FailNow() + } - if err != nil && tcase.Err != nil { - if !errors.Is(err, tcase.Err) { - t.Errorf("expected the error <%s> got <%s>", tcase.Err, err) - t.FailNow() - } - } - }) - t.Run("without-context", func(t *testing.T) { - err := s.checkInput(reflect.TypeOf(tcase.Fn), 0) - if err == nil && tcase.Err != nil { - t.Errorf("expected an error: '%s'", tcase.Err.Error()) + if err != nil && tcase.Err != nil { + if !errors.Is(err, tcase.Err) { + t.Errorf("expected the error <%s> got <%s>", tcase.Err, err) t.FailNow() } - if err != nil && tcase.Err == nil { - t.Errorf("unexpected error: '%s'", err.Error()) - t.FailNow() - } - - if err != nil && tcase.Err != nil { - if !errors.Is(err, tcase.Err) { - t.Errorf("expected the error <%s> got <%s>", tcase.Err, err) - t.FailNow() - } - } - }) + } }) } } @@ -279,25 +260,37 @@ func TestOutputCheck(t *testing.T) { // no input -> missing api.Err { Output: map[string]reflect.Type{}, - Fn: func() {}, - Err: errMissingHandlerOutput, + Fn: func(context.Context) {}, + Err: errMissingHandlerOutputArgument, }, // no input -> with last type not api.Err { Output: map[string]reflect.Type{}, - Fn: func() bool { return true }, - Err: errMissingHandlerErrorOutput, + Fn: func(context.Context) bool { return true }, + Err: errMissingHandlerErrorArgument, }, // no input -> with api.Err { Output: map[string]reflect.Type{}, - Fn: func() api.Err { return api.ErrSuccess }, + Fn: func(context.Context) api.Err { return api.ErrSuccess }, Err: nil, }, + // no input -> missing context.Context + { + Output: map[string]reflect.Type{}, + Fn: func(context.Context) api.Err { return api.ErrSuccess }, + Err: errMissingHandlerContextArgument, + }, + // no input -> invlaid context.Context type + { + Output: map[string]reflect.Type{}, + Fn: func(context.Context, int) api.Err { return api.ErrSuccess }, + Err: errMissingHandlerContextArgument, + }, // func can have output if not specified { Output: map[string]reflect.Type{}, - Fn: func() (*struct{}, api.Err) { return nil, api.ErrSuccess }, + Fn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, Err: nil, }, // missing output struct in func @@ -306,7 +299,7 @@ func TestOutputCheck(t *testing.T) { "Test1": reflect.TypeOf(int(0)), }, Fn: func() api.Err { return api.ErrSuccess }, - Err: errMissingParamOutput, + Err: errWrongOutputArgumentType, }, // output not a pointer { @@ -314,7 +307,7 @@ func TestOutputCheck(t *testing.T) { "Test1": reflect.TypeOf(int(0)), }, Fn: func() (int, api.Err) { return 0, api.ErrSuccess }, - Err: errMissingParamOutput, + Err: errWrongOutputArgumentType, }, // output not a pointer to struct { @@ -322,7 +315,7 @@ func TestOutputCheck(t *testing.T) { "Test1": reflect.TypeOf(int(0)), }, Fn: func() (*int, api.Err) { return nil, api.ErrSuccess }, - Err: errMissingParamOutput, + Err: errWrongOutputArgumentType, }, // unexported param name { @@ -338,7 +331,7 @@ func TestOutputCheck(t *testing.T) { "Test1": reflect.TypeOf(int(0)), }, Fn: func() (*struct{}, api.Err) { return nil, api.ErrSuccess }, - Err: errMissingParamFromConfig, + Err: errMissingConfigArgument, }, // output field invalid type { @@ -371,12 +364,12 @@ func TestOutputCheck(t *testing.T) { t.Parallel() // mock spec - s := signature{ + s := Signature{ Input: nil, Output: tcase.Output, } - err := s.checkOutput(reflect.TypeOf(tcase.Fn)) + err := s.ValidateOutput(reflect.TypeOf(tcase.Fn)) if err == nil && tcase.Err != nil { t.Errorf("expected an error: '%s'", tcase.Err.Error()) t.FailNow()