Skip to main content

uni_sparse_vector/
sparse.rs

1use crate::error::SparseError;
2use serde::{Deserialize, Serialize};
3
4/// A learned-sparse vector: `{term_id -> weight}` over a high-cardinality
5/// vocabulary (e.g. SPLADE-v3 / BGE-M3 sparse head over a ~30k-term BERT vocab).
6///
7/// Stored as two parallel arrays — `indices` (term ids) and `values` (weights) —
8/// with `indices` kept **strictly ascending** so that the canonical scoring
9/// kernel ([`crate::ops::sparse_dot`]) is a linear merge-join. The invariant is
10/// enforced at construction by [`SparseVector::new`].
11///
12/// The in-memory and binary forms keep weights as lossless `f32`. Weight
13/// quantization (8-bit, etc.) is a storage-engine concern applied at the index
14/// postings boundary, never in this type, so a brute-force scorer over
15/// `SparseVector` is always an exact ground truth.
16#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
17pub struct SparseVector {
18    indices: Vec<u32>,
19    values: Vec<f32>,
20}
21
22impl SparseVector {
23    /// Construct a sparse vector, validating its invariants:
24    /// - `indices.len() == values.len()` (SV-1),
25    /// - `indices` strictly ascending — sorted and unique (SV-2),
26    /// - every weight finite — no NaN / ±inf (SV-3).
27    ///
28    /// Use [`SparseVector::from_pairs`] for unsorted input with duplicate
29    /// term ids (the typical embedding-producer shape).
30    pub fn new(indices: Vec<u32>, values: Vec<f32>) -> Result<Self, SparseError> {
31        if indices.len() != values.len() {
32            return Err(SparseError::LengthMismatch {
33                indices: indices.len(),
34                values: values.len(),
35            });
36        }
37        for (prev_pos, pair) in indices.windows(2).enumerate() {
38            let [prev, curr] = [pair[0], pair[1]];
39            if curr <= prev {
40                return Err(SparseError::UnsortedIndices {
41                    // `pair` starts at index `prev_pos`, so the violating
42                    // element `curr` sits at `prev_pos + 1`.
43                    position: prev_pos + 1,
44                    prev,
45                    curr,
46                });
47            }
48        }
49        for (position, &value) in values.iter().enumerate() {
50            if !value.is_finite() {
51                return Err(SparseError::NonFiniteWeight { position, value });
52            }
53        }
54        Ok(Self { indices, values })
55    }
56
57    /// Build from arbitrary `(term_id, weight)` pairs: sorts by term id and
58    /// sums the weights of duplicate term ids, then validates. This is the
59    /// ingestion-friendly constructor for embedding-model output and for
60    /// `dict[int, float]` from the Python surface.
61    ///
62    /// Non-finite input weights are still rejected (after summation).
63    pub fn from_pairs(mut pairs: Vec<(u32, f32)>) -> Result<Self, SparseError> {
64        pairs.sort_by_key(|&(term, _)| term);
65        let mut indices = Vec::with_capacity(pairs.len());
66        let mut values = Vec::with_capacity(pairs.len());
67        for (term, weight) in pairs {
68            if indices.last() == Some(&term) {
69                *values
70                    .last_mut()
71                    .expect("indices non-empty implies values non-empty") += weight;
72            } else {
73                indices.push(term);
74                values.push(weight);
75            }
76        }
77        Self::new(indices, values)
78    }
79
80    /// The (strictly ascending) term ids.
81    #[inline]
82    pub fn indices(&self) -> &[u32] {
83        &self.indices
84    }
85
86    /// The weights, parallel to [`SparseVector::indices`].
87    #[inline]
88    pub fn values(&self) -> &[f32] {
89        &self.values
90    }
91
92    /// Number of non-zero terms.
93    #[inline]
94    pub fn len(&self) -> usize {
95        self.indices.len()
96    }
97
98    /// Whether the vector has no non-zero terms.
99    #[inline]
100    pub fn is_empty(&self) -> bool {
101        self.indices.is_empty()
102    }
103
104    /// Iterate over `(term_id, weight)` pairs in ascending term-id order.
105    pub fn iter(&self) -> impl Iterator<Item = (u32, f32)> + '_ {
106        self.indices
107            .iter()
108            .copied()
109            .zip(self.values.iter().copied())
110    }
111
112    /// Consume the vector into its parallel arrays.
113    pub fn into_parts(self) -> (Vec<u32>, Vec<f32>) {
114        (self.indices, self.values)
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn new_accepts_sorted_finite() {
124        let v = SparseVector::new(vec![1, 5, 9], vec![0.5, -1.0, 2.0]).unwrap();
125        assert_eq!(v.len(), 3);
126        assert!(!v.is_empty());
127    }
128
129    #[test]
130    fn new_accepts_empty() {
131        let v = SparseVector::new(vec![], vec![]).unwrap();
132        assert!(v.is_empty());
133        assert_eq!(v.len(), 0);
134    }
135
136    #[test]
137    fn new_rejects_length_mismatch() {
138        let err = SparseVector::new(vec![1, 2], vec![0.5]).unwrap_err();
139        assert!(matches!(err, SparseError::LengthMismatch { .. }));
140    }
141
142    #[test]
143    fn new_rejects_unsorted() {
144        let err = SparseVector::new(vec![5, 1], vec![0.5, 0.5]).unwrap_err();
145        assert!(matches!(err, SparseError::UnsortedIndices { .. }));
146    }
147
148    #[test]
149    fn new_rejects_duplicate_indices() {
150        let err = SparseVector::new(vec![3, 3], vec![0.5, 0.5]).unwrap_err();
151        assert!(matches!(err, SparseError::UnsortedIndices { .. }));
152    }
153
154    #[test]
155    fn new_rejects_nan_and_inf() {
156        assert!(matches!(
157            SparseVector::new(vec![1], vec![f32::NAN]).unwrap_err(),
158            SparseError::NonFiniteWeight { .. }
159        ));
160        assert!(matches!(
161            SparseVector::new(vec![1], vec![f32::INFINITY]).unwrap_err(),
162            SparseError::NonFiniteWeight { .. }
163        ));
164    }
165
166    #[test]
167    fn from_pairs_sorts_and_sums_duplicates() {
168        let v = SparseVector::from_pairs(vec![(9, 1.0), (1, 2.0), (9, 0.5), (1, -0.5)]).unwrap();
169        assert_eq!(v.indices(), &[1, 9]);
170        assert_eq!(v.values(), &[1.5, 1.5]);
171    }
172
173    #[test]
174    fn from_pairs_empty() {
175        let v = SparseVector::from_pairs(vec![]).unwrap();
176        assert!(v.is_empty());
177    }
178
179    #[test]
180    fn iter_yields_pairs_in_order() {
181        let v = SparseVector::new(vec![2, 4], vec![1.0, 2.0]).unwrap();
182        let pairs: Vec<_> = v.iter().collect();
183        assert_eq!(pairs, vec![(2, 1.0), (4, 2.0)]);
184    }
185}