refactor: idiomatic remove of api.Context for context.Context, custom middlewares for standard http middlewares
continuous-integration/drone/push Build is passing Details

- remove api.Context as using context.Context is more idiomatic
 - remove api.Adapter as it is redundant with func(http.Handler) http.Handler
 - remove authentication middlewares as they be achieved as normal middlewares but launched around the handler (after the service has been found and validated)
 - builder.With() adds an standard Middleware that runs before any aicra code
 - builder.WithContext() adds an http middleware that runs just before the service handler is called. The http.Request provided contains a context with useful values such as the required permissions (from the service configuration).
 - handlers take a context.Context variable as first argument instead of api.Context
This commit is contained in:
Adrien Marquès 2021-06-20 02:14:31 +02:00
parent 6a78351a2c
commit af63c4514b
Signed by: xdrm-brackets
GPG Key ID: D75243CA236D825E
12 changed files with 230 additions and 196 deletions

View File

@ -1,13 +0,0 @@
package api
import "net/http"
// Adapter to encapsulate incoming requests
type Adapter func(http.HandlerFunc) http.HandlerFunc
// AuthHandlerFunc is http.HandlerFunc with additional Authorization information
type AuthHandlerFunc func(Auth, http.ResponseWriter, *http.Request)
// AuthAdapter to encapsulate incoming request with access to api.Auth
// to manage permissions
type AuthAdapter func(AuthHandlerFunc) AuthHandlerFunc

View File

@ -21,7 +21,7 @@ type Auth struct {
// Granted returns whether the authorization is granted // Granted returns whether the authorization is granted
// i.e. Auth.Active fulfills Auth.Required // i.e. Auth.Active fulfills Auth.Required
func (a Auth) Granted() bool { func (a *Auth) Granted() bool {
var nothingRequired = true var nothingRequired = true
// first dimension: OR ; at least one is valid // first dimension: OR ; at least one is valid
@ -43,7 +43,7 @@ func (a Auth) Granted() bool {
} }
// returns whether Auth.Active fulfills (contains) all @required roles // returns whether Auth.Active fulfills (contains) all @required roles
func (a Auth) fulfills(required []string) bool { func (a *Auth) fulfills(required []string) bool {
for _, requiredRole := range required { for _, requiredRole := range required {
var found = false var found = false
for _, activeRole := range a.Active { for _, activeRole := range a.Active {

View File

@ -7,12 +7,8 @@ import (
"git.xdrm.io/go/aicra/internal/ctx" "git.xdrm.io/go/aicra/internal/ctx"
) )
// Context is a simple wrapper around context.Context that adds helper methods // GetRequest extracts the current request from a context.Context
// to access additional information func GetRequest(c context.Context) *http.Request {
type Context struct{ context.Context }
// Request current request
func (c Context) Request() *http.Request {
var ( var (
raw = c.Value(ctx.Request) raw = c.Value(ctx.Request)
cast, ok = raw.(*http.Request) cast, ok = raw.(*http.Request)
@ -23,8 +19,8 @@ func (c Context) Request() *http.Request {
return cast return cast
} }
// ResponseWriter for this request // GetResponseWriter extracts the response writer from a context.Context
func (c Context) ResponseWriter() http.ResponseWriter { func GetResponseWriter(c context.Context) http.ResponseWriter {
var ( var (
raw = c.Value(ctx.Response) raw = c.Value(ctx.Response)
cast, ok = raw.(http.ResponseWriter) cast, ok = raw.(http.ResponseWriter)
@ -35,8 +31,8 @@ func (c Context) ResponseWriter() http.ResponseWriter {
return cast return cast
} }
// Auth associated with this request // GetAuth returns the api.Auth associated with this request from a context.Context
func (c Context) Auth() *Auth { func GetAuth(c context.Context) *Auth {
var ( var (
raw = c.Value(ctx.Auth) raw = c.Value(ctx.Auth)
cast, ok = raw.(*Auth) cast, ok = raw.(*Auth)

View File

@ -5,7 +5,6 @@ import (
"io" "io"
"net/http" "net/http"
"git.xdrm.io/go/aicra/api"
"git.xdrm.io/go/aicra/datatype" "git.xdrm.io/go/aicra/datatype"
"git.xdrm.io/go/aicra/internal/config" "git.xdrm.io/go/aicra/internal/config"
"git.xdrm.io/go/aicra/internal/dynfunc" "git.xdrm.io/go/aicra/internal/dynfunc"
@ -13,10 +12,16 @@ import (
// Builder for an aicra server // Builder for an aicra server
type Builder struct { type Builder struct {
// the server configuration defining available services
conf *config.Server conf *config.Server
// user-defined handlers bound to services from the configuration
handlers []*apiHandler handlers []*apiHandler
adapters []api.Adapter // http middlewares wrapping the entire http connection (e.g. logger)
authAdapters []api.AuthAdapter middlewares []func(http.Handler) http.Handler
// custom middlewares only wrapping the service handler of a request
// they will benefit from the request's context that contains service-specific
// information (e.g. required permisisons from the configuration)
ctxMiddlewares []func(http.Handler) http.Handler
} }
// represents an api handler (method-pattern combination) // represents an api handler (method-pattern combination)
@ -41,26 +46,36 @@ func (b *Builder) AddType(t datatype.T) error {
return nil return nil
} }
// With adds an http adapter (middleware) // With adds an http middleware on top of the http connection
func (b *Builder) With(adapter api.Adapter) { //
// Authentication management can only be done with the WithContext() methods as
// the service associated with the request has not been found at this stage.
// This stage is perfect for logging or generic request management.
func (b *Builder) With(mw func(http.Handler) http.Handler) {
if b.conf == nil { if b.conf == nil {
b.conf = &config.Server{} b.conf = &config.Server{}
} }
if b.adapters == nil { if b.middlewares == nil {
b.adapters = make([]api.Adapter, 0) b.middlewares = make([]func(http.Handler) http.Handler, 0)
} }
b.adapters = append(b.adapters, adapter) b.middlewares = append(b.middlewares, mw)
} }
// WithAuth adds an http adapter with auth capabilities (middleware) // WithContext adds an http middleware with the fully loaded context
func (b *Builder) WithAuth(adapter api.AuthAdapter) { //
// Logging or generic request management should be done with the With() method as
// it wraps the full http connection. Middlewares added through this method only
// wrap the user-defined service handler. The context.Context is filled with useful
// data that can be access with api.GetRequest(), api.GetResponseWriter(),
// api.GetAuth(), etc methods.
func (b *Builder) WithContext(mw func(http.Handler) http.Handler) {
if b.conf == nil { if b.conf == nil {
b.conf = &config.Server{} b.conf = &config.Server{}
} }
if b.authAdapters == nil { if b.ctxMiddlewares == nil {
b.authAdapters = make([]api.AuthAdapter, 0) b.ctxMiddlewares = make([]func(http.Handler) http.Handler, 0)
} }
b.authAdapters = append(b.authAdapters, adapter) b.ctxMiddlewares = append(b.ctxMiddlewares, mw)
} }
// Setup the builder with its api definition file // Setup the builder with its api definition file

View File

@ -1,6 +1,7 @@
package aicra package aicra
import ( import (
"context"
"errors" "errors"
"net/http" "net/http"
"strings" "strings"
@ -72,7 +73,7 @@ func TestBind(t *testing.T) {
Config: "[]", Config: "[]",
HandlerMethod: "", HandlerMethod: "",
HandlerPath: "", HandlerPath: "",
HandlerFn: func(*api.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, HandlerFn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess },
BindErr: errUnknownService, BindErr: errUnknownService,
BuildErr: nil, BuildErr: nil,
}, },
@ -108,7 +109,7 @@ func TestBind(t *testing.T) {
]`, ]`,
HandlerMethod: http.MethodPost, HandlerMethod: http.MethodPost,
HandlerPath: "/path", HandlerPath: "/path",
HandlerFn: func(*api.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, HandlerFn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess },
BindErr: errUnknownService, BindErr: errUnknownService,
BuildErr: errMissingHandler, BuildErr: errMissingHandler,
}, },
@ -126,7 +127,7 @@ func TestBind(t *testing.T) {
]`, ]`,
HandlerMethod: http.MethodGet, HandlerMethod: http.MethodGet,
HandlerPath: "/paths", HandlerPath: "/paths",
HandlerFn: func(*api.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, HandlerFn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess },
BindErr: errUnknownService, BindErr: errUnknownService,
BuildErr: errMissingHandler, BuildErr: errMissingHandler,
}, },
@ -144,7 +145,7 @@ func TestBind(t *testing.T) {
]`, ]`,
HandlerMethod: http.MethodGet, HandlerMethod: http.MethodGet,
HandlerPath: "/path", HandlerPath: "/path",
HandlerFn: func(*api.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, HandlerFn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess },
BindErr: nil, BindErr: nil,
BuildErr: nil, BuildErr: nil,
}, },
@ -164,7 +165,7 @@ func TestBind(t *testing.T) {
]`, ]`,
HandlerMethod: http.MethodGet, HandlerMethod: http.MethodGet,
HandlerPath: "/path", HandlerPath: "/path",
HandlerFn: func(*api.Context, struct{ Name int }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, HandlerFn: func(context.Context, struct{ Name int }) (*struct{}, api.Err) { return nil, api.ErrSuccess },
BindErr: nil, BindErr: nil,
BuildErr: nil, BuildErr: nil,
}, },
@ -184,7 +185,7 @@ func TestBind(t *testing.T) {
]`, ]`,
HandlerMethod: http.MethodGet, HandlerMethod: http.MethodGet,
HandlerPath: "/path", HandlerPath: "/path",
HandlerFn: func(*api.Context, struct{ Name uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, HandlerFn: func(context.Context, struct{ Name uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess },
BindErr: nil, BindErr: nil,
BuildErr: nil, BuildErr: nil,
}, },
@ -204,7 +205,7 @@ func TestBind(t *testing.T) {
]`, ]`,
HandlerMethod: http.MethodGet, HandlerMethod: http.MethodGet,
HandlerPath: "/path", HandlerPath: "/path",
HandlerFn: func(*api.Context, struct{ Name string }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, HandlerFn: func(context.Context, struct{ Name string }) (*struct{}, api.Err) { return nil, api.ErrSuccess },
BindErr: nil, BindErr: nil,
BuildErr: nil, BuildErr: nil,
}, },
@ -224,7 +225,7 @@ func TestBind(t *testing.T) {
]`, ]`,
HandlerMethod: http.MethodGet, HandlerMethod: http.MethodGet,
HandlerPath: "/path", HandlerPath: "/path",
HandlerFn: func(*api.Context, struct{ Name bool }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, HandlerFn: func(context.Context, struct{ Name bool }) (*struct{}, api.Err) { return nil, api.ErrSuccess },
BindErr: nil, BindErr: nil,
BuildErr: nil, BuildErr: nil,
}, },

View File

@ -17,30 +17,31 @@ type Handler Builder
// ServeHTTP implements http.Handler and wraps it in middlewares (adapters) // ServeHTTP implements http.Handler and wraps it in middlewares (adapters)
func (s Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var h = http.HandlerFunc(s.resolve) var h http.Handler = http.HandlerFunc(s.resolve)
for _, adapter := range s.adapters { for _, mw := range s.middlewares {
h = adapter(h) h = mw(h)
} }
h(w, r) h.ServeHTTP(w, r)
} }
// ServeHTTP implements http.Handler and wraps it in middlewares (adapters)
func (s Handler) resolve(w http.ResponseWriter, r *http.Request) { func (s Handler) resolve(w http.ResponseWriter, r *http.Request) {
// 1. find a matching service from config // 1ind a matching service from config
var service = s.conf.Find(r) var service = s.conf.Find(r)
if service == nil { if service == nil {
handleError(api.ErrUnknownService, w, r) handleError(api.ErrUnknownService, w, r)
return return
} }
// 2. extract request data // extract request data
var input, err = extractInput(service, *r) var input, err = extractInput(service, *r)
if err != nil { if err != nil {
handleError(api.ErrMissingParam, w, r) handleError(api.ErrMissingParam, w, r)
return return
} }
// 3. find a matching handler // find a matching handler
var handler *apiHandler var handler *apiHandler
for _, h := range s.handlers { for _, h := range s.handlers {
if h.Method == service.Method && h.Path == service.Pattern { if h.Method == service.Method && h.Path == service.Pattern {
@ -48,63 +49,36 @@ func (s Handler) resolve(w http.ResponseWriter, r *http.Request) {
} }
} }
// 4. fail on no matching handler // fail on no matching handler
if handler == nil { if handler == nil {
handleError(api.ErrUncallableService, w, r) handleError(api.ErrUncallableService, w, r)
return return
} }
// replace format '[a]' in scope where 'a' is an existing input's name
scope := make([][]string, len(service.Scope))
for a, list := range service.Scope {
scope[a] = make([]string, len(list))
for b, perm := range list {
scope[a][b] = perm
for name, value := range input.Data {
var (
token = fmt.Sprintf("[%s]", name)
replacement = ""
)
if value != nil {
replacement = fmt.Sprintf("[%v]", value)
}
scope[a][b] = strings.ReplaceAll(scope[a][b], token, replacement)
}
}
}
var auth = api.Auth{
Required: scope,
Active: []string{},
}
// 5. run auth-aware middlewares
var h = api.AuthHandlerFunc(func(a api.Auth, w http.ResponseWriter, r *http.Request) {
if !a.Granted() {
handleError(api.ErrPermission, w, r)
return
}
s.handle(input, handler, service, w, r)
})
for _, adapter := range s.authAdapters {
h = adapter(h)
}
h(auth, w, r)
}
func (s *Handler) handle(input *reqdata.T, handler *apiHandler, service *config.Service, w http.ResponseWriter, r *http.Request) {
// build context with builtin data // build context with builtin data
c := r.Context() c := r.Context()
c = context.WithValue(c, ctx.Request, r) c = context.WithValue(c, ctx.Request, r)
c = context.WithValue(c, ctx.Response, w) c = context.WithValue(c, ctx.Response, w)
c = context.WithValue(c, ctx.Auth, w) c = context.WithValue(c, ctx.Auth, buildAuth(service.Scope, input.Data))
apictx := &api.Context{Context: c}
// create http handler
var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// use context defined in the request
s.handle(r.Context(), input, handler, service, w, r)
})
// run middlewares the handler
for _, mw := range s.ctxMiddlewares {
h = mw(h)
}
// serve using the context with values
h.ServeHTTP(w, r.WithContext(c))
}
func (s *Handler) handle(c context.Context, input *reqdata.T, handler *apiHandler, service *config.Service, w http.ResponseWriter, r *http.Request) {
// pass execution to the handler // pass execution to the handler
var outData, outErr = handler.dyn.Handle(apictx, input.Data) var outData, outErr = handler.dyn.Handle(c, input.Data)
// build response from returned arguments // build response from returned arguments
var res = api.EmptyResponse().WithError(outErr) var res = api.EmptyResponse().WithError(outErr)
@ -157,3 +131,35 @@ func extractInput(service *config.Service, req http.Request) (*reqdata.T, error)
return dataset, nil return dataset, nil
} }
// buildAuth builds the api.Auth struct from the service scope configuration
//
// it replaces format '[a]' in scope where 'a' is an existing input argument's
// name with its value
func buildAuth(scope [][]string, in map[string]interface{}) *api.Auth {
updated := make([][]string, len(scope))
// replace '[arg_name]' with the 'arg_name' value if it is a known variable
// name
for a, list := range scope {
updated[a] = make([]string, len(list))
for b, perm := range list {
updated[a][b] = perm
for name, value := range in {
var (
token = fmt.Sprintf("[%s]", name)
replacement = ""
)
if value != nil {
replacement = fmt.Sprintf("[%v]", value)
}
updated[a][b] = strings.ReplaceAll(updated[a][b], token, replacement)
}
}
}
return &api.Auth{
Required: updated,
Active: []string{},
}
}

View File

@ -48,15 +48,15 @@ func TestWith(t *testing.T) {
type ckey int type ckey int
const key ckey = 0 const key ckey = 0
middleware := func(next http.HandlerFunc) http.HandlerFunc { middleware := func(next http.Handler) http.Handler {
return func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
newr := r newr := r
// first time -> store 1 // first time -> store 1
value := r.Context().Value(key) value := r.Context().Value(key)
if value == nil { if value == nil {
newr = r.WithContext(context.WithValue(r.Context(), key, int(1))) newr = r.WithContext(context.WithValue(r.Context(), key, int(1)))
next(w, newr) next.ServeHTTP(w, newr)
return return
} }
@ -67,8 +67,8 @@ func TestWith(t *testing.T) {
} }
cast++ cast++
newr = r.WithContext(context.WithValue(r.Context(), key, cast)) newr = r.WithContext(context.WithValue(r.Context(), key, cast))
next(w, newr) next.ServeHTTP(w, newr)
} })
} }
// add middleware @n times // add middleware @n times
@ -82,7 +82,7 @@ func TestWith(t *testing.T) {
t.Fatalf("setup: unexpected error <%v>", err) t.Fatalf("setup: unexpected error <%v>", err)
} }
pathHandler := func(ctx *api.Context) (*struct{}, api.Err) { pathHandler := func(ctx context.Context) (*struct{}, api.Err) {
// write value from middlewares into response // write value from middlewares into response
value := ctx.Value(key) value := ctx.Value(key)
if value == nil { if value == nil {
@ -93,7 +93,7 @@ func TestWith(t *testing.T) {
t.Fatalf("cannot cast context data to int") t.Fatalf("cannot cast context data to int")
} }
// write to response // write to response
ctx.ResponseWriter().Write([]byte(fmt.Sprintf("#%d#", cast))) api.GetResponseWriter(ctx).Write([]byte(fmt.Sprintf("#%d#", cast)))
return nil, api.ErrSuccess return nil, api.ErrSuccess
} }
@ -212,8 +212,13 @@ func TestWithAuth(t *testing.T) {
} }
// tester middleware (last executed) // tester middleware (last executed)
builder.WithAuth(func(next api.AuthHandlerFunc) api.AuthHandlerFunc { builder.WithContext(func(next http.Handler) http.Handler {
return func(a api.Auth, w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
a := api.GetAuth(r.Context())
if a == nil {
t.Fatalf("cannot access api.Auth form request context")
}
if a.Granted() == tc.granted { if a.Granted() == tc.granted {
return return
} }
@ -222,14 +227,20 @@ func TestWithAuth(t *testing.T) {
} else { } else {
t.Fatalf("expected granted auth") t.Fatalf("expected granted auth")
} }
} next.ServeHTTP(w, r)
})
}) })
builder.WithAuth(func(next api.AuthHandlerFunc) api.AuthHandlerFunc { builder.WithContext(func(next http.Handler) http.Handler {
return func(a api.Auth, w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
a.Active = tc.permissions a := api.GetAuth(r.Context())
next(a, w, r) if a == nil {
t.Fatalf("cannot access api.Auth form request context")
} }
a.Active = tc.permissions
next.ServeHTTP(w, r)
})
}) })
err := builder.Setup(strings.NewReader(tc.manifest)) err := builder.Setup(strings.NewReader(tc.manifest))
@ -237,7 +248,7 @@ func TestWithAuth(t *testing.T) {
t.Fatalf("setup: unexpected error <%v>", err) t.Fatalf("setup: unexpected error <%v>", err)
} }
pathHandler := func(ctx *api.Context) (*struct{}, api.Err) { pathHandler := func(ctx context.Context) (*struct{}, api.Err) {
return nil, api.ErrNotImplemented return nil, api.ErrNotImplemented
} }
@ -290,7 +301,7 @@ func TestDynamicScope(t *testing.T) {
} }
]`, ]`,
path: "/path/{id}", path: "/path/{id}",
handler: func(*api.Context, struct{ Input1 uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, handler: func(context.Context, struct{ Input1 uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess },
url: "/path/123", url: "/path/123",
body: ``, body: ``,
permissions: []string{"user[123]"}, permissions: []string{"user[123]"},
@ -311,7 +322,7 @@ func TestDynamicScope(t *testing.T) {
} }
]`, ]`,
path: "/path/{id}", path: "/path/{id}",
handler: func(*api.Context, struct{ Input1 uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, handler: func(context.Context, struct{ Input1 uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess },
url: "/path/666", url: "/path/666",
body: ``, body: ``,
permissions: []string{"user[123]"}, permissions: []string{"user[123]"},
@ -332,7 +343,7 @@ func TestDynamicScope(t *testing.T) {
} }
]`, ]`,
path: "/path/{id}", path: "/path/{id}",
handler: func(*api.Context, struct{ User uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess }, handler: func(context.Context, struct{ User uint }) (*struct{}, api.Err) { return nil, api.ErrSuccess },
url: "/path/123", url: "/path/123",
body: ``, body: ``,
permissions: []string{"prefix.user[123].suffix"}, permissions: []string{"prefix.user[123].suffix"},
@ -354,7 +365,7 @@ func TestDynamicScope(t *testing.T) {
} }
]`, ]`,
path: "/prefix/{pid}/user/{uid}", path: "/prefix/{pid}/user/{uid}",
handler: func(*api.Context, struct { handler: func(context.Context, struct {
Prefix uint Prefix uint
User uint User uint
}) (*struct{}, api.Err) { }) (*struct{}, api.Err) {
@ -381,7 +392,7 @@ func TestDynamicScope(t *testing.T) {
} }
]`, ]`,
path: "/prefix/{pid}/user/{uid}", path: "/prefix/{pid}/user/{uid}",
handler: func(*api.Context, struct { handler: func(context.Context, struct {
Prefix uint Prefix uint
User uint User uint
}) (*struct{}, api.Err) { }) (*struct{}, api.Err) {
@ -409,7 +420,7 @@ func TestDynamicScope(t *testing.T) {
} }
]`, ]`,
path: "/prefix/{pid}/user/{uid}/suffix/{sid}", path: "/prefix/{pid}/user/{uid}/suffix/{sid}",
handler: func(*api.Context, struct { handler: func(context.Context, struct {
Prefix uint Prefix uint
User uint User uint
Suffix uint Suffix uint
@ -438,7 +449,7 @@ func TestDynamicScope(t *testing.T) {
} }
]`, ]`,
path: "/prefix/{pid}/user/{uid}/suffix/{sid}", path: "/prefix/{pid}/user/{uid}/suffix/{sid}",
handler: func(*api.Context, struct { handler: func(context.Context, struct {
Prefix uint Prefix uint
User uint User uint
Suffix uint Suffix uint
@ -460,8 +471,12 @@ func TestDynamicScope(t *testing.T) {
} }
// tester middleware (last executed) // tester middleware (last executed)
builder.WithAuth(func(next api.AuthHandlerFunc) api.AuthHandlerFunc { builder.WithContext(func(next http.Handler) http.Handler {
return func(a api.Auth, w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
a := api.GetAuth(r.Context())
if a == nil {
t.Fatalf("cannot access api.Auth form request context")
}
if a.Granted() == tc.granted { if a.Granted() == tc.granted {
return return
} }
@ -470,15 +485,20 @@ func TestDynamicScope(t *testing.T) {
} else { } else {
t.Fatalf("expected granted auth") t.Fatalf("expected granted auth")
} }
} next.ServeHTTP(w, r)
})
}) })
// update permissions // update permissions
builder.WithAuth(func(next api.AuthHandlerFunc) api.AuthHandlerFunc { builder.WithContext(func(next http.Handler) http.Handler {
return func(a api.Auth, w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
a.Active = tc.permissions a := api.GetAuth(r.Context())
next(a, w, r) if a == nil {
t.Fatalf("cannot access api.Auth form request context")
} }
a.Active = tc.permissions
next.ServeHTTP(w, r)
})
}) })
err := builder.Setup(strings.NewReader(tc.manifest)) err := builder.Setup(strings.NewReader(tc.manifest))

View File

@ -14,7 +14,7 @@ const errHandlerNotFunc = cerr("handler must be a func")
const errNoServiceForHandler = cerr("no service found for this handler") const errNoServiceForHandler = cerr("no service found for this handler")
// errMissingHandlerArgumentParam - missing params arguments for handler // errMissingHandlerArgumentParam - missing params arguments for handler
const errMissingHandlerContextArgument = cerr("missing handler first argument of type *api.Context") const errMissingHandlerContextArgument = cerr("missing handler first argument of type context.Context")
// errMissingHandlerInputArgument - missing params arguments for handler // errMissingHandlerInputArgument - missing params arguments for handler
const errMissingHandlerInputArgument = cerr("missing handler argument: input struct") const errMissingHandlerInputArgument = cerr("missing handler argument: input struct")

View File

@ -1,6 +1,7 @@
package dynfunc package dynfunc
import ( import (
"context"
"fmt" "fmt"
"log" "log"
"reflect" "reflect"
@ -50,7 +51,7 @@ func Build(fn interface{}, service config.Service) (*Handler, error) {
} }
// Handle binds input `data` into the dynamic function and returns an output map // Handle binds input `data` into the dynamic function and returns an output map
func (h *Handler) Handle(ctx *api.Context, data map[string]interface{}) (map[string]interface{}, api.Err) { func (h *Handler) Handle(ctx context.Context, data map[string]interface{}) (map[string]interface{}, api.Err) {
var ( var (
ert = reflect.TypeOf(api.Err{}) ert = reflect.TypeOf(api.Err{})
fnv = reflect.ValueOf(h.fn) fnv = reflect.ValueOf(h.fn)

View File

@ -1,6 +1,7 @@
package dynfunc package dynfunc
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"testing" "testing"
@ -52,7 +53,7 @@ func TestInput(t *testing.T) {
{ {
Name: "none required none provided", Name: "none required none provided",
Spec: (&testsignature{}).withArgs(), Spec: (&testsignature{}).withArgs(),
Fn: func(*api.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, Fn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess },
HasContext: false, HasContext: false,
Input: []interface{}{}, Input: []interface{}{},
ExpectedOutput: []interface{}{}, ExpectedOutput: []interface{}{},
@ -61,7 +62,7 @@ func TestInput(t *testing.T) {
{ {
Name: "int proxy (0)", Name: "int proxy (0)",
Spec: (&testsignature{}).withArgs(reflect.TypeOf(int(0))), Spec: (&testsignature{}).withArgs(reflect.TypeOf(int(0))),
Fn: func(ctx *api.Context, in intstruct) (*intstruct, api.Err) { Fn: func(ctx context.Context, in intstruct) (*intstruct, api.Err) {
return &intstruct{P1: in.P1}, api.ErrSuccess return &intstruct{P1: in.P1}, api.ErrSuccess
}, },
HasContext: false, HasContext: false,
@ -72,7 +73,7 @@ func TestInput(t *testing.T) {
{ {
Name: "int proxy (11)", Name: "int proxy (11)",
Spec: (&testsignature{}).withArgs(reflect.TypeOf(int(0))), Spec: (&testsignature{}).withArgs(reflect.TypeOf(int(0))),
Fn: func(ctx *api.Context, in intstruct) (*intstruct, api.Err) { Fn: func(ctx context.Context, in intstruct) (*intstruct, api.Err) {
return &intstruct{P1: in.P1}, api.ErrSuccess return &intstruct{P1: in.P1}, api.ErrSuccess
}, },
HasContext: false, HasContext: false,
@ -83,7 +84,7 @@ func TestInput(t *testing.T) {
{ {
Name: "*int proxy (nil)", Name: "*int proxy (nil)",
Spec: (&testsignature{}).withArgs(reflect.TypeOf(new(int))), Spec: (&testsignature{}).withArgs(reflect.TypeOf(new(int))),
Fn: func(ctx *api.Context, in intptrstruct) (*intptrstruct, api.Err) { Fn: func(ctx context.Context, in intptrstruct) (*intptrstruct, api.Err) {
return &intptrstruct{P1: in.P1}, api.ErrSuccess return &intptrstruct{P1: in.P1}, api.ErrSuccess
}, },
HasContext: false, HasContext: false,
@ -94,7 +95,7 @@ func TestInput(t *testing.T) {
{ {
Name: "*int proxy (28)", Name: "*int proxy (28)",
Spec: (&testsignature{}).withArgs(reflect.TypeOf(new(int))), Spec: (&testsignature{}).withArgs(reflect.TypeOf(new(int))),
Fn: func(ctx *api.Context, in intptrstruct) (*intstruct, api.Err) { Fn: func(ctx context.Context, in intptrstruct) (*intstruct, api.Err) {
return &intstruct{P1: *in.P1}, api.ErrSuccess return &intstruct{P1: *in.P1}, api.ErrSuccess
}, },
HasContext: false, HasContext: false,
@ -105,7 +106,7 @@ func TestInput(t *testing.T) {
{ {
Name: "*int proxy (13)", Name: "*int proxy (13)",
Spec: (&testsignature{}).withArgs(reflect.TypeOf(new(int))), Spec: (&testsignature{}).withArgs(reflect.TypeOf(new(int))),
Fn: func(ctx *api.Context, in intptrstruct) (*intstruct, api.Err) { Fn: func(ctx context.Context, in intptrstruct) (*intstruct, api.Err) {
return &intstruct{P1: *in.P1}, api.ErrSuccess return &intstruct{P1: *in.P1}, api.ErrSuccess
}, },
HasContext: false, HasContext: false,
@ -131,7 +132,7 @@ func TestInput(t *testing.T) {
input[key] = val input[key] = val
} }
var output, err = handler.Handle(&api.Context{}, input) var output, err = handler.Handle(context.Background(), input)
if err != tcase.ExpectedErr { if err != tcase.ExpectedErr {
t.Fatalf("expected api error <%v> got <%v>", tcase.ExpectedErr, err) t.Fatalf("expected api error <%v> got <%v>", tcase.ExpectedErr, err)
} }

View File

@ -1,6 +1,7 @@
package dynfunc package dynfunc
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
@ -48,12 +49,17 @@ func BuildSignature(service config.Service) *Signature {
// ValidateInput validates a handler's input arguments against the service signature // ValidateInput validates a handler's input arguments against the service signature
func (s *Signature) ValidateInput(handlerType reflect.Type) error { func (s *Signature) ValidateInput(handlerType reflect.Type) error {
ctxType := reflect.TypeOf(api.Context{}) ctxType := reflect.TypeOf((*context.Context)(nil)).Elem()
// missing or invalid first arg: api.Context // missing or invalid first arg: context.Context
if handlerType.NumIn() < 1 || ctxType.AssignableTo(handlerType.In(0)) { if handlerType.NumIn() < 1 {
return errMissingHandlerContextArgument return errMissingHandlerContextArgument
} }
firstArgType := handlerType.In(0)
if !firstArgType.Implements(ctxType) {
return fmt.Errorf("fock")
}
// no input required // no input required
if len(s.Input) == 0 { if len(s.Input) == 0 {

View File

@ -1,6 +1,7 @@
package dynfunc package dynfunc
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@ -20,22 +21,22 @@ func TestInputCheck(t *testing.T) {
{ {
Name: "no input 0 given", Name: "no input 0 given",
Input: map[string]reflect.Type{}, Input: map[string]reflect.Type{},
Fn: func(*api.Context) {}, Fn: func(context.Context) {},
FnCtx: func(*api.Context) {}, FnCtx: func(context.Context) {},
Err: nil, Err: nil,
}, },
{ {
Name: "no input 1 given", Name: "no input 1 given",
Input: map[string]reflect.Type{}, Input: map[string]reflect.Type{},
Fn: func(*api.Context, int) {}, Fn: func(context.Context, int) {},
FnCtx: func(*api.Context, int) {}, FnCtx: func(context.Context, int) {},
Err: errUnexpectedInput, Err: errUnexpectedInput,
}, },
{ {
Name: "no input 2 given", Name: "no input 2 given",
Input: map[string]reflect.Type{}, Input: map[string]reflect.Type{},
Fn: func(*api.Context, int, string) {}, Fn: func(context.Context, int, string) {},
FnCtx: func(*api.Context, int, string) {}, FnCtx: func(context.Context, int, string) {},
Err: errUnexpectedInput, Err: errUnexpectedInput,
}, },
{ {
@ -43,8 +44,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf(int(0)), "Test1": reflect.TypeOf(int(0)),
}, },
Fn: func(*api.Context) {}, Fn: func(context.Context) {},
FnCtx: func(*api.Context) {}, FnCtx: func(context.Context) {},
Err: errMissingHandlerInputArgument, Err: errMissingHandlerInputArgument,
}, },
{ {
@ -52,8 +53,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf(int(0)), "Test1": reflect.TypeOf(int(0)),
}, },
Fn: func(*api.Context, int) {}, Fn: func(context.Context, int) {},
FnCtx: func(*api.Context, int) {}, FnCtx: func(context.Context, int) {},
Err: errMissingParamArgument, Err: errMissingParamArgument,
}, },
{ {
@ -61,8 +62,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"test1": reflect.TypeOf(int(0)), "test1": reflect.TypeOf(int(0)),
}, },
Fn: func(*api.Context, struct{}) {}, Fn: func(context.Context, struct{}) {},
FnCtx: func(*api.Context, struct{}) {}, FnCtx: func(context.Context, struct{}) {},
Err: errUnexportedName, Err: errUnexportedName,
}, },
{ {
@ -70,8 +71,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf(int(0)), "Test1": reflect.TypeOf(int(0)),
}, },
Fn: func(*api.Context, struct{}) {}, Fn: func(context.Context, struct{}) {},
FnCtx: func(*api.Context, struct{}) {}, FnCtx: func(context.Context, struct{}) {},
Err: errMissingConfigArgument, Err: errMissingConfigArgument,
}, },
{ {
@ -79,8 +80,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf(int(0)), "Test1": reflect.TypeOf(int(0)),
}, },
Fn: func(*api.Context, struct{ Test1 string }) {}, Fn: func(context.Context, struct{ Test1 string }) {},
FnCtx: func(*api.Context, struct{ Test1 string }) {}, FnCtx: func(context.Context, struct{ Test1 string }) {},
Err: errWrongParamTypeFromConfig, Err: errWrongParamTypeFromConfig,
}, },
{ {
@ -88,8 +89,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf(int(0)), "Test1": reflect.TypeOf(int(0)),
}, },
Fn: func(*api.Context, struct{ Test1 int }) {}, Fn: func(context.Context, struct{ Test1 int }) {},
FnCtx: func(*api.Context, struct{ Test1 int }) {}, FnCtx: func(context.Context, struct{ Test1 int }) {},
Err: nil, Err: nil,
}, },
{ {
@ -97,8 +98,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf(new(int)), "Test1": reflect.TypeOf(new(int)),
}, },
Fn: func(*api.Context, struct{}) {}, Fn: func(context.Context, struct{}) {},
FnCtx: func(*api.Context, struct{}) {}, FnCtx: func(context.Context, struct{}) {},
Err: errMissingConfigArgument, Err: errMissingConfigArgument,
}, },
{ {
@ -106,8 +107,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf(new(int)), "Test1": reflect.TypeOf(new(int)),
}, },
Fn: func(*api.Context, struct{ Test1 string }) {}, Fn: func(context.Context, struct{ Test1 string }) {},
FnCtx: func(*api.Context, struct{ Test1 string }) {}, FnCtx: func(context.Context, struct{ Test1 string }) {},
Err: errWrongParamTypeFromConfig, Err: errWrongParamTypeFromConfig,
}, },
{ {
@ -115,8 +116,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf(new(int)), "Test1": reflect.TypeOf(new(int)),
}, },
Fn: func(*api.Context, struct{ Test1 *string }) {}, Fn: func(context.Context, struct{ Test1 *string }) {},
FnCtx: func(*api.Context, struct{ Test1 *string }) {}, FnCtx: func(context.Context, struct{ Test1 *string }) {},
Err: errWrongParamTypeFromConfig, Err: errWrongParamTypeFromConfig,
}, },
{ {
@ -124,8 +125,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf(new(int)), "Test1": reflect.TypeOf(new(int)),
}, },
Fn: func(*api.Context, struct{ Test1 *int }) {}, Fn: func(context.Context, struct{ Test1 *int }) {},
FnCtx: func(*api.Context, struct{ Test1 *int }) {}, FnCtx: func(context.Context, struct{ Test1 *int }) {},
Err: nil, Err: nil,
}, },
{ {
@ -133,8 +134,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf(string("")), "Test1": reflect.TypeOf(string("")),
}, },
Fn: func(*api.Context, struct{ Test1 string }) {}, Fn: func(context.Context, struct{ Test1 string }) {},
FnCtx: func(*api.Context, struct{ Test1 string }) {}, FnCtx: func(context.Context, struct{ Test1 string }) {},
Err: nil, Err: nil,
}, },
{ {
@ -142,8 +143,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf(uint(0)), "Test1": reflect.TypeOf(uint(0)),
}, },
Fn: func(*api.Context, struct{ Test1 uint }) {}, Fn: func(context.Context, struct{ Test1 uint }) {},
FnCtx: func(*api.Context, struct{ Test1 uint }) {}, FnCtx: func(context.Context, struct{ Test1 uint }) {},
Err: nil, Err: nil,
}, },
{ {
@ -151,8 +152,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf(float64(0)), "Test1": reflect.TypeOf(float64(0)),
}, },
Fn: func(*api.Context, struct{ Test1 float64 }) {}, Fn: func(context.Context, struct{ Test1 float64 }) {},
FnCtx: func(*api.Context, struct{ Test1 float64 }) {}, FnCtx: func(context.Context, struct{ Test1 float64 }) {},
Err: nil, Err: nil,
}, },
{ {
@ -160,8 +161,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf([]byte("")), "Test1": reflect.TypeOf([]byte("")),
}, },
Fn: func(*api.Context, struct{ Test1 []byte }) {}, Fn: func(context.Context, struct{ Test1 []byte }) {},
FnCtx: func(*api.Context, struct{ Test1 []byte }) {}, FnCtx: func(context.Context, struct{ Test1 []byte }) {},
Err: nil, Err: nil,
}, },
{ {
@ -169,8 +170,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf([]rune("")), "Test1": reflect.TypeOf([]rune("")),
}, },
Fn: func(*api.Context, struct{ Test1 []rune }) {}, Fn: func(context.Context, struct{ Test1 []rune }) {},
FnCtx: func(*api.Context, struct{ Test1 []rune }) {}, FnCtx: func(context.Context, struct{ Test1 []rune }) {},
Err: nil, Err: nil,
}, },
{ {
@ -178,8 +179,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf(new(string)), "Test1": reflect.TypeOf(new(string)),
}, },
Fn: func(*api.Context, struct{ Test1 *string }) {}, Fn: func(context.Context, struct{ Test1 *string }) {},
FnCtx: func(*api.Context, struct{ Test1 *string }) {}, FnCtx: func(context.Context, struct{ Test1 *string }) {},
Err: nil, Err: nil,
}, },
{ {
@ -187,8 +188,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf(new(uint)), "Test1": reflect.TypeOf(new(uint)),
}, },
Fn: func(*api.Context, struct{ Test1 *uint }) {}, Fn: func(context.Context, struct{ Test1 *uint }) {},
FnCtx: func(*api.Context, struct{ Test1 *uint }) {}, FnCtx: func(context.Context, struct{ Test1 *uint }) {},
Err: nil, Err: nil,
}, },
{ {
@ -196,8 +197,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf(new(float64)), "Test1": reflect.TypeOf(new(float64)),
}, },
Fn: func(*api.Context, struct{ Test1 *float64 }) {}, Fn: func(context.Context, struct{ Test1 *float64 }) {},
FnCtx: func(*api.Context, struct{ Test1 *float64 }) {}, FnCtx: func(context.Context, struct{ Test1 *float64 }) {},
Err: nil, Err: nil,
}, },
{ {
@ -205,8 +206,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf(new([]byte)), "Test1": reflect.TypeOf(new([]byte)),
}, },
Fn: func(*api.Context, struct{ Test1 *[]byte }) {}, Fn: func(context.Context, struct{ Test1 *[]byte }) {},
FnCtx: func(*api.Context, struct{ Test1 *[]byte }) {}, FnCtx: func(context.Context, struct{ Test1 *[]byte }) {},
Err: nil, Err: nil,
}, },
{ {
@ -214,8 +215,8 @@ func TestInputCheck(t *testing.T) {
Input: map[string]reflect.Type{ Input: map[string]reflect.Type{
"Test1": reflect.TypeOf(new([]rune)), "Test1": reflect.TypeOf(new([]rune)),
}, },
Fn: func(*api.Context, struct{ Test1 *[]rune }) {}, Fn: func(context.Context, struct{ Test1 *[]rune }) {},
FnCtx: func(*api.Context, struct{ Test1 *[]rune }) {}, FnCtx: func(context.Context, struct{ Test1 *[]rune }) {},
Err: nil, Err: nil,
}, },
} }
@ -259,37 +260,37 @@ func TestOutputCheck(t *testing.T) {
// no input -> missing api.Err // no input -> missing api.Err
{ {
Output: map[string]reflect.Type{}, Output: map[string]reflect.Type{},
Fn: func(*api.Context) {}, Fn: func(context.Context) {},
Err: errMissingHandlerOutputArgument, Err: errMissingHandlerOutputArgument,
}, },
// no input -> with last type not api.Err // no input -> with last type not api.Err
{ {
Output: map[string]reflect.Type{}, Output: map[string]reflect.Type{},
Fn: func(*api.Context) bool { return true }, Fn: func(context.Context) bool { return true },
Err: errMissingHandlerErrorArgument, Err: errMissingHandlerErrorArgument,
}, },
// no input -> with api.Err // no input -> with api.Err
{ {
Output: map[string]reflect.Type{}, Output: map[string]reflect.Type{},
Fn: func(*api.Context) api.Err { return api.ErrSuccess }, Fn: func(context.Context) api.Err { return api.ErrSuccess },
Err: nil, Err: nil,
}, },
// no input -> missing *api.Context // no input -> missing context.Context
{ {
Output: map[string]reflect.Type{}, Output: map[string]reflect.Type{},
Fn: func(*api.Context) api.Err { return api.ErrSuccess }, Fn: func(context.Context) api.Err { return api.ErrSuccess },
Err: errMissingHandlerContextArgument, Err: errMissingHandlerContextArgument,
}, },
// no input -> invlaid *api.Context type // no input -> invlaid context.Context type
{ {
Output: map[string]reflect.Type{}, Output: map[string]reflect.Type{},
Fn: func(*api.Context, int) api.Err { return api.ErrSuccess }, Fn: func(context.Context, int) api.Err { return api.ErrSuccess },
Err: errMissingHandlerContextArgument, Err: errMissingHandlerContextArgument,
}, },
// func can have output if not specified // func can have output if not specified
{ {
Output: map[string]reflect.Type{}, Output: map[string]reflect.Type{},
Fn: func(*api.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess }, Fn: func(context.Context) (*struct{}, api.Err) { return nil, api.ErrSuccess },
Err: nil, Err: nil,
}, },
// missing output struct in func // missing output struct in func