ruvector_nervous_system/separate/
projection.rs

1//! Sparse random projection for dimensionality expansion
2//!
3//! Implements sparse random matrices for efficient high-dimensional projections
4//! with controlled sparsity (connection probability).
5
6use crate::{NervousSystemError, Result};
7use rand::rngs::StdRng;
8use rand::{Rng, SeedableRng};
9
10/// Sparse random projection matrix for dimensionality expansion
11///
12/// Uses a sparse random matrix to project low-dimensional inputs into
13/// high-dimensional space while maintaining computational efficiency.
14///
15/// # Properties
16///
17/// - Sparse connectivity (typically 10-20% connections)
18/// - Gaussian-distributed weights
19/// - Deterministic (seeded) for reproducibility
20///
21/// # Performance
22///
23/// - Time complexity: O(input_dim × output_dim × sparsity)
24/// - Space complexity: O(input_dim × output_dim)
25#[derive(Debug, Clone)]
26pub struct SparseProjection {
27    /// Projection weights [input_dim × output_dim]
28    weights: Vec<Vec<f32>>,
29
30    /// Connection probability (0.0 to 1.0)
31    sparsity: f32,
32
33    /// Random seed for reproducibility
34    seed: u64,
35
36    /// Input dimension
37    input_dim: usize,
38
39    /// Output dimension
40    output_dim: usize,
41}
42
43impl SparseProjection {
44    /// Create a new sparse random projection
45    ///
46    /// # Arguments
47    ///
48    /// * `input_dim` - Input vector dimension
49    /// * `output_dim` - Output vector dimension (should be >> input_dim)
50    /// * `sparsity` - Connection probability (typically 0.1-0.2)
51    /// * `seed` - Random seed for reproducibility
52    ///
53    /// # Example
54    ///
55    /// ```
56    /// use ruvector_nervous_system::SparseProjection;
57    ///
58    /// let projection = SparseProjection::new(128, 10000, 0.15, 42);
59    /// ```
60    pub fn new(input_dim: usize, output_dim: usize, sparsity: f32, seed: u64) -> Result<Self> {
61        if input_dim == 0 {
62            return Err(NervousSystemError::InvalidDimension(
63                "Input dimension must be > 0".to_string(),
64            ));
65        }
66
67        if output_dim == 0 {
68            return Err(NervousSystemError::InvalidDimension(
69                "Output dimension must be > 0".to_string(),
70            ));
71        }
72
73        if sparsity <= 0.0 || sparsity > 1.0 {
74            return Err(NervousSystemError::InvalidSparsity(format!(
75                "Sparsity must be in (0, 1], got {}",
76                sparsity
77            )));
78        }
79
80        let mut rng = StdRng::seed_from_u64(seed);
81        let mut weights = Vec::with_capacity(input_dim);
82
83        // Initialize sparse random weights
84        for _ in 0..input_dim {
85            let mut row = Vec::with_capacity(output_dim);
86            for _ in 0..output_dim {
87                if rng.gen::<f32>() < sparsity {
88                    // Gaussian random weight
89                    let weight: f32 = rng.gen_range(-1.0..1.0);
90                    row.push(weight);
91                } else {
92                    row.push(0.0);
93                }
94            }
95            weights.push(row);
96        }
97
98        Ok(Self {
99            weights,
100            sparsity,
101            seed,
102            input_dim,
103            output_dim,
104        })
105    }
106
107    /// Project input vector to high-dimensional space
108    ///
109    /// # Arguments
110    ///
111    /// * `input` - Input vector of size input_dim
112    ///
113    /// # Returns
114    ///
115    /// Output vector of size output_dim
116    ///
117    /// # Example
118    ///
119    /// ```
120    /// use ruvector_nervous_system::SparseProjection;
121    ///
122    /// let projection = SparseProjection::new(128, 10000, 0.15, 42).unwrap();
123    /// let input = vec![1.0; 128];
124    /// let output = projection.project(&input).unwrap();
125    /// assert_eq!(output.len(), 10000);
126    /// ```
127    pub fn project(&self, input: &[f32]) -> Result<Vec<f32>> {
128        if input.len() != self.input_dim {
129            return Err(NervousSystemError::DimensionMismatch {
130                expected: self.input_dim,
131                actual: input.len(),
132            });
133        }
134
135        let mut output = vec![0.0; self.output_dim];
136
137        // Matrix-vector multiplication: output = weights^T × input
138        for i in 0..self.input_dim {
139            let input_val = input[i];
140            if input_val != 0.0 {
141                for j in 0..self.output_dim {
142                    let weight = self.weights[i][j];
143                    if weight != 0.0 {
144                        output[j] += input_val * weight;
145                    }
146                }
147            }
148        }
149
150        Ok(output)
151    }
152
153    /// Get input dimension
154    pub fn input_dim(&self) -> usize {
155        self.input_dim
156    }
157
158    /// Get output dimension
159    pub fn output_dim(&self) -> usize {
160        self.output_dim
161    }
162
163    /// Get sparsity level
164    pub fn sparsity(&self) -> f32 {
165        self.sparsity
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn test_sparse_projection_creation() {
175        let proj = SparseProjection::new(128, 1000, 0.15, 42).unwrap();
176        assert_eq!(proj.input_dim(), 128);
177        assert_eq!(proj.output_dim(), 1000);
178        assert_eq!(proj.sparsity(), 0.15);
179    }
180
181    #[test]
182    fn test_invalid_dimensions() {
183        assert!(SparseProjection::new(0, 1000, 0.15, 42).is_err());
184        assert!(SparseProjection::new(128, 0, 0.15, 42).is_err());
185    }
186
187    #[test]
188    fn test_invalid_sparsity() {
189        assert!(SparseProjection::new(128, 1000, 0.0, 42).is_err());
190        assert!(SparseProjection::new(128, 1000, 1.5, 42).is_err());
191    }
192
193    #[test]
194    fn test_projection_dimensions() {
195        let proj = SparseProjection::new(128, 1000, 0.15, 42).unwrap();
196        let input = vec![1.0; 128];
197        let output = proj.project(&input).unwrap();
198        assert_eq!(output.len(), 1000);
199    }
200
201    #[test]
202    fn test_projection_dimension_mismatch() {
203        let proj = SparseProjection::new(128, 1000, 0.15, 42).unwrap();
204        let input = vec![1.0; 64]; // Wrong size
205        assert!(proj.project(&input).is_err());
206    }
207
208    #[test]
209    fn test_projection_deterministic() {
210        let proj1 = SparseProjection::new(128, 1000, 0.15, 42).unwrap();
211        let proj2 = SparseProjection::new(128, 1000, 0.15, 42).unwrap();
212
213        let input = vec![1.0; 128];
214        let output1 = proj1.project(&input).unwrap();
215        let output2 = proj2.project(&input).unwrap();
216
217        // Same seed should produce same results
218        assert_eq!(output1, output2);
219    }
220
221    #[test]
222    fn test_projection_sparsity_effect() {
223        let proj_sparse = SparseProjection::new(128, 1000, 0.1, 42).unwrap();
224        let proj_dense = SparseProjection::new(128, 1000, 0.9, 42).unwrap();
225
226        let input = vec![1.0; 128];
227        let output_sparse = proj_sparse.project(&input).unwrap();
228        let output_dense = proj_dense.project(&input).unwrap();
229
230        // Dense projection should have larger average magnitude
231        // (more connections contributing to each output)
232        let avg_sparse: f32 = output_sparse.iter().map(|x| x.abs()).sum::<f32>() / 1000.0;
233        let avg_dense: f32 = output_dense.iter().map(|x| x.abs()).sum::<f32>() / 1000.0;
234
235        // 0.9 sparsity means 9x more connections, so roughly sqrt(9) = 3x larger magnitude
236        assert!(
237            avg_dense > avg_sparse,
238            "Dense avg={} should be > sparse avg={}",
239            avg_dense,
240            avg_sparse
241        );
242    }
243
244    #[test]
245    fn test_zero_input_produces_zero_output() {
246        let proj = SparseProjection::new(128, 1000, 0.15, 42).unwrap();
247        let input = vec![0.0; 128];
248        let output = proj.project(&input).unwrap();
249
250        assert!(output.iter().all(|&x| x == 0.0));
251    }
252}