Skip to main content

sqlrite/sql/pager/
hnsw_cell.rs

1//! On-disk format for a single HNSW graph node (Phase 7d.3).
2//!
3//! Each cell carries one node's per-layer neighbor lists. The cells live
4//! on `TableLeaf`-style pages identical to a regular table's data tree —
5//! same slot directory, same sibling `next_page` chain, same interior-
6//! page mechanics from Phase 3d. The only thing different is the per-cell
7//! body, signaled by `KIND_HNSW`.
8//!
9//! Reusing the table-tree shape lets `Cell::peek_rowid` work uniformly
10//! across all cell kinds: it skips `cell_length | kind_tag` and reads the
11//! first varint, which is `node_id` here. So slot-directory binary
12//! search by node_id works without HNSW-specific code in the page-level
13//! plumbing.
14//!
15//! ```text
16//!   cell_length   varint          bytes after this field
17//!   kind_tag      u8 = 0x05       (KIND_HNSW)
18//!   node_id       zigzag varint   the rowid this graph node represents
19//!   max_layer     varint          highest layer this node lives in
20//!   for layer in 0..=max_layer:
21//!     count       varint          number of neighbors at this layer
22//!     for each neighbor:
23//!       neighbor  zigzag varint   neighbor's node_id
24//! ```
25//!
26//! No null bitmap — every field is always present. No type tag — every
27//! field has a fixed type (varint or zigzag varint). The encoding is
28//! deliberately minimal because HNSW indexes can have N nodes each with
29//! up to ~M·log(N) total neighbors, and we don't want the per-cell
30//! overhead to dominate disk usage.
31
32use crate::error::{Result, SQLRiteError};
33use crate::sql::pager::cell::KIND_HNSW;
34use crate::sql::pager::varint;
35
36/// One HNSW node's persisted form. `layers[i]` is the list of neighbor
37/// node_ids at layer i; the node lives at every layer 0..=layers.len()-1.
38#[derive(Debug, Clone, PartialEq)]
39pub struct HnswNodeCell {
40    pub node_id: i64,
41    /// `layers[0]` is the densest layer (always present); `layers.len()`
42    /// equals the node's max_layer + 1.
43    pub layers: Vec<Vec<i64>>,
44}
45
46impl HnswNodeCell {
47    pub fn new(node_id: i64, layers: Vec<Vec<i64>>) -> Self {
48        Self { node_id, layers }
49    }
50
51    /// Encodes the cell into a freshly-allocated `Vec<u8>`. The result
52    /// starts with the shared `cell_length | kind_tag` prefix and is
53    /// directly usable as a slot-directory entry on a `TableLeaf`-style
54    /// page.
55    pub fn encode(&self) -> Result<Vec<u8>> {
56        if self.layers.is_empty() {
57            return Err(SQLRiteError::Internal(format!(
58                "HNSW node {} has zero layers — every node lives at layer 0 minimum",
59                self.node_id
60            )));
61        }
62
63        // Body capacity guess: 1 (kind) + 10 (node_id) + 5 (max_layer)
64        // + per-layer overhead. Most nodes are layer-0-only so the
65        // typical body is ~1 + 10 + 1 + 1 + M·10 ≈ 175 bytes for M=16.
66        let layer_bytes = self.layers.iter().map(|l| 5 + l.len() * 10).sum::<usize>();
67        let mut body = Vec::with_capacity(1 + 10 + 5 + layer_bytes);
68
69        body.push(KIND_HNSW);
70        varint::write_i64(&mut body, self.node_id);
71        // max_layer = layers.len() - 1
72        varint::write_u64(&mut body, (self.layers.len() - 1) as u64);
73        for layer in &self.layers {
74            varint::write_u64(&mut body, layer.len() as u64);
75            for n in layer {
76                varint::write_i64(&mut body, *n);
77            }
78        }
79
80        let mut out = Vec::with_capacity(body.len() + varint::MAX_VARINT_BYTES);
81        varint::write_u64(&mut out, body.len() as u64);
82        out.extend_from_slice(&body);
83        Ok(out)
84    }
85
86    /// Decodes one cell starting at `pos`. Returns the cell plus the
87    /// total bytes consumed (including the leading length varint).
88    pub fn decode(buf: &[u8], pos: usize) -> Result<(HnswNodeCell, usize)> {
89        let (body_len, len_bytes) = varint::read_u64(buf, pos)?;
90        let body_start = pos + len_bytes;
91        let body_end = body_start
92            .checked_add(body_len as usize)
93            .ok_or_else(|| SQLRiteError::Internal("HNSW cell length overflow".to_string()))?;
94        if body_end > buf.len() {
95            return Err(SQLRiteError::Internal(format!(
96                "HNSW cell extends past buffer: needs {body_start}..{body_end}, have {}",
97                buf.len()
98            )));
99        }
100        let body = &buf[body_start..body_end];
101        if body.first().copied() != Some(KIND_HNSW) {
102            return Err(SQLRiteError::Internal(format!(
103                "HnswNodeCell::decode called on non-HNSW entry (kind_tag = {:#x})",
104                body.first().copied().unwrap_or(0)
105            )));
106        }
107
108        let mut cur = 1usize;
109        let (node_id, n) = varint::read_i64(body, cur)?;
110        cur += n;
111        let (max_layer_u64, n) = varint::read_u64(body, cur)?;
112        cur += n;
113
114        let layer_count = (max_layer_u64 as usize)
115            .checked_add(1)
116            .ok_or_else(|| SQLRiteError::Internal("HNSW max_layer overflow".to_string()))?;
117        // Sanity: max_layer is in practice ≤ ~10 for N ≤ 1B with
118        // m_l ≈ 0.36. A wildly-large value almost certainly means a
119        // corrupt cell — bail before allocating an enormous Vec.
120        if layer_count > 64 {
121            return Err(SQLRiteError::Internal(format!(
122                "HNSW node {node_id} claims max_layer {} (>= 64) — corrupt cell?",
123                layer_count - 1
124            )));
125        }
126
127        let mut layers = Vec::with_capacity(layer_count);
128        for _ in 0..layer_count {
129            let (count, n) = varint::read_u64(body, cur)?;
130            cur += n;
131            // Same sanity bound — a single layer's neighbor list shouldn't
132            // exceed `2 · M_max0` even after pruning bugs. 256 is a
133            // generous cap.
134            if count > 256 {
135                return Err(SQLRiteError::Internal(format!(
136                    "HNSW node {node_id} layer claims {count} neighbors (>256) — corrupt cell?"
137                )));
138            }
139            let mut neighbors = Vec::with_capacity(count as usize);
140            for _ in 0..count {
141                let (id, n) = varint::read_i64(body, cur)?;
142                cur += n;
143                neighbors.push(id);
144            }
145            layers.push(neighbors);
146        }
147
148        if cur != body.len() {
149            return Err(SQLRiteError::Internal(format!(
150                "HNSW cell had {} trailing bytes",
151                body.len() - cur
152            )));
153        }
154
155        Ok((
156            HnswNodeCell { node_id, layers },
157            len_bytes + body_len as usize,
158        ))
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    fn round_trip(cell: &HnswNodeCell) {
167        let bytes = cell.encode().expect("encode");
168        let (decoded, consumed) = HnswNodeCell::decode(&bytes, 0).expect("decode");
169        assert_eq!(
170            consumed,
171            bytes.len(),
172            "decode should consume the whole cell"
173        );
174        assert_eq!(&decoded, cell);
175    }
176
177    #[test]
178    fn single_layer_node_round_trips() {
179        // Most common case: a layer-0-only node with a handful of neighbors.
180        let cell = HnswNodeCell::new(42, vec![vec![1, 2, 3, 5, 8]]);
181        round_trip(&cell);
182    }
183
184    #[test]
185    fn multi_layer_node_round_trips() {
186        let cell = HnswNodeCell::new(
187            17,
188            vec![
189                vec![1, 2, 3, 4, 5, 6, 7, 8], // layer 0 (densest)
190                vec![1, 3, 7],                // layer 1
191                vec![3],                      // layer 2 (sparsest)
192            ],
193        );
194        round_trip(&cell);
195    }
196
197    #[test]
198    fn empty_neighbor_layer_round_trips() {
199        // A node can have an empty layer (e.g. if its only neighbor was
200        // pruned away). The encoding must still survive.
201        let cell = HnswNodeCell::new(5, vec![vec![1, 2], vec![]]);
202        round_trip(&cell);
203    }
204
205    #[test]
206    fn node_id_negative_and_large() {
207        // node_id is zigzag-encoded; cover both signs.
208        round_trip(&HnswNodeCell::new(-1, vec![vec![]]));
209        round_trip(&HnswNodeCell::new(i64::MAX, vec![vec![1, 2]]));
210        round_trip(&HnswNodeCell::new(i64::MIN, vec![vec![3, 4]]));
211    }
212
213    #[test]
214    fn zero_layers_is_rejected_at_encode() {
215        let bad = HnswNodeCell::new(1, vec![]);
216        let err = bad.encode().unwrap_err();
217        assert!(format!("{err}").contains("zero layers"));
218    }
219
220    #[test]
221    fn decode_rejects_wrong_kind_tag() {
222        // Build something that looks like a cell with an arbitrary
223        // (non-HNSW) tag byte and confirm decode bails.
224        let mut bad = Vec::new();
225        varint::write_u64(&mut bad, 1); // body_len
226        bad.push(0x01); // KIND_LOCAL, not KIND_HNSW
227        let err = HnswNodeCell::decode(&bad, 0).unwrap_err();
228        assert!(format!("{err}").contains("non-HNSW entry"));
229    }
230
231    #[test]
232    fn decode_rejects_truncated_buffer() {
233        let cell = HnswNodeCell::new(1, vec![vec![10, 20, 30]]);
234        let bytes = cell.encode().expect("encode");
235        for chop in 1..=3 {
236            let truncated = &bytes[..bytes.len() - chop];
237            assert!(
238                HnswNodeCell::decode(truncated, 0).is_err(),
239                "expected error chopping {chop} byte(s) from end of {} byte cell",
240                bytes.len()
241            );
242        }
243    }
244
245    #[test]
246    fn decode_rejects_implausible_max_layer() {
247        // Hand-craft a cell whose max_layer is 100 (above the 64 sanity bound).
248        let mut body = Vec::new();
249        body.push(KIND_HNSW);
250        varint::write_i64(&mut body, 0); // node_id
251        varint::write_u64(&mut body, 100); // max_layer = 100 → 101 layers
252        let mut out = Vec::new();
253        varint::write_u64(&mut out, body.len() as u64);
254        out.extend_from_slice(&body);
255        let err = HnswNodeCell::decode(&out, 0).unwrap_err();
256        assert!(format!("{err}").to_lowercase().contains("corrupt"));
257    }
258}