Skip to main content

trit_vsa/
sparse.rs

1//! Sparse ternary vector storage using COO format.
2//!
3//! This module provides `SparseVec`, an efficient representation for highly
4//! sparse ternary vectors. It stores only non-zero indices and their signs.
5
6use serde::{Deserialize, Serialize};
7use std::fmt;
8
9use crate::error::{Result, TernaryError};
10use crate::packed::PackedTritVec;
11use crate::trit::Trit;
12
13/// A sparse ternary vector using COO (Coordinate) format.
14///
15/// Only non-zero values are stored, making this efficient for vectors where
16/// most elements are zero (high sparsity).
17///
18/// # Storage
19///
20/// Non-zero indices are stored separately for positive and negative values:
21/// - `positive_indices`: indices where value is +1
22/// - `negative_indices`: indices where value is -1
23///
24/// # When to Use
25///
26/// Use `SparseVec` when sparsity > 90% for memory efficiency.
27/// Use `PackedTritVec` for denser vectors or when operations like dot product
28/// need consistent O(n) time regardless of sparsity.
29///
30/// # Examples
31///
32/// ```
33/// use trit_vsa::{SparseVec, Trit};
34///
35/// let mut vec = SparseVec::new(1000);
36/// vec.set(10, Trit::P);
37/// vec.set(500, Trit::N);
38///
39/// assert_eq!(vec.get(10), Trit::P);
40/// assert_eq!(vec.get(500), Trit::N);
41/// assert_eq!(vec.get(0), Trit::Z);
42/// assert_eq!(vec.count_nonzero(), 2);
43/// ```
44#[derive(Clone, Serialize, Deserialize)]
45pub struct SparseVec {
46    /// Indices where value is +1 (sorted).
47    positive_indices: Vec<usize>,
48    /// Indices where value is -1 (sorted).
49    negative_indices: Vec<usize>,
50    /// Logical dimension count.
51    num_dims: usize,
52}
53
54impl SparseVec {
55    /// Create a new sparse vector with given dimension count.
56    ///
57    /// All values are initialized to zero (no storage needed).
58    #[must_use]
59    pub fn new(num_dims: usize) -> Self {
60        Self {
61            positive_indices: Vec::new(),
62            negative_indices: Vec::new(),
63            num_dims,
64        }
65    }
66
67    /// Create from separate index lists.
68    ///
69    /// # Arguments
70    ///
71    /// * `positive_indices` - Indices where value is +1
72    /// * `negative_indices` - Indices where value is -1
73    /// * `num_dims` - Logical dimension count
74    ///
75    /// # Errors
76    ///
77    /// Returns error if any index is out of bounds or if there are duplicates
78    /// across positive and negative lists.
79    pub fn from_indices(
80        mut positive_indices: Vec<usize>,
81        mut negative_indices: Vec<usize>,
82        num_dims: usize,
83    ) -> Result<Self> {
84        // Validate and sort
85        positive_indices.sort_unstable();
86        negative_indices.sort_unstable();
87
88        // Check bounds
89        if let Some(&max) = positive_indices.last() {
90            if max >= num_dims {
91                return Err(TernaryError::IndexOutOfBounds {
92                    index: max,
93                    size: num_dims,
94                });
95            }
96        }
97        if let Some(&max) = negative_indices.last() {
98            if max >= num_dims {
99                return Err(TernaryError::IndexOutOfBounds {
100                    index: max,
101                    size: num_dims,
102                });
103            }
104        }
105
106        // Check for overlap (same index can't be both positive and negative)
107        let mut pi = 0;
108        let mut ni = 0;
109        while pi < positive_indices.len() && ni < negative_indices.len() {
110            match positive_indices[pi].cmp(&negative_indices[ni]) {
111                std::cmp::Ordering::Equal => {
112                    return Err(TernaryError::InvalidValue(positive_indices[pi] as i32));
113                }
114                std::cmp::Ordering::Less => pi += 1,
115                std::cmp::Ordering::Greater => ni += 1,
116            }
117        }
118
119        Ok(Self {
120            positive_indices,
121            negative_indices,
122            num_dims,
123        })
124    }
125
126    /// Create from a slice of trits.
127    #[must_use]
128    pub fn from_trits(trits: &[Trit]) -> Self {
129        let mut positive_indices = Vec::new();
130        let mut negative_indices = Vec::new();
131
132        for (i, &trit) in trits.iter().enumerate() {
133            match trit {
134                Trit::P => positive_indices.push(i),
135                Trit::N => negative_indices.push(i),
136                Trit::Z => {}
137            }
138        }
139
140        Self {
141            positive_indices,
142            negative_indices,
143            num_dims: trits.len(),
144        }
145    }
146
147    /// Create from a [`PackedTritVec`].
148    #[must_use]
149    pub fn from_packed(packed: &PackedTritVec) -> Self {
150        let mut positive_indices = Vec::new();
151        let mut negative_indices = Vec::new();
152
153        for i in 0..packed.len() {
154            match packed.get(i) {
155                Trit::P => positive_indices.push(i),
156                Trit::N => negative_indices.push(i),
157                Trit::Z => {}
158            }
159        }
160
161        Self {
162            positive_indices,
163            negative_indices,
164            num_dims: packed.len(),
165        }
166    }
167
168    /// Get the number of logical dimensions.
169    #[must_use]
170    pub const fn len(&self) -> usize {
171        self.num_dims
172    }
173
174    /// Check if the vector is empty.
175    #[must_use]
176    pub const fn is_empty(&self) -> bool {
177        self.num_dims == 0
178    }
179
180    /// Set a dimension to a trit value.
181    ///
182    /// # Panics
183    ///
184    /// Panics if `dim >= len()`.
185    pub fn set(&mut self, dim: usize, value: Trit) {
186        assert!(dim < self.num_dims, "dimension out of bounds");
187
188        // Remove from current lists
189        self.positive_indices.retain(|&i| i != dim);
190        self.negative_indices.retain(|&i| i != dim);
191
192        // Add to appropriate list
193        match value {
194            Trit::P => {
195                let pos = self.positive_indices.partition_point(|&x| x < dim);
196                self.positive_indices.insert(pos, dim);
197            }
198            Trit::N => {
199                let pos = self.negative_indices.partition_point(|&x| x < dim);
200                self.negative_indices.insert(pos, dim);
201            }
202            Trit::Z => {} // Already removed
203        }
204    }
205
206    /// Get the trit value at a dimension.
207    ///
208    /// # Panics
209    ///
210    /// Panics if `dim >= len()`.
211    #[must_use]
212    pub fn get(&self, dim: usize) -> Trit {
213        assert!(dim < self.num_dims, "dimension out of bounds");
214
215        if self.positive_indices.binary_search(&dim).is_ok() {
216            Trit::P
217        } else if self.negative_indices.binary_search(&dim).is_ok() {
218            Trit::N
219        } else {
220            Trit::Z
221        }
222    }
223
224    /// Count non-zero elements.
225    #[must_use]
226    pub fn count_nonzero(&self) -> usize {
227        self.positive_indices.len() + self.negative_indices.len()
228    }
229
230    /// Count positive (+1) elements.
231    #[must_use]
232    pub fn count_positive(&self) -> usize {
233        self.positive_indices.len()
234    }
235
236    /// Count negative (-1) elements.
237    #[must_use]
238    pub fn count_negative(&self) -> usize {
239        self.negative_indices.len()
240    }
241
242    /// Calculate sparsity (fraction of zeros).
243    #[must_use]
244    #[allow(clippy::cast_precision_loss)]
245    pub fn sparsity(&self) -> f32 {
246        if self.num_dims == 0 {
247            return 1.0;
248        }
249        1.0 - (self.count_nonzero() as f32 / self.num_dims as f32)
250    }
251
252    /// Compute dot product with another sparse vector.
253    ///
254    /// This is O(k1 + k2) where k1 and k2 are the number of non-zero elements.
255    ///
256    /// # Panics
257    ///
258    /// Panics if vectors have different dimensions.
259    #[must_use]
260    pub fn dot(&self, other: &SparseVec) -> i32 {
261        assert_eq!(
262            self.num_dims, other.num_dims,
263            "vectors must have same dimensions"
264        );
265
266        let mut result: i32 = 0;
267
268        // Count intersections between same-sign indices
269        result += Self::count_intersection(&self.positive_indices, &other.positive_indices) as i32;
270        result += Self::count_intersection(&self.negative_indices, &other.negative_indices) as i32;
271
272        // Subtract intersections between opposite-sign indices
273        result -= Self::count_intersection(&self.positive_indices, &other.negative_indices) as i32;
274        result -= Self::count_intersection(&self.negative_indices, &other.positive_indices) as i32;
275
276        result
277    }
278
279    /// Compute dot product with a packed vector.
280    ///
281    /// Efficient when this sparse vector has few non-zeros.
282    ///
283    /// # Panics
284    ///
285    /// Panics if vectors have different dimensions.
286    #[must_use]
287    pub fn dot_packed(&self, other: &PackedTritVec) -> i32 {
288        assert_eq!(
289            self.num_dims,
290            other.len(),
291            "vectors must have same dimensions"
292        );
293
294        let mut result: i32 = 0;
295
296        // Sum contributions from positive indices
297        for &idx in &self.positive_indices {
298            result += other.get(idx).value() as i32;
299        }
300
301        // Sum contributions from negative indices (note: we add negative of other's value)
302        for &idx in &self.negative_indices {
303            result -= other.get(idx).value() as i32;
304        }
305
306        result
307    }
308
309    /// Compute the sum of all elements.
310    #[must_use]
311    pub fn sum(&self) -> i32 {
312        self.positive_indices.len() as i32 - self.negative_indices.len() as i32
313    }
314
315    /// Return a negated copy.
316    #[must_use]
317    pub fn negated(&self) -> Self {
318        Self {
319            positive_indices: self.negative_indices.clone(),
320            negative_indices: self.positive_indices.clone(),
321            num_dims: self.num_dims,
322        }
323    }
324
325    /// Get reference to positive indices.
326    #[must_use]
327    pub fn positive_indices(&self) -> &[usize] {
328        &self.positive_indices
329    }
330
331    /// Get reference to negative indices.
332    #[must_use]
333    pub fn negative_indices(&self) -> &[usize] {
334        &self.negative_indices
335    }
336
337    /// Convert to a [`PackedTritVec`].
338    #[must_use]
339    pub fn to_packed(&self) -> PackedTritVec {
340        let mut packed = PackedTritVec::new(self.num_dims);
341        for &idx in &self.positive_indices {
342            packed.set(idx, Trit::P);
343        }
344        for &idx in &self.negative_indices {
345            packed.set(idx, Trit::N);
346        }
347        packed
348    }
349
350    /// Convert to a vector of trits.
351    #[must_use]
352    pub fn to_trits(&self) -> Vec<Trit> {
353        let mut result = vec![Trit::Z; self.num_dims];
354        for &idx in &self.positive_indices {
355            result[idx] = Trit::P;
356        }
357        for &idx in &self.negative_indices {
358            result[idx] = Trit::N;
359        }
360        result
361    }
362
363    /// Memory size in bytes (approximate).
364    #[must_use]
365    pub fn memory_bytes(&self) -> usize {
366        // Vec overhead + index storage
367        std::mem::size_of::<Self>()
368            + self.positive_indices.capacity() * std::mem::size_of::<usize>()
369            + self.negative_indices.capacity() * std::mem::size_of::<usize>()
370    }
371
372    // Internal: count intersection of two sorted lists
373    fn count_intersection(a: &[usize], b: &[usize]) -> usize {
374        let mut count = 0;
375        let mut ai = 0;
376        let mut bi = 0;
377
378        while ai < a.len() && bi < b.len() {
379            match a[ai].cmp(&b[bi]) {
380                std::cmp::Ordering::Equal => {
381                    count += 1;
382                    ai += 1;
383                    bi += 1;
384                }
385                std::cmp::Ordering::Less => ai += 1,
386                std::cmp::Ordering::Greater => bi += 1,
387            }
388        }
389
390        count
391    }
392}
393
394impl fmt::Debug for SparseVec {
395    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
396        write!(
397            f,
398            "SparseVec(dims={}, pos={}, neg={}, sparsity={:.2}%)",
399            self.num_dims,
400            self.positive_indices.len(),
401            self.negative_indices.len(),
402            self.sparsity() * 100.0
403        )
404    }
405}
406
407impl PartialEq for SparseVec {
408    fn eq(&self, other: &Self) -> bool {
409        self.num_dims == other.num_dims
410            && self.positive_indices == other.positive_indices
411            && self.negative_indices == other.negative_indices
412    }
413}
414
415impl Eq for SparseVec {}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[test]
422    fn test_sparse_new() {
423        let vec = SparseVec::new(1000);
424        assert_eq!(vec.len(), 1000);
425        assert_eq!(vec.count_nonzero(), 0);
426        assert!((vec.sparsity() - 1.0).abs() < 0.001);
427    }
428
429    #[test]
430    fn test_sparse_set_get() {
431        let mut vec = SparseVec::new(100);
432
433        vec.set(10, Trit::P);
434        vec.set(20, Trit::N);
435        vec.set(50, Trit::P);
436
437        assert_eq!(vec.get(10), Trit::P);
438        assert_eq!(vec.get(20), Trit::N);
439        assert_eq!(vec.get(50), Trit::P);
440        assert_eq!(vec.get(0), Trit::Z);
441        assert_eq!(vec.get(99), Trit::Z);
442    }
443
444    #[test]
445    fn test_sparse_overwrite() {
446        let mut vec = SparseVec::new(10);
447
448        vec.set(0, Trit::P);
449        assert_eq!(vec.get(0), Trit::P);
450        assert_eq!(vec.count_nonzero(), 1);
451
452        vec.set(0, Trit::N);
453        assert_eq!(vec.get(0), Trit::N);
454        assert_eq!(vec.count_nonzero(), 1);
455
456        vec.set(0, Trit::Z);
457        assert_eq!(vec.get(0), Trit::Z);
458        assert_eq!(vec.count_nonzero(), 0);
459    }
460
461    #[test]
462    fn test_sparse_dot() {
463        let mut a = SparseVec::new(100);
464        let mut b = SparseVec::new(100);
465
466        // a = [+1 at 0, -1 at 1, +1 at 10]
467        a.set(0, Trit::P);
468        a.set(1, Trit::N);
469        a.set(10, Trit::P);
470
471        // b = [+1 at 0, +1 at 1, -1 at 20]
472        b.set(0, Trit::P);
473        b.set(1, Trit::P);
474        b.set(20, Trit::N);
475
476        // dot = 1*1 + (-1)*1 + 1*0 + 0*(-1) = 1 - 1 = 0
477        assert_eq!(a.dot(&b), 0);
478
479        // Modify b[1] to -1
480        b.set(1, Trit::N);
481        // dot = 1*1 + (-1)*(-1) + 1*0 + 0*(-1) = 1 + 1 = 2
482        assert_eq!(a.dot(&b), 2);
483    }
484
485    #[test]
486    fn test_sparse_dot_packed() {
487        let mut sparse = SparseVec::new(64);
488        let mut packed = PackedTritVec::new(64);
489
490        sparse.set(0, Trit::P);
491        sparse.set(1, Trit::N);
492
493        packed.set(0, Trit::P);
494        packed.set(1, Trit::P);
495        packed.set(2, Trit::N);
496
497        // dot = 1*1 + (-1)*1 = 0
498        assert_eq!(sparse.dot_packed(&packed), 0);
499
500        packed.set(1, Trit::N);
501        // dot = 1*1 + (-1)*(-1) = 2
502        assert_eq!(sparse.dot_packed(&packed), 2);
503    }
504
505    #[test]
506    fn test_sparse_from_trits() {
507        let trits = [Trit::P, Trit::N, Trit::Z, Trit::P, Trit::Z];
508        let vec = SparseVec::from_trits(&trits);
509
510        assert_eq!(vec.len(), 5);
511        assert_eq!(vec.count_positive(), 2);
512        assert_eq!(vec.count_negative(), 1);
513
514        assert_eq!(vec.to_trits(), trits);
515    }
516
517    #[test]
518    fn test_sparse_to_packed_roundtrip() {
519        let mut sparse = SparseVec::new(100);
520        sparse.set(0, Trit::P);
521        sparse.set(50, Trit::N);
522        sparse.set(99, Trit::P);
523
524        let packed = sparse.to_packed();
525        let back = SparseVec::from_packed(&packed);
526
527        assert_eq!(sparse, back);
528    }
529
530    #[test]
531    fn test_sparse_negated() {
532        let mut vec = SparseVec::new(10);
533        vec.set(0, Trit::P);
534        vec.set(1, Trit::N);
535
536        let neg = vec.negated();
537
538        assert_eq!(neg.get(0), Trit::N);
539        assert_eq!(neg.get(1), Trit::P);
540    }
541
542    #[test]
543    fn test_sparse_from_indices() {
544        let pos = vec![0, 10, 50];
545        let neg = vec![5, 20];
546        let vec = SparseVec::from_indices(pos, neg, 100).unwrap();
547
548        assert_eq!(vec.get(0), Trit::P);
549        assert_eq!(vec.get(10), Trit::P);
550        assert_eq!(vec.get(50), Trit::P);
551        assert_eq!(vec.get(5), Trit::N);
552        assert_eq!(vec.get(20), Trit::N);
553        assert_eq!(vec.get(1), Trit::Z);
554    }
555
556    #[test]
557    fn test_sparse_from_indices_overlap_error() {
558        let pos = vec![0, 10];
559        let neg = vec![10, 20]; // 10 is in both - invalid
560        let result = SparseVec::from_indices(pos, neg, 100);
561        assert!(result.is_err());
562    }
563
564    #[test]
565    fn test_sparse_from_indices_bounds_error() {
566        let pos = vec![100]; // Out of bounds for dim=100
567        let neg = vec![];
568        let result = SparseVec::from_indices(pos, neg, 100);
569        assert!(result.is_err());
570    }
571
572    #[test]
573    fn test_sparse_sum() {
574        let mut vec = SparseVec::new(100);
575        vec.set(0, Trit::P);
576        vec.set(1, Trit::P);
577        vec.set(2, Trit::N);
578
579        assert_eq!(vec.sum(), 1); // 1 + 1 - 1 = 1
580    }
581}