Skip to main content

sci_form/ani/
weights.rs

1//! Binary weight file loader for pre-trained ANI models.
2//!
3//! Simple binary format:
4//! - u32: number of element types
5//! - For each element type:
6//!   - u8: atomic number
7//!   - u32: number of layers
8//!   - For each layer:
9//!     - u32: rows (output_dim)
10//!     - u32: cols (input_dim)
11//!     - f64 × (rows × cols): weights (row-major)
12//!     - f64 × rows: bias
13//!     - u8: activation (0=None, 1=Gelu, 2=Celu)
14
15use super::nn::{Activation, DenseLayer, FeedForwardNet};
16use nalgebra::{DMatrix, DVector};
17use std::collections::HashMap;
18use std::io::{Read, Write};
19
20/// Load ANI model weights from a binary reader.
21pub fn load_weights<R: Read>(reader: &mut R) -> Result<HashMap<u8, FeedForwardNet>, String> {
22    let mut buf4 = [0u8; 4];
23    let mut buf1 = [0u8; 1];
24    let mut buf8 = [0u8; 8];
25
26    reader.read_exact(&mut buf4).map_err(|e| e.to_string())?;
27    let n_elements = u32::from_le_bytes(buf4) as usize;
28
29    let mut models = HashMap::new();
30
31    for _ in 0..n_elements {
32        reader.read_exact(&mut buf1).map_err(|e| e.to_string())?;
33        let element = buf1[0];
34
35        reader.read_exact(&mut buf4).map_err(|e| e.to_string())?;
36        let n_layers = u32::from_le_bytes(buf4) as usize;
37
38        let mut layers = Vec::with_capacity(n_layers);
39
40        for _ in 0..n_layers {
41            reader.read_exact(&mut buf4).map_err(|e| e.to_string())?;
42            let rows = u32::from_le_bytes(buf4) as usize;
43
44            reader.read_exact(&mut buf4).map_err(|e| e.to_string())?;
45            let cols = u32::from_le_bytes(buf4) as usize;
46
47            let mut weights = vec![0.0f64; rows * cols];
48            for w in &mut weights {
49                reader.read_exact(&mut buf8).map_err(|e| e.to_string())?;
50                *w = f64::from_le_bytes(buf8);
51            }
52
53            let mut bias = vec![0.0f64; rows];
54            for b in &mut bias {
55                reader.read_exact(&mut buf8).map_err(|e| e.to_string())?;
56                *b = f64::from_le_bytes(buf8);
57            }
58
59            reader.read_exact(&mut buf1).map_err(|e| e.to_string())?;
60            let activation = match buf1[0] {
61                1 => Activation::Gelu,
62                2 => Activation::Celu,
63                _ => Activation::None,
64            };
65
66            layers.push(DenseLayer {
67                weights: DMatrix::from_row_slice(rows, cols, &weights),
68                bias: DVector::from_vec(bias),
69                activation,
70            });
71        }
72
73        models.insert(element, FeedForwardNet::new(layers));
74    }
75
76    Ok(models)
77}
78
79/// Save ANI model weights to a binary writer.
80pub fn save_weights<W: Write>(
81    writer: &mut W,
82    models: &HashMap<u8, FeedForwardNet>,
83) -> Result<(), String> {
84    writer
85        .write_all(&(models.len() as u32).to_le_bytes())
86        .map_err(|e| e.to_string())?;
87
88    for (&element, net) in models {
89        writer.write_all(&[element]).map_err(|e| e.to_string())?;
90        writer
91            .write_all(&(net.layers.len() as u32).to_le_bytes())
92            .map_err(|e| e.to_string())?;
93
94        for layer in &net.layers {
95            let rows = layer.weights.nrows();
96            let cols = layer.weights.ncols();
97            writer
98                .write_all(&(rows as u32).to_le_bytes())
99                .map_err(|e| e.to_string())?;
100            writer
101                .write_all(&(cols as u32).to_le_bytes())
102                .map_err(|e| e.to_string())?;
103
104            // Weights row-major
105            for r in 0..rows {
106                for c in 0..cols {
107                    writer
108                        .write_all(&layer.weights[(r, c)].to_le_bytes())
109                        .map_err(|e| e.to_string())?;
110                }
111            }
112            // Bias
113            for b in layer.bias.iter() {
114                writer
115                    .write_all(&b.to_le_bytes())
116                    .map_err(|e| e.to_string())?;
117            }
118
119            let act_byte: u8 = match layer.activation {
120                Activation::Gelu => 1,
121                Activation::Celu => 2,
122                Activation::None => 0,
123            };
124            writer.write_all(&[act_byte]).map_err(|e| e.to_string())?;
125        }
126    }
127    Ok(())
128}
129
130/// Create a tiny test model for one element (for unit tests).
131pub fn make_test_model(input_dim: usize) -> FeedForwardNet {
132    let l1 = DenseLayer {
133        weights: DMatrix::from_fn(16, input_dim, |r, c| {
134            ((r * input_dim + c) as f64 * 0.01).sin() * 0.1
135        }),
136        bias: DVector::from_element(16, 0.01),
137        activation: Activation::Gelu,
138    };
139    let l2 = DenseLayer {
140        weights: DMatrix::from_fn(1, 16, |_, c| (c as f64 * 0.1).cos() * 0.05),
141        bias: DVector::from_element(1, 0.0),
142        activation: Activation::None,
143    };
144    FeedForwardNet::new(vec![l1, l2])
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use std::io::Cursor;
151
152    #[test]
153    fn test_roundtrip() {
154        let mut models = HashMap::new();
155        models.insert(1u8, make_test_model(8));
156        models.insert(6u8, make_test_model(8));
157
158        let mut buf = Vec::new();
159        save_weights(&mut buf, &models).unwrap();
160
161        let mut cursor = Cursor::new(buf);
162        let loaded = load_weights(&mut cursor).unwrap();
163
164        assert_eq!(loaded.len(), 2);
165        assert!(loaded.contains_key(&1));
166        assert!(loaded.contains_key(&6));
167
168        // Check a forward pass matches
169        let input = DVector::from_element(8, 0.5);
170        let orig = models[&1].forward(&input);
171        let load = loaded[&1].forward(&input);
172        assert!((orig - load).abs() < 1e-12);
173    }
174}