1use nalgebra::{DMatrix, DVector};
7
8#[derive(Debug, Clone)]
10pub struct DenseLayer {
11 pub weights: DMatrix<f64>,
13 pub bias: DVector<f64>,
15 pub activation: Activation,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum Activation {
22 Gelu,
24 Celu,
26 None,
28}
29
30#[derive(Debug, Clone)]
32pub struct FeedForwardNet {
33 pub layers: Vec<DenseLayer>,
34}
35
36impl FeedForwardNet {
37 pub fn new(layers: Vec<DenseLayer>) -> Self {
39 FeedForwardNet { layers }
40 }
41
42 pub fn forward(&self, input: &DVector<f64>) -> f64 {
44 let mut x = input.clone();
45 for layer in &self.layers {
46 x = &layer.weights * &x + &layer.bias;
47 apply_activation(&mut x, layer.activation);
48 }
49 assert_eq!(x.len(), 1, "Output layer must produce a scalar");
50 x[0]
51 }
52
53 pub fn forward_with_intermediates(&self, input: &DVector<f64>) -> Vec<DVector<f64>> {
55 let mut activations = Vec::with_capacity(self.layers.len() + 1);
56 activations.push(input.clone());
57
58 let mut x = input.clone();
59 for layer in &self.layers {
60 let z = &layer.weights * &x + &layer.bias;
61 let mut a = z.clone();
62 apply_activation(&mut a, layer.activation);
63 x = a.clone();
64 activations.push(a);
65 }
66 activations
67 }
68
69 pub fn backward(&self, input: &DVector<f64>) -> DVector<f64> {
71 let mut pre_acts = Vec::with_capacity(self.layers.len());
73 let mut acts = Vec::with_capacity(self.layers.len() + 1);
74 acts.push(input.clone());
75
76 let mut x = input.clone();
77 for layer in &self.layers {
78 let z = &layer.weights * &x + &layer.bias;
79 pre_acts.push(z.clone());
80 let mut a = z;
81 apply_activation(&mut a, layer.activation);
82 x = a.clone();
83 acts.push(a);
84 }
85
86 let n_layers = self.layers.len();
88 let mut grad = DVector::from_element(1, 1.0);
89
90 for l in (0..n_layers).rev() {
91 let act_deriv = activation_derivative(&pre_acts[l], self.layers[l].activation);
93 grad = grad.component_mul(&act_deriv);
94 grad = self.layers[l].weights.transpose() * &grad;
96 }
97 grad
98 }
99
100 pub fn input_dim(&self) -> usize {
102 if self.layers.is_empty() {
103 0
104 } else {
105 self.layers[0].weights.ncols()
106 }
107 }
108}
109
110fn apply_activation(x: &mut DVector<f64>, act: Activation) {
111 match act {
112 Activation::Gelu => {
113 for v in x.iter_mut() {
114 *v = gelu(*v);
115 }
116 }
117 Activation::Celu => {
118 for v in x.iter_mut() {
119 *v = celu(*v, 1.0);
120 }
121 }
122 Activation::None => {}
123 }
124}
125
126fn activation_derivative(z: &DVector<f64>, act: Activation) -> DVector<f64> {
127 match act {
128 Activation::Gelu => DVector::from_iterator(z.len(), z.iter().map(|&v| gelu_deriv(v))),
129 Activation::Celu => DVector::from_iterator(z.len(), z.iter().map(|&v| celu_deriv(v, 1.0))),
130 Activation::None => DVector::from_element(z.len(), 1.0),
131 }
132}
133
134#[inline]
136fn gelu(x: f64) -> f64 {
137 0.5 * x * (1.0 + erf(x / std::f64::consts::SQRT_2))
138}
139
140#[inline]
142fn gelu_deriv(x: f64) -> f64 {
143 let s2 = std::f64::consts::SQRT_2;
144 let phi = 0.5 * (1.0 + erf(x / s2));
145 let pdf = (-0.5 * x * x).exp() / (2.0 * std::f64::consts::PI).sqrt();
146 phi + x * pdf
147}
148
149#[inline]
151fn celu(x: f64, alpha: f64) -> f64 {
152 if x >= 0.0 {
153 x
154 } else {
155 alpha * ((x / alpha).exp() - 1.0)
156 }
157}
158
159#[inline]
161fn celu_deriv(x: f64, alpha: f64) -> f64 {
162 if x >= 0.0 {
163 1.0
164 } else {
165 (x / alpha).exp()
166 }
167}
168
169fn erf(x: f64) -> f64 {
171 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
172 let x = x.abs();
173 let t = 1.0 / (1.0 + 0.3275911 * x);
174 let poly = t
175 * (0.254829592
176 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
177 sign * (1.0 - poly * (-x * x).exp())
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 fn make_test_net() -> FeedForwardNet {
185 let l1 = DenseLayer {
186 weights: DMatrix::from_row_slice(3, 2, &[1.0, 0.5, -0.3, 0.8, 0.2, -0.1]),
187 bias: DVector::from_vec(vec![0.1, -0.2, 0.05]),
188 activation: Activation::Gelu,
189 };
190 let l2 = DenseLayer {
191 weights: DMatrix::from_row_slice(1, 3, &[0.4, -0.6, 0.3]),
192 bias: DVector::from_vec(vec![0.0]),
193 activation: Activation::None,
194 };
195 FeedForwardNet::new(vec![l1, l2])
196 }
197
198 #[test]
199 fn test_forward_deterministic() {
200 let net = make_test_net();
201 let input = DVector::from_vec(vec![1.0, -0.5]);
202 let out1 = net.forward(&input);
203 let out2 = net.forward(&input);
204 assert!((out1 - out2).abs() < 1e-15);
205 }
206
207 #[test]
208 fn test_backward_numerical() {
209 let net = make_test_net();
210 let input = DVector::from_vec(vec![1.0, -0.5]);
211 let grad = net.backward(&input);
212 let h = 1e-6;
213 for d in 0..input.len() {
214 let mut inp_p = input.clone();
215 let mut inp_m = input.clone();
216 inp_p[d] += h;
217 inp_m[d] -= h;
218 let num = (net.forward(&inp_p) - net.forward(&inp_m)) / (2.0 * h);
219 assert!(
220 (num - grad[d]).abs() < 1e-4,
221 "dim {d}: numerical={num}, analytical={}",
222 grad[d]
223 );
224 }
225 }
226}