Skip to main content

uni_sparse_vector/
encode.rs

1use crate::error::SparseError;
2use crate::sparse::SparseVector;
3
4/// Canonical binary form of a [`SparseVector`].
5///
6/// Layout (little-endian, variable length):
7/// ```text
8/// [count: u32] [indices: count × u32] [values: count × f32]
9/// ```
10/// Indices and values are stored in separate runs (not interleaved) to match
11/// the Arrow `Struct{indices: List<UInt32>, values: List<Float32>}` lowering
12/// and to keep each run contiguous for scoring. Weights are lossless `f32`;
13/// quantization is applied by the storage engine at the postings boundary, not
14/// here.
15pub fn encode(sv: &SparseVector) -> Vec<u8> {
16    let n = sv.len();
17    let mut buf = Vec::with_capacity(4 + n * 8);
18    buf.extend_from_slice(&(n as u32).to_le_bytes());
19    for &idx in sv.indices() {
20        buf.extend_from_slice(&idx.to_le_bytes());
21    }
22    for &val in sv.values() {
23        buf.extend_from_slice(&val.to_le_bytes());
24    }
25    buf
26}
27
28/// Decode a [`SparseVector`] from its canonical binary form, re-validating all
29/// invariants (so a corrupted or hand-built buffer cannot smuggle in unsorted
30/// indices or NaN weights).
31pub fn decode_slice(bytes: &[u8]) -> Result<SparseVector, SparseError> {
32    if bytes.len() < 4 {
33        return Err(SparseError::Truncated {
34            need: 4,
35            got: bytes.len(),
36        });
37    }
38    let count = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
39    let need = 4 + count * 8;
40    if bytes.len() < need {
41        return Err(SparseError::Truncated {
42            need,
43            got: bytes.len(),
44        });
45    }
46    if bytes.len() > need {
47        return Err(SparseError::TrailingBytes {
48            trailing: bytes.len() - need,
49        });
50    }
51
52    // The payload is two contiguous runs of `count` little-endian 4-byte words:
53    // the indices first, then the values. `need` was checked above, so each run
54    // splits cleanly into 4-byte chunks.
55    let (index_bytes, value_bytes) = bytes[4..need].split_at(count * 4);
56    let indices = index_bytes
57        .chunks_exact(4)
58        .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
59        .collect();
60    let values = value_bytes
61        .chunks_exact(4)
62        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
63        .collect();
64
65    SparseVector::new(indices, values)
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71
72    #[test]
73    fn roundtrip_basic() {
74        let original = SparseVector::new(vec![1, 7, 42], vec![0.25, -1.5, 3.0]).unwrap();
75        let bytes = encode(&original);
76        let decoded = decode_slice(&bytes).unwrap();
77        assert_eq!(original, decoded);
78    }
79
80    #[test]
81    fn roundtrip_empty() {
82        let original = SparseVector::new(vec![], vec![]).unwrap();
83        let bytes = encode(&original);
84        assert_eq!(bytes.len(), 4);
85        assert_eq!(decode_slice(&bytes).unwrap(), original);
86    }
87
88    #[test]
89    fn truncated_header_rejected() {
90        assert!(matches!(
91            decode_slice(&[0u8; 3]).unwrap_err(),
92            SparseError::Truncated { .. }
93        ));
94    }
95
96    #[test]
97    fn truncated_payload_rejected() {
98        let original = SparseVector::new(vec![1, 2], vec![1.0, 2.0]).unwrap();
99        let mut bytes = encode(&original);
100        bytes.truncate(bytes.len() - 1);
101        assert!(matches!(
102            decode_slice(&bytes).unwrap_err(),
103            SparseError::Truncated { .. }
104        ));
105    }
106
107    #[test]
108    fn trailing_bytes_rejected() {
109        let original = SparseVector::new(vec![1], vec![1.0]).unwrap();
110        let mut bytes = encode(&original);
111        bytes.push(0xFF);
112        assert!(matches!(
113            decode_slice(&bytes).unwrap_err(),
114            SparseError::TrailingBytes { .. }
115        ));
116    }
117
118    #[test]
119    fn decode_revalidates_invariants() {
120        // Hand-build a buffer with unsorted indices [5, 1]; decode must reject.
121        let mut bytes = Vec::new();
122        bytes.extend_from_slice(&2u32.to_le_bytes());
123        bytes.extend_from_slice(&5u32.to_le_bytes());
124        bytes.extend_from_slice(&1u32.to_le_bytes());
125        bytes.extend_from_slice(&1.0f32.to_le_bytes());
126        bytes.extend_from_slice(&2.0f32.to_le_bytes());
127        assert!(matches!(
128            decode_slice(&bytes).unwrap_err(),
129            SparseError::UnsortedIndices { .. }
130        ));
131    }
132}