add reader|writer storage using JSON format with linearization of matrices (?)
This commit is contained in:
parent
be55de113c
commit
a180e09cf0
|
@ -0,0 +1,138 @@
|
|||
package nn
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gonum.org/v1/gonum/mat"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
type jsonNetwork struct {
|
||||
Layers []uint `json:"layers"`
|
||||
Biases [][]float64 `json:"biases"`
|
||||
Weights [][]float64 `json:"weights"`
|
||||
}
|
||||
|
||||
// MarshallJSON implements the json.Marshaler interface
|
||||
func (net *Network) MarshalJSON() ([]byte, error) {
|
||||
|
||||
raw := new(jsonNetwork)
|
||||
|
||||
// 1. Layers
|
||||
raw.Layers = net.layers
|
||||
|
||||
// 2. Biases
|
||||
raw.Biases = make([][]float64, 0)
|
||||
for _, bias := range net.Biases {
|
||||
|
||||
vector := bias.ColView(0)
|
||||
biasJSON := make([]float64, 0)
|
||||
|
||||
for i, l := 0, vector.Len(); i < l; i++ {
|
||||
biasJSON = append(biasJSON, vector.AtVec(i))
|
||||
}
|
||||
|
||||
raw.Biases = append(raw.Biases, biasJSON)
|
||||
}
|
||||
|
||||
// 3. Weights
|
||||
raw.Weights = make([][]float64, 0)
|
||||
for _, weight := range net.Weights {
|
||||
|
||||
rows, cols := weight.Dims()
|
||||
weightJSON := make([]float64, 0)
|
||||
|
||||
for row := 0; row < rows; row++ {
|
||||
for col := 0; col < cols; col++ {
|
||||
weightJSON = append(weightJSON, weight.At(row, col))
|
||||
}
|
||||
}
|
||||
|
||||
raw.Weights = append(raw.Weights, weightJSON)
|
||||
}
|
||||
|
||||
return json.Marshal(raw)
|
||||
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface
|
||||
func (net *Network) UnmarshalJSON(in []byte) error {
|
||||
|
||||
// parse as 'jsonNetwork' struct
|
||||
raw := new(jsonNetwork)
|
||||
err := json.Unmarshal(in, raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
net.fed = false
|
||||
net.layers = raw.Layers
|
||||
|
||||
// extract biases
|
||||
net.Biases = make([]*mat.Dense, 0)
|
||||
for l, layer := range net.layers {
|
||||
if l == 0 {
|
||||
continue
|
||||
}
|
||||
net.Biases = append(net.Biases, mat.NewDense(int(layer), 1, raw.Biases[l-1]))
|
||||
}
|
||||
// extract weights
|
||||
net.Weights = make([]*mat.Dense, 0)
|
||||
for l, layer := range net.layers {
|
||||
if l == 0 {
|
||||
continue
|
||||
}
|
||||
net.Weights = append(net.Weights, mat.NewDense(int(layer), int(net.layers[l-1]), raw.Weights[l-1]))
|
||||
}
|
||||
|
||||
// mockup neurons
|
||||
net.Neurons = make([]*mat.Dense, 0)
|
||||
for _, layer := range net.layers {
|
||||
net.Neurons = append(net.Neurons, mat.NewDense(int(layer), 1, nil))
|
||||
}
|
||||
fmt.Printf("neurons: %v\n", net.Neurons)
|
||||
|
||||
// extract into the current network receiver (net)
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteTo implements the io.WriterTo interface
|
||||
func (net *Network) WriteTo(w io.Writer) (int64, error) {
|
||||
|
||||
// get json
|
||||
raw, err := json.Marshal(net)
|
||||
if err != nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// write to file
|
||||
written, err := w.Write(raw)
|
||||
return int64(written), err
|
||||
}
|
||||
|
||||
// ReadFrom implements the io.ReaderFrom interface
|
||||
func (net *Network) ReadFrom(r io.Reader) (int64, error) {
|
||||
|
||||
// read
|
||||
raw, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// parse json
|
||||
readNet := new(Network)
|
||||
err = json.Unmarshal(raw, readNet)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// copy values
|
||||
net.layers = readNet.layers
|
||||
net.fed = readNet.fed
|
||||
net.Neurons = readNet.Neurons
|
||||
net.Biases = readNet.Biases
|
||||
net.Weights = readNet.Weights
|
||||
|
||||
return int64(len(raw)), nil
|
||||
}
|
Loading…
Reference in New Issue