Skip to main content

trit_vsa/kernels/
mod.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tyler Zervas
3
4//! Modular kernel architecture for ternary VSA operations.
5//!
6//! This module provides a backend-agnostic interface for ternary vector operations,
7//! allowing easy swapping between CPU (scalar/SIMD), CUDA (CubeCL), and future
8//! backends (e.g., Burn) at runtime.
9//!
10//! # Architecture
11//!
12//! ```text
13//! +-------------------+
14//! |  TernaryBackend   |  <- Trait defining all operations
15//! +-------------------+
16//!          |
17//!    +-----+-----+-----+
18//!    |           |     |
19//!    v           v     v
20//! +------+  +-------+  +------+
21//! |  CPU |  | CubeCL|  | Burn |
22//! +------+  +-------+  +------+
23//! ```
24//!
25//! # Backend Selection
26//!
27//! Backends can be selected based on:
28//! - Feature flags (`cuda`, `burn`)
29//! - Runtime detection (GPU availability)
30//! - User configuration (force CPU/GPU)
31//! - Problem size thresholds
32//!
33//! # Usage
34//!
35//! ```rust,ignore
36//! use trit_vsa::kernels::{TernaryBackend, get_backend, BackendConfig};
37//!
38//! let config = BackendConfig::auto();
39//! let backend = get_backend(&config);
40//!
41//! let result = backend.bind(&vec_a, &vec_b)?;
42//! ```
43
44pub mod cpu;
45
46#[cfg(feature = "cuda")]
47pub mod cubecl;
48
49// Burn backend stub for future integration
50pub mod burn;
51
52use crate::{PackedTritVec, Result, TernaryError};
53
54// Re-export key types
55pub use cpu::CpuBackend;
56
57#[cfg(feature = "cuda")]
58pub use cubecl::CubeclBackend;
59
60pub use burn::BurnBackend;
61
62/// Configuration for backend selection.
63#[derive(Debug, Clone)]
64pub struct BackendConfig {
65    /// Preferred backend type.
66    pub preferred: BackendPreference,
67    /// Minimum dimensions for GPU dispatch (default: 4096).
68    pub gpu_threshold: usize,
69    /// Whether to use SIMD on CPU (default: true).
70    pub use_simd: bool,
71}
72
73impl Default for BackendConfig {
74    fn default() -> Self {
75        Self::auto()
76    }
77}
78
79impl BackendConfig {
80    /// Create configuration with automatic backend selection.
81    #[must_use]
82    pub fn auto() -> Self {
83        Self {
84            preferred: BackendPreference::Auto,
85            gpu_threshold: 4096,
86            use_simd: true,
87        }
88    }
89
90    /// Force CPU backend.
91    #[must_use]
92    pub fn cpu_only() -> Self {
93        Self {
94            preferred: BackendPreference::Cpu,
95            gpu_threshold: usize::MAX,
96            use_simd: true,
97        }
98    }
99
100    /// Force GPU backend (requires `cuda` feature).
101    #[must_use]
102    pub fn gpu_only() -> Self {
103        Self {
104            preferred: BackendPreference::Gpu,
105            gpu_threshold: 0,
106            use_simd: false,
107        }
108    }
109
110    /// Set GPU threshold for automatic selection.
111    #[must_use]
112    pub fn with_gpu_threshold(mut self, threshold: usize) -> Self {
113        self.gpu_threshold = threshold;
114        self
115    }
116
117    /// Enable or disable SIMD on CPU.
118    #[must_use]
119    pub fn with_simd(mut self, enabled: bool) -> Self {
120        self.use_simd = enabled;
121        self
122    }
123}
124
125/// Backend preference for kernel execution.
126#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
127pub enum BackendPreference {
128    /// Automatically select based on problem size and availability.
129    #[default]
130    Auto,
131    /// Force CPU execution.
132    Cpu,
133    /// Force GPU execution (requires cuda feature).
134    Gpu,
135    /// Use Burn backend (for future integration).
136    Burn,
137}
138
139/// Input for random vector generation.
140#[derive(Debug, Clone)]
141pub struct RandomConfig {
142    /// Vector dimension.
143    pub dim: usize,
144    /// Random seed.
145    pub seed: u64,
146}
147
148impl RandomConfig {
149    /// Create a new random configuration.
150    #[must_use]
151    pub fn new(dim: usize, seed: u64) -> Self {
152        Self { dim, seed }
153    }
154}
155
156/// Backend-agnostic trait for ternary VSA operations.
157///
158/// This trait defines all core operations that can be implemented by different
159/// backends (CPU, CubeCL/CUDA, Burn, etc.).
160///
161/// # Implementors
162///
163/// - [`CpuBackend`]: CPU implementation with optional SIMD acceleration
164/// - [`CubeclBackend`]: CUDA implementation via CubeCL (requires `cuda` feature)
165/// - [`BurnBackend`]: Burn framework integration (stub for future)
166///
167/// # Thread Safety
168///
169/// All implementations must be `Send + Sync` to allow use across threads.
170pub trait TernaryBackend: Send + Sync {
171    /// Returns the backend name for debugging/logging.
172    fn name(&self) -> &'static str;
173
174    /// Returns true if this backend is available on the current system.
175    fn is_available(&self) -> bool;
176
177    /// Bind two vectors (composition operation).
178    ///
179    /// Implements balanced ternary binding: `result[i] = (a[i] - b[i]) mod 3`
180    ///
181    /// # Properties
182    /// - `unbind(bind(a, b), b) == a`
183    ///
184    /// # Errors
185    /// Returns error if vectors have mismatched dimensions.
186    fn bind(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<PackedTritVec>;
187
188    /// Unbind a vector (inverse of bind).
189    ///
190    /// Implements: `result[i] = (a[i] + b[i]) mod 3`
191    ///
192    /// # Errors
193    /// Returns error if vectors have mismatched dimensions.
194    fn unbind(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<PackedTritVec>;
195
196    /// Bundle multiple vectors using majority voting.
197    ///
198    /// For each dimension, selects the majority trit value.
199    /// Ties resolve to zero.
200    ///
201    /// # Errors
202    /// Returns error if vectors have mismatched dimensions or input is empty.
203    fn bundle(&self, vectors: &[&PackedTritVec]) -> Result<PackedTritVec>;
204
205    /// Compute dot product similarity.
206    ///
207    /// Returns the sum of element-wise products.
208    /// Range: [-n, +n] where n = dimension.
209    ///
210    /// # Errors
211    /// Returns error if vectors have mismatched dimensions.
212    fn dot_similarity(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<i32>;
213
214    /// Compute cosine similarity.
215    ///
216    /// Returns: `dot(a, b) / (||a|| * ||b||)`
217    /// Range: [-1.0, +1.0]
218    ///
219    /// # Errors
220    /// Returns error if vectors have mismatched dimensions.
221    fn cosine_similarity(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<f32>;
222
223    /// Compute Hamming distance.
224    ///
225    /// Counts positions where vectors differ.
226    /// Range: [0, n] where n = dimension.
227    ///
228    /// # Errors
229    /// Returns error if vectors have mismatched dimensions.
230    fn hamming_distance(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<usize>;
231
232    /// Generate a random ternary vector.
233    ///
234    /// Uses the provided seed for reproducibility.
235    fn random(&self, config: &RandomConfig) -> Result<PackedTritVec>;
236
237    /// Negate a vector element-wise.
238    ///
239    /// Returns a new vector where all values are negated.
240    fn negate(&self, a: &PackedTritVec) -> Result<PackedTritVec>;
241}
242
243/// Dynamic backend dispatcher.
244///
245/// Wraps any `TernaryBackend` implementation for dynamic dispatch.
246pub struct DynamicBackend {
247    inner: Box<dyn TernaryBackend>,
248}
249
250impl DynamicBackend {
251    /// Create a new dynamic backend from a concrete implementation.
252    pub fn new<B: TernaryBackend + 'static>(backend: B) -> Self {
253        Self {
254            inner: Box::new(backend),
255        }
256    }
257
258    /// Get the underlying backend reference.
259    #[must_use]
260    pub fn inner(&self) -> &dyn TernaryBackend {
261        &*self.inner
262    }
263}
264
265impl TernaryBackend for DynamicBackend {
266    fn name(&self) -> &'static str {
267        self.inner.name()
268    }
269
270    fn is_available(&self) -> bool {
271        self.inner.is_available()
272    }
273
274    fn bind(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<PackedTritVec> {
275        self.inner.bind(a, b)
276    }
277
278    fn unbind(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<PackedTritVec> {
279        self.inner.unbind(a, b)
280    }
281
282    fn bundle(&self, vectors: &[&PackedTritVec]) -> Result<PackedTritVec> {
283        self.inner.bundle(vectors)
284    }
285
286    fn dot_similarity(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<i32> {
287        self.inner.dot_similarity(a, b)
288    }
289
290    fn cosine_similarity(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<f32> {
291        self.inner.cosine_similarity(a, b)
292    }
293
294    fn hamming_distance(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<usize> {
295        self.inner.hamming_distance(a, b)
296    }
297
298    fn random(&self, config: &RandomConfig) -> Result<PackedTritVec> {
299        self.inner.random(config)
300    }
301
302    fn negate(&self, a: &PackedTritVec) -> Result<PackedTritVec> {
303        self.inner.negate(a)
304    }
305}
306
307/// Get the appropriate backend based on configuration.
308///
309/// This function selects the best available backend based on:
310/// 1. User preference (if specified)
311/// 2. Feature availability (cuda, burn)
312/// 3. Hardware detection (GPU presence)
313///
314/// # Arguments
315///
316/// * `config` - Backend configuration
317///
318/// # Returns
319///
320/// A boxed backend implementation ready for use.
321#[must_use]
322pub fn get_backend(config: &BackendConfig) -> DynamicBackend {
323    match config.preferred {
324        BackendPreference::Cpu => DynamicBackend::new(CpuBackend::new(config.use_simd)),
325
326        #[cfg(feature = "cuda")]
327        BackendPreference::Gpu => {
328            let cubecl = CubeclBackend::new();
329            if cubecl.is_available() {
330                DynamicBackend::new(cubecl)
331            } else {
332                // Fall back to CPU if GPU not available
333                DynamicBackend::new(CpuBackend::new(config.use_simd))
334            }
335        }
336
337        #[cfg(not(feature = "cuda"))]
338        BackendPreference::Gpu => {
339            // No CUDA support compiled in, fall back to CPU
340            DynamicBackend::new(CpuBackend::new(config.use_simd))
341        }
342
343        BackendPreference::Burn => {
344            // Burn backend is a stub - fall back to CPU for now
345            DynamicBackend::new(CpuBackend::new(config.use_simd))
346        }
347
348        BackendPreference::Auto => {
349            // Auto-selection: try GPU first if available and configured
350            #[cfg(feature = "cuda")]
351            {
352                let cubecl = CubeclBackend::new();
353                if cubecl.is_available() {
354                    return DynamicBackend::new(cubecl);
355                }
356            }
357            // Fall back to CPU
358            DynamicBackend::new(CpuBackend::new(config.use_simd))
359        }
360    }
361}
362
363/// Get a backend appropriate for the given problem size.
364///
365/// This is a convenience function that considers both configuration
366/// and problem size when selecting a backend.
367///
368/// # Arguments
369///
370/// * `config` - Backend configuration
371/// * `problem_size` - Size of the operation (typically vector dimension)
372///
373/// # Returns
374///
375/// A backend optimized for the given problem size.
376#[must_use]
377pub fn get_backend_for_size(config: &BackendConfig, problem_size: usize) -> DynamicBackend {
378    // For auto selection, only use GPU if problem size exceeds threshold
379    if config.preferred == BackendPreference::Auto && problem_size < config.gpu_threshold {
380        return DynamicBackend::new(CpuBackend::new(config.use_simd));
381    }
382
383    get_backend(config)
384}
385
386/// Check dimension compatibility and return error if mismatched.
387pub(crate) fn check_dimensions(a: &PackedTritVec, b: &PackedTritVec) -> Result<()> {
388    if a.len() != b.len() {
389        return Err(TernaryError::DimensionMismatch {
390            expected: a.len(),
391            actual: b.len(),
392        });
393    }
394    Ok(())
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use crate::Trit;
401
402    fn make_test_vector(values: &[i8]) -> PackedTritVec {
403        let mut vec = PackedTritVec::new(values.len());
404        for (i, &v) in values.iter().enumerate() {
405            let trit = match v {
406                -1 => Trit::N,
407                0 => Trit::Z,
408                1 => Trit::P,
409                _ => panic!("Invalid trit value"),
410            };
411            vec.set(i, trit);
412        }
413        vec
414    }
415
416    #[test]
417    fn test_backend_config_default() {
418        let config = BackendConfig::default();
419        assert_eq!(config.preferred, BackendPreference::Auto);
420        assert_eq!(config.gpu_threshold, 4096);
421        assert!(config.use_simd);
422    }
423
424    #[test]
425    fn test_backend_config_cpu_only() {
426        let config = BackendConfig::cpu_only();
427        assert_eq!(config.preferred, BackendPreference::Cpu);
428    }
429
430    #[test]
431    fn test_get_backend_cpu() {
432        let config = BackendConfig::cpu_only();
433        let backend = get_backend(&config);
434        assert_eq!(backend.name(), "cpu");
435        assert!(backend.is_available());
436    }
437
438    #[test]
439    fn test_cpu_backend_bind() {
440        let config = BackendConfig::cpu_only();
441        let backend = get_backend(&config);
442
443        let a = make_test_vector(&[1, 0, -1, 1]);
444        let b = make_test_vector(&[1, -1, 0, -1]);
445
446        let result = backend.bind(&a, &b).unwrap();
447        assert_eq!(result.len(), 4);
448
449        // Verify bind/unbind inverse property
450        let recovered = backend.unbind(&result, &b).unwrap();
451        for i in 0..4 {
452            assert_eq!(recovered.get(i), a.get(i), "mismatch at position {i}");
453        }
454    }
455
456    #[test]
457    fn test_cpu_backend_bundle() {
458        let config = BackendConfig::cpu_only();
459        let backend = get_backend(&config);
460
461        let a = make_test_vector(&[1, 1, -1, 0]);
462        let b = make_test_vector(&[1, -1, -1, 1]);
463        let c = make_test_vector(&[1, 0, 1, -1]);
464
465        let result = backend.bundle(&[&a, &b, &c]).unwrap();
466
467        // Position 0: 1, 1, 1 -> majority is 1
468        assert_eq!(result.get(0), Trit::P);
469        // Position 2: -1, -1, 1 -> majority is -1
470        assert_eq!(result.get(2), Trit::N);
471    }
472
473    #[test]
474    fn test_cpu_backend_dot_similarity() {
475        let config = BackendConfig::cpu_only();
476        let backend = get_backend(&config);
477
478        let a = make_test_vector(&[1, 0, -1, 1]);
479        let b = make_test_vector(&[1, -1, -1, 0]);
480
481        let dot = backend.dot_similarity(&a, &b).unwrap();
482        // Expected: 1*1 + 0*(-1) + (-1)*(-1) + 1*0 = 1 + 0 + 1 + 0 = 2
483        assert_eq!(dot, 2);
484    }
485
486    #[test]
487    fn test_cpu_backend_hamming_distance() {
488        let config = BackendConfig::cpu_only();
489        let backend = get_backend(&config);
490
491        let a = make_test_vector(&[1, 0, -1, 1]);
492        let b = make_test_vector(&[1, -1, -1, 0]);
493
494        let dist = backend.hamming_distance(&a, &b).unwrap();
495        // Positions 1 and 3 differ
496        assert_eq!(dist, 2);
497    }
498
499    #[test]
500    fn test_cpu_backend_random() {
501        let config = BackendConfig::cpu_only();
502        let backend = get_backend(&config);
503
504        let random_config = RandomConfig::new(100, 42);
505        let result = backend.random(&random_config).unwrap();
506
507        assert_eq!(result.len(), 100);
508
509        // Check distribution (statistical test)
510        let pos = result.count_positive();
511        let neg = result.count_negative();
512        let zero = result.len() - pos - neg;
513
514        assert!(pos > 10, "too few positive: {pos}");
515        assert!(neg > 10, "too few negative: {neg}");
516        assert!(zero > 10, "too few zero: {zero}");
517    }
518
519    #[test]
520    fn test_dimension_mismatch() {
521        let config = BackendConfig::cpu_only();
522        let backend = get_backend(&config);
523
524        let a = make_test_vector(&[1, 0, -1]);
525        let b = make_test_vector(&[1, -1]);
526
527        assert!(backend.bind(&a, &b).is_err());
528        assert!(backend.unbind(&a, &b).is_err());
529        assert!(backend.dot_similarity(&a, &b).is_err());
530        assert!(backend.hamming_distance(&a, &b).is_err());
531    }
532
533    #[test]
534    fn test_get_backend_for_size() {
535        let config = BackendConfig::auto().with_gpu_threshold(1000);
536
537        // Small problem should use CPU
538        let backend_small = get_backend_for_size(&config, 500);
539        assert_eq!(backend_small.name(), "cpu");
540
541        // Large problem would use GPU if available, but falls back to CPU
542        let backend_large = get_backend_for_size(&config, 2000);
543        // On systems without CUDA, this will still be CPU
544        assert!(backend_large.is_available());
545    }
546}