139 lines
2.8 KiB
Go
139 lines
2.8 KiB
Go
package nn
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"gonum.org/v1/gonum/mat"
|
|
"io"
|
|
"io/ioutil"
|
|
)
|
|
|
|
type jsonNetwork struct {
|
|
Layers []uint `json:"layers"`
|
|
Biases [][]float64 `json:"biases"`
|
|
Weights [][]float64 `json:"weights"`
|
|
}
|
|
|
|
// MarshallJSON implements the json.Marshaler interface
|
|
func (net *Network) MarshalJSON() ([]byte, error) {
|
|
|
|
raw := new(jsonNetwork)
|
|
|
|
// 1. Layers
|
|
raw.Layers = net.layers
|
|
|
|
// 2. Biases
|
|
raw.Biases = make([][]float64, 0)
|
|
for _, bias := range net.Biases {
|
|
|
|
vector := bias.ColView(0)
|
|
biasJSON := make([]float64, 0)
|
|
|
|
for i, l := 0, vector.Len(); i < l; i++ {
|
|
biasJSON = append(biasJSON, vector.AtVec(i))
|
|
}
|
|
|
|
raw.Biases = append(raw.Biases, biasJSON)
|
|
}
|
|
|
|
// 3. Weights
|
|
raw.Weights = make([][]float64, 0)
|
|
for _, weight := range net.Weights {
|
|
|
|
rows, cols := weight.Dims()
|
|
weightJSON := make([]float64, 0)
|
|
|
|
for row := 0; row < rows; row++ {
|
|
for col := 0; col < cols; col++ {
|
|
weightJSON = append(weightJSON, weight.At(row, col))
|
|
}
|
|
}
|
|
|
|
raw.Weights = append(raw.Weights, weightJSON)
|
|
}
|
|
|
|
return json.Marshal(raw)
|
|
|
|
}
|
|
|
|
// UnmarshalJSON implements the json.Unmarshaler interface
|
|
func (net *Network) UnmarshalJSON(in []byte) error {
|
|
|
|
// parse as 'jsonNetwork' struct
|
|
raw := new(jsonNetwork)
|
|
err := json.Unmarshal(in, raw)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
net.fed = false
|
|
net.layers = raw.Layers
|
|
|
|
// extract biases
|
|
net.Biases = make([]*mat.Dense, 0)
|
|
for l, layer := range net.layers {
|
|
if l == 0 {
|
|
continue
|
|
}
|
|
net.Biases = append(net.Biases, mat.NewDense(int(layer), 1, raw.Biases[l-1]))
|
|
}
|
|
// extract weights
|
|
net.Weights = make([]*mat.Dense, 0)
|
|
for l, layer := range net.layers {
|
|
if l == 0 {
|
|
continue
|
|
}
|
|
net.Weights = append(net.Weights, mat.NewDense(int(layer), int(net.layers[l-1]), raw.Weights[l-1]))
|
|
}
|
|
|
|
// mockup neurons
|
|
net.Neurons = make([]*mat.Dense, 0)
|
|
for _, layer := range net.layers {
|
|
net.Neurons = append(net.Neurons, mat.NewDense(int(layer), 1, nil))
|
|
}
|
|
fmt.Printf("neurons: %v\n", net.Neurons)
|
|
|
|
// extract into the current network receiver (net)
|
|
return nil
|
|
}
|
|
|
|
// WriteTo implements the io.WriterTo interface
|
|
func (net *Network) WriteTo(w io.Writer) (int64, error) {
|
|
|
|
// get json
|
|
raw, err := json.Marshal(net)
|
|
if err != nil {
|
|
return 0, nil
|
|
}
|
|
|
|
// write to file
|
|
written, err := w.Write(raw)
|
|
return int64(written), err
|
|
}
|
|
|
|
// ReadFrom implements the io.ReaderFrom interface
|
|
func (net *Network) ReadFrom(r io.Reader) (int64, error) {
|
|
|
|
// read
|
|
raw, err := ioutil.ReadAll(r)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// parse json
|
|
readNet := new(Network)
|
|
err = json.Unmarshal(raw, readNet)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// copy values
|
|
net.layers = readNet.layers
|
|
net.fed = readNet.fed
|
|
net.Neurons = readNet.Neurons
|
|
net.Biases = readNet.Biases
|
|
net.Weights = readNet.Weights
|
|
|
|
return int64(len(raw)), nil
|
|
}
|