ruvector_sparse_inference/sparse/
ffn.rs

1//! Sparse Feed-Forward Network implementation.
2
3use ndarray::{Array1, Array2};
4use serde::{Deserialize, Serialize};
5use tracing::{debug, trace};
6
7use crate::backend::{get_backend, Backend};
8use crate::config::ActivationType;
9use crate::error::{InferenceError, Result};
10
11/// Sparse Feed-Forward Network computation.
12///
13/// This implements a two-layer FFN that can compute using only a subset of neurons:
14/// - W1: [hidden_dim, input_dim] - first projection (row-major for neuron access)
15/// - W2: [output_dim, hidden_dim] - second projection (column-major for accumulation)
16/// - Activation function applied between layers
17///
18/// The sparse forward pass:
19/// 1. Sparse first layer: only compute active neurons
20/// 2. Apply activation function
21/// 3. Sparse second layer: accumulate only active neuron contributions
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct SparseFfn {
24    /// W1: [hidden_dim, input_dim] - first projection.
25    /// Row-major layout for efficient neuron access.
26    w1: Array2<f32>,
27
28    /// W2: [output_dim, hidden_dim] - second projection.
29    /// Column-major layout for efficient accumulation.
30    #[serde(with = "w2_serde")]
31    w2: Array2<f32>,
32
33    /// Bias for first layer.
34    b1: Array1<f32>,
35
36    /// Bias for second layer.
37    b2: Array1<f32>,
38
39    /// Activation function type.
40    activation: ActivationType,
41}
42
43// Custom serialization for w2 to handle layout
44mod w2_serde {
45    use super::*;
46    use ndarray::Array2;
47
48    pub fn serialize<S>(w2: &Array2<f32>, serializer: S) -> std::result::Result<S::Ok, S::Error>
49    where
50        S: serde::Serializer,
51    {
52        // Convert to standard layout for serialization
53        let standard = w2.as_standard_layout();
54        standard.serialize(serializer)
55    }
56
57    pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<Array2<f32>, D::Error>
58    where
59        D: serde::Deserializer<'de>,
60    {
61        let standard = Array2::<f32>::deserialize(deserializer)?;
62        Ok(standard)
63    }
64}
65
66impl SparseFfn {
67    /// Create a new sparse FFN with given dimensions.
68    pub fn new(
69        input_dim: usize,
70        hidden_dim: usize,
71        output_dim: usize,
72        activation: ActivationType,
73    ) -> Result<Self> {
74        use rand::Rng;
75        let mut rng = rand::thread_rng();
76
77        // Initialize with small random values
78        let w1 = Array2::from_shape_fn((hidden_dim, input_dim), |_| {
79            rng.gen::<f32>() * 0.01
80        });
81
82        let w2 = Array2::from_shape_fn((output_dim, hidden_dim), |_| {
83            rng.gen::<f32>() * 0.01
84        });
85
86        let b1 = Array1::zeros(hidden_dim);
87        let b2 = Array1::zeros(output_dim);
88
89        Ok(Self {
90            w1,
91            w2,
92            b1,
93            b2,
94            activation,
95        })
96    }
97
98    /// Create from existing weights.
99    pub fn from_weights(
100        w1: Array2<f32>,
101        w2: Array2<f32>,
102        b1: Array1<f32>,
103        b2: Array1<f32>,
104        activation: ActivationType,
105    ) -> Result<Self> {
106        let (hidden_dim, input_dim) = w1.dim();
107        let (output_dim, w2_hidden) = w2.dim();
108
109        if hidden_dim != w2_hidden {
110            return Err(InferenceError::Failed(
111                format!("Hidden dimension mismatch: W1 has {}, W2 has {}",
112                    hidden_dim, w2_hidden)
113            ).into());
114        }
115
116        if b1.len() != hidden_dim {
117            return Err(InferenceError::Failed(
118                format!("b1 dimension mismatch: expected {}, got {}",
119                    hidden_dim, b1.len())
120            ).into());
121        }
122
123        if b2.len() != output_dim {
124            return Err(InferenceError::Failed(
125                format!("b2 dimension mismatch: expected {}, got {}",
126                    output_dim, b2.len())
127            ).into());
128        }
129
130        Ok(Self {
131            w1,
132            w2,
133            b1,
134            b2,
135            activation,
136        })
137    }
138
139    /// Get input dimension.
140    pub fn input_dim(&self) -> usize {
141        self.w1.ncols()
142    }
143
144    /// Get hidden dimension.
145    pub fn hidden_dim(&self) -> usize {
146        self.w1.nrows()
147    }
148
149    /// Get output dimension.
150    pub fn output_dim(&self) -> usize {
151        self.w2.nrows()
152    }
153
154    /// Compute FFN using only active neurons (sparse computation).
155    ///
156    /// This is the main optimization: only compute activations for predicted neurons.
157    pub fn forward_sparse(&self, input: &[f32], active_neurons: &[usize]) -> Result<Vec<f32>> {
158        if input.len() != self.input_dim() {
159            return Err(InferenceError::InputDimensionMismatch {
160                expected: self.input_dim(),
161                actual: input.len(),
162            }.into());
163        }
164
165        if active_neurons.is_empty() {
166            return Err(InferenceError::NoActiveNeurons.into());
167        }
168
169        trace!("Sparse forward: {} active neurons ({:.1}% sparsity)",
170            active_neurons.len(),
171            100.0 * (1.0 - active_neurons.len() as f32 / self.hidden_dim() as f32)
172        );
173
174        let backend = get_backend();
175
176        // 1. Sparse first layer: only compute active neurons
177        let mut hidden = Vec::with_capacity(active_neurons.len());
178        for &neuron_idx in active_neurons {
179            if neuron_idx >= self.hidden_dim() {
180                return Err(InferenceError::Failed(
181                    format!("Invalid neuron index: {}", neuron_idx)
182                ).into());
183            }
184
185            let row = self.w1.row(neuron_idx);
186            let dot = backend.dot_product(row.as_slice().unwrap(), input);
187            hidden.push(dot + self.b1[neuron_idx]);
188        }
189
190        // 2. Apply activation function
191        backend.activation(&mut hidden, self.activation);
192
193        // 3. Sparse second layer: accumulate only active neuron contributions
194        let mut output = self.b2.to_vec();
195
196        for (i, &neuron_idx) in active_neurons.iter().enumerate() {
197            let col = self.w2.column(neuron_idx);
198            let h_val = hidden[i];
199
200            for (j, &w) in col.iter().enumerate() {
201                output[j] += h_val * w;
202            }
203        }
204
205        Ok(output)
206    }
207
208    /// Compute FFN using all neurons (dense computation).
209    ///
210    /// This is the baseline for comparison and correctness checking.
211    pub fn forward_dense(&self, input: &[f32]) -> Result<Vec<f32>> {
212        if input.len() != self.input_dim() {
213            return Err(InferenceError::InputDimensionMismatch {
214                expected: self.input_dim(),
215                actual: input.len(),
216            }.into());
217        }
218
219        let backend = get_backend();
220        let input_arr = Array1::from_vec(input.to_vec());
221
222        // 1. First layer: hidden = activation(W1 · input + b1)
223        let mut hidden = self.w1.dot(&input_arr) + &self.b1;
224        backend.activation(hidden.as_slice_mut().unwrap(), self.activation);
225
226        // 2. Second layer: output = W2 · hidden + b2
227        let output = self.w2.dot(&hidden) + &self.b2;
228
229        Ok(output.to_vec())
230    }
231
232    /// Compute both sparse and dense, returning the difference for validation.
233    #[cfg(test)]
234    pub fn validate_sparse(&self, input: &[f32], active_neurons: &[usize]) -> Result<f32> {
235        let sparse_output = self.forward_sparse(input, active_neurons)?;
236        let dense_output = self.forward_dense(input)?;
237
238        // Compute mean absolute error
239        let mae: f32 = sparse_output.iter()
240            .zip(dense_output.iter())
241            .map(|(s, d)| (s - d).abs())
242            .sum::<f32>() / sparse_output.len() as f32;
243
244        Ok(mae)
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_ffn_creation() {
254        let ffn = SparseFfn::new(128, 512, 128, ActivationType::Gelu).unwrap();
255
256        assert_eq!(ffn.input_dim(), 128);
257        assert_eq!(ffn.hidden_dim(), 512);
258        assert_eq!(ffn.output_dim(), 128);
259    }
260
261    #[test]
262    fn test_dense_forward() {
263        let ffn = SparseFfn::new(64, 256, 64, ActivationType::Relu).unwrap();
264        let input = vec![0.1; 64];
265
266        let output = ffn.forward_dense(&input).unwrap();
267        assert_eq!(output.len(), 64);
268    }
269
270    #[test]
271    fn test_sparse_forward() {
272        let ffn = SparseFfn::new(64, 256, 64, ActivationType::Relu).unwrap();
273        let input = vec![0.1; 64];
274        let active_neurons: Vec<usize> = (0..64).collect(); // First 64 neurons
275
276        let output = ffn.forward_sparse(&input, &active_neurons).unwrap();
277        assert_eq!(output.len(), 64);
278    }
279
280    #[test]
281    fn test_sparse_vs_dense() {
282        let ffn = SparseFfn::new(32, 128, 32, ActivationType::Relu).unwrap();
283        let input = vec![0.5; 32];
284
285        // Use all neurons - should match dense computation
286        let all_neurons: Vec<usize> = (0..128).collect();
287        let mae = ffn.validate_sparse(&input, &all_neurons).unwrap();
288
289        // Should be very close (allowing for floating point precision)
290        assert!(mae < 1e-5, "MAE too large: {}", mae);
291    }
292
293    #[test]
294    fn test_empty_active_neurons() {
295        let ffn = SparseFfn::new(32, 128, 32, ActivationType::Relu).unwrap();
296        let input = vec![0.1; 32];
297        let empty: Vec<usize> = vec![];
298
299        let result = ffn.forward_sparse(&input, &empty);
300        assert!(result.is_err());
301    }
302
303    #[test]
304    fn test_invalid_neuron_index() {
305        let ffn = SparseFfn::new(32, 128, 32, ActivationType::Relu).unwrap();
306        let input = vec![0.1; 32];
307        let invalid = vec![200]; // Out of bounds
308
309        let result = ffn.forward_sparse(&input, &invalid);
310        assert!(result.is_err());
311    }
312}