ruvector_sparse_inference/sparse/
ffn.rs1use ndarray::{Array1, Array2};
4use serde::{Deserialize, Serialize};
5use tracing::{debug, trace};
6
7use crate::backend::{get_backend, Backend};
8use crate::config::ActivationType;
9use crate::error::{InferenceError, Result};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SparseFfn {
29 w1: Array2<f32>,
32
33 #[serde(with = "w2_serde")]
37 w2_t: Array2<f32>,
38
39 b1: Array1<f32>,
41
42 b2: Array1<f32>,
44
45 activation: ActivationType,
47
48 output_dim: usize,
50}
51
52mod w2_serde {
54 use super::*;
55 use ndarray::Array2;
56
57 pub fn serialize<S>(w2_t: &Array2<f32>, serializer: S) -> std::result::Result<S::Ok, S::Error>
58 where
59 S: serde::Serializer,
60 {
61 let w2 = w2_t.t().to_owned();
63 w2.serialize(serializer)
64 }
65
66 pub fn deserialize<'de, D>(deserializer: D) -> std::result::Result<Array2<f32>, D::Error>
67 where
68 D: serde::Deserializer<'de>,
69 {
70 let w2 = Array2::<f32>::deserialize(deserializer)?;
72 Ok(w2.t().to_owned())
73 }
74}
75
76impl SparseFfn {
77 pub fn new(
79 input_dim: usize,
80 hidden_dim: usize,
81 output_dim: usize,
82 activation: ActivationType,
83 ) -> Result<Self> {
84 use rand::Rng;
85 let mut rng = rand::thread_rng();
86
87 let w1 = Array2::from_shape_fn((hidden_dim, input_dim), |_| {
89 rng.gen::<f32>() * 0.01
90 });
91
92 let w2_t = Array2::from_shape_fn((hidden_dim, output_dim), |_| {
95 rng.gen::<f32>() * 0.01
96 });
97
98 let b1 = Array1::zeros(hidden_dim);
99 let b2 = Array1::zeros(output_dim);
100
101 Ok(Self {
102 w1,
103 w2_t,
104 b1,
105 b2,
106 activation,
107 output_dim,
108 })
109 }
110
111 pub fn from_weights(
113 w1: Array2<f32>,
114 w2: Array2<f32>,
115 b1: Array1<f32>,
116 b2: Array1<f32>,
117 activation: ActivationType,
118 ) -> Result<Self> {
119 let (hidden_dim, _input_dim) = w1.dim();
120 let (output_dim, w2_hidden) = w2.dim();
121
122 if hidden_dim != w2_hidden {
123 return Err(InferenceError::Failed(
124 format!("Hidden dimension mismatch: W1 has {}, W2 has {}",
125 hidden_dim, w2_hidden)
126 ).into());
127 }
128
129 if b1.len() != hidden_dim {
130 return Err(InferenceError::Failed(
131 format!("b1 dimension mismatch: expected {}, got {}",
132 hidden_dim, b1.len())
133 ).into());
134 }
135
136 if b2.len() != output_dim {
137 return Err(InferenceError::Failed(
138 format!("b2 dimension mismatch: expected {}, got {}",
139 output_dim, b2.len())
140 ).into());
141 }
142
143 let w2_t = w2.t().to_owned();
145
146 Ok(Self {
147 w1,
148 w2_t,
149 b1,
150 b2,
151 activation,
152 output_dim,
153 })
154 }
155
156 pub fn input_dim(&self) -> usize {
158 self.w1.ncols()
159 }
160
161 pub fn hidden_dim(&self) -> usize {
163 self.w1.nrows()
164 }
165
166 pub fn output_dim(&self) -> usize {
168 self.output_dim
169 }
170
171 pub fn forward_sparse(&self, input: &[f32], active_neurons: &[usize]) -> Result<Vec<f32>> {
175 if input.len() != self.input_dim() {
176 return Err(InferenceError::InputDimensionMismatch {
177 expected: self.input_dim(),
178 actual: input.len(),
179 }.into());
180 }
181
182 if active_neurons.is_empty() {
183 return Err(InferenceError::NoActiveNeurons.into());
184 }
185
186 trace!("Sparse forward: {} active neurons ({:.1}% sparsity)",
187 active_neurons.len(),
188 100.0 * (1.0 - active_neurons.len() as f32 / self.hidden_dim() as f32)
189 );
190
191 let backend = get_backend();
192
193 let mut hidden = Vec::with_capacity(active_neurons.len());
195 for &neuron_idx in active_neurons {
196 if neuron_idx >= self.hidden_dim() {
197 return Err(InferenceError::Failed(
198 format!("Invalid neuron index: {}", neuron_idx)
199 ).into());
200 }
201
202 let row = self.w1.row(neuron_idx);
203 let dot = backend.dot_product(row.as_slice().unwrap(), input);
204 hidden.push(dot + self.b1[neuron_idx]);
205 }
206
207 backend.activation(&mut hidden, self.activation);
209
210 let mut output = self.b2.to_vec();
213 let backend = get_backend();
214
215 for (i, &neuron_idx) in active_neurons.iter().enumerate() {
216 let weights = self.w2_t.row(neuron_idx);
218 let h_val = hidden[i];
219
220 backend.axpy(&mut output, weights.as_slice().unwrap(), h_val);
222 }
223
224 Ok(output)
225 }
226
227 pub fn forward_dense(&self, input: &[f32]) -> Result<Vec<f32>> {
231 if input.len() != self.input_dim() {
232 return Err(InferenceError::InputDimensionMismatch {
233 expected: self.input_dim(),
234 actual: input.len(),
235 }.into());
236 }
237
238 let backend = get_backend();
239 let input_arr = Array1::from_vec(input.to_vec());
240
241 let mut hidden = self.w1.dot(&input_arr) + &self.b1;
243 backend.activation(hidden.as_slice_mut().unwrap(), self.activation);
244
245 let output = self.w2_t.t().dot(&hidden) + &self.b2;
249
250 Ok(output.to_vec())
251 }
252
253 #[cfg(test)]
255 pub fn validate_sparse(&self, input: &[f32], active_neurons: &[usize]) -> Result<f32> {
256 let sparse_output = self.forward_sparse(input, active_neurons)?;
257 let dense_output = self.forward_dense(input)?;
258
259 let mae: f32 = sparse_output.iter()
261 .zip(dense_output.iter())
262 .map(|(s, d)| (s - d).abs())
263 .sum::<f32>() / sparse_output.len() as f32;
264
265 Ok(mae)
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_ffn_creation() {
275 let ffn = SparseFfn::new(128, 512, 128, ActivationType::Gelu).unwrap();
276
277 assert_eq!(ffn.input_dim(), 128);
278 assert_eq!(ffn.hidden_dim(), 512);
279 assert_eq!(ffn.output_dim(), 128);
280 }
281
282 #[test]
283 fn test_dense_forward() {
284 let ffn = SparseFfn::new(64, 256, 64, ActivationType::Relu).unwrap();
285 let input = vec![0.1; 64];
286
287 let output = ffn.forward_dense(&input).unwrap();
288 assert_eq!(output.len(), 64);
289 }
290
291 #[test]
292 fn test_sparse_forward() {
293 let ffn = SparseFfn::new(64, 256, 64, ActivationType::Relu).unwrap();
294 let input = vec![0.1; 64];
295 let active_neurons: Vec<usize> = (0..64).collect(); let output = ffn.forward_sparse(&input, &active_neurons).unwrap();
298 assert_eq!(output.len(), 64);
299 }
300
301 #[test]
302 fn test_sparse_vs_dense() {
303 let ffn = SparseFfn::new(32, 128, 32, ActivationType::Relu).unwrap();
304 let input = vec![0.5; 32];
305
306 let all_neurons: Vec<usize> = (0..128).collect();
308 let mae = ffn.validate_sparse(&input, &all_neurons).unwrap();
309
310 assert!(mae < 1e-5, "MAE too large: {}", mae);
312 }
313
314 #[test]
315 fn test_empty_active_neurons() {
316 let ffn = SparseFfn::new(32, 128, 32, ActivationType::Relu).unwrap();
317 let input = vec![0.1; 32];
318 let empty: Vec<usize> = vec![];
319
320 let result = ffn.forward_sparse(&input, &empty);
321 assert!(result.is_err());
322 }
323
324 #[test]
325 fn test_invalid_neuron_index() {
326 let ffn = SparseFfn::new(32, 128, 32, ActivationType::Relu).unwrap();
327 let input = vec![0.1; 32];
328 let invalid = vec![200]; let result = ffn.forward_sparse(&input, &invalid);
331 assert!(result.is_err());
332 }
333}