Skip to main content

uni_sparse_vector/
sparse.rs

1use crate::error::SparseError;
2use serde::{Deserialize, Serialize};
3
4/// A learned-sparse vector: `{term_id -> weight}` over a high-cardinality
5/// vocabulary (e.g. SPLADE-v3 / BGE-M3 sparse head over a ~30k-term BERT vocab).
6///
7/// Stored as two parallel arrays — `indices` (term ids) and `values` (weights) —
8/// with `indices` kept **strictly ascending** so that the canonical scoring
9/// kernel ([`crate::ops::sparse_dot`]) is a linear merge-join. The invariant is
10/// enforced at construction by [`SparseVector::new`].
11///
12/// The in-memory and binary forms keep weights as lossless `f32`. Weight
13/// quantization (8-bit, etc.) is a storage-engine concern applied at the index
14/// postings boundary, never in this type, so a brute-force scorer over
15/// `SparseVector` is always an exact ground truth.
16#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
17pub struct SparseVector {
18    indices: Vec<u32>,
19    values: Vec<f32>,
20}
21
22impl SparseVector {
23    /// Construct a sparse vector, validating its invariants:
24    /// - `indices.len() == values.len()` (SV-1),
25    /// - `indices` strictly ascending — sorted and unique (SV-2),
26    /// - every weight finite — no NaN / ±inf (SV-3).
27    ///
28    /// Use [`SparseVector::from_pairs`] for unsorted input with duplicate
29    /// term ids (the typical embedding-producer shape).
30    pub fn new(indices: Vec<u32>, values: Vec<f32>) -> Result<Self, SparseError> {
31        if indices.len() != values.len() {
32            return Err(SparseError::LengthMismatch {
33                indices: indices.len(),
34                values: values.len(),
35            });
36        }
37        for i in 1..indices.len() {
38            if indices[i] <= indices[i - 1] {
39                return Err(SparseError::UnsortedIndices {
40                    position: i,
41                    prev: indices[i - 1],
42                    curr: indices[i],
43                });
44            }
45        }
46        for (position, &value) in values.iter().enumerate() {
47            if !value.is_finite() {
48                return Err(SparseError::NonFiniteWeight { position, value });
49            }
50        }
51        Ok(Self { indices, values })
52    }
53
54    /// Build from arbitrary `(term_id, weight)` pairs: sorts by term id and
55    /// sums the weights of duplicate term ids, then validates. This is the
56    /// ingestion-friendly constructor for embedding-model output and for
57    /// `dict[int, float]` from the Python surface.
58    ///
59    /// Non-finite input weights are still rejected (after summation).
60    pub fn from_pairs(mut pairs: Vec<(u32, f32)>) -> Result<Self, SparseError> {
61        pairs.sort_by_key(|&(term, _)| term);
62        let mut indices = Vec::with_capacity(pairs.len());
63        let mut values = Vec::with_capacity(pairs.len());
64        for (term, weight) in pairs {
65            if indices.last() == Some(&term) {
66                *values
67                    .last_mut()
68                    .expect("indices non-empty implies values non-empty") += weight;
69            } else {
70                indices.push(term);
71                values.push(weight);
72            }
73        }
74        Self::new(indices, values)
75    }
76
77    /// The (strictly ascending) term ids.
78    #[inline]
79    pub fn indices(&self) -> &[u32] {
80        &self.indices
81    }
82
83    /// The weights, parallel to [`SparseVector::indices`].
84    #[inline]
85    pub fn values(&self) -> &[f32] {
86        &self.values
87    }
88
89    /// Number of non-zero terms.
90    #[inline]
91    pub fn len(&self) -> usize {
92        self.indices.len()
93    }
94
95    /// Whether the vector has no non-zero terms.
96    #[inline]
97    pub fn is_empty(&self) -> bool {
98        self.indices.is_empty()
99    }
100
101    /// Iterate over `(term_id, weight)` pairs in ascending term-id order.
102    pub fn iter(&self) -> impl Iterator<Item = (u32, f32)> + '_ {
103        self.indices
104            .iter()
105            .copied()
106            .zip(self.values.iter().copied())
107    }
108
109    /// Consume the vector into its parallel arrays.
110    pub fn into_parts(self) -> (Vec<u32>, Vec<f32>) {
111        (self.indices, self.values)
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn new_accepts_sorted_finite() {
121        let v = SparseVector::new(vec![1, 5, 9], vec![0.5, -1.0, 2.0]).unwrap();
122        assert_eq!(v.len(), 3);
123        assert!(!v.is_empty());
124    }
125
126    #[test]
127    fn new_accepts_empty() {
128        let v = SparseVector::new(vec![], vec![]).unwrap();
129        assert!(v.is_empty());
130        assert_eq!(v.len(), 0);
131    }
132
133    #[test]
134    fn new_rejects_length_mismatch() {
135        let err = SparseVector::new(vec![1, 2], vec![0.5]).unwrap_err();
136        assert!(matches!(err, SparseError::LengthMismatch { .. }));
137    }
138
139    #[test]
140    fn new_rejects_unsorted() {
141        let err = SparseVector::new(vec![5, 1], vec![0.5, 0.5]).unwrap_err();
142        assert!(matches!(err, SparseError::UnsortedIndices { .. }));
143    }
144
145    #[test]
146    fn new_rejects_duplicate_indices() {
147        let err = SparseVector::new(vec![3, 3], vec![0.5, 0.5]).unwrap_err();
148        assert!(matches!(err, SparseError::UnsortedIndices { .. }));
149    }
150
151    #[test]
152    fn new_rejects_nan_and_inf() {
153        assert!(matches!(
154            SparseVector::new(vec![1], vec![f32::NAN]).unwrap_err(),
155            SparseError::NonFiniteWeight { .. }
156        ));
157        assert!(matches!(
158            SparseVector::new(vec![1], vec![f32::INFINITY]).unwrap_err(),
159            SparseError::NonFiniteWeight { .. }
160        ));
161    }
162
163    #[test]
164    fn from_pairs_sorts_and_sums_duplicates() {
165        let v = SparseVector::from_pairs(vec![(9, 1.0), (1, 2.0), (9, 0.5), (1, -0.5)]).unwrap();
166        assert_eq!(v.indices(), &[1, 9]);
167        assert_eq!(v.values(), &[1.5, 1.5]);
168    }
169
170    #[test]
171    fn from_pairs_empty() {
172        let v = SparseVector::from_pairs(vec![]).unwrap();
173        assert!(v.is_empty());
174    }
175
176    #[test]
177    fn iter_yields_pairs_in_order() {
178        let v = SparseVector::new(vec![2, 4], vec![1.0, 2.0]).unwrap();
179        let pairs: Vec<_> = v.iter().collect();
180        assert_eq!(pairs, vec![(2, 1.0), (4, 2.0)]);
181    }
182}