1use crate::error::{Result, SQLRiteError};
40use crate::sql::pager::cell::KIND_FTS_POSTING;
41use crate::sql::pager::varint;
42
43#[derive(Debug, Clone, PartialEq)]
46pub struct FtsPostingCell {
47 pub cell_id: i64,
50 pub term: String,
52 pub entries: Vec<(i64, u32)>,
55}
56
57impl FtsPostingCell {
58 pub fn posting(cell_id: i64, term: String, entries: Vec<(i64, u32)>) -> Self {
59 Self {
60 cell_id,
61 term,
62 entries,
63 }
64 }
65
66 pub fn doc_lengths(cell_id: i64, entries: Vec<(i64, u32)>) -> Self {
68 Self {
69 cell_id,
70 term: String::new(),
71 entries,
72 }
73 }
74
75 pub fn encode(&self) -> Result<Vec<u8>> {
80 let pair_bytes = self.entries.len() * 15;
83 let mut body = Vec::with_capacity(1 + 10 + 5 + self.term.len() + 5 + pair_bytes);
84
85 body.push(KIND_FTS_POSTING);
86 varint::write_i64(&mut body, self.cell_id);
87 varint::write_u64(&mut body, self.term.len() as u64);
88 body.extend_from_slice(self.term.as_bytes());
89 varint::write_u64(&mut body, self.entries.len() as u64);
90 for (rowid, value) in &self.entries {
91 varint::write_i64(&mut body, *rowid);
92 varint::write_u64(&mut body, *value as u64);
93 }
94
95 let mut out = Vec::with_capacity(body.len() + varint::MAX_VARINT_BYTES);
96 varint::write_u64(&mut out, body.len() as u64);
97 out.extend_from_slice(&body);
98 Ok(out)
99 }
100
101 pub fn decode(buf: &[u8], pos: usize) -> Result<(FtsPostingCell, usize)> {
104 let (body_len, len_bytes) = varint::read_u64(buf, pos)?;
105 let body_start = pos + len_bytes;
106 let body_end = body_start
107 .checked_add(body_len as usize)
108 .ok_or_else(|| SQLRiteError::Internal("FTS cell length overflow".to_string()))?;
109 if body_end > buf.len() {
110 return Err(SQLRiteError::Internal(format!(
111 "FTS cell extends past buffer: needs {body_start}..{body_end}, have {}",
112 buf.len()
113 )));
114 }
115 let body = &buf[body_start..body_end];
116 if body.first().copied() != Some(KIND_FTS_POSTING) {
117 return Err(SQLRiteError::Internal(format!(
118 "FtsPostingCell::decode called on non-FTS entry (kind_tag = {:#x})",
119 body.first().copied().unwrap_or(0)
120 )));
121 }
122
123 let mut cur = 1usize;
124 let (cell_id, n) = varint::read_i64(body, cur)?;
125 cur += n;
126
127 let (term_len, n) = varint::read_u64(body, cur)?;
128 cur += n;
129 if term_len as usize > body.len().saturating_sub(cur) {
134 return Err(SQLRiteError::Internal(format!(
135 "FTS cell {cell_id}: term_len {term_len} exceeds remaining body \
136 ({}) — corrupt cell?",
137 body.len() - cur
138 )));
139 }
140 let term_bytes = &body[cur..cur + term_len as usize];
141 cur += term_len as usize;
142 let term = std::str::from_utf8(term_bytes)
143 .map_err(|e| {
144 SQLRiteError::Internal(format!("FTS cell {cell_id}: term not valid UTF-8: {e}"))
145 })?
146 .to_string();
147
148 let (count, n) = varint::read_u64(body, cur)?;
149 cur += n;
150 if count > 1 << 28 {
154 return Err(SQLRiteError::Internal(format!(
155 "FTS cell {cell_id}: claims {count} entries (>2^28) — corrupt cell?"
156 )));
157 }
158 let mut entries = Vec::with_capacity(count as usize);
159 for _ in 0..count {
160 let (rowid, n) = varint::read_i64(body, cur)?;
161 cur += n;
162 let (value_u64, n) = varint::read_u64(body, cur)?;
163 cur += n;
164 if value_u64 > u32::MAX as u64 {
168 return Err(SQLRiteError::Internal(format!(
169 "FTS cell {cell_id}: value {value_u64} exceeds u32::MAX — corrupt cell?"
170 )));
171 }
172 entries.push((rowid, value_u64 as u32));
173 }
174
175 if cur != body.len() {
176 return Err(SQLRiteError::Internal(format!(
177 "FTS cell {cell_id} had {} trailing bytes",
178 body.len() - cur
179 )));
180 }
181
182 Ok((
183 FtsPostingCell {
184 cell_id,
185 term,
186 entries,
187 },
188 len_bytes + body_len as usize,
189 ))
190 }
191
192 pub fn is_doc_lengths(&self) -> bool {
194 self.term.is_empty()
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 fn round_trip(cell: &FtsPostingCell) {
203 let bytes = cell.encode().expect("encode");
204 let (decoded, consumed) = FtsPostingCell::decode(&bytes, 0).expect("decode");
205 assert_eq!(
206 consumed,
207 bytes.len(),
208 "decode should consume the whole cell"
209 );
210 assert_eq!(&decoded, cell);
211 }
212
213 #[test]
214 fn posting_cell_round_trips() {
215 let cell = FtsPostingCell::posting(7, "rust".to_string(), vec![(1, 2), (3, 1), (5, 7)]);
216 round_trip(&cell);
217 }
218
219 #[test]
220 fn doc_lengths_sidecar_round_trips() {
221 let cell = FtsPostingCell::doc_lengths(1, vec![(1, 12), (2, 20), (3, 0), (4, 7)]);
222 assert!(cell.is_doc_lengths());
223 round_trip(&cell);
224 }
225
226 #[test]
227 fn empty_postings_round_trips() {
228 let cell = FtsPostingCell::posting(2, "ghost".to_string(), vec![]);
232 round_trip(&cell);
233 }
234
235 #[test]
236 fn negative_and_large_rowids_round_trip() {
237 round_trip(&FtsPostingCell::posting(
239 3,
240 "x".to_string(),
241 vec![(-1, 1), (i64::MAX, 99), (i64::MIN, 1)],
242 ));
243 }
244
245 #[test]
246 fn long_term_round_trips() {
247 let term = "a".repeat(1024);
251 let cell = FtsPostingCell::posting(4, term, vec![(1, 1)]);
252 round_trip(&cell);
253 }
254
255 #[test]
256 fn long_posting_list_round_trips() {
257 let entries: Vec<(i64, u32)> = (0..5000_i64).map(|i| (i, ((i * 3) as u32) + 1)).collect();
259 let cell = FtsPostingCell::posting(5, "common".to_string(), entries);
260 round_trip(&cell);
261 }
262
263 #[test]
264 fn decode_rejects_wrong_kind_tag() {
265 let mut bad = Vec::new();
266 varint::write_u64(&mut bad, 1); bad.push(0x01); let err = FtsPostingCell::decode(&bad, 0).unwrap_err();
269 assert!(format!("{err}").contains("non-FTS entry"));
270 }
271
272 #[test]
273 fn decode_rejects_truncated_buffer() {
274 let cell = FtsPostingCell::posting(1, "rust".to_string(), vec![(1, 2), (5, 3)]);
275 let bytes = cell.encode().expect("encode");
276 for chop in 1..=3 {
277 let truncated = &bytes[..bytes.len() - chop];
278 assert!(
279 FtsPostingCell::decode(truncated, 0).is_err(),
280 "expected error chopping {chop} byte(s) from end of {} byte cell",
281 bytes.len()
282 );
283 }
284 }
285
286 #[test]
287 fn decode_rejects_invalid_utf8_term() {
288 let mut body = Vec::new();
290 body.push(KIND_FTS_POSTING);
291 varint::write_i64(&mut body, 1); varint::write_u64(&mut body, 2); body.extend_from_slice(&[0xFF, 0xFE]); varint::write_u64(&mut body, 0); let mut out = Vec::new();
296 varint::write_u64(&mut out, body.len() as u64);
297 out.extend_from_slice(&body);
298 let err = FtsPostingCell::decode(&out, 0).unwrap_err();
299 assert!(format!("{err}").to_lowercase().contains("utf-8"));
300 }
301
302 #[test]
303 fn decode_rejects_implausible_count() {
304 let mut body = Vec::new();
306 body.push(KIND_FTS_POSTING);
307 varint::write_i64(&mut body, 1);
308 varint::write_u64(&mut body, 4);
309 body.extend_from_slice(b"term");
310 varint::write_u64(&mut body, 1u64 << 29);
311 let mut out = Vec::new();
312 varint::write_u64(&mut out, body.len() as u64);
313 out.extend_from_slice(&body);
314 let err = FtsPostingCell::decode(&out, 0).unwrap_err();
315 assert!(format!("{err}").to_lowercase().contains("corrupt"));
316 }
317}