1pub mod cpu;
45
46#[cfg(feature = "cuda")]
47pub mod cubecl;
48
49pub mod burn;
51
52use crate::{PackedTritVec, Result, TernaryError};
53
54pub use cpu::CpuBackend;
56
57#[cfg(feature = "cuda")]
58pub use cubecl::CubeclBackend;
59
60pub use burn::BurnBackend;
61
62#[derive(Debug, Clone)]
64pub struct BackendConfig {
65 pub preferred: BackendPreference,
67 pub gpu_threshold: usize,
69 pub use_simd: bool,
71}
72
73impl Default for BackendConfig {
74 fn default() -> Self {
75 Self::auto()
76 }
77}
78
79impl BackendConfig {
80 #[must_use]
82 pub fn auto() -> Self {
83 Self {
84 preferred: BackendPreference::Auto,
85 gpu_threshold: 4096,
86 use_simd: true,
87 }
88 }
89
90 #[must_use]
92 pub fn cpu_only() -> Self {
93 Self {
94 preferred: BackendPreference::Cpu,
95 gpu_threshold: usize::MAX,
96 use_simd: true,
97 }
98 }
99
100 #[must_use]
102 pub fn gpu_only() -> Self {
103 Self {
104 preferred: BackendPreference::Gpu,
105 gpu_threshold: 0,
106 use_simd: false,
107 }
108 }
109
110 #[must_use]
112 pub fn with_gpu_threshold(mut self, threshold: usize) -> Self {
113 self.gpu_threshold = threshold;
114 self
115 }
116
117 #[must_use]
119 pub fn with_simd(mut self, enabled: bool) -> Self {
120 self.use_simd = enabled;
121 self
122 }
123}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
127pub enum BackendPreference {
128 #[default]
130 Auto,
131 Cpu,
133 Gpu,
135 Burn,
137}
138
139#[derive(Debug, Clone)]
141pub struct RandomConfig {
142 pub dim: usize,
144 pub seed: u64,
146}
147
148impl RandomConfig {
149 #[must_use]
151 pub fn new(dim: usize, seed: u64) -> Self {
152 Self { dim, seed }
153 }
154}
155
156pub trait TernaryBackend: Send + Sync {
171 fn name(&self) -> &'static str;
173
174 fn is_available(&self) -> bool;
176
177 fn bind(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<PackedTritVec>;
187
188 fn unbind(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<PackedTritVec>;
195
196 fn bundle(&self, vectors: &[&PackedTritVec]) -> Result<PackedTritVec>;
204
205 fn dot_similarity(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<i32>;
213
214 fn cosine_similarity(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<f32>;
222
223 fn hamming_distance(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<usize>;
231
232 fn random(&self, config: &RandomConfig) -> Result<PackedTritVec>;
236
237 fn negate(&self, a: &PackedTritVec) -> Result<PackedTritVec>;
241}
242
243pub struct DynamicBackend {
247 inner: Box<dyn TernaryBackend>,
248}
249
250impl DynamicBackend {
251 pub fn new<B: TernaryBackend + 'static>(backend: B) -> Self {
253 Self {
254 inner: Box::new(backend),
255 }
256 }
257
258 #[must_use]
260 pub fn inner(&self) -> &dyn TernaryBackend {
261 &*self.inner
262 }
263}
264
265impl TernaryBackend for DynamicBackend {
266 fn name(&self) -> &'static str {
267 self.inner.name()
268 }
269
270 fn is_available(&self) -> bool {
271 self.inner.is_available()
272 }
273
274 fn bind(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<PackedTritVec> {
275 self.inner.bind(a, b)
276 }
277
278 fn unbind(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<PackedTritVec> {
279 self.inner.unbind(a, b)
280 }
281
282 fn bundle(&self, vectors: &[&PackedTritVec]) -> Result<PackedTritVec> {
283 self.inner.bundle(vectors)
284 }
285
286 fn dot_similarity(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<i32> {
287 self.inner.dot_similarity(a, b)
288 }
289
290 fn cosine_similarity(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<f32> {
291 self.inner.cosine_similarity(a, b)
292 }
293
294 fn hamming_distance(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<usize> {
295 self.inner.hamming_distance(a, b)
296 }
297
298 fn random(&self, config: &RandomConfig) -> Result<PackedTritVec> {
299 self.inner.random(config)
300 }
301
302 fn negate(&self, a: &PackedTritVec) -> Result<PackedTritVec> {
303 self.inner.negate(a)
304 }
305}
306
307#[must_use]
322pub fn get_backend(config: &BackendConfig) -> DynamicBackend {
323 match config.preferred {
324 BackendPreference::Cpu => DynamicBackend::new(CpuBackend::new(config.use_simd)),
325
326 #[cfg(feature = "cuda")]
327 BackendPreference::Gpu => {
328 let cubecl = CubeclBackend::new();
329 if cubecl.is_available() {
330 DynamicBackend::new(cubecl)
331 } else {
332 DynamicBackend::new(CpuBackend::new(config.use_simd))
334 }
335 }
336
337 #[cfg(not(feature = "cuda"))]
338 BackendPreference::Gpu => {
339 DynamicBackend::new(CpuBackend::new(config.use_simd))
341 }
342
343 BackendPreference::Burn => {
344 DynamicBackend::new(CpuBackend::new(config.use_simd))
346 }
347
348 BackendPreference::Auto => {
349 #[cfg(feature = "cuda")]
351 {
352 let cubecl = CubeclBackend::new();
353 if cubecl.is_available() {
354 return DynamicBackend::new(cubecl);
355 }
356 }
357 DynamicBackend::new(CpuBackend::new(config.use_simd))
359 }
360 }
361}
362
363#[must_use]
377pub fn get_backend_for_size(config: &BackendConfig, problem_size: usize) -> DynamicBackend {
378 if config.preferred == BackendPreference::Auto && problem_size < config.gpu_threshold {
380 return DynamicBackend::new(CpuBackend::new(config.use_simd));
381 }
382
383 get_backend(config)
384}
385
386pub(crate) fn check_dimensions(a: &PackedTritVec, b: &PackedTritVec) -> Result<()> {
388 if a.len() != b.len() {
389 return Err(TernaryError::DimensionMismatch {
390 expected: a.len(),
391 actual: b.len(),
392 });
393 }
394 Ok(())
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400 use crate::Trit;
401
402 fn make_test_vector(values: &[i8]) -> PackedTritVec {
403 let mut vec = PackedTritVec::new(values.len());
404 for (i, &v) in values.iter().enumerate() {
405 let trit = match v {
406 -1 => Trit::N,
407 0 => Trit::Z,
408 1 => Trit::P,
409 _ => panic!("Invalid trit value"),
410 };
411 vec.set(i, trit);
412 }
413 vec
414 }
415
416 #[test]
417 fn test_backend_config_default() {
418 let config = BackendConfig::default();
419 assert_eq!(config.preferred, BackendPreference::Auto);
420 assert_eq!(config.gpu_threshold, 4096);
421 assert!(config.use_simd);
422 }
423
424 #[test]
425 fn test_backend_config_cpu_only() {
426 let config = BackendConfig::cpu_only();
427 assert_eq!(config.preferred, BackendPreference::Cpu);
428 }
429
430 #[test]
431 fn test_get_backend_cpu() {
432 let config = BackendConfig::cpu_only();
433 let backend = get_backend(&config);
434 assert_eq!(backend.name(), "cpu");
435 assert!(backend.is_available());
436 }
437
438 #[test]
439 fn test_cpu_backend_bind() {
440 let config = BackendConfig::cpu_only();
441 let backend = get_backend(&config);
442
443 let a = make_test_vector(&[1, 0, -1, 1]);
444 let b = make_test_vector(&[1, -1, 0, -1]);
445
446 let result = backend.bind(&a, &b).unwrap();
447 assert_eq!(result.len(), 4);
448
449 let recovered = backend.unbind(&result, &b).unwrap();
451 for i in 0..4 {
452 assert_eq!(recovered.get(i), a.get(i), "mismatch at position {i}");
453 }
454 }
455
456 #[test]
457 fn test_cpu_backend_bundle() {
458 let config = BackendConfig::cpu_only();
459 let backend = get_backend(&config);
460
461 let a = make_test_vector(&[1, 1, -1, 0]);
462 let b = make_test_vector(&[1, -1, -1, 1]);
463 let c = make_test_vector(&[1, 0, 1, -1]);
464
465 let result = backend.bundle(&[&a, &b, &c]).unwrap();
466
467 assert_eq!(result.get(0), Trit::P);
469 assert_eq!(result.get(2), Trit::N);
471 }
472
473 #[test]
474 fn test_cpu_backend_dot_similarity() {
475 let config = BackendConfig::cpu_only();
476 let backend = get_backend(&config);
477
478 let a = make_test_vector(&[1, 0, -1, 1]);
479 let b = make_test_vector(&[1, -1, -1, 0]);
480
481 let dot = backend.dot_similarity(&a, &b).unwrap();
482 assert_eq!(dot, 2);
484 }
485
486 #[test]
487 fn test_cpu_backend_hamming_distance() {
488 let config = BackendConfig::cpu_only();
489 let backend = get_backend(&config);
490
491 let a = make_test_vector(&[1, 0, -1, 1]);
492 let b = make_test_vector(&[1, -1, -1, 0]);
493
494 let dist = backend.hamming_distance(&a, &b).unwrap();
495 assert_eq!(dist, 2);
497 }
498
499 #[test]
500 fn test_cpu_backend_random() {
501 let config = BackendConfig::cpu_only();
502 let backend = get_backend(&config);
503
504 let random_config = RandomConfig::new(100, 42);
505 let result = backend.random(&random_config).unwrap();
506
507 assert_eq!(result.len(), 100);
508
509 let pos = result.count_positive();
511 let neg = result.count_negative();
512 let zero = result.len() - pos - neg;
513
514 assert!(pos > 10, "too few positive: {pos}");
515 assert!(neg > 10, "too few negative: {neg}");
516 assert!(zero > 10, "too few zero: {zero}");
517 }
518
519 #[test]
520 fn test_dimension_mismatch() {
521 let config = BackendConfig::cpu_only();
522 let backend = get_backend(&config);
523
524 let a = make_test_vector(&[1, 0, -1]);
525 let b = make_test_vector(&[1, -1]);
526
527 assert!(backend.bind(&a, &b).is_err());
528 assert!(backend.unbind(&a, &b).is_err());
529 assert!(backend.dot_similarity(&a, &b).is_err());
530 assert!(backend.hamming_distance(&a, &b).is_err());
531 }
532
533 #[test]
534 fn test_get_backend_for_size() {
535 let config = BackendConfig::auto().with_gpu_threshold(1000);
536
537 let backend_small = get_backend_for_size(&config, 500);
539 assert_eq!(backend_small.name(), "cpu");
540
541 let backend_large = get_backend_for_size(&config, 2000);
543 assert!(backend_large.is_available());
545 }
546}