ruvector_sparse_inference/sparse/
ffn.rs

1//! Sparse Feed-Forward Network implementation.
2
3use ndarray::{Array1, Array2};
4use serde::{Deserialize, Serialize};
5use tracing::{debug, trace};
6
7use crate::backend::{get_backend, Backend};
8use crate::config::ActivationType;
9use crate::error::{InferenceError, Result};
10
11/// Sparse Feed-Forward Network computation.
12///
13/// This implements a two-layer FFN that can compute using only a subset of neurons:
14/// - W1: [hidden_dim, input_dim] - first projection (row-major for neuron access)
15/// - W2_T: [hidden_dim, output_dim] - second projection TRANSPOSED (row-major for contiguous access)
16/// - Activation function applied between layers
17///
18/// The sparse forward pass:
19/// 1. Sparse first layer: only compute active neurons
20/// 2. Apply activation function
21/// 3. Sparse second layer: accumulate only active neuron contributions (now contiguous!)
22///
23/// # Performance Optimization
24///
25/// W2 is stored transposed so that accessing columns (by neuron index) becomes row access,
26/// which is contiguous in memory. This provides 15-25% speedup in the sparse accumulation step.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SparseFfn {
29    /// W1: [hidden_dim, input_dim] - first projection.
30    /// Row-major layout for efficient neuron access.
31    w1: Array2<f32>,
32
33    /// W2_T: [hidden_dim, output_dim] - second projection TRANSPOSED.
34    /// Row-major layout for contiguous neuron weight access.
35    /// Original W2 shape was [output_dim, hidden_dim].
36    #[serde(with = "w2_serde")]
37    w2_t: Array2<f32>,
38
39    /// Bias for first layer.
40    b1: Array1<f32>,
41
42    /// Bias for second layer.
43    b2: Array1<f32>,
44
45    /// Activation function type.
46    activation: ActivationType,
47
48    /// Output dimension (cached for efficiency)
49    output_dim: usize,
50}
51
52// Custom serialization for w2_t - stores as original W2 for compatibility
53mod w2_serde {
54    use super::*;
55    use ndarray::Array2;
56
57    pub fn serialize<S>(w2_t: &Array2<f32>, serializer: S) -> std::result::Result<S::Ok, S::Error>
58    where
59        S: serde::Serializer,
60    {
61        // Transpose back to original W2 shape for serialization compatibility
62        let w2 = w2_t.t().to_owned();
63        w2.serialize(serializer)
64    }
65
66    pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<Array2<f32>, D::Error>
67    where
68        D: serde::Deserializer<'de>,
69    {
70        // Load as original W2 and transpose for optimized storage
71        let w2 = Array2::<f32>::deserialize(deserializer)?;
72        Ok(w2.t().to_owned())
73    }
74}
75
76impl SparseFfn {
77    /// Create a new sparse FFN with given dimensions.
78    pub fn new(
79        input_dim: usize,
80        hidden_dim: usize,
81        output_dim: usize,
82        activation: ActivationType,
83    ) -> Result<Self> {
84        use rand::Rng;
85        let mut rng = rand::thread_rng();
86
87        // Initialize with small random values
88        let w1 = Array2::from_shape_fn((hidden_dim, input_dim), |_| {
89            rng.gen::<f32>() * 0.01
90        });
91
92        // Store W2 transposed: [hidden_dim, output_dim] instead of [output_dim, hidden_dim]
93        // This allows contiguous row access when iterating by neuron index
94        let w2_t = Array2::from_shape_fn((hidden_dim, output_dim), |_| {
95            rng.gen::<f32>() * 0.01
96        });
97
98        let b1 = Array1::zeros(hidden_dim);
99        let b2 = Array1::zeros(output_dim);
100
101        Ok(Self {
102            w1,
103            w2_t,
104            b1,
105            b2,
106            activation,
107            output_dim,
108        })
109    }
110
111    /// Create from existing weights.
112    pub fn from_weights(
113        w1: Array2<f32>,
114        w2: Array2<f32>,
115        b1: Array1<f32>,
116        b2: Array1<f32>,
117        activation: ActivationType,
118    ) -> Result<Self> {
119        let (hidden_dim, _input_dim) = w1.dim();
120        let (output_dim, w2_hidden) = w2.dim();
121
122        if hidden_dim != w2_hidden {
123            return Err(InferenceError::Failed(
124                format!("Hidden dimension mismatch: W1 has {}, W2 has {}",
125                    hidden_dim, w2_hidden)
126            ).into());
127        }
128
129        if b1.len() != hidden_dim {
130            return Err(InferenceError::Failed(
131                format!("b1 dimension mismatch: expected {}, got {}",
132                    hidden_dim, b1.len())
133            ).into());
134        }
135
136        if b2.len() != output_dim {
137            return Err(InferenceError::Failed(
138                format!("b2 dimension mismatch: expected {}, got {}",
139                    output_dim, b2.len())
140            ).into());
141        }
142
143        // Transpose W2 for optimized storage
144        let w2_t = w2.t().to_owned();
145
146        Ok(Self {
147            w1,
148            w2_t,
149            b1,
150            b2,
151            activation,
152            output_dim,
153        })
154    }
155
156    /// Get input dimension.
157    pub fn input_dim(&self) -> usize {
158        self.w1.ncols()
159    }
160
161    /// Get hidden dimension.
162    pub fn hidden_dim(&self) -> usize {
163        self.w1.nrows()
164    }
165
166    /// Get output dimension.
167    pub fn output_dim(&self) -> usize {
168        self.output_dim
169    }
170
171    /// Compute FFN using only active neurons (sparse computation).
172    ///
173    /// This is the main optimization: only compute activations for predicted neurons.
174    pub fn forward_sparse(&self, input: &[f32], active_neurons: &[usize]) -> Result<Vec<f32>> {
175        if input.len() != self.input_dim() {
176            return Err(InferenceError::InputDimensionMismatch {
177                expected: self.input_dim(),
178                actual: input.len(),
179            }.into());
180        }
181
182        if active_neurons.is_empty() {
183            return Err(InferenceError::NoActiveNeurons.into());
184        }
185
186        trace!("Sparse forward: {} active neurons ({:.1}% sparsity)",
187            active_neurons.len(),
188            100.0 * (1.0 - active_neurons.len() as f32 / self.hidden_dim() as f32)
189        );
190
191        let backend = get_backend();
192
193        // 1. Sparse first layer: only compute active neurons
194        let mut hidden = Vec::with_capacity(active_neurons.len());
195        for &neuron_idx in active_neurons {
196            if neuron_idx >= self.hidden_dim() {
197                return Err(InferenceError::Failed(
198                    format!("Invalid neuron index: {}", neuron_idx)
199                ).into());
200            }
201
202            let row = self.w1.row(neuron_idx);
203            let dot = backend.dot_product(row.as_slice().unwrap(), input);
204            hidden.push(dot + self.b1[neuron_idx]);
205        }
206
207        // 2. Apply activation function
208        backend.activation(&mut hidden, self.activation);
209
210        // 3. Sparse second layer: accumulate only active neuron contributions
211        // W2_T is [hidden_dim, output_dim], so row access by neuron_idx is CONTIGUOUS
212        let mut output = self.b2.to_vec();
213        let backend = get_backend();
214
215        for (i, &neuron_idx) in active_neurons.iter().enumerate() {
216            // Row access is contiguous in memory - major optimization!
217            let weights = self.w2_t.row(neuron_idx);
218            let h_val = hidden[i];
219
220            // Use SIMD-optimized axpy: output += h_val * weights
221            backend.axpy(&mut output, weights.as_slice().unwrap(), h_val);
222        }
223
224        Ok(output)
225    }
226
227    /// Compute FFN using all neurons (dense computation).
228    ///
229    /// This is the baseline for comparison and correctness checking.
230    pub fn forward_dense(&self, input: &[f32]) -> Result<Vec<f32>> {
231        if input.len() != self.input_dim() {
232            return Err(InferenceError::InputDimensionMismatch {
233                expected: self.input_dim(),
234                actual: input.len(),
235            }.into());
236        }
237
238        let backend = get_backend();
239        let input_arr = Array1::from_vec(input.to_vec());
240
241        // 1. First layer: hidden = activation(W1 · input + b1)
242        let mut hidden = self.w1.dot(&input_arr) + &self.b1;
243        backend.activation(hidden.as_slice_mut().unwrap(), self.activation);
244
245        // 2. Second layer: output = W2 · hidden + b2
246        // W2_T is [hidden_dim, output_dim], so W2 = W2_T.t()
247        // output = W2_T.t() · hidden = (hidden.t() · W2_T).t() = W2_T.t().dot(hidden)
248        let output = self.w2_t.t().dot(&hidden) + &self.b2;
249
250        Ok(output.to_vec())
251    }
252
253    /// Compute both sparse and dense, returning the difference for validation.
254    #[cfg(test)]
255    pub fn validate_sparse(&self, input: &[f32], active_neurons: &[usize]) -> Result<f32> {
256        let sparse_output = self.forward_sparse(input, active_neurons)?;
257        let dense_output = self.forward_dense(input)?;
258
259        // Compute mean absolute error
260        let mae: f32 = sparse_output.iter()
261            .zip(dense_output.iter())
262            .map(|(s, d)| (s - d).abs())
263            .sum::<f32>() / sparse_output.len() as f32;
264
265        Ok(mae)
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn test_ffn_creation() {
275        let ffn = SparseFfn::new(128, 512, 128, ActivationType::Gelu).unwrap();
276
277        assert_eq!(ffn.input_dim(), 128);
278        assert_eq!(ffn.hidden_dim(), 512);
279        assert_eq!(ffn.output_dim(), 128);
280    }
281
282    #[test]
283    fn test_dense_forward() {
284        let ffn = SparseFfn::new(64, 256, 64, ActivationType::Relu).unwrap();
285        let input = vec![0.1; 64];
286
287        let output = ffn.forward_dense(&input).unwrap();
288        assert_eq!(output.len(), 64);
289    }
290
291    #[test]
292    fn test_sparse_forward() {
293        let ffn = SparseFfn::new(64, 256, 64, ActivationType::Relu).unwrap();
294        let input = vec![0.1; 64];
295        let active_neurons: Vec<usize> = (0..64).collect(); // First 64 neurons
296
297        let output = ffn.forward_sparse(&input, &active_neurons).unwrap();
298        assert_eq!(output.len(), 64);
299    }
300
301    #[test]
302    fn test_sparse_vs_dense() {
303        let ffn = SparseFfn::new(32, 128, 32, ActivationType::Relu).unwrap();
304        let input = vec![0.5; 32];
305
306        // Use all neurons - should match dense computation
307        let all_neurons: Vec<usize> = (0..128).collect();
308        let mae = ffn.validate_sparse(&input, &all_neurons).unwrap();
309
310        // Should be very close (allowing for floating point precision)
311        assert!(mae < 1e-5, "MAE too large: {}", mae);
312    }
313
314    #[test]
315    fn test_empty_active_neurons() {
316        let ffn = SparseFfn::new(32, 128, 32, ActivationType::Relu).unwrap();
317        let input = vec![0.1; 32];
318        let empty: Vec<usize> = vec![];
319
320        let result = ffn.forward_sparse(&input, &empty);
321        assert!(result.is_err());
322    }
323
324    #[test]
325    fn test_invalid_neuron_index() {
326        let ffn = SparseFfn::new(32, 128, 32, ActivationType::Relu).unwrap();
327        let input = vec![0.1; 32];
328        let invalid = vec![200]; // Out of bounds
329
330        let result = ffn.forward_sparse(&input, &invalid);
331        assert!(result.is_err());
332    }
333}