1use super::nn::{Activation, DenseLayer, FeedForwardNet};
16use nalgebra::{DMatrix, DVector};
17use std::collections::HashMap;
18use std::io::{Read, Write};
19
20pub fn load_weights<R: Read>(reader: &mut R) -> Result<HashMap<u8, FeedForwardNet>, String> {
22 let mut buf4 = [0u8; 4];
23 let mut buf1 = [0u8; 1];
24 let mut buf8 = [0u8; 8];
25
26 reader.read_exact(&mut buf4).map_err(|e| e.to_string())?;
27 let n_elements = u32::from_le_bytes(buf4) as usize;
28
29 let mut models = HashMap::new();
30
31 for _ in 0..n_elements {
32 reader.read_exact(&mut buf1).map_err(|e| e.to_string())?;
33 let element = buf1[0];
34
35 reader.read_exact(&mut buf4).map_err(|e| e.to_string())?;
36 let n_layers = u32::from_le_bytes(buf4) as usize;
37
38 let mut layers = Vec::with_capacity(n_layers);
39
40 for _ in 0..n_layers {
41 reader.read_exact(&mut buf4).map_err(|e| e.to_string())?;
42 let rows = u32::from_le_bytes(buf4) as usize;
43
44 reader.read_exact(&mut buf4).map_err(|e| e.to_string())?;
45 let cols = u32::from_le_bytes(buf4) as usize;
46
47 let mut weights = vec![0.0f64; rows * cols];
48 for w in &mut weights {
49 reader.read_exact(&mut buf8).map_err(|e| e.to_string())?;
50 *w = f64::from_le_bytes(buf8);
51 }
52
53 let mut bias = vec![0.0f64; rows];
54 for b in &mut bias {
55 reader.read_exact(&mut buf8).map_err(|e| e.to_string())?;
56 *b = f64::from_le_bytes(buf8);
57 }
58
59 reader.read_exact(&mut buf1).map_err(|e| e.to_string())?;
60 let activation = match buf1[0] {
61 1 => Activation::Gelu,
62 2 => Activation::Celu,
63 _ => Activation::None,
64 };
65
66 layers.push(DenseLayer {
67 weights: DMatrix::from_row_slice(rows, cols, &weights),
68 bias: DVector::from_vec(bias),
69 activation,
70 });
71 }
72
73 models.insert(element, FeedForwardNet::new(layers));
74 }
75
76 Ok(models)
77}
78
79pub fn save_weights<W: Write>(
81 writer: &mut W,
82 models: &HashMap<u8, FeedForwardNet>,
83) -> Result<(), String> {
84 writer
85 .write_all(&(models.len() as u32).to_le_bytes())
86 .map_err(|e| e.to_string())?;
87
88 for (&element, net) in models {
89 writer.write_all(&[element]).map_err(|e| e.to_string())?;
90 writer
91 .write_all(&(net.layers.len() as u32).to_le_bytes())
92 .map_err(|e| e.to_string())?;
93
94 for layer in &net.layers {
95 let rows = layer.weights.nrows();
96 let cols = layer.weights.ncols();
97 writer
98 .write_all(&(rows as u32).to_le_bytes())
99 .map_err(|e| e.to_string())?;
100 writer
101 .write_all(&(cols as u32).to_le_bytes())
102 .map_err(|e| e.to_string())?;
103
104 for r in 0..rows {
106 for c in 0..cols {
107 writer
108 .write_all(&layer.weights[(r, c)].to_le_bytes())
109 .map_err(|e| e.to_string())?;
110 }
111 }
112 for b in layer.bias.iter() {
114 writer
115 .write_all(&b.to_le_bytes())
116 .map_err(|e| e.to_string())?;
117 }
118
119 let act_byte: u8 = match layer.activation {
120 Activation::Gelu => 1,
121 Activation::Celu => 2,
122 Activation::None => 0,
123 };
124 writer.write_all(&[act_byte]).map_err(|e| e.to_string())?;
125 }
126 }
127 Ok(())
128}
129
130pub fn make_test_model(input_dim: usize) -> FeedForwardNet {
132 let l1 = DenseLayer {
133 weights: DMatrix::from_fn(16, input_dim, |r, c| {
134 ((r * input_dim + c) as f64 * 0.01).sin() * 0.1
135 }),
136 bias: DVector::from_element(16, 0.01),
137 activation: Activation::Gelu,
138 };
139 let l2 = DenseLayer {
140 weights: DMatrix::from_fn(1, 16, |_, c| (c as f64 * 0.1).cos() * 0.05),
141 bias: DVector::from_element(1, 0.0),
142 activation: Activation::None,
143 };
144 FeedForwardNet::new(vec![l1, l2])
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use std::io::Cursor;
151
152 #[test]
153 fn test_roundtrip() {
154 let mut models = HashMap::new();
155 models.insert(1u8, make_test_model(8));
156 models.insert(6u8, make_test_model(8));
157
158 let mut buf = Vec::new();
159 save_weights(&mut buf, &models).unwrap();
160
161 let mut cursor = Cursor::new(buf);
162 let loaded = load_weights(&mut cursor).unwrap();
163
164 assert_eq!(loaded.len(), 2);
165 assert!(loaded.contains_key(&1));
166 assert!(loaded.contains_key(&6));
167
168 let input = DVector::from_element(8, 0.5);
170 let orig = models[&1].forward(&input);
171 let load = loaded[&1].forward(&input);
172 assert!((orig - load).abs() < 1e-12);
173 }
174}