1use crate::kernels::{self, BackendConfig, BackendPreference, TernaryBackend};
105use crate::{PackedTritVec, Result, SparseVec, TernaryError, Trit};
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
109pub enum Format {
110 #[default]
112 Tritsliced,
113 Tritpacked,
115 Sparse,
117 Auto,
119}
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
123pub enum DevicePreference {
124 #[default]
126 Auto,
127 Cpu,
129 Gpu,
131}
132
133#[derive(Debug, Clone)]
135pub struct DispatchConfig {
136 pub format: Format,
138 pub device: DevicePreference,
140 pub sparse_threshold: f32,
142 pub gpu_threshold: usize,
144 pub cache_conversions: bool,
146}
147
148impl Default for DispatchConfig {
149 fn default() -> Self {
150 Self::auto()
151 }
152}
153
154impl DispatchConfig {
155 #[must_use]
157 pub fn auto() -> Self {
158 Self {
159 format: Format::Auto,
160 device: DevicePreference::Auto,
161 sparse_threshold: 0.90,
162 gpu_threshold: 4096,
163 cache_conversions: true,
164 }
165 }
166
167 #[must_use]
169 pub fn cpu_only() -> Self {
170 Self {
171 device: DevicePreference::Cpu,
172 ..Self::auto()
173 }
174 }
175
176 #[must_use]
178 pub fn with_format(mut self, format: Format) -> Self {
179 self.format = format;
180 self
181 }
182
183 #[must_use]
185 pub fn with_device(mut self, device: DevicePreference) -> Self {
186 self.device = device;
187 self
188 }
189
190 #[must_use]
192 pub fn with_sparse_threshold(mut self, threshold: f32) -> Self {
193 self.sparse_threshold = threshold;
194 self
195 }
196
197 #[must_use]
199 pub fn with_gpu_threshold(mut self, threshold: usize) -> Self {
200 self.gpu_threshold = threshold;
201 self
202 }
203}
204
205#[derive(Debug, Clone, Copy, PartialEq, Eq)]
207pub enum Operation {
208 Dot,
210 Similarity,
212 Bind,
214 Unbind,
216 Bundle,
218 Negate,
220 Hamming,
222}
223
224impl Operation {
225 #[must_use]
227 pub fn preferred_format(self) -> Format {
228 match self {
229 Operation::Dot | Operation::Similarity | Operation::Hamming => Format::Tritsliced,
231 Operation::Bind | Operation::Unbind | Operation::Negate => Format::Tritsliced,
233 Operation::Bundle => Format::Tritsliced,
235 }
236 }
237
238 #[must_use]
240 pub fn benefits_from_sparse(self) -> bool {
241 matches!(self, Operation::Dot | Operation::Similarity)
242 }
243}
244
245#[derive(Debug, Clone)]
247pub enum TritVector {
248 Sliced(PackedTritVec),
250 Sparse(SparseVec),
252}
253
254impl TritVector {
255 #[must_use]
257 pub fn new(dims: usize) -> Self {
258 Self::Sliced(PackedTritVec::new(dims))
259 }
260
261 #[must_use]
263 pub fn from_packed(packed: PackedTritVec) -> Self {
264 Self::Sliced(packed)
265 }
266
267 #[must_use]
269 pub fn from_sparse(sparse: SparseVec) -> Self {
270 Self::Sparse(sparse)
271 }
272
273 #[must_use]
275 pub fn dims(&self) -> usize {
276 match self {
277 Self::Sliced(p) => p.len(),
278 Self::Sparse(s) => s.num_dims(),
279 }
280 }
281
282 #[must_use]
284 pub fn sparsity(&self) -> f32 {
285 match self {
286 Self::Sliced(p) => p.sparsity(),
287 Self::Sparse(s) => s.sparsity(),
288 }
289 }
290
291 #[must_use]
293 pub fn get(&self, idx: usize) -> Trit {
294 match self {
295 Self::Sliced(p) => p.get(idx),
296 Self::Sparse(s) => s.get(idx),
297 }
298 }
299
300 pub fn set(&mut self, idx: usize, value: Trit) {
302 match self {
303 Self::Sliced(p) => p.set(idx, value),
304 Self::Sparse(s) => s.set(idx, value),
305 }
306 }
307
308 #[must_use]
310 pub fn to_packed(&self) -> PackedTritVec {
311 match self {
312 Self::Sliced(p) => p.clone(),
313 Self::Sparse(s) => s.to_packed(),
314 }
315 }
316
317 #[must_use]
319 pub fn to_sparse(&self) -> SparseVec {
320 match self {
321 Self::Sliced(p) => SparseVec::from_packed(p),
322 Self::Sparse(s) => s.clone(),
323 }
324 }
325
326 fn select_format(
328 &self,
329 other: Option<&Self>,
330 op: Operation,
331 config: &DispatchConfig,
332 ) -> Format {
333 if config.format != Format::Auto {
335 return config.format;
336 }
337
338 let self_sparse = self.sparsity() > config.sparse_threshold;
340 let other_sparse = other.is_some_and(|o| o.sparsity() > config.sparse_threshold);
341
342 if op.benefits_from_sparse() && self_sparse && other_sparse {
343 return Format::Sparse;
344 }
345
346 op.preferred_format()
348 }
349
350 fn to_backend_config(config: &DispatchConfig) -> BackendConfig {
352 let preferred = match config.device {
353 DevicePreference::Cpu => BackendPreference::Cpu,
354 DevicePreference::Gpu => BackendPreference::Gpu,
355 DevicePreference::Auto => BackendPreference::Auto,
356 };
357
358 BackendConfig {
359 preferred,
360 gpu_threshold: config.gpu_threshold,
361 use_simd: true,
362 }
363 }
364
365 fn get_backend_for_config(&self, config: &DispatchConfig) -> kernels::DynamicBackend {
367 let backend_config = Self::to_backend_config(config);
368 kernels::get_backend_for_size(&backend_config, self.dims())
369 }
370
371 pub fn dot(&self, other: &Self, config: &DispatchConfig) -> Result<i32> {
379 if self.dims() != other.dims() {
380 return Err(TernaryError::DimensionMismatch {
381 expected: self.dims(),
382 actual: other.dims(),
383 });
384 }
385
386 let format = self.select_format(Some(other), Operation::Dot, config);
387
388 match format {
389 Format::Sparse => {
390 let a = self.to_sparse();
392 let b = other.to_sparse();
393 Ok(a.dot(&b))
394 }
395 Format::Tritsliced | Format::Tritpacked | Format::Auto => {
396 let a = self.to_packed();
398 let b = other.to_packed();
399 let backend = self.get_backend_for_config(config);
400 backend.dot_similarity(&a, &b)
401 }
402 }
403 }
404
405 pub fn cosine_similarity(&self, other: &Self, config: &DispatchConfig) -> Result<f32> {
413 if self.dims() != other.dims() {
414 return Err(TernaryError::DimensionMismatch {
415 expected: self.dims(),
416 actual: other.dims(),
417 });
418 }
419
420 let format = self.select_format(Some(other), Operation::Similarity, config);
421
422 match format {
423 Format::Sparse => {
424 let a = self.to_sparse();
426 let b = other.to_sparse();
427 Ok(crate::vsa::cosine_similarity_sparse(&a, &b))
428 }
429 Format::Tritsliced | Format::Tritpacked | Format::Auto => {
430 let a = self.to_packed();
432 let b = other.to_packed();
433 let backend = self.get_backend_for_config(config);
434 backend.cosine_similarity(&a, &b)
435 }
436 }
437 }
438
439 pub fn bind(&self, other: &Self, config: &DispatchConfig) -> Result<Self> {
447 if self.dims() != other.dims() {
448 return Err(TernaryError::DimensionMismatch {
449 expected: self.dims(),
450 actual: other.dims(),
451 });
452 }
453
454 let a = self.to_packed();
455 let b = other.to_packed();
456
457 let backend = self.get_backend_for_config(config);
459 let result = backend.bind(&a, &b)?;
460 Ok(Self::Sliced(result))
461 }
462
463 pub fn unbind(&self, other: &Self, config: &DispatchConfig) -> Result<Self> {
471 if self.dims() != other.dims() {
472 return Err(TernaryError::DimensionMismatch {
473 expected: self.dims(),
474 actual: other.dims(),
475 });
476 }
477
478 let a = self.to_packed();
479 let b = other.to_packed();
480
481 let backend = self.get_backend_for_config(config);
483 let result = backend.unbind(&a, &b)?;
484 Ok(Self::Sliced(result))
485 }
486
487 pub fn bundle(&self, other: &Self, config: &DispatchConfig) -> Result<Self> {
495 if self.dims() != other.dims() {
496 return Err(TernaryError::DimensionMismatch {
497 expected: self.dims(),
498 actual: other.dims(),
499 });
500 }
501
502 let a = self.to_packed();
503 let b = other.to_packed();
504
505 let backend = self.get_backend_for_config(config);
507 let result = backend.bundle(&[&a, &b])?;
508 Ok(Self::Sliced(result))
509 }
510
511 pub fn hamming_distance(&self, other: &Self, config: &DispatchConfig) -> Result<usize> {
519 if self.dims() != other.dims() {
520 return Err(TernaryError::DimensionMismatch {
521 expected: self.dims(),
522 actual: other.dims(),
523 });
524 }
525
526 let a = self.to_packed();
527 let b = other.to_packed();
528
529 let backend = self.get_backend_for_config(config);
531 backend.hamming_distance(&a, &b)
532 }
533
534 #[must_use]
536 pub fn negate(&self) -> Self {
537 match self {
538 Self::Sliced(p) => Self::Sliced(p.negated()),
539 Self::Sparse(s) => Self::Sparse(s.negated()),
540 }
541 }
542}
543
544impl From<PackedTritVec> for TritVector {
545 fn from(packed: PackedTritVec) -> Self {
546 Self::Sliced(packed)
547 }
548}
549
550impl From<SparseVec> for TritVector {
551 fn from(sparse: SparseVec) -> Self {
552 Self::Sparse(sparse)
553 }
554}
555
556#[derive(Debug, Default, Clone)]
558pub struct DispatchStats {
559 pub tritsliced_count: usize,
561 pub sparse_count: usize,
563 pub gpu_count: usize,
565 pub cpu_count: usize,
567 pub conversion_count: usize,
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574
575 fn make_test_vector(values: &[i8]) -> TritVector {
576 let mut packed = PackedTritVec::new(values.len());
577 for (i, &v) in values.iter().enumerate() {
578 let trit = match v {
579 -1 => Trit::N,
580 0 => Trit::Z,
581 1 => Trit::P,
582 _ => panic!("Invalid trit value"),
583 };
584 packed.set(i, trit);
585 }
586 TritVector::Sliced(packed)
587 }
588
589 #[test]
590 fn test_dispatch_config_default() {
591 let config = DispatchConfig::auto();
592 assert_eq!(config.format, Format::Auto);
593 assert_eq!(config.device, DevicePreference::Auto);
594 assert!((config.sparse_threshold - 0.90).abs() < f32::EPSILON);
595 assert_eq!(config.gpu_threshold, 4096);
596 }
597
598 #[test]
599 fn test_trit_vector_from_packed() {
600 let packed = PackedTritVec::new(100);
601 let tv = TritVector::from_packed(packed.clone());
602 assert_eq!(tv.dims(), 100);
603 assert!(matches!(tv, TritVector::Sliced(_)));
604 }
605
606 #[test]
607 fn test_operation_preferred_format() {
608 assert_eq!(Operation::Dot.preferred_format(), Format::Tritsliced);
609 assert_eq!(Operation::Similarity.preferred_format(), Format::Tritsliced);
610 assert_eq!(Operation::Bind.preferred_format(), Format::Tritsliced);
611 }
612
613 #[test]
614 fn test_dot_product_dispatch() {
615 let a = TritVector::new(100);
616 let b = TritVector::new(100);
617 let config = DispatchConfig::cpu_only();
618
619 let result = a.dot(&b, &config);
620 assert!(result.is_ok());
621 assert_eq!(result.unwrap(), 0);
622 }
623
624 #[test]
625 fn test_dimension_mismatch() {
626 let a = TritVector::new(100);
627 let b = TritVector::new(200);
628 let config = DispatchConfig::cpu_only();
629
630 let result = a.dot(&b, &config);
631 assert!(result.is_err());
632 }
633
634 #[test]
635 fn test_bind_unbind_with_backend() {
636 let a = make_test_vector(&[1, -1, 0, 1, -1, 0, 1]);
637 let b = make_test_vector(&[-1, 1, 0, -1, 1, 0, -1]);
638 let config = DispatchConfig::cpu_only();
639
640 let bound = a.bind(&b, &config).unwrap();
641 let recovered = bound.unbind(&b, &config).unwrap();
642
643 for i in 0..a.dims() {
645 assert_eq!(recovered.get(i), a.get(i), "mismatch at position {i}");
646 }
647 }
648
649 #[test]
650 fn test_bundle_with_backend() {
651 let a = make_test_vector(&[1, 1, -1, 0, 0]);
652 let b = make_test_vector(&[1, -1, -1, 1, -1]);
653 let config = DispatchConfig::cpu_only();
654
655 let bundled = a.bundle(&b, &config).unwrap();
656
657 assert_eq!(bundled.get(0), Trit::P);
659 assert_eq!(bundled.get(1), Trit::Z);
661 assert_eq!(bundled.get(2), Trit::N);
663 }
664
665 #[test]
666 fn test_cosine_similarity_with_backend() {
667 let a = make_test_vector(&[1, 1, -1, -1]);
668 let config = DispatchConfig::cpu_only();
669
670 let sim = a.cosine_similarity(&a, &config).unwrap();
671 assert!((sim - 1.0).abs() < 0.001);
672 }
673
674 #[test]
675 fn test_hamming_distance_with_backend() {
676 let a = make_test_vector(&[1, 0, -1, 1]);
677 let b = make_test_vector(&[1, -1, -1, 0]);
678 let config = DispatchConfig::cpu_only();
679
680 let dist = a.hamming_distance(&b, &config).unwrap();
681 assert_eq!(dist, 2);
683 }
684
685 #[test]
686 fn test_backend_config_conversion() {
687 let config = DispatchConfig::cpu_only();
688 let backend_config = TritVector::to_backend_config(&config);
689 assert_eq!(backend_config.preferred, BackendPreference::Cpu);
690
691 let config = DispatchConfig::auto().with_device(DevicePreference::Gpu);
692 let backend_config = TritVector::to_backend_config(&config);
693 assert_eq!(backend_config.preferred, BackendPreference::Gpu);
694 }
695
696 #[test]
697 fn test_auto_backend_selection() {
698 let small_vec = TritVector::new(100);
699 let large_vec = TritVector::new(10000);
700
701 let config = DispatchConfig::auto().with_gpu_threshold(5000);
702
703 let backend = small_vec.get_backend_for_config(&config);
705 assert!(backend.name().starts_with("cpu"));
706
707 let backend = large_vec.get_backend_for_config(&config);
709 assert!(backend.is_available());
710 }
711}