Skip to main content

tritter_accel/core/
vsa.rs

1//! Vector Symbolic Architecture (VSA) operations.
2//!
3//! Provides hyperdimensional computing primitives for gradient compression,
4//! associative memory, and symbolic reasoning.
5//!
6//! # Operations
7//!
8//! - **Bind**: Associative composition (XOR-like for ternary)
9//! - **Bundle**: Superposition via majority voting
10//! - **Similarity**: Cosine, dot product, Hamming distance
11//!
12//! # Example
13//!
14//! ```rust,ignore
15//! use tritter_accel::core::vsa::{VsaOps, VsaConfig};
16//! use candle_core::Device;
17//!
18//! let config = VsaConfig::default();
19//! let ops = VsaOps::new(config)?;
20//!
21//! // Create random ternary vectors
22//! let a = ops.random(10000, 42)?;
23//! let b = ops.random(10000, 43)?;
24//!
25//! // Bind creates association
26//! let bound = ops.bind(&a, &b)?;
27//!
28//! // Unbind recovers original
29//! let recovered = ops.unbind(&bound, &b)?;
30//! assert!(ops.cosine_similarity(&a, &recovered)? > 0.9);
31//! ```
32
33use candle_core::Device;
34use thiserror::Error;
35use trit_vsa::{PackedTritVec, Trit};
36
37#[cfg(feature = "cuda")]
38use trit_vsa::gpu::{
39    GpuBind, GpuBundle, GpuCosineSimilarity, GpuDispatchable, GpuDotSimilarity, GpuHammingDistance,
40    GpuRandom, GpuUnbind, RandomInput,
41};
42
43/// Errors from VSA operations.
44#[derive(Debug, Error)]
45pub enum VsaError {
46    /// Vectors have mismatched dimensions.
47    #[error("dimension mismatch: expected {expected}, got {actual}")]
48    DimensionMismatch { expected: usize, actual: usize },
49
50    /// Invalid ternary value.
51    #[error("invalid value {value} at index {index}")]
52    InvalidValue { value: i8, index: usize },
53
54    /// GPU operation failed.
55    #[error("GPU error: {0}")]
56    Gpu(String),
57
58    /// Empty input.
59    #[error("empty input")]
60    EmptyInput,
61}
62
63/// Configuration for VSA operations.
64#[derive(Debug, Clone)]
65pub struct VsaConfig {
66    /// Preferred device for computation.
67    pub device: DevicePreference,
68}
69
70/// Device preference for VSA operations.
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum DevicePreference {
73    /// Use GPU if available, fall back to CPU.
74    Auto,
75    /// Force GPU (error if unavailable).
76    Gpu,
77    /// Force CPU.
78    Cpu,
79}
80
81impl Default for VsaConfig {
82    fn default() -> Self {
83        Self {
84            device: DevicePreference::Auto,
85        }
86    }
87}
88
89impl VsaConfig {
90    /// Set device preference.
91    pub fn with_device(mut self, device: DevicePreference) -> Self {
92        self.device = device;
93        self
94    }
95}
96
97/// VSA operations handler.
98///
99/// Provides ternary VSA operations with automatic CPU/GPU dispatch.
100#[derive(Debug, Clone)]
101pub struct VsaOps {
102    config: VsaConfig,
103}
104
105impl VsaOps {
106    /// Create new VSA operations handler.
107    pub fn new(config: VsaConfig) -> Self {
108        Self { config }
109    }
110
111    /// Get the effective device for computation.
112    fn get_device(&self) -> Result<Device, VsaError> {
113        match self.config.device {
114            DevicePreference::Cpu => Ok(Device::Cpu),
115            DevicePreference::Gpu => {
116                #[cfg(feature = "cuda")]
117                {
118                    Device::cuda_if_available(0).map_err(|e| VsaError::Gpu(e.to_string()))
119                }
120                #[cfg(not(feature = "cuda"))]
121                {
122                    Err(VsaError::Gpu(
123                        "CUDA not compiled. Rebuild with --features cuda".to_string(),
124                    ))
125                }
126            }
127            DevicePreference::Auto => {
128                #[cfg(feature = "cuda")]
129                {
130                    Ok(Device::cuda_if_available(0).unwrap_or(Device::Cpu))
131                }
132                #[cfg(not(feature = "cuda"))]
133                {
134                    Ok(Device::Cpu)
135                }
136            }
137        }
138    }
139
140    /// Generate a random ternary vector.
141    ///
142    /// # Arguments
143    ///
144    /// * `dim` - Vector dimension
145    /// * `seed` - Random seed for reproducibility
146    pub fn random(&self, dim: usize, seed: u32) -> Result<PackedTritVec, VsaError> {
147        if dim == 0 {
148            return Err(VsaError::EmptyInput);
149        }
150
151        let device = self.get_device()?;
152
153        #[cfg(feature = "cuda")]
154        {
155            if matches!(device, Device::Cuda(_)) {
156                let input = RandomInput::new(dim, seed);
157                return GpuRandom
158                    .dispatch(&input, &device)
159                    .map_err(|e| VsaError::Gpu(e.to_string()));
160            }
161        }
162
163        // CPU fallback
164        let _ = device; // silence unused warning
165        Ok(cpu_random(dim, seed))
166    }
167
168    /// Bind two ternary vectors (association).
169    ///
170    /// Bind is the composition operation, creating associations between vectors.
171    /// It is commutative and associative.
172    pub fn bind(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<PackedTritVec, VsaError> {
173        if a.len() != b.len() {
174            return Err(VsaError::DimensionMismatch {
175                expected: a.len(),
176                actual: b.len(),
177            });
178        }
179
180        let device = self.get_device()?;
181
182        #[cfg(feature = "cuda")]
183        {
184            if matches!(device, Device::Cuda(_)) {
185                return GpuBind
186                    .dispatch(&(a.clone(), b.clone()), &device)
187                    .map_err(|e| VsaError::Gpu(e.to_string()));
188            }
189        }
190
191        let _ = device;
192        Ok(cpu_bind(a, b))
193    }
194
195    /// Unbind two ternary vectors (inverse association).
196    ///
197    /// If bound = bind(a, b), then unbind(bound, b) recovers a.
198    pub fn unbind(&self, bound: &PackedTritVec, key: &PackedTritVec) -> Result<PackedTritVec, VsaError> {
199        if bound.len() != key.len() {
200            return Err(VsaError::DimensionMismatch {
201                expected: bound.len(),
202                actual: key.len(),
203            });
204        }
205
206        let device = self.get_device()?;
207
208        #[cfg(feature = "cuda")]
209        {
210            if matches!(device, Device::Cuda(_)) {
211                return GpuUnbind
212                    .dispatch(&(bound.clone(), key.clone()), &device)
213                    .map_err(|e| VsaError::Gpu(e.to_string()));
214            }
215        }
216
217        let _ = device;
218        // For ternary VSA, unbind is the same as bind (self-inverse)
219        Ok(cpu_bind(bound, key))
220    }
221
222    /// Bundle multiple vectors (superposition).
223    ///
224    /// Combines vectors via majority voting at each dimension.
225    /// The result is similar to all input vectors.
226    pub fn bundle(&self, vectors: &[PackedTritVec]) -> Result<PackedTritVec, VsaError> {
227        if vectors.is_empty() {
228            return Err(VsaError::EmptyInput);
229        }
230
231        let dim = vectors[0].len();
232        for (i, v) in vectors.iter().enumerate() {
233            if v.len() != dim {
234                return Err(VsaError::DimensionMismatch {
235                    expected: dim,
236                    actual: v.len(),
237                });
238            }
239            let _ = i; // used for error reporting if needed
240        }
241
242        let device = self.get_device()?;
243
244        #[cfg(feature = "cuda")]
245        {
246            if matches!(device, Device::Cuda(_)) {
247                return GpuBundle
248                    .dispatch(&vectors.to_vec(), &device)
249                    .map_err(|e| VsaError::Gpu(e.to_string()));
250            }
251        }
252
253        let _ = device;
254        Ok(cpu_bundle(vectors))
255    }
256
257    /// Compute cosine similarity between two vectors.
258    ///
259    /// Returns a value in [-1, 1].
260    pub fn cosine_similarity(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<f32, VsaError> {
261        if a.len() != b.len() {
262            return Err(VsaError::DimensionMismatch {
263                expected: a.len(),
264                actual: b.len(),
265            });
266        }
267
268        let device = self.get_device()?;
269
270        #[cfg(feature = "cuda")]
271        {
272            if matches!(device, Device::Cuda(_)) {
273                return GpuCosineSimilarity
274                    .dispatch(&(a.clone(), b.clone()), &device)
275                    .map_err(|e| VsaError::Gpu(e.to_string()));
276            }
277        }
278
279        let _ = device;
280        Ok(cpu_cosine_similarity(a, b))
281    }
282
283    /// Compute dot product between two vectors.
284    pub fn dot(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<i32, VsaError> {
285        if a.len() != b.len() {
286            return Err(VsaError::DimensionMismatch {
287                expected: a.len(),
288                actual: b.len(),
289            });
290        }
291
292        let device = self.get_device()?;
293
294        #[cfg(feature = "cuda")]
295        {
296            if matches!(device, Device::Cuda(_)) {
297                return GpuDotSimilarity
298                    .dispatch(&(a.clone(), b.clone()), &device)
299                    .map_err(|e| VsaError::Gpu(e.to_string()));
300            }
301        }
302
303        let _ = device;
304        Ok(a.dot(b))
305    }
306
307    /// Compute Hamming distance between two vectors.
308    ///
309    /// Returns the number of positions where the vectors differ.
310    pub fn hamming_distance(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<usize, VsaError> {
311        if a.len() != b.len() {
312            return Err(VsaError::DimensionMismatch {
313                expected: a.len(),
314                actual: b.len(),
315            });
316        }
317
318        let device = self.get_device()?;
319
320        #[cfg(feature = "cuda")]
321        {
322            if matches!(device, Device::Cuda(_)) {
323                return GpuHammingDistance
324                    .dispatch(&(a.clone(), b.clone()), &device)
325                    .map_err(|e| VsaError::Gpu(e.to_string()));
326            }
327        }
328
329        let _ = device;
330        Ok(cpu_hamming_distance(a, b))
331    }
332
333    /// Convert i8 slice to PackedTritVec.
334    pub fn from_i8(&self, values: &[i8]) -> Result<PackedTritVec, VsaError> {
335        let mut packed = PackedTritVec::new(values.len());
336        for (i, &v) in values.iter().enumerate() {
337            let trit = match v {
338                1 => Trit::P,
339                0 => Trit::Z,
340                -1 => Trit::N,
341                _ => return Err(VsaError::InvalidValue { value: v, index: i }),
342            };
343            packed.set(i, trit);
344        }
345        Ok(packed)
346    }
347
348    /// Convert PackedTritVec to i8 Vec.
349    pub fn to_i8(&self, packed: &PackedTritVec) -> Vec<i8> {
350        let mut result = Vec::with_capacity(packed.len());
351        for i in 0..packed.len() {
352            result.push(packed.get(i).value());
353        }
354        result
355    }
356}
357
358// CPU implementations
359
360fn cpu_random(dim: usize, seed: u32) -> PackedTritVec {
361    use rand::{Rng, SeedableRng};
362    use rand_chacha::ChaCha8Rng;
363
364    let mut rng = ChaCha8Rng::seed_from_u64(u64::from(seed));
365    let mut packed = PackedTritVec::new(dim);
366
367    for i in 0..dim {
368        let r: f32 = rng.gen();
369        let trit = if r < 0.333 {
370            Trit::N
371        } else if r < 0.666 {
372            Trit::Z
373        } else {
374            Trit::P
375        };
376        packed.set(i, trit);
377    }
378
379    packed
380}
381
382fn cpu_bind(a: &PackedTritVec, b: &PackedTritVec) -> PackedTritVec {
383    // Ternary multiplication table:
384    // P * P = P, P * Z = Z, P * N = N
385    // Z * _ = Z
386    // N * P = N, N * Z = Z, N * N = P
387    let mut result = PackedTritVec::new(a.len());
388    for i in 0..a.len() {
389        let va = a.get(i).value();
390        let vb = b.get(i).value();
391        let prod = va * vb;
392        let trit = match prod {
393            1 => Trit::P,
394            -1 => Trit::N,
395            _ => Trit::Z,
396        };
397        result.set(i, trit);
398    }
399    result
400}
401
402fn cpu_bundle(vectors: &[PackedTritVec]) -> PackedTritVec {
403    let dim = vectors[0].len();
404    let mut result = PackedTritVec::new(dim);
405
406    for i in 0..dim {
407        let mut pos_count = 0i32;
408        let mut neg_count = 0i32;
409
410        for v in vectors {
411            match v.get(i) {
412                Trit::P => pos_count += 1,
413                Trit::N => neg_count += 1,
414                Trit::Z => {}
415            }
416        }
417
418        let trit = if pos_count > neg_count {
419            Trit::P
420        } else if neg_count > pos_count {
421            Trit::N
422        } else {
423            Trit::Z
424        };
425        result.set(i, trit);
426    }
427
428    result
429}
430
431fn cpu_cosine_similarity(a: &PackedTritVec, b: &PackedTritVec) -> f32 {
432    let dot = a.dot(b) as f32;
433
434    // Count non-zero elements for normalization
435    let mut norm_a_sq = 0i32;
436    let mut norm_b_sq = 0i32;
437
438    for i in 0..a.len() {
439        let va = a.get(i).value() as i32;
440        let vb = b.get(i).value() as i32;
441        norm_a_sq += va * va;
442        norm_b_sq += vb * vb;
443    }
444
445    if norm_a_sq == 0 || norm_b_sq == 0 {
446        return 0.0;
447    }
448
449    dot / ((norm_a_sq as f32).sqrt() * (norm_b_sq as f32).sqrt())
450}
451
452fn cpu_hamming_distance(a: &PackedTritVec, b: &PackedTritVec) -> usize {
453    let mut distance = 0;
454    for i in 0..a.len() {
455        if a.get(i) != b.get(i) {
456            distance += 1;
457        }
458    }
459    distance
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    #[test]
467    fn test_bind_unbind_roundtrip() {
468        let ops = VsaOps::new(VsaConfig::default().with_device(DevicePreference::Cpu));
469
470        // Use non-zero vectors for perfect recovery
471        // (zeros in either vector cause information loss)
472        let a = ops.from_i8(&[1, -1, 1, -1, 1, -1, 1, -1]).unwrap();
473        let b = ops.from_i8(&[1, 1, -1, -1, 1, 1, -1, -1]).unwrap();
474
475        let bound = ops.bind(&a, &b).unwrap();
476        let recovered = ops.unbind(&bound, &b).unwrap();
477
478        // Should recover a exactly when no zeros involved
479        for i in 0..a.len() {
480            assert_eq!(a.get(i), recovered.get(i));
481        }
482    }
483
484    #[test]
485    fn test_bundle_majority() {
486        let ops = VsaOps::new(VsaConfig::default().with_device(DevicePreference::Cpu));
487
488        // Create 3 vectors with known values
489        let v1 = ops.from_i8(&[1, 1, -1, 0]).unwrap();
490        let v2 = ops.from_i8(&[1, -1, -1, 1]).unwrap();
491        let v3 = ops.from_i8(&[1, 0, 1, -1]).unwrap();
492
493        let bundled = ops.bundle(&[v1, v2, v3]).unwrap();
494        let result = ops.to_i8(&bundled);
495
496        // Position 0: [1, 1, 1] -> majority 1
497        assert_eq!(result[0], 1);
498        // Position 2: [-1, -1, 1] -> majority -1
499        assert_eq!(result[2], -1);
500    }
501
502    #[test]
503    fn test_cosine_similarity_identical() {
504        let ops = VsaOps::new(VsaConfig::default().with_device(DevicePreference::Cpu));
505
506        let a = ops.random(1000, 42).unwrap();
507        let sim = ops.cosine_similarity(&a, &a).unwrap();
508
509        // Identical vectors should have similarity 1.0
510        assert!((sim - 1.0).abs() < 1e-6);
511    }
512
513    #[test]
514    fn test_hamming_distance() {
515        let ops = VsaOps::new(VsaConfig::default().with_device(DevicePreference::Cpu));
516
517        let a = ops.from_i8(&[1, 0, -1, 1]).unwrap();
518        let b = ops.from_i8(&[1, -1, -1, 0]).unwrap();
519
520        // Differences at positions 1, 3
521        let dist = ops.hamming_distance(&a, &b).unwrap();
522        assert_eq!(dist, 2);
523    }
524}