sqlrite/sql/pager/
hnsw_cell.rs1use crate::error::{Result, SQLRiteError};
33use crate::sql::pager::cell::KIND_HNSW;
34use crate::sql::pager::varint;
35
36#[derive(Debug, Clone, PartialEq)]
39pub struct HnswNodeCell {
40 pub node_id: i64,
41 pub layers: Vec<Vec<i64>>,
44}
45
46impl HnswNodeCell {
47 pub fn new(node_id: i64, layers: Vec<Vec<i64>>) -> Self {
48 Self { node_id, layers }
49 }
50
51 pub fn encode(&self) -> Result<Vec<u8>> {
56 if self.layers.is_empty() {
57 return Err(SQLRiteError::Internal(format!(
58 "HNSW node {} has zero layers — every node lives at layer 0 minimum",
59 self.node_id
60 )));
61 }
62
63 let layer_bytes = self.layers.iter().map(|l| 5 + l.len() * 10).sum::<usize>();
67 let mut body = Vec::with_capacity(1 + 10 + 5 + layer_bytes);
68
69 body.push(KIND_HNSW);
70 varint::write_i64(&mut body, self.node_id);
71 varint::write_u64(&mut body, (self.layers.len() - 1) as u64);
73 for layer in &self.layers {
74 varint::write_u64(&mut body, layer.len() as u64);
75 for n in layer {
76 varint::write_i64(&mut body, *n);
77 }
78 }
79
80 let mut out = Vec::with_capacity(body.len() + varint::MAX_VARINT_BYTES);
81 varint::write_u64(&mut out, body.len() as u64);
82 out.extend_from_slice(&body);
83 Ok(out)
84 }
85
86 pub fn decode(buf: &[u8], pos: usize) -> Result<(HnswNodeCell, usize)> {
89 let (body_len, len_bytes) = varint::read_u64(buf, pos)?;
90 let body_start = pos + len_bytes;
91 let body_end = body_start
92 .checked_add(body_len as usize)
93 .ok_or_else(|| SQLRiteError::Internal("HNSW cell length overflow".to_string()))?;
94 if body_end > buf.len() {
95 return Err(SQLRiteError::Internal(format!(
96 "HNSW cell extends past buffer: needs {body_start}..{body_end}, have {}",
97 buf.len()
98 )));
99 }
100 let body = &buf[body_start..body_end];
101 if body.first().copied() != Some(KIND_HNSW) {
102 return Err(SQLRiteError::Internal(format!(
103 "HnswNodeCell::decode called on non-HNSW entry (kind_tag = {:#x})",
104 body.first().copied().unwrap_or(0)
105 )));
106 }
107
108 let mut cur = 1usize;
109 let (node_id, n) = varint::read_i64(body, cur)?;
110 cur += n;
111 let (max_layer_u64, n) = varint::read_u64(body, cur)?;
112 cur += n;
113
114 let layer_count = (max_layer_u64 as usize)
115 .checked_add(1)
116 .ok_or_else(|| SQLRiteError::Internal("HNSW max_layer overflow".to_string()))?;
117 if layer_count > 64 {
121 return Err(SQLRiteError::Internal(format!(
122 "HNSW node {node_id} claims max_layer {} (>= 64) — corrupt cell?",
123 layer_count - 1
124 )));
125 }
126
127 let mut layers = Vec::with_capacity(layer_count);
128 for _ in 0..layer_count {
129 let (count, n) = varint::read_u64(body, cur)?;
130 cur += n;
131 if count > 256 {
135 return Err(SQLRiteError::Internal(format!(
136 "HNSW node {node_id} layer claims {count} neighbors (>256) — corrupt cell?"
137 )));
138 }
139 let mut neighbors = Vec::with_capacity(count as usize);
140 for _ in 0..count {
141 let (id, n) = varint::read_i64(body, cur)?;
142 cur += n;
143 neighbors.push(id);
144 }
145 layers.push(neighbors);
146 }
147
148 if cur != body.len() {
149 return Err(SQLRiteError::Internal(format!(
150 "HNSW cell had {} trailing bytes",
151 body.len() - cur
152 )));
153 }
154
155 Ok((
156 HnswNodeCell { node_id, layers },
157 len_bytes + body_len as usize,
158 ))
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 fn round_trip(cell: &HnswNodeCell) {
167 let bytes = cell.encode().expect("encode");
168 let (decoded, consumed) = HnswNodeCell::decode(&bytes, 0).expect("decode");
169 assert_eq!(
170 consumed,
171 bytes.len(),
172 "decode should consume the whole cell"
173 );
174 assert_eq!(&decoded, cell);
175 }
176
177 #[test]
178 fn single_layer_node_round_trips() {
179 let cell = HnswNodeCell::new(42, vec![vec![1, 2, 3, 5, 8]]);
181 round_trip(&cell);
182 }
183
184 #[test]
185 fn multi_layer_node_round_trips() {
186 let cell = HnswNodeCell::new(
187 17,
188 vec![
189 vec![1, 2, 3, 4, 5, 6, 7, 8], vec![1, 3, 7], vec![3], ],
193 );
194 round_trip(&cell);
195 }
196
197 #[test]
198 fn empty_neighbor_layer_round_trips() {
199 let cell = HnswNodeCell::new(5, vec![vec![1, 2], vec![]]);
202 round_trip(&cell);
203 }
204
205 #[test]
206 fn node_id_negative_and_large() {
207 round_trip(&HnswNodeCell::new(-1, vec![vec![]]));
209 round_trip(&HnswNodeCell::new(i64::MAX, vec![vec![1, 2]]));
210 round_trip(&HnswNodeCell::new(i64::MIN, vec![vec![3, 4]]));
211 }
212
213 #[test]
214 fn zero_layers_is_rejected_at_encode() {
215 let bad = HnswNodeCell::new(1, vec![]);
216 let err = bad.encode().unwrap_err();
217 assert!(format!("{err}").contains("zero layers"));
218 }
219
220 #[test]
221 fn decode_rejects_wrong_kind_tag() {
222 let mut bad = Vec::new();
225 varint::write_u64(&mut bad, 1); bad.push(0x01); let err = HnswNodeCell::decode(&bad, 0).unwrap_err();
228 assert!(format!("{err}").contains("non-HNSW entry"));
229 }
230
231 #[test]
232 fn decode_rejects_truncated_buffer() {
233 let cell = HnswNodeCell::new(1, vec![vec![10, 20, 30]]);
234 let bytes = cell.encode().expect("encode");
235 for chop in 1..=3 {
236 let truncated = &bytes[..bytes.len() - chop];
237 assert!(
238 HnswNodeCell::decode(truncated, 0).is_err(),
239 "expected error chopping {chop} byte(s) from end of {} byte cell",
240 bytes.len()
241 );
242 }
243 }
244
245 #[test]
246 fn decode_rejects_implausible_max_layer() {
247 let mut body = Vec::new();
249 body.push(KIND_HNSW);
250 varint::write_i64(&mut body, 0); varint::write_u64(&mut body, 100); let mut out = Vec::new();
253 varint::write_u64(&mut out, body.len() as u64);
254 out.extend_from_slice(&body);
255 let err = HnswNodeCell::decode(&out, 0).unwrap_err();
256 assert!(format!("{err}").to_lowercase().contains("corrupt"));
257 }
258}