spacetimedb/
vm.rs

1//! The [DbProgram] that execute arbitrary queries & code against the database.
2
3use crate::db::datastore::locking_tx_datastore::state_view::IterByColRangeMutTx;
4use crate::db::datastore::locking_tx_datastore::tx::TxId;
5use crate::db::datastore::locking_tx_datastore::IterByColRangeTx;
6use crate::db::datastore::system_tables::{st_var_schema, StVarName, StVarRow};
7use crate::db::relational_db::{MutTx, RelationalDB, Tx};
8use crate::error::DBError;
9use crate::estimation;
10use crate::execution_context::ExecutionContext;
11use core::ops::{Bound, RangeBounds};
12use itertools::Itertools;
13use spacetimedb_data_structures::map::IntMap;
14use spacetimedb_lib::identity::AuthCtx;
15use spacetimedb_lib::relation::{ColExpr, DbTable};
16use spacetimedb_primitives::*;
17use spacetimedb_sats::{AlgebraicValue, ProductValue};
18use spacetimedb_table::static_assert_size;
19use spacetimedb_table::table::RowRef;
20use spacetimedb_vm::errors::ErrorVm;
21use spacetimedb_vm::eval::{box_iter, build_project, build_select, join_inner, IterRows};
22use spacetimedb_vm::expr::*;
23use spacetimedb_vm::iterators::RelIter;
24use spacetimedb_vm::program::{ProgramVm, Sources};
25use spacetimedb_vm::rel_ops::{EmptyRelOps, RelOps};
26use spacetimedb_vm::relation::{MemTable, RelValue};
27use std::str::FromStr;
28use std::sync::Arc;
29
30pub enum TxMode<'a> {
31    MutTx(&'a mut MutTx),
32    Tx(&'a Tx),
33}
34
35impl TxMode<'_> {
36    /// Unwraps `self`, ensuring we are in a mutable tx.
37    fn unwrap_mut(&mut self) -> &mut MutTx {
38        match self {
39            Self::MutTx(tx) => tx,
40            Self::Tx(_) => unreachable!("mutable operation is invalid with read tx"),
41        }
42    }
43
44    pub(crate) fn ctx(&self) -> &ExecutionContext {
45        match self {
46            Self::MutTx(tx) => &tx.ctx,
47            Self::Tx(tx) => &tx.ctx,
48        }
49    }
50}
51
52impl<'a> From<&'a mut MutTx> for TxMode<'a> {
53    fn from(tx: &'a mut MutTx) -> Self {
54        TxMode::MutTx(tx)
55    }
56}
57
58impl<'a> From<&'a Tx> for TxMode<'a> {
59    fn from(tx: &'a Tx) -> Self {
60        TxMode::Tx(tx)
61    }
62}
63
64impl<'a> From<&'a mut Tx> for TxMode<'a> {
65    fn from(tx: &'a mut Tx) -> Self {
66        TxMode::Tx(tx)
67    }
68}
69
70fn bound_is_satisfiable(lower: &Bound<AlgebraicValue>, upper: &Bound<AlgebraicValue>) -> bool {
71    match (lower, upper) {
72        (Bound::Excluded(lower), Bound::Excluded(upper)) if lower >= upper => false,
73        (Bound::Included(lower), Bound::Excluded(upper)) | (Bound::Excluded(lower), Bound::Included(upper))
74            if lower > upper =>
75        {
76            false
77        }
78        _ => true,
79    }
80}
81
82//TODO: This is partially duplicated from the `vm` crate to avoid borrow checker issues
83//and pull all that crate in core. Will be revisited after trait refactor
84pub fn build_query<'a>(
85    db: &'a RelationalDB,
86    tx: &'a TxMode<'a>,
87    query: &'a QueryExpr,
88    sources: &mut impl SourceProvider<'a>,
89) -> Box<IterRows<'a>> {
90    let db_table = query.source.is_db_table();
91
92    // We're incrementally building a query iterator by applying each operation in the `query.query`.
93    // Most such operations will modify their parent, but certain operations (i.e. `IndexJoin`s)
94    // are only valid as the first operation in the list,
95    // and construct a new base query.
96    //
97    // Branches which use `result` will do `unwrap_or_else(|| get_table(ctx, db, tx, &query.table, sources))`
98    // to get an `IterRows` defaulting to the `query.table`.
99    //
100    // Branches which do not use the `result` will assert that it is `None`,
101    // i.e. that they are the first operator.
102    //
103    // TODO(bikeshedding): Avoid duplication of the ugly `result.take().map(...).unwrap_or_else(...)?` expr?
104    // TODO(bikeshedding): Refactor `QueryExpr` to separate `IndexJoin` from other `Query` variants,
105    //   removing the need for this convoluted logic?
106    let mut result = None;
107
108    let result_or_base = |sources: &mut _, result: &mut Option<_>| {
109        result
110            .take()
111            .unwrap_or_else(|| get_table(db, tx, &query.source, sources))
112    };
113
114    for op in &query.query {
115        result = Some(match op {
116            Query::IndexScan(IndexScan { table, columns, bounds }) if db_table => {
117                if !bound_is_satisfiable(&bounds.0, &bounds.1) {
118                    // If the bound is impossible to satisfy
119                    // because the lower bound is greater than the upper bound, or both bounds are excluded and equal,
120                    // return an empty iterator.
121                    // This avoids a panic in `BTreeMap`'s `NodeRef::search_tree_for_bifurcation`,
122                    // which is very unhappy about unsatisfiable bounds.
123                    Box::new(EmptyRelOps) as Box<IterRows<'a>>
124                } else {
125                    let bounds = (bounds.start_bound(), bounds.end_bound());
126                    iter_by_col_range(db, tx, table, columns.clone(), bounds)
127                }
128            }
129            Query::IndexScan(index_scan) => {
130                let result = result_or_base(sources, &mut result);
131                let cols = &index_scan.columns;
132                let bounds = &index_scan.bounds;
133
134                if !bound_is_satisfiable(&bounds.0, &bounds.1) {
135                    // If the bound is impossible to satisfy
136                    // because the lower bound is greater than the upper bound, or both bounds are excluded and equal,
137                    // return an empty iterator.
138                    // Unlike the above case, this is not necessary, as the below `select` will never panic,
139                    // but it's still nice to avoid needlessly traversing a bunch of rows.
140                    // TODO: We should change the compiler to not emit an `IndexScan` in this case,
141                    // so that this branch is unreachable.
142                    // The current behavior is a hack
143                    // because this patch was written (2024-04-01 pgoldman) a short time before the BitCraft alpha,
144                    // and a more invasive change was infeasible.
145                    Box::new(EmptyRelOps) as Box<IterRows<'a>>
146                } else if let Some(head) = cols.as_singleton() {
147                    // For singleton constraints, we compare the column directly against `bounds`.
148                    let head = head.idx();
149                    let iter = result.select(move |row| bounds.contains(&*row.read_column(head).unwrap()));
150                    Box::new(iter) as Box<IterRows<'a>>
151                } else {
152                    // For multi-col constraints, these are stored as bounds of product values,
153                    // so we need to project these into single-col bounds and compare against the column.
154                    // Project start/end `Bound<AV>`s to `Bound<Vec<AV>>`s.
155                    let start_bound = bounds.0.as_ref().map(|av| &av.as_product().unwrap().elements);
156                    let end_bound = bounds.1.as_ref().map(|av| &av.as_product().unwrap().elements);
157                    // Construct the query:
158                    Box::new(result.select(move |row| {
159                        // Go through each column position,
160                        // project to a `Bound<AV>` for the position,
161                        // and compare against the column in the row.
162                        // All columns must match to include the row,
163                        // which is essentially the same as a big `AND` of `ColumnOp`s.
164                        cols.iter().enumerate().all(|(idx, col)| {
165                            let start_bound = start_bound.map(|pv| &pv[idx]);
166                            let end_bound = end_bound.map(|pv| &pv[idx]);
167                            let read_col = row.read_column(col.idx()).unwrap();
168                            (start_bound, end_bound).contains(&*read_col)
169                        })
170                    }))
171                }
172            }
173            Query::IndexJoin(_) if result.is_some() => panic!("Invalid query: `IndexJoin` must be the first operator"),
174            Query::IndexJoin(IndexJoin {
175                probe_side,
176                probe_col,
177                index_side,
178                index_select,
179                index_col,
180                return_index_rows,
181            }) => {
182                let probe_side = build_query(db, tx, probe_side, sources);
183                // The compiler guarantees that the index side is a db table,
184                // and therefore this unwrap is always safe.
185                let index_table = index_side.table_id().unwrap();
186
187                if *return_index_rows {
188                    index_semi_join_left(db, tx, probe_side, *probe_col, index_select, index_table, *index_col)
189                } else {
190                    index_semi_join_right(db, tx, probe_side, *probe_col, index_select, index_table, *index_col)
191                }
192            }
193            Query::Select(cmp) => build_select(result_or_base(sources, &mut result), cmp),
194            Query::Project(proj) => build_project(result_or_base(sources, &mut result), proj),
195            Query::JoinInner(join) => join_inner(
196                result_or_base(sources, &mut result),
197                build_query(db, tx, &join.rhs, sources),
198                join,
199            ),
200        })
201    }
202
203    result_or_base(sources, &mut result)
204}
205
206/// Resolve `query` to a table iterator,
207/// either taken from an in-memory table, in the case of [`SourceExpr::InMemory`],
208/// or from a physical table, in the case of [`SourceExpr::DbTable`].
209///
210/// If `query` refers to an in memory table,
211/// `sources` will be used to fetch the table `I`.
212/// Examples of `I` could be derived from `MemTable` or `&'a [ProductValue]`
213/// whereas `sources` could a [`SourceSet`].
214///
215/// On the other hand, if the `query` is a `SourceExpr::DbTable`, `sources` is unused.
216fn get_table<'a>(
217    stdb: &'a RelationalDB,
218    tx: &'a TxMode,
219    query: &'a SourceExpr,
220    sources: &mut impl SourceProvider<'a>,
221) -> Box<IterRows<'a>> {
222    match query {
223        // Extracts an in-memory table with `source_id` from `sources` and builds a query for the table.
224        SourceExpr::InMemory { source_id, .. } => build_iter(
225            sources
226                .take_source(*source_id)
227                .unwrap_or_else(|| {
228                    panic!("Query plan specifies in-mem table for {source_id:?}, but found a `DbTable` or nothing")
229                })
230                .into_iter(),
231        ),
232        SourceExpr::DbTable(db_table) => build_iter_from_db(match tx {
233            TxMode::MutTx(tx) => stdb.iter_mut(tx, db_table.table_id).map(box_iter),
234            TxMode::Tx(tx) => stdb.iter(tx, db_table.table_id).map(box_iter),
235        }),
236    }
237}
238
239fn iter_by_col_range<'a>(
240    db: &'a RelationalDB,
241    tx: &'a TxMode,
242    table: &'a DbTable,
243    columns: ColList,
244    range: impl RangeBounds<AlgebraicValue> + 'a,
245) -> Box<IterRows<'a>> {
246    build_iter_from_db(match tx {
247        TxMode::MutTx(tx) => db
248            .iter_by_col_range_mut(tx, table.table_id, columns, range)
249            .map(box_iter),
250        TxMode::Tx(tx) => db.iter_by_col_range(tx, table.table_id, columns, range).map(box_iter),
251    })
252}
253
254fn build_iter_from_db<'a>(iter: Result<impl 'a + Iterator<Item = RowRef<'a>>, DBError>) -> Box<IterRows<'a>> {
255    build_iter(iter.expect(TABLE_ID_EXPECTED_VALID).map(RelValue::Row))
256}
257
258fn build_iter<'a>(iter: impl 'a + Iterator<Item = RelValue<'a>>) -> Box<IterRows<'a>> {
259    Box::new(RelIter::new(iter)) as Box<IterRows<'_>>
260}
261
262const TABLE_ID_EXPECTED_VALID: &str = "all `table_id`s in compiled query should be valid";
263
264/// An index join operator that returns matching rows from the index side.
265pub struct IndexSemiJoinLeft<'c, Rhs, IndexIter, F> {
266    /// An iterator for the probe side.
267    /// The values returned will be used to probe the index.
268    probe_side: Rhs,
269    /// The column whose value will be used to probe the index.
270    probe_col: ColId,
271    /// An optional predicate to evaluate over the matching rows of the index.
272    index_select: &'c Option<ColumnOp>,
273    /// An iterator for the index side.
274    /// A new iterator will be instantiated for each row on the probe side.
275    index_iter: Option<IndexIter>,
276    /// The function that returns an iterator for the index side.
277    index_function: F,
278}
279
280impl<'a, Rhs, IndexIter, F> IndexSemiJoinLeft<'_, Rhs, IndexIter, F>
281where
282    F: Fn(AlgebraicValue) -> Result<IndexIter, DBError>,
283    IndexIter: Iterator<Item = RowRef<'a>>,
284    Rhs: RelOps<'a>,
285{
286    fn filter(&self, index_row: &RelValue<'_>) -> bool {
287        self.index_select.as_ref().map_or(true, |op| op.eval_bool(index_row))
288    }
289}
290
291impl<'a, Rhs, IndexIter, F> RelOps<'a> for IndexSemiJoinLeft<'_, Rhs, IndexIter, F>
292where
293    F: Fn(AlgebraicValue) -> Result<IndexIter, DBError>,
294    IndexIter: Iterator<Item = RowRef<'a>>,
295    Rhs: RelOps<'a>,
296{
297    fn next(&mut self) -> Option<RelValue<'a>> {
298        // Return a value from the current index iterator, if not exhausted.
299        while let Some(index_row) = self.index_iter.as_mut().and_then(|iter| iter.next()).map(RelValue::Row) {
300            if self.filter(&index_row) {
301                return Some(index_row);
302            }
303        }
304
305        // Otherwise probe the index with a row from the probe side.
306        let probe_col = self.probe_col.idx();
307        while let Some(mut row) = self.probe_side.next() {
308            if let Some(value) = row.read_or_take_column(probe_col) {
309                let mut index_iter = (self.index_function)(value).expect(TABLE_ID_EXPECTED_VALID);
310                while let Some(index_row) = index_iter.next().map(RelValue::Row) {
311                    if self.filter(&index_row) {
312                        self.index_iter = Some(index_iter);
313                        return Some(index_row);
314                    }
315                }
316            }
317        }
318        None
319    }
320}
321
322/// Return an iterator index join operator that returns matching rows from the index side.
323pub fn index_semi_join_left<'a>(
324    db: &'a RelationalDB,
325    tx: &'a TxMode<'a>,
326    probe_side: Box<IterRows<'a>>,
327    probe_col: ColId,
328    index_select: &'a Option<ColumnOp>,
329    index_table: TableId,
330    index_col: ColId,
331) -> Box<IterRows<'a>> {
332    match tx {
333        TxMode::MutTx(tx) => Box::new(IndexSemiJoinLeft {
334            probe_side,
335            probe_col,
336            index_select,
337            index_iter: None,
338            index_function: move |value| db.iter_by_col_range_mut(tx, index_table, index_col, value),
339        }),
340        TxMode::Tx(tx) => Box::new(IndexSemiJoinLeft {
341            probe_side,
342            probe_col,
343            index_select,
344            index_iter: None,
345            index_function: move |value| db.iter_by_col_range(tx, index_table, index_col, value),
346        }),
347    }
348}
349
350static_assert_size!(
351    IndexSemiJoinLeft<
352        Box<IterRows<'static>>,
353        fn(AlgebraicValue) -> Result<IterByColRangeTx<'static, AlgebraicValue>, DBError>,
354        IterByColRangeTx<'static, AlgebraicValue>,
355    >,
356    144
357);
358static_assert_size!(
359    IndexSemiJoinLeft<
360        Box<IterRows<'static>>,
361        fn(AlgebraicValue) -> Result<IterByColRangeMutTx<'static, AlgebraicValue>, DBError>,
362        IterByColRangeMutTx<'static, AlgebraicValue>,
363    >,
364    240
365);
366
367/// An index join operator that returns matching rows from the probe side.
368pub struct IndexSemiJoinRight<'c, Rhs: RelOps<'c>, F> {
369    /// An iterator for the probe side.
370    /// The values returned will be used to probe the index.
371    probe_side: Rhs,
372    /// The column whose value will be used to probe the index.
373    probe_col: ColId,
374    /// An optional predicate to evaluate over the matching rows of the index.
375    index_select: &'c Option<ColumnOp>,
376    /// A function that returns an iterator for the index side.
377    index_function: F,
378}
379
380impl<'a, Rhs: RelOps<'a>, F, IndexIter> IndexSemiJoinRight<'a, Rhs, F>
381where
382    F: Fn(AlgebraicValue) -> Result<IndexIter, DBError>,
383    IndexIter: Iterator<Item = RowRef<'a>>,
384{
385    fn filter(&self, index_row: &RelValue<'_>) -> bool {
386        self.index_select.as_ref().map_or(true, |op| op.eval_bool(index_row))
387    }
388}
389
390impl<'a, Rhs: RelOps<'a>, F, IndexIter> RelOps<'a> for IndexSemiJoinRight<'a, Rhs, F>
391where
392    F: Fn(AlgebraicValue) -> Result<IndexIter, DBError>,
393    IndexIter: Iterator<Item = RowRef<'a>>,
394{
395    fn next(&mut self) -> Option<RelValue<'a>> {
396        // Otherwise probe the index with a row from the probe side.
397        let probe_col = self.probe_col.idx();
398        while let Some(mut row) = self.probe_side.next() {
399            if let Some(value) = row.read_or_take_column(probe_col) {
400                let mut index_iter = (self.index_function)(value).expect(TABLE_ID_EXPECTED_VALID);
401                while let Some(index_row) = index_iter.next().map(RelValue::Row) {
402                    if self.filter(&index_row) {
403                        return Some(row);
404                    }
405                }
406            }
407        }
408        None
409    }
410}
411
412/// Return an iterator index join operator that returns matching rows from the probe side.
413pub fn index_semi_join_right<'a>(
414    db: &'a RelationalDB,
415    tx: &'a TxMode<'a>,
416    probe_side: Box<IterRows<'a>>,
417    probe_col: ColId,
418    index_select: &'a Option<ColumnOp>,
419    index_table: TableId,
420    index_col: ColId,
421) -> Box<IterRows<'a>> {
422    match tx {
423        TxMode::MutTx(tx) => Box::new(IndexSemiJoinRight {
424            probe_side,
425            probe_col,
426            index_select,
427            index_function: move |value| db.iter_by_col_range_mut(tx, index_table, index_col, value),
428        }),
429        TxMode::Tx(tx) => Box::new(IndexSemiJoinRight {
430            probe_side,
431            probe_col,
432            index_select,
433            index_function: move |value| db.iter_by_col_range(tx, index_table, index_col, value),
434        }),
435    }
436}
437static_assert_size!(
438    IndexSemiJoinRight<
439        Box<IterRows<'static>>,
440        fn(AlgebraicValue) -> Result<IterByColRangeTx<'static, AlgebraicValue>, DBError>,
441    >,
442    40
443);
444static_assert_size!(
445    IndexSemiJoinRight<
446        Box<IterRows<'static>>,
447        fn(AlgebraicValue) -> Result<IterByColRangeMutTx<'static, AlgebraicValue>, DBError>,
448    >,
449    40
450);
451
452/// A [ProgramVm] implementation that carry a [RelationalDB] for it
453/// query execution
454pub struct DbProgram<'db, 'tx> {
455    pub(crate) db: &'db RelationalDB,
456    pub(crate) tx: &'tx mut TxMode<'tx>,
457    pub(crate) auth: AuthCtx,
458}
459
460/// If the subscriber is not the database owner,
461/// reject the request if the estimated cardinality exceeds the limit.
462pub fn check_row_limit<Query>(
463    queries: &[Query],
464    db: &RelationalDB,
465    tx: &TxId,
466    row_est: impl Fn(&Query, &TxId) -> u64,
467    auth: &AuthCtx,
468) -> Result<(), DBError> {
469    if auth.caller != auth.owner {
470        if let Some(limit) = db.row_limit(tx)? {
471            let mut estimate: u64 = 0;
472            for query in queries {
473                estimate = estimate.saturating_add(row_est(query, tx));
474            }
475            if estimate > limit {
476                return Err(DBError::Other(anyhow::anyhow!(
477                    "Estimated cardinality ({estimate} rows) exceeds limit ({limit} rows)"
478                )));
479            }
480        }
481    }
482    Ok(())
483}
484
485impl<'db, 'tx> DbProgram<'db, 'tx> {
486    pub fn new(db: &'db RelationalDB, tx: &'tx mut TxMode<'tx>, auth: AuthCtx) -> Self {
487        Self { db, tx, auth }
488    }
489
490    fn _eval_query<const N: usize>(&mut self, query: &QueryExpr, sources: Sources<'_, N>) -> Result<Code, ErrorVm> {
491        if let TxMode::Tx(tx) = self.tx {
492            check_row_limit(
493                &[query],
494                self.db,
495                tx,
496                |expr, tx| estimation::num_rows(tx, expr),
497                &self.auth,
498            )?;
499        }
500
501        let table_access = query.source.table_access();
502        tracing::trace!(table = query.source.table_name());
503
504        let head = query.head().clone();
505        let rows = build_query(self.db, self.tx, query, &mut |id| {
506            sources.take(id).map(|mt| mt.into_iter().map(RelValue::Projection))
507        })
508        .collect_vec(|row| row.into_product_value());
509
510        Ok(Code::Table(MemTable::new(head, table_access, rows)))
511    }
512
513    // TODO(centril): investigate taking bsatn as input instead.
514    fn _execute_insert(&mut self, table: &DbTable, inserts: Vec<ProductValue>) -> Result<Code, ErrorVm> {
515        let tx = self.tx.unwrap_mut();
516        let mut scratch = Vec::new();
517        for row in &inserts {
518            row.encode(&mut scratch);
519            self.db.insert(tx, table.table_id, &scratch)?;
520            scratch.clear();
521        }
522        Ok(Code::Pass(Some(Update {
523            table_id: table.table_id,
524            table_name: table.head.table_name.clone(),
525            inserts,
526            deletes: Vec::default(),
527        })))
528    }
529
530    fn _execute_update<const N: usize>(
531        &mut self,
532        delete: &QueryExpr,
533        mut assigns: IntMap<ColId, ColExpr>,
534        sources: Sources<'_, N>,
535    ) -> Result<Code, ErrorVm> {
536        let result = self._eval_query(delete, sources)?;
537        let Code::Table(deleted) = result else {
538            return Ok(result);
539        };
540
541        let table = delete
542            .source
543            .get_db_table()
544            .expect("source for Update should be a DbTable");
545
546        self._execute_delete(table, deleted.data.clone())?;
547
548        // Replace the columns in the matched rows with the assigned
549        // values. No typechecking is performed here, nor that all
550        // assignments are consumed.
551        let deletes = deleted.data.clone();
552        let exprs: Vec<Option<ColExpr>> = (0..table.head.fields.len())
553            .map(ColId::from)
554            .map(|c| assigns.remove(&c))
555            .collect();
556
557        let insert_rows = deleted
558            .data
559            .into_iter()
560            .map(|row| {
561                let elements = row
562                    .into_iter()
563                    .zip(&exprs)
564                    .map(|(val, expr)| {
565                        if let Some(ColExpr::Value(assigned)) = expr {
566                            assigned.clone()
567                        } else {
568                            val
569                        }
570                    })
571                    .collect();
572
573                ProductValue { elements }
574            })
575            .collect_vec();
576
577        let result = self._execute_insert(table, insert_rows);
578        let Ok(Code::Pass(Some(insert))) = result else {
579            return result;
580        };
581
582        Ok(Code::Pass(Some(Update { deletes, ..insert })))
583    }
584
585    fn _execute_delete(&mut self, table: &DbTable, rows: Vec<ProductValue>) -> Result<Code, ErrorVm> {
586        let deletes = rows.clone();
587        self.db.delete_by_rel(self.tx.unwrap_mut(), table.table_id, rows);
588
589        Ok(Code::Pass(Some(Update {
590            table_id: table.table_id,
591            table_name: table.head.table_name.clone(),
592            inserts: Vec::default(),
593            deletes,
594        })))
595    }
596
597    fn _delete_query<const N: usize>(&mut self, query: &QueryExpr, sources: Sources<'_, N>) -> Result<Code, ErrorVm> {
598        match self._eval_query(query, sources)? {
599            Code::Table(result) => self._execute_delete(query.source.get_db_table().unwrap(), result.data),
600            r => Ok(r),
601        }
602    }
603
604    fn _set_var(&mut self, name: String, literal: String) -> Result<Code, ErrorVm> {
605        let tx = self.tx.unwrap_mut();
606        self.db.write_var(tx, StVarName::from_str(&name)?, &literal)?;
607        Ok(Code::Pass(None))
608    }
609
610    fn _read_var(&self, name: String) -> Result<Code, ErrorVm> {
611        fn read_key_into_table(env: &DbProgram, name: &str) -> Result<MemTable, ErrorVm> {
612            if let TxMode::Tx(tx) = &env.tx {
613                let name = StVarName::from_str(name)?;
614                if let Some(value) = env.db.read_var(tx, name)? {
615                    return Ok(MemTable::from_iter(
616                        Arc::new(st_var_schema().into()),
617                        [ProductValue::from(StVarRow { name, value })],
618                    ));
619                }
620            }
621            Ok(MemTable::from_iter(Arc::new(st_var_schema().into()), []))
622        }
623        Ok(Code::Table(read_key_into_table(self, &name)?))
624    }
625}
626
627impl ProgramVm for DbProgram<'_, '_> {
628    // Safety: For DbProgram with tx = TxMode::Tx variant, all queries must match to CrudCode::Query and no other branch.
629    fn eval_query<const N: usize>(&mut self, query: CrudExpr, sources: Sources<'_, N>) -> Result<Code, ErrorVm> {
630        query.check_auth(self.auth.owner, self.auth.caller)?;
631
632        match query {
633            CrudExpr::Query(query) => self._eval_query(&query, sources),
634            CrudExpr::Insert { table, rows } => self._execute_insert(&table, rows),
635            CrudExpr::Update { delete, assignments } => self._execute_update(&delete, assignments, sources),
636            CrudExpr::Delete { query } => self._delete_query(&query, sources),
637            CrudExpr::SetVar { name, literal } => self._set_var(name, literal),
638            CrudExpr::ReadVar { name } => self._read_var(name),
639        }
640    }
641}
642
643#[cfg(test)]
644pub(crate) mod tests {
645    use super::*;
646    use crate::db::datastore::system_tables::{
647        StColumnFields, StColumnRow, StFields as _, StIndexAlgorithm, StIndexFields, StIndexRow, StSequenceFields,
648        StSequenceRow, StTableFields, StTableRow, ST_COLUMN_ID, ST_COLUMN_NAME, ST_INDEX_ID, ST_INDEX_NAME,
649        ST_RESERVED_SEQUENCE_RANGE, ST_SEQUENCE_ID, ST_SEQUENCE_NAME, ST_TABLE_ID, ST_TABLE_NAME,
650    };
651    use crate::db::relational_db::tests_utils::{begin_tx, insert, with_auto_commit, with_read_only, TestDB};
652    use pretty_assertions::assert_eq;
653    use spacetimedb_lib::db::auth::{StAccess, StTableType};
654    use spacetimedb_lib::error::ResultTest;
655    use spacetimedb_lib::relation::{FieldName, Header};
656    use spacetimedb_sats::{product, AlgebraicType, ProductType, ProductValue};
657    use spacetimedb_schema::def::{BTreeAlgorithm, IndexAlgorithm};
658    use spacetimedb_schema::schema::{ColumnSchema, IndexSchema, TableSchema};
659    use spacetimedb_vm::eval::run_ast;
660    use spacetimedb_vm::eval::test_helpers::{mem_table, mem_table_one_u64, scalar};
661    use spacetimedb_vm::operator::OpCmp;
662    use std::sync::Arc;
663
664    pub(crate) fn create_table_with_rows(
665        db: &RelationalDB,
666        tx: &mut MutTx,
667        table_name: &str,
668        schema: ProductType,
669        rows: &[ProductValue],
670        access: StAccess,
671    ) -> ResultTest<Arc<TableSchema>> {
672        let columns = schema
673            .elements
674            .iter()
675            .enumerate()
676            .map(|(i, element)| ColumnSchema {
677                table_id: TableId::SENTINEL,
678                col_name: element.name.as_ref().unwrap().clone(),
679                col_type: element.algebraic_type.clone(),
680                col_pos: ColId(i as _),
681            })
682            .collect();
683
684        let table_id = db.create_table(
685            tx,
686            TableSchema::new(
687                TableId::SENTINEL,
688                table_name.into(),
689                columns,
690                vec![],
691                vec![],
692                vec![],
693                StTableType::User,
694                access,
695                None,
696                None,
697            ),
698        )?;
699        let schema = db.schema_for_table_mut(tx, table_id)?;
700
701        for row in rows {
702            insert(db, tx, table_id, &row)?;
703        }
704
705        Ok(schema)
706    }
707
708    /// Creates a table "inventory" with `(inventory_id: u64, name : String)` as columns.
709    fn create_inv_table(db: &RelationalDB, tx: &mut MutTx) -> ResultTest<(Arc<TableSchema>, ProductValue)> {
710        let schema_ty = ProductType::from([("inventory_id", AlgebraicType::U64), ("name", AlgebraicType::String)]);
711        let row = product!(1u64, "health");
712        let schema = create_table_with_rows(db, tx, "inventory", schema_ty.clone(), &[row.clone()], StAccess::Public)?;
713        Ok((schema, row))
714    }
715
716    fn run_query<const N: usize>(
717        db: &RelationalDB,
718        q: QueryExpr,
719        sources: SourceSet<Vec<ProductValue>, N>,
720    ) -> MemTable {
721        with_read_only(db, |tx| {
722            let mut tx_mode = (&*tx).into();
723            let p = &mut DbProgram::new(db, &mut tx_mode, AuthCtx::for_testing());
724            match run_ast(p, q.into(), sources) {
725                Code::Table(x) => x,
726                x => panic!("invalid result {x}"),
727            }
728        })
729    }
730
731    #[test]
732    fn test_db_query_inner_join() -> ResultTest<()> {
733        let stdb = TestDB::durable()?;
734
735        let (schema, _) = with_auto_commit(&stdb, |tx| create_inv_table(&stdb, tx))?;
736        let table_id = schema.table_id;
737
738        let data = mem_table_one_u64(u32::MAX.into());
739        let mut sources = SourceSet::<_, 1>::empty();
740        let rhs_source_expr = sources.add_mem_table(data);
741        let q = QueryExpr::new(&*schema).with_join_inner(rhs_source_expr, 0.into(), 0.into(), false);
742        let result = run_query(&stdb, q, sources);
743
744        // The expected result.
745        let inv = ProductType::from([AlgebraicType::U64, AlgebraicType::String, AlgebraicType::U64]);
746        let row = product![1u64, "health", 1u64];
747        let input = mem_table(table_id, inv, vec![row]);
748
749        assert_eq!(result.data, input.data, "Inventory");
750
751        Ok(())
752    }
753
754    #[test]
755    fn test_db_query_semijoin() -> ResultTest<()> {
756        let stdb = TestDB::durable()?;
757
758        let (schema, row) = with_auto_commit(&stdb, |tx| create_inv_table(&stdb, tx))?;
759
760        let data = mem_table_one_u64(u32::MAX.into());
761        let mut sources = SourceSet::<_, 1>::empty();
762        let rhs_source_expr = sources.add_mem_table(data);
763        let q = QueryExpr::new(&*schema).with_join_inner(rhs_source_expr, 0.into(), 0.into(), true);
764        let result = run_query(&stdb, q, sources);
765
766        // The expected result.
767        let input = mem_table(schema.table_id, schema.get_row_type().clone(), vec![row]);
768        assert_eq!(result.data, input.data, "Inventory");
769
770        Ok(())
771    }
772
773    fn check_catalog(db: &RelationalDB, name: &str, row: ProductValue, q: QueryExpr, schema: &TableSchema) {
774        let result = run_query(db, q, [].into());
775        let input = MemTable::from_iter(Header::from(schema).into(), [row]);
776        assert_eq!(result, input, "{}", name);
777    }
778
779    #[test]
780    fn test_query_catalog_tables() -> ResultTest<()> {
781        let stdb = TestDB::durable()?;
782        let schema = &*stdb.schema_for_table(&begin_tx(&stdb), ST_TABLE_ID).unwrap();
783
784        let q = QueryExpr::new(schema)
785            .with_select_cmp(
786                OpCmp::Eq,
787                FieldName::new(ST_TABLE_ID, StTableFields::TableName.into()),
788                scalar(ST_TABLE_NAME),
789            )
790            .unwrap();
791        let st_table_row = StTableRow {
792            table_id: ST_TABLE_ID,
793            table_name: ST_TABLE_NAME.into(),
794            table_type: StTableType::System,
795            table_access: StAccess::Public,
796            table_primary_key: Some(StTableFields::TableId.into()),
797        }
798        .into();
799        check_catalog(&stdb, ST_TABLE_NAME, st_table_row, q, schema);
800
801        Ok(())
802    }
803
804    #[test]
805    fn test_query_catalog_columns() -> ResultTest<()> {
806        let stdb = TestDB::durable()?;
807        let schema = &*stdb.schema_for_table(&begin_tx(&stdb), ST_COLUMN_ID).unwrap();
808
809        let q = QueryExpr::new(schema)
810            .with_select_cmp(
811                OpCmp::Eq,
812                FieldName::new(ST_COLUMN_ID, StColumnFields::TableId.into()),
813                scalar(ST_COLUMN_ID),
814            )
815            .unwrap()
816            .with_select_cmp(
817                OpCmp::Eq,
818                FieldName::new(ST_COLUMN_ID, StColumnFields::ColPos.into()),
819                scalar(StColumnFields::TableId as u16),
820            )
821            .unwrap();
822        let st_column_row = StColumnRow {
823            table_id: ST_COLUMN_ID,
824            col_pos: StColumnFields::TableId.col_id(),
825            col_name: StColumnFields::TableId.col_name(),
826            col_type: AlgebraicType::U32.into(),
827        }
828        .into();
829        check_catalog(&stdb, ST_COLUMN_NAME, st_column_row, q, schema);
830
831        Ok(())
832    }
833
834    #[test]
835    fn test_query_catalog_indexes() -> ResultTest<()> {
836        let db = TestDB::durable()?;
837
838        let (schema, _) = with_auto_commit(&db, |tx| create_inv_table(&db, tx))?;
839        let table_id = schema.table_id;
840        let columns = ColList::from(ColId(0));
841        let index_name = "idx_1";
842        let is_unique = false;
843
844        let index = IndexSchema {
845            table_id,
846            index_id: IndexId::SENTINEL,
847            index_name: index_name.into(),
848            index_algorithm: IndexAlgorithm::BTree(BTreeAlgorithm {
849                columns: columns.clone(),
850            }),
851        };
852        let index_id = with_auto_commit(&db, |tx| db.create_index(tx, index, is_unique))?;
853
854        let indexes_schema = &*db.schema_for_table(&begin_tx(&db), ST_INDEX_ID).unwrap();
855        let q = QueryExpr::new(indexes_schema)
856            .with_select_cmp(
857                OpCmp::Eq,
858                FieldName::new(ST_INDEX_ID, StIndexFields::IndexName.into()),
859                scalar(index_name),
860            )
861            .unwrap();
862
863        let st_index_row = StIndexRow {
864            index_id,
865            index_name: index_name.into(),
866            table_id,
867            index_algorithm: StIndexAlgorithm::BTree { columns },
868        }
869        .into();
870        check_catalog(&db, ST_INDEX_NAME, st_index_row, q, indexes_schema);
871
872        Ok(())
873    }
874
875    #[test]
876    fn test_query_catalog_sequences() -> ResultTest<()> {
877        let db = TestDB::durable()?;
878
879        let schema = &*db.schema_for_table(&begin_tx(&db), ST_SEQUENCE_ID).unwrap();
880        let q = QueryExpr::new(schema)
881            .with_select_cmp(
882                OpCmp::Eq,
883                FieldName::new(ST_SEQUENCE_ID, StSequenceFields::TableId.into()),
884                scalar(ST_SEQUENCE_ID),
885            )
886            .unwrap();
887        let st_sequence_row = StSequenceRow {
888            sequence_id: 5.into(),
889            sequence_name: "st_sequence_sequence_id_seq".into(),
890            table_id: ST_SEQUENCE_ID,
891            col_pos: 0.into(),
892            increment: 1,
893            start: ST_RESERVED_SEQUENCE_RANGE as i128 + 1,
894            min_value: 1,
895            max_value: i128::MAX,
896            allocated: ST_RESERVED_SEQUENCE_RANGE as i128,
897        }
898        .into();
899        check_catalog(&db, ST_SEQUENCE_NAME, st_sequence_row, q, schema);
900
901        Ok(())
902    }
903}