spacetimedb_table/static_bsatn_validator.rs
1//! To efficiently implement a fast-path BSATN -> BFLATN,
2//! we use a `StaticLayout` but in reverse of the read path.
3//! This however leaves us with no way to validate
4//! that the BSATN satisfies the row type of a given table.
5//!
6//! More specifically, we must validate that:
7//! 1. The length of the BSATN-encoded row matches the expected length.
8//! 2. All `bool`s in the row type only receive the values 0 or 1.
9//! 3. All sum tags are valid.
10//! 4. a sum's payload follows 2-3 recursively.
11//!
12//! That is where this module comes in,
13//! which provides two functions:
14//! - [`static_bsatn_validator`], which compiles a validator program given the table's row type.
15//! - [`validate_bsatn`], which executes the validator program against a row encoded in BSATN.
16//!
17//! The compilation uses the same strategy as for row type visitors,
18//! first simplifying to a rose-tree and then flattening that to
19//! a simple forward-progress-only byte code instruction set.
20
21#![allow(unused)]
22
23use crate::layout::ProductTypeLayoutView;
24
25use super::{
26 layout::{AlgebraicTypeLayout, HasLayout as _, ProductTypeLayout, RowTypeLayout},
27 static_layout::StaticLayout,
28 MemoryUsage,
29};
30use itertools::{repeat_n, Itertools as _};
31use spacetimedb_sats::bsatn::DecodeError;
32use spacetimedb_schema::type_for_generate::PrimitiveType;
33use std::sync::Arc;
34
35/// Constructs a validator for a row encoded in BSATN
36/// that checks that the row satisfies the type `ty`
37/// when `ty` has `StaticLayout`.
38///
39/// This is a potentially expensive operation,
40/// so the resulting `StaticBsatnValidator` should be stored and re-used.
41pub(crate) fn static_bsatn_validator(ty: &RowTypeLayout) -> StaticBsatnValidator {
42 let tree = row_type_to_tree(ty.product());
43 let insns = tree_to_insns(&tree).into();
44 StaticBsatnValidator { insns }
45}
46
47/// Construct a `Tree` from `ty`.
48///
49/// See [`extend_trees_for_algebraic_type`] for more details.
50fn row_type_to_tree(ty: ProductTypeLayoutView<'_>) -> Tree {
51 let mut sub_trees = Vec::new();
52 extend_trees_for_product_type(ty, &mut 0, &mut sub_trees);
53 sub_trees_to_tree(sub_trees)
54}
55
56/// Convert a list of `sub_trees` to one tree.
57fn sub_trees_to_tree(mut sub_trees: Vec<Tree>) -> Tree {
58 match sub_trees.len() {
59 // No trees is `Empty`.
60 0 => Tree::Empty,
61 // A single subtree can be collapsed.
62 // so prune the intermediate node.
63 1 => sub_trees.pop().unwrap(),
64 // For more than one children, sequence them doing one after the other.
65 _ => Tree::Sequence { sub_trees },
66 }
67}
68
69/// Extend `sub_trees` with checks for `ty`.
70///
71/// See [`extend_trees_for_algebraic_type`] for more details.
72fn extend_trees_for_product_type(ty: ProductTypeLayoutView<'_>, current_offset: &mut usize, sub_trees: &mut Vec<Tree>) {
73 for elem in ty.elements {
74 extend_trees_for_algebraic_type(&elem.ty, current_offset, sub_trees);
75 }
76}
77
78/// Extend `sub_trees` with checks for `ty`.
79///
80/// `current_offset` should be passed as `&mut 0` upon entry to the row-type,
81/// and will be incremented as appropriate during recursive traversal
82/// to track the offset in bytes of the member currently being visited.
83fn extend_trees_for_algebraic_type(ty: &AlgebraicTypeLayout, current_offset: &mut usize, sub_trees: &mut Vec<Tree>) {
84 match ty {
85 AlgebraicTypeLayout::Primitive(PrimitiveType::Bool) => {
86 // The `Bool` type is special, as it only allows a BSATN byte to be 0 or 1.
87 let offset = *current_offset as u16;
88 *current_offset += 1;
89 sub_trees.push(Tree::CheckBool { offset });
90 }
91 AlgebraicTypeLayout::Primitive(prim_ty) => {
92 // For primitive types, increment `current_offset` past this member.
93 // Primitive types have no padding, so we can use `prim_ty.size()` for bsatn.
94 *current_offset += prim_ty.size();
95 }
96 AlgebraicTypeLayout::Product(prod_ty) => {
97 extend_trees_for_product_type(prod_ty.view(), current_offset, sub_trees)
98 }
99 AlgebraicTypeLayout::Sum(sum_ty) => {
100 // Record the tag's offset and the number of variants.
101 let num_variants = sum_ty.variants.len() as u8;
102 let tag_offset = *current_offset as u16;
103 *current_offset += 1;
104
105 // For each variant, collect that variant's sub-tree.
106 // All variants are stored overlapping at the offset of the sum
107 // so we must reset `current_offset` each time to the before-variant value.
108 // We also need to create a fresh `sub_tree` context.
109 // Note that BSATN stores sums with tag first,
110 // followed by data/payload.
111 let mut child_offset = *current_offset;
112 let mut variants = sum_ty
113 .variants
114 .iter()
115 .map(|variant| {
116 let var_ty = &variant.ty;
117 let mut sub_trees = Vec::new();
118 child_offset = *current_offset;
119 extend_trees_for_algebraic_type(var_ty, &mut child_offset, &mut sub_trees);
120 sub_trees_to_tree(sub_trees)
121 })
122 .collect::<Vec<_>>();
123 // Having dealt with all variants,
124 // we must now move `current_offset` forward to the size of the payload
125 // which we know to be same for all variants.
126 *current_offset = child_offset;
127
128 if variants.iter().all_equal() {
129 // When all variants have the same set checks,
130 // there's no need to switch on the tag, so prune the intermediate node.
131 // A special case of this is single-variant sums.
132 sub_trees.push(Tree::CheckTag {
133 tag_offset,
134 num_variants,
135 });
136 if let Some(tree) = variants.pop() {
137 sub_trees.push(tree);
138 }
139 } else {
140 sub_trees.push(Tree::Sum {
141 tag_offset,
142 tag_data_processors: variants,
143 });
144 }
145 }
146
147 // There are no var-len members when there's a static fixed bsatn length.
148 AlgebraicTypeLayout::VarLen(_) => unreachable!(),
149 }
150}
151
152/// A [Rose Tree](https://en.wikipedia.org/wiki/Rose_tree)
153/// containing information about validation steps for
154/// decoding BSATN for statically known fixed size `AlgebraicType`s.
155#[derive(Debug, PartialEq, Eq)]
156enum Tree {
157 /// Nothing to check.
158 Empty,
159
160 /// Do each sub-tree after each other.
161 Sequence { sub_trees: Vec<Tree> },
162
163 /// Check a byte at `start + N` bytes to be a valid `bool`.
164 CheckBool { offset: u16 },
165
166 /// Check a byte at `start + N` bytes to be `< num_variants`.
167 CheckTag {
168 /// The sum's tag is at `row + tag_offset` bytes.
169 tag_offset: u16,
170 /// The number of variants there are.
171 /// The read tag must be `< num_variants`.
172 num_variants: u8,
173 },
174
175 /// A choice between several variants.
176 Sum {
177 /// The sum's tag is at `row + tag_offset` bytes.
178 tag_offset: u16,
179 /// The checks for variant `N` are described in `tag_data_processors[N]`.
180 tag_data_processors: Vec<Tree>,
181 },
182}
183
184/// Compile the [`Tree`] to a list of [`Insn`].
185fn tree_to_insns(tree: &Tree) -> Vec<Insn> {
186 let mut program = Vec::new();
187
188 fn compile_tree(tree: &Tree, into: &mut Vec<Insn>) {
189 match tree {
190 Tree::Empty => {}
191 &Tree::CheckBool { offset } => into.push(Insn::CheckBool(offset)),
192 Tree::Sequence { sub_trees } => {
193 for tree in &**sub_trees {
194 compile_tree(tree, into);
195 }
196 }
197 &Tree::CheckTag {
198 tag_offset,
199 num_variants,
200 } => into.push(Insn::CheckTag(CheckTag {
201 tag_offset,
202 num_variants,
203 })),
204 Tree::Sum {
205 tag_offset,
206 tag_data_processors,
207 } => {
208 // Add the branching instruction itself.
209 let num_variants = tag_data_processors.len();
210 into.push(Insn::CheckReadTagRelBranch(CheckTag {
211 tag_offset: *tag_offset,
212 num_variants: num_variants as u8,
213 }));
214 // Add N slots for "to variant goto"s.
215 let to_branches = into.len();
216 into.extend(repeat_n(Insn::FIXUP, num_variants));
217 // Compile the branches.
218 let mut from_variant_gotos = Vec::with_capacity(num_variants);
219 for (tag, branch) in tag_data_processors.iter().enumerate() {
220 // Fixup the to-variant jump address.
221 into[to_branches + tag] = Insn::Goto(into.len() as u16);
222 // Compile the branch.
223 compile_tree(branch, into);
224 // Add jump-out gotos that we'll fixup later to store the after-sum address.
225 from_variant_gotos.push(into.len());
226 into.push(Insn::FIXUP);
227 }
228 // Fixup the jump-out-from-variant addresses.
229 let goto_addr = into.len();
230 for idx in from_variant_gotos {
231 into[idx] = Insn::Goto(goto_addr as u16);
232 }
233 }
234 }
235 }
236
237 compile_tree(tree, &mut program);
238 remove_trailing_gotos(&mut program);
239 program
240}
241
242/// Remove any trailing gotos.
243///
244/// They are not needed as they will only go towards the end,
245/// so we can just cut them out.
246fn remove_trailing_gotos(program: &mut Vec<Insn>) {
247 for idx in (0..program.len()).rev() {
248 match program[idx] {
249 Insn::Goto(_) => program.pop(),
250 _ => break,
251 };
252 }
253}
254
255#[derive(Debug, Clone, Copy, PartialEq, Eq)]
256struct CheckTag {
257 /// The tag to check is stored at `start + tag_offset`.
258 tag_offset: u16,
259 /// The number of variants there are.
260 /// The read tag must be `< num_variants`.
261 num_variants: u8,
262}
263
264/// The instruction set of a [`StaticBsatnValidator`].
265#[derive(Debug, Clone, Copy, PartialEq, Eq)]
266enum Insn {
267 /// Visit the byte at offset `start + N`
268 /// and assert that it is 0 or 1, i.e., a valid `bool`.
269 CheckBool(u16),
270
271 /// Read the `tag` at `start + tag_offset`
272 /// and validate that `tag < num_variants`.
273 CheckTag(CheckTag),
274
275 /// Read the `tag` at `start + tag_offset`
276 /// and validate that `tag < num_variants`.
277 /// Then move the instruction pointer forward by `tag + 1`.
278 /// The branch logic for the variant payload continues there.
279 CheckReadTagRelBranch(CheckTag),
280
281 /// Unconditionally branch to the instruction at `program[N]`
282 /// where `N > instruction pointer`.
283 Goto(u16),
284}
285
286impl Insn {
287 const FIXUP: Self = Self::Goto(u16::MAX);
288}
289
290impl MemoryUsage for Insn {}
291
292#[derive(Clone, Debug, PartialEq, Eq)]
293pub struct StaticBsatnValidator {
294 /// The list of instructions that make up this program.
295 insns: Arc<[Insn]>,
296}
297
298impl MemoryUsage for StaticBsatnValidator {
299 fn heap_usage(&self) -> usize {
300 let Self { insns } = self;
301 insns.heap_usage()
302 }
303}
304
305/// Check that `bytes[tag_offset] < num_variants`.
306///
307/// SAFETY: `tag_offset < bytes.len()`.
308unsafe fn check_tag(bytes: &[u8], check: CheckTag) -> Result<u8, DecodeError> {
309 // SAFETY: the caller has guaranteed that `tag_offset < bytes.len()`.
310 let tag = *unsafe { bytes.get_unchecked(check.tag_offset as usize) };
311 if tag < check.num_variants {
312 Ok(tag)
313 } else {
314 Err(DecodeError::InvalidTag { tag, sum_name: None })
315 }
316}
317
318/// Validates that `bytes`, encoded in BSATN,
319/// is valid according to the validation `program`
320/// and a corresponding `static_layout`,
321///
322/// # Safety
323///
324/// The caller must guarantee that
325/// all offsets in `program` are `< static_layout.bsatn_length`.
326pub(crate) unsafe fn validate_bsatn(
327 program: &StaticBsatnValidator,
328 static_layout: &StaticLayout,
329 bytes: &[u8],
330) -> Result<(), DecodeError> {
331 // Validate length of BSATN `bytes` against the expected length.
332 let expected = static_layout.bsatn_length as usize;
333 let given = bytes.len();
334 if expected != given {
335 return Err(DecodeError::InvalidLen { expected, given });
336 }
337
338 let program = &*program.insns;
339 let mut instr_ptr = 0;
340 loop {
341 match program.get(instr_ptr as usize).copied() {
342 None => break,
343 Some(Insn::CheckBool(offset)) => {
344 instr_ptr += 1;
345 // SAFETY: the caller has guaranteed
346 // that all offsets in `program` are `< expected`
347 // which we by now know is `= bytes.len()`.
348 let byte = *unsafe { bytes.get_unchecked(offset as usize) };
349 if byte > 1 {
350 return Err(DecodeError::InvalidBool(byte));
351 }
352 }
353 Some(Insn::Goto(new_insn)) => instr_ptr = new_insn,
354 Some(Insn::CheckTag(check)) => {
355 // SAFETY: the caller has guaranteed
356 // that all offsets in `program` are `< expected`
357 // which we by now know is `= bytes.len()`.
358 unsafe { check_tag(bytes, check) }?;
359 instr_ptr += 1;
360 }
361 Some(Insn::CheckReadTagRelBranch(check)) => {
362 // SAFETY: the caller has guaranteed
363 // that all offsets in `program` are `< expected`
364 // which we by now know is `= bytes.len()`.
365 let tag = unsafe { check_tag(bytes, check) }?;
366 instr_ptr += tag as u16 + 1;
367 }
368 }
369 }
370
371 Ok(())
372}
373
374#[cfg(test)]
375pub mod test {
376 use super::*;
377 use crate::{
378 bflatn_to::write_row_to_page, blob_store::HashMapBlobStore, page::Page, row_type_visitor::row_type_visitor,
379 };
380 use proptest::{prelude::*, prop_assert_eq, proptest};
381 use spacetimedb_sats::bsatn::to_vec;
382 use spacetimedb_sats::proptest::generate_typed_row;
383 use spacetimedb_sats::{AlgebraicType, ProductType};
384
385 proptest! {
386 // This test checks that `validate_bsatn(...).is_ok() == write_row_to_page(..).is_ok()`.
387 #![proptest_config(ProptestConfig {
388 max_global_rejects: 65536,
389 cases: if cfg!(miri) { 8 } else { 2048 },
390 ..<_>::default()
391 })]
392 #[test]
393 fn validation_same_as_write_row_to_pages((ty, val) in generate_typed_row()) {
394 let ty: RowTypeLayout = ty.into();
395 let Some(static_layout) = StaticLayout::for_row_type(&ty) else {
396 // `ty` has a var-len member or a sum with different payload lengths,
397 // so the fast path doesn't apply.
398 return Err(TestCaseError::reject("Var-length type"));
399 };
400 let validator = static_bsatn_validator(&ty);
401 let bsatn = to_vec(&val).unwrap();
402 let res_validate = unsafe { validate_bsatn(&validator, &static_layout, &bsatn) };
403
404 let mut page = Page::new(ty.size());
405 let visitor = row_type_visitor(&ty);
406 let blob_store = &mut HashMapBlobStore::default();
407 let res_write = unsafe { write_row_to_page(&mut page, blob_store, &visitor, &ty, &val) };
408
409 prop_assert_eq!(res_validate.is_ok(), res_write.is_ok());
410 }
411
412 #[test]
413 fn bad_bool_validates_to_error(byte in 2u8..) {
414 let ty: RowTypeLayout = ProductType::from([AlgebraicType::Bool]).into();
415 let static_layout = StaticLayout::for_row_type(&ty).unwrap();
416 let validator = static_bsatn_validator(&ty);
417
418 let bsatn = [byte];
419 let res_validate = unsafe { validate_bsatn(&validator, &static_layout, &bsatn) };
420 prop_assert_eq!(res_validate, Err(DecodeError::InvalidBool(byte)));
421 }
422 }
423}