ruvector_nervous_system/separate/
projection.rs1use crate::{NervousSystemError, Result};
7use rand::rngs::StdRng;
8use rand::{Rng, SeedableRng};
9
10#[derive(Debug, Clone)]
26pub struct SparseProjection {
27 weights: Vec<Vec<f32>>,
29
30 sparsity: f32,
32
33 seed: u64,
35
36 input_dim: usize,
38
39 output_dim: usize,
41}
42
43impl SparseProjection {
44 pub fn new(input_dim: usize, output_dim: usize, sparsity: f32, seed: u64) -> Result<Self> {
61 if input_dim == 0 {
62 return Err(NervousSystemError::InvalidDimension(
63 "Input dimension must be > 0".to_string(),
64 ));
65 }
66
67 if output_dim == 0 {
68 return Err(NervousSystemError::InvalidDimension(
69 "Output dimension must be > 0".to_string(),
70 ));
71 }
72
73 if sparsity <= 0.0 || sparsity > 1.0 {
74 return Err(NervousSystemError::InvalidSparsity(format!(
75 "Sparsity must be in (0, 1], got {}",
76 sparsity
77 )));
78 }
79
80 let mut rng = StdRng::seed_from_u64(seed);
81 let mut weights = Vec::with_capacity(input_dim);
82
83 for _ in 0..input_dim {
85 let mut row = Vec::with_capacity(output_dim);
86 for _ in 0..output_dim {
87 if rng.gen::<f32>() < sparsity {
88 let weight: f32 = rng.gen_range(-1.0..1.0);
90 row.push(weight);
91 } else {
92 row.push(0.0);
93 }
94 }
95 weights.push(row);
96 }
97
98 Ok(Self {
99 weights,
100 sparsity,
101 seed,
102 input_dim,
103 output_dim,
104 })
105 }
106
107 pub fn project(&self, input: &[f32]) -> Result<Vec<f32>> {
128 if input.len() != self.input_dim {
129 return Err(NervousSystemError::DimensionMismatch {
130 expected: self.input_dim,
131 actual: input.len(),
132 });
133 }
134
135 let mut output = vec![0.0; self.output_dim];
136
137 for i in 0..self.input_dim {
139 let input_val = input[i];
140 if input_val != 0.0 {
141 for j in 0..self.output_dim {
142 let weight = self.weights[i][j];
143 if weight != 0.0 {
144 output[j] += input_val * weight;
145 }
146 }
147 }
148 }
149
150 Ok(output)
151 }
152
153 pub fn input_dim(&self) -> usize {
155 self.input_dim
156 }
157
158 pub fn output_dim(&self) -> usize {
160 self.output_dim
161 }
162
163 pub fn sparsity(&self) -> f32 {
165 self.sparsity
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn test_sparse_projection_creation() {
175 let proj = SparseProjection::new(128, 1000, 0.15, 42).unwrap();
176 assert_eq!(proj.input_dim(), 128);
177 assert_eq!(proj.output_dim(), 1000);
178 assert_eq!(proj.sparsity(), 0.15);
179 }
180
181 #[test]
182 fn test_invalid_dimensions() {
183 assert!(SparseProjection::new(0, 1000, 0.15, 42).is_err());
184 assert!(SparseProjection::new(128, 0, 0.15, 42).is_err());
185 }
186
187 #[test]
188 fn test_invalid_sparsity() {
189 assert!(SparseProjection::new(128, 1000, 0.0, 42).is_err());
190 assert!(SparseProjection::new(128, 1000, 1.5, 42).is_err());
191 }
192
193 #[test]
194 fn test_projection_dimensions() {
195 let proj = SparseProjection::new(128, 1000, 0.15, 42).unwrap();
196 let input = vec![1.0; 128];
197 let output = proj.project(&input).unwrap();
198 assert_eq!(output.len(), 1000);
199 }
200
201 #[test]
202 fn test_projection_dimension_mismatch() {
203 let proj = SparseProjection::new(128, 1000, 0.15, 42).unwrap();
204 let input = vec![1.0; 64]; assert!(proj.project(&input).is_err());
206 }
207
208 #[test]
209 fn test_projection_deterministic() {
210 let proj1 = SparseProjection::new(128, 1000, 0.15, 42).unwrap();
211 let proj2 = SparseProjection::new(128, 1000, 0.15, 42).unwrap();
212
213 let input = vec![1.0; 128];
214 let output1 = proj1.project(&input).unwrap();
215 let output2 = proj2.project(&input).unwrap();
216
217 assert_eq!(output1, output2);
219 }
220
221 #[test]
222 fn test_projection_sparsity_effect() {
223 let proj_sparse = SparseProjection::new(128, 1000, 0.1, 42).unwrap();
224 let proj_dense = SparseProjection::new(128, 1000, 0.9, 42).unwrap();
225
226 let input = vec![1.0; 128];
227 let output_sparse = proj_sparse.project(&input).unwrap();
228 let output_dense = proj_dense.project(&input).unwrap();
229
230 let avg_sparse: f32 = output_sparse.iter().map(|x| x.abs()).sum::<f32>() / 1000.0;
233 let avg_dense: f32 = output_dense.iter().map(|x| x.abs()).sum::<f32>() / 1000.0;
234
235 assert!(
237 avg_dense > avg_sparse,
238 "Dense avg={} should be > sparse avg={}",
239 avg_dense,
240 avg_sparse
241 );
242 }
243
244 #[test]
245 fn test_zero_input_produces_zero_output() {
246 let proj = SparseProjection::new(128, 1000, 0.15, 42).unwrap();
247 let input = vec![0.0; 128];
248 let output = proj.project(&input).unwrap();
249
250 assert!(output.iter().all(|&x| x == 0.0));
251 }
252}