Skip to main content

vsa_optim_rs/vsa/
compressor.rs

1//! VSA Gradient Compressor implementation.
2//!
3//! Uses proper Vector Symbolic Architecture operations:
4//! - **Bind**: Associates each gradient with a unique random key
5//! - **Bundle**: Combines all bound gradients into a single superposition
6//! - **Unbind**: Extracts individual gradients by binding with inverse key
7//!
8//! This approach maintains the memory benefits of bundling while enabling
9//! accurate reconstruction via the quasi-orthogonality of random keys in
10//! high-dimensional space.
11
12use std::collections::HashMap;
13
14use candle_core::{DType, Device, Tensor};
15use rand::SeedableRng;
16use rand_chacha::ChaCha8Rng;
17use trit_vsa::{PackedTritVec, Trit, vsa as trit_vsa_ops};
18
19use crate::config::VSAConfig;
20use crate::error::{OptimError, Result};
21
22/// Gradient metadata for reconstruction.
23#[derive(Debug, Clone)]
24pub struct GradientMetadata {
25    /// Index in the bundling order (used for key generation).
26    pub key_index: usize,
27    /// Scale factor from quantization.
28    pub scale: f32,
29    /// Original shape.
30    pub shape: Vec<usize>,
31}
32
33/// Compress gradients using Vector Symbolic Architecture.
34///
35/// This compressor uses proper VSA operations:
36/// 1. **Project** each gradient to hyperdimensional space
37/// 2. **Bind** each projected gradient with a unique random key
38/// 3. **Bundle** all bound vectors into a single superposition
39/// 4. **Unbind** during decompression to extract individual gradients
40///
41/// The bundled representation achieves significant compression while the
42/// bind/unbind operations enable accurate reconstruction due to the
43/// quasi-orthogonality of random keys in high dimensions.
44///
45/// # Example
46///
47/// ```ignore
48/// use vsa_optim_rs::vsa::VSAGradientCompressor;
49/// use vsa_optim_rs::VSAConfig;
50///
51/// let compressor = VSAGradientCompressor::new(1_000_000, VSAConfig::default());
52///
53/// // After computing gradients
54/// let (compressed, metadata) = compressor.compress(&gradients)?;
55/// let reconstructed = compressor.decompress(&compressed, &metadata)?;
56/// ```
57pub struct VSAGradientCompressor {
58    config: VSAConfig,
59    param_count: usize,
60    hypervector_dim: usize,
61    /// Cache of binding keys per gradient index
62    key_cache: HashMap<usize, PackedTritVec>,
63    /// Cache of projection matrices
64    projection_cache: HashMap<usize, Tensor>,
65}
66
67impl VSAGradientCompressor {
68    /// Create a new VSA gradient compressor.
69    ///
70    /// # Arguments
71    ///
72    /// * `param_count` - Total number of model parameters
73    /// * `config` - VSA configuration
74    #[allow(clippy::cast_possible_truncation)]
75    #[allow(clippy::cast_sign_loss)]
76    #[must_use]
77    pub fn new(param_count: usize, config: VSAConfig) -> Self {
78        // Use configured dimension or default based on compression ratio
79        let hypervector_dim = config.dimension.max(
80            (param_count as f32 * config.compression_ratio).max(256.0) as usize
81        );
82
83        Self {
84            config,
85            param_count,
86            hypervector_dim,
87            key_cache: HashMap::new(),
88            projection_cache: HashMap::new(),
89        }
90    }
91
92    /// Get the hypervector dimension.
93    #[must_use]
94    pub const fn compressed_dim(&self) -> usize {
95        self.hypervector_dim
96    }
97
98    /// Generate or retrieve a random binding key for a gradient.
99    fn get_binding_key(&mut self, index: usize) -> PackedTritVec {
100        if let Some(key) = self.key_cache.get(&index) {
101            return key.clone();
102        }
103
104        // Generate deterministic random key for this index
105        let seed = self.config.seed.wrapping_add(index as u64 * 12345);
106        let mut rng = ChaCha8Rng::seed_from_u64(seed);
107
108        let mut key = PackedTritVec::new(self.hypervector_dim);
109        for i in 0..self.hypervector_dim {
110            use rand::Rng;
111            // Generate random ternary: ~33% each of -1, 0, +1
112            let r: f32 = rng.gen();
113            let trit = if r < 0.33 {
114                Trit::N
115            } else if r < 0.66 {
116                Trit::Z
117            } else {
118                Trit::P
119            };
120            key.set(i, trit);
121        }
122
123        self.key_cache.insert(index, key.clone());
124        key
125    }
126
127    /// Generate projection matrix to map gradient to hypervector space.
128    fn get_projection(&mut self, grad_size: usize, device: &Device) -> Result<Tensor> {
129        if let Some(proj) = self.projection_cache.get(&grad_size) {
130            return Ok(proj.clone());
131        }
132
133        let seed = self.config.seed.wrapping_add(grad_size as u64 * 54321);
134        let mut rng = ChaCha8Rng::seed_from_u64(seed);
135
136        // Scale for Johnson-Lindenstrauss: 1/sqrt(d) preserves dot products in expectation
137        let scale = 1.0 / (self.hypervector_dim as f32).sqrt();
138
139        let data: Vec<f32> = (0..grad_size * self.hypervector_dim)
140            .map(|_| {
141                use rand::Rng;
142                // Sparse random projection for efficiency: ~68% zeros
143                let r: f32 = rng.gen();
144                if r < 0.16 {
145                    scale * 3.0_f32.sqrt()  // sqrt(3) to maintain variance
146                } else if r < 0.32 {
147                    -scale * 3.0_f32.sqrt()
148                } else {
149                    0.0
150                }
151            })
152            .collect();
153
154        let proj = Tensor::from_vec(data, (grad_size, self.hypervector_dim), device)?;
155        self.projection_cache.insert(grad_size, proj.clone());
156        Ok(proj)
157    }
158
159    /// Project gradient to hypervector, returning ternary representation.
160    fn project_to_hypervector(
161        &mut self,
162        gradient: &Tensor,
163    ) -> Result<(PackedTritVec, f32)> {
164        let device = gradient.device();
165        let flat = gradient.flatten_all()?.to_dtype(DType::F32)?;
166        let grad_size = flat.elem_count();
167
168        // Get projection matrix
169        let proj = self.get_projection(grad_size, device)?;
170
171        // Project: (1, grad_size) @ (grad_size, dim) -> (1, dim)
172        let projected = flat.unsqueeze(0)?.matmul(&proj)?.squeeze(0)?;
173        let data: Vec<f32> = projected.to_vec1()?;
174
175        // Compute scale (mean absolute value)
176        let scale = if data.is_empty() {
177            0.0
178        } else {
179            data.iter().map(|v| v.abs()).sum::<f32>() / data.len() as f32
180        };
181
182        // Quantize to ternary
183        let mut packed = PackedTritVec::new(self.hypervector_dim);
184        if scale > 0.0 {
185            for (i, &v) in data.iter().enumerate() {
186                let trit = if v > scale {
187                    Trit::P
188                } else if v < -scale {
189                    Trit::N
190                } else {
191                    Trit::Z
192                };
193                packed.set(i, trit);
194            }
195        }
196
197        Ok((packed, scale))
198    }
199
200    /// Compress gradients to bundled hyperdimensional representation.
201    ///
202    /// # Algorithm
203    ///
204    /// 1. Project each gradient to hypervector space
205    /// 2. Quantize to ternary {-1, 0, +1}
206    /// 3. Bind with unique random key
207    /// 4. Bundle all bound vectors via element-wise sum
208    ///
209    /// # Arguments
210    ///
211    /// * `gradients` - Map of parameter names to gradient tensors
212    ///
213    /// # Returns
214    ///
215    /// Tuple of (bundled hypervector, metadata for reconstruction).
216    pub fn compress(
217        &mut self,
218        gradients: &HashMap<String, Tensor>,
219    ) -> Result<(PackedTritVec, HashMap<String, GradientMetadata>)> {
220        if gradients.is_empty() {
221            return Err(OptimError::EmptyInput("No gradients to compress".to_string()));
222        }
223
224        let mut metadata = HashMap::new();
225        let mut bound_vectors: Vec<PackedTritVec> = Vec::new();
226
227        for (index, (name, grad)) in gradients.iter().enumerate() {
228            // Project gradient to hypervector
229            let (projected, scale) = self.project_to_hypervector(grad)?;
230
231            // Get binding key for this gradient
232            let key = self.get_binding_key(index);
233
234            // Bind: gradient ⊛ key
235            let bound = trit_vsa_ops::bind(&projected, &key);
236            bound_vectors.push(bound);
237
238            metadata.insert(
239                name.clone(),
240                GradientMetadata {
241                    key_index: index,
242                    scale,
243                    shape: grad.dims().to_vec(),
244                },
245            );
246        }
247
248        // Bundle all bound vectors via majority voting
249        let refs: Vec<&PackedTritVec> = bound_vectors.iter().collect();
250        let bundled = trit_vsa_ops::bundle_many(&refs);
251
252        Ok((bundled, metadata))
253    }
254
255    /// Decompress gradients from bundled hypervector.
256    ///
257    /// # Algorithm
258    ///
259    /// For each gradient:
260    /// 1. Unbind with the gradient's key to extract from bundle
261    /// 2. Inverse project back to gradient space
262    /// 3. Apply stored scale factor
263    ///
264    /// # Arguments
265    ///
266    /// * `bundled` - Bundled hypervector from compress
267    /// * `metadata` - Metadata from compression
268    ///
269    /// # Returns
270    ///
271    /// Map of reconstructed gradients.
272    pub fn decompress(
273        &mut self,
274        bundled: &PackedTritVec,
275        metadata: &HashMap<String, GradientMetadata>,
276    ) -> Result<HashMap<String, Tensor>> {
277        let device = Device::Cpu; // Ternary ops are CPU-based
278        let mut gradients = HashMap::new();
279
280        for (name, meta) in metadata {
281            // Get the binding key used during compression
282            let key = self.get_binding_key(meta.key_index);
283
284            // Unbind to extract this gradient's contribution
285            let unbound = trit_vsa_ops::unbind(bundled, &key);
286
287            // Convert ternary to float and apply scale
288            let grad_size: usize = meta.shape.iter().product();
289            let proj = self.get_projection(grad_size, &device)?;
290
291            // Inverse projection: unbound @ proj.T
292            // First convert unbound ternary to float
293            let unbound_float: Vec<f32> = (0..self.hypervector_dim)
294                .map(|i| unbound.get(i).value() as f32 * meta.scale)
295                .collect();
296
297            let unbound_tensor = Tensor::from_vec(
298                unbound_float,
299                self.hypervector_dim,
300                &device,
301            )?;
302
303            // Inverse project: (1, dim) @ (dim, grad_size) -> (1, grad_size)
304            let reconstructed = unbound_tensor.unsqueeze(0)?
305                .matmul(&proj.t()?)?
306                .squeeze(0)?;
307
308            // Reshape to original
309            let grad = reconstructed.reshape(meta.shape.as_slice())?;
310            gradients.insert(name.clone(), grad);
311        }
312
313        Ok(gradients)
314    }
315
316    /// Get compression statistics.
317    #[must_use]
318    #[allow(clippy::cast_precision_loss)]
319    pub fn get_compression_stats(&self) -> CompressionStats {
320        CompressionStats {
321            original_params: self.param_count,
322            compressed_dim: self.hypervector_dim,
323            compression_ratio: self.hypervector_dim as f32 / self.param_count as f32,
324            // Ternary uses 2 bits per element vs 32 bits for float
325            memory_saving: 1.0 - (self.hypervector_dim as f32 * 2.0 / 32.0) / self.param_count as f32,
326        }
327    }
328
329    /// Clear caches to free memory.
330    pub fn clear_cache(&mut self) {
331        self.key_cache.clear();
332        self.projection_cache.clear();
333    }
334}
335
336/// Compression statistics.
337#[derive(Debug, Clone)]
338pub struct CompressionStats {
339    /// Original parameter count.
340    pub original_params: usize,
341    /// Compressed dimension.
342    pub compressed_dim: usize,
343    /// Compression ratio (compressed / original).
344    pub compression_ratio: f32,
345    /// Memory saving fraction (1 - compression_ratio).
346    pub memory_saving: f32,
347}
348
349impl std::fmt::Display for CompressionStats {
350    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
351        write!(
352            f,
353            "Compression: {} → {} ({:.1}% saved)",
354            self.original_params,
355            self.compressed_dim,
356            self.memory_saving * 100.0
357        )
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    fn create_mock_gradients(device: &Device) -> HashMap<String, Tensor> {
366        let mut gradients = HashMap::new();
367        gradients.insert(
368            "layer1.weight".to_string(),
369            Tensor::randn(0.0f32, 1.0, (64, 128), device).unwrap(),
370        );
371        gradients.insert(
372            "layer1.bias".to_string(),
373            Tensor::randn(0.0f32, 1.0, 64, device).unwrap(),
374        );
375        gradients.insert(
376            "layer2.weight".to_string(),
377            Tensor::randn(0.0f32, 1.0, (32, 64), device).unwrap(),
378        );
379        gradients
380    }
381
382    #[test]
383    fn test_compressor_creation() {
384        let compressor = VSAGradientCompressor::new(1_000_000, VSAConfig::default());
385        assert!(compressor.compressed_dim() >= 256);
386    }
387
388    #[test]
389    fn test_compress_decompress_roundtrip() {
390        let device = Device::Cpu;
391        let gradients = create_mock_gradients(&device);
392
393        let param_count: usize = gradients.values().map(|g| g.elem_count()).sum();
394        let mut compressor = VSAGradientCompressor::new(
395            param_count,
396            VSAConfig::default().with_compression_ratio(0.5),
397        );
398
399        // Compress
400        let (bundled, metadata) = compressor.compress(&gradients).unwrap();
401        assert_eq!(bundled.len(), compressor.compressed_dim());
402        assert_eq!(metadata.len(), 3);
403
404        // Decompress
405        let reconstructed = compressor.decompress(&bundled, &metadata).unwrap();
406        assert_eq!(reconstructed.len(), 3);
407
408        // Check shapes match
409        for (name, orig) in &gradients {
410            let recon = reconstructed.get(name).unwrap();
411            assert_eq!(orig.dims(), recon.dims());
412        }
413    }
414
415    #[test]
416    fn test_compression_stats() {
417        let compressor = VSAGradientCompressor::new(1_000_000, VSAConfig::default());
418        let stats = compressor.get_compression_stats();
419
420        assert_eq!(stats.original_params, 1_000_000);
421        // With ternary (2 bits per element), memory saving should be high
422        assert!(stats.memory_saving > 0.9);
423    }
424
425    #[test]
426    fn test_direction_preservation() {
427        let device = Device::Cpu;
428        let gradients = create_mock_gradients(&device);
429
430        let param_count: usize = gradients.values().map(|g| g.elem_count()).sum();
431        let mut compressor = VSAGradientCompressor::new(
432            param_count,
433            VSAConfig::default()
434                .with_dimension(8192)  // Use larger dimension for better reconstruction
435                .with_compression_ratio(0.5),
436        );
437
438        let (bundled, metadata) = compressor.compress(&gradients).unwrap();
439        let reconstructed = compressor.decompress(&bundled, &metadata).unwrap();
440
441        // Check cosine similarity is positive (direction preserved)
442        for (name, orig) in &gradients {
443            let recon = reconstructed.get(name).unwrap();
444
445            let orig_flat = orig.flatten_all().unwrap();
446            let recon_flat = recon.flatten_all().unwrap();
447
448            let orig_data: Vec<f32> = orig_flat.to_vec1().unwrap();
449            let recon_data: Vec<f32> = recon_flat.to_vec1().unwrap();
450
451            let dot: f32 = orig_data.iter().zip(recon_data.iter()).map(|(a, b)| a * b).sum();
452            let norm_orig: f32 = orig_data.iter().map(|x| x * x).sum::<f32>().sqrt();
453            let norm_recon: f32 = recon_data.iter().map(|x| x * x).sum::<f32>().sqrt();
454
455            // Skip very small tensors where numerical instability is expected
456            if norm_orig < 1e-6 || norm_recon < 1e-6 {
457                continue;
458            }
459
460            let cosine = dot / (norm_orig * norm_recon + 1e-8);
461
462            // Direction should be roughly preserved for larger tensors
463            // VSA reconstruction is approximate due to bundling interference
464            if orig.elem_count() >= 1024 {
465                assert!(
466                    cosine > 0.1,  // Lower threshold due to bundling noise
467                    "Gradient direction not preserved for {name}: cosine = {cosine}"
468                );
469            }
470        }
471    }
472
473    #[test]
474    fn test_bind_unbind_property() {
475        // Test that bind/unbind correctly recovers the original
476        let mut compressor = VSAGradientCompressor::new(1000, VSAConfig::default().with_dimension(1024));
477
478        let key0 = compressor.get_binding_key(0);
479        let key1 = compressor.get_binding_key(1);
480
481        // Keys should be different
482        let mut same_count = 0;
483        for i in 0..key0.len() {
484            if key0.get(i) == key1.get(i) {
485                same_count += 1;
486            }
487        }
488        // Should be roughly 1/3 same by chance
489        assert!(same_count < key0.len() * 2 / 3);
490
491        // Bind and unbind should recover original
492        let test_vec = key0.clone();
493        let bound = trit_vsa_ops::bind(&test_vec, &key1);
494        let recovered = trit_vsa_ops::unbind(&bound, &key1);
495
496        for i in 0..test_vec.len() {
497            assert_eq!(test_vec.get(i), recovered.get(i));
498        }
499    }
500
501    #[test]
502    fn test_empty_gradients() {
503        let mut compressor = VSAGradientCompressor::new(1000, VSAConfig::default());
504        let gradients = HashMap::new();
505
506        let result = compressor.compress(&gradients);
507        assert!(result.is_err());
508    }
509}