1use std::collections::HashMap;
13
14use candle_core::{DType, Device, Tensor};
15use rand::SeedableRng;
16use rand_chacha::ChaCha8Rng;
17use trit_vsa::{PackedTritVec, Trit, vsa as trit_vsa_ops};
18
19use crate::config::VSAConfig;
20use crate::error::{OptimError, Result};
21
22#[derive(Debug, Clone)]
24pub struct GradientMetadata {
25 pub key_index: usize,
27 pub scale: f32,
29 pub shape: Vec<usize>,
31}
32
33pub struct VSAGradientCompressor {
58 config: VSAConfig,
59 param_count: usize,
60 hypervector_dim: usize,
61 key_cache: HashMap<usize, PackedTritVec>,
63 projection_cache: HashMap<usize, Tensor>,
65}
66
67impl VSAGradientCompressor {
68 #[allow(clippy::cast_possible_truncation)]
75 #[allow(clippy::cast_sign_loss)]
76 #[must_use]
77 pub fn new(param_count: usize, config: VSAConfig) -> Self {
78 let hypervector_dim = config.dimension.max(
80 (param_count as f32 * config.compression_ratio).max(256.0) as usize
81 );
82
83 Self {
84 config,
85 param_count,
86 hypervector_dim,
87 key_cache: HashMap::new(),
88 projection_cache: HashMap::new(),
89 }
90 }
91
92 #[must_use]
94 pub const fn compressed_dim(&self) -> usize {
95 self.hypervector_dim
96 }
97
98 fn get_binding_key(&mut self, index: usize) -> PackedTritVec {
100 if let Some(key) = self.key_cache.get(&index) {
101 return key.clone();
102 }
103
104 let seed = self.config.seed.wrapping_add(index as u64 * 12345);
106 let mut rng = ChaCha8Rng::seed_from_u64(seed);
107
108 let mut key = PackedTritVec::new(self.hypervector_dim);
109 for i in 0..self.hypervector_dim {
110 use rand::Rng;
111 let r: f32 = rng.gen();
113 let trit = if r < 0.33 {
114 Trit::N
115 } else if r < 0.66 {
116 Trit::Z
117 } else {
118 Trit::P
119 };
120 key.set(i, trit);
121 }
122
123 self.key_cache.insert(index, key.clone());
124 key
125 }
126
127 fn get_projection(&mut self, grad_size: usize, device: &Device) -> Result<Tensor> {
129 if let Some(proj) = self.projection_cache.get(&grad_size) {
130 return Ok(proj.clone());
131 }
132
133 let seed = self.config.seed.wrapping_add(grad_size as u64 * 54321);
134 let mut rng = ChaCha8Rng::seed_from_u64(seed);
135
136 let scale = 1.0 / (self.hypervector_dim as f32).sqrt();
138
139 let data: Vec<f32> = (0..grad_size * self.hypervector_dim)
140 .map(|_| {
141 use rand::Rng;
142 let r: f32 = rng.gen();
144 if r < 0.16 {
145 scale * 3.0_f32.sqrt() } else if r < 0.32 {
147 -scale * 3.0_f32.sqrt()
148 } else {
149 0.0
150 }
151 })
152 .collect();
153
154 let proj = Tensor::from_vec(data, (grad_size, self.hypervector_dim), device)?;
155 self.projection_cache.insert(grad_size, proj.clone());
156 Ok(proj)
157 }
158
159 fn project_to_hypervector(
161 &mut self,
162 gradient: &Tensor,
163 ) -> Result<(PackedTritVec, f32)> {
164 let device = gradient.device();
165 let flat = gradient.flatten_all()?.to_dtype(DType::F32)?;
166 let grad_size = flat.elem_count();
167
168 let proj = self.get_projection(grad_size, device)?;
170
171 let projected = flat.unsqueeze(0)?.matmul(&proj)?.squeeze(0)?;
173 let data: Vec<f32> = projected.to_vec1()?;
174
175 let scale = if data.is_empty() {
177 0.0
178 } else {
179 data.iter().map(|v| v.abs()).sum::<f32>() / data.len() as f32
180 };
181
182 let mut packed = PackedTritVec::new(self.hypervector_dim);
184 if scale > 0.0 {
185 for (i, &v) in data.iter().enumerate() {
186 let trit = if v > scale {
187 Trit::P
188 } else if v < -scale {
189 Trit::N
190 } else {
191 Trit::Z
192 };
193 packed.set(i, trit);
194 }
195 }
196
197 Ok((packed, scale))
198 }
199
200 pub fn compress(
217 &mut self,
218 gradients: &HashMap<String, Tensor>,
219 ) -> Result<(PackedTritVec, HashMap<String, GradientMetadata>)> {
220 if gradients.is_empty() {
221 return Err(OptimError::EmptyInput("No gradients to compress".to_string()));
222 }
223
224 let mut metadata = HashMap::new();
225 let mut bound_vectors: Vec<PackedTritVec> = Vec::new();
226
227 for (index, (name, grad)) in gradients.iter().enumerate() {
228 let (projected, scale) = self.project_to_hypervector(grad)?;
230
231 let key = self.get_binding_key(index);
233
234 let bound = trit_vsa_ops::bind(&projected, &key);
236 bound_vectors.push(bound);
237
238 metadata.insert(
239 name.clone(),
240 GradientMetadata {
241 key_index: index,
242 scale,
243 shape: grad.dims().to_vec(),
244 },
245 );
246 }
247
248 let refs: Vec<&PackedTritVec> = bound_vectors.iter().collect();
250 let bundled = trit_vsa_ops::bundle_many(&refs);
251
252 Ok((bundled, metadata))
253 }
254
255 pub fn decompress(
273 &mut self,
274 bundled: &PackedTritVec,
275 metadata: &HashMap<String, GradientMetadata>,
276 ) -> Result<HashMap<String, Tensor>> {
277 let device = Device::Cpu; let mut gradients = HashMap::new();
279
280 for (name, meta) in metadata {
281 let key = self.get_binding_key(meta.key_index);
283
284 let unbound = trit_vsa_ops::unbind(bundled, &key);
286
287 let grad_size: usize = meta.shape.iter().product();
289 let proj = self.get_projection(grad_size, &device)?;
290
291 let unbound_float: Vec<f32> = (0..self.hypervector_dim)
294 .map(|i| unbound.get(i).value() as f32 * meta.scale)
295 .collect();
296
297 let unbound_tensor = Tensor::from_vec(
298 unbound_float,
299 self.hypervector_dim,
300 &device,
301 )?;
302
303 let reconstructed = unbound_tensor.unsqueeze(0)?
305 .matmul(&proj.t()?)?
306 .squeeze(0)?;
307
308 let grad = reconstructed.reshape(meta.shape.as_slice())?;
310 gradients.insert(name.clone(), grad);
311 }
312
313 Ok(gradients)
314 }
315
316 #[must_use]
318 #[allow(clippy::cast_precision_loss)]
319 pub fn get_compression_stats(&self) -> CompressionStats {
320 CompressionStats {
321 original_params: self.param_count,
322 compressed_dim: self.hypervector_dim,
323 compression_ratio: self.hypervector_dim as f32 / self.param_count as f32,
324 memory_saving: 1.0 - (self.hypervector_dim as f32 * 2.0 / 32.0) / self.param_count as f32,
326 }
327 }
328
329 pub fn clear_cache(&mut self) {
331 self.key_cache.clear();
332 self.projection_cache.clear();
333 }
334}
335
336#[derive(Debug, Clone)]
338pub struct CompressionStats {
339 pub original_params: usize,
341 pub compressed_dim: usize,
343 pub compression_ratio: f32,
345 pub memory_saving: f32,
347}
348
349impl std::fmt::Display for CompressionStats {
350 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
351 write!(
352 f,
353 "Compression: {} → {} ({:.1}% saved)",
354 self.original_params,
355 self.compressed_dim,
356 self.memory_saving * 100.0
357 )
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 fn create_mock_gradients(device: &Device) -> HashMap<String, Tensor> {
366 let mut gradients = HashMap::new();
367 gradients.insert(
368 "layer1.weight".to_string(),
369 Tensor::randn(0.0f32, 1.0, (64, 128), device).unwrap(),
370 );
371 gradients.insert(
372 "layer1.bias".to_string(),
373 Tensor::randn(0.0f32, 1.0, 64, device).unwrap(),
374 );
375 gradients.insert(
376 "layer2.weight".to_string(),
377 Tensor::randn(0.0f32, 1.0, (32, 64), device).unwrap(),
378 );
379 gradients
380 }
381
382 #[test]
383 fn test_compressor_creation() {
384 let compressor = VSAGradientCompressor::new(1_000_000, VSAConfig::default());
385 assert!(compressor.compressed_dim() >= 256);
386 }
387
388 #[test]
389 fn test_compress_decompress_roundtrip() {
390 let device = Device::Cpu;
391 let gradients = create_mock_gradients(&device);
392
393 let param_count: usize = gradients.values().map(|g| g.elem_count()).sum();
394 let mut compressor = VSAGradientCompressor::new(
395 param_count,
396 VSAConfig::default().with_compression_ratio(0.5),
397 );
398
399 let (bundled, metadata) = compressor.compress(&gradients).unwrap();
401 assert_eq!(bundled.len(), compressor.compressed_dim());
402 assert_eq!(metadata.len(), 3);
403
404 let reconstructed = compressor.decompress(&bundled, &metadata).unwrap();
406 assert_eq!(reconstructed.len(), 3);
407
408 for (name, orig) in &gradients {
410 let recon = reconstructed.get(name).unwrap();
411 assert_eq!(orig.dims(), recon.dims());
412 }
413 }
414
415 #[test]
416 fn test_compression_stats() {
417 let compressor = VSAGradientCompressor::new(1_000_000, VSAConfig::default());
418 let stats = compressor.get_compression_stats();
419
420 assert_eq!(stats.original_params, 1_000_000);
421 assert!(stats.memory_saving > 0.9);
423 }
424
425 #[test]
426 fn test_direction_preservation() {
427 let device = Device::Cpu;
428 let gradients = create_mock_gradients(&device);
429
430 let param_count: usize = gradients.values().map(|g| g.elem_count()).sum();
431 let mut compressor = VSAGradientCompressor::new(
432 param_count,
433 VSAConfig::default()
434 .with_dimension(8192) .with_compression_ratio(0.5),
436 );
437
438 let (bundled, metadata) = compressor.compress(&gradients).unwrap();
439 let reconstructed = compressor.decompress(&bundled, &metadata).unwrap();
440
441 for (name, orig) in &gradients {
443 let recon = reconstructed.get(name).unwrap();
444
445 let orig_flat = orig.flatten_all().unwrap();
446 let recon_flat = recon.flatten_all().unwrap();
447
448 let orig_data: Vec<f32> = orig_flat.to_vec1().unwrap();
449 let recon_data: Vec<f32> = recon_flat.to_vec1().unwrap();
450
451 let dot: f32 = orig_data.iter().zip(recon_data.iter()).map(|(a, b)| a * b).sum();
452 let norm_orig: f32 = orig_data.iter().map(|x| x * x).sum::<f32>().sqrt();
453 let norm_recon: f32 = recon_data.iter().map(|x| x * x).sum::<f32>().sqrt();
454
455 if norm_orig < 1e-6 || norm_recon < 1e-6 {
457 continue;
458 }
459
460 let cosine = dot / (norm_orig * norm_recon + 1e-8);
461
462 if orig.elem_count() >= 1024 {
465 assert!(
466 cosine > 0.1, "Gradient direction not preserved for {name}: cosine = {cosine}"
468 );
469 }
470 }
471 }
472
473 #[test]
474 fn test_bind_unbind_property() {
475 let mut compressor = VSAGradientCompressor::new(1000, VSAConfig::default().with_dimension(1024));
477
478 let key0 = compressor.get_binding_key(0);
479 let key1 = compressor.get_binding_key(1);
480
481 let mut same_count = 0;
483 for i in 0..key0.len() {
484 if key0.get(i) == key1.get(i) {
485 same_count += 1;
486 }
487 }
488 assert!(same_count < key0.len() * 2 / 3);
490
491 let test_vec = key0.clone();
493 let bound = trit_vsa_ops::bind(&test_vec, &key1);
494 let recovered = trit_vsa_ops::unbind(&bound, &key1);
495
496 for i in 0..test_vec.len() {
497 assert_eq!(test_vec.get(i), recovered.get(i));
498 }
499 }
500
501 #[test]
502 fn test_empty_gradients() {
503 let mut compressor = VSAGradientCompressor::new(1000, VSAConfig::default());
504 let gradients = HashMap::new();
505
506 let result = compressor.compress(&gradients);
507 assert!(result.is_err());
508 }
509}