diff --git a/storage.go b/storage.go new file mode 100644 index 0000000..9bac442 --- /dev/null +++ b/storage.go @@ -0,0 +1,138 @@ +package nn + +import ( + "encoding/json" + "fmt" + "gonum.org/v1/gonum/mat" + "io" + "io/ioutil" +) + +type jsonNetwork struct { + Layers []uint `json:"layers"` + Biases [][]float64 `json:"biases"` + Weights [][]float64 `json:"weights"` +} + +// MarshallJSON implements the json.Marshaler interface +func (net *Network) MarshalJSON() ([]byte, error) { + + raw := new(jsonNetwork) + + // 1. Layers + raw.Layers = net.layers + + // 2. Biases + raw.Biases = make([][]float64, 0) + for _, bias := range net.Biases { + + vector := bias.ColView(0) + biasJSON := make([]float64, 0) + + for i, l := 0, vector.Len(); i < l; i++ { + biasJSON = append(biasJSON, vector.AtVec(i)) + } + + raw.Biases = append(raw.Biases, biasJSON) + } + + // 3. Weights + raw.Weights = make([][]float64, 0) + for _, weight := range net.Weights { + + rows, cols := weight.Dims() + weightJSON := make([]float64, 0) + + for row := 0; row < rows; row++ { + for col := 0; col < cols; col++ { + weightJSON = append(weightJSON, weight.At(row, col)) + } + } + + raw.Weights = append(raw.Weights, weightJSON) + } + + return json.Marshal(raw) + +} + +// UnmarshalJSON implements the json.Unmarshaler interface +func (net *Network) UnmarshalJSON(in []byte) error { + + // parse as 'jsonNetwork' struct + raw := new(jsonNetwork) + err := json.Unmarshal(in, raw) + if err != nil { + return err + } + + net.fed = false + net.layers = raw.Layers + + // extract biases + net.Biases = make([]*mat.Dense, 0) + for l, layer := range net.layers { + if l == 0 { + continue + } + net.Biases = append(net.Biases, mat.NewDense(int(layer), 1, raw.Biases[l-1])) + } + // extract weights + net.Weights = make([]*mat.Dense, 0) + for l, layer := range net.layers { + if l == 0 { + continue + } + net.Weights = append(net.Weights, mat.NewDense(int(layer), int(net.layers[l-1]), raw.Weights[l-1])) + } + + // mockup neurons + net.Neurons = make([]*mat.Dense, 0) + for _, layer := range net.layers { + net.Neurons = append(net.Neurons, mat.NewDense(int(layer), 1, nil)) + } + fmt.Printf("neurons: %v\n", net.Neurons) + + // extract into the current network receiver (net) + return nil +} + +// WriteTo implements the io.WriterTo interface +func (net *Network) WriteTo(w io.Writer) (int64, error) { + + // get json + raw, err := json.Marshal(net) + if err != nil { + return 0, nil + } + + // write to file + written, err := w.Write(raw) + return int64(written), err +} + +// ReadFrom implements the io.ReaderFrom interface +func (net *Network) ReadFrom(r io.Reader) (int64, error) { + + // read + raw, err := ioutil.ReadAll(r) + if err != nil { + return 0, err + } + + // parse json + readNet := new(Network) + err = json.Unmarshal(raw, readNet) + if err != nil { + return 0, err + } + + // copy values + net.layers = readNet.layers + net.fed = readNet.fed + net.Neurons = readNet.Neurons + net.Biases = readNet.Biases + net.Weights = readNet.Weights + + return int64(len(raw)), nil +}