neuralnet/storage.go

137 lines
2.7 KiB
Go
Raw Normal View History

package nn
import (
"encoding/json"
"gonum.org/v1/gonum/mat"
"io"
"io/ioutil"
)
type jsonNetwork struct {
Layers []uint `json:"layers"`
Biases [][]float64 `json:"biases"`
Weights [][]float64 `json:"weights"`
}
// MarshallJSON implements the json.Marshaler interface
func (net *Network) MarshalJSON() ([]byte, error) {
raw := new(jsonNetwork)
// 1. Layers
raw.Layers = net.layers
// 2. Biases
raw.Biases = make([][]float64, 0)
for _, bias := range net.biases {
vector := bias.ColView(0)
biasJSON := make([]float64, 0)
for i, l := 0, vector.Len(); i < l; i++ {
biasJSON = append(biasJSON, vector.AtVec(i))
}
raw.Biases = append(raw.Biases, biasJSON)
}
// 3. Weights
raw.Weights = make([][]float64, 0)
for _, weight := range net.weights {
rows, cols := weight.Dims()
weightJSON := make([]float64, 0)
for row := 0; row < rows; row++ {
for col := 0; col < cols; col++ {
weightJSON = append(weightJSON, weight.At(row, col))
}
}
raw.Weights = append(raw.Weights, weightJSON)
}
return json.Marshal(raw)
}
// UnmarshalJSON implements the json.Unmarshaler interface
func (net *Network) UnmarshalJSON(in []byte) error {
// parse as 'jsonNetwork' struct
raw := new(jsonNetwork)
err := json.Unmarshal(in, raw)
if err != nil {
return err
}
net.fed = false
net.layers = raw.Layers
// extract biases
net.biases = make([]*mat.Dense, 0)
for l, layer := range net.layers {
if l == 0 {
continue
}
net.biases = append(net.biases, mat.NewDense(int(layer), 1, raw.Biases[l-1]))
}
// extract weights
net.weights = make([]*mat.Dense, 0)
for l, layer := range net.layers {
if l == 0 {
continue
}
net.weights = append(net.weights, mat.NewDense(int(layer), int(net.layers[l-1]), raw.Weights[l-1]))
}
// mockup neurons
net.neurons = make([]*mat.Dense, 0)
for _, layer := range net.layers {
net.neurons = append(net.neurons, mat.NewDense(int(layer), 1, nil))
}
// extract into the current network receiver (net)
return nil
}
// WriteTo implements the io.WriterTo interface
func (net *Network) WriteTo(w io.Writer) (int64, error) {
// get json
raw, err := json.Marshal(net)
if err != nil {
return 0, nil
}
// write to file
written, err := w.Write(raw)
return int64(written), err
}
// ReadFrom implements the io.ReaderFrom interface
func (net *Network) ReadFrom(r io.Reader) (int64, error) {
// read
raw, err := ioutil.ReadAll(r)
if err != nil {
return 0, err
}
// parse json
readNet := new(Network)
err = json.Unmarshal(raw, readNet)
if err != nil {
return 0, err
}
// copy values
net.layers = readNet.layers
net.fed = readNet.fed
net.neurons = readNet.neurons
net.biases = readNet.biases
net.weights = readNet.weights
return int64(len(raw)), nil
}