Skip to main content

uni_sparse_vector/
encode.rs

1use crate::error::SparseError;
2use crate::sparse::SparseVector;
3
4/// Canonical binary form of a [`SparseVector`].
5///
6/// Layout (little-endian, variable length):
7/// ```text
8/// [count: u32] [indices: count × u32] [values: count × f32]
9/// ```
10/// Indices and values are stored in separate runs (not interleaved) to match
11/// the Arrow `Struct{indices: List<UInt32>, values: List<Float32>}` lowering
12/// and to keep each run contiguous for scoring. Weights are lossless `f32`;
13/// quantization is applied by the storage engine at the postings boundary, not
14/// here.
15pub fn encode(sv: &SparseVector) -> Vec<u8> {
16    let n = sv.len();
17    let mut buf = Vec::with_capacity(4 + n * 8);
18    buf.extend_from_slice(&(n as u32).to_le_bytes());
19    for &idx in sv.indices() {
20        buf.extend_from_slice(&idx.to_le_bytes());
21    }
22    for &val in sv.values() {
23        buf.extend_from_slice(&val.to_le_bytes());
24    }
25    buf
26}
27
28/// Decode a [`SparseVector`] from its canonical binary form, re-validating all
29/// invariants (so a corrupted or hand-built buffer cannot smuggle in unsorted
30/// indices or NaN weights).
31pub fn decode_slice(bytes: &[u8]) -> Result<SparseVector, SparseError> {
32    if bytes.len() < 4 {
33        return Err(SparseError::Truncated {
34            need: 4,
35            got: bytes.len(),
36        });
37    }
38    let count = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
39    let need = 4 + count * 8;
40    if bytes.len() < need {
41        return Err(SparseError::Truncated {
42            need,
43            got: bytes.len(),
44        });
45    }
46    if bytes.len() > need {
47        return Err(SparseError::TrailingBytes {
48            trailing: bytes.len() - need,
49        });
50    }
51
52    let mut indices = Vec::with_capacity(count);
53    let mut off = 4;
54    for _ in 0..count {
55        indices.push(u32::from_le_bytes([
56            bytes[off],
57            bytes[off + 1],
58            bytes[off + 2],
59            bytes[off + 3],
60        ]));
61        off += 4;
62    }
63    let mut values = Vec::with_capacity(count);
64    for _ in 0..count {
65        values.push(f32::from_le_bytes([
66            bytes[off],
67            bytes[off + 1],
68            bytes[off + 2],
69            bytes[off + 3],
70        ]));
71        off += 4;
72    }
73
74    SparseVector::new(indices, values)
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80
81    #[test]
82    fn roundtrip_basic() {
83        let original = SparseVector::new(vec![1, 7, 42], vec![0.25, -1.5, 3.0]).unwrap();
84        let bytes = encode(&original);
85        let decoded = decode_slice(&bytes).unwrap();
86        assert_eq!(original, decoded);
87    }
88
89    #[test]
90    fn roundtrip_empty() {
91        let original = SparseVector::new(vec![], vec![]).unwrap();
92        let bytes = encode(&original);
93        assert_eq!(bytes.len(), 4);
94        assert_eq!(decode_slice(&bytes).unwrap(), original);
95    }
96
97    #[test]
98    fn truncated_header_rejected() {
99        assert!(matches!(
100            decode_slice(&[0u8; 3]).unwrap_err(),
101            SparseError::Truncated { .. }
102        ));
103    }
104
105    #[test]
106    fn truncated_payload_rejected() {
107        let original = SparseVector::new(vec![1, 2], vec![1.0, 2.0]).unwrap();
108        let mut bytes = encode(&original);
109        bytes.truncate(bytes.len() - 1);
110        assert!(matches!(
111            decode_slice(&bytes).unwrap_err(),
112            SparseError::Truncated { .. }
113        ));
114    }
115
116    #[test]
117    fn trailing_bytes_rejected() {
118        let original = SparseVector::new(vec![1], vec![1.0]).unwrap();
119        let mut bytes = encode(&original);
120        bytes.push(0xFF);
121        assert!(matches!(
122            decode_slice(&bytes).unwrap_err(),
123            SparseError::TrailingBytes { .. }
124        ));
125    }
126
127    #[test]
128    fn decode_revalidates_invariants() {
129        // Hand-build a buffer with unsorted indices [5, 1]; decode must reject.
130        let mut bytes = Vec::new();
131        bytes.extend_from_slice(&2u32.to_le_bytes());
132        bytes.extend_from_slice(&5u32.to_le_bytes());
133        bytes.extend_from_slice(&1u32.to_le_bytes());
134        bytes.extend_from_slice(&1.0f32.to_le_bytes());
135        bytes.extend_from_slice(&2.0f32.to_le_bytes());
136        assert!(matches!(
137            decode_slice(&bytes).unwrap_err(),
138            SparseError::UnsortedIndices { .. }
139        ));
140    }
141}