ruvector_sparse_inference/sparse/
ffn.rs1use ndarray::{Array1, Array2};
4use serde::{Deserialize, Serialize};
5use tracing::{debug, trace};
6
7use crate::backend::{get_backend, Backend};
8use crate::config::ActivationType;
9use crate::error::{InferenceError, Result};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct SparseFfn {
24 w1: Array2<f32>,
27
28 #[serde(with = "w2_serde")]
31 w2: Array2<f32>,
32
33 b1: Array1<f32>,
35
36 b2: Array1<f32>,
38
39 activation: ActivationType,
41}
42
43mod w2_serde {
45 use super::*;
46 use ndarray::Array2;
47
48 pub fn serialize<S>(w2: &Array2<f32>, serializer: S) -> std::result::Result<S::Ok, S::Error>
49 where
50 S: serde::Serializer,
51 {
52 let standard = w2.as_standard_layout();
54 standard.serialize(serializer)
55 }
56
57 pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<Array2<f32>, D::Error>
58 where
59 D: serde::Deserializer<'de>,
60 {
61 let standard = Array2::<f32>::deserialize(deserializer)?;
62 Ok(standard)
63 }
64}
65
66impl SparseFfn {
67 pub fn new(
69 input_dim: usize,
70 hidden_dim: usize,
71 output_dim: usize,
72 activation: ActivationType,
73 ) -> Result<Self> {
74 use rand::Rng;
75 let mut rng = rand::thread_rng();
76
77 let w1 = Array2::from_shape_fn((hidden_dim, input_dim), |_| {
79 rng.gen::<f32>() * 0.01
80 });
81
82 let w2 = Array2::from_shape_fn((output_dim, hidden_dim), |_| {
83 rng.gen::<f32>() * 0.01
84 });
85
86 let b1 = Array1::zeros(hidden_dim);
87 let b2 = Array1::zeros(output_dim);
88
89 Ok(Self {
90 w1,
91 w2,
92 b1,
93 b2,
94 activation,
95 })
96 }
97
98 pub fn from_weights(
100 w1: Array2<f32>,
101 w2: Array2<f32>,
102 b1: Array1<f32>,
103 b2: Array1<f32>,
104 activation: ActivationType,
105 ) -> Result<Self> {
106 let (hidden_dim, input_dim) = w1.dim();
107 let (output_dim, w2_hidden) = w2.dim();
108
109 if hidden_dim != w2_hidden {
110 return Err(InferenceError::Failed(
111 format!("Hidden dimension mismatch: W1 has {}, W2 has {}",
112 hidden_dim, w2_hidden)
113 ).into());
114 }
115
116 if b1.len() != hidden_dim {
117 return Err(InferenceError::Failed(
118 format!("b1 dimension mismatch: expected {}, got {}",
119 hidden_dim, b1.len())
120 ).into());
121 }
122
123 if b2.len() != output_dim {
124 return Err(InferenceError::Failed(
125 format!("b2 dimension mismatch: expected {}, got {}",
126 output_dim, b2.len())
127 ).into());
128 }
129
130 Ok(Self {
131 w1,
132 w2,
133 b1,
134 b2,
135 activation,
136 })
137 }
138
139 pub fn input_dim(&self) -> usize {
141 self.w1.ncols()
142 }
143
144 pub fn hidden_dim(&self) -> usize {
146 self.w1.nrows()
147 }
148
149 pub fn output_dim(&self) -> usize {
151 self.w2.nrows()
152 }
153
154 pub fn forward_sparse(&self, input: &[f32], active_neurons: &[usize]) -> Result<Vec<f32>> {
158 if input.len() != self.input_dim() {
159 return Err(InferenceError::InputDimensionMismatch {
160 expected: self.input_dim(),
161 actual: input.len(),
162 }.into());
163 }
164
165 if active_neurons.is_empty() {
166 return Err(InferenceError::NoActiveNeurons.into());
167 }
168
169 trace!("Sparse forward: {} active neurons ({:.1}% sparsity)",
170 active_neurons.len(),
171 100.0 * (1.0 - active_neurons.len() as f32 / self.hidden_dim() as f32)
172 );
173
174 let backend = get_backend();
175
176 let mut hidden = Vec::with_capacity(active_neurons.len());
178 for &neuron_idx in active_neurons {
179 if neuron_idx >= self.hidden_dim() {
180 return Err(InferenceError::Failed(
181 format!("Invalid neuron index: {}", neuron_idx)
182 ).into());
183 }
184
185 let row = self.w1.row(neuron_idx);
186 let dot = backend.dot_product(row.as_slice().unwrap(), input);
187 hidden.push(dot + self.b1[neuron_idx]);
188 }
189
190 backend.activation(&mut hidden, self.activation);
192
193 let mut output = self.b2.to_vec();
195
196 for (i, &neuron_idx) in active_neurons.iter().enumerate() {
197 let col = self.w2.column(neuron_idx);
198 let h_val = hidden[i];
199
200 for (j, &w) in col.iter().enumerate() {
201 output[j] += h_val * w;
202 }
203 }
204
205 Ok(output)
206 }
207
208 pub fn forward_dense(&self, input: &[f32]) -> Result<Vec<f32>> {
212 if input.len() != self.input_dim() {
213 return Err(InferenceError::InputDimensionMismatch {
214 expected: self.input_dim(),
215 actual: input.len(),
216 }.into());
217 }
218
219 let backend = get_backend();
220 let input_arr = Array1::from_vec(input.to_vec());
221
222 let mut hidden = self.w1.dot(&input_arr) + &self.b1;
224 backend.activation(hidden.as_slice_mut().unwrap(), self.activation);
225
226 let output = self.w2.dot(&hidden) + &self.b2;
228
229 Ok(output.to_vec())
230 }
231
232 #[cfg(test)]
234 pub fn validate_sparse(&self, input: &[f32], active_neurons: &[usize]) -> Result<f32> {
235 let sparse_output = self.forward_sparse(input, active_neurons)?;
236 let dense_output = self.forward_dense(input)?;
237
238 let mae: f32 = sparse_output.iter()
240 .zip(dense_output.iter())
241 .map(|(s, d)| (s - d).abs())
242 .sum::<f32>() / sparse_output.len() as f32;
243
244 Ok(mae)
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn test_ffn_creation() {
254 let ffn = SparseFfn::new(128, 512, 128, ActivationType::Gelu).unwrap();
255
256 assert_eq!(ffn.input_dim(), 128);
257 assert_eq!(ffn.hidden_dim(), 512);
258 assert_eq!(ffn.output_dim(), 128);
259 }
260
261 #[test]
262 fn test_dense_forward() {
263 let ffn = SparseFfn::new(64, 256, 64, ActivationType::Relu).unwrap();
264 let input = vec![0.1; 64];
265
266 let output = ffn.forward_dense(&input).unwrap();
267 assert_eq!(output.len(), 64);
268 }
269
270 #[test]
271 fn test_sparse_forward() {
272 let ffn = SparseFfn::new(64, 256, 64, ActivationType::Relu).unwrap();
273 let input = vec![0.1; 64];
274 let active_neurons: Vec<usize> = (0..64).collect(); let output = ffn.forward_sparse(&input, &active_neurons).unwrap();
277 assert_eq!(output.len(), 64);
278 }
279
280 #[test]
281 fn test_sparse_vs_dense() {
282 let ffn = SparseFfn::new(32, 128, 32, ActivationType::Relu).unwrap();
283 let input = vec![0.5; 32];
284
285 let all_neurons: Vec<usize> = (0..128).collect();
287 let mae = ffn.validate_sparse(&input, &all_neurons).unwrap();
288
289 assert!(mae < 1e-5, "MAE too large: {}", mae);
291 }
292
293 #[test]
294 fn test_empty_active_neurons() {
295 let ffn = SparseFfn::new(32, 128, 32, ActivationType::Relu).unwrap();
296 let input = vec![0.1; 32];
297 let empty: Vec<usize> = vec![];
298
299 let result = ffn.forward_sparse(&input, &empty);
300 assert!(result.is_err());
301 }
302
303 #[test]
304 fn test_invalid_neuron_index() {
305 let ffn = SparseFfn::new(32, 128, 32, ActivationType::Relu).unwrap();
306 let input = vec![0.1; 32];
307 let invalid = vec![200]; let result = ffn.forward_sparse(&input, &invalid);
310 assert!(result.is_err());
311 }
312}