1use candle_core::Device;
34use thiserror::Error;
35use trit_vsa::{PackedTritVec, Trit};
36
37#[cfg(feature = "cuda")]
38use trit_vsa::gpu::{
39 GpuBind, GpuBundle, GpuCosineSimilarity, GpuDispatchable, GpuDotSimilarity, GpuHammingDistance,
40 GpuRandom, GpuUnbind, RandomInput,
41};
42
43#[derive(Debug, Error)]
45pub enum VsaError {
46 #[error("dimension mismatch: expected {expected}, got {actual}")]
48 DimensionMismatch { expected: usize, actual: usize },
49
50 #[error("invalid value {value} at index {index}")]
52 InvalidValue { value: i8, index: usize },
53
54 #[error("GPU error: {0}")]
56 Gpu(String),
57
58 #[error("empty input")]
60 EmptyInput,
61}
62
63#[derive(Debug, Clone)]
65pub struct VsaConfig {
66 pub device: DevicePreference,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum DevicePreference {
73 Auto,
75 Gpu,
77 Cpu,
79}
80
81impl Default for VsaConfig {
82 fn default() -> Self {
83 Self {
84 device: DevicePreference::Auto,
85 }
86 }
87}
88
89impl VsaConfig {
90 pub fn with_device(mut self, device: DevicePreference) -> Self {
92 self.device = device;
93 self
94 }
95}
96
97#[derive(Debug, Clone)]
101pub struct VsaOps {
102 config: VsaConfig,
103}
104
105impl VsaOps {
106 pub fn new(config: VsaConfig) -> Self {
108 Self { config }
109 }
110
111 fn get_device(&self) -> Result<Device, VsaError> {
113 match self.config.device {
114 DevicePreference::Cpu => Ok(Device::Cpu),
115 DevicePreference::Gpu => {
116 #[cfg(feature = "cuda")]
117 {
118 Device::cuda_if_available(0).map_err(|e| VsaError::Gpu(e.to_string()))
119 }
120 #[cfg(not(feature = "cuda"))]
121 {
122 Err(VsaError::Gpu(
123 "CUDA not compiled. Rebuild with --features cuda".to_string(),
124 ))
125 }
126 }
127 DevicePreference::Auto => {
128 #[cfg(feature = "cuda")]
129 {
130 Ok(Device::cuda_if_available(0).unwrap_or(Device::Cpu))
131 }
132 #[cfg(not(feature = "cuda"))]
133 {
134 Ok(Device::Cpu)
135 }
136 }
137 }
138 }
139
140 pub fn random(&self, dim: usize, seed: u32) -> Result<PackedTritVec, VsaError> {
147 if dim == 0 {
148 return Err(VsaError::EmptyInput);
149 }
150
151 let device = self.get_device()?;
152
153 #[cfg(feature = "cuda")]
154 {
155 if matches!(device, Device::Cuda(_)) {
156 let input = RandomInput::new(dim, seed);
157 return GpuRandom
158 .dispatch(&input, &device)
159 .map_err(|e| VsaError::Gpu(e.to_string()));
160 }
161 }
162
163 let _ = device; Ok(cpu_random(dim, seed))
166 }
167
168 pub fn bind(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<PackedTritVec, VsaError> {
173 if a.len() != b.len() {
174 return Err(VsaError::DimensionMismatch {
175 expected: a.len(),
176 actual: b.len(),
177 });
178 }
179
180 let device = self.get_device()?;
181
182 #[cfg(feature = "cuda")]
183 {
184 if matches!(device, Device::Cuda(_)) {
185 return GpuBind
186 .dispatch(&(a.clone(), b.clone()), &device)
187 .map_err(|e| VsaError::Gpu(e.to_string()));
188 }
189 }
190
191 let _ = device;
192 Ok(cpu_bind(a, b))
193 }
194
195 pub fn unbind(&self, bound: &PackedTritVec, key: &PackedTritVec) -> Result<PackedTritVec, VsaError> {
199 if bound.len() != key.len() {
200 return Err(VsaError::DimensionMismatch {
201 expected: bound.len(),
202 actual: key.len(),
203 });
204 }
205
206 let device = self.get_device()?;
207
208 #[cfg(feature = "cuda")]
209 {
210 if matches!(device, Device::Cuda(_)) {
211 return GpuUnbind
212 .dispatch(&(bound.clone(), key.clone()), &device)
213 .map_err(|e| VsaError::Gpu(e.to_string()));
214 }
215 }
216
217 let _ = device;
218 Ok(cpu_bind(bound, key))
220 }
221
222 pub fn bundle(&self, vectors: &[PackedTritVec]) -> Result<PackedTritVec, VsaError> {
227 if vectors.is_empty() {
228 return Err(VsaError::EmptyInput);
229 }
230
231 let dim = vectors[0].len();
232 for (i, v) in vectors.iter().enumerate() {
233 if v.len() != dim {
234 return Err(VsaError::DimensionMismatch {
235 expected: dim,
236 actual: v.len(),
237 });
238 }
239 let _ = i; }
241
242 let device = self.get_device()?;
243
244 #[cfg(feature = "cuda")]
245 {
246 if matches!(device, Device::Cuda(_)) {
247 return GpuBundle
248 .dispatch(&vectors.to_vec(), &device)
249 .map_err(|e| VsaError::Gpu(e.to_string()));
250 }
251 }
252
253 let _ = device;
254 Ok(cpu_bundle(vectors))
255 }
256
257 pub fn cosine_similarity(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<f32, VsaError> {
261 if a.len() != b.len() {
262 return Err(VsaError::DimensionMismatch {
263 expected: a.len(),
264 actual: b.len(),
265 });
266 }
267
268 let device = self.get_device()?;
269
270 #[cfg(feature = "cuda")]
271 {
272 if matches!(device, Device::Cuda(_)) {
273 return GpuCosineSimilarity
274 .dispatch(&(a.clone(), b.clone()), &device)
275 .map_err(|e| VsaError::Gpu(e.to_string()));
276 }
277 }
278
279 let _ = device;
280 Ok(cpu_cosine_similarity(a, b))
281 }
282
283 pub fn dot(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<i32, VsaError> {
285 if a.len() != b.len() {
286 return Err(VsaError::DimensionMismatch {
287 expected: a.len(),
288 actual: b.len(),
289 });
290 }
291
292 let device = self.get_device()?;
293
294 #[cfg(feature = "cuda")]
295 {
296 if matches!(device, Device::Cuda(_)) {
297 return GpuDotSimilarity
298 .dispatch(&(a.clone(), b.clone()), &device)
299 .map_err(|e| VsaError::Gpu(e.to_string()));
300 }
301 }
302
303 let _ = device;
304 Ok(a.dot(b))
305 }
306
307 pub fn hamming_distance(&self, a: &PackedTritVec, b: &PackedTritVec) -> Result<usize, VsaError> {
311 if a.len() != b.len() {
312 return Err(VsaError::DimensionMismatch {
313 expected: a.len(),
314 actual: b.len(),
315 });
316 }
317
318 let device = self.get_device()?;
319
320 #[cfg(feature = "cuda")]
321 {
322 if matches!(device, Device::Cuda(_)) {
323 return GpuHammingDistance
324 .dispatch(&(a.clone(), b.clone()), &device)
325 .map_err(|e| VsaError::Gpu(e.to_string()));
326 }
327 }
328
329 let _ = device;
330 Ok(cpu_hamming_distance(a, b))
331 }
332
333 pub fn from_i8(&self, values: &[i8]) -> Result<PackedTritVec, VsaError> {
335 let mut packed = PackedTritVec::new(values.len());
336 for (i, &v) in values.iter().enumerate() {
337 let trit = match v {
338 1 => Trit::P,
339 0 => Trit::Z,
340 -1 => Trit::N,
341 _ => return Err(VsaError::InvalidValue { value: v, index: i }),
342 };
343 packed.set(i, trit);
344 }
345 Ok(packed)
346 }
347
348 pub fn to_i8(&self, packed: &PackedTritVec) -> Vec<i8> {
350 let mut result = Vec::with_capacity(packed.len());
351 for i in 0..packed.len() {
352 result.push(packed.get(i).value());
353 }
354 result
355 }
356}
357
358fn cpu_random(dim: usize, seed: u32) -> PackedTritVec {
361 use rand::{Rng, SeedableRng};
362 use rand_chacha::ChaCha8Rng;
363
364 let mut rng = ChaCha8Rng::seed_from_u64(u64::from(seed));
365 let mut packed = PackedTritVec::new(dim);
366
367 for i in 0..dim {
368 let r: f32 = rng.gen();
369 let trit = if r < 0.333 {
370 Trit::N
371 } else if r < 0.666 {
372 Trit::Z
373 } else {
374 Trit::P
375 };
376 packed.set(i, trit);
377 }
378
379 packed
380}
381
382fn cpu_bind(a: &PackedTritVec, b: &PackedTritVec) -> PackedTritVec {
383 let mut result = PackedTritVec::new(a.len());
388 for i in 0..a.len() {
389 let va = a.get(i).value();
390 let vb = b.get(i).value();
391 let prod = va * vb;
392 let trit = match prod {
393 1 => Trit::P,
394 -1 => Trit::N,
395 _ => Trit::Z,
396 };
397 result.set(i, trit);
398 }
399 result
400}
401
402fn cpu_bundle(vectors: &[PackedTritVec]) -> PackedTritVec {
403 let dim = vectors[0].len();
404 let mut result = PackedTritVec::new(dim);
405
406 for i in 0..dim {
407 let mut pos_count = 0i32;
408 let mut neg_count = 0i32;
409
410 for v in vectors {
411 match v.get(i) {
412 Trit::P => pos_count += 1,
413 Trit::N => neg_count += 1,
414 Trit::Z => {}
415 }
416 }
417
418 let trit = if pos_count > neg_count {
419 Trit::P
420 } else if neg_count > pos_count {
421 Trit::N
422 } else {
423 Trit::Z
424 };
425 result.set(i, trit);
426 }
427
428 result
429}
430
431fn cpu_cosine_similarity(a: &PackedTritVec, b: &PackedTritVec) -> f32 {
432 let dot = a.dot(b) as f32;
433
434 let mut norm_a_sq = 0i32;
436 let mut norm_b_sq = 0i32;
437
438 for i in 0..a.len() {
439 let va = a.get(i).value() as i32;
440 let vb = b.get(i).value() as i32;
441 norm_a_sq += va * va;
442 norm_b_sq += vb * vb;
443 }
444
445 if norm_a_sq == 0 || norm_b_sq == 0 {
446 return 0.0;
447 }
448
449 dot / ((norm_a_sq as f32).sqrt() * (norm_b_sq as f32).sqrt())
450}
451
452fn cpu_hamming_distance(a: &PackedTritVec, b: &PackedTritVec) -> usize {
453 let mut distance = 0;
454 for i in 0..a.len() {
455 if a.get(i) != b.get(i) {
456 distance += 1;
457 }
458 }
459 distance
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465
466 #[test]
467 fn test_bind_unbind_roundtrip() {
468 let ops = VsaOps::new(VsaConfig::default().with_device(DevicePreference::Cpu));
469
470 let a = ops.from_i8(&[1, -1, 1, -1, 1, -1, 1, -1]).unwrap();
473 let b = ops.from_i8(&[1, 1, -1, -1, 1, 1, -1, -1]).unwrap();
474
475 let bound = ops.bind(&a, &b).unwrap();
476 let recovered = ops.unbind(&bound, &b).unwrap();
477
478 for i in 0..a.len() {
480 assert_eq!(a.get(i), recovered.get(i));
481 }
482 }
483
484 #[test]
485 fn test_bundle_majority() {
486 let ops = VsaOps::new(VsaConfig::default().with_device(DevicePreference::Cpu));
487
488 let v1 = ops.from_i8(&[1, 1, -1, 0]).unwrap();
490 let v2 = ops.from_i8(&[1, -1, -1, 1]).unwrap();
491 let v3 = ops.from_i8(&[1, 0, 1, -1]).unwrap();
492
493 let bundled = ops.bundle(&[v1, v2, v3]).unwrap();
494 let result = ops.to_i8(&bundled);
495
496 assert_eq!(result[0], 1);
498 assert_eq!(result[2], -1);
500 }
501
502 #[test]
503 fn test_cosine_similarity_identical() {
504 let ops = VsaOps::new(VsaConfig::default().with_device(DevicePreference::Cpu));
505
506 let a = ops.random(1000, 42).unwrap();
507 let sim = ops.cosine_similarity(&a, &a).unwrap();
508
509 assert!((sim - 1.0).abs() < 1e-6);
511 }
512
513 #[test]
514 fn test_hamming_distance() {
515 let ops = VsaOps::new(VsaConfig::default().with_device(DevicePreference::Cpu));
516
517 let a = ops.from_i8(&[1, 0, -1, 1]).unwrap();
518 let b = ops.from_i8(&[1, -1, -1, 0]).unwrap();
519
520 let dist = ops.hamming_distance(&a, &b).unwrap();
522 assert_eq!(dist, 2);
523 }
524}