diff --git a/internal/cnf/loader.go b/internal/cnf/loader.go index 81b1808..808017f 100644 --- a/internal/cnf/loader.go +++ b/internal/cnf/loader.go @@ -71,6 +71,8 @@ func loadFromExtension(ext string) ConfigurationFormat { return new(json) case ".ini": return new(ini) + case ".yaml": + return new(yaml) case ".conf": return new(confl) default: diff --git a/internal/cnf/yaml.go b/internal/cnf/yaml.go new file mode 100644 index 0000000..87c9440 --- /dev/null +++ b/internal/cnf/yaml.go @@ -0,0 +1,136 @@ +package cnf + +import ( + lib "gopkg.in/yaml.v2" + "io" + "strings" +) + +type yaml struct { + data interface{} + parsed bool +} + +// ReadFrom implements io.ReaderFrom +func (d *yaml) ReadFrom(_reader io.Reader) (int64, error) { + + // 1. get yaml decoder + decoder := lib.NewDecoder(_reader) + err := decoder.Decode(&d.data) + if err != nil { + return 0, err + } + d.parsed = true + return 0, nil + +} + +// WriteTo implements io.WriterTo +func (d *yaml) WriteTo(_writer io.Writer) (int64, error) { + encoder := lib.NewEncoder(_writer) + defer encoder.Close() + return 0, encoder.Encode(&d.data) +} + +// browse returns the target of a dot-separated path (as an interface{} chain where the last is the target if found) +func (d *yaml) browse(_path string) ([]interface{}, bool) { + + // 1. extract path + path := strings.Split(_path, ".") + + // 2. init output chain + current := d.data + chain := make([]interface{}, 0, len(path)+1) + chain = append(chain, current) + + // 3. iterate over path / nested fields + for _, field := range path { + fmap, ok := current.(map[interface{}]interface{}) + if !ok { // incomplete path + return chain, false + } + + child, ok := fmap[field] + if !ok { // incomplete path + return chain, false + } + + current = child + chain = append(chain, current) + } + + return chain, true +} + +// Get returns the value of a dot-separated path, and if it exists +func (d *yaml) Get(_path string) (string, bool) { + + // 1. browse path + chain, found := d.browse(_path) + if !found { + return "", false + } + + // 2. return if string value + value, ok := chain[len(chain)-1].(string) + return value, ok + +} + +// Set the value of a dot-separated path, and creates it if not found +func (d *yaml) Set(_path, _value string) bool { + + // 1. browse path + create it if does not exist + path := strings.Split(_path, ".") + lp := len(path) + chain, found := d.browse(_path) + + // 2. if found -> set value + if found { + mapWrapper, ok := chain[len(chain)-2].(map[interface{}]interface{}) + if !ok { // impossible + return false + } + key := path[lp-1] + mapWrapper[key] = _value + return true + } + + // 3. create path until the end to set value + root := make(map[interface{}]interface{}) + current := root + + // create children until second to last + for i, l := len(chain)-1, lp-1; i < l; i++ { + child := make(map[interface{}]interface{}) + current[path[i]] = child + current = child + } + + // set value + current[path[lp-1]] = _value + + // replace whole object if empty + if len(chain) < 2 { + wrapper, ok := d.data.(map[interface{}]interface{}) + if !ok { // impossible + return false + } + key := path[0] + wrapper[key] = root[key] // store with key ; eitherway it will erase all brother keys + return true + } + + // update last found object to add the value + wrapper, ok := chain[len(chain)-1].(map[interface{}]interface{}) + if !ok { // impossible + return false + } + + // add each subkey + for subkey, subvalue := range root { + wrapper[subkey] = subvalue + } + return true + +} diff --git a/internal/cnf/yaml_test.go b/internal/cnf/yaml_test.go new file mode 100644 index 0000000..a569a9b --- /dev/null +++ b/internal/cnf/yaml_test.go @@ -0,0 +1,192 @@ +package cnf + +import ( + "bytes" + "testing" +) + +func TestYamlGet(t *testing.T) { + + tests := []struct { + raw string + key string + }{ + {"key: value\n", "key"}, + {"ignore: xxx\nkey: value\n", "key"}, + {"parent:\n child: value\n", "parent.child"}, + {"ignore: xxx\nparent:\n child: value\n", "parent.child"}, + } + + for _, test := range tests { + + parser := new(yaml) + reader := bytes.NewBufferString(test.raw) + + // try to extract value + _, err := parser.ReadFrom(reader) + if err != nil { + t.Errorf("parse error: %s", err) + continue + } + + // extract value + value, found := parser.Get(test.key) + if !found { + t.Errorf("expected a result, got none") + continue + } + + // check value + if value != "value" { + t.Errorf("expected 'value' got '%s'", value) + } + + } + +} + +func TestYamlGetNotString(t *testing.T) { + + tests := []struct { + raw string + key string + }{ + {"key:\n- value\n", "key"}, + {"key:\n subkey: value\n", "key"}, + {"parent:\n child:\n - value\n", "parent.child"}, + {"parent:\n child:\n subkey: value\n", "parent.child"}, + } + + for i, test := range tests { + + parser := new(yaml) + reader := bytes.NewBufferString(test.raw) + + // try to extract value + _, err := parser.ReadFrom(reader) + if err != nil { + t.Errorf("[%d] parse error: %s", i, err) + continue + } + + // extract value + value, found := parser.Get(test.key) + if found || len(value) > 0 { + t.Errorf("[%d] expected no result, got '%s'", i, value) + continue + } + + } + +} + +func TestYamlSetPathExistsAndIsString(t *testing.T) { + + tests := []struct { + raw string + key string + value string + }{ + {"key: value\n", "key", "newvalue"}, + {"ignore: xxx\nkey: value\n", "key", "newvalue"}, + {"parent:\n child: value\n", "parent.child", "newvalue"}, + {"ignore: xxx\nparent:\n child: value\n", "parent.child", "newvalue"}, + } + + for i, test := range tests { + + parser := new(yaml) + reader := bytes.NewBufferString(test.raw) + + // try to extract value + _, err := parser.ReadFrom(reader) + if err != nil { + t.Errorf("parse error: %s", err) + continue + } + + // update value + if !parser.Set(test.key, test.value) { + t.Errorf("[%d] cannot set '%s' to '%s'", i, test.key, test.value) + continue + } + + // check new value + value, found := parser.Get(test.key) + if !found { + t.Errorf("[%d] expected a result, got none", i) + continue + } + + // check value + if value != test.value { + t.Errorf("[%d] expected '%s' got '%s'", i, test.value, value) + } + + } + +} + +func TestYamlSetCreatePath(t *testing.T) { + + tests := []struct { + raw string + key string + ignore string // path to field that must be present after transformation + value string + }{ + {"ignore: xxx\n", "key", "ignore", "newvalue"}, + {"ignore: xxx\n", "parent.child", "ignore", "newvalue"}, + {"ignore: xxx\n", "parent.child.subchild", "ignore", "newvalue"}, + + {"ignore: xxx\n", "key", "ignore", "newvalue"}, + {"parent:\n ignore: xxx\n", "parent.child", "parent.ignore", "newvalue"}, + {"ignore: xxx\nparent: {}\n", "parent.child", "ignore", "newvalue"}, + } + + for i, test := range tests { + + parser := new(yaml) + reader := bytes.NewBufferString(test.raw) + + // try to extract value + _, err := parser.ReadFrom(reader) + if err != nil { + t.Errorf("[%d] parse error: %s", i, err) + continue + } + + // update value + if !parser.Set(test.key, test.value) { + t.Errorf("[%d] cannot set '%s' to '%s'", i, test.key, test.value) + continue + } + + // check new value + value, found := parser.Get(test.key) + if !found { + t.Errorf("[%d] expected a result, got none", i) + continue + } + + // check value + if value != test.value { + t.Errorf("[%d] expected '%s' got '%s'", i, test.value, value) + continue + } + + // check that ignore field is still there + value, found = parser.Get(test.ignore) + if !found { + t.Errorf("[%d] expected ignore field, got none", i) + continue + } + + // check value + if value != "xxx" { + t.Errorf("[%d] expected ignore value to be '%s' got '%s'", i, "xxx", value) + continue + } + } + +}