diff --git a/internal/reqdata/param_reflect.go b/internal/reqdata/param_reflect.go deleted file mode 100644 index 6cbd403..0000000 --- a/internal/reqdata/param_reflect.go +++ /dev/null @@ -1,88 +0,0 @@ -package reqdata - -import ( - "encoding/json" - "fmt" - "reflect" -) - -// parseParameter parses http GET/POST data -// - []string -// - size = 1 : return json of first element -// - size > 1 : return array of json elements -// - string : return json if valid, else return raw string -func parseParameter(data interface{}) interface{} { - dtype := reflect.TypeOf(data) - dvalue := reflect.ValueOf(data) - - switch dtype.Kind() { - - /* (1) []string -> recursive */ - case reflect.Slice: - - // 1. Return nothing if empty - if dvalue.Len() == 0 { - return nil - } - - // 2. only return first element if alone - if dvalue.Len() == 1 { - - element := dvalue.Index(0) - if element.Kind() != reflect.String { - return nil - } - return parseParameter(element.String()) - - } - - // 3. Return all elements if more than 1 - result := make([]interface{}, dvalue.Len()) - - for i, l := 0, dvalue.Len(); i < l; i++ { - element := dvalue.Index(i) - - // ignore non-string - if element.Kind() != reflect.String { - continue - } - - result[i] = parseParameter(element.String()) - } - return result - - /* (2) string -> parse */ - case reflect.String: - - // build json wrapper - wrapper := fmt.Sprintf("{\"wrapped\":%s}", dvalue.String()) - - // try to parse as json - var result interface{} - err := json.Unmarshal([]byte(wrapper), &result) - - // return if success - if err == nil { - - mapval, ok := result.(map[string]interface{}) - if !ok { - return dvalue.String() - } - - wrapped, ok := mapval["wrapped"] - if !ok { - return dvalue.String() - } - - return wrapped - } - - // else return as string - return dvalue.String() - - } - - /* (3) NIL if unknown type */ - return dvalue - -} diff --git a/internal/reqdata/parameter.go b/internal/reqdata/parameter.go index 79201b9..4bceb95 100644 --- a/internal/reqdata/parameter.go +++ b/internal/reqdata/parameter.go @@ -1,5 +1,11 @@ package reqdata +import ( + "encoding/json" + "fmt" + "reflect" +) + // Parameter represents an http request parameter // that can be of type URL, GET, or FORM (multipart, json, urlencoded) type Parameter struct { @@ -27,3 +33,84 @@ func (i *Parameter) Parse() { i.Value = parseParameter(i.Value) } + +// parseParameter parses http GET/POST data +// - []string +// - size = 1 : return json of first element +// - size > 1 : return array of json elements +// - string : return json if valid, else return raw string +func parseParameter(data interface{}) interface{} { + dtype := reflect.TypeOf(data) + dvalue := reflect.ValueOf(data) + + switch dtype.Kind() { + + /* (1) []string -> recursive */ + case reflect.Slice: + + // 1. Return nothing if empty + if dvalue.Len() == 0 { + return nil + } + + // 2. only return first element if alone + if dvalue.Len() == 1 { + + element := dvalue.Index(0) + if element.Kind() != reflect.String { + return nil + } + return parseParameter(element.String()) + + } + + // 3. Return all elements if more than 1 + result := make([]interface{}, dvalue.Len()) + + for i, l := 0, dvalue.Len(); i < l; i++ { + element := dvalue.Index(i) + + // ignore non-string + if element.Kind() != reflect.String { + continue + } + + result[i] = parseParameter(element.String()) + } + return result + + /* (2) string -> parse */ + case reflect.String: + + // build json wrapper + wrapper := fmt.Sprintf("{\"wrapped\":%s}", dvalue.String()) + + // try to parse as json + var result interface{} + err := json.Unmarshal([]byte(wrapper), &result) + + // return if success + if err == nil { + + mapval, ok := result.(map[string]interface{}) + if !ok { + return dvalue.String() + } + + wrapped, ok := mapval["wrapped"] + if !ok { + return dvalue.String() + } + + return wrapped + } + + // else return as string + return dvalue.String() + + } + + /* (3) NIL if unknown type */ + return dvalue + +} diff --git a/internal/reqdata/store.go b/internal/reqdata/store.go index 514d4d4..0a4d81e 100644 --- a/internal/reqdata/store.go +++ b/internal/reqdata/store.go @@ -57,6 +57,11 @@ func New(uriParams []string, req *http.Request) *Store { // 1. set URI parameters ds.setURIParams(uriParams) + // ignore nil requests + if req == nil { + return ds + } + // 2. GET (query) data ds.readQuery(req) diff --git a/internal/reqdata/store_test.go b/internal/reqdata/store_test.go new file mode 100644 index 0000000..6866a8f --- /dev/null +++ b/internal/reqdata/store_test.go @@ -0,0 +1,547 @@ +package reqdata + +import ( + "bytes" + "fmt" + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +func TestEmptyStore(t *testing.T) { + store := New(nil, nil) + + if store.URI == nil { + t.Errorf("store 'URI' list should be initialized") + t.Fail() + } + if len(store.URI) != 0 { + t.Errorf("store 'URI' list should be empty") + t.Fail() + } + + if store.Get == nil { + t.Errorf("store 'Get' map should be initialized") + t.Fail() + } + if store.Form == nil { + t.Errorf("store 'Form' map should be initialized") + t.Fail() + } + if store.Set == nil { + t.Errorf("store 'Set' map should be initialized") + t.Fail() + } +} + +func TestStoreWithUri(t *testing.T) { + urilist := []string{"abc", "def"} + store := New(urilist, nil) + + if len(store.URI) != len(urilist) { + t.Errorf("store 'Set' should contain %d elements (got %d)", len(urilist), len(store.URI)) + t.Fail() + } + if len(store.Set) != len(urilist) { + t.Errorf("store 'Set' should contain %d elements (got %d)", len(urilist), len(store.Set)) + t.Fail() + } + + for i, value := range urilist { + + t.Run(fmt.Sprintf("URL#%d='%s'", i, value), func(t *testing.T) { + key := fmt.Sprintf("URL#%d", i) + element, isset := store.Set[key] + + if !isset { + t.Errorf("store should contain element with key '%s'", key) + t.Failed() + } + + if element.Value != value { + t.Errorf("store[%s] should return '%s' (got '%s')", key, value, element.Value) + t.Failed() + } + }) + + } + +} + +func TestStoreWithGet(t *testing.T) { + tests := []struct { + Query string + + InvalidNames []string + ParamNames []string + ParamValues [][]string + }{ + { + Query: "", + InvalidNames: []string{}, + ParamNames: []string{}, + ParamValues: [][]string{}, + }, + { + Query: "a", + InvalidNames: []string{}, + ParamNames: []string{"a"}, + ParamValues: [][]string{[]string{""}}, + }, + { + Query: "a&b", + InvalidNames: []string{}, + ParamNames: []string{"a", "b"}, + ParamValues: [][]string{[]string{""}, []string{""}}, + }, + { + Query: "a=", + InvalidNames: []string{}, + ParamNames: []string{"a"}, + ParamValues: [][]string{[]string{""}}, + }, + { + Query: "a=&b=x", + InvalidNames: []string{}, + ParamNames: []string{"a", "b"}, + ParamValues: [][]string{[]string{""}, []string{"x"}}, + }, + { + Query: "a=b&c=d", + InvalidNames: []string{}, + ParamNames: []string{"a", "c"}, + ParamValues: [][]string{[]string{"b"}, []string{"d"}}, + }, + { + Query: "a=b&c=d&a=x", + InvalidNames: []string{}, + ParamNames: []string{"a", "c"}, + ParamValues: [][]string{[]string{"b", "x"}, []string{"d"}}, + }, + { + Query: "a=b&_invalid=x", + InvalidNames: []string{"_invalid"}, + ParamNames: []string{"a", "_invalid"}, + ParamValues: [][]string{[]string{"b"}, []string{""}}, + }, + { + Query: "a=b&invalid_=x", + InvalidNames: []string{"invalid_"}, + ParamNames: []string{"a", "invalid_"}, + ParamValues: [][]string{[]string{"b"}, []string{""}}, + }, + { + Query: "a=b&GET@injection=x", + InvalidNames: []string{"GET@injection"}, + ParamNames: []string{"a", "GET@injection"}, + ParamValues: [][]string{[]string{"b"}, []string{""}}, + }, + { // not really useful as all after '#' should be ignored by http clients + Query: "a=b&URL#injection=x", + InvalidNames: []string{"URL#injection"}, + ParamNames: []string{"a", "URL#injection"}, + ParamValues: [][]string{[]string{"b"}, []string{""}}, + }, + } + + for i, test := range tests { + t.Run(fmt.Sprintf("request.%d", i), func(t *testing.T) { + + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://host.com?%s", test.Query), nil) + store := New(nil, req) + + if test.ParamNames == nil || test.ParamValues == nil { + if len(store.Set) != 0 { + t.Errorf("expected no GET parameters and got %d", len(store.Get)) + t.Failed() + } + + // no param to check + return + } + + if len(test.ParamNames) != len(test.ParamValues) { + t.Errorf("invalid test: names and values differ in size (%d vs %d)", len(test.ParamNames), len(test.ParamValues)) + t.Failed() + } + + for pi, pName := range test.ParamNames { + key := fmt.Sprintf("GET@%s", pName) + values := test.ParamValues[pi] + + isNameValid := true + for _, invalid := range test.InvalidNames { + if pName == invalid { + isNameValid = false + } + } + + t.Run(key, func(t *testing.T) { + + param, isset := store.Set[key] + if !isset { + if isNameValid { + t.Errorf("store should contain element with key '%s'", key) + t.Failed() + } + return + } + + // if should be invalid + if isset && !isNameValid { + t.Errorf("store should NOT contain element with key '%s' (invalid name)", key) + t.Failed() + } + + cast, canCast := param.Value.([]string) + + if !canCast { + t.Errorf("should return a []string (got '%v')", cast) + t.Failed() + } + + if len(cast) != len(values) { + t.Errorf("should return %d string(s) (got '%d')", len(values), len(cast)) + t.Failed() + } + + for vi, value := range values { + + t.Run(fmt.Sprintf("value.%d", vi), func(t *testing.T) { + if value != cast[vi] { + t.Errorf("should return '%s' (got '%s')", value, cast[vi]) + t.Failed() + } + }) + } + }) + + } + }) + } + +} + +func TestStoreWithUrlEncodedForm(t *testing.T) { + tests := []struct { + URLEncoded string + + InvalidNames []string + ParamNames []string + ParamValues [][]string + }{ + { + URLEncoded: "", + InvalidNames: []string{}, + ParamNames: []string{}, + ParamValues: [][]string{}, + }, + { + URLEncoded: "a", + InvalidNames: []string{}, + ParamNames: []string{"a"}, + ParamValues: [][]string{[]string{""}}, + }, + { + URLEncoded: "a&b", + InvalidNames: []string{}, + ParamNames: []string{"a", "b"}, + ParamValues: [][]string{[]string{""}, []string{""}}, + }, + { + URLEncoded: "a=", + InvalidNames: []string{}, + ParamNames: []string{"a"}, + ParamValues: [][]string{[]string{""}}, + }, + { + URLEncoded: "a=&b=x", + InvalidNames: []string{}, + ParamNames: []string{"a", "b"}, + ParamValues: [][]string{[]string{""}, []string{"x"}}, + }, + { + URLEncoded: "a=b&c=d", + InvalidNames: []string{}, + ParamNames: []string{"a", "c"}, + ParamValues: [][]string{[]string{"b"}, []string{"d"}}, + }, + { + URLEncoded: "a=b&c=d&a=x", + InvalidNames: []string{}, + ParamNames: []string{"a", "c"}, + ParamValues: [][]string{[]string{"b", "x"}, []string{"d"}}, + }, + { + URLEncoded: "a=b&_invalid=x", + InvalidNames: []string{"_invalid"}, + ParamNames: []string{"a", "_invalid"}, + ParamValues: [][]string{[]string{"b"}, []string{""}}, + }, + { + URLEncoded: "a=b&invalid_=x", + InvalidNames: []string{"invalid_"}, + ParamNames: []string{"a", "invalid_"}, + ParamValues: [][]string{[]string{"b"}, []string{""}}, + }, + { + URLEncoded: "a=b&GET@injection=x", + InvalidNames: []string{"GET@injection"}, + ParamNames: []string{"a", "GET@injection"}, + ParamValues: [][]string{[]string{"b"}, []string{""}}, + }, + { + URLEncoded: "a=b&URL#injection=x", + InvalidNames: []string{"URL#injection"}, + ParamNames: []string{"a", "URL#injection"}, + ParamValues: [][]string{[]string{"b"}, []string{""}}, + }, + } + + for i, test := range tests { + t.Run(fmt.Sprintf("request.%d", i), func(t *testing.T) { + body := bytes.NewBufferString(test.URLEncoded) + req := httptest.NewRequest(http.MethodPost, "http://host.com", body) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + defer req.Body.Close() + store := New(nil, req) + + if test.ParamNames == nil || test.ParamValues == nil { + if len(store.Set) != 0 { + t.Errorf("expected no FORM parameters and got %d", len(store.Get)) + t.Failed() + } + + // no param to check + return + } + + if len(test.ParamNames) != len(test.ParamValues) { + t.Errorf("invalid test: names and values differ in size (%d vs %d)", len(test.ParamNames), len(test.ParamValues)) + t.Failed() + } + + for pi, pName := range test.ParamNames { + key := pName + values := test.ParamValues[pi] + + isNameValid := true + for _, invalid := range test.InvalidNames { + if pName == invalid { + isNameValid = false + } + } + + t.Run(key, func(t *testing.T) { + + param, isset := store.Set[key] + if !isset { + if isNameValid { + t.Errorf("store should contain element with key '%s'", key) + t.Failed() + } + return + } + + // if should be invalid + if isset && !isNameValid { + t.Errorf("store should NOT contain element with key '%s' (invalid name)", key) + t.Failed() + } + + cast, canCast := param.Value.([]string) + + if !canCast { + t.Errorf("should return a []string (got '%v')", cast) + t.Failed() + } + + if len(cast) != len(values) { + t.Errorf("should return %d string(s) (got '%d')", len(values), len(cast)) + t.Failed() + } + + for vi, value := range values { + + t.Run(fmt.Sprintf("value.%d", vi), func(t *testing.T) { + if value != cast[vi] { + t.Errorf("should return '%s' (got '%s')", value, cast[vi]) + t.Failed() + } + }) + } + }) + + } + }) + } + +} + +func TestJsonParameters(t *testing.T) { + tests := []struct { + RawJson string + + InvalidNames []string + ParamNames []string + ParamValues []interface{} + }{ + // no need to fully check json because it is parsed with the standard library + { + RawJson: "", + InvalidNames: []string{}, + ParamNames: []string{}, + ParamValues: []interface{}{}, + }, + { + RawJson: "{}", + InvalidNames: []string{}, + ParamNames: []string{}, + ParamValues: []interface{}{}, + }, + { + RawJson: "{ \"a\": \"b\" }", + InvalidNames: []string{}, + ParamNames: []string{"a"}, + ParamValues: []interface{}{"b"}, + }, + { + RawJson: "{ \"a\": \"b\", \"c\": \"d\" }", + InvalidNames: []string{}, + ParamNames: []string{"a", "c"}, + ParamValues: []interface{}{"b", "d"}, + }, + { + RawJson: "{ \"_invalid\": \"x\" }", + InvalidNames: []string{"_invalid"}, + ParamNames: []string{"_invalid"}, + ParamValues: []interface{}{nil}, + }, + { + RawJson: "{ \"a\": \"b\", \"_invalid\": \"x\" }", + InvalidNames: []string{"_invalid"}, + ParamNames: []string{"a", "_invalid"}, + ParamValues: []interface{}{"b", nil}, + }, + + { + RawJson: "{ \"invalid_\": \"x\" }", + InvalidNames: []string{"invalid_"}, + ParamNames: []string{"invalid_"}, + ParamValues: []interface{}{nil}, + }, + { + RawJson: "{ \"a\": \"b\", \"invalid_\": \"x\" }", + InvalidNames: []string{"invalid_"}, + ParamNames: []string{"a", "invalid_"}, + ParamValues: []interface{}{"b", nil}, + }, + + { + RawJson: "{ \"GET@injection\": \"x\" }", + InvalidNames: []string{"GET@injection"}, + ParamNames: []string{"GET@injection"}, + ParamValues: []interface{}{nil}, + }, + { + RawJson: "{ \"a\": \"b\", \"GET@injection\": \"x\" }", + InvalidNames: []string{"GET@injection"}, + ParamNames: []string{"a", "GET@injection"}, + ParamValues: []interface{}{"b", nil}, + }, + + { + RawJson: "{ \"URL#injection\": \"x\" }", + InvalidNames: []string{"URL#injection"}, + ParamNames: []string{"URL#injection"}, + ParamValues: []interface{}{nil}, + }, + { + RawJson: "{ \"a\": \"b\", \"URL#injection\": \"x\" }", + InvalidNames: []string{"URL#injection"}, + ParamNames: []string{"a", "URL#injection"}, + ParamValues: []interface{}{"b", nil}, + }, + // json parse error + { + RawJson: "{ \"a\": \"b\", }", + InvalidNames: []string{}, + ParamNames: []string{}, + ParamValues: []interface{}{}, + }, + } + + for i, test := range tests { + t.Run(fmt.Sprintf("request.%d", i), func(t *testing.T) { + body := bytes.NewBufferString(test.RawJson) + req := httptest.NewRequest(http.MethodPost, "http://host.com", body) + req.Header.Add("Content-Type", "application/json") + defer req.Body.Close() + store := New(nil, req) + + if test.ParamNames == nil || test.ParamValues == nil { + if len(store.Set) != 0 { + t.Errorf("expected no JSON parameters and got %d", len(store.Get)) + t.Failed() + } + + // no param to check + return + } + + if len(test.ParamNames) != len(test.ParamValues) { + t.Errorf("invalid test: names and values differ in size (%d vs %d)", len(test.ParamNames), len(test.ParamValues)) + t.Failed() + } + + for pi, pName := range test.ParamNames { + key := pName + value := test.ParamValues[pi] + + isNameValid := true + for _, invalid := range test.InvalidNames { + if pName == invalid { + isNameValid = false + } + } + + t.Run(key, func(t *testing.T) { + + param, isset := store.Set[key] + if !isset { + if isNameValid { + t.Errorf("store should contain element with key '%s'", key) + t.Failed() + } + return + } + + // if should be invalid + if isset && !isNameValid { + t.Errorf("store should NOT contain element with key '%s' (invalid name)", key) + t.Failed() + } + + valueType := reflect.TypeOf(value) + + paramValue := param.Value + paramValueType := reflect.TypeOf(param.Value) + + if valueType != paramValueType { + t.Errorf("should be of type %v (got '%v')", valueType, paramValueType) + t.Failed() + } + + if paramValue != value { + t.Errorf("should return %v (got '%v')", value, paramValue) + t.Failed() + } + + }) + + } + }) + } + +}