diff --git a/api/auth.go b/api/auth.go index cc162b7..343437b 100644 --- a/api/auth.go +++ b/api/auth.go @@ -1,62 +1,62 @@ -package api - -// Auth can be used by http middleware to -// 1) consult required roles in @Auth.Required -// 2) update active roles in @Auth.Active -type Auth struct { - // required roles for this request - // - the first dimension of the array reads as a OR - // - the second dimension reads as a AND - // - // Example: - // [ [A, B], [C, D] ] reads: roles (A and B) or (C and D) are required - // - // Warning: must not be mutated - Required [][]string - - // active roles to be updated by authentication - // procedures (e.g. jwt) - Active []string -} - -// Granted returns whether the authorization is granted -// i.e. Auth.Active fulfills Auth.Required -func (a *Auth) Granted() bool { - var nothingRequired = true - - // first dimension: OR ; at least one is valid - for _, required := range a.Required { - // empty list - if len(required) < 1 { - continue - } - - nothingRequired = false - - // second dimension: AND ; all required must be fulfilled - if a.fulfills(required) { - return true - } - } - - return nothingRequired -} - -// returns whether Auth.Active fulfills (contains) all @required roles -func (a *Auth) fulfills(required []string) bool { - for _, requiredRole := range required { - var found = false - for _, activeRole := range a.Active { - if activeRole == requiredRole { - found = true - break - } - } - // missing role -> fail - if !found { - return false - } - } - // all @required are fulfilled - return true -} +package api + +// Auth can be used by http middleware to +// 1) consult required roles in @Auth.Required +// 2) update active roles in @Auth.Active +type Auth struct { + // required roles for this request + // - the first dimension of the array reads as a OR + // - the second dimension reads as a AND + // + // Example: + // [ [A, B], [C, D] ] reads: roles (A and B) or (C and D) are required + // + // Warning: must not be mutated + Required [][]string + + // active roles to be updated by authentication + // procedures (e.g. jwt) + Active []string +} + +// Granted returns whether the authorization is granted +// i.e. Auth.Active fulfills Auth.Required +func (a *Auth) Granted() bool { + var nothingRequired = true + + // first dimension: OR ; at least one is valid + for _, required := range a.Required { + // empty list + if len(required) < 1 { + continue + } + + nothingRequired = false + + // second dimension: AND ; all required must be fulfilled + if a.fulfills(required) { + return true + } + } + + return nothingRequired +} + +// returns whether Auth.Active fulfills (contains) all @required roles +func (a *Auth) fulfills(required []string) bool { + for _, requiredRole := range required { + var found = false + for _, activeRole := range a.Active { + if activeRole == requiredRole { + found = true + break + } + } + // missing role -> fail + if !found { + return false + } + } + // all @required are fulfilled + return true +} diff --git a/api/auth_test.go b/api/auth_test.go index a230b59..8d63b42 100644 --- a/api/auth_test.go +++ b/api/auth_test.go @@ -1,108 +1,114 @@ -package api - -import ( - "testing" -) - -func TestCombination(t *testing.T) { - tcases := []struct { - Name string - Required [][]string - Active []string - Granted bool - }{ - { - Name: "no requirement none given", - Required: [][]string{}, - Active: []string{}, - Granted: true, - }, - { - Name: "no requirement 1 given", - Required: [][]string{}, - Active: []string{"a"}, - Granted: true, - }, - { - Name: "no requirement some given", - Required: [][]string{}, - Active: []string{"a", "b"}, - Granted: true, - }, - - { - Name: "1 required none given", - Required: [][]string{{"a"}}, - Active: []string{}, - Granted: false, - }, - { - Name: "1 required fulfilled", - Required: [][]string{{"a"}}, - Active: []string{"a"}, - Granted: true, - }, - { - Name: "1 required mismatch", - Required: [][]string{{"a"}}, - Active: []string{"b"}, - Granted: false, - }, - { - Name: "2 required none gien", - Required: [][]string{{"a", "b"}}, - Active: []string{}, - Granted: false, - }, - { - Name: "2 required other given", - Required: [][]string{{"a", "b"}}, - Active: []string{"c"}, - Granted: false, - }, - { - Name: "2 required one given", - Required: [][]string{{"a", "b"}}, - Active: []string{"a"}, - Granted: false, - }, - { - Name: "2 required fulfilled", - Required: [][]string{{"a", "b"}}, - Active: []string{"a", "b"}, - Granted: true, - }, - { - Name: "2 or 2 required first fulfilled", - Required: [][]string{{"a", "b"}, {"c", "d"}}, - Active: []string{"a", "b"}, - Granted: true, - }, - { - Name: "2 or 2 required second fulfilled", - Required: [][]string{{"a", "b"}, {"c", "d"}}, - Active: []string{"c", "d"}, - Granted: true, - }, - } - - for _, tcase := range tcases { - t.Run(tcase.Name, func(t *testing.T) { - - auth := Auth{ - Required: tcase.Required, - Active: tcase.Active, - } - - // all right - if tcase.Granted == auth.Granted() { - return - } - - if tcase.Granted && !auth.Granted() { - t.Fatalf("expected granted authorization") - } - t.Fatalf("unexpected granted authorization") - }) - } -} +package api + +import ( + "testing" +) + +func TestCombination(t *testing.T) { + tcases := []struct { + Name string + Required [][]string + Active []string + Granted bool + }{ + { + Name: "no requirement none given", + Required: [][]string{}, + Active: []string{}, + Granted: true, + }, + { + Name: "empty requirements none given", + Required: [][]string{{}}, + Active: []string{}, + Granted: true, + }, + { + Name: "no requirement 1 given", + Required: [][]string{}, + Active: []string{"a"}, + Granted: true, + }, + { + Name: "no requirement some given", + Required: [][]string{}, + Active: []string{"a", "b"}, + Granted: true, + }, + + { + Name: "1 required none given", + Required: [][]string{{"a"}}, + Active: []string{}, + Granted: false, + }, + { + Name: "1 required fulfilled", + Required: [][]string{{"a"}}, + Active: []string{"a"}, + Granted: true, + }, + { + Name: "1 required mismatch", + Required: [][]string{{"a"}}, + Active: []string{"b"}, + Granted: false, + }, + { + Name: "2 required none gien", + Required: [][]string{{"a", "b"}}, + Active: []string{}, + Granted: false, + }, + { + Name: "2 required other given", + Required: [][]string{{"a", "b"}}, + Active: []string{"c"}, + Granted: false, + }, + { + Name: "2 required one given", + Required: [][]string{{"a", "b"}}, + Active: []string{"a"}, + Granted: false, + }, + { + Name: "2 required fulfilled", + Required: [][]string{{"a", "b"}}, + Active: []string{"a", "b"}, + Granted: true, + }, + { + Name: "2 or 2 required first fulfilled", + Required: [][]string{{"a", "b"}, {"c", "d"}}, + Active: []string{"a", "b"}, + Granted: true, + }, + { + Name: "2 or 2 required second fulfilled", + Required: [][]string{{"a", "b"}, {"c", "d"}}, + Active: []string{"c", "d"}, + Granted: true, + }, + } + + for _, tcase := range tcases { + t.Run(tcase.Name, func(t *testing.T) { + + auth := Auth{ + Required: tcase.Required, + Active: tcase.Active, + } + + // all right + if tcase.Granted == auth.Granted() { + return + } + + if tcase.Granted && !auth.Granted() { + t.Fatalf("expected granted authorization") + } + t.Fatalf("unexpected granted authorization") + }) + } +} diff --git a/api/context_test.go b/api/context_test.go new file mode 100644 index 0000000..40568ce --- /dev/null +++ b/api/context_test.go @@ -0,0 +1,79 @@ +package api_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/xdrm-io/aicra/api" + "github.com/xdrm-io/aicra/internal/ctx" +) + +func TestContextGetRequest(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/random", nil) + if err != nil { + t.Fatalf("cannot create http request: %s", err) + } + + // store in bare context + c := context.Background() + c = context.WithValue(c, ctx.Request, req) + + // fetch from context + fetched := api.GetRequest(c) + if fetched != req { + t.Fatalf("fetched http request %v ; expected %v", fetched, req) + } +} +func TestContextGetNilRequest(t *testing.T) { + // fetch from bare context + fetched := api.GetRequest(context.Background()) + if fetched != nil { + t.Fatalf("fetched http request %v from empty context; expected nil", fetched) + } +} + +func TestContextGetResponseWriter(t *testing.T) { + res := httptest.NewRecorder() + + // store in bare context + c := context.Background() + c = context.WithValue(c, ctx.Response, res) + + // fetch from context + fetched := api.GetResponseWriter(c) + if fetched != res { + t.Fatalf("fetched http response writer %v ; expected %v", fetched, res) + } +} + +func TestContextGetNilResponseWriter(t *testing.T) { + // fetch from bare context + fetched := api.GetResponseWriter(context.Background()) + if fetched != nil { + t.Fatalf("fetched http response writer %v from empty context; expected nil", fetched) + } +} + +func TestContextGetAuth(t *testing.T) { + auth := &api.Auth{} + + // store in bare context + c := context.Background() + c = context.WithValue(c, ctx.Auth, auth) + + // fetch from context + fetched := api.GetAuth(c) + if fetched != auth { + t.Fatalf("fetched api auth %v ; expected %v", fetched, auth) + } +} + +func TestContextGetNilAuth(t *testing.T) { + // fetch from bare context + fetched := api.GetAuth(context.Background()) + if fetched != nil { + t.Fatalf("fetched api auth %v from empty context; expected nil", fetched) + } +} diff --git a/api/response.go b/api/response.go deleted file mode 100644 index ae403f8..0000000 --- a/api/response.go +++ /dev/null @@ -1,63 +0,0 @@ -package api - -import ( - "encoding/json" - "net/http" -) - -// ResponseData defines format for response parameters to return -type ResponseData map[string]interface{} - -// Response represents an API response to be sent -type Response struct { - Data ResponseData - Status int - Headers http.Header - err Err -} - -// EmptyResponse creates an empty response. -func EmptyResponse() *Response { - return &Response{ - Status: http.StatusOK, - Data: make(ResponseData), - err: ErrFailure, - Headers: make(http.Header), - } -} - -// WithError sets the error -func (res *Response) WithError(err Err) *Response { - res.err = err - return res -} - -func (res *Response) Error() string { - return res.err.Error() -} - -// SetData adds/overrides a new response field -func (res *Response) SetData(name string, value interface{}) { - res.Data[name] = value -} - -// MarshalJSON implements the 'json.Marshaler' interface and is used -// to generate the JSON representation of the response -func (res *Response) MarshalJSON() ([]byte, error) { - fmt := make(map[string]interface{}) - for k, v := range res.Data { - fmt[k] = v - } - fmt["error"] = res.err - return json.Marshal(fmt) -} - -func (res *Response) ServeHTTP(w http.ResponseWriter, r *http.Request) error { - w.WriteHeader(res.err.Status) - encoded, err := json.Marshal(res) - if err != nil { - return err - } - w.Write(encoded) - return nil -} diff --git a/builder.go b/builder.go index fc17243..63cf5a9 100644 --- a/builder.go +++ b/builder.go @@ -20,7 +20,7 @@ type Builder struct { middlewares []func(http.Handler) http.Handler // custom middlewares only wrapping the service handler of a request // they will benefit from the request's context that contains service-specific - // information (e.g. required permisisons from the configuration) + // information (e.g. required permissions from the configuration) ctxMiddlewares []func(http.Handler) http.Handler } @@ -52,9 +52,6 @@ func (b *Builder) Validate(t validator.Type) error { // the service associated with the request has not been found at this stage. // This stage is perfect for logging or generic request management. func (b *Builder) With(mw func(http.Handler) http.Handler) { - if b.conf == nil { - b.conf = &config.Server{} - } if b.middlewares == nil { b.middlewares = make([]func(http.Handler) http.Handler, 0) } @@ -69,9 +66,6 @@ func (b *Builder) With(mw func(http.Handler) http.Handler) { // data that can be access with api.GetRequest(), api.GetResponseWriter(), // api.GetAuth(), etc methods. func (b *Builder) WithContext(mw func(http.Handler) http.Handler) { - if b.conf == nil { - b.conf = &config.Server{} - } if b.ctxMiddlewares == nil { b.ctxMiddlewares = make([]func(http.Handler) http.Handler, 0) } @@ -85,14 +79,14 @@ func (b *Builder) Setup(r io.Reader) error { b.conf = &config.Server{} } if b.conf.Services != nil { - panic(errAlreadySetup) + return errAlreadySetup } return b.conf.Parse(r) } // Bind a dynamic handler to a REST service (method and pattern) func (b *Builder) Bind(method, path string, fn interface{}) error { - if b.conf.Services == nil { + if b.conf == nil || b.conf.Services == nil { return errNotSetup } diff --git a/builder_test.go b/builder_test.go index fbbebe0..235901e 100644 --- a/builder_test.go +++ b/builder_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/xdrm-io/aicra/api" + "github.com/xdrm-io/aicra/internal/dynfunc" "github.com/xdrm-io/aicra/validator" ) @@ -34,6 +35,8 @@ func addBuiltinTypes(b *Builder) error { } func TestAddType(t *testing.T) { + t.Parallel() + builder := &Builder{} err := builder.Validate(validator.BoolType{}) if err != nil { @@ -49,7 +52,180 @@ func TestAddType(t *testing.T) { } } +func TestSetupNoType(t *testing.T) { + t.Parallel() + + builder := &Builder{} + err := builder.Setup(strings.NewReader("[]")) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } +} +func TestSetupTwice(t *testing.T) { + t.Parallel() + + builder := &Builder{} + err := builder.Setup(strings.NewReader("[]")) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + // double Setup() must fail + err = builder.Setup(strings.NewReader("[]")) + if err != errAlreadySetup { + t.Fatalf("expected error %v, got %v", errAlreadySetup, err) + } +} + +func TestBindBeforeSetup(t *testing.T) { + t.Parallel() + + builder := &Builder{} + // binding before Setup() must fail + err := builder.Bind(http.MethodGet, "/path", func() {}) + if err != errNotSetup { + t.Fatalf("expected error %v, got %v", errNotSetup, err) + } +} + +func TestBindUnknownService(t *testing.T) { + t.Parallel() + + builder := &Builder{} + err := builder.Setup(strings.NewReader("[]")) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + err = builder.Bind(http.MethodGet, "/path", func() {}) + if !errors.Is(err, errUnknownService) { + t.Fatalf("expected error %v, got %v", errUnknownService, err) + } +} +func TestBindInvalidHandler(t *testing.T) { + t.Parallel() + + builder := &Builder{} + err := builder.Setup(strings.NewReader(`[ + { + "method": "GET", + "path": "/path", + "scope": [[]], + "info": "info", + "in": {}, + "out": {} + } + ]`)) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + err = builder.Bind(http.MethodGet, "/path", func() {}) + + if err == nil { + t.Fatalf("expected an error") + } + + if !errors.Is(err, dynfunc.ErrMissingHandlerContextArgument) { + t.Fatalf("expected a dynfunc.Err got %v", err) + } +} +func TestBindGet(t *testing.T) { + t.Parallel() + + builder := &Builder{} + err := builder.Setup(strings.NewReader(`[ + { + "method": "GET", + "path": "/path", + "scope": [[]], + "info": "info", + "in": {}, + "out": {} + }, + { + "method": "POST", + "path": "/path", + "scope": [[]], + "info": "info", + "in": {}, + "out": {} + }, + { + "method": "PUT", + "path": "/path", + "scope": [[]], + "info": "info", + "in": {}, + "out": {} + }, + { + "method": "DELETE", + "path": "/path", + "scope": [[]], + "info": "info", + "in": {}, + "out": {} + } + ]`)) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + err = builder.Get("/path", func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }) + if err != nil { + t.Fatalf("unexpected error %v", err) + } + err = builder.Post("/path", func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }) + if err != nil { + t.Fatalf("unexpected error %v", err) + } + err = builder.Put("/path", func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }) + if err != nil { + t.Fatalf("unexpected error %v", err) + } + err = builder.Delete("/path", func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }) + if err != nil { + t.Fatalf("unexpected error %v", err) + } +} + +func TestUnhandledService(t *testing.T) { + t.Parallel() + + builder := &Builder{} + err := builder.Setup(strings.NewReader(`[ + { + "method": "GET", + "path": "/path", + "scope": [[]], + "info": "info", + "in": {}, + "out": {} + }, + { + "method": "POST", + "path": "/path", + "scope": [[]], + "info": "info", + "in": {}, + "out": {} + } + ]`)) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + err = builder.Get("/path", func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }) + if err != nil { + t.Fatalf("unexpected error %v", err) + } + + _, err = builder.Build() + if !errors.Is(err, errMissingHandler) { + t.Fatalf("expected a %v error, got %v", errMissingHandler, err) + } +} func TestBind(t *testing.T) { + t.Parallel() + tcases := []struct { Name string Config string diff --git a/handler.go b/handler.go index de3c151..d5553c1 100644 --- a/handler.go +++ b/handler.go @@ -2,6 +2,7 @@ package aicra import ( "context" + "errors" "fmt" "net/http" "strings" @@ -27,21 +28,25 @@ func (s Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // ServeHTTP implements http.Handler and wraps it in middlewares (adapters) func (s Handler) resolve(w http.ResponseWriter, r *http.Request) { - // 1ind a matching service from config + // match service from config var service = s.conf.Find(r) if service == nil { - handleError(api.ErrUnknownService, w, r) + newResponse().WithError(api.ErrUnknownService).ServeHTTP(w, r) return } // extract request data var input, err = extractInput(service, *r) if err != nil { - handleError(api.ErrMissingParam, w, r) + if errors.Is(err, reqdata.ErrInvalidType) { + newResponse().WithError(api.ErrInvalidParam).ServeHTTP(w, r) + } else { + newResponse().WithError(api.ErrMissingParam).ServeHTTP(w, r) + } return } - // find a matching handler + // match handler var handler *apiHandler for _, h := range s.handlers { if h.Method == service.Method && h.Path == service.Pattern { @@ -49,13 +54,13 @@ func (s Handler) resolve(w http.ResponseWriter, r *http.Request) { } } - // fail on no matching handler + // no handler found if handler == nil { - handleError(api.ErrUncallableService, w, r) + newResponse().WithError(api.ErrUncallableService).ServeHTTP(w, r) return } - // build context with builtin data + // add info into context c := r.Context() c = context.WithValue(c, ctx.Request, r) c = context.WithValue(c, ctx.Response, w) @@ -63,63 +68,55 @@ func (s Handler) resolve(w http.ResponseWriter, r *http.Request) { // create http handler var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // should not happen auth := api.GetAuth(r.Context()) if auth == nil { - handleError(api.ErrPermission, w, r) + newResponse().WithError(api.ErrPermission).ServeHTTP(w, r) return } // reject non granted requests if !auth.Granted() { - handleError(api.ErrPermission, w, r) + newResponse().WithError(api.ErrPermission).ServeHTTP(w, r) return } - // use context defined in the request + // execute the service handler s.handle(r.Context(), input, handler, service, w, r) }) - // run middlewares the handler + // run contextual middlewares for _, mw := range s.ctxMiddlewares { h = mw(h) } - // serve using the context with values + // serve using the pre-filled context h.ServeHTTP(w, r.WithContext(c)) } +// handle the service request with the associated handler func and respond using +// the handler func output func (s *Handler) handle(c context.Context, input *reqdata.T, handler *apiHandler, service *config.Service, w http.ResponseWriter, r *http.Request) { - // pass execution to the handler + // pass execution to the handler function var outData, outErr = handler.dyn.Handle(c, input.Data) - // build response from returned arguments - var res = api.EmptyResponse().WithError(outErr) + // build response from output arguments + var res = newResponse().WithError(outErr) for key, value := range outData { // find original name from 'rename' field for name, param := range service.Output { if param.Rename == key { - res.SetData(name, value) + res.WithValue(name, value) } } } - // 7. apply headers + // write response and close request w.Header().Set("Content-Type", "application/json; charset=utf-8") - for key, values := range res.Headers { - for _, value := range values { - w.Header().Add(key, value) - } - } - res.ServeHTTP(w, r) } -func handleError(err api.Err, w http.ResponseWriter, r *http.Request) { - var response = api.EmptyResponse().WithError(err) - response.ServeHTTP(w, r) -} - func extractInput(service *config.Service, req http.Request) (*reqdata.T, error) { var dataset = reqdata.New(service) diff --git a/handler_test.go b/handler_test.go index fe3327e..eeaca6b 100644 --- a/handler_test.go +++ b/handler_test.go @@ -1,623 +1,1138 @@ -package aicra_test - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/xdrm-io/aicra" - "github.com/xdrm-io/aicra/api" - "github.com/xdrm-io/aicra/validator" -) - -func addBuiltinTypes(b *aicra.Builder) error { - if err := b.Validate(validator.AnyType{}); err != nil { - return err - } - if err := b.Validate(validator.BoolType{}); err != nil { - return err - } - if err := b.Validate(validator.FloatType{}); err != nil { - return err - } - if err := b.Validate(validator.IntType{}); err != nil { - return err - } - if err := b.Validate(validator.StringType{}); err != nil { - return err - } - if err := b.Validate(validator.UintType{}); err != nil { - return err - } - return nil -} - -func TestWith(t *testing.T) { - builder := &aicra.Builder{} - if err := addBuiltinTypes(builder); err != nil { - t.Fatalf("unexpected error <%v>", err) - } - - // build @n middlewares that take data from context and increment it - n := 1024 - - type ckey int - const key ckey = 0 - - middleware := func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - newr := r - - // first time -> store 1 - value := r.Context().Value(key) - if value == nil { - newr = r.WithContext(context.WithValue(r.Context(), key, int(1))) - next.ServeHTTP(w, newr) - return - } - - // get value and increment - cast, ok := value.(int) - if !ok { - t.Fatalf("value is not an int") - } - cast++ - newr = r.WithContext(context.WithValue(r.Context(), key, cast)) - next.ServeHTTP(w, newr) - }) - } - - // add middleware @n times - for i := 0; i < n; i++ { - builder.With(middleware) - } - - config := strings.NewReader(`[ { "method": "GET", "path": "/path", "scope": [[]], "info": "info", "in": {}, "out": {} } ]`) - err := builder.Setup(config) - if err != nil { - t.Fatalf("setup: unexpected error <%v>", err) - } - - pathHandler := func(ctx context.Context) (*struct{}, api.Err) { - // write value from middlewares into response - value := ctx.Value(key) - if value == nil { - t.Fatalf("nothing found in context") - } - cast, ok := value.(int) - if !ok { - t.Fatalf("cannot cast context data to int") - } - // write to response - api.GetResponseWriter(ctx).Write([]byte(fmt.Sprintf("#%d#", cast))) - - return nil, api.ErrSuccess - } - - if err := builder.Bind(http.MethodGet, "/path", pathHandler); err != nil { - t.Fatalf("bind: unexpected error <%v>", err) - } - - handler, err := builder.Build() - if err != nil { - t.Fatalf("build: unexpected error <%v>", err) - } - - response := httptest.NewRecorder() - request := httptest.NewRequest(http.MethodGet, "/path", &bytes.Buffer{}) - - // test request - handler.ServeHTTP(response, request) - if response.Body == nil { - t.Fatalf("response has no body") - } - token := fmt.Sprintf("#%d#", n) - if !strings.Contains(response.Body.String(), token) { - t.Fatalf("expected '%s' to be in response <%s>", token, response.Body.String()) - } - -} - -func TestWithAuth(t *testing.T) { - - tt := []struct { - name string - manifest string - permissions []string - granted bool - }{ - { - name: "provide only requirement A", - manifest: `[ { "method": "GET", "path": "/path", "scope": [["A"]], "info": "info", "in": {}, "out": {} } ]`, - permissions: []string{"A"}, - granted: true, - }, - { - name: "missing requirement", - manifest: `[ { "method": "GET", "path": "/path", "scope": [["A"]], "info": "info", "in": {}, "out": {} } ]`, - permissions: []string{}, - granted: false, - }, - { - name: "missing requirements", - manifest: `[ { "method": "GET", "path": "/path", "scope": [["A", "B"]], "info": "info", "in": {}, "out": {} } ]`, - permissions: []string{}, - granted: false, - }, - { - name: "missing some requirements", - manifest: `[ { "method": "GET", "path": "/path", "scope": [["A", "B"]], "info": "info", "in": {}, "out": {} } ]`, - permissions: []string{"A"}, - granted: false, - }, - { - name: "provide requirements", - manifest: `[ { "method": "GET", "path": "/path", "scope": [["A", "B"]], "info": "info", "in": {}, "out": {} } ]`, - permissions: []string{"A", "B"}, - granted: true, - }, - { - name: "missing OR requirements", - manifest: `[ { "method": "GET", "path": "/path", "scope": [["A"], ["B"]], "info": "info", "in": {}, "out": {} } ]`, - permissions: []string{"C"}, - granted: false, - }, - { - name: "provide 1 OR requirement", - manifest: `[ { "method": "GET", "path": "/path", "scope": [["A"], ["B"]], "info": "info", "in": {}, "out": {} } ]`, - permissions: []string{"A"}, - granted: true, - }, - { - name: "provide both OR requirements", - manifest: `[ { "method": "GET", "path": "/path", "scope": [["A"], ["B"]], "info": "info", "in": {}, "out": {} } ]`, - permissions: []string{"A", "B"}, - granted: true, - }, - { - name: "missing composite OR requirements", - manifest: `[ { "method": "GET", "path": "/path", "scope": [["A", "B"], ["C", "D"]], "info": "info", "in": {}, "out": {} } ]`, - permissions: []string{}, - granted: false, - }, - { - name: "missing partial composite OR requirements", - manifest: `[ { "method": "GET", "path": "/path", "scope": [["A", "B"], ["C", "D"]], "info": "info", "in": {}, "out": {} } ]`, - permissions: []string{"A", "C"}, - granted: false, - }, - { - name: "provide 1 composite OR requirement", - manifest: `[ { "method": "GET", "path": "/path", "scope": [["A", "B"], ["C", "D"]], "info": "info", "in": {}, "out": {} } ]`, - permissions: []string{"A", "B", "C"}, - granted: true, - }, - { - name: "provide both composite OR requirements", - manifest: `[ { "method": "GET", "path": "/path", "scope": [["A", "B"], ["C", "D"]], "info": "info", "in": {}, "out": {} } ]`, - permissions: []string{"A", "B", "C", "D"}, - granted: true, - }, - } - - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - builder := &aicra.Builder{} - if err := addBuiltinTypes(builder); err != nil { - t.Fatalf("unexpected error <%v>", err) - } - - // tester middleware (last executed) - builder.WithContext(func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - a := api.GetAuth(r.Context()) - if a == nil { - t.Fatalf("cannot access api.Auth form request context") - } - - if a.Granted() == tc.granted { - return - } - if a.Granted() { - t.Fatalf("unexpected granted auth") - } else { - t.Fatalf("expected granted auth") - } - next.ServeHTTP(w, r) - }) - }) - - builder.WithContext(func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - a := api.GetAuth(r.Context()) - if a == nil { - t.Fatalf("cannot access api.Auth form request context") - } - - a.Active = tc.permissions - next.ServeHTTP(w, r) - }) - }) - - err := builder.Setup(strings.NewReader(tc.manifest)) - if err != nil { - t.Fatalf("setup: unexpected error <%v>", err) - } - - pathHandler := func(ctx context.Context) (*struct{}, api.Err) { - return nil, api.ErrNotImplemented - } - - if err := builder.Bind(http.MethodGet, "/path", pathHandler); err != nil { - t.Fatalf("bind: unexpected error <%v>", err) - } - - handler, err := builder.Build() - if err != nil { - t.Fatalf("build: unexpected error <%v>", err) - } - - response := httptest.NewRecorder() - request := httptest.NewRequest(http.MethodGet, "/path", &bytes.Buffer{}) - - // test request - handler.ServeHTTP(response, request) - if response.Body == nil { - t.Fatalf("response has no body") - } - - }) - } - -} - -func TestPermissionError(t *testing.T) { - - tt := []struct { - name string - manifest string - permissions []string - granted bool - }{ - { - name: "permission fulfilled", - manifest: `[ { "method": "GET", "path": "/path", "scope": [["A"]], "info": "info", "in": {}, "out": {} } ]`, - permissions: []string{"A"}, - granted: true, - }, - { - name: "missing permission", - manifest: `[ { "method": "GET", "path": "/path", "scope": [["A"]], "info": "info", "in": {}, "out": {} } ]`, - permissions: []string{}, - granted: false, - }, - } - - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - builder := &aicra.Builder{} - if err := addBuiltinTypes(builder); err != nil { - t.Fatalf("unexpected error <%v>", err) - } - - // add active permissions - builder.WithContext(func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - a := api.GetAuth(r.Context()) - if a == nil { - t.Fatalf("cannot access api.Auth form request context") - } - - a.Active = tc.permissions - next.ServeHTTP(w, r) - }) - }) - - err := builder.Setup(strings.NewReader(tc.manifest)) - if err != nil { - t.Fatalf("setup: unexpected error <%v>", err) - } - - pathHandler := func(ctx context.Context) (*struct{}, api.Err) { - return nil, api.ErrNotImplemented - } - - if err := builder.Bind(http.MethodGet, "/path", pathHandler); err != nil { - t.Fatalf("bind: unexpected error <%v>", err) - } - - handler, err := builder.Build() - if err != nil { - t.Fatalf("build: unexpected error <%v>", err) - } - - response := httptest.NewRecorder() - request := httptest.NewRequest(http.MethodGet, "/path", &bytes.Buffer{}) - - // test request - handler.ServeHTTP(response, request) - if response.Body == nil { - t.Fatalf("response has no body") - } - type jsonResponse struct { - Err api.Err `json:"error"` - } - var res jsonResponse - err = json.Unmarshal(response.Body.Bytes(), &res) - if err != nil { - t.Fatalf("cannot unmarshal response: %s", err) - } - - expectedError := api.ErrNotImplemented - if !tc.granted { - expectedError = api.ErrPermission - } - - if res.Err.Code != expectedError.Code { - t.Fatalf("expected error code %d got %d", expectedError.Code, res.Err.Code) - } - - }) - } - -} - -func TestDynamicScope(t *testing.T) { - tt := []struct { - name string - manifest string - path string - handler interface{} - url string - body string - permissions []string - granted bool - }{ - { - name: "replace one granted", - manifest: `[ - { - "method": "POST", - "path": "/path/{id}", - "info": "info", - "scope": [["user[Input1]"]], - "in": { - "{id}": { "info": "info", "name": "Input1", "type": "uint" } - }, - "out": {} - } - ]`, - path: "/path/{id}", - handler: func(context.Context, struct{ Input1 uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, - url: "/path/123", - body: ``, - permissions: []string{"user[123]"}, - granted: true, - }, - { - name: "replace one mismatch", - manifest: `[ - { - "method": "POST", - "path": "/path/{id}", - "info": "info", - "scope": [["user[Input1]"]], - "in": { - "{id}": { "info": "info", "name": "Input1", "type": "uint" } - }, - "out": {} - } - ]`, - path: "/path/{id}", - handler: func(context.Context, struct{ Input1 uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, - url: "/path/666", - body: ``, - permissions: []string{"user[123]"}, - granted: false, - }, - { - name: "replace one valid dot separated", - manifest: `[ - { - "method": "POST", - "path": "/path/{id}", - "info": "info", - "scope": [["prefix.user[User].suffix"]], - "in": { - "{id}": { "info": "info", "name": "User", "type": "uint" } - }, - "out": {} - } - ]`, - path: "/path/{id}", - handler: func(context.Context, struct{ User uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, - url: "/path/123", - body: ``, - permissions: []string{"prefix.user[123].suffix"}, - granted: true, - }, - { - name: "replace two valid dot separated", - manifest: `[ - { - "method": "POST", - "path": "/prefix/{pid}/user/{uid}", - "info": "info", - "scope": [["prefix[Prefix].user[User].suffix"]], - "in": { - "{pid}": { "info": "info", "name": "Prefix", "type": "uint" }, - "{uid}": { "info": "info", "name": "User", "type": "uint" } - }, - "out": {} - } - ]`, - path: "/prefix/{pid}/user/{uid}", - handler: func(context.Context, struct { - Prefix uint - User uint - }) (*struct{}, api.Err) { - return nil, api.ErrSuccess - }, - url: "/prefix/123/user/456", - body: ``, - permissions: []string{"prefix[123].user[456].suffix"}, - granted: true, - }, - { - name: "replace two invalid dot separated", - manifest: `[ - { - "method": "POST", - "path": "/prefix/{pid}/user/{uid}", - "info": "info", - "scope": [["prefix[Prefix].user[User].suffix"]], - "in": { - "{pid}": { "info": "info", "name": "Prefix", "type": "uint" }, - "{uid}": { "info": "info", "name": "User", "type": "uint" } - }, - "out": {} - } - ]`, - path: "/prefix/{pid}/user/{uid}", - handler: func(context.Context, struct { - Prefix uint - User uint - }) (*struct{}, api.Err) { - return nil, api.ErrSuccess - }, - url: "/prefix/123/user/666", - body: ``, - permissions: []string{"prefix[123].user[456].suffix"}, - granted: false, - }, - { - name: "replace three valid dot separated", - manifest: `[ - { - "method": "POST", - "path": "/prefix/{pid}/user/{uid}/suffix/{sid}", - "info": "info", - "scope": [["prefix[Prefix].user[User].suffix[Suffix]"]], - "in": { - "{pid}": { "info": "info", "name": "Prefix", "type": "uint" }, - "{uid}": { "info": "info", "name": "User", "type": "uint" }, - "{sid}": { "info": "info", "name": "Suffix", "type": "uint" } - }, - "out": {} - } - ]`, - path: "/prefix/{pid}/user/{uid}/suffix/{sid}", - handler: func(context.Context, struct { - Prefix uint - User uint - Suffix uint - }) (*struct{}, api.Err) { - return nil, api.ErrSuccess - }, - url: "/prefix/123/user/456/suffix/789", - body: ``, - permissions: []string{"prefix[123].user[456].suffix[789]"}, - granted: true, - }, - { - name: "replace three invalid dot separated", - manifest: `[ - { - "method": "POST", - "path": "/prefix/{pid}/user/{uid}/suffix/{sid}", - "info": "info", - "scope": [["prefix[Prefix].user[User].suffix[Suffix]"]], - "in": { - "{pid}": { "info": "info", "name": "Prefix", "type": "uint" }, - "{uid}": { "info": "info", "name": "User", "type": "uint" }, - "{sid}": { "info": "info", "name": "Suffix", "type": "uint" } - }, - "out": {} - } - ]`, - path: "/prefix/{pid}/user/{uid}/suffix/{sid}", - handler: func(context.Context, struct { - Prefix uint - User uint - Suffix uint - }) (*struct{}, api.Err) { - return nil, api.ErrSuccess - }, - url: "/prefix/123/user/666/suffix/789", - body: ``, - permissions: []string{"prefix[123].user[456].suffix[789]"}, - granted: false, - }, - } - - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - builder := &aicra.Builder{} - if err := addBuiltinTypes(builder); err != nil { - t.Fatalf("unexpected error <%v>", err) - } - - // tester middleware (last executed) - builder.WithContext(func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - a := api.GetAuth(r.Context()) - if a == nil { - t.Fatalf("cannot access api.Auth form request context") - } - if a.Granted() == tc.granted { - return - } - if a.Granted() { - t.Fatalf("unexpected granted auth") - } else { - t.Fatalf("expected granted auth") - } - next.ServeHTTP(w, r) - }) - }) - - // update permissions - builder.WithContext(func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - a := api.GetAuth(r.Context()) - if a == nil { - t.Fatalf("cannot access api.Auth form request context") - } - a.Active = tc.permissions - next.ServeHTTP(w, r) - }) - }) - - err := builder.Setup(strings.NewReader(tc.manifest)) - if err != nil { - t.Fatalf("setup: unexpected error <%v>", err) - } - - if err := builder.Bind(http.MethodPost, tc.path, tc.handler); err != nil { - t.Fatalf("bind: unexpected error <%v>", err) - } - - handler, err := builder.Build() - if err != nil { - t.Fatalf("build: unexpected error <%v>", err) - } - - response := httptest.NewRecorder() - body := strings.NewReader(tc.body) - request := httptest.NewRequest(http.MethodPost, tc.url, body) - - // test request - handler.ServeHTTP(response, request) - if response.Body == nil { - t.Fatalf("response has no body") - } - - }) - } - -} +package aicra_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/xdrm-io/aicra" + "github.com/xdrm-io/aicra/api" + "github.com/xdrm-io/aicra/validator" +) + +func printEscaped(raw string) string { + raw = strings.ReplaceAll(raw, "\n", "\\n") + raw = strings.ReplaceAll(raw, "\r", "\\r") + return raw +} + +func addDefaultTypes(b *aicra.Builder) error { + if err := b.Validate(validator.AnyType{}); err != nil { + return err + } + if err := b.Validate(validator.BoolType{}); err != nil { + return err + } + if err := b.Validate(validator.FloatType{}); err != nil { + return err + } + if err := b.Validate(validator.IntType{}); err != nil { + return err + } + if err := b.Validate(validator.StringType{}); err != nil { + return err + } + if err := b.Validate(validator.UintType{}); err != nil { + return err + } + return nil +} + +func TestHandler_With(t *testing.T) { + builder := &aicra.Builder{} + if err := addDefaultTypes(builder); err != nil { + t.Fatalf("unexpected error <%v>", err) + } + + // build @n middlewares that take data from context and increment it + n := 1024 + + type ckey int + const key ckey = 0 + + middleware := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // first time -> store 1 + value := r.Context().Value(key) + if value == nil { + r = r.WithContext(context.WithValue(r.Context(), key, int(1))) + next.ServeHTTP(w, r) + return + } + + // get value and increment + cast, ok := value.(int) + if !ok { + t.Fatalf("value is not an int") + } + cast++ + r = r.WithContext(context.WithValue(r.Context(), key, cast)) + next.ServeHTTP(w, r) + }) + } + + // add middleware @n times + for i := 0; i < n; i++ { + builder.With(middleware) + } + + config := strings.NewReader(`[ { "method": "GET", "path": "/path", "scope": [[]], "info": "info", "in": {}, "out": {} } ]`) + err := builder.Setup(config) + if err != nil { + t.Fatalf("setup: unexpected error <%v>", err) + } + + pathHandler := func(ctx context.Context) (*struct{}, api.Err) { + // write value from middlewares into response + value := ctx.Value(key) + if value == nil { + t.Fatalf("nothing found in context") + } + cast, ok := value.(int) + if !ok { + t.Fatalf("cannot cast context data to int") + } + // write to response + api.GetResponseWriter(ctx).Write([]byte(fmt.Sprintf("#%d#", cast))) + + return nil, api.ErrSuccess + } + + if err := builder.Bind(http.MethodGet, "/path", pathHandler); err != nil { + t.Fatalf("bind: unexpected error <%v>", err) + } + + handler, err := builder.Build() + if err != nil { + t.Fatalf("build: unexpected error <%v>", err) + } + + response := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/path", &bytes.Buffer{}) + + // test request + handler.ServeHTTP(response, request) + if response.Body == nil { + t.Fatalf("response has no body") + } + token := fmt.Sprintf("#%d#", n) + if !strings.Contains(response.Body.String(), token) { + t.Fatalf("expected '%s' to be in response <%s>", token, response.Body.String()) + } + +} + +func TestHandler_WithAuth(t *testing.T) { + + tt := []struct { + name string + manifest string + permissions []string + granted bool + }{ + { + name: "provide only requirement A", + manifest: `[ { "method": "GET", "path": "/path", "scope": [["A"]], "info": "info", "in": {}, "out": {} } ]`, + permissions: []string{"A"}, + granted: true, + }, + { + name: "missing requirement", + manifest: `[ { "method": "GET", "path": "/path", "scope": [["A"]], "info": "info", "in": {}, "out": {} } ]`, + permissions: []string{}, + granted: false, + }, + { + name: "missing requirements", + manifest: `[ { "method": "GET", "path": "/path", "scope": [["A", "B"]], "info": "info", "in": {}, "out": {} } ]`, + permissions: []string{}, + granted: false, + }, + { + name: "missing some requirements", + manifest: `[ { "method": "GET", "path": "/path", "scope": [["A", "B"]], "info": "info", "in": {}, "out": {} } ]`, + permissions: []string{"A"}, + granted: false, + }, + { + name: "provide requirements", + manifest: `[ { "method": "GET", "path": "/path", "scope": [["A", "B"]], "info": "info", "in": {}, "out": {} } ]`, + permissions: []string{"A", "B"}, + granted: true, + }, + { + name: "missing OR requirements", + manifest: `[ { "method": "GET", "path": "/path", "scope": [["A"], ["B"]], "info": "info", "in": {}, "out": {} } ]`, + permissions: []string{"C"}, + granted: false, + }, + { + name: "provide 1 OR requirement", + manifest: `[ { "method": "GET", "path": "/path", "scope": [["A"], ["B"]], "info": "info", "in": {}, "out": {} } ]`, + permissions: []string{"A"}, + granted: true, + }, + { + name: "provide both OR requirements", + manifest: `[ { "method": "GET", "path": "/path", "scope": [["A"], ["B"]], "info": "info", "in": {}, "out": {} } ]`, + permissions: []string{"A", "B"}, + granted: true, + }, + { + name: "missing composite OR requirements", + manifest: `[ { "method": "GET", "path": "/path", "scope": [["A", "B"], ["C", "D"]], "info": "info", "in": {}, "out": {} } ]`, + permissions: []string{}, + granted: false, + }, + { + name: "missing partial composite OR requirements", + manifest: `[ { "method": "GET", "path": "/path", "scope": [["A", "B"], ["C", "D"]], "info": "info", "in": {}, "out": {} } ]`, + permissions: []string{"A", "C"}, + granted: false, + }, + { + name: "provide 1 composite OR requirement", + manifest: `[ { "method": "GET", "path": "/path", "scope": [["A", "B"], ["C", "D"]], "info": "info", "in": {}, "out": {} } ]`, + permissions: []string{"A", "B", "C"}, + granted: true, + }, + { + name: "provide both composite OR requirements", + manifest: `[ { "method": "GET", "path": "/path", "scope": [["A", "B"], ["C", "D"]], "info": "info", "in": {}, "out": {} } ]`, + permissions: []string{"A", "B", "C", "D"}, + granted: true, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + builder := &aicra.Builder{} + if err := addDefaultTypes(builder); err != nil { + t.Fatalf("unexpected error <%v>", err) + } + + // tester middleware (last executed) + builder.WithContext(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + a := api.GetAuth(r.Context()) + if a == nil { + t.Fatalf("cannot access api.Auth form request context") + } + + if a.Granted() == tc.granted { + return + } + if a.Granted() { + t.Fatalf("unexpected granted auth") + } else { + t.Fatalf("expected granted auth") + } + next.ServeHTTP(w, r) + }) + }) + + builder.WithContext(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + a := api.GetAuth(r.Context()) + if a == nil { + t.Fatalf("cannot access api.Auth form request context") + } + + a.Active = tc.permissions + next.ServeHTTP(w, r) + }) + }) + + err := builder.Setup(strings.NewReader(tc.manifest)) + if err != nil { + t.Fatalf("setup: unexpected error <%v>", err) + } + + pathHandler := func(ctx context.Context) (*struct{}, api.Err) { + return nil, api.ErrNotImplemented + } + + if err := builder.Bind(http.MethodGet, "/path", pathHandler); err != nil { + t.Fatalf("bind: unexpected error <%v>", err) + } + + handler, err := builder.Build() + if err != nil { + t.Fatalf("build: unexpected error <%v>", err) + } + + response := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/path", &bytes.Buffer{}) + + // test request + handler.ServeHTTP(response, request) + if response.Body == nil { + t.Fatalf("response has no body") + } + + }) + } + +} + +func TestHandler_PermissionError(t *testing.T) { + + tt := []struct { + name string + manifest string + permissions []string + granted bool + }{ + { + name: "permission fulfilled", + manifest: `[ { "method": "GET", "path": "/path", "scope": [["A"]], "info": "info", "in": {}, "out": {} } ]`, + permissions: []string{"A"}, + granted: true, + }, + { + name: "missing permission", + manifest: `[ { "method": "GET", "path": "/path", "scope": [["A"]], "info": "info", "in": {}, "out": {} } ]`, + permissions: []string{}, + granted: false, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + builder := &aicra.Builder{} + if err := addDefaultTypes(builder); err != nil { + t.Fatalf("unexpected error <%v>", err) + } + + // add active permissions + builder.WithContext(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + a := api.GetAuth(r.Context()) + if a == nil { + t.Fatalf("cannot access api.Auth form request context") + } + + a.Active = tc.permissions + next.ServeHTTP(w, r) + }) + }) + + err := builder.Setup(strings.NewReader(tc.manifest)) + if err != nil { + t.Fatalf("setup: unexpected error <%v>", err) + } + + pathHandler := func(ctx context.Context) (*struct{}, api.Err) { + return nil, api.ErrNotImplemented + } + + if err := builder.Bind(http.MethodGet, "/path", pathHandler); err != nil { + t.Fatalf("bind: unexpected error <%v>", err) + } + + handler, err := builder.Build() + if err != nil { + t.Fatalf("build: unexpected error <%v>", err) + } + + response := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/path", &bytes.Buffer{}) + + // test request + handler.ServeHTTP(response, request) + if response.Body == nil { + t.Fatalf("response has no body") + } + type jsonResponse struct { + Err api.Err `json:"error"` + } + var res jsonResponse + err = json.Unmarshal(response.Body.Bytes(), &res) + if err != nil { + t.Fatalf("cannot unmarshal response: %s", err) + } + + expectedError := api.ErrNotImplemented + if !tc.granted { + expectedError = api.ErrPermission + } + + if res.Err.Code != expectedError.Code { + t.Fatalf("expected error code %d got %d", expectedError.Code, res.Err.Code) + } + + }) + } + +} + +func TestHandler_DynamicScope(t *testing.T) { + tt := []struct { + name string + manifest string + path string + handler interface{} + url string + body string + permissions []string + granted bool + }{ + { + name: "replace one granted", + manifest: `[ + { + "method": "POST", + "path": "/path/{id}", + "info": "info", + "scope": [["user[Input1]"]], + "in": { + "{id}": { "info": "info", "name": "Input1", "type": "uint" } + }, + "out": {} + } + ]`, + path: "/path/{id}", + handler: func(context.Context, struct{ Input1 uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + url: "/path/123", + body: ``, + permissions: []string{"user[123]"}, + granted: true, + }, + { + name: "replace one mismatch", + manifest: `[ + { + "method": "POST", + "path": "/path/{id}", + "info": "info", + "scope": [["user[Input1]"]], + "in": { + "{id}": { "info": "info", "name": "Input1", "type": "uint" } + }, + "out": {} + } + ]`, + path: "/path/{id}", + handler: func(context.Context, struct{ Input1 uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + url: "/path/666", + body: ``, + permissions: []string{"user[123]"}, + granted: false, + }, + { + name: "replace one valid dot separated", + manifest: `[ + { + "method": "POST", + "path": "/path/{id}", + "info": "info", + "scope": [["prefix.user[User].suffix"]], + "in": { + "{id}": { "info": "info", "name": "User", "type": "uint" } + }, + "out": {} + } + ]`, + path: "/path/{id}", + handler: func(context.Context, struct{ User uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + url: "/path/123", + body: ``, + permissions: []string{"prefix.user[123].suffix"}, + granted: true, + }, + { + name: "replace two valid dot separated", + manifest: `[ + { + "method": "POST", + "path": "/prefix/{pid}/user/{uid}", + "info": "info", + "scope": [["prefix[Prefix].user[User].suffix"]], + "in": { + "{pid}": { "info": "info", "name": "Prefix", "type": "uint" }, + "{uid}": { "info": "info", "name": "User", "type": "uint" } + }, + "out": {} + } + ]`, + path: "/prefix/{pid}/user/{uid}", + handler: func(context.Context, struct { + Prefix uint + User uint + }) (*struct{}, api.Err) { + return nil, api.ErrSuccess + }, + url: "/prefix/123/user/456", + body: ``, + permissions: []string{"prefix[123].user[456].suffix"}, + granted: true, + }, + { + name: "replace two invalid dot separated", + manifest: `[ + { + "method": "POST", + "path": "/prefix/{pid}/user/{uid}", + "info": "info", + "scope": [["prefix[Prefix].user[User].suffix"]], + "in": { + "{pid}": { "info": "info", "name": "Prefix", "type": "uint" }, + "{uid}": { "info": "info", "name": "User", "type": "uint" } + }, + "out": {} + } + ]`, + path: "/prefix/{pid}/user/{uid}", + handler: func(context.Context, struct { + Prefix uint + User uint + }) (*struct{}, api.Err) { + return nil, api.ErrSuccess + }, + url: "/prefix/123/user/666", + body: ``, + permissions: []string{"prefix[123].user[456].suffix"}, + granted: false, + }, + { + name: "replace three valid dot separated", + manifest: `[ + { + "method": "POST", + "path": "/prefix/{pid}/user/{uid}/suffix/{sid}", + "info": "info", + "scope": [["prefix[Prefix].user[User].suffix[Suffix]"]], + "in": { + "{pid}": { "info": "info", "name": "Prefix", "type": "uint" }, + "{uid}": { "info": "info", "name": "User", "type": "uint" }, + "{sid}": { "info": "info", "name": "Suffix", "type": "uint" } + }, + "out": {} + } + ]`, + path: "/prefix/{pid}/user/{uid}/suffix/{sid}", + handler: func(context.Context, struct { + Prefix uint + User uint + Suffix uint + }) (*struct{}, api.Err) { + return nil, api.ErrSuccess + }, + url: "/prefix/123/user/456/suffix/789", + body: ``, + permissions: []string{"prefix[123].user[456].suffix[789]"}, + granted: true, + }, + { + name: "replace three invalid dot separated", + manifest: `[ + { + "method": "POST", + "path": "/prefix/{pid}/user/{uid}/suffix/{sid}", + "info": "info", + "scope": [["prefix[Prefix].user[User].suffix[Suffix]"]], + "in": { + "{pid}": { "info": "info", "name": "Prefix", "type": "uint" }, + "{uid}": { "info": "info", "name": "User", "type": "uint" }, + "{sid}": { "info": "info", "name": "Suffix", "type": "uint" } + }, + "out": {} + } + ]`, + path: "/prefix/{pid}/user/{uid}/suffix/{sid}", + handler: func(context.Context, struct { + Prefix uint + User uint + Suffix uint + }) (*struct{}, api.Err) { + return nil, api.ErrSuccess + }, + url: "/prefix/123/user/666/suffix/789", + body: ``, + permissions: []string{"prefix[123].user[456].suffix[789]"}, + granted: false, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + builder := &aicra.Builder{} + if err := addDefaultTypes(builder); err != nil { + t.Fatalf("unexpected error <%v>", err) + } + + // tester middleware (last executed) + builder.WithContext(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + a := api.GetAuth(r.Context()) + if a == nil { + t.Fatalf("cannot access api.Auth form request context") + } + if a.Granted() == tc.granted { + return + } + if a.Granted() { + t.Fatalf("unexpected granted auth") + } else { + t.Fatalf("expected granted auth") + } + next.ServeHTTP(w, r) + }) + }) + + // update permissions + builder.WithContext(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + a := api.GetAuth(r.Context()) + if a == nil { + t.Fatalf("cannot access api.Auth form request context") + } + a.Active = tc.permissions + next.ServeHTTP(w, r) + }) + }) + + err := builder.Setup(strings.NewReader(tc.manifest)) + if err != nil { + t.Fatalf("setup: unexpected error <%v>", err) + } + + if err := builder.Bind(http.MethodPost, tc.path, tc.handler); err != nil { + t.Fatalf("bind: unexpected error <%v>", err) + } + + handler, err := builder.Build() + if err != nil { + t.Fatalf("build: unexpected error <%v>", err) + } + + response := httptest.NewRecorder() + body := strings.NewReader(tc.body) + request := httptest.NewRequest(http.MethodPost, tc.url, body) + + // test request + handler.ServeHTTP(response, request) + if response.Body == nil { + t.Fatalf("response has no body") + } + + }) + } + +} + +func TestHandler_ServiceErrors(t *testing.T) { + tt := []struct { + name string + manifest string + // handler + hmethod, huri string + hfn interface{} + // request + method, url string + contentType string + body string + permissions []string + err api.Err + }{ + // service match + { + name: "unknown service method", + manifest: `[ + { + "method": "GET", + "path": "/", + "info": "info", + "scope": [], + "in": {}, + "out": {} + } + ]`, + hmethod: http.MethodGet, + huri: "/", + hfn: func(context.Context) api.Err { + return api.ErrSuccess + }, + method: http.MethodPost, + url: "/", + body: ``, + permissions: []string{}, + err: api.ErrUnknownService, + }, + { + name: "unknown service path", + manifest: `[ + { + "method": "GET", + "path": "/", + "info": "info", + "scope": [], + "in": {}, + "out": {} + } + ]`, + hmethod: http.MethodGet, + huri: "/", + hfn: func(context.Context) api.Err { + return api.ErrSuccess + }, + method: http.MethodGet, + url: "/invalid", + body: ``, + permissions: []string{}, + err: api.ErrUnknownService, + }, + { + name: "valid empty service", + manifest: `[ + { + "method": "GET", + "path": "/", + "info": "info", + "scope": [], + "in": {}, + "out": {} + } + ]`, + hmethod: http.MethodGet, + huri: "/", + hfn: func(context.Context) api.Err { + return api.ErrSuccess + }, + method: http.MethodGet, + url: "/", + body: ``, + permissions: []string{}, + err: api.ErrSuccess, + }, + + // invalid uri param -> unknown service + { + name: "invalid uri param", + manifest: `[ + { + "method": "GET", + "path": "/a/{id}/b", + "info": "info", + "scope": [], + "in": { + "{id}": { "info": "info", "type": "int", "name": "ID" } + }, + "out": {} + } + ]`, + hmethod: http.MethodGet, + huri: "/a/{id}/b", + hfn: func(context.Context, struct{ ID int }) api.Err { + return api.ErrSuccess + }, + method: http.MethodGet, + url: "/a/invalid/b", + body: ``, + permissions: []string{}, + err: api.ErrUnknownService, + }, + + // query param + { + name: "missing query param", + manifest: `[ + { + "method": "GET", + "path": "/", + "info": "info", + "scope": [], + "in": { + "GET@id": { "info": "info", "type": "int", "name": "ID" } + }, + "out": {} + } + ]`, + hmethod: http.MethodGet, + huri: "/", + hfn: func(context.Context, struct{ ID int }) api.Err { + return api.ErrSuccess + }, + method: http.MethodGet, + url: "/", + body: ``, + permissions: []string{}, + err: api.ErrMissingParam, + }, + { + name: "invalid query param", + manifest: `[ + { + "method": "GET", + "path": "/a", + "info": "info", + "scope": [], + "in": { + "GET@id": { "info": "info", "type": "int", "name": "ID" } + }, + "out": {} + } + ]`, + hmethod: http.MethodGet, + huri: "/a", + hfn: func(context.Context, struct{ ID int }) api.Err { + return api.ErrSuccess + }, + method: http.MethodGet, + url: "/a?id=abc", + body: ``, + permissions: []string{}, + err: api.ErrInvalidParam, + }, + { + name: "invalid query multi param", + manifest: `[ + { + "method": "GET", + "path": "/a", + "info": "info", + "scope": [], + "in": { + "GET@id": { "info": "info", "type": "int", "name": "ID" } + }, + "out": {} + } + ]`, + hmethod: http.MethodGet, + huri: "/a", + hfn: func(context.Context, struct{ ID int }) api.Err { + return api.ErrSuccess + }, + method: http.MethodGet, + url: "/a?id=123&id=456", + body: ``, + permissions: []string{}, + err: api.ErrInvalidParam, + }, + { + name: "valid query param", + manifest: `[ + { + "method": "GET", + "path": "/a", + "info": "info", + "scope": [], + "in": { + "GET@id": { "info": "info", "type": "int", "name": "ID" } + }, + "out": {} + } + ]`, + hmethod: http.MethodGet, + huri: "/a", + hfn: func(context.Context, struct{ ID int }) api.Err { + return api.ErrSuccess + }, + method: http.MethodGet, + url: "/a?id=123", + body: ``, + permissions: []string{}, + err: api.ErrSuccess, + }, + + // json param + { + name: "missing json param", + manifest: `[ + { + "method": "POST", + "path": "/", + "info": "info", + "scope": [], + "in": { + "id": { "info": "info", "type": "int", "name": "ID" } + }, + "out": {} + } + ]`, + hmethod: http.MethodPost, + huri: "/", + hfn: func(context.Context, struct{ ID int }) api.Err { + return api.ErrSuccess + }, + contentType: "application/json", + method: http.MethodPost, + url: "/", + body: ``, + permissions: []string{}, + err: api.ErrMissingParam, + }, + { + name: "invalid json param", + manifest: `[ + { + "method": "POST", + "path": "/", + "info": "info", + "scope": [], + "in": { + "id": { "info": "info", "type": "int", "name": "ID" } + }, + "out": {} + } + ]`, + hmethod: http.MethodPost, + huri: "/", + hfn: func(context.Context, struct{ ID int }) api.Err { + return api.ErrSuccess + }, + contentType: "application/json", + method: http.MethodPost, + url: "/", + body: `{ "id": "invalid type" }`, + permissions: []string{}, + err: api.ErrInvalidParam, + }, + { + name: "valid json param", + manifest: `[ + { + "method": "POST", + "path": "/", + "info": "info", + "scope": [], + "in": { + "id": { "info": "info", "type": "int", "name": "ID" } + }, + "out": {} + } + ]`, + hmethod: http.MethodPost, + huri: "/", + hfn: func(context.Context, struct{ ID int }) api.Err { + return api.ErrSuccess + }, + contentType: "application/json", + method: http.MethodPost, + url: "/", + body: `{ "id": 123 }`, + permissions: []string{}, + err: api.ErrSuccess, + }, + + // urlencoded param + { + name: "missing urlencoded param", + manifest: `[ + { + "method": "POST", + "path": "/", + "info": "info", + "scope": [], + "in": { + "id": { "info": "info", "type": "int", "name": "ID" } + }, + "out": {} + } + ]`, + hmethod: http.MethodPost, + huri: "/", + hfn: func(context.Context, struct{ ID int }) api.Err { + return api.ErrSuccess + }, + contentType: "application/x-www-form-urlencoded", + method: http.MethodPost, + url: "/", + body: ``, + permissions: []string{}, + err: api.ErrMissingParam, + }, + { + name: "invalid urlencoded param", + manifest: `[ + { + "method": "POST", + "path": "/", + "info": "info", + "scope": [], + "in": { + "id": { "info": "info", "type": "int", "name": "ID" } + }, + "out": {} + } + ]`, + hmethod: http.MethodPost, + huri: "/", + hfn: func(context.Context, struct{ ID int }) api.Err { + return api.ErrSuccess + }, + contentType: "application/x-www-form-urlencoded", + method: http.MethodPost, + url: "/", + body: `id=abc`, + permissions: []string{}, + err: api.ErrInvalidParam, + }, + { + name: "valid urlencoded param", + manifest: `[ + { + "method": "POST", + "path": "/", + "info": "info", + "scope": [], + "in": { + "id": { "info": "info", "type": "int", "name": "ID" } + }, + "out": {} + } + ]`, + hmethod: http.MethodPost, + huri: "/", + hfn: func(context.Context, struct{ ID int }) api.Err { + return api.ErrSuccess + }, + contentType: "application/x-www-form-urlencoded", + method: http.MethodPost, + url: "/", + body: `id=123`, + permissions: []string{}, + err: api.ErrSuccess, + }, + + // formdata param + { + name: "missing multipart param", + manifest: `[ + { + "method": "POST", + "path": "/", + "info": "info", + "scope": [], + "in": { + "id": { "info": "info", "type": "int", "name": "ID" } + }, + "out": {} + } + ]`, + hmethod: http.MethodPost, + huri: "/", + hfn: func(context.Context, struct{ ID int }) api.Err { + return api.ErrSuccess + }, + contentType: "multipart/form-data; boundary=xxx", + method: http.MethodPost, + url: "/", + body: ``, + permissions: []string{}, + err: api.ErrMissingParam, + }, + { + name: "invalid multipart param", + manifest: `[ + { + "method": "POST", + "path": "/", + "info": "info", + "scope": [], + "in": { + "id": { "info": "info", "type": "int", "name": "ID" } + }, + "out": {} + } + ]`, + hmethod: http.MethodPost, + huri: "/", + hfn: func(context.Context, struct{ ID int }) api.Err { + return api.ErrSuccess + }, + contentType: "multipart/form-data; boundary=xxx", + method: http.MethodPost, + url: "/", + body: `--xxx +Content-Disposition: form-data; name="id" + +abc +--xxx--`, + permissions: []string{}, + err: api.ErrInvalidParam, + }, + { + name: "valid multipart param", + manifest: `[ + { + "method": "POST", + "path": "/", + "info": "info", + "scope": [], + "in": { + "id": { "info": "info", "type": "int", "name": "ID" } + }, + "out": {} + } + ]`, + hmethod: http.MethodPost, + huri: "/", + hfn: func(context.Context, struct{ ID int }) api.Err { + return api.ErrSuccess + }, + contentType: "multipart/form-data; boundary=xxx", + method: http.MethodPost, + url: "/", + body: `--xxx +Content-Disposition: form-data; name="id" + +123 +--xxx--`, + permissions: []string{}, + err: api.ErrSuccess, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + builder := &aicra.Builder{} + if err := addDefaultTypes(builder); err != nil { + t.Fatalf("unexpected error <%v>", err) + } + + err := builder.Setup(strings.NewReader(tc.manifest)) + if err != nil { + t.Fatalf("setup: unexpected error <%v>", err) + } + + if err := builder.Bind(tc.hmethod, tc.huri, tc.hfn); err != nil { + t.Fatalf("bind: unexpected error <%v>", err) + } + + handler, err := builder.Build() + if err != nil { + t.Fatalf("build: unexpected error <%v>", err) + } + + var ( + response = httptest.NewRecorder() + body = strings.NewReader(tc.body) + request = httptest.NewRequest(tc.method, tc.url, body) + ) + if len(tc.contentType) > 0 { + request.Header.Add("Content-Type", tc.contentType) + } + + // test request + handler.ServeHTTP(response, request) + if response.Body == nil { + t.Fatalf("response has no body") + } + + jsonErr, err := json.Marshal(tc.err) + if err != nil { + t.Fatalf("cannot marshal expected error: %v", err) + } + jsonExpected := fmt.Sprintf(`{"error":%s}`, jsonErr) + if response.Body.String() != jsonExpected { + t.Fatalf("invalid response:\n- actual: %s\n- expect: %s\n", printEscaped(response.Body.String()), printEscaped(jsonExpected)) + } + }) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index cc529d0..cea9b23 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -18,37 +18,37 @@ type Server struct { // Parse a configuration into a server. Server.Types must be set beforehand to // make datatypes available when checking and formatting the read configuration. -func (srv *Server) Parse(r io.Reader) error { - err := json.NewDecoder(r).Decode(&srv.Services) +func (s *Server) Parse(r io.Reader) error { + err := json.NewDecoder(r).Decode(&s.Services) if err != nil { - return fmt.Errorf("%s: %w", errRead, err) + return fmt.Errorf("%s: %w", ErrRead, err) } - err = srv.validate() + err = s.validate() if err != nil { - return fmt.Errorf("%s: %w", errFormat, err) + return fmt.Errorf("%s: %w", ErrFormat, err) } return nil } // validate implements the validator interface -func (server Server) validate(datatypes ...validator.Type) error { - for _, service := range server.Services { - err := service.validate(server.Validators...) +func (s Server) validate(datatypes ...validator.Type) error { + for _, service := range s.Services { + err := service.validate(s.Validators...) if err != nil { return fmt.Errorf("%s '%s': %w", service.Method, service.Pattern, err) } } - if err := server.collide(); err != nil { - return fmt.Errorf("%s: %w", errFormat, err) + if err := s.collide(); err != nil { + return fmt.Errorf("%s: %w", ErrFormat, err) } return nil } // Find a service matching an incoming HTTP request -func (server Server) Find(r *http.Request) *Service { - for _, service := range server.Services { +func (s Server) Find(r *http.Request) *Service { + for _, service := range s.Services { if matches := service.Match(r); matches { return service } @@ -62,14 +62,14 @@ func (server Server) Find(r *http.Request) *Service { // - example 1: `/user/{id}` and `/user/articles` will not collide as {id} is an int and "articles" is not // - example 2: `/user/{name}` and `/user/articles` will collide as {name} is a string so as "articles" // - example 3: `/user/{name}` and `/user/{id}` will collide as {name} and {id} cannot be checked against their potential values -func (server *Server) collide() error { - length := len(server.Services) +func (s *Server) collide() error { + length := len(s.Services) // for each service combination for a := 0; a < length; a++ { for b := a + 1; b < length; b++ { - aService := server.Services[a] - bService := server.Services[b] + aService := s.Services[a] + bService := s.Services[b] if aService.Method != bService.Method { continue @@ -105,14 +105,14 @@ func checkURICollision(uriA, uriB []string, inputA, inputB map[string]*Parameter // both captures -> as we cannot check, consider a collision if aIsCapture && bIsCapture { - errors = append(errors, fmt.Errorf("%w (path %s and %s)", errPatternCollision, aPart, bPart)) + errors = append(errors, fmt.Errorf("%w (path %s and %s)", ErrPatternCollision, aPart, bPart)) continue } // no capture -> check strict equality if !aIsCapture && !bIsCapture { if aPart == bPart { - errors = append(errors, fmt.Errorf("%w (same path '%s')", errPatternCollision, aPart)) + errors = append(errors, fmt.Errorf("%w (same path '%s')", ErrPatternCollision, aPart)) continue } } @@ -123,13 +123,13 @@ func checkURICollision(uriA, uriB []string, inputA, inputB map[string]*Parameter // fail if no type or no validator if !exists || input.Validator == nil { - errors = append(errors, fmt.Errorf("%w (invalid type for %s)", errPatternCollision, aPart)) + errors = append(errors, fmt.Errorf("%w (invalid type for %s)", ErrPatternCollision, aPart)) continue } // fail if not valid if _, valid := input.Validator(bPart); valid { - errors = append(errors, fmt.Errorf("%w (%s captures '%s')", errPatternCollision, aPart, bPart)) + errors = append(errors, fmt.Errorf("%w (%s captures '%s')", ErrPatternCollision, aPart, bPart)) continue } @@ -139,13 +139,13 @@ func checkURICollision(uriA, uriB []string, inputA, inputB map[string]*Parameter // fail if no type or no validator if !exists || input.Validator == nil { - errors = append(errors, fmt.Errorf("%w (invalid type for %s)", errPatternCollision, bPart)) + errors = append(errors, fmt.Errorf("%w (invalid type for %s)", ErrPatternCollision, bPart)) continue } // fail if not valid if _, valid := input.Validator(aPart); valid { - errors = append(errors, fmt.Errorf("%w (%s captures '%s')", errPatternCollision, bPart, aPart)) + errors = append(errors, fmt.Errorf("%w (%s captures '%s')", ErrPatternCollision, bPart, aPart)) continue } } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 5db787c..9d9dd60 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -21,15 +21,15 @@ func TestLegalServiceName(t *testing.T) { // empty { `[ { "method": "GET", "info": "a", "path": "" } ]`, - errInvalidPattern, + ErrInvalidPattern, }, { `[ { "method": "GET", "info": "a", "path": "no-starting-slash" } ]`, - errInvalidPattern, + ErrInvalidPattern, }, { `[ { "method": "GET", "info": "a", "path": "ending-slash/" } ]`, - errInvalidPattern, + ErrInvalidPattern, }, { `[ { "method": "GET", "info": "a", "path": "/" } ]`, @@ -45,35 +45,35 @@ func TestLegalServiceName(t *testing.T) { }, { `[ { "method": "GET", "info": "a", "path": "/invalid/s{braces}" } ]`, - errInvalidPatternBraceCapture, + ErrInvalidPatternBraceCapture, }, { `[ { "method": "GET", "info": "a", "path": "/invalid/{braces}a" } ]`, - errInvalidPatternBraceCapture, + ErrInvalidPatternBraceCapture, }, { `[ { "method": "GET", "info": "a", "path": "/invalid/{braces}" } ]`, - errUndefinedBraceCapture, + ErrUndefinedBraceCapture, }, { `[ { "method": "GET", "info": "a", "path": "/invalid/s{braces}/abc" } ]`, - errInvalidPatternBraceCapture, + ErrInvalidPatternBraceCapture, }, { `[ { "method": "GET", "info": "a", "path": "/invalid/{braces}s/abc" } ]`, - errInvalidPatternBraceCapture, + ErrInvalidPatternBraceCapture, }, { `[ { "method": "GET", "info": "a", "path": "/invalid/{braces}/abc" } ]`, - errUndefinedBraceCapture, + ErrUndefinedBraceCapture, }, { `[ { "method": "GET", "info": "a", "path": "/invalid/{b{races}s/abc" } ]`, - errInvalidPatternBraceCapture, + ErrInvalidPatternBraceCapture, }, { `[ { "method": "GET", "info": "a", "path": "/invalid/{braces}/}abc" } ]`, - errInvalidPatternBraceCapture, + ErrInvalidPatternBraceCapture, }, } @@ -143,8 +143,8 @@ func TestAvailableMethods(t *testing.T) { t.FailNow() } - if !test.ValidMethod && !errors.Is(err, errUnknownMethod) { - t.Errorf("expected error <%s> got <%s>", errUnknownMethod, err) + if !test.ValidMethod && !errors.Is(err, ErrUnknownMethod) { + t.Errorf("expected error <%s> got <%s>", ErrUnknownMethod, err) t.FailNow() } }) @@ -184,7 +184,7 @@ func TestParseMissingMethodDescription(t *testing.T) { `[ { "method": "GET", "path": "/" }]`, false, }, - { // missing description + { // missing descriptiontype `[ { "method": "GET", "path": "/subservice" }]`, false, }, @@ -217,8 +217,8 @@ func TestParseMissingMethodDescription(t *testing.T) { t.FailNow() } - if !test.ValidDescription && !errors.Is(err, errMissingDescription) { - t.Errorf("expected error <%s> got <%s>", errMissingDescription, err) + if !test.ValidDescription && !errors.Is(err, ErrMissingDescription) { + t.Errorf("expected error <%s> got <%s>", ErrMissingDescription, err) t.FailNow() } }) @@ -321,7 +321,7 @@ func TestParseParameters(t *testing.T) { } } ]`, - errMissingParamDesc, + ErrMissingParamDesc, }, { // invalid param name suffix `[ @@ -334,7 +334,7 @@ func TestParseParameters(t *testing.T) { } } ]`, - errMissingParamDesc, + ErrMissingParamDesc, }, { // missing param description @@ -348,7 +348,7 @@ func TestParseParameters(t *testing.T) { } } ]`, - errMissingParamDesc, + ErrMissingParamDesc, }, { // empty param description `[ @@ -361,7 +361,7 @@ func TestParseParameters(t *testing.T) { } } ]`, - errMissingParamDesc, + ErrMissingParamDesc, }, { // missing param type @@ -375,7 +375,7 @@ func TestParseParameters(t *testing.T) { } } ]`, - errMissingParamType, + ErrMissingParamType, }, { // empty param type `[ @@ -388,7 +388,7 @@ func TestParseParameters(t *testing.T) { } } ]`, - errMissingParamType, + ErrMissingParamType, }, { // invalid type (optional mark only) `[ @@ -402,7 +402,7 @@ func TestParseParameters(t *testing.T) { } ]`, - errMissingParamType, + ErrMissingParamType, }, { // valid description + valid type `[ @@ -444,7 +444,7 @@ func TestParseParameters(t *testing.T) { } ]`, // 2 possible errors as map order is not deterministic - errParamNameConflict, + ErrParamNameConflict, }, { // rename conflict with name `[ @@ -459,7 +459,7 @@ func TestParseParameters(t *testing.T) { } ]`, // 2 possible errors as map order is not deterministic - errParamNameConflict, + ErrParamNameConflict, }, { // rename conflict with rename `[ @@ -474,7 +474,7 @@ func TestParseParameters(t *testing.T) { } ]`, // 2 possible errors as map order is not deterministic - errParamNameConflict, + ErrParamNameConflict, }, { // both renamed with no conflict @@ -503,7 +503,7 @@ func TestParseParameters(t *testing.T) { } } ]`, - errMandatoryRename, + ErrMandatoryRename, }, { `[ @@ -516,7 +516,7 @@ func TestParseParameters(t *testing.T) { } } ]`, - errMandatoryRename, + ErrMandatoryRename, }, { `[ @@ -556,7 +556,7 @@ func TestParseParameters(t *testing.T) { } } ]`, - errIllegalOptionalURIParam, + ErrIllegalOptionalURIParam, }, { // URI parameter not specified `[ @@ -569,7 +569,7 @@ func TestParseParameters(t *testing.T) { } } ]`, - errUnspecifiedBraceCapture, + ErrUnspecifiedBraceCapture, }, { // URI parameter not defined `[ @@ -580,7 +580,7 @@ func TestParseParameters(t *testing.T) { "in": { } } ]`, - errUndefinedBraceCapture, + ErrUndefinedBraceCapture, }, } @@ -637,7 +637,7 @@ func TestServiceCollision(t *testing.T) { "info": "info", "in": {} } ]`, - errPatternCollision, + ErrPatternCollision, }, { `[ @@ -672,7 +672,7 @@ func TestServiceCollision(t *testing.T) { } } ]`, - errPatternCollision, + ErrPatternCollision, }, { `[ @@ -698,7 +698,7 @@ func TestServiceCollision(t *testing.T) { } } ]`, - errPatternCollision, + ErrPatternCollision, }, { `[ @@ -711,7 +711,7 @@ func TestServiceCollision(t *testing.T) { } } ]`, - errPatternCollision, + ErrPatternCollision, }, { `[ @@ -750,7 +750,7 @@ func TestServiceCollision(t *testing.T) { } } ]`, - errPatternCollision, + ErrPatternCollision, }, { `[ @@ -789,7 +789,7 @@ func TestServiceCollision(t *testing.T) { } } ]`, - errPatternCollision, + ErrPatternCollision, }, { `[ @@ -804,7 +804,7 @@ func TestServiceCollision(t *testing.T) { } } ]`, - errPatternCollision, + ErrPatternCollision, }, { `[ diff --git a/internal/config/errors.go b/internal/config/errors.go index 91f4997..96fe253 100644 --- a/internal/config/errors.go +++ b/internal/config/errors.go @@ -1,59 +1,61 @@ package config -// cerr allows you to create constant "const" error with type boxing. -type cerr string +// Err allows you to create constant "const" error with type boxing. +type Err string -func (err cerr) Error() string { +func (err Err) Error() string { return string(err) } -// errRead - read error -const errRead = cerr("cannot read config") +const ( + // ErrRead - read error + ErrRead = Err("cannot read config") -// errUnknownMethod - unknown http method -const errUnknownMethod = cerr("unknown HTTP method") + // ErrUnknownMethod - unknown http method + ErrUnknownMethod = Err("unknown HTTP method") -// errFormat - invalid format -const errFormat = cerr("invalid config format") + // ErrFormat - invalid format + ErrFormat = Err("invalid config format") -// errPatternCollision - collision between 2 services' patterns -const errPatternCollision = cerr("pattern collision") + // ErrPatternCollision - collision between 2 services' patterns + ErrPatternCollision = Err("pattern collision") -// errInvalidPattern - malformed service pattern -const errInvalidPattern = cerr("malformed service path: must begin with a '/' and not end with") + // ErrInvalidPattern - malformed service pattern + ErrInvalidPattern = Err("malformed service path: must begin with a '/' and not end with") -// errInvalidPatternBraceCapture - invalid brace capture -const errInvalidPatternBraceCapture = cerr("invalid uri parameter") + // ErrInvalidPatternBraceCapture - invalid brace capture + ErrInvalidPatternBraceCapture = Err("invalid uri parameter") -// errUnspecifiedBraceCapture - missing path brace capture -const errUnspecifiedBraceCapture = cerr("missing uri parameter") + // ErrUnspecifiedBraceCapture - missing path brace capture + ErrUnspecifiedBraceCapture = Err("missing uri parameter") -// errUndefinedBraceCapture - missing capturing brace definition -const errUndefinedBraceCapture = cerr("missing uri parameter definition") + // ErrUndefinedBraceCapture - missing capturing brace definition + ErrUndefinedBraceCapture = Err("missing uri parameter definition") -// errMandatoryRename - capture/query parameters must be renamed -const errMandatoryRename = cerr("uri and query parameters must be renamed") + // ErrMandatoryRename - capture/query parameters must be renamed + ErrMandatoryRename = Err("uri and query parameters must be renamed") -// errMissingDescription - a service is missing its description -const errMissingDescription = cerr("missing description") + // ErrMissingDescription - a service is missing its description + ErrMissingDescription = Err("missing description") -// errIllegalOptionalURIParam - uri parameter cannot optional -const errIllegalOptionalURIParam = cerr("uri parameter cannot be optional") + // ErrIllegalOptionalURIParam - uri parameter cannot optional + ErrIllegalOptionalURIParam = Err("uri parameter cannot be optional") -// errOptionalOption - cannot have optional output -const errOptionalOption = cerr("output cannot be optional") + // ErrOptionalOption - cannot have optional output + ErrOptionalOption = Err("output cannot be optional") -// errMissingParamDesc - missing parameter description -const errMissingParamDesc = cerr("missing parameter description") + // ErrMissingParamDesc - missing parameter description + ErrMissingParamDesc = Err("missing parameter description") -// errUnknownDataType - unknown parameter datatype -const errUnknownDataType = cerr("unknown parameter datatype") + // ErrUnknownParamType - unknown parameter type + ErrUnknownParamType = Err("unknown parameter datatype") -// errIllegalParamName - illegal parameter name -const errIllegalParamName = cerr("illegal parameter name") + // ErrIllegalParamName - illegal parameter name + ErrIllegalParamName = Err("illegal parameter name") -// errMissingParamType - missing parameter type -const errMissingParamType = cerr("missing parameter type") + // ErrMissingParamType - missing parameter type + ErrMissingParamType = Err("missing parameter type") -// errParamNameConflict - name/rename conflict -const errParamNameConflict = cerr("parameter name conflict") + // ErrParamNameConflict - name/rename conflict + ErrParamNameConflict = Err("parameter name conflict") +) diff --git a/internal/config/parameter.go b/internal/config/parameter.go index a22227b..9e13e77 100644 --- a/internal/config/parameter.go +++ b/internal/config/parameter.go @@ -20,11 +20,11 @@ type Parameter struct { func (param *Parameter) validate(datatypes ...validator.Type) error { if len(param.Description) < 1 { - return errMissingParamDesc + return ErrMissingParamDesc } if len(param.Type) < 1 || param.Type == "?" { - return errMissingParamType + return ErrMissingParamType } // optional type @@ -42,7 +42,7 @@ func (param *Parameter) validate(datatypes ...validator.Type) error { } } if param.Validator == nil { - return errUnknownDataType + return ErrUnknownParamType } return nil } diff --git a/internal/config/service.go b/internal/config/service.go index 6d10d27..1b84dc7 100644 --- a/internal/config/service.go +++ b/internal/config/service.go @@ -9,9 +9,11 @@ import ( "github.com/xdrm-io/aicra/validator" ) -var braceRegex = regexp.MustCompile(`^{([a-z_-]+)}$`) -var queryRegex = regexp.MustCompile(`^GET@([a-z_-]+)$`) -var availableHTTPMethods = []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete} +var ( + captureRegex = regexp.MustCompile(`^{([a-z_-]+)}$`) + queryRegex = regexp.MustCompile(`^GET@([a-z_-]+)$`) + availableHTTPMethods = []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete} +) // Service definition type Service struct { @@ -43,19 +45,25 @@ type BraceCapture struct { // Match returns if this service would handle this HTTP request func (svc *Service) Match(req *http.Request) bool { - if req.Method != svc.Method { - return false + var ( + uri = req.RequestURI + queryIndex = strings.IndexByte(uri, '?') + ) + + // remove query part for matching the pattern + if queryIndex > -1 { + uri = uri[:queryIndex] } - if !svc.matchPattern(req.RequestURI) { - return false - } - return true + + return req.Method == svc.Method && svc.matchPattern(uri) } // checks if an uri matches the service's pattern func (svc *Service) matchPattern(uri string) bool { - uriparts := SplitURL(uri) - parts := SplitURL(svc.Pattern) + var ( + uriparts = SplitURL(uri) + parts = SplitURL(svc.Pattern) + ) if len(uriparts) != len(parts) { return false @@ -98,39 +106,34 @@ func (svc *Service) matchPattern(uri string) bool { // Validate implements the validator interface func (svc *Service) validate(datatypes ...validator.Type) error { - // check method - err := svc.isMethodAvailable() + err := svc.checkMethod() if err != nil { return fmt.Errorf("field 'method': %w", err) } - // check pattern svc.Pattern = strings.Trim(svc.Pattern, " \t\r\n") - err = svc.isPatternValid() + err = svc.checkPattern() if err != nil { return fmt.Errorf("field 'path': %w", err) } - // check description if len(strings.Trim(svc.Description, " \t\r\n")) < 1 { - return fmt.Errorf("field 'description': %w", errMissingDescription) + return fmt.Errorf("field 'description': %w", ErrMissingDescription) } - // check input parameters - err = svc.validateInput(datatypes) + err = svc.checkInput(datatypes) if err != nil { return fmt.Errorf("field 'in': %w", err) } - // fail if a brace capture remains undefined + // fail when a brace capture remains undefined for _, capture := range svc.Captures { if capture.Ref == nil { - return fmt.Errorf("field 'in': %s: %w", capture.Name, errUndefinedBraceCapture) + return fmt.Errorf("field 'in': %s: %w", capture.Name, ErrUndefinedBraceCapture) } } - // check output - err = svc.validateOutput(datatypes) + err = svc.checkOutput(datatypes) if err != nil { return fmt.Errorf("field 'out': %w", err) } @@ -138,27 +141,34 @@ func (svc *Service) validate(datatypes ...validator.Type) error { return nil } -func (svc *Service) isMethodAvailable() error { +func (svc *Service) checkMethod() error { for _, available := range availableHTTPMethods { if svc.Method == available { return nil } } - return errUnknownMethod + return ErrUnknownMethod } -func (svc *Service) isPatternValid() error { +// checkPattern checks for the validity of the pattern definition (i.e. the uri) +// +// Note that the uri can contain capture params e.g. `/a/{b}/c/{d}`, in this +// example, input parameters with names `{b}` and `{d}` are expected. +// +// This methods sets up the service state with adding capture params that are +// expected; checkInputs() will be able to check params agains pattern captures. +func (svc *Service) checkPattern() error { length := len(svc.Pattern) // empty pattern if length < 1 { - return errInvalidPattern + return ErrInvalidPattern } if length > 1 { // pattern not starting with '/' or ending with '/' if svc.Pattern[0] != '/' || svc.Pattern[length-1] == '/' { - return errInvalidPattern + return ErrInvalidPattern } } @@ -166,11 +176,11 @@ func (svc *Service) isPatternValid() error { parts := SplitURL(svc.Pattern) for i, part := range parts { if len(part) < 1 { - return errInvalidPattern + return ErrInvalidPattern } // if brace capture - if matches := braceRegex.FindAllStringSubmatch(part, -1); len(matches) > 0 && len(matches[0]) > 1 { + if matches := captureRegex.FindAllStringSubmatch(part, -1); len(matches) > 0 && len(matches[0]) > 1 { braceName := matches[0][1] // append @@ -187,149 +197,185 @@ func (svc *Service) isPatternValid() error { // fail on invalid format if strings.ContainsAny(part, "{}") { - return errInvalidPatternBraceCapture + return ErrInvalidPatternBraceCapture } - } return nil } -func (svc *Service) validateInput(types []validator.Type) error { - - // ignore no parameter +func (svc *Service) checkInput(types []validator.Type) error { + // no parameter if svc.Input == nil || len(svc.Input) < 1 { - svc.Input = make(map[string]*Parameter, 0) + svc.Input = map[string]*Parameter{} return nil } // for each parameter - for paramName, param := range svc.Input { - if len(paramName) < 1 { - return fmt.Errorf("%s: %w", paramName, errIllegalParamName) + for name, p := range svc.Input { + if len(name) < 1 { + return fmt.Errorf("%s: %w", name, ErrIllegalParamName) } - // fail if brace capture does not exists in pattern - var iscapture, isquery bool - if matches := braceRegex.FindAllStringSubmatch(paramName, -1); len(matches) > 0 && len(matches[0]) > 1 { - braceName := matches[0][1] - - found := false - for _, capture := range svc.Captures { - if capture.Name == braceName { - capture.Ref = param - found = true - break - } - } - if !found { - return fmt.Errorf("%s: %w", paramName, errUnspecifiedBraceCapture) - } - iscapture = true - - } else if matches := queryRegex.FindAllStringSubmatch(paramName, -1); len(matches) > 0 && len(matches[0]) > 1 { - - queryName := matches[0][1] - - // init map - if svc.Query == nil { - svc.Query = make(map[string]*Parameter) - } - svc.Query[queryName] = param - isquery = true - } else { - if svc.Form == nil { - svc.Form = make(map[string]*Parameter) - } - svc.Form[paramName] = param - } - - // fail if capture or query without rename - if len(param.Rename) < 1 && (iscapture || isquery) { - return fmt.Errorf("%s: %w", paramName, errMandatoryRename) - } - - // use param name if no rename - if len(param.Rename) < 1 { - param.Rename = paramName - } - - err := param.validate(types...) + // parse parameters: capture (uri), query or form and update the service + // attributes accordingly + ptype, err := svc.parseParam(name, p) if err != nil { - return fmt.Errorf("%s: %w", paramName, err) + return err + } + + // Rename mandatory for capture and query + if len(p.Rename) < 1 && (ptype == captureParam || ptype == queryParam) { + return fmt.Errorf("%s: %w", name, ErrMandatoryRename) + } + + // fallback to name when Rename is not provided + if len(p.Rename) < 1 { + p.Rename = name + } + + err = p.validate(types...) + if err != nil { + return fmt.Errorf("%s: %w", name, err) } // capture parameter cannot be optional - if iscapture && param.Optional { - return fmt.Errorf("%s: %w", paramName, errIllegalOptionalURIParam) + if p.Optional && ptype == captureParam { + return fmt.Errorf("%s: %w", name, ErrIllegalOptionalURIParam) } - // fail on name/rename conflict - for paramName2, param2 := range svc.Input { - // ignore self - if paramName == paramName2 { - continue - } - - // 3.2.1. Same rename field - // 3.2.2. Not-renamed field matches a renamed field - // 3.2.3. Renamed field matches name - if param.Rename == param2.Rename || paramName == param2.Rename || paramName2 == param.Rename { - return fmt.Errorf("%s: %w", paramName, errParamNameConflict) - } - + err = nameConflicts(name, p, svc.Input) + if err != nil { + return err } - } - return nil } -func (svc *Service) validateOutput(types []validator.Type) error { - - // ignore no parameter +func (svc *Service) checkOutput(types []validator.Type) error { + // no parameter if svc.Output == nil || len(svc.Output) < 1 { svc.Output = make(map[string]*Parameter, 0) return nil } - // for each parameter - for paramName, param := range svc.Output { - if len(paramName) < 1 { - return fmt.Errorf("%s: %w", paramName, errIllegalParamName) + for name, p := range svc.Output { + if len(name) < 1 { + return fmt.Errorf("%s: %w", name, ErrIllegalParamName) } - // use param name if no rename - if len(param.Rename) < 1 { - param.Rename = paramName + // fallback to name when Rename is not provided + if len(p.Rename) < 1 { + p.Rename = name } - err := param.validate(types...) + err := p.validate(types...) if err != nil { - return fmt.Errorf("%s: %w", paramName, err) + return fmt.Errorf("%s: %w", name, err) } - if param.Optional { - return fmt.Errorf("%s: %w", paramName, errOptionalOption) + if p.Optional { + return fmt.Errorf("%s: %w", name, ErrOptionalOption) } - // fail on name/rename conflict - for paramName2, param2 := range svc.Output { - // ignore self - if paramName == paramName2 { - continue - } - - // 3.2.1. Same rename field - // 3.2.2. Not-renamed field matches a renamed field - // 3.2.3. Renamed field matches name - if param.Rename == param2.Rename || paramName == param2.Rename || paramName2 == param.Rename { - return fmt.Errorf("%s: %w", paramName, errParamNameConflict) - } - + err = nameConflicts(name, p, svc.Output) + if err != nil { + return err + } + } + return nil +} + +type paramType int + +const ( + captureParam paramType = iota + queryParam + formParam +) + +// parseParam determines which param type it is from its name: +// - `{paramName}` is an capture; it captures a segment of the uri defined in +// the pattern definition, e.g. `/some/path/with/{paramName}/somewhere` +// - `GET@paramName` is an uri query that is received from the http query format +// in the uri, e.g. `http://domain.com/uri?paramName=paramValue¶m2=value2` +// - any other name that contains valid characters is considered a Form +// parameter; it is extracted from the http request's body as: json, multipart +// or using the x-www-form-urlencoded format. +// +// Special notes: +// - capture params MUST be found in the pattern definition. +// - capture params MUST NOT be optional as they are in the pattern anyways. +// - capture and query params MUST be renamed because the `{param}` or +// `GET@param` name formats cannot be translated to a valid go exported name. +// c.f. the `dynfunc` package that creates a handler func() signature from +// the service definitions (i.e. input and output parameters). +func (svc *Service) parseParam(name string, p *Parameter) (paramType, error) { + var ( + captureMatches = captureRegex.FindAllStringSubmatch(name, -1) + isCapture = len(captureMatches) > 0 && len(captureMatches[0]) > 1 + ) + + // Parameter is a capture (uri/{param}) + if isCapture { + captureName := captureMatches[0][1] + + // fail if brace capture does not exists in pattern + found := false + for _, capture := range svc.Captures { + if capture.Name == captureName { + capture.Ref = p + found = true + break + } + } + if !found { + return captureParam, fmt.Errorf("%s: %w", name, ErrUnspecifiedBraceCapture) + } + return captureParam, nil + } + + var ( + queryMatches = queryRegex.FindAllStringSubmatch(name, -1) + isQuery = len(queryMatches) > 0 && len(queryMatches[0]) > 1 + ) + + // Parameter is a query (uri?param) + if isQuery { + queryName := queryMatches[0][1] + + // init map + if svc.Query == nil { + svc.Query = make(map[string]*Parameter) + } + svc.Query[queryName] = p + + return queryParam, nil + } + + // Parameter is a form param + if svc.Form == nil { + svc.Form = make(map[string]*Parameter) + } + svc.Form[name] = p + return formParam, nil +} + +// nameConflicts returns whether ar given parameter has its name or Rename field +// in conflict with an existing parameter +func nameConflicts(name string, param *Parameter, others map[string]*Parameter) error { + for otherName, other := range others { + // ignore self + if otherName == name { + continue + } + + // 1. same rename field + // 2. original name matches a renamed field + // 3. renamed field matches an original name + if param.Rename == other.Rename || name == other.Rename || otherName == param.Rename { + return fmt.Errorf("%s: %w", otherName, ErrParamNameConflict) } - } - return nil } diff --git a/internal/dynfunc/errors.go b/internal/dynfunc/errors.go index 903e5f0..0fa7278 100644 --- a/internal/dynfunc/errors.go +++ b/internal/dynfunc/errors.go @@ -1,50 +1,52 @@ package dynfunc -// cerr allows you to create constant "const" error with type boxing. -type cerr string +// Err allows you to create constant "const" error with type boxing. +type Err string -func (err cerr) Error() string { +func (err Err) Error() string { return string(err) } -// errHandlerNotFunc - handler is not a func -const errHandlerNotFunc = cerr("handler must be a func") +const ( + // ErrHandlerNotFunc - handler is not a func + ErrHandlerNotFunc = Err("handler must be a func") -// errNoServiceForHandler - no service matching this handler -const errNoServiceForHandler = cerr("no service found for this handler") + // ErrNoServiceForHandler - no service matching this handler + ErrNoServiceForHandler = Err("no service found for this handler") -// errMissingHandlerArgumentParam - missing params arguments for handler -const errMissingHandlerContextArgument = cerr("missing handler first argument of type context.Context") + // errMissingHandlerArgumentParam - missing params arguments for handler + ErrMissingHandlerContextArgument = Err("missing handler first argument of type context.Context") -// errMissingHandlerInputArgument - missing params arguments for handler -const errMissingHandlerInputArgument = cerr("missing handler argument: input struct") + // ErrInvalidHandlerContextArgument - missing handler output error + ErrInvalidHandlerContextArgument = Err("first input argument should be of type context.Context") -// errUnexpectedInput - input argument is not expected -const errUnexpectedInput = cerr("unexpected input struct") + // ErrMissingHandlerInputArgument - missing params arguments for handler + ErrMissingHandlerInputArgument = Err("missing handler argument: input struct") -// errMissingHandlerOutputArgument - missing output for handler -const errMissingHandlerOutputArgument = cerr("missing handler first output argument: output struct") + // ErrUnexpectedInput - input argument is not expected + ErrUnexpectedInput = Err("unexpected input struct") -// errMissingHandlerOutputError - missing error output for handler -const errMissingHandlerOutputError = cerr("missing handler last output argument of type api.Err") + // ErrMissingHandlerOutputArgument - missing output for handler + ErrMissingHandlerOutputArgument = Err("missing handler first output argument: output struct") -// errMissingRequestArgument - missing request argument for handler -const errMissingRequestArgument = cerr("handler first argument must be of type api.Request") + // ErrMissingHandlerErrorArgument - missing error output for handler + ErrMissingHandlerErrorArgument = Err("missing handler last output argument of type api.Err") -// errMissingParamArgument - missing parameters argument for handler -const errMissingParamArgument = cerr("handler second argument must be a struct") + // ErrInvalidHandlerErrorArgument - missing handler output error + ErrInvalidHandlerErrorArgument = Err("last output must be of type api.Err") -// errUnexportedName - argument is unexported in struct -const errUnexportedName = cerr("unexported name") + // ErrMissingParamArgument - missing parameters argument for handler + ErrMissingParamArgument = Err("handler second argument must be a struct") -// errWrongOutputArgumentType - wrong type for output first argument -const errWrongOutputArgumentType = cerr("handler first output argument must be a *struct") + // ErrUnexportedName - argument is unexported in struct + ErrUnexportedName = Err("unexported name") -// errMissingConfigArgument - missing an input/output argument in handler struct -const errMissingConfigArgument = cerr("missing an argument from the configuration") + // ErrWrongOutputArgumentType - wrong type for output first argument + ErrWrongOutputArgumentType = Err("handler first output argument must be a *struct") -// errWrongParamTypeFromConfig - a configuration parameter type is invalid in the handler param struct -const errWrongParamTypeFromConfig = cerr("invalid struct field type") + // ErrMissingConfigArgument - missing an input/output argument in handler struct + ErrMissingConfigArgument = Err("missing an argument from the configuration") -// errMissingHandlerErrorArgument - missing handler output error -const errMissingHandlerErrorArgument = cerr("last output must be of type api.Err") + // ErrWrongParamTypeFromConfig - a configuration parameter type is invalid in the handler param struct + ErrWrongParamTypeFromConfig = Err("invalid struct field type") +) diff --git a/internal/dynfunc/handler.go b/internal/dynfunc/handler.go index 395d485..3010396 100644 --- a/internal/dynfunc/handler.go +++ b/internal/dynfunc/handler.go @@ -20,14 +20,14 @@ type Handler struct { // Build a handler from a dynamic function and checks its signature against a // service configuration -//e -// `fn` must have as a signature : `func(*api.Context, in) (*out, api.Err)` +// +// `fn` must have as a signature : `func(context.Context, in) (*out, api.Err)` // - `in` is a struct{} containing a field for each service input (with valid reflect.Type) // - `out` is a struct{} containing a field for each service output (with valid reflect.Type) // // Special cases: -// - it there is no input, `in` MUST be omitted -// - it there is no output, `out` MUST be omitted +// - it there is no input, `in` MUST be omitted +// - it there is no output, `out` CAN be omitted func Build(fn interface{}, service config.Service) (*Handler, error) { var ( h = &Handler{ @@ -38,7 +38,7 @@ func Build(fn interface{}, service config.Service) (*Handler, error) { ) if fnType.Kind() != reflect.Func { - return nil, errHandlerNotFunc + return nil, ErrHandlerNotFunc } if err := h.signature.ValidateInput(fnType); err != nil { return nil, fmt.Errorf("input: %w", err) diff --git a/internal/dynfunc/signature.go b/internal/dynfunc/signature.go index d5c30aa..4fdb474 100644 --- a/internal/dynfunc/signature.go +++ b/internal/dynfunc/signature.go @@ -53,47 +53,47 @@ func (s *Signature) ValidateInput(handlerType reflect.Type) error { // missing or invalid first arg: context.Context if handlerType.NumIn() < 1 { - return errMissingHandlerContextArgument + return ErrMissingHandlerContextArgument } firstArgType := handlerType.In(0) if !firstArgType.Implements(ctxType) { - return fmt.Errorf("fock") + return ErrInvalidHandlerContextArgument } // no input required if len(s.Input) == 0 { // input struct provided if handlerType.NumIn() > 1 { - return errUnexpectedInput + return ErrUnexpectedInput } return nil } // too much arguments - if handlerType.NumIn() > 2 { - return errMissingHandlerInputArgument + if handlerType.NumIn() != 2 { + return ErrMissingHandlerInputArgument } // arg must be a struct inStruct := handlerType.In(1) if inStruct.Kind() != reflect.Struct { - return errMissingParamArgument + return ErrMissingParamArgument } // check for invalid param for name, ptype := range s.Input { if name[0] == strings.ToLower(name)[0] { - return fmt.Errorf("%s: %w", name, errUnexportedName) + return fmt.Errorf("%s: %w", name, ErrUnexportedName) } field, exists := inStruct.FieldByName(name) if !exists { - return fmt.Errorf("%s: %w", name, errMissingConfigArgument) + return fmt.Errorf("%s: %w", name, ErrMissingConfigArgument) } if !ptype.AssignableTo(field.Type) { - return fmt.Errorf("%s: %w (%s instead of %s)", name, errWrongParamTypeFromConfig, field.Type, ptype) + return fmt.Errorf("%s: %w (%s instead of %s)", name, ErrWrongParamTypeFromConfig, field.Type, ptype) } } @@ -105,44 +105,44 @@ func (s Signature) ValidateOutput(handlerType reflect.Type) error { errType := reflect.TypeOf(api.ErrUnknown) if handlerType.NumOut() < 1 { - return errMissingHandlerErrorArgument + return ErrMissingHandlerErrorArgument } // last output must be api.Err lastArgType := handlerType.Out(handlerType.NumOut() - 1) if !lastArgType.AssignableTo(errType) { - return errMissingHandlerErrorArgument + return ErrInvalidHandlerErrorArgument } - // no output -> ok + // no output required -> ok if len(s.Output) == 0 { return nil } if handlerType.NumOut() < 2 { - return errMissingHandlerOutputArgument + return ErrMissingHandlerOutputArgument } // fail if first output is not a pointer to struct outStructPtr := handlerType.Out(0) if outStructPtr.Kind() != reflect.Ptr { - return errWrongOutputArgumentType + return ErrWrongOutputArgumentType } outStruct := outStructPtr.Elem() if outStruct.Kind() != reflect.Struct { - return errWrongOutputArgumentType + return ErrWrongOutputArgumentType } // fail on invalid output for name, ptype := range s.Output { if name[0] == strings.ToLower(name)[0] { - return fmt.Errorf("%s: %w", name, errUnexportedName) + return fmt.Errorf("%s: %w", name, ErrUnexportedName) } field, exists := outStruct.FieldByName(name) if !exists { - return fmt.Errorf("%s: %w", name, errMissingConfigArgument) + return fmt.Errorf("%s: %w", name, ErrMissingConfigArgument) } // ignore types evalutating to nil @@ -151,7 +151,7 @@ func (s Signature) ValidateOutput(handlerType reflect.Type) error { } if !field.Type.ConvertibleTo(ptype) { - return fmt.Errorf("%s: %w (%s instead of %s)", name, errWrongParamTypeFromConfig, field.Type, ptype) + return fmt.Errorf("%s: %w (%s instead of %s)", name, ErrWrongParamTypeFromConfig, field.Type, ptype) } } diff --git a/internal/dynfunc/signature_test.go b/internal/dynfunc/signature_test.go index 3554abf..6669643 100644 --- a/internal/dynfunc/signature_test.go +++ b/internal/dynfunc/signature_test.go @@ -8,382 +8,562 @@ import ( "testing" "github.com/xdrm-io/aicra/api" + "github.com/xdrm-io/aicra/internal/config" ) -func TestInputCheck(t *testing.T) { - tcases := []struct { - Name string - Input map[string]reflect.Type - Fn interface{} - FnCtx interface{} - Err error +func TestInputValidation(t *testing.T) { + tt := []struct { + name string + input map[string]reflect.Type + fn interface{} + err error }{ { - Name: "no input 0 given", - Input: map[string]reflect.Type{}, - Fn: func(context.Context) {}, - FnCtx: func(context.Context) {}, - Err: nil, + name: "missing context", + input: map[string]reflect.Type{}, + fn: func() {}, + err: ErrMissingHandlerContextArgument, }, { - Name: "no input 1 given", - Input: map[string]reflect.Type{}, - Fn: func(context.Context, int) {}, - FnCtx: func(context.Context, int) {}, - Err: errUnexpectedInput, + name: "invalid context", + input: map[string]reflect.Type{}, + fn: func(int) {}, + err: ErrInvalidHandlerContextArgument, }, { - Name: "no input 2 given", - Input: map[string]reflect.Type{}, - Fn: func(context.Context, int, string) {}, - FnCtx: func(context.Context, int, string) {}, - Err: errUnexpectedInput, + name: "no input 0 given", + input: map[string]reflect.Type{}, + fn: func(context.Context) {}, + err: nil, }, { - Name: "1 input 0 given", - Input: map[string]reflect.Type{ + name: "no input 1 given", + input: map[string]reflect.Type{}, + fn: func(context.Context, int) {}, + err: ErrUnexpectedInput, + }, + { + name: "no input 2 given", + input: map[string]reflect.Type{}, + fn: func(context.Context, int, string) {}, + err: ErrUnexpectedInput, + }, + { + name: "1 input 0 given", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func(context.Context) {}, - FnCtx: func(context.Context) {}, - Err: errMissingHandlerInputArgument, + fn: func(context.Context) {}, + err: ErrMissingHandlerInputArgument, }, { - Name: "1 input non-struct given", - Input: map[string]reflect.Type{ + name: "1 input non-struct given", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func(context.Context, int) {}, - FnCtx: func(context.Context, int) {}, - Err: errMissingParamArgument, + fn: func(context.Context, int) {}, + err: ErrMissingParamArgument, }, { - Name: "unexported input", - Input: map[string]reflect.Type{ + name: "unexported input", + input: map[string]reflect.Type{ "test1": reflect.TypeOf(int(0)), }, - Fn: func(context.Context, struct{}) {}, - FnCtx: func(context.Context, struct{}) {}, - Err: errUnexportedName, + fn: func(context.Context, struct{}) {}, + err: ErrUnexportedName, }, { - Name: "1 input empty struct given", - Input: map[string]reflect.Type{ + name: "1 input empty struct given", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func(context.Context, struct{}) {}, - FnCtx: func(context.Context, struct{}) {}, - Err: errMissingConfigArgument, + fn: func(context.Context, struct{}) {}, + err: ErrMissingConfigArgument, }, { - Name: "1 input invalid given", - Input: map[string]reflect.Type{ + name: "1 input invalid given", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func(context.Context, struct{ Test1 string }) {}, - FnCtx: func(context.Context, struct{ Test1 string }) {}, - Err: errWrongParamTypeFromConfig, + fn: func(context.Context, struct{ Test1 string }) {}, + err: ErrWrongParamTypeFromConfig, }, { - Name: "1 input valid given", - Input: map[string]reflect.Type{ + name: "1 input valid given", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func(context.Context, struct{ Test1 int }) {}, - FnCtx: func(context.Context, struct{ Test1 int }) {}, - Err: nil, + fn: func(context.Context, struct{ Test1 int }) {}, + err: nil, }, { - Name: "1 input ptr empty struct given", - Input: map[string]reflect.Type{ + name: "1 input ptr empty struct given", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(int)), }, - Fn: func(context.Context, struct{}) {}, - FnCtx: func(context.Context, struct{}) {}, - Err: errMissingConfigArgument, + fn: func(context.Context, struct{}) {}, + err: ErrMissingConfigArgument, }, { - Name: "1 input ptr invalid given", - Input: map[string]reflect.Type{ + name: "1 input ptr invalid given", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(int)), }, - Fn: func(context.Context, struct{ Test1 string }) {}, - FnCtx: func(context.Context, struct{ Test1 string }) {}, - Err: errWrongParamTypeFromConfig, + fn: func(context.Context, struct{ Test1 string }) {}, + err: ErrWrongParamTypeFromConfig, }, { - Name: "1 input ptr invalid ptr type given", - Input: map[string]reflect.Type{ + name: "1 input ptr invalid ptr type given", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(int)), }, - Fn: func(context.Context, struct{ Test1 *string }) {}, - FnCtx: func(context.Context, struct{ Test1 *string }) {}, - Err: errWrongParamTypeFromConfig, + fn: func(context.Context, struct{ Test1 *string }) {}, + err: ErrWrongParamTypeFromConfig, }, { - Name: "1 input ptr valid given", - Input: map[string]reflect.Type{ + name: "1 input ptr valid given", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(int)), }, - Fn: func(context.Context, struct{ Test1 *int }) {}, - FnCtx: func(context.Context, struct{ Test1 *int }) {}, - Err: nil, + fn: func(context.Context, struct{ Test1 *int }) {}, + err: nil, }, { - Name: "1 valid string", - Input: map[string]reflect.Type{ + name: "1 valid string", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf(string("")), }, - Fn: func(context.Context, struct{ Test1 string }) {}, - FnCtx: func(context.Context, struct{ Test1 string }) {}, - Err: nil, + fn: func(context.Context, struct{ Test1 string }) {}, + err: nil, }, { - Name: "1 valid uint", - Input: map[string]reflect.Type{ + name: "1 valid uint", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf(uint(0)), }, - Fn: func(context.Context, struct{ Test1 uint }) {}, - FnCtx: func(context.Context, struct{ Test1 uint }) {}, - Err: nil, + fn: func(context.Context, struct{ Test1 uint }) {}, + err: nil, }, { - Name: "1 valid float64", - Input: map[string]reflect.Type{ + name: "1 valid float64", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf(float64(0)), }, - Fn: func(context.Context, struct{ Test1 float64 }) {}, - FnCtx: func(context.Context, struct{ Test1 float64 }) {}, - Err: nil, + fn: func(context.Context, struct{ Test1 float64 }) {}, + err: nil, }, { - Name: "1 valid []byte", - Input: map[string]reflect.Type{ + name: "1 valid []byte", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf([]byte("")), }, - Fn: func(context.Context, struct{ Test1 []byte }) {}, - FnCtx: func(context.Context, struct{ Test1 []byte }) {}, - Err: nil, + fn: func(context.Context, struct{ Test1 []byte }) {}, + err: nil, }, { - Name: "1 valid []rune", - Input: map[string]reflect.Type{ + name: "1 valid []rune", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf([]rune("")), }, - Fn: func(context.Context, struct{ Test1 []rune }) {}, - FnCtx: func(context.Context, struct{ Test1 []rune }) {}, - Err: nil, + fn: func(context.Context, struct{ Test1 []rune }) {}, + err: nil, }, { - Name: "1 valid *string", - Input: map[string]reflect.Type{ + name: "1 valid *string", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(string)), }, - Fn: func(context.Context, struct{ Test1 *string }) {}, - FnCtx: func(context.Context, struct{ Test1 *string }) {}, - Err: nil, + fn: func(context.Context, struct{ Test1 *string }) {}, + err: nil, }, { - Name: "1 valid *uint", - Input: map[string]reflect.Type{ + name: "1 valid *uint", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(uint)), }, - Fn: func(context.Context, struct{ Test1 *uint }) {}, - FnCtx: func(context.Context, struct{ Test1 *uint }) {}, - Err: nil, + fn: func(context.Context, struct{ Test1 *uint }) {}, + err: nil, }, { - Name: "1 valid *float64", - Input: map[string]reflect.Type{ + name: "1 valid *float64", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new(float64)), }, - Fn: func(context.Context, struct{ Test1 *float64 }) {}, - FnCtx: func(context.Context, struct{ Test1 *float64 }) {}, - Err: nil, + fn: func(context.Context, struct{ Test1 *float64 }) {}, + err: nil, }, { - Name: "1 valid *[]byte", - Input: map[string]reflect.Type{ + name: "1 valid *[]byte", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new([]byte)), }, - Fn: func(context.Context, struct{ Test1 *[]byte }) {}, - FnCtx: func(context.Context, struct{ Test1 *[]byte }) {}, - Err: nil, + fn: func(context.Context, struct{ Test1 *[]byte }) {}, + err: nil, }, { - Name: "1 valid *[]rune", - Input: map[string]reflect.Type{ + name: "1 valid *[]rune", + input: map[string]reflect.Type{ "Test1": reflect.TypeOf(new([]rune)), }, - Fn: func(context.Context, struct{ Test1 *[]rune }) {}, - FnCtx: func(context.Context, struct{ Test1 *[]rune }) {}, - Err: nil, + fn: func(context.Context, struct{ Test1 *[]rune }) {}, + err: nil, }, } - for _, tcase := range tcases { - t.Run(tcase.Name, func(t *testing.T) { - t.Parallel() - + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { // mock spec s := Signature{ - Input: tcase.Input, + Input: tc.input, Output: nil, } - err := s.ValidateInput(reflect.TypeOf(tcase.FnCtx)) - if err == nil && tcase.Err != nil { - t.Errorf("expected an error: '%s'", tcase.Err.Error()) - t.FailNow() + err := s.ValidateInput(reflect.TypeOf(tc.fn)) + if err == nil && tc.err != nil { + t.Fatalf("expected an error: '%s'", tc.err.Error()) } - if err != nil && tcase.Err == nil { - t.Errorf("unexpected error: '%s'", err.Error()) - t.FailNow() + if err != nil && tc.err == nil { + t.Fatalf("unexpected error: '%s'", err.Error()) } - if err != nil && tcase.Err != nil { - if !errors.Is(err, tcase.Err) { - t.Errorf("expected the error <%s> got <%s>", tcase.Err, err) - t.FailNow() + if err != nil && tc.err != nil { + if !errors.Is(err, tc.err) { + t.Fatalf("expected the error <%s> got <%s>", tc.err, err) } } }) } } -func TestOutputCheck(t *testing.T) { - tcases := []struct { - Output map[string]reflect.Type - Fn interface{} - Err error +func TestOutputValidation(t *testing.T) { + tt := []struct { + name string + output map[string]reflect.Type + fn interface{} + err error }{ - // no input -> missing api.Err { - Output: map[string]reflect.Type{}, - Fn: func(context.Context) {}, - Err: errMissingHandlerOutputArgument, + name: "no output missing err", + output: map[string]reflect.Type{}, + fn: func() {}, + err: ErrMissingHandlerErrorArgument, }, - // no input -> with last type not api.Err { - Output: map[string]reflect.Type{}, - Fn: func(context.Context) bool { return true }, - Err: errMissingHandlerErrorArgument, + name: "no output invalid err", + output: map[string]reflect.Type{}, + fn: func() bool { return true }, + err: ErrInvalidHandlerErrorArgument, }, - // no input -> with api.Err { - Output: map[string]reflect.Type{}, - Fn: func(context.Context) api.Err { return api.ErrSuccess }, - Err: nil, + name: "1 output none required", + output: map[string]reflect.Type{}, + fn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + err: nil, }, - // no input -> missing context.Context { - Output: map[string]reflect.Type{}, - Fn: func(context.Context) api.Err { return api.ErrSuccess }, - Err: errMissingHandlerContextArgument, - }, - // no input -> invlaid context.Context type - { - Output: map[string]reflect.Type{}, - Fn: func(context.Context, int) api.Err { return api.ErrSuccess }, - Err: errMissingHandlerContextArgument, - }, - // func can have output if not specified - { - Output: map[string]reflect.Type{}, - Fn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, - Err: nil, - }, - // missing output struct in func - { - Output: map[string]reflect.Type{ + name: "no output 1 required", + output: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func() api.Err { return api.ErrSuccess }, - Err: errWrongOutputArgumentType, + fn: func() api.Err { return api.ErrSuccess }, + err: ErrMissingHandlerOutputArgument, }, - // output not a pointer { - Output: map[string]reflect.Type{ + name: "invalid int output", + output: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func() (int, api.Err) { return 0, api.ErrSuccess }, - Err: errWrongOutputArgumentType, + fn: func() (int, api.Err) { return 0, api.ErrSuccess }, + err: ErrWrongOutputArgumentType, }, - // output not a pointer to struct { - Output: map[string]reflect.Type{ + name: "invalid int ptr output", + output: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func() (*int, api.Err) { return nil, api.ErrSuccess }, - Err: errWrongOutputArgumentType, + fn: func() (*int, api.Err) { return nil, api.ErrSuccess }, + err: ErrWrongOutputArgumentType, }, - // unexported param name { - Output: map[string]reflect.Type{ + name: "invalid struct output", + output: map[string]reflect.Type{ + "Test1": reflect.TypeOf(int(0)), + }, + fn: func() (struct{ Test1 int }, api.Err) { return struct{ Test1 int }{Test1: 1}, api.ErrSuccess }, + err: ErrWrongOutputArgumentType, + }, + { + name: "unexported param", + output: map[string]reflect.Type{ "test1": reflect.TypeOf(int(0)), }, - Fn: func() (*struct{}, api.Err) { return nil, api.ErrSuccess }, - Err: errUnexportedName, + fn: func() (*struct{}, api.Err) { return nil, api.ErrSuccess }, + err: ErrUnexportedName, }, - // output field missing { - Output: map[string]reflect.Type{ + name: "missing output param", + output: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func() (*struct{}, api.Err) { return nil, api.ErrSuccess }, - Err: errMissingConfigArgument, + fn: func() (*struct{}, api.Err) { return nil, api.ErrSuccess }, + err: ErrMissingConfigArgument, }, - // output field invalid type { - Output: map[string]reflect.Type{ + name: "invalid output param", + output: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func() (*struct{ Test1 string }, api.Err) { return nil, api.ErrSuccess }, - Err: errWrongParamTypeFromConfig, + fn: func() (*struct{ Test1 string }, api.Err) { return nil, api.ErrSuccess }, + err: ErrWrongParamTypeFromConfig, }, - // output field valid type { - Output: map[string]reflect.Type{ + name: "valid param", + output: map[string]reflect.Type{ "Test1": reflect.TypeOf(int(0)), }, - Fn: func() (*struct{ Test1 int }, api.Err) { return nil, api.ErrSuccess }, - Err: nil, + fn: func() (*struct{ Test1 int }, api.Err) { return nil, api.ErrSuccess }, + err: nil, }, - // ignore type check on nil type { - Output: map[string]reflect.Type{ + name: "2 valid params", + output: map[string]reflect.Type{ + "Test1": reflect.TypeOf(int(0)), + "Test2": reflect.TypeOf(string("")), + }, + fn: func() (*struct { + Test1 int + Test2 string + }, api.Err) { + return nil, api.ErrSuccess + }, + err: nil, + }, + { + name: "nil type ignore typecheck", + output: map[string]reflect.Type{ "Test1": nil, }, - Fn: func() (*struct{ Test1 int }, api.Err) { return nil, api.ErrSuccess }, - Err: nil, + fn: func() (*struct{ Test1 int }, api.Err) { return nil, api.ErrSuccess }, + err: nil, }, } - for i, tcase := range tcases { - t.Run(fmt.Sprintf("case.%d", i), func(t *testing.T) { - t.Parallel() - + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { // mock spec s := Signature{ Input: nil, - Output: tcase.Output, + Output: tc.output, } - err := s.ValidateOutput(reflect.TypeOf(tcase.Fn)) - if err == nil && tcase.Err != nil { - t.Errorf("expected an error: '%s'", tcase.Err.Error()) - t.FailNow() - } - if err != nil && tcase.Err == nil { - t.Errorf("unexpected error: '%s'", err.Error()) - t.FailNow() - } - - if err != nil && tcase.Err != nil { - if !errors.Is(err, tcase.Err) { - t.Errorf("expected the error <%s> got <%s>", tcase.Err, err) - t.FailNow() - } + err := s.ValidateOutput(reflect.TypeOf(tc.fn)) + if !errors.Is(err, tc.err) { + t.Fatalf("expected the error <%s> got <%s>", tc.err, err) + } + }) + } +} + +func TestServiceValidation(t *testing.T) { + + tt := []struct { + name string + in []*config.Parameter + out []*config.Parameter + fn interface{} + err error + }{ + { + name: "missing context", + fn: func() {}, + err: ErrMissingHandlerContextArgument, + }, + { + name: "invalid context", + fn: func(int) {}, + err: ErrInvalidHandlerContextArgument, + }, + { + name: "missing error", + fn: func(context.Context) {}, + err: ErrMissingHandlerErrorArgument, + }, + { + name: "invalid error", + fn: func(context.Context) int { return 1 }, + err: ErrInvalidHandlerErrorArgument, + }, + { + name: "no in no out", + fn: func(context.Context) api.Err { return api.ErrSuccess }, + err: nil, + }, + { + name: "unamed in", + in: []*config.Parameter{ + { + Rename: "", // should be ignored + GoType: reflect.TypeOf(int(0)), + }, + }, + fn: func(context.Context) api.Err { return api.ErrSuccess }, + err: nil, + }, + { + name: "missing in", + in: []*config.Parameter{ + { + Rename: "Test1", + GoType: reflect.TypeOf(int(0)), + }, + }, + fn: func(context.Context) api.Err { return api.ErrSuccess }, + err: ErrMissingHandlerInputArgument, + }, + { + name: "valid in", + in: []*config.Parameter{ + { + Rename: "Test1", + GoType: reflect.TypeOf(int(0)), + }, + }, + fn: func(context.Context, struct{ Test1 int }) api.Err { return api.ErrSuccess }, + err: nil, + }, + { + name: "optional in not ptr", + in: []*config.Parameter{ + { + Rename: "Test1", + GoType: reflect.TypeOf(int(0)), + Optional: true, + }, + }, + fn: func(context.Context, struct{ Test1 int }) api.Err { return api.ErrSuccess }, + err: ErrWrongParamTypeFromConfig, + }, + { + name: "valid optional in", + in: []*config.Parameter{ + { + Rename: "Test1", + GoType: reflect.TypeOf(int(0)), + Optional: true, + }, + }, + fn: func(context.Context, struct{ Test1 *int }) api.Err { return api.ErrSuccess }, + err: nil, + }, + + { + name: "unamed out", + out: []*config.Parameter{ + { + Rename: "", // should be ignored + GoType: reflect.TypeOf(int(0)), + }, + }, + fn: func(context.Context) api.Err { return api.ErrSuccess }, + err: nil, + }, + { + name: "missing out struct", + out: []*config.Parameter{ + { + Rename: "Test1", + GoType: reflect.TypeOf(int(0)), + }, + }, + fn: func(context.Context) api.Err { return api.ErrSuccess }, + err: ErrMissingHandlerOutputArgument, + }, + { + name: "invalid out struct type", + out: []*config.Parameter{ + { + Rename: "Test1", + GoType: reflect.TypeOf(int(0)), + }, + }, + fn: func(context.Context) (int, api.Err) { return 0, api.ErrSuccess }, + err: ErrWrongOutputArgumentType, + }, + { + name: "missing out", + out: []*config.Parameter{ + { + Rename: "Test1", + GoType: reflect.TypeOf(int(0)), + }, + }, + fn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, + err: ErrMissingConfigArgument, + }, + { + name: "valid out", + out: []*config.Parameter{ + { + Rename: "Test1", + GoType: reflect.TypeOf(int(0)), + }, + }, + fn: func(context.Context) (*struct{ Test1 int }, api.Err) { return nil, api.ErrSuccess }, + err: nil, + }, + { + name: "optional out not ptr", + out: []*config.Parameter{ + { + Rename: "Test1", + GoType: reflect.TypeOf(int(0)), + Optional: true, + }, + }, + fn: func(context.Context) (*struct{ Test1 *int }, api.Err) { return nil, api.ErrSuccess }, + err: ErrWrongParamTypeFromConfig, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + service := config.Service{ + Input: make(map[string]*config.Parameter), + Output: make(map[string]*config.Parameter), + } + + // fill service with arguments + if tc.in != nil && len(tc.in) > 0 { + for i, in := range tc.in { + service.Input[fmt.Sprintf("%d", i)] = in + } + } + if tc.out != nil && len(tc.out) > 0 { + for i, out := range tc.out { + service.Output[fmt.Sprintf("%d", i)] = out + } + } + + s := BuildSignature(service) + + err := s.ValidateInput(reflect.TypeOf(tc.fn)) + if err != nil { + if !errors.Is(err, tc.err) { + t.Fatalf("expected the error <%s> got <%s>", tc.err, err) + } + return + } + err = s.ValidateOutput(reflect.TypeOf(tc.fn)) + if err != nil { + if !errors.Is(err, tc.err) { + t.Fatalf("expected the error <%s> got <%s>", tc.err, err) + } + return + } + + // no error encountered but expected 1 + if tc.err != nil { + t.Fatalf("expected an error <%v>", tc.err) } }) } diff --git a/internal/reqdata/set.go b/internal/reqdata/set.go index 2b217dd..2024752 100644 --- a/internal/reqdata/set.go +++ b/internal/reqdata/set.go @@ -66,17 +66,24 @@ func (i *T) GetQuery(req http.Request) error { query := req.URL.Query() for name, param := range i.service.Query { - value, exist := query[name] - - if !exist && !param.Optional { - return fmt.Errorf("%s: %w", name, ErrMissingRequiredParam) - } + values, exist := query[name] if !exist { + if !param.Optional { + return fmt.Errorf("%s: %w", name, ErrMissingRequiredParam) + } continue } - parsed := parseParameter(value) + var parsed interface{} + + // consider element instead of slice or elements when only 1 + if len(values) == 1 { + parsed = parseParameter(values[0]) + } else { // consider slice + parsed = parseParameter(values) + } + cast, valid := param.Validator(parsed) if !valid { return fmt.Errorf("%s: %w", name, ErrInvalidType) @@ -99,17 +106,32 @@ func (i *T) GetForm(req http.Request) error { ct := req.Header.Get("Content-Type") switch { case strings.HasPrefix(ct, "application/json"): - return i.parseJSON(req) + err := i.parseJSON(req) + if err != nil { + return err + } case strings.HasPrefix(ct, "application/x-www-form-urlencoded"): - return i.parseUrlencoded(req) + err := i.parseUrlencoded(req) + if err != nil { + return err + } case strings.HasPrefix(ct, "multipart/form-data; boundary="): - return i.parseMultipart(req) - - default: - return nil + err := i.parseMultipart(req) + if err != nil { + return err + } } + + // fail on at least 1 mandatory form param when there is no body + for name, param := range i.service.Form { + _, exists := i.Data[param.Rename] + if !exists && !param.Optional { + return fmt.Errorf("%s: %w", name, ErrMissingRequiredParam) + } + } + return nil } // parseJSON parses JSON from the request body inside 'Form' @@ -129,10 +151,6 @@ func (i *T) parseJSON(req http.Request) error { for name, param := range i.service.Form { value, exist := parsed[name] - if !exist && !param.Optional { - return fmt.Errorf("%s: %w", name, ErrMissingRequiredParam) - } - if !exist { continue } @@ -155,17 +173,21 @@ func (i *T) parseUrlencoded(req http.Request) error { } for name, param := range i.service.Form { - value, exist := req.PostForm[name] - - if !exist && !param.Optional { - return fmt.Errorf("%s: %w", name, ErrMissingRequiredParam) - } + values, exist := req.PostForm[name] if !exist { continue } - parsed := parseParameter(value) + var parsed interface{} + + // consider element instead of slice or elements when only 1 + if len(values) == 1 { + parsed = parseParameter(values[0]) + } else { // consider slice + parsed = parseParameter(values) + } + cast, valid := param.Validator(parsed) if !valid { return fmt.Errorf("%s: %w", name, ErrInvalidType) @@ -185,7 +207,7 @@ func (i *T) parseMultipart(req http.Request) error { return nil } if err != nil { - return err + return fmt.Errorf("%s: %w", err, ErrInvalidMultipart) } err = mpr.Parse() @@ -196,10 +218,6 @@ func (i *T) parseMultipart(req http.Request) error { for name, param := range i.service.Form { component, exist := mpr.Data[name] - if !exist && !param.Optional { - return fmt.Errorf("%s: %w", name, ErrMissingRequiredParam) - } - if !exist { continue } diff --git a/internal/reqdata/set_test.go b/internal/reqdata/set_test.go index b71a7be..2651a2d 100644 --- a/internal/reqdata/set_test.go +++ b/internal/reqdata/set_test.go @@ -135,13 +135,11 @@ func TestStoreWithUri(t *testing.T) { if err != nil { if test.Err != nil { if !errors.Is(err, test.Err) { - t.Errorf("expected error <%s>, got <%s>", test.Err, err) - t.FailNow() + t.Fatalf("expected error <%s>, got <%s>", test.Err, err) } return } - t.Errorf("unexpected error <%s>", err) - t.FailNow() + t.Fatalf("unexpected error <%s>", err) } if len(store.Data) != len(service.Input) { @@ -183,14 +181,14 @@ func TestExtractQuery(t *testing.T) { Query: "a", Err: nil, ParamNames: []string{"a"}, - ParamValues: [][]string{[]string{""}}, + ParamValues: [][]string{{""}}, }, { ServiceParam: []string{"a"}, Query: "a&b", Err: nil, ParamNames: []string{"a"}, - ParamValues: [][]string{[]string{""}}, + ParamValues: [][]string{{""}}, }, { ServiceParam: []string{"a", "missing"}, @@ -204,40 +202,40 @@ func TestExtractQuery(t *testing.T) { Query: "a&b", Err: nil, ParamNames: []string{"a", "b"}, - ParamValues: [][]string{[]string{""}, []string{""}}, + ParamValues: [][]string{{""}, {""}}, }, { ServiceParam: []string{"a"}, Err: nil, Query: "a=", ParamNames: []string{"a"}, - ParamValues: [][]string{[]string{""}}, + ParamValues: [][]string{{""}}, }, { ServiceParam: []string{"a", "b"}, Err: nil, Query: "a=&b=x", ParamNames: []string{"a", "b"}, - ParamValues: [][]string{[]string{""}, []string{"x"}}, + ParamValues: [][]string{{""}, {"x"}}, }, { ServiceParam: []string{"a", "c"}, Err: nil, Query: "a=b&c=d", ParamNames: []string{"a", "c"}, - ParamValues: [][]string{[]string{"b"}, []string{"d"}}, + ParamValues: [][]string{{"b"}, {"d"}}, }, { ServiceParam: []string{"a", "c"}, Err: nil, Query: "a=b&c=d&a=x", ParamNames: []string{"a", "c"}, - ParamValues: [][]string{[]string{"b", "x"}, []string{"d"}}, + ParamValues: [][]string{{"b", "x"}, {"d"}}, }, } for i, test := range tests { - t.Run(fmt.Sprintf("request.%d", i), func(t *testing.T) { + t.Run(fmt.Sprintf("request[%d]", i), func(t *testing.T) { store := New(getServiceWithQuery(test.ServiceParam...)) @@ -246,19 +244,16 @@ func TestExtractQuery(t *testing.T) { if err != nil { if test.Err != nil { if !errors.Is(err, test.Err) { - t.Errorf("expected error <%s>, got <%s>", test.Err, err) - t.FailNow() + t.Fatalf("expected error <%s>, got <%s>", test.Err, err) } return } - t.Errorf("unexpected error <%s>", err) - t.FailNow() + t.Fatalf("unexpected error <%s>", err) } if test.ParamNames == nil || test.ParamValues == nil { if len(store.Data) != 0 { - t.Errorf("expected no GET parameters and got %d", len(store.Data)) - t.FailNow() + t.Fatalf("expected no GET parameters and got %d", len(store.Data)) } // no param to check @@ -266,8 +261,7 @@ func TestExtractQuery(t *testing.T) { } if len(test.ParamNames) != len(test.ParamValues) { - t.Errorf("invalid test: names and values differ in size (%d vs %d)", len(test.ParamNames), len(test.ParamValues)) - t.FailNow() + t.Fatalf("invalid test: names and values differ in size (%d vs %d)", len(test.ParamNames), len(test.ParamValues)) } for pi, pName := range test.ParamNames { @@ -276,29 +270,35 @@ func TestExtractQuery(t *testing.T) { t.Run(pName, func(t *testing.T) { param, isset := store.Data[pName] if !isset { - t.Errorf("param does not exist") - t.FailNow() + t.Fatalf("param does not exist") } + // single value, should return a single element + if len(values) == 1 { + cast, canCast := param.(string) + if !canCast { + t.Fatalf("should return a string (got '%v')", cast) + } + if values[0] != cast { + t.Fatalf("should return '%s' (got '%s')", values[0], cast) + } + return + } + + // multiple values, should return a slice cast, canCast := param.([]interface{}) if !canCast { - t.Errorf("should return a []string (got '%v')", cast) - t.FailNow() + t.Fatalf("should return a []string (got '%v')", cast) } if len(cast) != len(values) { - t.Errorf("should return %d string(s) (got '%d')", len(values), len(cast)) - t.FailNow() + t.Fatalf("should return %d string(s) (got '%d')", len(values), len(cast)) } for vi, value := range values { - - t.Run(fmt.Sprintf("value.%d", vi), func(t *testing.T) { - if value != cast[vi] { - t.Errorf("should return '%s' (got '%s')", value, cast[vi]) - t.FailNow() - } - }) + if value != cast[vi] { + t.Fatalf("should return '%s' (got '%s')", value, cast[vi]) + } } }) @@ -326,9 +326,7 @@ func TestStoreWithUrlEncodedFormParseError(t *testing.T) { store := New(nil) err := store.GetForm(*req) if err == nil { - t.Errorf("expected malformed urlencoded to have FailNow being parsed (got %d elements)", len(store.Data)) - t.FailNow() - + t.Fatalf("expected malformed urlencoded to have FailNow being parsed (got %d elements)", len(store.Data)) } } func TestExtractFormUrlEncoded(t *testing.T) { @@ -359,14 +357,14 @@ func TestExtractFormUrlEncoded(t *testing.T) { URLEncoded: "a", Err: nil, ParamNames: []string{"a"}, - ParamValues: [][]string{[]string{""}}, + ParamValues: [][]string{{""}}, }, { ServiceParams: []string{"a"}, URLEncoded: "a&b", Err: nil, ParamNames: []string{"a"}, - ParamValues: [][]string{[]string{""}}, + ParamValues: [][]string{{""}}, }, { ServiceParams: []string{"a", "missing"}, @@ -380,35 +378,35 @@ func TestExtractFormUrlEncoded(t *testing.T) { URLEncoded: "a&b", Err: nil, ParamNames: []string{"a", "b"}, - ParamValues: [][]string{[]string{""}, []string{""}}, + ParamValues: [][]string{{""}, {""}}, }, { ServiceParams: []string{"a"}, Err: nil, URLEncoded: "a=", ParamNames: []string{"a"}, - ParamValues: [][]string{[]string{""}}, + ParamValues: [][]string{{""}}, }, { ServiceParams: []string{"a", "b"}, Err: nil, URLEncoded: "a=&b=x", ParamNames: []string{"a", "b"}, - ParamValues: [][]string{[]string{""}, []string{"x"}}, + ParamValues: [][]string{{""}, {"x"}}, }, { ServiceParams: []string{"a", "c"}, Err: nil, URLEncoded: "a=b&c=d", ParamNames: []string{"a", "c"}, - ParamValues: [][]string{[]string{"b"}, []string{"d"}}, + ParamValues: [][]string{{"b"}, {"d"}}, }, { ServiceParams: []string{"a", "c"}, Err: nil, URLEncoded: "a=b&c=d&a=x", ParamNames: []string{"a", "c"}, - ParamValues: [][]string{[]string{"b", "x"}, []string{"d"}}, + ParamValues: [][]string{{"b", "x"}, {"d"}}, }, } @@ -424,19 +422,16 @@ func TestExtractFormUrlEncoded(t *testing.T) { if err != nil { if test.Err != nil { if !errors.Is(err, test.Err) { - t.Errorf("expected error <%s>, got <%s>", test.Err, err) - t.FailNow() + t.Fatalf("expected error <%s>, got <%s>", test.Err, err) } return } - t.Errorf("unexpected error <%s>", err) - t.FailNow() + t.Fatalf("unexpected error <%s>", err) } if test.ParamNames == nil || test.ParamValues == nil { if len(store.Data) != 0 { - t.Errorf("expected no GET parameters and got %d", len(store.Data)) - t.FailNow() + t.Fatalf("expected no GET parameters and got %d", len(store.Data)) } // no param to check @@ -444,8 +439,7 @@ func TestExtractFormUrlEncoded(t *testing.T) { } if len(test.ParamNames) != len(test.ParamValues) { - t.Errorf("invalid test: names and values differ in size (%d vs %d)", len(test.ParamNames), len(test.ParamValues)) - t.FailNow() + t.Fatalf("invalid test: names and values differ in size (%d vs %d)", len(test.ParamNames), len(test.ParamValues)) } for pi, key := range test.ParamNames { @@ -454,29 +448,35 @@ func TestExtractFormUrlEncoded(t *testing.T) { t.Run(key, func(t *testing.T) { param, isset := store.Data[key] if !isset { - t.Errorf("param does not exist") - t.FailNow() + t.Fatalf("param does not exist") } + // single value, should return a single element + if len(values) == 1 { + cast, canCast := param.(string) + if !canCast { + t.Fatalf("should return a string (got '%v')", cast) + } + if values[0] != cast { + t.Fatalf("should return '%s' (got '%s')", values[0], cast) + } + return + } + + // multiple values, should return a slice cast, canCast := param.([]interface{}) if !canCast { - t.Errorf("should return a []interface{} (got '%v')", cast) - t.FailNow() + t.Fatalf("should return a []string (got '%v')", cast) } if len(cast) != len(values) { - t.Errorf("should return %d string(s) (got '%d')", len(values), len(cast)) - t.FailNow() + t.Fatalf("should return %d string(s) (got '%d')", len(values), len(cast)) } for vi, value := range values { - - t.Run(fmt.Sprintf("value.%d", vi), func(t *testing.T) { - if value != cast[vi] { - t.Errorf("should return '%s' (got '%s')", value, cast[vi]) - t.FailNow() - } - }) + if value != cast[vi] { + t.Fatalf("should return '%s' (got '%s')", value, cast[vi]) + } } }) @@ -567,19 +567,16 @@ func TestJsonParameters(t *testing.T) { if err != nil { if test.Err != nil { if !errors.Is(err, test.Err) { - t.Errorf("expected error <%s>, got <%s>", test.Err, err) - t.FailNow() + t.Fatalf("expected error <%s>, got <%s>", test.Err, err) } return } - t.Errorf("unexpected error <%s>", err) - t.FailNow() + t.Fatalf("unexpected error <%s>", err) } if test.ParamNames == nil || test.ParamValues == nil { if len(store.Data) != 0 { - t.Errorf("expected no JSON parameters and got %d", len(store.Data)) - t.FailNow() + t.Fatalf("expected no JSON parameters and got %d", len(store.Data)) } // no param to check @@ -587,8 +584,7 @@ func TestJsonParameters(t *testing.T) { } if len(test.ParamNames) != len(test.ParamValues) { - t.Errorf("invalid test: names and values differ in size (%d vs %d)", len(test.ParamNames), len(test.ParamValues)) - t.FailNow() + t.Fatalf("invalid test: names and values differ in size (%d vs %d)", len(test.ParamNames), len(test.ParamValues)) } for pi, pName := range test.ParamNames { @@ -599,8 +595,7 @@ func TestJsonParameters(t *testing.T) { param, isset := store.Data[key] if !isset { - t.Errorf("store should contain element with key '%s'", key) - t.FailNow() + t.Fatalf("store should contain element with key '%s'", key) return } @@ -610,13 +605,11 @@ func TestJsonParameters(t *testing.T) { paramValueType := reflect.TypeOf(param) if valueType != paramValueType { - t.Errorf("should be of type %v (got '%v')", valueType, paramValueType) - t.FailNow() + t.Fatalf("should be of type %v (got '%v')", valueType, paramValueType) } if paramValue != value { - t.Errorf("should return %v (got '%v')", value, paramValue) - t.FailNow() + t.Fatalf("should return %v (got '%v')", value, paramValue) } }) @@ -724,19 +717,16 @@ x if err != nil { if test.Err != nil { if !errors.Is(err, test.Err) { - t.Errorf("expected error <%s>, got <%s>", test.Err, err) - t.FailNow() + t.Fatalf("expected error <%s>, got <%s>", test.Err, err) } return } - t.Errorf("unexpected error <%s>", err) - t.FailNow() + t.Fatalf("unexpected error <%s>", err) } if test.ParamNames == nil || test.ParamValues == nil { if len(store.Data) != 0 { - t.Errorf("expected no JSON parameters and got %d", len(store.Data)) - t.FailNow() + t.Fatalf("expected no JSON parameters and got %d", len(store.Data)) } // no param to check @@ -744,8 +734,7 @@ x } if len(test.ParamNames) != len(test.ParamValues) { - t.Errorf("invalid test: names and values differ in size (%d vs %d)", len(test.ParamNames), len(test.ParamValues)) - t.FailNow() + t.Fatalf("invalid test: names and values differ in size (%d vs %d)", len(test.ParamNames), len(test.ParamValues)) } for pi, key := range test.ParamNames { @@ -755,8 +744,7 @@ x param, isset := store.Data[key] if !isset { - t.Errorf("store should contain element with key '%s'", key) - t.FailNow() + t.Fatalf("store should contain element with key '%s'", key) return } @@ -766,13 +754,11 @@ x paramValueType := reflect.TypeOf(param) if valueType != paramValueType { - t.Errorf("should be of type %v (got '%v')", valueType, paramValueType) - t.FailNow() + t.Fatalf("should be of type %v (got '%v')", valueType, paramValueType) } if paramValue != value { - t.Errorf("should return %v (got '%v')", value, paramValue) - t.FailNow() + t.Fatalf("should return %v (got '%v')", value, paramValue) } }) diff --git a/response.go b/response.go new file mode 100644 index 0000000..b4ec028 --- /dev/null +++ b/response.go @@ -0,0 +1,60 @@ +package aicra + +import ( + "encoding/json" + "net/http" + + "github.com/xdrm-io/aicra/api" +) + +// response for an service call +type response struct { + Data map[string]interface{} + Status int + err api.Err +} + +// newResponse creates an empty response. +func newResponse() *response { + return &response{ + Status: http.StatusOK, + Data: make(map[string]interface{}), + err: api.ErrFailure, + } +} + +// WithError sets the response error +func (r *response) WithError(err api.Err) *response { + r.err = err + return r +} + +// WithValue sets a response value +func (r *response) WithValue(name string, value interface{}) *response { + r.Data[name] = value + return r +} + +// MarshalJSON generates the JSON representation of the response +// +// implements json.Marshaler +func (r *response) MarshalJSON() ([]byte, error) { + fmt := make(map[string]interface{}) + for k, v := range r.Data { + fmt[k] = v + } + fmt["error"] = r.err + return json.Marshal(fmt) +} + +// ServeHTTP writes the response representation back to the http.ResponseWriter +// +// implements http.Handler +func (res *response) ServeHTTP(w http.ResponseWriter, r *http.Request) error { + w.WriteHeader(res.err.Status) + encoded, err := json.Marshal(res) + if err == nil { + w.Write(encoded) + } + return err +} diff --git a/response_test.go b/response_test.go new file mode 100644 index 0000000..7807450 --- /dev/null +++ b/response_test.go @@ -0,0 +1,95 @@ +package aicra + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/xdrm-io/aicra/api" +) + +func printEscaped(raw string) string { + raw = strings.ReplaceAll(raw, "\n", "\\n") + raw = strings.ReplaceAll(raw, "\r", "\\r") + return raw +} + +func TestResponseJSON(t *testing.T) { + t.Parallel() + + tt := []struct { + name string + err api.Err + data map[string]interface{} + json string + }{ + { + name: "empty success response", + err: api.ErrSuccess, + data: map[string]interface{}{}, + json: `{"error":{"code":0,"reason":"all right"}}`, + }, + { + name: "empty failure response", + err: api.ErrFailure, + data: map[string]interface{}{}, + json: `{"error":{"code":1,"reason":"it failed"}}`, + }, + { + name: "empty unknown error response", + err: api.ErrUnknown, + data: map[string]interface{}{}, + json: `{"error":{"code":-1,"reason":"unknown error"}}`, + }, + { + name: "success with data before err", + err: api.ErrSuccess, + data: map[string]interface{}{"a": 12}, + json: `{"a":12,"error":{"code":0,"reason":"all right"}}`, + }, + { + name: "success with data right before err", + err: api.ErrSuccess, + data: map[string]interface{}{"e": 12}, + json: `{"e":12,"error":{"code":0,"reason":"all right"}}`, + }, + { + name: "success with data right after err", + err: api.ErrSuccess, + data: map[string]interface{}{"f": 12}, + json: `{"error":{"code":0,"reason":"all right"},"f":12}`, + }, + { + name: "success with data after err", + err: api.ErrSuccess, + data: map[string]interface{}{"z": 12}, + json: `{"error":{"code":0,"reason":"all right"},"z":12}`, + }, + { + name: "success with data around err", + err: api.ErrSuccess, + data: map[string]interface{}{"d": "before", "f": "after"}, + json: `{"d":"before","error":{"code":0,"reason":"all right"},"f":"after"}`, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + res := newResponse().WithError(tc.err) + for k, v := range tc.data { + res.WithValue(k, v) + } + + raw, err := json.Marshal(res) + if err != nil { + t.Fatalf("cannot marshal to json: %s", err) + } + + if string(raw) != tc.json { + t.Fatalf("mismatching json:\nexpect: %v\nactual: %v", printEscaped(tc.json), printEscaped(string(raw))) + } + + }) + } + +} diff --git a/validator/any_test.go b/validator/any_test.go index d612047..4ec04f1 100644 --- a/validator/any_test.go +++ b/validator/any_test.go @@ -2,11 +2,24 @@ package validator_test import ( "fmt" + "reflect" "testing" "github.com/xdrm-io/aicra/validator" ) +func TestAny_ReflectType(t *testing.T) { + t.Parallel() + + var ( + dt = validator.AnyType{} + expected = reflect.TypeOf(interface{}(nil)) + ) + if dt.GoType() != expected { + t.Fatalf("invalid GoType() %v ; expected %v", dt.GoType(), expected) + } +} + func TestAny_AvailableTypes(t *testing.T) { t.Parallel() diff --git a/validator/bool_test.go b/validator/bool_test.go index 8600d34..f34819d 100644 --- a/validator/bool_test.go +++ b/validator/bool_test.go @@ -2,11 +2,24 @@ package validator_test import ( "fmt" + "reflect" "testing" "github.com/xdrm-io/aicra/validator" ) +func TestBool_ReflectType(t *testing.T) { + t.Parallel() + + var ( + dt = validator.BoolType{} + expected = reflect.TypeOf(true) + ) + if dt.GoType() != expected { + t.Fatalf("invalid GoType() %v ; expected %v", dt.GoType(), expected) + } +} + func TestBool_AvailableTypes(t *testing.T) { t.Parallel() diff --git a/validator/float_test.go b/validator/float_test.go index b965441..1a77949 100644 --- a/validator/float_test.go +++ b/validator/float_test.go @@ -3,11 +3,24 @@ package validator_test import ( "fmt" "math" + "reflect" "testing" "github.com/xdrm-io/aicra/validator" ) +func TestFloat64_ReflectType(t *testing.T) { + t.Parallel() + + var ( + dt = validator.FloatType{} + expected = reflect.TypeOf(float64(0.0)) + ) + if dt.GoType() != expected { + t.Fatalf("invalid GoType() %v ; expected %v", dt.GoType(), expected) + } +} + func TestFloat64_AvailableTypes(t *testing.T) { t.Parallel() diff --git a/validator/int_test.go b/validator/int_test.go index 14dc515..dc688fd 100644 --- a/validator/int_test.go +++ b/validator/int_test.go @@ -3,11 +3,24 @@ package validator_test import ( "fmt" "math" + "reflect" "testing" "github.com/xdrm-io/aicra/validator" ) +func TestInt_ReflectType(t *testing.T) { + t.Parallel() + + var ( + dt = validator.IntType{} + expected = reflect.TypeOf(int(0)) + ) + if dt.GoType() != expected { + t.Fatalf("invalid GoType() %v ; expected %v", dt.GoType(), expected) + } +} + func TestInt_AvailableTypes(t *testing.T) { t.Parallel() @@ -71,7 +84,7 @@ func TestInt_Values(t *testing.T) { {uint(math.MaxInt64 + 1), false}, {float64(math.MinInt64), true}, - // we cannot just substract 1 because of how precision works + // we cannot just subtract 1 because of how precision works {float64(math.MinInt64 - 1024 - 1), false}, // WARNING : this is due to how floats are compared diff --git a/validator/string_test.go b/validator/string_test.go index 0a038ef..deb8ec1 100644 --- a/validator/string_test.go +++ b/validator/string_test.go @@ -2,11 +2,24 @@ package validator_test import ( "fmt" + "reflect" "testing" "github.com/xdrm-io/aicra/validator" ) +func TestString_ReflectType(t *testing.T) { + t.Parallel() + + var ( + dt = validator.StringType{} + expected = reflect.TypeOf(string("abc")) + ) + if dt.GoType() != expected { + t.Fatalf("invalid GoType() %v ; expected %v", dt.GoType(), expected) + } +} + func TestString_AvailableTypes(t *testing.T) { t.Parallel() diff --git a/validator/uint_test.go b/validator/uint_test.go index 4fc9be9..d754124 100644 --- a/validator/uint_test.go +++ b/validator/uint_test.go @@ -3,11 +3,24 @@ package validator_test import ( "fmt" "math" + "reflect" "testing" "github.com/xdrm-io/aicra/validator" ) +func TestUint_ReflectType(t *testing.T) { + t.Parallel() + + var ( + dt = validator.UintType{} + expected = reflect.TypeOf(uint(0)) + ) + if dt.GoType() != expected { + t.Fatalf("invalid GoType() %v ; expected %v", dt.GoType(), expected) + } +} + func TestUint_AvailableTypes(t *testing.T) { t.Parallel()