1use serde::{Deserialize, Serialize};
7use std::fmt;
8
9use crate::error::{Result, TernaryError};
10use crate::packed::PackedTritVec;
11use crate::trit::Trit;
12
13#[derive(Clone, Serialize, Deserialize)]
45pub struct SparseVec {
46 positive_indices: Vec<usize>,
48 negative_indices: Vec<usize>,
50 num_dims: usize,
52}
53
54impl SparseVec {
55 #[must_use]
59 pub fn new(num_dims: usize) -> Self {
60 Self {
61 positive_indices: Vec::new(),
62 negative_indices: Vec::new(),
63 num_dims,
64 }
65 }
66
67 pub fn from_indices(
80 mut positive_indices: Vec<usize>,
81 mut negative_indices: Vec<usize>,
82 num_dims: usize,
83 ) -> Result<Self> {
84 positive_indices.sort_unstable();
86 negative_indices.sort_unstable();
87
88 if let Some(&max) = positive_indices.last() {
90 if max >= num_dims {
91 return Err(TernaryError::IndexOutOfBounds {
92 index: max,
93 size: num_dims,
94 });
95 }
96 }
97 if let Some(&max) = negative_indices.last() {
98 if max >= num_dims {
99 return Err(TernaryError::IndexOutOfBounds {
100 index: max,
101 size: num_dims,
102 });
103 }
104 }
105
106 let mut pi = 0;
108 let mut ni = 0;
109 while pi < positive_indices.len() && ni < negative_indices.len() {
110 match positive_indices[pi].cmp(&negative_indices[ni]) {
111 std::cmp::Ordering::Equal => {
112 return Err(TernaryError::InvalidValue(positive_indices[pi] as i32));
113 }
114 std::cmp::Ordering::Less => pi += 1,
115 std::cmp::Ordering::Greater => ni += 1,
116 }
117 }
118
119 Ok(Self {
120 positive_indices,
121 negative_indices,
122 num_dims,
123 })
124 }
125
126 #[must_use]
128 pub fn from_trits(trits: &[Trit]) -> Self {
129 let mut positive_indices = Vec::new();
130 let mut negative_indices = Vec::new();
131
132 for (i, &trit) in trits.iter().enumerate() {
133 match trit {
134 Trit::P => positive_indices.push(i),
135 Trit::N => negative_indices.push(i),
136 Trit::Z => {}
137 }
138 }
139
140 Self {
141 positive_indices,
142 negative_indices,
143 num_dims: trits.len(),
144 }
145 }
146
147 #[must_use]
149 pub fn from_packed(packed: &PackedTritVec) -> Self {
150 let mut positive_indices = Vec::new();
151 let mut negative_indices = Vec::new();
152
153 for i in 0..packed.len() {
154 match packed.get(i) {
155 Trit::P => positive_indices.push(i),
156 Trit::N => negative_indices.push(i),
157 Trit::Z => {}
158 }
159 }
160
161 Self {
162 positive_indices,
163 negative_indices,
164 num_dims: packed.len(),
165 }
166 }
167
168 #[must_use]
170 pub const fn len(&self) -> usize {
171 self.num_dims
172 }
173
174 #[must_use]
176 pub const fn is_empty(&self) -> bool {
177 self.num_dims == 0
178 }
179
180 pub fn set(&mut self, dim: usize, value: Trit) {
186 assert!(dim < self.num_dims, "dimension out of bounds");
187
188 self.positive_indices.retain(|&i| i != dim);
190 self.negative_indices.retain(|&i| i != dim);
191
192 match value {
194 Trit::P => {
195 let pos = self.positive_indices.partition_point(|&x| x < dim);
196 self.positive_indices.insert(pos, dim);
197 }
198 Trit::N => {
199 let pos = self.negative_indices.partition_point(|&x| x < dim);
200 self.negative_indices.insert(pos, dim);
201 }
202 Trit::Z => {} }
204 }
205
206 #[must_use]
212 pub fn get(&self, dim: usize) -> Trit {
213 assert!(dim < self.num_dims, "dimension out of bounds");
214
215 if self.positive_indices.binary_search(&dim).is_ok() {
216 Trit::P
217 } else if self.negative_indices.binary_search(&dim).is_ok() {
218 Trit::N
219 } else {
220 Trit::Z
221 }
222 }
223
224 #[must_use]
226 pub fn num_dims(&self) -> usize {
227 self.num_dims
228 }
229
230 #[must_use]
232 pub fn count_nonzero(&self) -> usize {
233 self.positive_indices.len() + self.negative_indices.len()
234 }
235
236 #[must_use]
238 pub fn count_positive(&self) -> usize {
239 self.positive_indices.len()
240 }
241
242 #[must_use]
244 pub fn count_negative(&self) -> usize {
245 self.negative_indices.len()
246 }
247
248 #[must_use]
250 #[allow(clippy::cast_precision_loss)]
251 pub fn sparsity(&self) -> f32 {
252 if self.num_dims == 0 {
253 return 1.0;
254 }
255 1.0 - (self.count_nonzero() as f32 / self.num_dims as f32)
256 }
257
258 #[must_use]
266 pub fn dot(&self, other: &SparseVec) -> i32 {
267 assert_eq!(
268 self.num_dims, other.num_dims,
269 "vectors must have same dimensions"
270 );
271
272 let mut result: i32 = 0;
273
274 result += Self::count_intersection(&self.positive_indices, &other.positive_indices) as i32;
276 result += Self::count_intersection(&self.negative_indices, &other.negative_indices) as i32;
277
278 result -= Self::count_intersection(&self.positive_indices, &other.negative_indices) as i32;
280 result -= Self::count_intersection(&self.negative_indices, &other.positive_indices) as i32;
281
282 result
283 }
284
285 #[must_use]
293 pub fn dot_packed(&self, other: &PackedTritVec) -> i32 {
294 assert_eq!(
295 self.num_dims,
296 other.len(),
297 "vectors must have same dimensions"
298 );
299
300 let mut result: i32 = 0;
301
302 for &idx in &self.positive_indices {
304 result += other.get(idx).value() as i32;
305 }
306
307 for &idx in &self.negative_indices {
309 result -= other.get(idx).value() as i32;
310 }
311
312 result
313 }
314
315 #[must_use]
317 pub fn sum(&self) -> i32 {
318 self.positive_indices.len() as i32 - self.negative_indices.len() as i32
319 }
320
321 #[must_use]
323 pub fn negated(&self) -> Self {
324 Self {
325 positive_indices: self.negative_indices.clone(),
326 negative_indices: self.positive_indices.clone(),
327 num_dims: self.num_dims,
328 }
329 }
330
331 #[must_use]
333 pub fn positive_indices(&self) -> &[usize] {
334 &self.positive_indices
335 }
336
337 #[must_use]
339 pub fn negative_indices(&self) -> &[usize] {
340 &self.negative_indices
341 }
342
343 #[must_use]
345 pub fn to_packed(&self) -> PackedTritVec {
346 let mut packed = PackedTritVec::new(self.num_dims);
347 for &idx in &self.positive_indices {
348 packed.set(idx, Trit::P);
349 }
350 for &idx in &self.negative_indices {
351 packed.set(idx, Trit::N);
352 }
353 packed
354 }
355
356 #[must_use]
358 pub fn to_trits(&self) -> Vec<Trit> {
359 let mut result = vec![Trit::Z; self.num_dims];
360 for &idx in &self.positive_indices {
361 result[idx] = Trit::P;
362 }
363 for &idx in &self.negative_indices {
364 result[idx] = Trit::N;
365 }
366 result
367 }
368
369 #[must_use]
371 pub fn memory_bytes(&self) -> usize {
372 std::mem::size_of::<Self>()
374 + self.positive_indices.capacity() * std::mem::size_of::<usize>()
375 + self.negative_indices.capacity() * std::mem::size_of::<usize>()
376 }
377
378 fn count_intersection(a: &[usize], b: &[usize]) -> usize {
380 let mut count = 0;
381 let mut ai = 0;
382 let mut bi = 0;
383
384 while ai < a.len() && bi < b.len() {
385 match a[ai].cmp(&b[bi]) {
386 std::cmp::Ordering::Equal => {
387 count += 1;
388 ai += 1;
389 bi += 1;
390 }
391 std::cmp::Ordering::Less => ai += 1,
392 std::cmp::Ordering::Greater => bi += 1,
393 }
394 }
395
396 count
397 }
398}
399
400impl fmt::Debug for SparseVec {
401 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
402 write!(
403 f,
404 "SparseVec(dims={}, pos={}, neg={}, sparsity={:.2}%)",
405 self.num_dims,
406 self.positive_indices.len(),
407 self.negative_indices.len(),
408 self.sparsity() * 100.0
409 )
410 }
411}
412
413impl PartialEq for SparseVec {
414 fn eq(&self, other: &Self) -> bool {
415 self.num_dims == other.num_dims
416 && self.positive_indices == other.positive_indices
417 && self.negative_indices == other.negative_indices
418 }
419}
420
421impl Eq for SparseVec {}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
428 fn test_sparse_new() {
429 let vec = SparseVec::new(1000);
430 assert_eq!(vec.len(), 1000);
431 assert_eq!(vec.count_nonzero(), 0);
432 assert!((vec.sparsity() - 1.0).abs() < 0.001);
433 }
434
435 #[test]
436 fn test_sparse_set_get() {
437 let mut vec = SparseVec::new(100);
438
439 vec.set(10, Trit::P);
440 vec.set(20, Trit::N);
441 vec.set(50, Trit::P);
442
443 assert_eq!(vec.get(10), Trit::P);
444 assert_eq!(vec.get(20), Trit::N);
445 assert_eq!(vec.get(50), Trit::P);
446 assert_eq!(vec.get(0), Trit::Z);
447 assert_eq!(vec.get(99), Trit::Z);
448 }
449
450 #[test]
451 fn test_sparse_overwrite() {
452 let mut vec = SparseVec::new(10);
453
454 vec.set(0, Trit::P);
455 assert_eq!(vec.get(0), Trit::P);
456 assert_eq!(vec.count_nonzero(), 1);
457
458 vec.set(0, Trit::N);
459 assert_eq!(vec.get(0), Trit::N);
460 assert_eq!(vec.count_nonzero(), 1);
461
462 vec.set(0, Trit::Z);
463 assert_eq!(vec.get(0), Trit::Z);
464 assert_eq!(vec.count_nonzero(), 0);
465 }
466
467 #[test]
468 fn test_sparse_dot() {
469 let mut a = SparseVec::new(100);
470 let mut b = SparseVec::new(100);
471
472 a.set(0, Trit::P);
474 a.set(1, Trit::N);
475 a.set(10, Trit::P);
476
477 b.set(0, Trit::P);
479 b.set(1, Trit::P);
480 b.set(20, Trit::N);
481
482 assert_eq!(a.dot(&b), 0);
484
485 b.set(1, Trit::N);
487 assert_eq!(a.dot(&b), 2);
489 }
490
491 #[test]
492 fn test_sparse_dot_packed() {
493 let mut sparse = SparseVec::new(64);
494 let mut packed = PackedTritVec::new(64);
495
496 sparse.set(0, Trit::P);
497 sparse.set(1, Trit::N);
498
499 packed.set(0, Trit::P);
500 packed.set(1, Trit::P);
501 packed.set(2, Trit::N);
502
503 assert_eq!(sparse.dot_packed(&packed), 0);
505
506 packed.set(1, Trit::N);
507 assert_eq!(sparse.dot_packed(&packed), 2);
509 }
510
511 #[test]
512 fn test_sparse_from_trits() {
513 let trits = [Trit::P, Trit::N, Trit::Z, Trit::P, Trit::Z];
514 let vec = SparseVec::from_trits(&trits);
515
516 assert_eq!(vec.len(), 5);
517 assert_eq!(vec.count_positive(), 2);
518 assert_eq!(vec.count_negative(), 1);
519
520 assert_eq!(vec.to_trits(), trits);
521 }
522
523 #[test]
524 fn test_sparse_to_packed_roundtrip() {
525 let mut sparse = SparseVec::new(100);
526 sparse.set(0, Trit::P);
527 sparse.set(50, Trit::N);
528 sparse.set(99, Trit::P);
529
530 let packed = sparse.to_packed();
531 let back = SparseVec::from_packed(&packed);
532
533 assert_eq!(sparse, back);
534 }
535
536 #[test]
537 fn test_sparse_negated() {
538 let mut vec = SparseVec::new(10);
539 vec.set(0, Trit::P);
540 vec.set(1, Trit::N);
541
542 let neg = vec.negated();
543
544 assert_eq!(neg.get(0), Trit::N);
545 assert_eq!(neg.get(1), Trit::P);
546 }
547
548 #[test]
549 fn test_sparse_from_indices() {
550 let pos = vec![0, 10, 50];
551 let neg = vec![5, 20];
552 let vec = SparseVec::from_indices(pos, neg, 100).unwrap();
553
554 assert_eq!(vec.get(0), Trit::P);
555 assert_eq!(vec.get(10), Trit::P);
556 assert_eq!(vec.get(50), Trit::P);
557 assert_eq!(vec.get(5), Trit::N);
558 assert_eq!(vec.get(20), Trit::N);
559 assert_eq!(vec.get(1), Trit::Z);
560 }
561
562 #[test]
563 fn test_sparse_from_indices_overlap_error() {
564 let pos = vec![0, 10];
565 let neg = vec![10, 20]; let result = SparseVec::from_indices(pos, neg, 100);
567 assert!(result.is_err());
568 }
569
570 #[test]
571 fn test_sparse_from_indices_bounds_error() {
572 let pos = vec![100]; let neg = vec![];
574 let result = SparseVec::from_indices(pos, neg, 100);
575 assert!(result.is_err());
576 }
577
578 #[test]
579 fn test_sparse_sum() {
580 let mut vec = SparseVec::new(100);
581 vec.set(0, Trit::P);
582 vec.set(1, Trit::P);
583 vec.set(2, Trit::N);
584
585 assert_eq!(vec.sum(), 1); }
587}