feat: add optional api.Ctx first argument to handler checker

This commit is contained in:
Adrien Marquès 2021-04-18 19:25:31 +02:00
parent 24be7c294e
commit 0a55c2ee13
3 changed files with 144 additions and 80 deletions

View File

@ -11,8 +11,10 @@ import (
// Handler represents a dynamic api handler // Handler represents a dynamic api handler
type Handler struct { type Handler struct {
spec spec spec *spec
fn interface{} fn interface{}
// whether fn uses api.Ctx as 1st argument
hasContext bool
} }
// Build a handler from a service configuration and a dynamic function // Build a handler from a service configuration and a dynamic function
@ -30,16 +32,23 @@ func Build(fn interface{}, service config.Service) (*Handler, error) {
fn: fn, fn: fn,
} }
fnv := reflect.ValueOf(fn) impl := reflect.TypeOf(fn)
if fnv.Type().Kind() != reflect.Func { if impl.Kind() != reflect.Func {
return nil, errHandlerNotFunc return nil, errHandlerNotFunc
} }
if err := h.spec.checkInput(fnv); err != nil { h.hasContext = impl.NumIn() >= 1 && reflect.TypeOf(api.Ctx{}).AssignableTo(impl.In(0))
inputIndex := 0
if h.hasContext {
inputIndex = 1
}
if err := h.spec.checkInput(impl, inputIndex); err != nil {
return nil, fmt.Errorf("input: %w", err) return nil, fmt.Errorf("input: %w", err)
} }
if err := h.spec.checkOutput(fnv); err != nil { if err := h.spec.checkOutput(impl); err != nil {
return nil, fmt.Errorf("output: %w", err) return nil, fmt.Errorf("output: %w", err)
} }

View File

@ -12,11 +12,13 @@ import (
type spec struct { type spec struct {
Input map[string]reflect.Type Input map[string]reflect.Type
Output map[string]reflect.Type Output map[string]reflect.Type
// HasContext defines whether the given handler has api.Ctx as first argument
HasContext bool
} }
// builds a spec from the configuration service // builds a spec from the configuration service
func makeSpec(service config.Service) spec { func makeSpec(service config.Service) *spec {
spec := spec{ s := &spec{
Input: make(map[string]reflect.Type), Input: make(map[string]reflect.Type),
Output: make(map[string]reflect.Type), Output: make(map[string]reflect.Type),
} }
@ -27,40 +29,46 @@ func makeSpec(service config.Service) spec {
} }
// make a pointer if optional // make a pointer if optional
if param.Optional { if param.Optional {
spec.Input[param.Rename] = reflect.PtrTo(param.ExtractType) s.Input[param.Rename] = reflect.PtrTo(param.ExtractType)
continue continue
} }
spec.Input[param.Rename] = param.ExtractType s.Input[param.Rename] = param.ExtractType
} }
for _, param := range service.Output { for _, param := range service.Output {
if len(param.Rename) < 1 { if len(param.Rename) < 1 {
continue continue
} }
spec.Output[param.Rename] = param.ExtractType s.Output[param.Rename] = param.ExtractType
} }
return spec return s
} }
// checks for HandlerFn input arguments // checks for HandlerFn input arguments
func (s spec) checkInput(fnv reflect.Value) error { func (s *spec) checkInput(impl reflect.Type, index int) error {
fnt := fnv.Type() var requiredInput, structIndex = index, index
if len(s.Input) > 0 { // arguments struct
requiredInput++
}
// no input -> ok // missing arguments
if len(s.Input) == 0 { if impl.NumIn() > requiredInput {
if fnt.NumIn() > 0 {
return errUnexpectedInput return errUnexpectedInput
} }
// none required
if len(s.Input) == 0 {
return nil return nil
} }
if fnt.NumIn() != 1 { // too much arguments
if impl.NumIn() != requiredInput {
return errMissingHandlerArgumentParam return errMissingHandlerArgumentParam
} }
// arg must be a struct // arg must be a struct
structArg := fnt.In(0) structArg := impl.In(structIndex)
if structArg.Kind() != reflect.Struct { if structArg.Kind() != reflect.Struct {
return errMissingParamArgument return errMissingParamArgument
} }
@ -85,14 +93,13 @@ func (s spec) checkInput(fnv reflect.Value) error {
} }
// checks for HandlerFn output arguments // checks for HandlerFn output arguments
func (s spec) checkOutput(fnv reflect.Value) error { func (s spec) checkOutput(impl reflect.Type) error {
fnt := fnv.Type() if impl.NumOut() < 1 {
if fnt.NumOut() < 1 {
return errMissingHandlerOutput return errMissingHandlerOutput
} }
// last output must be api.Err // last output must be api.Err
errOutput := fnt.Out(fnt.NumOut() - 1) errOutput := impl.Out(impl.NumOut() - 1)
if !errOutput.AssignableTo(reflect.TypeOf(api.ErrUnknown)) { if !errOutput.AssignableTo(reflect.TypeOf(api.ErrUnknown)) {
return errMissingHandlerErrorOutput return errMissingHandlerErrorOutput
} }
@ -102,12 +109,12 @@ func (s spec) checkOutput(fnv reflect.Value) error {
return nil return nil
} }
if fnt.NumOut() != 2 { if impl.NumOut() != 2 {
return errMissingParamOutput return errMissingParamOutput
} }
// fail if first output is not a pointer to struct // fail if first output is not a pointer to struct
structOutputPtr := fnt.Out(0) structOutputPtr := impl.Out(0)
if structOutputPtr.Kind() != reflect.Ptr { if structOutputPtr.Kind() != reflect.Ptr {
return errMissingParamOutput return errMissingParamOutput
} }

View File

@ -14,24 +14,28 @@ func TestInputCheck(t *testing.T) {
Name string Name string
Input map[string]reflect.Type Input map[string]reflect.Type
Fn interface{} Fn interface{}
FnCtx interface{}
Err error Err error
}{ }{
{ {
Name: "no input 0 given", Name: "no input 0 given",
Input: map[string]reflect.Type{}, Input: map[string]reflect.Type{},
Fn: func() {}, Fn: func() {},
FnCtx: func(api.Ctx) {},
Err: nil, Err: nil,
}, },
{ {
Name: "no input 1 given", Name: "no input 1 given",
Input: map[string]reflect.Type{}, Input: map[string]reflect.Type{},
Fn: func(int) {}, Fn: func(int) {},
FnCtx: func(api.Ctx, int) {},
Err: errUnexpectedInput, Err: errUnexpectedInput,
}, },
{ {
Name: "no input 2 given", Name: "no input 2 given",
Input: map[string]reflect.Type{}, Input: map[string]reflect.Type{},
Fn: func(int, string) {}, Fn: func(int, string) {},
FnCtx: func(api.Ctx, int, string) {},
Err: errUnexpectedInput, Err: errUnexpectedInput,
}, },
{ {
@ -40,6 +44,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf(int(0)), "Test1": reflect.TypeOf(int(0)),
}, },
Fn: func() {}, Fn: func() {},
FnCtx: func(api.Ctx) {},
Err: errMissingHandlerArgumentParam, Err: errMissingHandlerArgumentParam,
}, },
{ {
@ -48,6 +53,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf(int(0)), "Test1": reflect.TypeOf(int(0)),
}, },
Fn: func(int) {}, Fn: func(int) {},
FnCtx: func(api.Ctx, int) {},
Err: errMissingParamArgument, Err: errMissingParamArgument,
}, },
{ {
@ -56,6 +62,7 @@ func TestInputCheck(t *testing.T) {
"test1": reflect.TypeOf(int(0)), "test1": reflect.TypeOf(int(0)),
}, },
Fn: func(struct{}) {}, Fn: func(struct{}) {},
FnCtx: func(api.Ctx, struct{}) {},
Err: errUnexportedName, Err: errUnexportedName,
}, },
{ {
@ -64,6 +71,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf(int(0)), "Test1": reflect.TypeOf(int(0)),
}, },
Fn: func(struct{}) {}, Fn: func(struct{}) {},
FnCtx: func(api.Ctx, struct{}) {},
Err: errMissingParamFromConfig, Err: errMissingParamFromConfig,
}, },
{ {
@ -72,6 +80,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf(int(0)), "Test1": reflect.TypeOf(int(0)),
}, },
Fn: func(struct{ Test1 string }) {}, Fn: func(struct{ Test1 string }) {},
FnCtx: func(api.Ctx, struct{ Test1 string }) {},
Err: errWrongParamTypeFromConfig, Err: errWrongParamTypeFromConfig,
}, },
{ {
@ -80,6 +89,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf(int(0)), "Test1": reflect.TypeOf(int(0)),
}, },
Fn: func(struct{ Test1 int }) {}, Fn: func(struct{ Test1 int }) {},
FnCtx: func(api.Ctx, struct{ Test1 int }) {},
Err: nil, Err: nil,
}, },
{ {
@ -88,6 +98,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf(new(int)), "Test1": reflect.TypeOf(new(int)),
}, },
Fn: func(struct{}) {}, Fn: func(struct{}) {},
FnCtx: func(api.Ctx, struct{}) {},
Err: errMissingParamFromConfig, Err: errMissingParamFromConfig,
}, },
{ {
@ -96,6 +107,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf(new(int)), "Test1": reflect.TypeOf(new(int)),
}, },
Fn: func(struct{ Test1 string }) {}, Fn: func(struct{ Test1 string }) {},
FnCtx: func(api.Ctx, struct{ Test1 string }) {},
Err: errWrongParamTypeFromConfig, Err: errWrongParamTypeFromConfig,
}, },
{ {
@ -104,6 +116,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf(new(int)), "Test1": reflect.TypeOf(new(int)),
}, },
Fn: func(struct{ Test1 *string }) {}, Fn: func(struct{ Test1 *string }) {},
FnCtx: func(api.Ctx, struct{ Test1 *string }) {},
Err: errWrongParamTypeFromConfig, Err: errWrongParamTypeFromConfig,
}, },
{ {
@ -112,6 +125,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf(new(int)), "Test1": reflect.TypeOf(new(int)),
}, },
Fn: func(struct{ Test1 *int }) {}, Fn: func(struct{ Test1 *int }) {},
FnCtx: func(api.Ctx, struct{ Test1 *int }) {},
Err: nil, Err: nil,
}, },
{ {
@ -120,6 +134,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf(string("")), "Test1": reflect.TypeOf(string("")),
}, },
Fn: func(struct{ Test1 string }) {}, Fn: func(struct{ Test1 string }) {},
FnCtx: func(api.Ctx, struct{ Test1 string }) {},
Err: nil, Err: nil,
}, },
{ {
@ -128,6 +143,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf(uint(0)), "Test1": reflect.TypeOf(uint(0)),
}, },
Fn: func(struct{ Test1 uint }) {}, Fn: func(struct{ Test1 uint }) {},
FnCtx: func(api.Ctx, struct{ Test1 uint }) {},
Err: nil, Err: nil,
}, },
{ {
@ -136,6 +152,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf(float64(0)), "Test1": reflect.TypeOf(float64(0)),
}, },
Fn: func(struct{ Test1 float64 }) {}, Fn: func(struct{ Test1 float64 }) {},
FnCtx: func(api.Ctx, struct{ Test1 float64 }) {},
Err: nil, Err: nil,
}, },
{ {
@ -144,6 +161,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf([]byte("")), "Test1": reflect.TypeOf([]byte("")),
}, },
Fn: func(struct{ Test1 []byte }) {}, Fn: func(struct{ Test1 []byte }) {},
FnCtx: func(api.Ctx, struct{ Test1 []byte }) {},
Err: nil, Err: nil,
}, },
{ {
@ -152,6 +170,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf([]rune("")), "Test1": reflect.TypeOf([]rune("")),
}, },
Fn: func(struct{ Test1 []rune }) {}, Fn: func(struct{ Test1 []rune }) {},
FnCtx: func(api.Ctx, struct{ Test1 []rune }) {},
Err: nil, Err: nil,
}, },
{ {
@ -160,6 +179,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf(new(string)), "Test1": reflect.TypeOf(new(string)),
}, },
Fn: func(struct{ Test1 *string }) {}, Fn: func(struct{ Test1 *string }) {},
FnCtx: func(api.Ctx, struct{ Test1 *string }) {},
Err: nil, Err: nil,
}, },
{ {
@ -168,6 +188,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf(new(uint)), "Test1": reflect.TypeOf(new(uint)),
}, },
Fn: func(struct{ Test1 *uint }) {}, Fn: func(struct{ Test1 *uint }) {},
FnCtx: func(api.Ctx, struct{ Test1 *uint }) {},
Err: nil, Err: nil,
}, },
{ {
@ -176,6 +197,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf(new(float64)), "Test1": reflect.TypeOf(new(float64)),
}, },
Fn: func(struct{ Test1 *float64 }) {}, Fn: func(struct{ Test1 *float64 }) {},
FnCtx: func(api.Ctx, struct{ Test1 *float64 }) {},
Err: nil, Err: nil,
}, },
{ {
@ -184,6 +206,7 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf(new([]byte)), "Test1": reflect.TypeOf(new([]byte)),
}, },
Fn: func(struct{ Test1 *[]byte }) {}, Fn: func(struct{ Test1 *[]byte }) {},
FnCtx: func(api.Ctx, struct{ Test1 *[]byte }) {},
Err: nil, Err: nil,
}, },
{ {
@ -192,19 +215,23 @@ func TestInputCheck(t *testing.T) {
"Test1": reflect.TypeOf(new([]rune)), "Test1": reflect.TypeOf(new([]rune)),
}, },
Fn: func(struct{ Test1 *[]rune }) {}, Fn: func(struct{ Test1 *[]rune }) {},
FnCtx: func(api.Ctx, struct{ Test1 *[]rune }) {},
Err: nil, Err: nil,
}, },
} }
for _, tcase := range tcases { for _, tcase := range tcases {
t.Run(tcase.Name, func(t *testing.T) { t.Run(tcase.Name, func(t *testing.T) {
t.Parallel()
// mock spec // mock spec
s := spec{ s := spec{
Input: tcase.Input, Input: tcase.Input,
Output: nil, Output: nil,
} }
err := s.checkInput(reflect.ValueOf(tcase.Fn)) t.Run("with-context", func(t *testing.T) {
err := s.checkInput(reflect.TypeOf(tcase.FnCtx), 1)
if err == nil && tcase.Err != nil { if err == nil && tcase.Err != nil {
t.Errorf("expected an error: '%s'", tcase.Err.Error()) t.Errorf("expected an error: '%s'", tcase.Err.Error())
t.FailNow() t.FailNow()
@ -221,6 +248,25 @@ func TestInputCheck(t *testing.T) {
} }
} }
}) })
t.Run("without-context", func(t *testing.T) {
err := s.checkInput(reflect.TypeOf(tcase.Fn), 0)
if err == nil && tcase.Err != nil {
t.Errorf("expected an error: '%s'", tcase.Err.Error())
t.FailNow()
}
if err != nil && tcase.Err == nil {
t.Errorf("unexpected error: '%s'", err.Error())
t.FailNow()
}
if err != nil && tcase.Err != nil {
if !errors.Is(err, tcase.Err) {
t.Errorf("expected the error <%s> got <%s>", tcase.Err, err)
t.FailNow()
}
}
})
})
} }
} }
@ -322,13 +368,15 @@ func TestOutputCheck(t *testing.T) {
for i, tcase := range tcases { for i, tcase := range tcases {
t.Run(fmt.Sprintf("case.%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("case.%d", i), func(t *testing.T) {
t.Parallel()
// mock spec // mock spec
s := spec{ s := spec{
Input: nil, Input: nil,
Output: tcase.Output, Output: tcase.Output,
} }
err := s.checkOutput(reflect.ValueOf(tcase.Fn)) err := s.checkOutput(reflect.TypeOf(tcase.Fn))
if err == nil && tcase.Err != nil { if err == nil && tcase.Err != nil {
t.Errorf("expected an error: '%s'", tcase.Err.Error()) t.Errorf("expected an error: '%s'", tcase.Err.Error())
t.FailNow() t.FailNow()