diff --git a/api/adapter.go b/api/adapter.go index 08a9959..646f881 100644 --- a/api/adapter.go +++ b/api/adapter.go @@ -4,3 +4,10 @@ import "net/http" // Adapter to encapsulate incoming requests type Adapter func(http.HandlerFunc) http.HandlerFunc + +// AuthHandlerFunc is http.HandlerFunc with additional Authorization information +type AuthHandlerFunc func(Auth, http.ResponseWriter, *http.Request) + +// AuthAdapter to encapsulate incoming request with access to api.Auth +// to manage permissions +type AuthAdapter func(AuthHandlerFunc) AuthHandlerFunc diff --git a/builder.go b/builder.go index 224ce7c..15d94e4 100644 --- a/builder.go +++ b/builder.go @@ -13,9 +13,10 @@ import ( // Builder for an aicra server type Builder struct { - conf *config.Server - handlers []*apiHandler - adapters []api.Adapter + conf *config.Server + handlers []*apiHandler + adapters []api.Adapter + authAdapters []api.AuthAdapter } // represents an api handler (method-pattern combination) @@ -40,8 +41,8 @@ func (b *Builder) AddType(t datatype.T) error { return nil } -// Use adds an http adapter (middleware) -func (b *Builder) Use(adapter api.Adapter) { +// With adds an http adapter (middleware) +func (b *Builder) With(adapter api.Adapter) { if b.conf == nil { b.conf = &config.Server{} } @@ -51,6 +52,17 @@ func (b *Builder) Use(adapter api.Adapter) { b.adapters = append(b.adapters, adapter) } +// WithAuth adds an http adapter with auth capabilities (middleware) +func (b *Builder) WithAuth(adapter api.AuthAdapter) { + if b.conf == nil { + b.conf = &config.Server{} + } + if b.authAdapters == nil { + b.authAdapters = make([]api.AuthAdapter, 0) + } + b.authAdapters = append(b.authAdapters, adapter) +} + // Setup the builder with its api definition file // panics if already setup func (b *Builder) Setup(r io.Reader) error { diff --git a/builder_test.go b/builder_test.go index 8e3358d..ab98dca 100644 --- a/builder_test.go +++ b/builder_test.go @@ -1,12 +1,8 @@ package aicra import ( - "bytes" - "context" "errors" - "fmt" "net/http" - "net/http/httptest" "strings" "testing" @@ -52,92 +48,6 @@ func TestAddType(t *testing.T) { } } -func TestUse(t *testing.T) { - builder := &Builder{} - if err := addBuiltinTypes(builder); err != nil { - t.Fatalf("unexpected error <%v>", err) - } - - // build @n middlewares that take data from context and increment it - n := 1024 - - type ckey int - const key ckey = 0 - - middleware := func(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - newr := r - - // first time -> store 1 - value := r.Context().Value(key) - if value == nil { - newr = r.WithContext(context.WithValue(r.Context(), key, int(1))) - next(w, newr) - return - } - - // get value and increment - cast, ok := value.(int) - if !ok { - t.Fatalf("value is not an int") - } - cast++ - newr = r.WithContext(context.WithValue(r.Context(), key, cast)) - next(w, newr) - } - } - - // add middleware @n times - for i := 0; i < n; i++ { - builder.Use(middleware) - } - - config := strings.NewReader(`[ { "method": "GET", "path": "/path", "scope": [[]], "info": "info", "in": {}, "out": {} } ]`) - err := builder.Setup(config) - if err != nil { - t.Fatalf("setup: unexpected error <%v>", err) - } - - pathHandler := func(ctx api.Ctx) (*struct{}, api.Err) { - // write value from middlewares into response - value := ctx.Req.Context().Value(key) - if value == nil { - t.Fatalf("nothing found in context") - } - cast, ok := value.(int) - if !ok { - t.Fatalf("cannot cast context data to int") - } - // write to response - ctx.Res.Write([]byte(fmt.Sprintf("#%d#", cast))) - - return nil, api.ErrSuccess - } - - if err := builder.Bind(http.MethodGet, "/path", pathHandler); err != nil { - t.Fatalf("bind: unexpected error <%v>", err) - } - - handler, err := builder.Build() - if err != nil { - t.Fatalf("build: unexpected error <%v>", err) - } - - response := httptest.NewRecorder() - request := httptest.NewRequest(http.MethodGet, "/path", &bytes.Buffer{}) - - // test request - handler.ServeHTTP(response, request) - if response.Body == nil { - t.Fatalf("response has no body") - } - token := fmt.Sprintf("#%d#", n) - if !strings.Contains(response.Body.String(), token) { - t.Fatalf("expected '%s' to be in response <%s>", token, response.Body.String()) - } - -} - func TestBind(t *testing.T) { tcases := []struct { Name string diff --git a/handler.go b/handler.go index 662daf3..1fce916 100644 --- a/handler.go +++ b/handler.go @@ -13,7 +13,7 @@ type Handler Builder // ServeHTTP implements http.Handler and wraps it in middlewares (adapters) func (s Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - var h = http.HandlerFunc(s.handleRequest) + var h = http.HandlerFunc(s.resolve) for _, adapter := range s.adapters { h = adapter(h) @@ -21,7 +21,7 @@ func (s Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h(w, r) } -func (s Handler) handleRequest(w http.ResponseWriter, r *http.Request) { +func (s Handler) resolve(w http.ResponseWriter, r *http.Request) { // 1. find a matching service from config var service = s.conf.Find(r) if service == nil { diff --git a/handler_test.go b/handler_test.go new file mode 100644 index 0000000..54b2e7b --- /dev/null +++ b/handler_test.go @@ -0,0 +1,99 @@ +package aicra + +import ( + "bytes" + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "git.xdrm.io/go/aicra/api" +) + +func TestWith(t *testing.T) { + builder := &Builder{} + if err := addBuiltinTypes(builder); err != nil { + t.Fatalf("unexpected error <%v>", err) + } + + // build @n middlewares that take data from context and increment it + n := 1024 + + type ckey int + const key ckey = 0 + + middleware := func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + newr := r + + // first time -> store 1 + value := r.Context().Value(key) + if value == nil { + newr = r.WithContext(context.WithValue(r.Context(), key, int(1))) + next(w, newr) + return + } + + // get value and increment + cast, ok := value.(int) + if !ok { + t.Fatalf("value is not an int") + } + cast++ + newr = r.WithContext(context.WithValue(r.Context(), key, cast)) + next(w, newr) + } + } + + // add middleware @n times + for i := 0; i < n; i++ { + builder.With(middleware) + } + + config := strings.NewReader(`[ { "method": "GET", "path": "/path", "scope": [[]], "info": "info", "in": {}, "out": {} } ]`) + err := builder.Setup(config) + if err != nil { + t.Fatalf("setup: unexpected error <%v>", err) + } + + pathHandler := func(ctx api.Ctx) (*struct{}, api.Err) { + // write value from middlewares into response + value := ctx.Req.Context().Value(key) + if value == nil { + t.Fatalf("nothing found in context") + } + cast, ok := value.(int) + if !ok { + t.Fatalf("cannot cast context data to int") + } + // write to response + ctx.Res.Write([]byte(fmt.Sprintf("#%d#", cast))) + + return nil, api.ErrSuccess + } + + if err := builder.Bind(http.MethodGet, "/path", pathHandler); err != nil { + t.Fatalf("bind: unexpected error <%v>", err) + } + + handler, err := builder.Build() + if err != nil { + t.Fatalf("build: unexpected error <%v>", err) + } + + response := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "/path", &bytes.Buffer{}) + + // test request + handler.ServeHTTP(response, request) + if response.Body == nil { + t.Fatalf("response has no body") + } + token := fmt.Sprintf("#%d#", n) + if !strings.Contains(response.Body.String(), token) { + t.Fatalf("expected '%s' to be in response <%s>", token, response.Body.String()) + } + +}