spacetimedb_table/
static_bsatn_validator.rs

1//! To efficiently implement a fast-path BSATN -> BFLATN,
2//! we use a `StaticLayout` but in reverse of the read path.
3//! This however leaves us with no way to validate
4//! that the BSATN satisfies the row type of a given table.
5//!
6//! More specifically, we must validate that:
7//! 1. The length of the BSATN-encoded row matches the expected length.
8//! 2. All `bool`s in the row type only receive the values 0 or 1.
9//! 3. All sum tags are valid.
10//! 4. a sum's payload follows 2-3 recursively.
11//!
12//! That is where this module comes in,
13//! which provides two functions:
14//! - [`static_bsatn_validator`], which compiles a validator program given the table's row type.
15//! - [`validate_bsatn`], which executes the validator program against a row encoded in BSATN.
16//!
17//! The compilation uses the same strategy as for row type visitors,
18//! first simplifying to a rose-tree and then flattening that to
19//! a simple forward-progress-only byte code instruction set.
20
21#![allow(unused)]
22
23use super::{
24    layout::{AlgebraicTypeLayout, HasLayout as _, ProductTypeLayout, RowTypeLayout},
25    static_layout::StaticLayout,
26    MemoryUsage,
27};
28use itertools::{repeat_n, Itertools as _};
29use spacetimedb_sats::bsatn::DecodeError;
30use spacetimedb_schema::type_for_generate::PrimitiveType;
31use std::sync::Arc;
32
33/// Constructs a validator for a row encoded in BSATN
34/// that checks that the row satisfies the type `ty`
35/// when `ty` has `StaticLayout`.
36///
37/// This is a potentially expensive operation,
38/// so the resulting `StaticBsatnValidator` should be stored and re-used.
39pub(crate) fn static_bsatn_validator(ty: &RowTypeLayout) -> StaticBsatnValidator {
40    let tree = row_type_to_tree(ty.product());
41    let insns = tree_to_insns(&tree).into();
42    StaticBsatnValidator { insns }
43}
44
45/// Construct a `Tree` from `ty`.
46///
47/// See [`extend_trees_for_algebraic_type`] for more details.
48fn row_type_to_tree(ty: &ProductTypeLayout) -> Tree {
49    let mut sub_trees = Vec::new();
50    extend_trees_for_product_type(ty, &mut 0, &mut sub_trees);
51    sub_trees_to_tree(sub_trees)
52}
53
54/// Convert a list of `sub_trees` to one tree.
55fn sub_trees_to_tree(mut sub_trees: Vec<Tree>) -> Tree {
56    match sub_trees.len() {
57        // No trees is `Empty`.
58        0 => Tree::Empty,
59        // A single subtree can be collapsed.
60        // so prune the intermediate node.
61        1 => sub_trees.pop().unwrap(),
62        // For more than one children, sequence them doing one after the other.
63        _ => Tree::Sequence { sub_trees },
64    }
65}
66
67/// Extend `sub_trees` with checks for `ty`.
68///
69/// See [`extend_trees_for_algebraic_type`] for more details.
70fn extend_trees_for_product_type(ty: &ProductTypeLayout, current_offset: &mut usize, sub_trees: &mut Vec<Tree>) {
71    for elem in &*ty.elements {
72        extend_trees_for_algebraic_type(&elem.ty, current_offset, sub_trees);
73    }
74}
75
76/// Extend `sub_trees` with checks for `ty`.
77///
78/// `current_offset` should be passed as `&mut 0` upon entry to the row-type,
79/// and will be incremented as appropriate during recursive traversal
80/// to track the offset in bytes of the member currently being visited.
81fn extend_trees_for_algebraic_type(ty: &AlgebraicTypeLayout, current_offset: &mut usize, sub_trees: &mut Vec<Tree>) {
82    match ty {
83        AlgebraicTypeLayout::Primitive(PrimitiveType::Bool) => {
84            // The `Bool` type is special, as it only allows a BSATN byte to be 0 or 1.
85            let offset = *current_offset as u16;
86            *current_offset += 1;
87            sub_trees.push(Tree::CheckBool { offset });
88        }
89        AlgebraicTypeLayout::Primitive(prim_ty) => {
90            // For primitive types, increment `current_offset` past this member.
91            // Primitive types have no padding, so we can use `prim_ty.size()` for bsatn.
92            *current_offset += prim_ty.size();
93        }
94        AlgebraicTypeLayout::Product(prod_ty) => extend_trees_for_product_type(prod_ty, current_offset, sub_trees),
95        AlgebraicTypeLayout::Sum(sum_ty) => {
96            // Record the tag's offset and the number of variants.
97            let num_variants = sum_ty.variants.len() as u8;
98            let tag_offset = *current_offset as u16;
99            *current_offset += 1;
100
101            // For each variant, collect that variant's sub-tree.
102            // All variants are stored overlapping at the offset of the sum
103            // so we must reset `current_offset` each time to the before-variant value.
104            // We also need to create a fresh `sub_tree` context.
105            // Note that BSATN stores sums with tag first,
106            // followed by data/payload.
107            let mut child_offset = *current_offset;
108            let mut variants = sum_ty
109                .variants
110                .iter()
111                .map(|variant| {
112                    let var_ty = &variant.ty;
113                    let mut sub_trees = Vec::new();
114                    child_offset = *current_offset;
115                    extend_trees_for_algebraic_type(var_ty, &mut child_offset, &mut sub_trees);
116                    sub_trees_to_tree(sub_trees)
117                })
118                .collect::<Vec<_>>();
119            // Having dealt with all variants,
120            // we must now move `current_offset` forward to the size of the payload
121            // which we know to be same for all variants.
122            *current_offset = child_offset;
123
124            if variants.iter().all_equal() {
125                // When all variants have the same set checks,
126                // there's no need to switch on the tag, so prune the intermediate node.
127                // A special case of this is single-variant sums.
128                sub_trees.push(Tree::CheckTag {
129                    tag_offset,
130                    num_variants,
131                });
132                if let Some(tree) = variants.pop() {
133                    sub_trees.push(tree);
134                }
135            } else {
136                sub_trees.push(Tree::Sum {
137                    tag_offset,
138                    tag_data_processors: variants,
139                });
140            }
141        }
142
143        // There are no var-len members when there's a static fixed bsatn length.
144        AlgebraicTypeLayout::VarLen(_) => unreachable!(),
145    }
146}
147
148/// A [Rose Tree](https://en.wikipedia.org/wiki/Rose_tree)
149/// containing information about validation steps for
150/// decoding BSATN for statically known fixed size `AlgebraicType`s.
151#[derive(Debug, PartialEq, Eq)]
152enum Tree {
153    /// Nothing to check.
154    Empty,
155
156    /// Do each sub-tree after each other.
157    Sequence { sub_trees: Vec<Tree> },
158
159    /// Check a byte at `start + N` bytes to be a valid `bool`.
160    CheckBool { offset: u16 },
161
162    /// Check a byte at `start + N` bytes to be `< num_variants`.
163    CheckTag {
164        /// The sum's tag is at `row + tag_offset` bytes.
165        tag_offset: u16,
166        /// The number of variants there are.
167        /// The read tag must be `< num_variants`.
168        num_variants: u8,
169    },
170
171    /// A choice between several variants.
172    Sum {
173        /// The sum's tag is at `row + tag_offset` bytes.
174        tag_offset: u16,
175        /// The checks for variant `N` are described in `tag_data_processors[N]`.
176        tag_data_processors: Vec<Tree>,
177    },
178}
179
180/// Compile the [`Tree`] to a list of [`Insn`].
181fn tree_to_insns(tree: &Tree) -> Vec<Insn> {
182    let mut program = Vec::new();
183
184    fn compile_tree(tree: &Tree, into: &mut Vec<Insn>) {
185        match tree {
186            Tree::Empty => {}
187            &Tree::CheckBool { offset } => into.push(Insn::CheckBool(offset)),
188            Tree::Sequence { sub_trees } => {
189                for tree in &**sub_trees {
190                    compile_tree(tree, into);
191                }
192            }
193            &Tree::CheckTag {
194                tag_offset,
195                num_variants,
196            } => into.push(Insn::CheckTag(CheckTag {
197                tag_offset,
198                num_variants,
199            })),
200            Tree::Sum {
201                tag_offset,
202                tag_data_processors,
203            } => {
204                // Add the branching instruction itself.
205                let num_variants = tag_data_processors.len();
206                into.push(Insn::CheckReadTagRelBranch(CheckTag {
207                    tag_offset: *tag_offset,
208                    num_variants: num_variants as u8,
209                }));
210                // Add N slots for "to variant goto"s.
211                let to_branches = into.len();
212                into.extend(repeat_n(Insn::FIXUP, num_variants));
213                // Compile the branches.
214                let mut from_variant_gotos = Vec::with_capacity(num_variants);
215                for (tag, branch) in tag_data_processors.iter().enumerate() {
216                    // Fixup the to-variant jump address.
217                    into[to_branches + tag] = Insn::Goto(into.len() as u16);
218                    // Compile the branch.
219                    compile_tree(branch, into);
220                    // Add jump-out gotos that we'll fixup later to store the after-sum address.
221                    from_variant_gotos.push(into.len());
222                    into.push(Insn::FIXUP);
223                }
224                // Fixup the jump-out-from-variant addresses.
225                let goto_addr = into.len();
226                for idx in from_variant_gotos {
227                    into[idx] = Insn::Goto(goto_addr as u16);
228                }
229            }
230        }
231    }
232
233    compile_tree(tree, &mut program);
234    remove_trailing_gotos(&mut program);
235    program
236}
237
238/// Remove any trailing gotos.
239///
240/// They are not needed as they will only go towards the end,
241/// so we can just cut them out.
242fn remove_trailing_gotos(program: &mut Vec<Insn>) {
243    for idx in (0..program.len()).rev() {
244        match program[idx] {
245            Insn::Goto(_) => program.pop(),
246            _ => break,
247        };
248    }
249}
250
251#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252struct CheckTag {
253    /// The tag to check is stored at `start + tag_offset`.
254    tag_offset: u16,
255    /// The number of variants there are.
256    /// The read tag must be `< num_variants`.
257    num_variants: u8,
258}
259
260/// The instruction set of a [`StaticBsatnValidator`].
261#[derive(Debug, Clone, Copy, PartialEq, Eq)]
262enum Insn {
263    /// Visit the byte at offset `start + N`
264    /// and assert that it is 0 or 1, i.e., a valid `bool`.
265    CheckBool(u16),
266
267    /// Read the `tag` at `start + tag_offset`
268    /// and validate that `tag < num_variants`.
269    CheckTag(CheckTag),
270
271    /// Read the `tag` at `start + tag_offset`
272    /// and validate that `tag < num_variants`.
273    /// Then move the instruction pointer forward by `tag + 1`.
274    /// The branch logic for the variant payload continues there.
275    CheckReadTagRelBranch(CheckTag),
276
277    /// Unconditionally branch to the instruction at `program[N]`
278    /// where `N > instruction pointer`.
279    Goto(u16),
280}
281
282impl Insn {
283    const FIXUP: Self = Self::Goto(u16::MAX);
284}
285
286impl MemoryUsage for Insn {}
287
288#[derive(Clone, Debug, PartialEq, Eq)]
289pub struct StaticBsatnValidator {
290    /// The list of instructions that make up this program.
291    insns: Arc<[Insn]>,
292}
293
294impl MemoryUsage for StaticBsatnValidator {
295    fn heap_usage(&self) -> usize {
296        let Self { insns } = self;
297        insns.heap_usage()
298    }
299}
300
301/// Check that `bytes[tag_offset] < num_variants`.
302///
303/// SAFETY: `tag_offset < bytes.len()`.
304unsafe fn check_tag(bytes: &[u8], check: CheckTag) -> Result<u8, DecodeError> {
305    // SAFETY: the caller has guaranteed that `tag_offset < bytes.len()`.
306    let tag = *unsafe { bytes.get_unchecked(check.tag_offset as usize) };
307    if tag < check.num_variants {
308        Ok(tag)
309    } else {
310        Err(DecodeError::InvalidTag { tag, sum_name: None })
311    }
312}
313
314/// Validates that `bytes`, encoded in BSATN,
315/// is valid according to the validation `program`
316/// and a corresponding `static_layout`,
317///
318/// # Safety
319///
320/// The caller must guarantee that
321/// all offsets in `program` are `< static_layout.bsatn_length`.
322pub(crate) unsafe fn validate_bsatn(
323    program: &StaticBsatnValidator,
324    static_layout: &StaticLayout,
325    bytes: &[u8],
326) -> Result<(), DecodeError> {
327    // Validate length of BSATN `bytes` against the expected length.
328    let expected = static_layout.bsatn_length as usize;
329    let given = bytes.len();
330    if expected != given {
331        return Err(DecodeError::InvalidLen { expected, given });
332    }
333
334    let program = &*program.insns;
335    let mut instr_ptr = 0;
336    loop {
337        match program.get(instr_ptr as usize).copied() {
338            None => break,
339            Some(Insn::CheckBool(offset)) => {
340                instr_ptr += 1;
341                // SAFETY: the caller has guaranteed
342                // that all offsets in `program` are `< expected`
343                // which we by now know is `= bytes.len()`.
344                let byte = *unsafe { bytes.get_unchecked(offset as usize) };
345                if byte > 1 {
346                    return Err(DecodeError::InvalidBool(byte));
347                }
348            }
349            Some(Insn::Goto(new_insn)) => instr_ptr = new_insn,
350            Some(Insn::CheckTag(check)) => {
351                // SAFETY: the caller has guaranteed
352                // that all offsets in `program` are `< expected`
353                // which we by now know is `= bytes.len()`.
354                unsafe { check_tag(bytes, check) }?;
355                instr_ptr += 1;
356            }
357            Some(Insn::CheckReadTagRelBranch(check)) => {
358                // SAFETY: the caller has guaranteed
359                // that all offsets in `program` are `< expected`
360                // which we by now know is `= bytes.len()`.
361                let tag = unsafe { check_tag(bytes, check) }?;
362                instr_ptr += tag as u16 + 1;
363            }
364        }
365    }
366
367    Ok(())
368}
369
370#[cfg(test)]
371pub mod test {
372    use super::*;
373    use crate::{
374        bflatn_to::write_row_to_page, blob_store::HashMapBlobStore, page::Page, row_type_visitor::row_type_visitor,
375    };
376    use proptest::{prelude::*, prop_assert_eq, proptest};
377    use spacetimedb_sats::bsatn::to_vec;
378    use spacetimedb_sats::proptest::generate_typed_row;
379    use spacetimedb_sats::{AlgebraicType, ProductType};
380
381    proptest! {
382        // This test checks that `validate_bsatn(...).is_ok() == write_row_to_page(..).is_ok()`.
383        #![proptest_config(ProptestConfig {
384            max_global_rejects: 65536,
385            cases: if cfg!(miri) { 8 } else { 2048 },
386            ..<_>::default()
387        })]
388        #[test]
389        fn validation_same_as_write_row_to_pages((ty, val) in generate_typed_row()) {
390            let ty: RowTypeLayout = ty.into();
391            let Some(static_layout) = StaticLayout::for_row_type(&ty) else {
392                // `ty` has a var-len member or a sum with different payload lengths,
393                // so the fast path doesn't apply.
394                return Err(TestCaseError::reject("Var-length type"));
395            };
396            let validator = static_bsatn_validator(&ty);
397            let bsatn = to_vec(&val).unwrap();
398            let res_validate = unsafe { validate_bsatn(&validator, &static_layout, &bsatn) };
399
400            let mut page = Page::new(ty.size());
401            let visitor = row_type_visitor(&ty);
402            let blob_store = &mut HashMapBlobStore::default();
403            let res_write = unsafe { write_row_to_page(&mut page, blob_store, &visitor, &ty, &val) };
404
405            prop_assert_eq!(res_validate.is_ok(), res_write.is_ok());
406        }
407
408        #[test]
409        fn bad_bool_validates_to_error(byte in 2u8..) {
410            let ty: RowTypeLayout = ProductType::from([AlgebraicType::Bool]).into();
411            let static_layout = StaticLayout::for_row_type(&ty).unwrap();
412            let validator = static_bsatn_validator(&ty);
413
414            let bsatn = [byte];
415            let res_validate = unsafe { validate_bsatn(&validator, &static_layout, &bsatn) };
416            prop_assert_eq!(res_validate, Err(DecodeError::InvalidBool(byte)));
417        }
418    }
419}