spacetimedb_table/
static_bsatn_validator.rs

1//! To efficiently implement a fast-path BSATN -> BFLATN,
2//! we use a `StaticLayout` but in reverse of the read path.
3//! This however leaves us with no way to validate
4//! that the BSATN satisfies the row type of a given table.
5//!
6//! More specifically, we must validate that:
7//! 1. The length of the BSATN-encoded row matches the expected length.
8//! 2. All `bool`s in the row type only receive the values 0 or 1.
9//! 3. All sum tags are valid.
10//! 4. a sum's payload follows 2-3 recursively.
11//!
12//! That is where this module comes in,
13//! which provides two functions:
14//! - [`static_bsatn_validator`], which compiles a validator program given the table's row type.
15//! - [`validate_bsatn`], which executes the validator program against a row encoded in BSATN.
16//!
17//! The compilation uses the same strategy as for row type visitors,
18//! first simplifying to a rose-tree and then flattening that to
19//! a simple forward-progress-only byte code instruction set.
20
21#![allow(unused)]
22
23use crate::layout::ProductTypeLayoutView;
24
25use super::{
26    layout::{AlgebraicTypeLayout, HasLayout as _, ProductTypeLayout, RowTypeLayout},
27    static_layout::StaticLayout,
28    MemoryUsage,
29};
30use itertools::{repeat_n, Itertools as _};
31use spacetimedb_sats::bsatn::DecodeError;
32use spacetimedb_schema::type_for_generate::PrimitiveType;
33use std::sync::Arc;
34
35/// Constructs a validator for a row encoded in BSATN
36/// that checks that the row satisfies the type `ty`
37/// when `ty` has `StaticLayout`.
38///
39/// This is a potentially expensive operation,
40/// so the resulting `StaticBsatnValidator` should be stored and re-used.
41pub(crate) fn static_bsatn_validator(ty: &RowTypeLayout) -> StaticBsatnValidator {
42    let tree = row_type_to_tree(ty.product());
43    let insns = tree_to_insns(&tree).into();
44    StaticBsatnValidator { insns }
45}
46
47/// Construct a `Tree` from `ty`.
48///
49/// See [`extend_trees_for_algebraic_type`] for more details.
50fn row_type_to_tree(ty: ProductTypeLayoutView<'_>) -> Tree {
51    let mut sub_trees = Vec::new();
52    extend_trees_for_product_type(ty, &mut 0, &mut sub_trees);
53    sub_trees_to_tree(sub_trees)
54}
55
56/// Convert a list of `sub_trees` to one tree.
57fn sub_trees_to_tree(mut sub_trees: Vec<Tree>) -> Tree {
58    match sub_trees.len() {
59        // No trees is `Empty`.
60        0 => Tree::Empty,
61        // A single subtree can be collapsed.
62        // so prune the intermediate node.
63        1 => sub_trees.pop().unwrap(),
64        // For more than one children, sequence them doing one after the other.
65        _ => Tree::Sequence { sub_trees },
66    }
67}
68
69/// Extend `sub_trees` with checks for `ty`.
70///
71/// See [`extend_trees_for_algebraic_type`] for more details.
72fn extend_trees_for_product_type(ty: ProductTypeLayoutView<'_>, current_offset: &mut usize, sub_trees: &mut Vec<Tree>) {
73    for elem in ty.elements {
74        extend_trees_for_algebraic_type(&elem.ty, current_offset, sub_trees);
75    }
76}
77
78/// Extend `sub_trees` with checks for `ty`.
79///
80/// `current_offset` should be passed as `&mut 0` upon entry to the row-type,
81/// and will be incremented as appropriate during recursive traversal
82/// to track the offset in bytes of the member currently being visited.
83fn extend_trees_for_algebraic_type(ty: &AlgebraicTypeLayout, current_offset: &mut usize, sub_trees: &mut Vec<Tree>) {
84    match ty {
85        AlgebraicTypeLayout::Primitive(PrimitiveType::Bool) => {
86            // The `Bool` type is special, as it only allows a BSATN byte to be 0 or 1.
87            let offset = *current_offset as u16;
88            *current_offset += 1;
89            sub_trees.push(Tree::CheckBool { offset });
90        }
91        AlgebraicTypeLayout::Primitive(prim_ty) => {
92            // For primitive types, increment `current_offset` past this member.
93            // Primitive types have no padding, so we can use `prim_ty.size()` for bsatn.
94            *current_offset += prim_ty.size();
95        }
96        AlgebraicTypeLayout::Product(prod_ty) => {
97            extend_trees_for_product_type(prod_ty.view(), current_offset, sub_trees)
98        }
99        AlgebraicTypeLayout::Sum(sum_ty) => {
100            // Record the tag's offset and the number of variants.
101            let num_variants = sum_ty.variants.len() as u8;
102            let tag_offset = *current_offset as u16;
103            *current_offset += 1;
104
105            // For each variant, collect that variant's sub-tree.
106            // All variants are stored overlapping at the offset of the sum
107            // so we must reset `current_offset` each time to the before-variant value.
108            // We also need to create a fresh `sub_tree` context.
109            // Note that BSATN stores sums with tag first,
110            // followed by data/payload.
111            let mut child_offset = *current_offset;
112            let mut variants = sum_ty
113                .variants
114                .iter()
115                .map(|variant| {
116                    let var_ty = &variant.ty;
117                    let mut sub_trees = Vec::new();
118                    child_offset = *current_offset;
119                    extend_trees_for_algebraic_type(var_ty, &mut child_offset, &mut sub_trees);
120                    sub_trees_to_tree(sub_trees)
121                })
122                .collect::<Vec<_>>();
123            // Having dealt with all variants,
124            // we must now move `current_offset` forward to the size of the payload
125            // which we know to be same for all variants.
126            *current_offset = child_offset;
127
128            if variants.iter().all_equal() {
129                // When all variants have the same set checks,
130                // there's no need to switch on the tag, so prune the intermediate node.
131                // A special case of this is single-variant sums.
132                sub_trees.push(Tree::CheckTag {
133                    tag_offset,
134                    num_variants,
135                });
136                if let Some(tree) = variants.pop() {
137                    sub_trees.push(tree);
138                }
139            } else {
140                sub_trees.push(Tree::Sum {
141                    tag_offset,
142                    tag_data_processors: variants,
143                });
144            }
145        }
146
147        // There are no var-len members when there's a static fixed bsatn length.
148        AlgebraicTypeLayout::VarLen(_) => unreachable!(),
149    }
150}
151
152/// A [Rose Tree](https://en.wikipedia.org/wiki/Rose_tree)
153/// containing information about validation steps for
154/// decoding BSATN for statically known fixed size `AlgebraicType`s.
155#[derive(Debug, PartialEq, Eq)]
156enum Tree {
157    /// Nothing to check.
158    Empty,
159
160    /// Do each sub-tree after each other.
161    Sequence { sub_trees: Vec<Tree> },
162
163    /// Check a byte at `start + N` bytes to be a valid `bool`.
164    CheckBool { offset: u16 },
165
166    /// Check a byte at `start + N` bytes to be `< num_variants`.
167    CheckTag {
168        /// The sum's tag is at `row + tag_offset` bytes.
169        tag_offset: u16,
170        /// The number of variants there are.
171        /// The read tag must be `< num_variants`.
172        num_variants: u8,
173    },
174
175    /// A choice between several variants.
176    Sum {
177        /// The sum's tag is at `row + tag_offset` bytes.
178        tag_offset: u16,
179        /// The checks for variant `N` are described in `tag_data_processors[N]`.
180        tag_data_processors: Vec<Tree>,
181    },
182}
183
184/// Compile the [`Tree`] to a list of [`Insn`].
185fn tree_to_insns(tree: &Tree) -> Vec<Insn> {
186    let mut program = Vec::new();
187
188    fn compile_tree(tree: &Tree, into: &mut Vec<Insn>) {
189        match tree {
190            Tree::Empty => {}
191            &Tree::CheckBool { offset } => into.push(Insn::CheckBool(offset)),
192            Tree::Sequence { sub_trees } => {
193                for tree in &**sub_trees {
194                    compile_tree(tree, into);
195                }
196            }
197            &Tree::CheckTag {
198                tag_offset,
199                num_variants,
200            } => into.push(Insn::CheckTag(CheckTag {
201                tag_offset,
202                num_variants,
203            })),
204            Tree::Sum {
205                tag_offset,
206                tag_data_processors,
207            } => {
208                // Add the branching instruction itself.
209                let num_variants = tag_data_processors.len();
210                into.push(Insn::CheckReadTagRelBranch(CheckTag {
211                    tag_offset: *tag_offset,
212                    num_variants: num_variants as u8,
213                }));
214                // Add N slots for "to variant goto"s.
215                let to_branches = into.len();
216                into.extend(repeat_n(Insn::FIXUP, num_variants));
217                // Compile the branches.
218                let mut from_variant_gotos = Vec::with_capacity(num_variants);
219                for (tag, branch) in tag_data_processors.iter().enumerate() {
220                    // Fixup the to-variant jump address.
221                    into[to_branches + tag] = Insn::Goto(into.len() as u16);
222                    // Compile the branch.
223                    compile_tree(branch, into);
224                    // Add jump-out gotos that we'll fixup later to store the after-sum address.
225                    from_variant_gotos.push(into.len());
226                    into.push(Insn::FIXUP);
227                }
228                // Fixup the jump-out-from-variant addresses.
229                let goto_addr = into.len();
230                for idx in from_variant_gotos {
231                    into[idx] = Insn::Goto(goto_addr as u16);
232                }
233            }
234        }
235    }
236
237    compile_tree(tree, &mut program);
238    remove_trailing_gotos(&mut program);
239    program
240}
241
242/// Remove any trailing gotos.
243///
244/// They are not needed as they will only go towards the end,
245/// so we can just cut them out.
246fn remove_trailing_gotos(program: &mut Vec<Insn>) {
247    for idx in (0..program.len()).rev() {
248        match program[idx] {
249            Insn::Goto(_) => program.pop(),
250            _ => break,
251        };
252    }
253}
254
255#[derive(Debug, Clone, Copy, PartialEq, Eq)]
256struct CheckTag {
257    /// The tag to check is stored at `start + tag_offset`.
258    tag_offset: u16,
259    /// The number of variants there are.
260    /// The read tag must be `< num_variants`.
261    num_variants: u8,
262}
263
264/// The instruction set of a [`StaticBsatnValidator`].
265#[derive(Debug, Clone, Copy, PartialEq, Eq)]
266enum Insn {
267    /// Visit the byte at offset `start + N`
268    /// and assert that it is 0 or 1, i.e., a valid `bool`.
269    CheckBool(u16),
270
271    /// Read the `tag` at `start + tag_offset`
272    /// and validate that `tag < num_variants`.
273    CheckTag(CheckTag),
274
275    /// Read the `tag` at `start + tag_offset`
276    /// and validate that `tag < num_variants`.
277    /// Then move the instruction pointer forward by `tag + 1`.
278    /// The branch logic for the variant payload continues there.
279    CheckReadTagRelBranch(CheckTag),
280
281    /// Unconditionally branch to the instruction at `program[N]`
282    /// where `N > instruction pointer`.
283    Goto(u16),
284}
285
286impl Insn {
287    const FIXUP: Self = Self::Goto(u16::MAX);
288}
289
290impl MemoryUsage for Insn {}
291
292#[derive(Clone, Debug, PartialEq, Eq)]
293pub struct StaticBsatnValidator {
294    /// The list of instructions that make up this program.
295    insns: Arc<[Insn]>,
296}
297
298impl MemoryUsage for StaticBsatnValidator {
299    fn heap_usage(&self) -> usize {
300        let Self { insns } = self;
301        insns.heap_usage()
302    }
303}
304
305/// Check that `bytes[tag_offset] < num_variants`.
306///
307/// SAFETY: `tag_offset < bytes.len()`.
308unsafe fn check_tag(bytes: &[u8], check: CheckTag) -> Result<u8, DecodeError> {
309    // SAFETY: the caller has guaranteed that `tag_offset < bytes.len()`.
310    let tag = *unsafe { bytes.get_unchecked(check.tag_offset as usize) };
311    if tag < check.num_variants {
312        Ok(tag)
313    } else {
314        Err(DecodeError::InvalidTag { tag, sum_name: None })
315    }
316}
317
318/// Validates that `bytes`, encoded in BSATN,
319/// is valid according to the validation `program`
320/// and a corresponding `static_layout`,
321///
322/// # Safety
323///
324/// The caller must guarantee that
325/// all offsets in `program` are `< static_layout.bsatn_length`.
326pub(crate) unsafe fn validate_bsatn(
327    program: &StaticBsatnValidator,
328    static_layout: &StaticLayout,
329    bytes: &[u8],
330) -> Result<(), DecodeError> {
331    // Validate length of BSATN `bytes` against the expected length.
332    let expected = static_layout.bsatn_length as usize;
333    let given = bytes.len();
334    if expected != given {
335        return Err(DecodeError::InvalidLen { expected, given });
336    }
337
338    let program = &*program.insns;
339    let mut instr_ptr = 0;
340    loop {
341        match program.get(instr_ptr as usize).copied() {
342            None => break,
343            Some(Insn::CheckBool(offset)) => {
344                instr_ptr += 1;
345                // SAFETY: the caller has guaranteed
346                // that all offsets in `program` are `< expected`
347                // which we by now know is `= bytes.len()`.
348                let byte = *unsafe { bytes.get_unchecked(offset as usize) };
349                if byte > 1 {
350                    return Err(DecodeError::InvalidBool(byte));
351                }
352            }
353            Some(Insn::Goto(new_insn)) => instr_ptr = new_insn,
354            Some(Insn::CheckTag(check)) => {
355                // SAFETY: the caller has guaranteed
356                // that all offsets in `program` are `< expected`
357                // which we by now know is `= bytes.len()`.
358                unsafe { check_tag(bytes, check) }?;
359                instr_ptr += 1;
360            }
361            Some(Insn::CheckReadTagRelBranch(check)) => {
362                // SAFETY: the caller has guaranteed
363                // that all offsets in `program` are `< expected`
364                // which we by now know is `= bytes.len()`.
365                let tag = unsafe { check_tag(bytes, check) }?;
366                instr_ptr += tag as u16 + 1;
367            }
368        }
369    }
370
371    Ok(())
372}
373
374#[cfg(test)]
375pub mod test {
376    use super::*;
377    use crate::{
378        bflatn_to::write_row_to_page, blob_store::HashMapBlobStore, page::Page, row_type_visitor::row_type_visitor,
379    };
380    use proptest::{prelude::*, prop_assert_eq, proptest};
381    use spacetimedb_sats::bsatn::to_vec;
382    use spacetimedb_sats::proptest::generate_typed_row;
383    use spacetimedb_sats::{AlgebraicType, ProductType};
384
385    proptest! {
386        // This test checks that `validate_bsatn(...).is_ok() == write_row_to_page(..).is_ok()`.
387        #![proptest_config(ProptestConfig {
388            max_global_rejects: 65536,
389            cases: if cfg!(miri) { 8 } else { 2048 },
390            ..<_>::default()
391        })]
392        #[test]
393        fn validation_same_as_write_row_to_pages((ty, val) in generate_typed_row()) {
394            let ty: RowTypeLayout = ty.into();
395            let Some(static_layout) = StaticLayout::for_row_type(&ty) else {
396                // `ty` has a var-len member or a sum with different payload lengths,
397                // so the fast path doesn't apply.
398                return Err(TestCaseError::reject("Var-length type"));
399            };
400            let validator = static_bsatn_validator(&ty);
401            let bsatn = to_vec(&val).unwrap();
402            let res_validate = unsafe { validate_bsatn(&validator, &static_layout, &bsatn) };
403
404            let mut page = Page::new(ty.size());
405            let visitor = row_type_visitor(&ty);
406            let blob_store = &mut HashMapBlobStore::default();
407            let res_write = unsafe { write_row_to_page(&mut page, blob_store, &visitor, &ty, &val) };
408
409            prop_assert_eq!(res_validate.is_ok(), res_write.is_ok());
410        }
411
412        #[test]
413        fn bad_bool_validates_to_error(byte in 2u8..) {
414            let ty: RowTypeLayout = ProductType::from([AlgebraicType::Bool]).into();
415            let static_layout = StaticLayout::for_row_type(&ty).unwrap();
416            let validator = static_bsatn_validator(&ty);
417
418            let bsatn = [byte];
419            let res_validate = unsafe { validate_bsatn(&validator, &static_layout, &bsatn) };
420            prop_assert_eq!(res_validate, Err(DecodeError::InvalidBool(byte)));
421        }
422    }
423}