spacetimedb_table/static_bsatn_validator.rs
1//! To efficiently implement a fast-path BSATN -> BFLATN,
2//! we use a `StaticLayout` but in reverse of the read path.
3//! This however leaves us with no way to validate
4//! that the BSATN satisfies the row type of a given table.
5//!
6//! More specifically, we must validate that:
7//! 1. The length of the BSATN-encoded row matches the expected length.
8//! 2. All `bool`s in the row type only receive the values 0 or 1.
9//! 3. All sum tags are valid.
10//! 4. a sum's payload follows 2-3 recursively.
11//!
12//! That is where this module comes in,
13//! which provides two functions:
14//! - [`static_bsatn_validator`], which compiles a validator program given the table's row type.
15//! - [`validate_bsatn`], which executes the validator program against a row encoded in BSATN.
16//!
17//! The compilation uses the same strategy as for row type visitors,
18//! first simplifying to a rose-tree and then flattening that to
19//! a simple forward-progress-only byte code instruction set.
20
21#![allow(unused)]
22
23use super::{
24 layout::{AlgebraicTypeLayout, HasLayout as _, ProductTypeLayout, RowTypeLayout},
25 static_layout::StaticLayout,
26 MemoryUsage,
27};
28use itertools::{repeat_n, Itertools as _};
29use spacetimedb_sats::bsatn::DecodeError;
30use spacetimedb_schema::type_for_generate::PrimitiveType;
31use std::sync::Arc;
32
33/// Constructs a validator for a row encoded in BSATN
34/// that checks that the row satisfies the type `ty`
35/// when `ty` has `StaticLayout`.
36///
37/// This is a potentially expensive operation,
38/// so the resulting `StaticBsatnValidator` should be stored and re-used.
39pub(crate) fn static_bsatn_validator(ty: &RowTypeLayout) -> StaticBsatnValidator {
40 let tree = row_type_to_tree(ty.product());
41 let insns = tree_to_insns(&tree).into();
42 StaticBsatnValidator { insns }
43}
44
45/// Construct a `Tree` from `ty`.
46///
47/// See [`extend_trees_for_algebraic_type`] for more details.
48fn row_type_to_tree(ty: &ProductTypeLayout) -> Tree {
49 let mut sub_trees = Vec::new();
50 extend_trees_for_product_type(ty, &mut 0, &mut sub_trees);
51 sub_trees_to_tree(sub_trees)
52}
53
54/// Convert a list of `sub_trees` to one tree.
55fn sub_trees_to_tree(mut sub_trees: Vec<Tree>) -> Tree {
56 match sub_trees.len() {
57 // No trees is `Empty`.
58 0 => Tree::Empty,
59 // A single subtree can be collapsed.
60 // so prune the intermediate node.
61 1 => sub_trees.pop().unwrap(),
62 // For more than one children, sequence them doing one after the other.
63 _ => Tree::Sequence { sub_trees },
64 }
65}
66
67/// Extend `sub_trees` with checks for `ty`.
68///
69/// See [`extend_trees_for_algebraic_type`] for more details.
70fn extend_trees_for_product_type(ty: &ProductTypeLayout, current_offset: &mut usize, sub_trees: &mut Vec<Tree>) {
71 for elem in &*ty.elements {
72 extend_trees_for_algebraic_type(&elem.ty, current_offset, sub_trees);
73 }
74}
75
76/// Extend `sub_trees` with checks for `ty`.
77///
78/// `current_offset` should be passed as `&mut 0` upon entry to the row-type,
79/// and will be incremented as appropriate during recursive traversal
80/// to track the offset in bytes of the member currently being visited.
81fn extend_trees_for_algebraic_type(ty: &AlgebraicTypeLayout, current_offset: &mut usize, sub_trees: &mut Vec<Tree>) {
82 match ty {
83 AlgebraicTypeLayout::Primitive(PrimitiveType::Bool) => {
84 // The `Bool` type is special, as it only allows a BSATN byte to be 0 or 1.
85 let offset = *current_offset as u16;
86 *current_offset += 1;
87 sub_trees.push(Tree::CheckBool { offset });
88 }
89 AlgebraicTypeLayout::Primitive(prim_ty) => {
90 // For primitive types, increment `current_offset` past this member.
91 // Primitive types have no padding, so we can use `prim_ty.size()` for bsatn.
92 *current_offset += prim_ty.size();
93 }
94 AlgebraicTypeLayout::Product(prod_ty) => extend_trees_for_product_type(prod_ty, current_offset, sub_trees),
95 AlgebraicTypeLayout::Sum(sum_ty) => {
96 // Record the tag's offset and the number of variants.
97 let num_variants = sum_ty.variants.len() as u8;
98 let tag_offset = *current_offset as u16;
99 *current_offset += 1;
100
101 // For each variant, collect that variant's sub-tree.
102 // All variants are stored overlapping at the offset of the sum
103 // so we must reset `current_offset` each time to the before-variant value.
104 // We also need to create a fresh `sub_tree` context.
105 // Note that BSATN stores sums with tag first,
106 // followed by data/payload.
107 let mut child_offset = *current_offset;
108 let mut variants = sum_ty
109 .variants
110 .iter()
111 .map(|variant| {
112 let var_ty = &variant.ty;
113 let mut sub_trees = Vec::new();
114 child_offset = *current_offset;
115 extend_trees_for_algebraic_type(var_ty, &mut child_offset, &mut sub_trees);
116 sub_trees_to_tree(sub_trees)
117 })
118 .collect::<Vec<_>>();
119 // Having dealt with all variants,
120 // we must now move `current_offset` forward to the size of the payload
121 // which we know to be same for all variants.
122 *current_offset = child_offset;
123
124 if variants.iter().all_equal() {
125 // When all variants have the same set checks,
126 // there's no need to switch on the tag, so prune the intermediate node.
127 // A special case of this is single-variant sums.
128 sub_trees.push(Tree::CheckTag {
129 tag_offset,
130 num_variants,
131 });
132 if let Some(tree) = variants.pop() {
133 sub_trees.push(tree);
134 }
135 } else {
136 sub_trees.push(Tree::Sum {
137 tag_offset,
138 tag_data_processors: variants,
139 });
140 }
141 }
142
143 // There are no var-len members when there's a static fixed bsatn length.
144 AlgebraicTypeLayout::VarLen(_) => unreachable!(),
145 }
146}
147
148/// A [Rose Tree](https://en.wikipedia.org/wiki/Rose_tree)
149/// containing information about validation steps for
150/// decoding BSATN for statically known fixed size `AlgebraicType`s.
151#[derive(Debug, PartialEq, Eq)]
152enum Tree {
153 /// Nothing to check.
154 Empty,
155
156 /// Do each sub-tree after each other.
157 Sequence { sub_trees: Vec<Tree> },
158
159 /// Check a byte at `start + N` bytes to be a valid `bool`.
160 CheckBool { offset: u16 },
161
162 /// Check a byte at `start + N` bytes to be `< num_variants`.
163 CheckTag {
164 /// The sum's tag is at `row + tag_offset` bytes.
165 tag_offset: u16,
166 /// The number of variants there are.
167 /// The read tag must be `< num_variants`.
168 num_variants: u8,
169 },
170
171 /// A choice between several variants.
172 Sum {
173 /// The sum's tag is at `row + tag_offset` bytes.
174 tag_offset: u16,
175 /// The checks for variant `N` are described in `tag_data_processors[N]`.
176 tag_data_processors: Vec<Tree>,
177 },
178}
179
180/// Compile the [`Tree`] to a list of [`Insn`].
181fn tree_to_insns(tree: &Tree) -> Vec<Insn> {
182 let mut program = Vec::new();
183
184 fn compile_tree(tree: &Tree, into: &mut Vec<Insn>) {
185 match tree {
186 Tree::Empty => {}
187 &Tree::CheckBool { offset } => into.push(Insn::CheckBool(offset)),
188 Tree::Sequence { sub_trees } => {
189 for tree in &**sub_trees {
190 compile_tree(tree, into);
191 }
192 }
193 &Tree::CheckTag {
194 tag_offset,
195 num_variants,
196 } => into.push(Insn::CheckTag(CheckTag {
197 tag_offset,
198 num_variants,
199 })),
200 Tree::Sum {
201 tag_offset,
202 tag_data_processors,
203 } => {
204 // Add the branching instruction itself.
205 let num_variants = tag_data_processors.len();
206 into.push(Insn::CheckReadTagRelBranch(CheckTag {
207 tag_offset: *tag_offset,
208 num_variants: num_variants as u8,
209 }));
210 // Add N slots for "to variant goto"s.
211 let to_branches = into.len();
212 into.extend(repeat_n(Insn::FIXUP, num_variants));
213 // Compile the branches.
214 let mut from_variant_gotos = Vec::with_capacity(num_variants);
215 for (tag, branch) in tag_data_processors.iter().enumerate() {
216 // Fixup the to-variant jump address.
217 into[to_branches + tag] = Insn::Goto(into.len() as u16);
218 // Compile the branch.
219 compile_tree(branch, into);
220 // Add jump-out gotos that we'll fixup later to store the after-sum address.
221 from_variant_gotos.push(into.len());
222 into.push(Insn::FIXUP);
223 }
224 // Fixup the jump-out-from-variant addresses.
225 let goto_addr = into.len();
226 for idx in from_variant_gotos {
227 into[idx] = Insn::Goto(goto_addr as u16);
228 }
229 }
230 }
231 }
232
233 compile_tree(tree, &mut program);
234 remove_trailing_gotos(&mut program);
235 program
236}
237
238/// Remove any trailing gotos.
239///
240/// They are not needed as they will only go towards the end,
241/// so we can just cut them out.
242fn remove_trailing_gotos(program: &mut Vec<Insn>) {
243 for idx in (0..program.len()).rev() {
244 match program[idx] {
245 Insn::Goto(_) => program.pop(),
246 _ => break,
247 };
248 }
249}
250
251#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252struct CheckTag {
253 /// The tag to check is stored at `start + tag_offset`.
254 tag_offset: u16,
255 /// The number of variants there are.
256 /// The read tag must be `< num_variants`.
257 num_variants: u8,
258}
259
260/// The instruction set of a [`StaticBsatnValidator`].
261#[derive(Debug, Clone, Copy, PartialEq, Eq)]
262enum Insn {
263 /// Visit the byte at offset `start + N`
264 /// and assert that it is 0 or 1, i.e., a valid `bool`.
265 CheckBool(u16),
266
267 /// Read the `tag` at `start + tag_offset`
268 /// and validate that `tag < num_variants`.
269 CheckTag(CheckTag),
270
271 /// Read the `tag` at `start + tag_offset`
272 /// and validate that `tag < num_variants`.
273 /// Then move the instruction pointer forward by `tag + 1`.
274 /// The branch logic for the variant payload continues there.
275 CheckReadTagRelBranch(CheckTag),
276
277 /// Unconditionally branch to the instruction at `program[N]`
278 /// where `N > instruction pointer`.
279 Goto(u16),
280}
281
282impl Insn {
283 const FIXUP: Self = Self::Goto(u16::MAX);
284}
285
286impl MemoryUsage for Insn {}
287
288#[derive(Clone, Debug, PartialEq, Eq)]
289pub struct StaticBsatnValidator {
290 /// The list of instructions that make up this program.
291 insns: Arc<[Insn]>,
292}
293
294impl MemoryUsage for StaticBsatnValidator {
295 fn heap_usage(&self) -> usize {
296 let Self { insns } = self;
297 insns.heap_usage()
298 }
299}
300
301/// Check that `bytes[tag_offset] < num_variants`.
302///
303/// SAFETY: `tag_offset < bytes.len()`.
304unsafe fn check_tag(bytes: &[u8], check: CheckTag) -> Result<u8, DecodeError> {
305 // SAFETY: the caller has guaranteed that `tag_offset < bytes.len()`.
306 let tag = *unsafe { bytes.get_unchecked(check.tag_offset as usize) };
307 if tag < check.num_variants {
308 Ok(tag)
309 } else {
310 Err(DecodeError::InvalidTag { tag, sum_name: None })
311 }
312}
313
314/// Validates that `bytes`, encoded in BSATN,
315/// is valid according to the validation `program`
316/// and a corresponding `static_layout`,
317///
318/// # Safety
319///
320/// The caller must guarantee that
321/// all offsets in `program` are `< static_layout.bsatn_length`.
322pub(crate) unsafe fn validate_bsatn(
323 program: &StaticBsatnValidator,
324 static_layout: &StaticLayout,
325 bytes: &[u8],
326) -> Result<(), DecodeError> {
327 // Validate length of BSATN `bytes` against the expected length.
328 let expected = static_layout.bsatn_length as usize;
329 let given = bytes.len();
330 if expected != given {
331 return Err(DecodeError::InvalidLen { expected, given });
332 }
333
334 let program = &*program.insns;
335 let mut instr_ptr = 0;
336 loop {
337 match program.get(instr_ptr as usize).copied() {
338 None => break,
339 Some(Insn::CheckBool(offset)) => {
340 instr_ptr += 1;
341 // SAFETY: the caller has guaranteed
342 // that all offsets in `program` are `< expected`
343 // which we by now know is `= bytes.len()`.
344 let byte = *unsafe { bytes.get_unchecked(offset as usize) };
345 if byte > 1 {
346 return Err(DecodeError::InvalidBool(byte));
347 }
348 }
349 Some(Insn::Goto(new_insn)) => instr_ptr = new_insn,
350 Some(Insn::CheckTag(check)) => {
351 // SAFETY: the caller has guaranteed
352 // that all offsets in `program` are `< expected`
353 // which we by now know is `= bytes.len()`.
354 unsafe { check_tag(bytes, check) }?;
355 instr_ptr += 1;
356 }
357 Some(Insn::CheckReadTagRelBranch(check)) => {
358 // SAFETY: the caller has guaranteed
359 // that all offsets in `program` are `< expected`
360 // which we by now know is `= bytes.len()`.
361 let tag = unsafe { check_tag(bytes, check) }?;
362 instr_ptr += tag as u16 + 1;
363 }
364 }
365 }
366
367 Ok(())
368}
369
370#[cfg(test)]
371pub mod test {
372 use super::*;
373 use crate::{
374 bflatn_to::write_row_to_page, blob_store::HashMapBlobStore, page::Page, row_type_visitor::row_type_visitor,
375 };
376 use proptest::{prelude::*, prop_assert_eq, proptest};
377 use spacetimedb_sats::bsatn::to_vec;
378 use spacetimedb_sats::proptest::generate_typed_row;
379 use spacetimedb_sats::{AlgebraicType, ProductType};
380
381 proptest! {
382 // This test checks that `validate_bsatn(...).is_ok() == write_row_to_page(..).is_ok()`.
383 #![proptest_config(ProptestConfig {
384 max_global_rejects: 65536,
385 cases: if cfg!(miri) { 8 } else { 2048 },
386 ..<_>::default()
387 })]
388 #[test]
389 fn validation_same_as_write_row_to_pages((ty, val) in generate_typed_row()) {
390 let ty: RowTypeLayout = ty.into();
391 let Some(static_layout) = StaticLayout::for_row_type(&ty) else {
392 // `ty` has a var-len member or a sum with different payload lengths,
393 // so the fast path doesn't apply.
394 return Err(TestCaseError::reject("Var-length type"));
395 };
396 let validator = static_bsatn_validator(&ty);
397 let bsatn = to_vec(&val).unwrap();
398 let res_validate = unsafe { validate_bsatn(&validator, &static_layout, &bsatn) };
399
400 let mut page = Page::new(ty.size());
401 let visitor = row_type_visitor(&ty);
402 let blob_store = &mut HashMapBlobStore::default();
403 let res_write = unsafe { write_row_to_page(&mut page, blob_store, &visitor, &ty, &val) };
404
405 prop_assert_eq!(res_validate.is_ok(), res_write.is_ok());
406 }
407
408 #[test]
409 fn bad_bool_validates_to_error(byte in 2u8..) {
410 let ty: RowTypeLayout = ProductType::from([AlgebraicType::Bool]).into();
411 let static_layout = StaticLayout::for_row_type(&ty).unwrap();
412 let validator = static_bsatn_validator(&ty);
413
414 let bsatn = [byte];
415 let res_validate = unsafe { validate_bsatn(&validator, &static_layout, &bsatn) };
416 prop_assert_eq!(res_validate, Err(DecodeError::InvalidBool(byte)));
417 }
418 }
419}