spacetimedb_vm/
expr.rs

1use crate::errors::{ErrorKind, ErrorLang};
2use crate::operator::{OpCmp, OpLogic, OpQuery};
3use crate::relation::{MemTable, RelValue};
4use arrayvec::ArrayVec;
5use core::slice::from_ref;
6use derive_more::From;
7use itertools::Itertools;
8use smallvec::SmallVec;
9use spacetimedb_data_structures::map::{HashSet, IntMap};
10use spacetimedb_lib::db::auth::{StAccess, StTableType};
11use spacetimedb_lib::db::error::{AuthError, RelationError};
12use spacetimedb_lib::relation::{ColExpr, DbTable, FieldName, Header};
13use spacetimedb_lib::{AlgebraicType, Identity};
14use spacetimedb_primitives::*;
15use spacetimedb_sats::algebraic_value::AlgebraicValue;
16use spacetimedb_sats::satn::Satn;
17use spacetimedb_sats::ProductValue;
18use spacetimedb_schema::schema::TableSchema;
19use std::borrow::Cow;
20use std::cmp::Reverse;
21use std::collections::btree_map::Entry;
22use std::collections::BTreeMap;
23use std::ops::Bound;
24use std::sync::Arc;
25use std::{fmt, iter, mem};
26
27/// Trait for checking if the `caller` have access to `Self`
28pub trait AuthAccess {
29    fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError>;
30}
31
32#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, From)]
33pub enum FieldExpr {
34    Name(FieldName),
35    Value(AlgebraicValue),
36}
37
38impl FieldExpr {
39    pub fn strip_table(self) -> ColExpr {
40        match self {
41            Self::Name(field) => ColExpr::Col(field.col),
42            Self::Value(value) => ColExpr::Value(value),
43        }
44    }
45
46    pub fn name_to_col(self, head: &Header) -> Result<ColExpr, RelationError> {
47        match self {
48            Self::Value(val) => Ok(ColExpr::Value(val)),
49            Self::Name(field) => head.column_pos_or_err(field).map(ColExpr::Col),
50        }
51    }
52}
53
54impl fmt::Display for FieldExpr {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        match self {
57            FieldExpr::Name(x) => write!(f, "{x}"),
58            FieldExpr::Value(x) => write!(f, "{}", x.to_satn()),
59        }
60    }
61}
62
63#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, From)]
64pub enum FieldOp {
65    #[from]
66    Field(FieldExpr),
67    Cmp {
68        op: OpQuery,
69        lhs: Box<FieldOp>,
70        rhs: Box<FieldOp>,
71    },
72}
73
74type FieldOpFlat = SmallVec<[FieldOp; 1]>;
75
76impl FieldOp {
77    pub fn new(op: OpQuery, lhs: Self, rhs: Self) -> Self {
78        Self::Cmp {
79            op,
80            lhs: Box::new(lhs),
81            rhs: Box::new(rhs),
82        }
83    }
84
85    pub fn cmp(field: impl Into<FieldName>, op: OpCmp, value: impl Into<AlgebraicValue>) -> Self {
86        Self::new(
87            OpQuery::Cmp(op),
88            Self::Field(FieldExpr::Name(field.into())),
89            Self::Field(FieldExpr::Value(value.into())),
90        )
91    }
92
93    pub fn names_to_cols(self, head: &Header) -> Result<ColumnOp, RelationError> {
94        match self {
95            Self::Field(field) => field.name_to_col(head).map(ColumnOp::from),
96            Self::Cmp { op, lhs, rhs } => {
97                let lhs = lhs.names_to_cols(head)?;
98                let rhs = rhs.names_to_cols(head)?;
99                Ok(ColumnOp::new(op, lhs, rhs))
100            }
101        }
102    }
103
104    /// Flattens a nested conjunction of AND expressions.
105    ///
106    /// For example, `a = 1 AND b = 2 AND c = 3` becomes `[a = 1, b = 2, c = 3]`.
107    ///
108    /// This helps with splitting the kinds of `queries`,
109    /// that *could* be answered by a `index`,
110    /// from the ones that need to be executed with a `scan`.
111    pub fn flatten_ands(self) -> FieldOpFlat {
112        fn fill_vec(buf: &mut FieldOpFlat, op: FieldOp) {
113            match op {
114                FieldOp::Cmp {
115                    op: OpQuery::Logic(OpLogic::And),
116                    lhs,
117                    rhs,
118                } => {
119                    fill_vec(buf, *lhs);
120                    fill_vec(buf, *rhs);
121                }
122                op => buf.push(op),
123            }
124        }
125        let mut buf = SmallVec::new();
126        fill_vec(&mut buf, self);
127        buf
128    }
129}
130
131impl fmt::Display for FieldOp {
132    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133        match self {
134            Self::Field(x) => {
135                write!(f, "{}", x)
136            }
137            Self::Cmp { op, lhs, rhs } => {
138                write!(f, "{} {} {}", lhs, op, rhs)
139            }
140        }
141    }
142}
143
144#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, From)]
145pub enum ColumnOp {
146    /// The value is the the column at `to_index(col)` in the row, i.e., `row.read_column(to_index(col))`.
147    #[from]
148    Col(ColId),
149    /// The value is the embedded value.
150    #[from]
151    Val(AlgebraicValue),
152    /// The value is `eval_cmp(cmp, row.read_column(to_index(lhs)), rhs)`.
153    /// This is an optimized version of `Cmp`, avoiding one depth of nesting.
154    ColCmpVal {
155        lhs: ColId,
156        cmp: OpCmp,
157        rhs: AlgebraicValue,
158    },
159    /// The value is `eval_cmp(cmp, eval(row, lhs), eval(row, rhs))`.
160    Cmp {
161        lhs: Box<ColumnOp>,
162        cmp: OpCmp,
163        rhs: Box<ColumnOp>,
164    },
165    /// Let `conds = eval(row, operands_i)`.
166    /// For `op = OpLogic::And`, the value is `all(conds)`.
167    /// For `op = OpLogic::Or`, the value is `any(conds)`.
168    Log { op: OpLogic, operands: Box<[ColumnOp]> },
169}
170
171impl ColumnOp {
172    pub fn new(op: OpQuery, lhs: Self, rhs: Self) -> Self {
173        match op {
174            OpQuery::Cmp(cmp) => match (lhs, rhs) {
175                (ColumnOp::Col(lhs), ColumnOp::Val(rhs)) => Self::cmp(lhs, cmp, rhs),
176                (lhs, rhs) => Self::Cmp {
177                    lhs: Box::new(lhs),
178                    cmp,
179                    rhs: Box::new(rhs),
180                },
181            },
182            OpQuery::Logic(op) => Self::Log {
183                op,
184                operands: [lhs, rhs].into(),
185            },
186        }
187    }
188
189    pub fn cmp(col: impl Into<ColId>, cmp: OpCmp, val: impl Into<AlgebraicValue>) -> Self {
190        let lhs = col.into();
191        let rhs = val.into();
192        Self::ColCmpVal { lhs, cmp, rhs }
193    }
194
195    /// Returns a new op where `lhs` and `rhs` are logically AND-ed together.
196    fn and(lhs: Self, rhs: Self) -> Self {
197        let ands = |operands| {
198            let op = OpLogic::And;
199            Self::Log { op, operands }
200        };
201
202        match (lhs, rhs) {
203            // Merge a pair of ⋀ into a single ⋀.
204            (
205                Self::Log {
206                    op: OpLogic::And,
207                    operands: lhs,
208                },
209                Self::Log {
210                    op: OpLogic::And,
211                    operands: rhs,
212                },
213            ) => {
214                let mut operands = Vec::from(lhs);
215                operands.append(&mut Vec::from(rhs));
216                ands(operands.into())
217            }
218            // Merge ⋀ with a single operand.
219            (
220                Self::Log {
221                    op: OpLogic::And,
222                    operands: lhs,
223                },
224                rhs,
225            ) => {
226                let mut operands = Vec::from(lhs);
227                operands.push(rhs);
228                ands(operands.into())
229            }
230            // And together lhs and rhs.
231            (lhs, rhs) => ands([lhs, rhs].into()),
232        }
233    }
234
235    /// Returns an op where `col_i op value_i` are all `AND`ed together.
236    fn and_cmp(op: OpCmp, cols: &ColList, value: AlgebraicValue) -> Self {
237        let cmp = |(col, value): (ColId, _)| Self::cmp(col, op, value);
238
239        // For singleton constraints, the `value` must be used directly.
240        if let Some(head) = cols.as_singleton() {
241            return cmp((head, value));
242        }
243
244        // Otherwise, pair column ids and product fields together.
245        let operands = cols.iter().zip(value.into_product().unwrap()).map(cmp).collect();
246        Self::Log {
247            op: OpLogic::And,
248            operands,
249        }
250    }
251
252    /// Returns an op where `cols` must be within bounds.
253    /// This handles both the case of single-col bounds and multi-col bounds.
254    fn from_op_col_bounds(cols: &ColList, bounds: (Bound<AlgebraicValue>, Bound<AlgebraicValue>)) -> Self {
255        let (cmp, value) = match bounds {
256            // Equality; field <= value && field >= value <=> field = value
257            (Bound::Included(a), Bound::Included(b)) if a == b => (OpCmp::Eq, a),
258            // Inclusive lower bound => field >= value
259            (Bound::Included(value), Bound::Unbounded) => (OpCmp::GtEq, value),
260            // Exclusive lower bound => field > value
261            (Bound::Excluded(value), Bound::Unbounded) => (OpCmp::Gt, value),
262            // Inclusive upper bound => field <= value
263            (Bound::Unbounded, Bound::Included(value)) => (OpCmp::LtEq, value),
264            // Exclusive upper bound => field < value
265            (Bound::Unbounded, Bound::Excluded(value)) => (OpCmp::Lt, value),
266            (Bound::Unbounded, Bound::Unbounded) => unreachable!(),
267            (lower_bound, upper_bound) => {
268                let lhs = Self::from_op_col_bounds(cols, (lower_bound, Bound::Unbounded));
269                let rhs = Self::from_op_col_bounds(cols, (Bound::Unbounded, upper_bound));
270                return ColumnOp::and(lhs, rhs);
271            }
272        };
273        ColumnOp::and_cmp(cmp, cols, value)
274    }
275
276    /// Converts `self` to the lhs `ColId` and the `OpCmp` if this is a comparison.
277    fn as_col_cmp(&self) -> Option<(ColId, OpCmp)> {
278        match self {
279            Self::ColCmpVal { lhs, cmp, rhs: _ } => Some((*lhs, *cmp)),
280            Self::Cmp { lhs, cmp, rhs: _ } => match &**lhs {
281                ColumnOp::Col(col) => Some((*col, *cmp)),
282                _ => None,
283            },
284            _ => None,
285        }
286    }
287
288    /// Evaluate `self` where `ColId`s are translated to values by indexing into `row`.
289    fn eval<'a>(&'a self, row: &'a RelValue<'_>) -> Cow<'a, AlgebraicValue> {
290        let into = |b| Cow::Owned(AlgebraicValue::Bool(b));
291
292        match self {
293            Self::Col(col) => row.read_column(col.idx()).unwrap(),
294            Self::Val(val) => Cow::Borrowed(val),
295            Self::ColCmpVal { lhs, cmp, rhs } => into(Self::eval_cmp_col_val(row, *cmp, *lhs, rhs)),
296            Self::Cmp { lhs, cmp, rhs } => into(Self::eval_cmp(row, *cmp, lhs, rhs)),
297            Self::Log { op, operands } => into(Self::eval_log(row, *op, operands)),
298        }
299    }
300
301    /// Evaluate `self` to a `bool` where `ColId`s are translated to values by indexing into `row`.
302    pub fn eval_bool(&self, row: &RelValue<'_>) -> bool {
303        match self {
304            Self::Col(col) => *row.read_column(col.idx()).unwrap().as_bool().unwrap(),
305            Self::Val(val) => *val.as_bool().unwrap(),
306            Self::ColCmpVal { lhs, cmp, rhs } => Self::eval_cmp_col_val(row, *cmp, *lhs, rhs),
307            Self::Cmp { lhs, cmp, rhs } => Self::eval_cmp(row, *cmp, lhs, rhs),
308            Self::Log { op, operands } => Self::eval_log(row, *op, operands),
309        }
310    }
311
312    /// Evaluates `lhs cmp rhs` according to `Ord for AlgebraicValue`.
313    fn eval_op_cmp(cmp: OpCmp, lhs: &AlgebraicValue, rhs: &AlgebraicValue) -> bool {
314        match cmp {
315            OpCmp::Eq => lhs == rhs,
316            OpCmp::NotEq => lhs != rhs,
317            OpCmp::Lt => lhs < rhs,
318            OpCmp::LtEq => lhs <= rhs,
319            OpCmp::Gt => lhs > rhs,
320            OpCmp::GtEq => lhs >= rhs,
321        }
322    }
323
324    /// Evaluates `lhs` to an [`AlgebraicValue`] and runs the comparison `lhs_av op rhs`.
325    fn eval_cmp_col_val(row: &RelValue<'_>, cmp: OpCmp, lhs: ColId, rhs: &AlgebraicValue) -> bool {
326        let lhs = row.read_column(lhs.idx()).unwrap();
327        Self::eval_op_cmp(cmp, &lhs, rhs)
328    }
329
330    /// Evaluates `lhs` and `rhs` to [`AlgebraicValue`]s
331    /// and then runs the comparison `cmp` on them,
332    /// returning the final `bool` result.
333    fn eval_cmp(row: &RelValue<'_>, cmp: OpCmp, lhs: &Self, rhs: &Self) -> bool {
334        let lhs = lhs.eval(row);
335        let rhs = rhs.eval(row);
336        Self::eval_op_cmp(cmp, &lhs, &rhs)
337    }
338
339    /// Evaluates if
340    /// - `op = OpLogic::And` the conjunctions (`⋀`) of `opers`
341    /// - `op = OpLogic::Or` the disjunctions (`⋁`) of `opers`
342    fn eval_log(row: &RelValue<'_>, op: OpLogic, opers: &[ColumnOp]) -> bool {
343        match op {
344            OpLogic::And => opers.iter().all(|o| o.eval_bool(row)),
345            OpLogic::Or => opers.iter().any(|o| o.eval_bool(row)),
346        }
347    }
348}
349
350impl fmt::Display for ColumnOp {
351    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352        match self {
353            Self::Col(col) => write!(f, "{col}"),
354            Self::Val(val) => write!(f, "{}", val.to_satn()),
355            Self::ColCmpVal { lhs, cmp, rhs } => write!(f, "{lhs} {cmp} {}", rhs.to_satn()),
356            Self::Cmp { cmp, lhs, rhs } => write!(f, "{lhs} {cmp} {rhs}"),
357            Self::Log { op, operands } => write!(f, "{}", operands.iter().format((*op).into())),
358        }
359    }
360}
361
362impl From<ColExpr> for ColumnOp {
363    fn from(ce: ColExpr) -> Self {
364        match ce {
365            ColExpr::Col(c) => c.into(),
366            ColExpr::Value(v) => v.into(),
367        }
368    }
369}
370
371impl From<Query> for Option<ColumnOp> {
372    fn from(value: Query) -> Self {
373        match value {
374            Query::IndexScan(op) => Some(ColumnOp::from_op_col_bounds(&op.columns, op.bounds)),
375            Query::Select(op) => Some(op),
376            _ => None,
377        }
378    }
379}
380
381/// An identifier for a data source (i.e. a table) in a query plan.
382///
383/// When compiling a query plan, rather than embedding the inputs in the plan,
384/// we annotate each input with a `SourceId`, and the compiled plan refers to its inputs by id.
385/// This allows the plan to be re-used with distinct inputs,
386/// assuming the inputs obey the same schema.
387///
388/// Note that re-using a query plan is only a good idea
389/// if the new inputs are similar to those used for compilation
390/// in terms of cardinality and distribution.
391#[derive(Debug, Copy, Clone, PartialEq, Eq, From, Hash)]
392pub struct SourceId(pub usize);
393
394/// Types that relate [`SourceId`]s to their in-memory tables.
395///
396/// Rather than embedding tables in query plans, we store a [`SourceExpr::InMemory`],
397/// which contains the information necessary for optimization along with a `SourceId`.
398/// Query execution then executes the plan, and when it encounters a `SourceExpr::InMemory`,
399/// retrieves the `Self::Source` table from the corresponding provider.
400/// This allows query plans to be re-used, though each execution might require a new provider.
401///
402/// An in-memory table `Self::Source` is a type capable of producing [`RelValue<'a>`]s.
403/// The general form of this is `Iterator<Item = RelValue<'a>>`.
404/// Depending on the situation, this could be e.g.,
405/// - [`MemTable`], producing [`RelValue::Projection`],
406/// - `&'a [ProductValue]` producing [`RelValue::ProjRef`].
407pub trait SourceProvider<'a> {
408    /// The type of in-memory tables that this provider uses.
409    type Source: 'a + IntoIterator<Item = RelValue<'a>>;
410
411    /// Retrieve the `Self::Source` associated with `id`, if any.
412    ///
413    /// Taking the same `id` a second time may or may not yield the same source.
414    /// Callers should not assume that a generic provider will yield it more than once.
415    /// This means that a query plan may not include multiple references to the same [`SourceId`].
416    ///
417    /// Implementations are also not obligated to inspect `id`, e.g., if there's only one option.
418    fn take_source(&mut self, id: SourceId) -> Option<Self::Source>;
419}
420
421impl<'a, I: 'a + IntoIterator<Item = RelValue<'a>>, F: FnMut(SourceId) -> Option<I>> SourceProvider<'a> for F {
422    type Source = I;
423    fn take_source(&mut self, id: SourceId) -> Option<Self::Source> {
424        self(id)
425    }
426}
427
428impl<'a, I: 'a + IntoIterator<Item = RelValue<'a>>> SourceProvider<'a> for Option<I> {
429    type Source = I;
430    fn take_source(&mut self, _: SourceId) -> Option<Self::Source> {
431        self.take()
432    }
433}
434
435pub struct NoInMemUsed;
436
437impl<'a> SourceProvider<'a> for NoInMemUsed {
438    type Source = iter::Empty<RelValue<'a>>;
439    fn take_source(&mut self, _: SourceId) -> Option<Self::Source> {
440        None
441    }
442}
443
444/// A [`SourceProvider`] backed by an `ArrayVec`.
445///
446/// Internally, the `SourceSet` stores an `Option<T>` for each planned [`SourceId`]
447/// which are [`Option::take`]n out of the set.
448#[derive(Debug, PartialEq, Eq, Clone)]
449#[repr(transparent)]
450pub struct SourceSet<T, const N: usize>(
451    // Benchmarks showed an improvement in performance
452    // on incr-select by ~10% by not using `Vec<Option<T>>`.
453    ArrayVec<Option<T>, N>,
454);
455
456impl<'a, T: 'a + IntoIterator<Item = RelValue<'a>>, const N: usize> SourceProvider<'a> for SourceSet<T, N> {
457    type Source = T;
458    fn take_source(&mut self, id: SourceId) -> Option<T> {
459        self.take(id)
460    }
461}
462
463impl<T, const N: usize> From<[T; N]> for SourceSet<T, N> {
464    #[inline]
465    fn from(sources: [T; N]) -> Self {
466        Self(sources.map(Some).into())
467    }
468}
469
470impl<T, const N: usize> SourceSet<T, N> {
471    /// Returns an empty source set.
472    pub fn empty() -> Self {
473        Self(ArrayVec::new())
474    }
475
476    /// Get a fresh `SourceId` which can be used as the id for a new entry.
477    fn next_id(&self) -> SourceId {
478        SourceId(self.0.len())
479    }
480
481    /// Insert an entry into this `SourceSet` so it can be used in a query plan,
482    /// and return a [`SourceId`] which can be embedded in that plan.
483    pub fn add(&mut self, table: T) -> SourceId {
484        let source_id = self.next_id();
485        self.0.push(Some(table));
486        source_id
487    }
488
489    /// Extract the entry referred to by `id` from this `SourceSet`,
490    /// leaving a "gap" in its place.
491    ///
492    /// Subsequent calls to `take` on the same `id` will return `None`.
493    pub fn take(&mut self, id: SourceId) -> Option<T> {
494        self.0.get_mut(id.0).map(mem::take).unwrap_or_default()
495    }
496
497    /// Returns the number of slots for [`MemTable`]s in this set.
498    ///
499    /// Calling `self.take_mem_table(...)` or `self.take_table(...)` won't affect this number.
500    pub fn len(&self) -> usize {
501        self.0.len()
502    }
503
504    /// Returns whether this set has any slots for [`MemTable`]s.
505    ///
506    /// Calling `self.take_mem_table(...)` or `self.take_table(...)` won't affect whether the set is empty.
507    pub fn is_empty(&self) -> bool {
508        self.0.is_empty()
509    }
510}
511
512impl<T, const N: usize> std::ops::Index<SourceId> for SourceSet<T, N> {
513    type Output = Option<T>;
514
515    fn index(&self, idx: SourceId) -> &Self::Output {
516        &self.0[idx.0]
517    }
518}
519
520impl<T, const N: usize> std::ops::IndexMut<SourceId> for SourceSet<T, N> {
521    fn index_mut(&mut self, idx: SourceId) -> &mut Self::Output {
522        &mut self.0[idx.0]
523    }
524}
525
526impl<const N: usize> SourceSet<Vec<ProductValue>, N> {
527    /// Insert a [`MemTable`] into this `SourceSet` so it can be used in a query plan,
528    /// and return a [`SourceExpr`] which can be embedded in that plan.
529    pub fn add_mem_table(&mut self, table: MemTable) -> SourceExpr {
530        let id = self.add(table.data);
531        SourceExpr::from_mem_table(table.head, table.table_access, id)
532    }
533}
534
535/// A reference to a table within a query plan,
536/// used as the source for selections, scans, filters and joins.
537#[derive(Debug, Clone, Eq, PartialEq, Hash)]
538pub enum SourceExpr {
539    /// A plan for a "virtual" or projected table.
540    ///
541    /// The actual in-memory table, e.g., [`MemTable`] or `&'a [ProductValue]`
542    /// is not stored within the query plan;
543    /// rather, the `source_id` is an index which corresponds to the table in e.g., a [`SourceSet`].
544    ///
545    /// This allows query plans to be reused by supplying e.g., a new [`SourceSet`].
546    InMemory {
547        source_id: SourceId,
548        header: Arc<Header>,
549        table_type: StTableType,
550        table_access: StAccess,
551    },
552    /// A plan for a database table. Because [`DbTable`] is small and efficiently cloneable,
553    /// no indirection into a [`SourceSet`] is required.
554    DbTable(DbTable),
555}
556
557impl SourceExpr {
558    /// If `self` refers to a [`MemTable`], returns the [`SourceId`] for its location in the plan's [`SourceSet`].
559    ///
560    /// Returns `None` if `self` refers to a [`DbTable`], as [`DbTable`]s are stored directly in the `SourceExpr`,
561    /// rather than indirected through the [`SourceSet`].
562    pub fn source_id(&self) -> Option<SourceId> {
563        if let SourceExpr::InMemory { source_id, .. } = self {
564            Some(*source_id)
565        } else {
566            None
567        }
568    }
569
570    pub fn table_name(&self) -> &str {
571        &self.head().table_name
572    }
573
574    pub fn table_type(&self) -> StTableType {
575        match self {
576            SourceExpr::InMemory { table_type, .. } => *table_type,
577            SourceExpr::DbTable(db_table) => db_table.table_type,
578        }
579    }
580
581    pub fn table_access(&self) -> StAccess {
582        match self {
583            SourceExpr::InMemory { table_access, .. } => *table_access,
584            SourceExpr::DbTable(db_table) => db_table.table_access,
585        }
586    }
587
588    pub fn head(&self) -> &Arc<Header> {
589        match self {
590            SourceExpr::InMemory { header, .. } => header,
591            SourceExpr::DbTable(db_table) => &db_table.head,
592        }
593    }
594
595    pub fn is_mem_table(&self) -> bool {
596        matches!(self, SourceExpr::InMemory { .. })
597    }
598
599    pub fn is_db_table(&self) -> bool {
600        matches!(self, SourceExpr::DbTable(_))
601    }
602
603    pub fn from_mem_table(header: Arc<Header>, table_access: StAccess, id: SourceId) -> Self {
604        SourceExpr::InMemory {
605            source_id: id,
606            header,
607            table_type: StTableType::User,
608            table_access,
609        }
610    }
611
612    pub fn table_id(&self) -> Option<TableId> {
613        if let SourceExpr::DbTable(db_table) = self {
614            Some(db_table.table_id)
615        } else {
616            None
617        }
618    }
619
620    /// If `self` refers to a [`DbTable`], get a reference to it.
621    ///
622    /// Returns `None` if `self` refers to a [`MemTable`].
623    /// In that case, retrieving the [`MemTable`] requires inspecting the plan's corresponding [`SourceSet`]
624    /// via [`SourceSet::take_mem_table`] or [`SourceSet::take_table`].
625    pub fn get_db_table(&self) -> Option<&DbTable> {
626        if let SourceExpr::DbTable(db_table) = self {
627            Some(db_table)
628        } else {
629            None
630        }
631    }
632}
633
634impl From<&TableSchema> for SourceExpr {
635    fn from(value: &TableSchema) -> Self {
636        SourceExpr::DbTable(value.into())
637    }
638}
639
640/// A descriptor for an index semi join operation.
641///
642/// The semantics are those of a semijoin with rows from the index or the probe side being returned.
643#[derive(Debug, Clone, Eq, PartialEq, Hash)]
644pub struct IndexJoin {
645    pub probe_side: QueryExpr,
646    pub probe_col: ColId,
647    pub index_side: SourceExpr,
648    pub index_select: Option<ColumnOp>,
649    pub index_col: ColId,
650    /// If true, returns rows from the `index_side`.
651    /// Otherwise, returns rows from the `probe_side`.
652    pub return_index_rows: bool,
653}
654
655impl From<IndexJoin> for QueryExpr {
656    fn from(join: IndexJoin) -> Self {
657        let source: SourceExpr = if join.return_index_rows {
658            join.index_side.clone()
659        } else {
660            join.probe_side.source.clone()
661        };
662        QueryExpr {
663            source,
664            query: vec![Query::IndexJoin(join)],
665        }
666    }
667}
668
669impl IndexJoin {
670    // Reorder the index and probe sides of an index join.
671    // This is necessary if the indexed table has been replaced by a delta table.
672    // A delta table is a virtual table consisting of changes or updates to a physical table.
673    pub fn reorder(self, row_count: impl Fn(TableId, &str) -> i64) -> Self {
674        // The probe table must be a physical table.
675        if self.probe_side.source.is_mem_table() {
676            return self;
677        }
678        // It must have an index defined on the join field.
679        if !self
680            .probe_side
681            .source
682            .head()
683            .has_constraint(self.probe_col, Constraints::indexed())
684        {
685            return self;
686        }
687        // It must be a linear pipeline of selections.
688        if !self
689            .probe_side
690            .query
691            .iter()
692            .all(|op| matches!(op, Query::Select(_) | Query::IndexScan(_)))
693        {
694            return self;
695        }
696        match self.index_side.get_db_table() {
697            // If the size of the indexed table is sufficiently large,
698            // do not reorder.
699            //
700            // TODO: This determination is quite arbitrary.
701            // Ultimately we should be using cardinality estimation.
702            Some(DbTable { head, table_id, .. }) if row_count(*table_id, &head.table_name) > 500 => self,
703            // If this is a delta table, we must reorder.
704            // If this is a sufficiently small physical table, we should reorder.
705            _ => {
706                // Merge all selections from the original probe side into a single predicate.
707                // This includes an index scan if present.
708                let predicate = self
709                    .probe_side
710                    .query
711                    .into_iter()
712                    .filter_map(<Query as Into<Option<ColumnOp>>>::into)
713                    .reduce(ColumnOp::and);
714                // Push any selections on the index side to the probe side.
715                let probe_side = if let Some(predicate) = self.index_select {
716                    QueryExpr {
717                        source: self.index_side,
718                        query: vec![predicate.into()],
719                    }
720                } else {
721                    self.index_side.into()
722                };
723                IndexJoin {
724                    // The new probe side consists of the updated rows.
725                    // Plus any selections from the original index probe.
726                    probe_side,
727                    // The new probe field is the previous index field.
728                    probe_col: self.index_col,
729                    // The original probe table is now the table that is being probed.
730                    index_side: self.probe_side.source,
731                    // Any selections from the original probe side are pulled above the index lookup.
732                    index_select: predicate,
733                    // The new index field is the previous probe field.
734                    index_col: self.probe_col,
735                    // Because we have swapped the original index and probe sides of the join,
736                    // the new index join needs to return rows from the opposite side.
737                    return_index_rows: !self.return_index_rows,
738                }
739            }
740        }
741    }
742
743    // Convert this index join to an inner join, followed by a projection.
744    // This is needed for incremental evaluation of index joins.
745    // In particular when there are updates to both the left and right tables.
746    // In other words, when an index join has two delta tables.
747    pub fn to_inner_join(self) -> QueryExpr {
748        if self.return_index_rows {
749            let (col_lhs, col_rhs) = (self.index_col, self.probe_col);
750            let rhs = self.probe_side;
751
752            let source = self.index_side;
753            let inner_join = Query::JoinInner(JoinExpr::new(rhs, col_lhs, col_rhs, None));
754            let query = if let Some(predicate) = self.index_select {
755                vec![predicate.into(), inner_join]
756            } else {
757                vec![inner_join]
758            };
759            QueryExpr { source, query }
760        } else {
761            let (col_lhs, col_rhs) = (self.probe_col, self.index_col);
762            let mut rhs: QueryExpr = self.index_side.into();
763
764            if let Some(predicate) = self.index_select {
765                rhs.query.push(predicate.into());
766            }
767
768            let source = self.probe_side.source;
769            let inner_join = Query::JoinInner(JoinExpr::new(rhs, col_lhs, col_rhs, None));
770            let query = vec![inner_join];
771            QueryExpr { source, query }
772        }
773    }
774}
775
776#[derive(Debug, Clone, Eq, PartialEq, Hash)]
777pub struct JoinExpr {
778    pub rhs: QueryExpr,
779    pub col_lhs: ColId,
780    pub col_rhs: ColId,
781    /// If None, this is a left semi-join, returning rows only from the source table,
782    /// using the `rhs` as a filter.
783    ///
784    /// If Some(_), this is an inner join, returning the concatenation of the matching rows.
785    pub inner: Option<Arc<Header>>,
786}
787
788impl JoinExpr {
789    pub fn new(rhs: QueryExpr, col_lhs: ColId, col_rhs: ColId, inner: Option<Arc<Header>>) -> Self {
790        Self {
791            rhs,
792            col_lhs,
793            col_rhs,
794            inner,
795        }
796    }
797}
798
799#[derive(Debug, Clone, Copy, Eq, PartialEq)]
800pub enum DbType {
801    Table,
802    Index,
803    Sequence,
804    Constraint,
805}
806
807#[derive(Debug, Clone, Copy, Eq, PartialEq)]
808pub enum Crud {
809    Query,
810    Insert,
811    Update,
812    Delete,
813    Create(DbType),
814    Drop(DbType),
815    Config,
816}
817
818#[derive(Debug, Eq, PartialEq)]
819pub enum CrudExpr {
820    Query(QueryExpr),
821    Insert {
822        table: DbTable,
823        rows: Vec<ProductValue>,
824    },
825    Update {
826        delete: QueryExpr,
827        assignments: IntMap<ColId, ColExpr>,
828    },
829    Delete {
830        query: QueryExpr,
831    },
832    SetVar {
833        name: String,
834        literal: String,
835    },
836    ReadVar {
837        name: String,
838    },
839}
840
841impl CrudExpr {
842    pub fn optimize(self, row_count: &impl Fn(TableId, &str) -> i64) -> Self {
843        match self {
844            CrudExpr::Query(x) => CrudExpr::Query(x.optimize(row_count)),
845            _ => self,
846        }
847    }
848
849    pub fn is_reads<'a>(exprs: impl IntoIterator<Item = &'a CrudExpr>) -> bool {
850        exprs
851            .into_iter()
852            .all(|expr| matches!(expr, CrudExpr::Query(_) | CrudExpr::ReadVar { .. }))
853    }
854}
855
856#[derive(Debug, Clone, Eq, PartialEq, Hash)]
857pub struct IndexScan {
858    pub table: DbTable,
859    pub columns: ColList,
860    pub bounds: (Bound<AlgebraicValue>, Bound<AlgebraicValue>),
861}
862
863impl IndexScan {
864    /// Returns whether this is a point range.
865    pub fn is_point(&self) -> bool {
866        match &self.bounds {
867            (Bound::Included(lower), Bound::Included(upper)) => lower == upper,
868            _ => false,
869        }
870    }
871}
872
873/// A projection operation in a query.
874#[derive(Debug, Clone, Eq, PartialEq, From, Hash)]
875pub struct ProjectExpr {
876    pub cols: Vec<ColExpr>,
877    // The table id for a qualified wildcard project, if any.
878    // If present, further optimizations are possible.
879    pub wildcard_table: Option<TableId>,
880    pub header_after: Arc<Header>,
881}
882
883// An individual operation in a query.
884#[derive(Debug, Clone, Eq, PartialEq, From, Hash)]
885pub enum Query {
886    // Fetching rows via an index.
887    IndexScan(IndexScan),
888    // Joining rows via an index.
889    // Equivalent to Index Nested Loop Join.
890    IndexJoin(IndexJoin),
891    // A filter over an intermediate relation.
892    // In particular it does not utilize any indexes.
893    // If it could it would have already been transformed into an IndexScan.
894    Select(ColumnOp),
895    // Projects a set of columns.
896    Project(ProjectExpr),
897    // A join of two relations (base or intermediate) based on equality.
898    // Equivalent to a Nested Loop Join.
899    // Its operands my use indexes but the join itself does not.
900    JoinInner(JoinExpr),
901}
902
903impl Query {
904    /// Iterate over all [`SourceExpr`]s involved in the [`Query`].
905    ///
906    /// Sources are yielded from left to right. Duplicates are not filtered out.
907    pub fn walk_sources<E>(&self, on_source: &mut impl FnMut(&SourceExpr) -> Result<(), E>) -> Result<(), E> {
908        match self {
909            Self::Select(..) | Self::Project(..) => Ok(()),
910            Self::IndexScan(scan) => on_source(&SourceExpr::DbTable(scan.table.clone())),
911            Self::IndexJoin(join) => join.probe_side.walk_sources(on_source),
912            Self::JoinInner(join) => join.rhs.walk_sources(on_source),
913        }
914    }
915}
916
917// IndexArgument represents an equality or range predicate that can be answered
918// using an index.
919#[derive(Debug, PartialEq, Clone)]
920enum IndexArgument<'a> {
921    Eq {
922        columns: &'a ColList,
923        value: AlgebraicValue,
924    },
925    LowerBound {
926        columns: &'a ColList,
927        value: AlgebraicValue,
928        inclusive: bool,
929    },
930    UpperBound {
931        columns: &'a ColList,
932        value: AlgebraicValue,
933        inclusive: bool,
934    },
935}
936
937#[derive(Debug, PartialEq, Clone)]
938enum IndexColumnOp<'a> {
939    Index(IndexArgument<'a>),
940    Scan(&'a ColumnOp),
941}
942
943fn make_index_arg(cmp: OpCmp, columns: &ColList, value: AlgebraicValue) -> IndexColumnOp<'_> {
944    let arg = match cmp {
945        OpCmp::Eq => IndexArgument::Eq { columns, value },
946        OpCmp::NotEq => unreachable!("No IndexArgument for NotEq, caller should've filtered out"),
947        // a < 5 => exclusive upper bound
948        OpCmp::Lt => IndexArgument::UpperBound {
949            columns,
950            value,
951            inclusive: false,
952        },
953        // a > 5 => exclusive lower bound
954        OpCmp::Gt => IndexArgument::LowerBound {
955            columns,
956            value,
957            inclusive: false,
958        },
959        // a <= 5 => inclusive upper bound
960        OpCmp::LtEq => IndexArgument::UpperBound {
961            columns,
962            value,
963            inclusive: true,
964        },
965        // a >= 5 => inclusive lower bound
966        OpCmp::GtEq => IndexArgument::LowerBound {
967            columns,
968            value,
969            inclusive: true,
970        },
971    };
972    IndexColumnOp::Index(arg)
973}
974
975#[derive(Debug)]
976struct ColValue<'a> {
977    parent: &'a ColumnOp,
978    col: ColId,
979    cmp: OpCmp,
980    value: &'a AlgebraicValue,
981}
982
983impl<'a> ColValue<'a> {
984    pub fn new(parent: &'a ColumnOp, col: ColId, cmp: OpCmp, value: &'a AlgebraicValue) -> Self {
985        Self {
986            parent,
987            col,
988            cmp,
989            value,
990        }
991    }
992}
993
994type IndexColumnOpSink<'a> = SmallVec<[IndexColumnOp<'a>; 1]>;
995type ColsIndexed = HashSet<(ColId, OpCmp)>;
996
997/// Pick the best indices that can serve the constraints in `op`
998/// where the indices are taken from `header`.
999///
1000/// This function is designed to handle complex scenarios when selecting the optimal index for a query.
1001/// The scenarios include:
1002///
1003/// - Combinations of multi- and single-column indexes that could refer to the same column.
1004///   For example, the table could have indexes `[a]` and `[a, b]]`
1005///   and a user could query for `WHERE a = 1 AND b = 2 AND a = 3`.
1006///
1007/// - Query constraints can be supplied in any order;
1008///   i.e., both `WHERE a = 1 AND b = 2`
1009///   and `WHERE b = 2 AND a = 1` are valid.
1010///
1011/// - Queries against multi-col indices must use `=`, for now, in their constraints.
1012///   Otherwise, the index cannot be used.
1013///
1014/// - The use of multiple tables could generate redundant/duplicate operations like
1015///   `[ScanOrIndex::Index(a = 1), ScanOrIndex::Index(a = 1), ScanOrIndex::Scan(a = 1)]`.
1016///   This *cannot* be handled here.
1017///
1018/// # Returns
1019///
1020/// - A vector of `ScanOrIndex` representing the selected `index` OR `scan` operations.
1021///
1022/// - A HashSet of `(ColId, OpCmp)` representing the columns
1023///   and operators that can be served by an index.
1024///
1025///   This is required to remove the redundant operation on e.g.,
1026///   `[ScanOrIndex::Index(a = 1), ScanOrIndex::Index(a = 1), ScanOrIndex::Scan(a = 1)]`,
1027///   that could be generated by calling this function several times by using multiple `JOINS`.
1028///
1029/// # Example
1030///
1031/// If we have a table with `indexes`: `[a], [b], [b, c]` and then try to
1032/// optimize `WHERE a = 1 AND d > 2 AND c = 2 AND b = 1` we should return
1033///
1034/// -`ScanOrIndex::Index([c, b] = [1, 2])`
1035/// -`ScanOrIndex::Index(a = 1)`
1036/// -`ScanOrIndex::Scan(c = 2)`
1037///
1038/// # Note
1039///
1040/// NOTE: For a query like `SELECT * FROM students WHERE age > 18 AND height < 180`
1041/// we cannot serve this with a single `IndexScan`,
1042/// but rather, `select_best_index`
1043/// would give us two separate `IndexScan`s.
1044/// However, the upper layers of `QueryExpr` building will convert both of those into `Select`s.
1045fn select_best_index<'a>(
1046    cols_indexed: &mut ColsIndexed,
1047    header: &'a Header,
1048    op: &'a ColumnOp,
1049) -> IndexColumnOpSink<'a> {
1050    // Collect and sort indices by their lengths, with longest first.
1051    // We do this so that multi-col indices are used first, as they are more efficient.
1052    // TODO(Centril): This could be computed when `Header` is constructed.
1053    let mut indices = header
1054        .constraints
1055        .iter()
1056        .filter(|(_, c)| c.has_indexed())
1057        .map(|(cl, _)| cl)
1058        .collect::<SmallVec<[_; 1]>>();
1059    indices.sort_unstable_by_key(|cl| Reverse(cl.len()));
1060
1061    let mut found: IndexColumnOpSink = IndexColumnOpSink::default();
1062
1063    // Collect fields into a multi-map `(col_id, cmp) -> [col value]`.
1064    // This gives us `log(N)` seek + deletion.
1065    // TODO(Centril): Consider https://docs.rs/small-map/0.1.3/small_map/enum.SmallMap.html
1066    let mut col_map = BTreeMap::<_, SmallVec<[_; 1]>>::new();
1067    extract_cols(op, &mut col_map, &mut found);
1068
1069    // Go through each index,
1070    // consuming all column constraints that can be served by an index.
1071    for col_list in indices {
1072        // (1) No columns left? We're done.
1073        if col_map.is_empty() {
1074            break;
1075        }
1076
1077        if let Some(head) = col_list.as_singleton() {
1078            // Go through each operator.
1079            // NOTE: We do not consider `OpCmp::NotEq` at the moment
1080            // since those are typically not answered using an index.
1081            for cmp in [OpCmp::Eq, OpCmp::Lt, OpCmp::LtEq, OpCmp::Gt, OpCmp::GtEq] {
1082                // For a single column index,
1083                // we want to avoid the `ProductValue` indirection of below.
1084                for ColValue { cmp, value, col, .. } in col_map.remove(&(head, cmp)).into_iter().flatten() {
1085                    found.push(make_index_arg(cmp, col_list, value.clone()));
1086                    cols_indexed.insert((col, cmp));
1087                }
1088            }
1089        } else {
1090            // We have a multi column index.
1091            // Try to fit constraints `c_0 = v_0, ..., c_n = v_n` to this index.
1092            //
1093            // For the time being, we restrict multi-col index scans to `=` only.
1094            // This is what our infrastructure is set-up to handle soundly.
1095            // To extend this support to ranges requires deeper changes.
1096            // TODO(Centril, 2024-05-30): extend this support to ranges.
1097            let cmp = OpCmp::Eq;
1098
1099            // Compute the minimum number of `=` constraints that every column in the index has.
1100            let mut min_all_cols_num_eq = col_list
1101                .iter()
1102                .map(|col| col_map.get(&(col, cmp)).map_or(0, |fs| fs.len()))
1103                .min()
1104                .unwrap_or_default();
1105
1106            // For all of these sets of constraints,
1107            // construct the value to compare against.
1108            while min_all_cols_num_eq > 0 {
1109                let mut elems = Vec::with_capacity(col_list.len() as usize);
1110                for col in col_list.iter() {
1111                    // Cannot panic as `min_all_cols_num_eq > 0`.
1112                    let col_val = pop_multimap(&mut col_map, (col, cmp)).unwrap();
1113                    cols_indexed.insert((col_val.col, cmp));
1114                    // Add the column value to the product value.
1115                    elems.push(col_val.value.clone());
1116                }
1117                // Construct the index scan.
1118                let value = AlgebraicValue::product(elems);
1119                found.push(make_index_arg(cmp, col_list, value));
1120                min_all_cols_num_eq -= 1;
1121            }
1122        }
1123    }
1124
1125    // The remaining constraints must be served by a scan.
1126    found.extend(
1127        col_map
1128            .into_iter()
1129            .flat_map(|(_, fs)| fs)
1130            .map(|f| IndexColumnOp::Scan(f.parent)),
1131    );
1132
1133    found
1134}
1135
1136/// Pop an element from `map[key]` in the multimap `map`,
1137/// removing the entry entirely if there are no more elements left after popping.
1138fn pop_multimap<K: Ord, V, const N: usize>(map: &mut BTreeMap<K, SmallVec<[V; N]>>, key: K) -> Option<V> {
1139    let Entry::Occupied(mut entry) = map.entry(key) else {
1140        return None;
1141    };
1142    let fields = entry.get_mut();
1143    let val = fields.pop();
1144    if fields.is_empty() {
1145        entry.remove();
1146    }
1147    val
1148}
1149
1150/// Extracts a list of `col = val` constraints that *could* be answered by an index
1151/// and populates those into `col_map`.
1152/// The [`ColumnOp`]s that don't fit `col = val`
1153/// are made into [`IndexColumnOp::Scan`]s immediately which are added to `found`.
1154fn extract_cols<'a>(
1155    op: &'a ColumnOp,
1156    col_map: &mut BTreeMap<(ColId, OpCmp), SmallVec<[ColValue<'a>; 1]>>,
1157    found: &mut IndexColumnOpSink<'a>,
1158) {
1159    let mut add_field = |parent, op, col, val| {
1160        let fv = ColValue::new(parent, col, op, val);
1161        col_map.entry((col, op)).or_default().push(fv);
1162    };
1163
1164    match op {
1165        ColumnOp::Cmp { cmp, lhs, rhs } => {
1166            if let (ColumnOp::Col(col), ColumnOp::Val(val)) = (&**lhs, &**rhs) {
1167                // `lhs` must be a field that exists and `rhs` must be a value.
1168                add_field(op, *cmp, *col, val);
1169            }
1170        }
1171        ColumnOp::ColCmpVal { lhs, cmp, rhs } => add_field(op, *cmp, *lhs, rhs),
1172        ColumnOp::Log {
1173            op: OpLogic::And,
1174            operands,
1175        } => {
1176            for oper in operands.iter() {
1177                extract_cols(oper, col_map, found);
1178            }
1179        }
1180        ColumnOp::Log { op: OpLogic::Or, .. } | ColumnOp::Col(_) | ColumnOp::Val(_) => {
1181            found.push(IndexColumnOp::Scan(op));
1182        }
1183    }
1184}
1185
1186#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1187// TODO(bikeshedding): Refactor this struct so that `IndexJoin`s replace the `table`,
1188// rather than appearing as the first element of the `query`.
1189//
1190// `IndexJoin`s do not behave like filters; in fact they behave more like data sources.
1191// A query conceptually starts with either a single table or an `IndexJoin`,
1192// and then stacks a set of filters on top of that.
1193pub struct QueryExpr {
1194    pub source: SourceExpr,
1195    pub query: Vec<Query>,
1196}
1197
1198impl From<SourceExpr> for QueryExpr {
1199    fn from(source: SourceExpr) -> Self {
1200        QueryExpr { source, query: vec![] }
1201    }
1202}
1203
1204impl QueryExpr {
1205    pub fn new<T: Into<SourceExpr>>(source: T) -> Self {
1206        Self {
1207            source: source.into(),
1208            query: vec![],
1209        }
1210    }
1211
1212    /// Iterate over all [`SourceExpr`]s involved in the [`QueryExpr`].
1213    ///
1214    /// Sources are yielded from left to right. Duplicates are not filtered out.
1215    pub fn walk_sources<E>(&self, on_source: &mut impl FnMut(&SourceExpr) -> Result<(), E>) -> Result<(), E> {
1216        on_source(&self.source)?;
1217        self.query.iter().try_for_each(|q| q.walk_sources(on_source))
1218    }
1219
1220    /// Returns the last [`Header`] of this query.
1221    ///
1222    /// Starts the scan from the back to the front,
1223    /// looking for query operations that change the `Header`.
1224    /// These are `JoinInner` and `Project`.
1225    /// If there are no operations that alter the `Header`,
1226    /// this falls back to the origin `self.source.head()`.
1227    pub fn head(&self) -> &Arc<Header> {
1228        self.query
1229            .iter()
1230            .rev()
1231            .find_map(|op| match op {
1232                Query::Select(_) => None,
1233                Query::IndexScan(scan) => Some(&scan.table.head),
1234                Query::IndexJoin(join) if join.return_index_rows => Some(join.index_side.head()),
1235                Query::IndexJoin(join) => Some(join.probe_side.head()),
1236                Query::Project(proj) => Some(&proj.header_after),
1237                Query::JoinInner(join) => join.inner.as_ref(),
1238            })
1239            .unwrap_or_else(|| self.source.head())
1240    }
1241
1242    /// Does this query read from a given table?
1243    pub fn reads_from_table(&self, id: &TableId) -> bool {
1244        self.source.table_id() == Some(*id)
1245            || self.query.iter().any(|q| match q {
1246                Query::Select(_) | Query::Project(..) => false,
1247                Query::IndexScan(scan) => scan.table.table_id == *id,
1248                Query::JoinInner(join) => join.rhs.reads_from_table(id),
1249                Query::IndexJoin(join) => {
1250                    join.index_side.table_id() == Some(*id) || join.probe_side.reads_from_table(id)
1251                }
1252            })
1253    }
1254
1255    // Generate an index scan for an equality predicate if this is the first operator.
1256    // Otherwise generate a select.
1257    // TODO: Replace these methods with a proper query optimization pass.
1258    pub fn with_index_eq(mut self, table: DbTable, columns: ColList, value: AlgebraicValue) -> Self {
1259        let point = |v: AlgebraicValue| (Bound::Included(v.clone()), Bound::Included(v));
1260
1261        // if this is the first operator in the list, generate index scan
1262        let Some(query) = self.query.pop() else {
1263            let bounds = point(value);
1264            self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1265            return self;
1266        };
1267        match query {
1268            // try to push below join's lhs
1269            Query::JoinInner(JoinExpr {
1270                rhs:
1271                    QueryExpr {
1272                        source: SourceExpr::DbTable(ref db_table),
1273                        ..
1274                    },
1275                ..
1276            }) if table.table_id != db_table.table_id => {
1277                self = self.with_index_eq(db_table.clone(), columns, value);
1278                self.query.push(query);
1279                self
1280            }
1281            // try to push below join's rhs
1282            Query::JoinInner(JoinExpr {
1283                rhs,
1284                col_lhs,
1285                col_rhs,
1286                inner: semi,
1287            }) => {
1288                self.query.push(Query::JoinInner(JoinExpr {
1289                    rhs: rhs.with_index_eq(table, columns, value),
1290                    col_lhs,
1291                    col_rhs,
1292                    inner: semi,
1293                }));
1294                self
1295            }
1296            // merge with a preceding select
1297            Query::Select(filter) => {
1298                let op = ColumnOp::and_cmp(OpCmp::Eq, &columns, value);
1299                self.query.push(Query::Select(ColumnOp::and(filter, op)));
1300                self
1301            }
1302            // else generate a new select
1303            query => {
1304                self.query.push(query);
1305                let op = ColumnOp::and_cmp(OpCmp::Eq, &columns, value);
1306                self.query.push(Query::Select(op));
1307                self
1308            }
1309        }
1310    }
1311
1312    // Generate an index scan for a range predicate or try merging with a previous index scan.
1313    // Otherwise generate a select.
1314    // TODO: Replace these methods with a proper query optimization pass.
1315    pub fn with_index_lower_bound(
1316        mut self,
1317        table: DbTable,
1318        columns: ColList,
1319        value: AlgebraicValue,
1320        inclusive: bool,
1321    ) -> Self {
1322        // if this is the first operator in the list, generate an index scan
1323        let Some(query) = self.query.pop() else {
1324            let bounds = (Self::bound(value, inclusive), Bound::Unbounded);
1325            self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1326            return self;
1327        };
1328        match query {
1329            // try to push below join's lhs
1330            Query::JoinInner(JoinExpr {
1331                rhs:
1332                    QueryExpr {
1333                        source: SourceExpr::DbTable(ref db_table),
1334                        ..
1335                    },
1336                ..
1337            }) if table.table_id != db_table.table_id => {
1338                self = self.with_index_lower_bound(table, columns, value, inclusive);
1339                self.query.push(query);
1340                self
1341            }
1342            // try to push below join's rhs
1343            Query::JoinInner(JoinExpr {
1344                rhs,
1345                col_lhs,
1346                col_rhs,
1347                inner: semi,
1348            }) => {
1349                self.query.push(Query::JoinInner(JoinExpr {
1350                    rhs: rhs.with_index_lower_bound(table, columns, value, inclusive),
1351                    col_lhs,
1352                    col_rhs,
1353                    inner: semi,
1354                }));
1355                self
1356            }
1357            // merge with a preceding upper bounded index scan (inclusive)
1358            Query::IndexScan(IndexScan {
1359                columns: lhs_col_id,
1360                bounds: (Bound::Unbounded, Bound::Included(upper)),
1361                ..
1362            }) if columns == lhs_col_id => {
1363                let bounds = (Self::bound(value, inclusive), Bound::Included(upper));
1364                self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1365                self
1366            }
1367            // merge with a preceding upper bounded index scan (exclusive)
1368            Query::IndexScan(IndexScan {
1369                columns: lhs_col_id,
1370                bounds: (Bound::Unbounded, Bound::Excluded(upper)),
1371                ..
1372            }) if columns == lhs_col_id => {
1373                // Queries like `WHERE x < 5 AND x > 5` never return any rows and are likely mistakes.
1374                // Detect such queries and log a warning.
1375                // Compute this condition early, then compute the resulting query and log it.
1376                // TODO: We should not emit an `IndexScan` in this case.
1377                // Further design work is necessary to decide whether this should be an error at query compile time,
1378                // or whether we should emit a query plan which explicitly says that it will return 0 rows.
1379                // The current behavior is a hack
1380                // because this patch was written (2024-04-01 pgoldman) a short time before the BitCraft alpha,
1381                // and a more invasive change was infeasible.
1382                let is_never = !inclusive && value == upper;
1383
1384                let bounds = (Self::bound(value, inclusive), Bound::Excluded(upper));
1385                self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1386
1387                if is_never {
1388                    log::warn!("Query will select no rows due to equal excluded bounds: {self:?}")
1389                }
1390
1391                self
1392            }
1393            // merge with a preceding select
1394            Query::Select(filter) => {
1395                let bounds = (Self::bound(value, inclusive), Bound::Unbounded);
1396                let op = ColumnOp::from_op_col_bounds(&columns, bounds);
1397                self.query.push(Query::Select(ColumnOp::and(filter, op)));
1398                self
1399            }
1400            // else generate a new select
1401            query => {
1402                self.query.push(query);
1403                let bounds = (Self::bound(value, inclusive), Bound::Unbounded);
1404                let op = ColumnOp::from_op_col_bounds(&columns, bounds);
1405                self.query.push(Query::Select(op));
1406                self
1407            }
1408        }
1409    }
1410
1411    // Generate an index scan for a range predicate or try merging with a previous index scan.
1412    // Otherwise generate a select.
1413    // TODO: Replace these methods with a proper query optimization pass.
1414    pub fn with_index_upper_bound(
1415        mut self,
1416        table: DbTable,
1417        columns: ColList,
1418        value: AlgebraicValue,
1419        inclusive: bool,
1420    ) -> Self {
1421        // if this is the first operator in the list, generate an index scan
1422        let Some(query) = self.query.pop() else {
1423            self.query.push(Query::IndexScan(IndexScan {
1424                table,
1425                columns,
1426                bounds: (Bound::Unbounded, Self::bound(value, inclusive)),
1427            }));
1428            return self;
1429        };
1430        match query {
1431            // try to push below join's lhs
1432            Query::JoinInner(JoinExpr {
1433                rhs:
1434                    QueryExpr {
1435                        source: SourceExpr::DbTable(ref db_table),
1436                        ..
1437                    },
1438                ..
1439            }) if table.table_id != db_table.table_id => {
1440                self = self.with_index_upper_bound(table, columns, value, inclusive);
1441                self.query.push(query);
1442                self
1443            }
1444            // try to push below join's rhs
1445            Query::JoinInner(JoinExpr {
1446                rhs,
1447                col_lhs,
1448                col_rhs,
1449                inner: semi,
1450            }) => {
1451                self.query.push(Query::JoinInner(JoinExpr {
1452                    rhs: rhs.with_index_upper_bound(table, columns, value, inclusive),
1453                    col_lhs,
1454                    col_rhs,
1455                    inner: semi,
1456                }));
1457                self
1458            }
1459            // merge with a preceding lower bounded index scan (inclusive)
1460            Query::IndexScan(IndexScan {
1461                columns: lhs_col_id,
1462                bounds: (Bound::Included(lower), Bound::Unbounded),
1463                ..
1464            }) if columns == lhs_col_id => {
1465                let bounds = (Bound::Included(lower), Self::bound(value, inclusive));
1466                self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1467                self
1468            }
1469            // merge with a preceding lower bounded index scan (exclusive)
1470            Query::IndexScan(IndexScan {
1471                columns: lhs_col_id,
1472                bounds: (Bound::Excluded(lower), Bound::Unbounded),
1473                ..
1474            }) if columns == lhs_col_id => {
1475                // Queries like `WHERE x < 5 AND x > 5` never return any rows and are likely mistakes.
1476                // Detect such queries and log a warning.
1477                // Compute this condition early, then compute the resulting query and log it.
1478                // TODO: We should not emit an `IndexScan` in this case.
1479                // Further design work is necessary to decide whether this should be an error at query compile time,
1480                // or whether we should emit a query plan which explicitly says that it will return 0 rows.
1481                // The current behavior is a hack
1482                // because this patch was written (2024-04-01 pgoldman) a short time before the BitCraft alpha,
1483                // and a more invasive change was infeasible.
1484                let is_never = !inclusive && value == lower;
1485
1486                let bounds = (Bound::Excluded(lower), Self::bound(value, inclusive));
1487                self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1488
1489                if is_never {
1490                    log::warn!("Query will select no rows due to equal excluded bounds: {self:?}")
1491                }
1492
1493                self
1494            }
1495            // merge with a preceding select
1496            Query::Select(filter) => {
1497                let bounds = (Bound::Unbounded, Self::bound(value, inclusive));
1498                let op = ColumnOp::from_op_col_bounds(&columns, bounds);
1499                self.query.push(Query::Select(ColumnOp::and(filter, op)));
1500                self
1501            }
1502            // else generate a new select
1503            query => {
1504                self.query.push(query);
1505                let bounds = (Bound::Unbounded, Self::bound(value, inclusive));
1506                let op = ColumnOp::from_op_col_bounds(&columns, bounds);
1507                self.query.push(Query::Select(op));
1508                self
1509            }
1510        }
1511    }
1512
1513    pub fn with_select<O>(mut self, op: O) -> Result<Self, RelationError>
1514    where
1515        O: Into<FieldOp>,
1516    {
1517        let op = op.into();
1518        let Some(query) = self.query.pop() else {
1519            return self.add_base_select(op);
1520        };
1521
1522        match (query, op) {
1523            (
1524                Query::JoinInner(JoinExpr {
1525                    rhs,
1526                    col_lhs,
1527                    col_rhs,
1528                    inner,
1529                }),
1530                FieldOp::Cmp {
1531                    op: OpQuery::Cmp(cmp),
1532                    lhs: field,
1533                    rhs: value,
1534                },
1535            ) => match (*field, *value) {
1536                (FieldOp::Field(FieldExpr::Name(field)), FieldOp::Field(FieldExpr::Value(value)))
1537                // Field is from lhs, so push onto join's left arg
1538                if self.head().column_pos(field).is_some() =>
1539                    {
1540                        // No typing restrictions on `field cmp value`,
1541                        // and there are no binary operators to recurse into.
1542                        self = self.with_select(FieldOp::cmp(field, cmp, value))?;
1543                        self.query.push(Query::JoinInner(JoinExpr { rhs, col_lhs, col_rhs, inner }));
1544                        Ok(self)
1545                    }
1546                (FieldOp::Field(FieldExpr::Name(field)), FieldOp::Field(FieldExpr::Value(value)))
1547                // Field is from rhs, so push onto join's right arg
1548                if rhs.head().column_pos(field).is_some() =>
1549                    {
1550                        // No typing restrictions on `field cmp value`,
1551                        // and there are no binary operators to recurse into.
1552                        let rhs = rhs.with_select(FieldOp::cmp(field, cmp, value))?;
1553                        self.query.push(Query::JoinInner(JoinExpr {
1554                            rhs,
1555                            col_lhs,
1556                            col_rhs,
1557                            inner,
1558                        }));
1559                        Ok(self)
1560                    }
1561                (field, value) => {
1562                    self.query.push(Query::JoinInner(JoinExpr { rhs, col_lhs, col_rhs, inner, }));
1563
1564                    // As we have `field op value` we need not demand `bool`,
1565                    // but we must still recuse into each side.
1566                    self.check_field_op_logics(&field)?;
1567                    self.check_field_op_logics(&value)?;
1568                    // Convert to `ColumnOp`.
1569                    let col = field.names_to_cols(self.head()).unwrap();
1570                    let value = value.names_to_cols(self.head()).unwrap();
1571                    // Add `col op value` filter to query.
1572                    self.query.push(Query::Select(ColumnOp::new(OpQuery::Cmp(cmp), col, value)));
1573                    Ok(self)
1574                }
1575            },
1576            // We have a previous filter `lhs`, so join with `rhs` forming `lhs AND rhs`.
1577            (Query::Select(lhs), rhs) => {
1578                // Type check `rhs`, demanding `bool`.
1579                self.check_field_op(&rhs)?;
1580                // Convert to `ColumnOp`.
1581                let rhs = rhs.names_to_cols(self.head()).unwrap();
1582                // Add `lhs AND op` to query.
1583                self.query.push(Query::Select(ColumnOp::and(lhs, rhs)));
1584                Ok(self)
1585            }
1586            // No previous filter, so add a base one.
1587            (query, op) => {
1588                self.query.push(query);
1589                self.add_base_select(op)
1590            }
1591        }
1592    }
1593
1594    /// Add a base `Select` query that filters according to `op`.
1595    /// The `op` is checked to produce a `bool` value.
1596    fn add_base_select(mut self, op: FieldOp) -> Result<Self, RelationError> {
1597        // Type check the filter, demanding `bool`.
1598        self.check_field_op(&op)?;
1599        // Convert to `ColumnOp`.
1600        let op = op.names_to_cols(self.head()).unwrap();
1601        // Add the filter.
1602        self.query.push(Query::Select(op));
1603        Ok(self)
1604    }
1605
1606    /// Type checks a `FieldOp` with respect to `self`,
1607    /// ensuring that query evaluation cannot get stuck or panic due to `reduce_bool`.
1608    fn check_field_op(&self, op: &FieldOp) -> Result<(), RelationError> {
1609        use OpQuery::*;
1610        match op {
1611            // `lhs` and `rhs` must both be typed at `bool`.
1612            FieldOp::Cmp { op: Logic(_), lhs, rhs } => {
1613                self.check_field_op(lhs)?;
1614                self.check_field_op(rhs)?;
1615                Ok(())
1616            }
1617            // `lhs` and `rhs` have no typing restrictions.
1618            // The result of `lhs op rhs` will always be a `bool`
1619            // either by `Eq` or `Ord` on `AlgebraicValue` (see `ColumnOp::compare_bin_op`).
1620            // However, we still have to recurse into `lhs` and `rhs`
1621            // in case we have e.g., `a == (b == c)`.
1622            FieldOp::Cmp { op: Cmp(_), lhs, rhs } => {
1623                self.check_field_op_logics(lhs)?;
1624                self.check_field_op_logics(rhs)?;
1625                Ok(())
1626            }
1627            FieldOp::Field(FieldExpr::Value(AlgebraicValue::Bool(_))) => Ok(()),
1628            FieldOp::Field(FieldExpr::Value(v)) => Err(RelationError::NotBoolValue { val: v.clone() }),
1629            FieldOp::Field(FieldExpr::Name(field)) => {
1630                let field = *field;
1631                let head = self.head();
1632                let col_id = head.column_pos_or_err(field)?;
1633                let col_ty = &head.fields[col_id.idx()].algebraic_type;
1634                match col_ty {
1635                    &AlgebraicType::Bool => Ok(()),
1636                    ty => Err(RelationError::NotBoolType { field, ty: ty.clone() }),
1637                }
1638            }
1639        }
1640    }
1641
1642    /// Traverses `op`, checking any logical operators for bool-typed operands.
1643    fn check_field_op_logics(&self, op: &FieldOp) -> Result<(), RelationError> {
1644        use OpQuery::*;
1645        match op {
1646            FieldOp::Field(_) => Ok(()),
1647            FieldOp::Cmp { op: Cmp(_), lhs, rhs } => {
1648                self.check_field_op_logics(lhs)?;
1649                self.check_field_op_logics(rhs)?;
1650                Ok(())
1651            }
1652            FieldOp::Cmp { op: Logic(_), lhs, rhs } => {
1653                self.check_field_op(lhs)?;
1654                self.check_field_op(rhs)?;
1655                Ok(())
1656            }
1657        }
1658    }
1659
1660    pub fn with_select_cmp<LHS, RHS, O>(self, op: O, lhs: LHS, rhs: RHS) -> Result<Self, RelationError>
1661    where
1662        LHS: Into<FieldExpr>,
1663        RHS: Into<FieldExpr>,
1664        O: Into<OpQuery>,
1665    {
1666        let op = FieldOp::new(op.into(), FieldOp::Field(lhs.into()), FieldOp::Field(rhs.into()));
1667        self.with_select(op)
1668    }
1669
1670    // Appends a project operation to the query operator pipeline.
1671    // The `wildcard_table_id` represents a projection of the form `table.*`.
1672    // This is used to determine if an inner join can be rewritten as an index join.
1673    pub fn with_project(
1674        mut self,
1675        fields: Vec<FieldExpr>,
1676        wildcard_table: Option<TableId>,
1677    ) -> Result<Self, RelationError> {
1678        if !fields.is_empty() {
1679            let header_before = self.head();
1680
1681            // Translate the field expressions to column expressions.
1682            let mut cols = Vec::with_capacity(fields.len());
1683            for field in fields {
1684                cols.push(field.name_to_col(header_before)?);
1685            }
1686
1687            // Project the header.
1688            // We'll store that so subsequent operations use that as a base.
1689            let header_after = Arc::new(header_before.project(&cols)?);
1690
1691            // Add the projection.
1692            self.query.push(Query::Project(ProjectExpr {
1693                cols,
1694                wildcard_table,
1695                header_after,
1696            }));
1697        }
1698        Ok(self)
1699    }
1700
1701    pub fn with_join_inner_raw(
1702        mut self,
1703        q_rhs: QueryExpr,
1704        c_lhs: ColId,
1705        c_rhs: ColId,
1706        inner: Option<Arc<Header>>,
1707    ) -> Self {
1708        self.query
1709            .push(Query::JoinInner(JoinExpr::new(q_rhs, c_lhs, c_rhs, inner)));
1710        self
1711    }
1712
1713    pub fn with_join_inner(self, q_rhs: impl Into<QueryExpr>, c_lhs: ColId, c_rhs: ColId, semi: bool) -> Self {
1714        let q_rhs = q_rhs.into();
1715        let inner = (!semi).then(|| Arc::new(self.head().extend(q_rhs.head())));
1716        self.with_join_inner_raw(q_rhs, c_lhs, c_rhs, inner)
1717    }
1718
1719    fn bound(value: AlgebraicValue, inclusive: bool) -> Bound<AlgebraicValue> {
1720        if inclusive {
1721            Bound::Included(value)
1722        } else {
1723            Bound::Excluded(value)
1724        }
1725    }
1726
1727    /// Try to turn an inner join followed by a projection into a semijoin.
1728    ///
1729    /// This optimization recognizes queries of the form:
1730    ///
1731    /// ```ignore
1732    /// QueryExpr {
1733    ///   source: LHS,
1734    ///   query: [
1735    ///     JoinInner(JoinExpr {
1736    ///       rhs: RHS,
1737    ///       semi: false,
1738    ///       ..
1739    ///     }),
1740    ///     Project(LHS.*),
1741    ///     ...
1742    ///   ]
1743    /// }
1744    /// ```
1745    ///
1746    /// And combines the `JoinInner` with the `Project` into a `JoinInner` with `semi: true`.
1747    ///
1748    /// Current limitations of this optimization:
1749    /// - The `JoinInner` must be the first (0th) element of the `query`.
1750    ///   Future work could search through the `query` to find any applicable `JoinInner`s,
1751    ///   but the current implementation inspects only the first expr.
1752    ///   This is likely sufficient because this optimization is primarily useful for enabling `try_index_join`,
1753    ///   which is fundamentally limited to operate on the first expr.
1754    ///   Note that we still get to optimize incremental joins, because we first optimize the original query
1755    ///   with [`DbTable`] sources, which results in an [`IndexJoin`]
1756    ///   then we replace the sources with [`MemTable`]s and go back to a [`JoinInner`] with `semi: true`.
1757    /// - The `Project` must immediately follow the `JoinInner`, with no intervening exprs.
1758    ///   Future work could search through intervening exprs to detect that the RHS table is unused.
1759    /// - The LHS/source table must be a [`DbTable`], not a [`MemTable`].
1760    ///   This is so we can recognize a wildcard project by its table id.
1761    ///   Future work could inspect the set of projected fields and compare them to the LHS table's header instead.
1762    pub fn try_semi_join(self) -> QueryExpr {
1763        let QueryExpr { source, query } = self;
1764
1765        let Some(source_table_id) = source.table_id() else {
1766            // Source is a `MemTable`, so we can't recognize a wildcard projection. Bail.
1767            return QueryExpr { source, query };
1768        };
1769
1770        let mut exprs = query.into_iter();
1771        let Some(join_candidate) = exprs.next() else {
1772            // No first (0th) expr to be the join; bail.
1773            return QueryExpr { source, query: vec![] };
1774        };
1775        let Query::JoinInner(join) = join_candidate else {
1776            // First (0th) expr is not an inner join. Bail.
1777            return QueryExpr {
1778                source,
1779                query: itertools::chain![Some(join_candidate), exprs].collect(),
1780            };
1781        };
1782
1783        let Some(project_candidate) = exprs.next() else {
1784            // No second (1st) expr to be the project. Bail.
1785            return QueryExpr {
1786                source,
1787                query: vec![Query::JoinInner(join)],
1788            };
1789        };
1790
1791        let Query::Project(proj) = project_candidate else {
1792            // Second (1st) expr is not a wildcard projection. Bail.
1793            return QueryExpr {
1794                source,
1795                query: itertools::chain![Some(Query::JoinInner(join)), Some(project_candidate), exprs].collect(),
1796            };
1797        };
1798
1799        if proj.wildcard_table != Some(source_table_id) {
1800            // Projection is selecting the RHS table. Bail.
1801            return QueryExpr {
1802                source,
1803                query: itertools::chain![Some(Query::JoinInner(join)), Some(Query::Project(proj)), exprs].collect(),
1804            };
1805        };
1806
1807        // All conditions met; return a semijoin.
1808        let semijoin = JoinExpr { inner: None, ..join };
1809
1810        QueryExpr {
1811            source,
1812            query: itertools::chain![Some(Query::JoinInner(semijoin)), exprs].collect(),
1813        }
1814    }
1815
1816    // Try to turn an applicable join into an index join.
1817    // An applicable join is one that can use an index to probe the lhs.
1818    // It must also project only the columns from the lhs.
1819    //
1820    // Ex. SELECT Left.* FROM Left JOIN Right ON Left.id = Right.id ...
1821    // where `Left` has an index defined on `id`.
1822    fn try_index_join(self) -> QueryExpr {
1823        let mut query = self;
1824        // We expect a single operation - an inner join with `semi: true`.
1825        // These can be transformed by `try_semi_join` from a sequence of two queries, an inner join followed by a wildcard project.
1826        if query.query.len() != 1 {
1827            return query;
1828        }
1829
1830        // If the source is a `MemTable`, it doesn't have any indexes,
1831        // so we can't plan an index join.
1832        if query.source.is_mem_table() {
1833            return query;
1834        }
1835        let source = query.source;
1836        let join = query.query.pop().unwrap();
1837
1838        match join {
1839            Query::JoinInner(join @ JoinExpr { inner: None, .. }) => {
1840                if !join.rhs.query.is_empty() {
1841                    // An applicable join must have an index defined on the correct field.
1842                    if source.head().has_constraint(join.col_lhs, Constraints::indexed()) {
1843                        let index_join = IndexJoin {
1844                            probe_side: join.rhs,
1845                            probe_col: join.col_rhs,
1846                            index_side: source.clone(),
1847                            index_select: None,
1848                            index_col: join.col_lhs,
1849                            return_index_rows: true,
1850                        };
1851                        let query = [Query::IndexJoin(index_join)].into();
1852                        return QueryExpr { source, query };
1853                    }
1854                }
1855                QueryExpr {
1856                    source,
1857                    query: vec![Query::JoinInner(join)],
1858                }
1859            }
1860            first => QueryExpr {
1861                source,
1862                query: vec![first],
1863            },
1864        }
1865    }
1866
1867    /// Look for filters that could use indexes
1868    fn optimize_select(mut q: QueryExpr, op: ColumnOp, tables: &[SourceExpr]) -> QueryExpr {
1869        // Go through each table schema referenced in the query.
1870        // Find the first sargable condition and short-circuit.
1871        let mut fields_found = HashSet::default();
1872        for schema in tables {
1873            for op in select_best_index(&mut fields_found, schema.head(), &op) {
1874                if let IndexColumnOp::Scan(op) = &op {
1875                    // Remove a duplicated/redundant operation on the same `field` and `op`
1876                    // like `[Index(a = 1), Index(a = 1), Scan(a = 1)]`
1877                    if op.as_col_cmp().is_some_and(|cc| !fields_found.insert(cc)) {
1878                        continue;
1879                    }
1880                }
1881
1882                match op {
1883                    // A sargable condition for on one of the table schemas,
1884                    // either an equality or range condition.
1885                    IndexColumnOp::Index(idx) => {
1886                        let table = schema
1887                            .get_db_table()
1888                            .expect("find_sargable_ops(schema, op) implies `schema.is_db_table()`")
1889                            .clone();
1890
1891                        q = match idx {
1892                            IndexArgument::Eq { columns, value } => q.with_index_eq(table, columns.clone(), value),
1893                            IndexArgument::LowerBound {
1894                                columns,
1895                                value,
1896                                inclusive,
1897                            } => q.with_index_lower_bound(table, columns.clone(), value, inclusive),
1898                            IndexArgument::UpperBound {
1899                                columns,
1900                                value,
1901                                inclusive,
1902                            } => q.with_index_upper_bound(table, columns.clone(), value, inclusive),
1903                        };
1904                    }
1905                    // Filter condition cannot be answered using an index.
1906                    IndexColumnOp::Scan(rhs) => {
1907                        let rhs = rhs.clone();
1908                        let op = match q.query.pop() {
1909                            // Merge condition into any pre-existing `Select`.
1910                            Some(Query::Select(lhs)) => ColumnOp::and(lhs, rhs),
1911                            None => rhs,
1912                            Some(other) => {
1913                                q.query.push(other);
1914                                rhs
1915                            }
1916                        };
1917                        q.query.push(Query::Select(op));
1918                    }
1919                }
1920            }
1921        }
1922
1923        q
1924    }
1925
1926    pub fn optimize(mut self, row_count: &impl Fn(TableId, &str) -> i64) -> Self {
1927        let mut q = Self {
1928            source: self.source.clone(),
1929            query: Vec::with_capacity(self.query.len()),
1930        };
1931
1932        if matches!(&*self.query, [Query::IndexJoin(_)]) {
1933            if let Some(Query::IndexJoin(join)) = self.query.pop() {
1934                q.query.push(Query::IndexJoin(join.reorder(row_count)));
1935                return q;
1936            }
1937        }
1938
1939        for query in self.query {
1940            match query {
1941                Query::Select(op) => {
1942                    q = Self::optimize_select(q, op, from_ref(&self.source));
1943                }
1944                Query::JoinInner(join) => {
1945                    q = q.with_join_inner_raw(join.rhs.optimize(row_count), join.col_lhs, join.col_rhs, join.inner);
1946                }
1947                _ => q.query.push(query),
1948            };
1949        }
1950
1951        // Make sure to `try_semi_join` before `try_index_join`, as the latter depends on the former.
1952        let q = q.try_semi_join();
1953        let q = q.try_index_join();
1954        if matches!(&*q.query, [Query::IndexJoin(_)]) {
1955            return q.optimize(row_count);
1956        }
1957        q
1958    }
1959}
1960
1961impl AuthAccess for Query {
1962    fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> {
1963        if owner == caller {
1964            return Ok(());
1965        }
1966
1967        self.walk_sources(&mut |s| s.check_auth(owner, caller))
1968    }
1969}
1970
1971#[derive(Debug, Eq, PartialEq, From)]
1972pub enum Expr {
1973    #[from]
1974    Value(AlgebraicValue),
1975    Block(Vec<Expr>),
1976    Ident(String),
1977    Crud(Box<CrudExpr>),
1978    Halt(ErrorLang),
1979}
1980
1981impl From<QueryExpr> for Expr {
1982    fn from(x: QueryExpr) -> Self {
1983        Expr::Crud(Box::new(CrudExpr::Query(x)))
1984    }
1985}
1986
1987impl fmt::Display for Query {
1988    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1989        match self {
1990            Query::IndexScan(op) => {
1991                write!(f, "index_scan {:?}", op)
1992            }
1993            Query::IndexJoin(op) => {
1994                write!(f, "index_join {:?}", op)
1995            }
1996            Query::Select(q) => {
1997                write!(f, "select {q}")
1998            }
1999            Query::Project(proj) => {
2000                let q = &proj.cols;
2001                write!(f, "project")?;
2002                if !q.is_empty() {
2003                    write!(f, " ")?;
2004                }
2005                for (pos, x) in q.iter().enumerate() {
2006                    write!(f, "{x}")?;
2007                    if pos + 1 < q.len() {
2008                        write!(f, ", ")?;
2009                    }
2010                }
2011                Ok(())
2012            }
2013            Query::JoinInner(q) => {
2014                write!(f, "&inner {:?} ON {} = {}", q.rhs, q.col_lhs, q.col_rhs)
2015            }
2016        }
2017    }
2018}
2019
2020impl AuthAccess for SourceExpr {
2021    fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> {
2022        if owner == caller || self.table_access() == StAccess::Public {
2023            return Ok(());
2024        }
2025
2026        Err(AuthError::TablePrivate {
2027            named: self.table_name().to_string(),
2028        })
2029    }
2030}
2031
2032impl AuthAccess for QueryExpr {
2033    fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> {
2034        if owner == caller {
2035            return Ok(());
2036        }
2037        self.walk_sources(&mut |s| s.check_auth(owner, caller))
2038    }
2039}
2040
2041impl AuthAccess for CrudExpr {
2042    fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> {
2043        if owner == caller {
2044            return Ok(());
2045        }
2046        // Anyone may query, so as long as the tables involved are public.
2047        if let CrudExpr::Query(q) = self {
2048            return q.check_auth(owner, caller);
2049        }
2050
2051        // Mutating operations require `owner == caller`.
2052        Err(AuthError::OwnerRequired)
2053    }
2054}
2055
2056#[derive(Debug, PartialEq)]
2057pub struct Update {
2058    pub table_id: TableId,
2059    pub table_name: Box<str>,
2060    pub inserts: Vec<ProductValue>,
2061    pub deletes: Vec<ProductValue>,
2062}
2063
2064#[derive(Debug, PartialEq)]
2065pub enum Code {
2066    Value(AlgebraicValue),
2067    Table(MemTable),
2068    Halt(ErrorLang),
2069    Block(Vec<Code>),
2070    Crud(CrudExpr),
2071    Pass(Option<Update>),
2072}
2073
2074impl fmt::Display for Code {
2075    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2076        match self {
2077            Code::Value(x) => {
2078                write!(f, "{:?}", &x)
2079            }
2080            Code::Block(_) => write!(f, "Block"),
2081            x => todo!("{:?}", x),
2082        }
2083    }
2084}
2085
2086#[derive(Debug, PartialEq)]
2087pub enum CodeResult {
2088    Value(AlgebraicValue),
2089    Table(MemTable),
2090    Block(Vec<CodeResult>),
2091    Halt(ErrorLang),
2092    Pass(Option<Update>),
2093}
2094
2095impl From<Code> for CodeResult {
2096    fn from(code: Code) -> Self {
2097        match code {
2098            Code::Value(x) => Self::Value(x),
2099            Code::Table(x) => Self::Table(x),
2100            Code::Halt(x) => Self::Halt(x),
2101            Code::Block(x) => {
2102                if x.is_empty() {
2103                    Self::Pass(None)
2104                } else {
2105                    Self::Block(x.into_iter().map(CodeResult::from).collect())
2106                }
2107            }
2108            Code::Pass(x) => Self::Pass(x),
2109            x => Self::Halt(ErrorLang::new(
2110                ErrorKind::Compiler,
2111                Some(&format!("Invalid result: {x}")),
2112            )),
2113        }
2114    }
2115}
2116
2117#[cfg(test)]
2118mod tests {
2119    use super::*;
2120
2121    use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, relation::Column};
2122    use spacetimedb_sats::{product, AlgebraicType, ProductType};
2123    use spacetimedb_schema::{def::ModuleDef, schema::Schema};
2124    use typed_arena::Arena;
2125
2126    const ALICE: Identity = Identity::from_byte_array([1; 32]);
2127    const BOB: Identity = Identity::from_byte_array([2; 32]);
2128
2129    // TODO(kim): Should better do property testing here, but writing generators
2130    // on recursive types (ie. `Query` and friends) is tricky.
2131
2132    fn tables() -> [SourceExpr; 2] {
2133        [
2134            SourceExpr::InMemory {
2135                source_id: SourceId(0),
2136                header: Arc::new(Header {
2137                    table_id: 42.into(),
2138                    table_name: "foo".into(),
2139                    fields: vec![],
2140                    constraints: Default::default(),
2141                }),
2142                table_type: StTableType::User,
2143                table_access: StAccess::Private,
2144            },
2145            SourceExpr::DbTable(DbTable {
2146                head: Arc::new(Header {
2147                    table_id: 42.into(),
2148                    table_name: "foo".into(),
2149                    fields: vec![],
2150                    constraints: [(ColId(42).into(), Constraints::indexed())].into_iter().collect(),
2151                }),
2152                table_id: 42.into(),
2153                table_type: StTableType::User,
2154                table_access: StAccess::Private,
2155            }),
2156        ]
2157    }
2158
2159    fn queries() -> impl IntoIterator<Item = Query> {
2160        let [mem_table, db_table] = tables();
2161        // Skip `Query::Select` and `QueryProject` -- they don't have table
2162        // information
2163        [
2164            Query::IndexScan(IndexScan {
2165                table: db_table.get_db_table().unwrap().clone(),
2166                columns: ColList::new(42.into()),
2167                bounds: (Bound::Included(22.into()), Bound::Unbounded),
2168            }),
2169            Query::IndexJoin(IndexJoin {
2170                probe_side: mem_table.clone().into(),
2171                probe_col: 0.into(),
2172                index_side: SourceExpr::DbTable(DbTable {
2173                    head: Arc::new(Header {
2174                        table_id: db_table.head().table_id,
2175                        table_name: db_table.table_name().into(),
2176                        fields: vec![],
2177                        constraints: Default::default(),
2178                    }),
2179                    table_id: db_table.head().table_id,
2180                    table_type: StTableType::User,
2181                    table_access: StAccess::Public,
2182                }),
2183                index_select: None,
2184                index_col: 22.into(),
2185                return_index_rows: true,
2186            }),
2187            Query::JoinInner(JoinExpr {
2188                col_rhs: 1.into(),
2189                rhs: mem_table.into(),
2190                col_lhs: 1.into(),
2191                inner: None,
2192            }),
2193        ]
2194    }
2195
2196    fn query_exprs() -> impl IntoIterator<Item = QueryExpr> {
2197        tables().map(|table| {
2198            let mut expr = QueryExpr::from(table);
2199            expr.query = queries().into_iter().collect();
2200            expr
2201        })
2202    }
2203
2204    fn assert_owner_private<T: AuthAccess>(auth: &T) {
2205        assert!(auth.check_auth(ALICE, ALICE).is_ok());
2206        assert!(matches!(
2207            auth.check_auth(ALICE, BOB),
2208            Err(AuthError::TablePrivate { .. })
2209        ));
2210    }
2211
2212    fn assert_owner_required<T: AuthAccess>(auth: T) {
2213        assert!(auth.check_auth(ALICE, ALICE).is_ok());
2214        assert!(matches!(auth.check_auth(ALICE, BOB), Err(AuthError::OwnerRequired)));
2215    }
2216
2217    fn mem_table(id: TableId, name: &str, fields: &[(u16, AlgebraicType, bool)]) -> SourceExpr {
2218        let table_access = StAccess::Public;
2219        let head = Header::new(
2220            id,
2221            name.into(),
2222            fields
2223                .iter()
2224                .map(|(col, ty, _)| Column::new(FieldName::new(id, (*col).into()), ty.clone()))
2225                .collect(),
2226            fields
2227                .iter()
2228                .enumerate()
2229                .filter(|(_, (_, _, indexed))| *indexed)
2230                .map(|(i, _)| (ColId::from(i).into(), Constraints::indexed())),
2231        );
2232        SourceExpr::InMemory {
2233            source_id: SourceId(0),
2234            header: Arc::new(head),
2235            table_access,
2236            table_type: StTableType::User,
2237        }
2238    }
2239
2240    #[test]
2241    fn test_index_to_inner_join() {
2242        let index_side = mem_table(
2243            0.into(),
2244            "index",
2245            &[(0, AlgebraicType::U8, false), (1, AlgebraicType::U8, true)],
2246        );
2247        let probe_side = mem_table(
2248            1.into(),
2249            "probe",
2250            &[(0, AlgebraicType::U8, false), (1, AlgebraicType::U8, true)],
2251        );
2252
2253        let index_col = 1.into();
2254        let probe_col = 1.into();
2255        let index_select = ColumnOp::cmp(0, OpCmp::Eq, 0u8);
2256        let join = IndexJoin {
2257            probe_side: probe_side.clone().into(),
2258            probe_col,
2259            index_side: index_side.clone(),
2260            index_select: Some(index_select.clone()),
2261            index_col,
2262            return_index_rows: false,
2263        };
2264
2265        let expr = join.to_inner_join();
2266
2267        assert_eq!(expr.source, probe_side);
2268        assert_eq!(expr.query.len(), 1);
2269
2270        let Query::JoinInner(ref join) = expr.query[0] else {
2271            panic!("expected an inner join, but got {:#?}", expr.query[0]);
2272        };
2273
2274        assert_eq!(join.col_lhs, probe_col);
2275        assert_eq!(join.col_rhs, index_col);
2276        assert_eq!(
2277            join.rhs,
2278            QueryExpr {
2279                source: index_side,
2280                query: vec![index_select.into()]
2281            }
2282        );
2283        assert_eq!(join.inner, None);
2284    }
2285
2286    fn setup_best_index() -> (Header, [ColId; 5], [AlgebraicValue; 5]) {
2287        let table_id = 0.into();
2288
2289        let vals = [1, 2, 3, 4, 5].map(AlgebraicValue::U64);
2290        let col_ids = [0, 1, 2, 3, 4].map(ColId);
2291        let [a, b, c, d, _] = col_ids;
2292        let columns = col_ids.map(|c| Column::new(FieldName::new(table_id, c), AlgebraicType::I8));
2293
2294        let head1 = Header::new(
2295            table_id,
2296            "t1".into(),
2297            columns.to_vec(),
2298            vec![
2299                // Index a
2300                (a.into(), Constraints::primary_key()),
2301                // Index b
2302                (b.into(), Constraints::indexed()),
2303                // Index b + c
2304                (col_list![b, c], Constraints::unique()),
2305                // Index a + b + c + d
2306                (col_list![a, b, c, d], Constraints::indexed()),
2307            ],
2308        );
2309
2310        (head1, col_ids, vals)
2311    }
2312
2313    fn make_field_value((cmp, col, value): (OpCmp, ColId, &AlgebraicValue)) -> ColumnOp {
2314        ColumnOp::cmp(col, cmp, value.clone())
2315    }
2316
2317    fn scan_eq<'a>(arena: &'a Arena<ColumnOp>, col: ColId, val: &'a AlgebraicValue) -> IndexColumnOp<'a> {
2318        scan(arena, OpCmp::Eq, col, val)
2319    }
2320
2321    fn scan<'a>(arena: &'a Arena<ColumnOp>, cmp: OpCmp, col: ColId, val: &'a AlgebraicValue) -> IndexColumnOp<'a> {
2322        IndexColumnOp::Scan(arena.alloc(make_field_value((cmp, col, val))))
2323    }
2324
2325    #[test]
2326    fn best_index() {
2327        let (head1, fields, vals) = setup_best_index();
2328        let [col_a, col_b, col_c, col_d, col_e] = fields;
2329        let [val_a, val_b, val_c, val_d, val_e] = vals;
2330
2331        let arena = Arena::new();
2332        let select_best_index = |fields: &[_]| {
2333            let fields = fields
2334                .iter()
2335                .copied()
2336                .map(|(col, val): (ColId, _)| make_field_value((OpCmp::Eq, col, val)))
2337                .reduce(ColumnOp::and)
2338                .unwrap();
2339            select_best_index(&mut <_>::default(), &head1, arena.alloc(fields))
2340        };
2341
2342        let col_list_arena = Arena::new();
2343        let idx_eq = |cols, val| make_index_arg(OpCmp::Eq, col_list_arena.alloc(cols), val);
2344
2345        // Check for simple scan
2346        assert_eq!(
2347            select_best_index(&[(col_d, &val_e)]),
2348            [scan_eq(&arena, col_d, &val_e)].into(),
2349        );
2350
2351        assert_eq!(
2352            select_best_index(&[(col_a, &val_a)]),
2353            [idx_eq(col_a.into(), val_a.clone())].into(),
2354        );
2355
2356        assert_eq!(
2357            select_best_index(&[(col_b, &val_b)]),
2358            [idx_eq(col_b.into(), val_b.clone())].into(),
2359        );
2360
2361        // Check for permutation
2362        assert_eq!(
2363            select_best_index(&[(col_b, &val_b), (col_c, &val_c)]),
2364            [idx_eq(
2365                col_list![col_b, col_c],
2366                product![val_b.clone(), val_c.clone()].into()
2367            )]
2368            .into(),
2369        );
2370
2371        assert_eq!(
2372            select_best_index(&[(col_c, &val_c), (col_b, &val_b)]),
2373            [idx_eq(
2374                col_list![col_b, col_c],
2375                product![val_b.clone(), val_c.clone()].into()
2376            )]
2377            .into(),
2378        );
2379
2380        // Check for permutation
2381        assert_eq!(
2382            select_best_index(&[(col_a, &val_a), (col_b, &val_b), (col_c, &val_c), (col_d, &val_d)]),
2383            [idx_eq(
2384                col_list![col_a, col_b, col_c, col_d],
2385                product![val_a.clone(), val_b.clone(), val_c.clone(), val_d.clone()].into(),
2386            )]
2387            .into(),
2388        );
2389
2390        assert_eq!(
2391            select_best_index(&[(col_b, &val_b), (col_a, &val_a), (col_d, &val_d), (col_c, &val_c)]),
2392            [idx_eq(
2393                col_list![col_a, col_b, col_c, col_d],
2394                product![val_a.clone(), val_b.clone(), val_c.clone(), val_d.clone()].into(),
2395            )]
2396            .into()
2397        );
2398
2399        // Check mix scan + index
2400        assert_eq!(
2401            select_best_index(&[(col_b, &val_b), (col_a, &val_a), (col_e, &val_e), (col_d, &val_d)]),
2402            [
2403                idx_eq(col_a.into(), val_a.clone()),
2404                idx_eq(col_b.into(), val_b.clone()),
2405                scan_eq(&arena, col_d, &val_d),
2406                scan_eq(&arena, col_e, &val_e),
2407            ]
2408            .into()
2409        );
2410
2411        assert_eq!(
2412            select_best_index(&[(col_b, &val_b), (col_c, &val_c), (col_d, &val_d)]),
2413            [
2414                idx_eq(col_list![col_b, col_c], product![val_b.clone(), val_c.clone()].into(),),
2415                scan_eq(&arena, col_d, &val_d),
2416            ]
2417            .into()
2418        );
2419    }
2420
2421    #[test]
2422    fn best_index_range() {
2423        let arena = Arena::new();
2424
2425        let (head1, cols, vals) = setup_best_index();
2426        let [col_a, col_b, col_c, col_d, _] = cols;
2427        let [val_a, val_b, val_c, val_d, _] = vals;
2428
2429        let select_best_index = |cols: &[_]| {
2430            let fields = cols.iter().map(|x| make_field_value(*x)).reduce(ColumnOp::and).unwrap();
2431            select_best_index(&mut <_>::default(), &head1, arena.alloc(fields))
2432        };
2433
2434        let col_list_arena = Arena::new();
2435        let idx = |cmp, cols: &[ColId], val: &AlgebraicValue| {
2436            let columns = cols.iter().copied().collect::<ColList>();
2437            let columns = col_list_arena.alloc(columns);
2438            make_index_arg(cmp, columns, val.clone())
2439        };
2440
2441        // `a > va AND a < vb` => `[index(a), index(a)]`
2442        assert_eq!(
2443            select_best_index(&[(OpCmp::Gt, col_a, &val_a), (OpCmp::Lt, col_a, &val_b)]),
2444            [idx(OpCmp::Lt, &[col_a], &val_b), idx(OpCmp::Gt, &[col_a], &val_a)].into()
2445        );
2446
2447        // `d > vd AND d < vb` => `[scan(d), scan(d)]`
2448        assert_eq!(
2449            select_best_index(&[(OpCmp::Gt, col_d, &val_d), (OpCmp::Lt, col_d, &val_b)]),
2450            [
2451                scan(&arena, OpCmp::Lt, col_d, &val_b),
2452                scan(&arena, OpCmp::Gt, col_d, &val_d)
2453            ]
2454            .into()
2455        );
2456
2457        // `b > vb AND c < vc` => `[index(b), scan(c)]`.
2458        assert_eq!(
2459            select_best_index(&[(OpCmp::Gt, col_b, &val_b), (OpCmp::Lt, col_c, &val_c)]),
2460            [idx(OpCmp::Gt, &[col_b], &val_b), scan(&arena, OpCmp::Lt, col_c, &val_c)].into()
2461        );
2462
2463        // `b = vb AND a >= va AND c = vc` => `[index(b, c), index(a)]`
2464        let idx_bc = idx(
2465            OpCmp::Eq,
2466            &[col_b, col_c],
2467            &product![val_b.clone(), val_c.clone()].into(),
2468        );
2469        assert_eq!(
2470            //
2471            select_best_index(&[
2472                (OpCmp::Eq, col_b, &val_b),
2473                (OpCmp::GtEq, col_a, &val_a),
2474                (OpCmp::Eq, col_c, &val_c),
2475            ]),
2476            [idx_bc.clone(), idx(OpCmp::GtEq, &[col_a], &val_a),].into()
2477        );
2478
2479        // `b > vb AND a = va AND c = vc` => `[index(a), index(b), scan(c)]`
2480        assert_eq!(
2481            select_best_index(&[
2482                (OpCmp::Gt, col_b, &val_b),
2483                (OpCmp::Eq, col_a, &val_a),
2484                (OpCmp::Lt, col_c, &val_c),
2485            ]),
2486            [
2487                idx(OpCmp::Eq, &[col_a], &val_a),
2488                idx(OpCmp::Gt, &[col_b], &val_b),
2489                scan(&arena, OpCmp::Lt, col_c, &val_c),
2490            ]
2491            .into()
2492        );
2493
2494        // `a = va AND b = vb AND c = vc AND d > vd` => `[index(b, c), index(a), scan(d)]`
2495        assert_eq!(
2496            select_best_index(&[
2497                (OpCmp::Eq, col_a, &val_a),
2498                (OpCmp::Eq, col_b, &val_b),
2499                (OpCmp::Eq, col_c, &val_c),
2500                (OpCmp::Gt, col_d, &val_d),
2501            ]),
2502            [
2503                idx_bc.clone(),
2504                idx(OpCmp::Eq, &[col_a], &val_a),
2505                scan(&arena, OpCmp::Gt, col_d, &val_d),
2506            ]
2507            .into()
2508        );
2509
2510        // `b = vb AND c = vc AND b = vb AND c = vc` => `[index(b, c), index(b, c)]`
2511        assert_eq!(
2512            select_best_index(&[
2513                (OpCmp::Eq, col_b, &val_b),
2514                (OpCmp::Eq, col_c, &val_c),
2515                (OpCmp::Eq, col_b, &val_b),
2516                (OpCmp::Eq, col_c, &val_c),
2517            ]),
2518            [idx_bc.clone(), idx_bc].into()
2519        );
2520    }
2521
2522    #[test]
2523    fn test_auth_table() {
2524        tables().iter().for_each(assert_owner_private)
2525    }
2526
2527    #[test]
2528    fn test_auth_query_code() {
2529        for code in query_exprs() {
2530            assert_owner_private(&code)
2531        }
2532    }
2533
2534    #[test]
2535    fn test_auth_query() {
2536        for query in queries() {
2537            assert_owner_private(&query);
2538        }
2539    }
2540
2541    #[test]
2542    fn test_auth_crud_code_query() {
2543        for query in query_exprs() {
2544            let crud = CrudExpr::Query(query);
2545            assert_owner_private(&crud);
2546        }
2547    }
2548
2549    #[test]
2550    fn test_auth_crud_code_insert() {
2551        for table in tables().into_iter().filter_map(|s| s.get_db_table().cloned()) {
2552            let crud = CrudExpr::Insert { table, rows: vec![] };
2553            assert_owner_required(crud);
2554        }
2555    }
2556
2557    #[test]
2558    fn test_auth_crud_code_update() {
2559        for qc in query_exprs() {
2560            let crud = CrudExpr::Update {
2561                delete: qc,
2562                assignments: Default::default(),
2563            };
2564            assert_owner_required(crud);
2565        }
2566    }
2567
2568    #[test]
2569    fn test_auth_crud_code_delete() {
2570        for query in query_exprs() {
2571            let crud = CrudExpr::Delete { query };
2572            assert_owner_required(crud);
2573        }
2574    }
2575
2576    fn test_def() -> ModuleDef {
2577        let mut builder = RawModuleDefV9Builder::new();
2578        builder.build_table_with_new_type(
2579            "lhs",
2580            ProductType::from([("a", AlgebraicType::I32), ("b", AlgebraicType::String)]),
2581            true,
2582        );
2583        builder.build_table_with_new_type(
2584            "rhs",
2585            ProductType::from([("c", AlgebraicType::I32), ("d", AlgebraicType::I64)]),
2586            true,
2587        );
2588        builder.finish().try_into().expect("test def should be valid")
2589    }
2590
2591    #[test]
2592    /// Tests that [`QueryExpr::optimize`] can rewrite inner joins followed by projections into semijoins.
2593    fn optimize_inner_join_to_semijoin() {
2594        let def: ModuleDef = test_def();
2595        let lhs = TableSchema::from_module_def(&def, def.table("lhs").unwrap(), (), 0.into());
2596        let rhs = TableSchema::from_module_def(&def, def.table("rhs").unwrap(), (), 1.into());
2597
2598        let lhs_source = SourceExpr::from(&lhs);
2599        let rhs_source = SourceExpr::from(&rhs);
2600
2601        let q = QueryExpr::new(lhs_source.clone())
2602            .with_join_inner(rhs_source.clone(), 0.into(), 0.into(), false)
2603            .with_project(
2604                [0, 1]
2605                    .map(|c| FieldExpr::Name(FieldName::new(lhs.table_id, c.into())))
2606                    .into(),
2607                Some(TableId::SENTINEL),
2608            )
2609            .unwrap();
2610        let q = q.optimize(&|_, _| 0);
2611
2612        assert_eq!(q.source, lhs_source, "Optimized query should read from lhs");
2613
2614        assert_eq!(
2615            q.query.len(),
2616            1,
2617            "Optimized query should have a single member, a semijoin"
2618        );
2619        match &q.query[0] {
2620            Query::JoinInner(JoinExpr { rhs, inner: semi, .. }) => {
2621                assert_eq!(semi, &None, "Optimized query should be a semijoin");
2622                assert_eq!(rhs.source, rhs_source, "Optimized query should filter with rhs");
2623                assert!(
2624                    rhs.query.is_empty(),
2625                    "Optimized query should not filter rhs before joining"
2626                );
2627            }
2628            wrong => panic!("Expected an inner join, but found {wrong:?}"),
2629        }
2630    }
2631
2632    #[test]
2633    /// Tests that [`QueryExpr::optimize`] will not rewrite inner joins which are not followed by projections to the LHS table.
2634    fn optimize_inner_join_no_project() {
2635        let def: ModuleDef = test_def();
2636        let lhs = TableSchema::from_module_def(&def, def.table("lhs").unwrap(), (), 0.into());
2637        let rhs = TableSchema::from_module_def(&def, def.table("rhs").unwrap(), (), 1.into());
2638
2639        let lhs_source = SourceExpr::from(&lhs);
2640        let rhs_source = SourceExpr::from(&rhs);
2641
2642        let q = QueryExpr::new(lhs_source.clone()).with_join_inner(rhs_source.clone(), 0.into(), 0.into(), false);
2643        let optimized = q.clone().optimize(&|_, _| 0);
2644        assert_eq!(q, optimized);
2645    }
2646
2647    #[test]
2648    /// Tests that [`QueryExpr::optimize`] will not rewrite inner joins followed by projections to the RHS rather than LHS table.
2649    fn optimize_inner_join_wrong_project() {
2650        let def: ModuleDef = test_def();
2651        let lhs = TableSchema::from_module_def(&def, def.table("lhs").unwrap(), (), 0.into());
2652        let rhs = TableSchema::from_module_def(&def, def.table("rhs").unwrap(), (), 1.into());
2653
2654        let lhs_source = SourceExpr::from(&lhs);
2655        let rhs_source = SourceExpr::from(&rhs);
2656
2657        let q = QueryExpr::new(lhs_source.clone())
2658            .with_join_inner(rhs_source.clone(), 0.into(), 0.into(), false)
2659            .with_project(
2660                [0, 1]
2661                    .map(|c| FieldExpr::Name(FieldName::new(rhs.table_id, c.into())))
2662                    .into(),
2663                Some(TableId(1)),
2664            )
2665            .unwrap();
2666        let optimized = q.clone().optimize(&|_, _| 0);
2667        assert_eq!(q, optimized);
2668    }
2669}