package nn import ( "encoding/json" "gonum.org/v1/gonum/mat" "io" "io/ioutil" ) type jsonNetwork struct { Layers []uint `json:"layers"` Biases [][]float64 `json:"biases"` Weights [][]float64 `json:"weights"` } // MarshallJSON implements the json.Marshaler interface func (net *Network) MarshalJSON() ([]byte, error) { raw := new(jsonNetwork) // 1. Layers raw.Layers = net.layers // 2. Biases raw.Biases = make([][]float64, 0) for _, bias := range net.biases { vector := bias.ColView(0) biasJSON := make([]float64, 0) for i, l := 0, vector.Len(); i < l; i++ { biasJSON = append(biasJSON, vector.AtVec(i)) } raw.Biases = append(raw.Biases, biasJSON) } // 3. Weights raw.Weights = make([][]float64, 0) for _, weight := range net.weights { rows, cols := weight.Dims() weightJSON := make([]float64, 0) for row := 0; row < rows; row++ { for col := 0; col < cols; col++ { weightJSON = append(weightJSON, weight.At(row, col)) } } raw.Weights = append(raw.Weights, weightJSON) } return json.Marshal(raw) } // UnmarshalJSON implements the json.Unmarshaler interface func (net *Network) UnmarshalJSON(in []byte) error { // parse as 'jsonNetwork' struct raw := new(jsonNetwork) err := json.Unmarshal(in, raw) if err != nil { return err } net.fed = false net.layers = raw.Layers // extract biases net.biases = make([]*mat.Dense, 0) for l, layer := range net.layers { if l == 0 { continue } net.biases = append(net.biases, mat.NewDense(int(layer), 1, raw.Biases[l-1])) } // extract weights net.weights = make([]*mat.Dense, 0) for l, layer := range net.layers { if l == 0 { continue } net.weights = append(net.weights, mat.NewDense(int(layer), int(net.layers[l-1]), raw.Weights[l-1])) } // mockup neurons net.neurons = make([]*mat.Dense, 0) for _, layer := range net.layers { net.neurons = append(net.neurons, mat.NewDense(int(layer), 1, nil)) } // extract into the current network receiver (net) return nil } // WriteTo implements the io.WriterTo interface func (net *Network) WriteTo(w io.Writer) (int64, error) { // get json raw, err := json.Marshal(net) if err != nil { return 0, nil } // write to file written, err := w.Write(raw) return int64(written), err } // ReadFrom implements the io.ReaderFrom interface func (net *Network) ReadFrom(r io.Reader) (int64, error) { // read raw, err := ioutil.ReadAll(r) if err != nil { return 0, err } // parse json readNet := new(Network) err = json.Unmarshal(raw, readNet) if err != nil { return 0, err } // copy values net.layers = readNet.layers net.fed = readNet.fed net.neurons = readNet.neurons net.biases = readNet.biases net.weights = readNet.weights return int64(len(raw)), nil }