Skip to main content

sci_form/ani/
nn.rs

1//! Pure-Rust feed-forward neural network for ANI atomic energy prediction.
2//!
3//! Minimal inference engine: matrix multiply + bias + activation.
4//! Supports GELU and CELU activation functions used in ANI models.
5
6use nalgebra::{DMatrix, DVector};
7
8/// A single dense (fully-connected) layer.
9#[derive(Debug, Clone)]
10pub struct DenseLayer {
11    /// Weight matrix (output_dim × input_dim).
12    pub weights: DMatrix<f64>,
13    /// Bias vector (output_dim).
14    pub bias: DVector<f64>,
15    /// Activation function.
16    pub activation: Activation,
17}
18
19/// Activation function type.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum Activation {
22    /// Gaussian Error Linear Unit.
23    Gelu,
24    /// Continuously Differentiable Exponential Linear Unit.
25    Celu,
26    /// No activation (identity, used on the output layer).
27    None,
28}
29
30/// Feed-forward neural network with multiple dense layers.
31#[derive(Debug, Clone)]
32pub struct FeedForwardNet {
33    pub layers: Vec<DenseLayer>,
34}
35
36impl FeedForwardNet {
37    /// Create a new network from a list of layers.
38    pub fn new(layers: Vec<DenseLayer>) -> Self {
39        FeedForwardNet { layers }
40    }
41
42    /// Forward pass: compute scalar output from input vector.
43    pub fn forward(&self, input: &DVector<f64>) -> f64 {
44        let mut x = input.clone();
45        for layer in &self.layers {
46            x = &layer.weights * &x + &layer.bias;
47            apply_activation(&mut x, layer.activation);
48        }
49        assert_eq!(x.len(), 1, "Output layer must produce a scalar");
50        x[0]
51    }
52
53    /// Forward pass returning all intermediate activations (for backprop).
54    pub fn forward_with_intermediates(&self, input: &DVector<f64>) -> Vec<DVector<f64>> {
55        let mut activations = Vec::with_capacity(self.layers.len() + 1);
56        activations.push(input.clone());
57
58        let mut x = input.clone();
59        for layer in &self.layers {
60            let z = &layer.weights * &x + &layer.bias;
61            let mut a = z.clone();
62            apply_activation(&mut a, layer.activation);
63            x = a.clone();
64            activations.push(a);
65        }
66        activations
67    }
68
69    /// Backward pass: compute gradient of output w.r.t. input.
70    pub fn backward(&self, input: &DVector<f64>) -> DVector<f64> {
71        // Forward pass storing pre-activation values
72        let mut pre_acts = Vec::with_capacity(self.layers.len());
73        let mut acts = Vec::with_capacity(self.layers.len() + 1);
74        acts.push(input.clone());
75
76        let mut x = input.clone();
77        for layer in &self.layers {
78            let z = &layer.weights * &x + &layer.bias;
79            pre_acts.push(z.clone());
80            let mut a = z;
81            apply_activation(&mut a, layer.activation);
82            x = a.clone();
83            acts.push(a);
84        }
85
86        // Backward: dL/dz for output layer is 1.0
87        let n_layers = self.layers.len();
88        let mut grad = DVector::from_element(1, 1.0);
89
90        for l in (0..n_layers).rev() {
91            // Apply activation derivative
92            let act_deriv = activation_derivative(&pre_acts[l], self.layers[l].activation);
93            grad = grad.component_mul(&act_deriv);
94            // Propagate through weights
95            grad = self.layers[l].weights.transpose() * &grad;
96        }
97        grad
98    }
99
100    /// Input dimension expected by the network.
101    pub fn input_dim(&self) -> usize {
102        if self.layers.is_empty() {
103            0
104        } else {
105            self.layers[0].weights.ncols()
106        }
107    }
108}
109
110fn apply_activation(x: &mut DVector<f64>, act: Activation) {
111    match act {
112        Activation::Gelu => {
113            for v in x.iter_mut() {
114                *v = gelu(*v);
115            }
116        }
117        Activation::Celu => {
118            for v in x.iter_mut() {
119                *v = celu(*v, 1.0);
120            }
121        }
122        Activation::None => {}
123    }
124}
125
126fn activation_derivative(z: &DVector<f64>, act: Activation) -> DVector<f64> {
127    match act {
128        Activation::Gelu => DVector::from_iterator(z.len(), z.iter().map(|&v| gelu_deriv(v))),
129        Activation::Celu => DVector::from_iterator(z.len(), z.iter().map(|&v| celu_deriv(v, 1.0))),
130        Activation::None => DVector::from_element(z.len(), 1.0),
131    }
132}
133
134/// GELU activation: x · Φ(x) where Φ is the standard normal CDF.
135#[inline]
136fn gelu(x: f64) -> f64 {
137    0.5 * x * (1.0 + erf(x / std::f64::consts::SQRT_2))
138}
139
140/// Approximate GELU derivative.
141#[inline]
142fn gelu_deriv(x: f64) -> f64 {
143    let s2 = std::f64::consts::SQRT_2;
144    let phi = 0.5 * (1.0 + erf(x / s2));
145    let pdf = (-0.5 * x * x).exp() / (2.0 * std::f64::consts::PI).sqrt();
146    phi + x * pdf
147}
148
149/// CELU activation: max(0,x) + min(0, α(e^{x/α} - 1)).
150#[inline]
151fn celu(x: f64, alpha: f64) -> f64 {
152    if x >= 0.0 {
153        x
154    } else {
155        alpha * ((x / alpha).exp() - 1.0)
156    }
157}
158
159/// CELU derivative.
160#[inline]
161fn celu_deriv(x: f64, alpha: f64) -> f64 {
162    if x >= 0.0 {
163        1.0
164    } else {
165        (x / alpha).exp()
166    }
167}
168
169/// Fast erf approximation (Abramowitz & Stegun 7.1.26, max error 1.5e-7).
170fn erf(x: f64) -> f64 {
171    let sign = if x >= 0.0 { 1.0 } else { -1.0 };
172    let x = x.abs();
173    let t = 1.0 / (1.0 + 0.3275911 * x);
174    let poly = t
175        * (0.254829592
176            + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
177    sign * (1.0 - poly * (-x * x).exp())
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    fn make_test_net() -> FeedForwardNet {
185        let l1 = DenseLayer {
186            weights: DMatrix::from_row_slice(3, 2, &[1.0, 0.5, -0.3, 0.8, 0.2, -0.1]),
187            bias: DVector::from_vec(vec![0.1, -0.2, 0.05]),
188            activation: Activation::Gelu,
189        };
190        let l2 = DenseLayer {
191            weights: DMatrix::from_row_slice(1, 3, &[0.4, -0.6, 0.3]),
192            bias: DVector::from_vec(vec![0.0]),
193            activation: Activation::None,
194        };
195        FeedForwardNet::new(vec![l1, l2])
196    }
197
198    #[test]
199    fn test_forward_deterministic() {
200        let net = make_test_net();
201        let input = DVector::from_vec(vec![1.0, -0.5]);
202        let out1 = net.forward(&input);
203        let out2 = net.forward(&input);
204        assert!((out1 - out2).abs() < 1e-15);
205    }
206
207    #[test]
208    fn test_backward_numerical() {
209        let net = make_test_net();
210        let input = DVector::from_vec(vec![1.0, -0.5]);
211        let grad = net.backward(&input);
212        let h = 1e-6;
213        for d in 0..input.len() {
214            let mut inp_p = input.clone();
215            let mut inp_m = input.clone();
216            inp_p[d] += h;
217            inp_m[d] -= h;
218            let num = (net.forward(&inp_p) - net.forward(&inp_m)) / (2.0 * h);
219            assert!(
220                (num - grad[d]).abs() < 1e-4,
221                "dim {d}: numerical={num}, analytical={}",
222                grad[d]
223            );
224        }
225    }
226}