1use crate::errors::{ErrorKind, ErrorLang};
2use crate::operator::{OpCmp, OpLogic, OpQuery};
3use crate::relation::{MemTable, RelValue};
4use arrayvec::ArrayVec;
5use core::slice::from_ref;
6use derive_more::From;
7use itertools::Itertools;
8use smallvec::SmallVec;
9use spacetimedb_data_structures::map::{HashSet, IntMap};
10use spacetimedb_lib::db::auth::{StAccess, StTableType};
11use spacetimedb_lib::db::error::{AuthError, RelationError};
12use spacetimedb_lib::relation::{ColExpr, DbTable, FieldName, Header};
13use spacetimedb_lib::{AlgebraicType, Identity};
14use spacetimedb_primitives::*;
15use spacetimedb_sats::algebraic_value::AlgebraicValue;
16use spacetimedb_sats::satn::Satn;
17use spacetimedb_sats::ProductValue;
18use spacetimedb_schema::schema::TableSchema;
19use std::borrow::Cow;
20use std::cmp::Reverse;
21use std::collections::btree_map::Entry;
22use std::collections::BTreeMap;
23use std::ops::Bound;
24use std::sync::Arc;
25use std::{fmt, iter, mem};
26
27pub trait AuthAccess {
29 fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError>;
30}
31
32#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, From)]
33pub enum FieldExpr {
34 Name(FieldName),
35 Value(AlgebraicValue),
36}
37
38impl FieldExpr {
39 pub fn strip_table(self) -> ColExpr {
40 match self {
41 Self::Name(field) => ColExpr::Col(field.col),
42 Self::Value(value) => ColExpr::Value(value),
43 }
44 }
45
46 pub fn name_to_col(self, head: &Header) -> Result<ColExpr, RelationError> {
47 match self {
48 Self::Value(val) => Ok(ColExpr::Value(val)),
49 Self::Name(field) => head.column_pos_or_err(field).map(ColExpr::Col),
50 }
51 }
52}
53
54impl fmt::Display for FieldExpr {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 match self {
57 FieldExpr::Name(x) => write!(f, "{x}"),
58 FieldExpr::Value(x) => write!(f, "{}", x.to_satn()),
59 }
60 }
61}
62
63#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, From)]
64pub enum FieldOp {
65 #[from]
66 Field(FieldExpr),
67 Cmp {
68 op: OpQuery,
69 lhs: Box<FieldOp>,
70 rhs: Box<FieldOp>,
71 },
72}
73
74type FieldOpFlat = SmallVec<[FieldOp; 1]>;
75
76impl FieldOp {
77 pub fn new(op: OpQuery, lhs: Self, rhs: Self) -> Self {
78 Self::Cmp {
79 op,
80 lhs: Box::new(lhs),
81 rhs: Box::new(rhs),
82 }
83 }
84
85 pub fn cmp(field: impl Into<FieldName>, op: OpCmp, value: impl Into<AlgebraicValue>) -> Self {
86 Self::new(
87 OpQuery::Cmp(op),
88 Self::Field(FieldExpr::Name(field.into())),
89 Self::Field(FieldExpr::Value(value.into())),
90 )
91 }
92
93 pub fn names_to_cols(self, head: &Header) -> Result<ColumnOp, RelationError> {
94 match self {
95 Self::Field(field) => field.name_to_col(head).map(ColumnOp::from),
96 Self::Cmp { op, lhs, rhs } => {
97 let lhs = lhs.names_to_cols(head)?;
98 let rhs = rhs.names_to_cols(head)?;
99 Ok(ColumnOp::new(op, lhs, rhs))
100 }
101 }
102 }
103
104 pub fn flatten_ands(self) -> FieldOpFlat {
112 fn fill_vec(buf: &mut FieldOpFlat, op: FieldOp) {
113 match op {
114 FieldOp::Cmp {
115 op: OpQuery::Logic(OpLogic::And),
116 lhs,
117 rhs,
118 } => {
119 fill_vec(buf, *lhs);
120 fill_vec(buf, *rhs);
121 }
122 op => buf.push(op),
123 }
124 }
125 let mut buf = SmallVec::new();
126 fill_vec(&mut buf, self);
127 buf
128 }
129}
130
131impl fmt::Display for FieldOp {
132 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133 match self {
134 Self::Field(x) => {
135 write!(f, "{}", x)
136 }
137 Self::Cmp { op, lhs, rhs } => {
138 write!(f, "{} {} {}", lhs, op, rhs)
139 }
140 }
141 }
142}
143
144#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, From)]
145pub enum ColumnOp {
146 #[from]
148 Col(ColId),
149 #[from]
151 Val(AlgebraicValue),
152 ColCmpVal {
155 lhs: ColId,
156 cmp: OpCmp,
157 rhs: AlgebraicValue,
158 },
159 Cmp {
161 lhs: Box<ColumnOp>,
162 cmp: OpCmp,
163 rhs: Box<ColumnOp>,
164 },
165 Log { op: OpLogic, operands: Box<[ColumnOp]> },
169}
170
171impl ColumnOp {
172 pub fn new(op: OpQuery, lhs: Self, rhs: Self) -> Self {
173 match op {
174 OpQuery::Cmp(cmp) => match (lhs, rhs) {
175 (ColumnOp::Col(lhs), ColumnOp::Val(rhs)) => Self::cmp(lhs, cmp, rhs),
176 (lhs, rhs) => Self::Cmp {
177 lhs: Box::new(lhs),
178 cmp,
179 rhs: Box::new(rhs),
180 },
181 },
182 OpQuery::Logic(op) => Self::Log {
183 op,
184 operands: [lhs, rhs].into(),
185 },
186 }
187 }
188
189 pub fn cmp(col: impl Into<ColId>, cmp: OpCmp, val: impl Into<AlgebraicValue>) -> Self {
190 let lhs = col.into();
191 let rhs = val.into();
192 Self::ColCmpVal { lhs, cmp, rhs }
193 }
194
195 fn and(lhs: Self, rhs: Self) -> Self {
197 let ands = |operands| {
198 let op = OpLogic::And;
199 Self::Log { op, operands }
200 };
201
202 match (lhs, rhs) {
203 (
205 Self::Log {
206 op: OpLogic::And,
207 operands: lhs,
208 },
209 Self::Log {
210 op: OpLogic::And,
211 operands: rhs,
212 },
213 ) => {
214 let mut operands = Vec::from(lhs);
215 operands.append(&mut Vec::from(rhs));
216 ands(operands.into())
217 }
218 (
220 Self::Log {
221 op: OpLogic::And,
222 operands: lhs,
223 },
224 rhs,
225 ) => {
226 let mut operands = Vec::from(lhs);
227 operands.push(rhs);
228 ands(operands.into())
229 }
230 (lhs, rhs) => ands([lhs, rhs].into()),
232 }
233 }
234
235 fn and_cmp(op: OpCmp, cols: &ColList, value: AlgebraicValue) -> Self {
237 let cmp = |(col, value): (ColId, _)| Self::cmp(col, op, value);
238
239 if let Some(head) = cols.as_singleton() {
241 return cmp((head, value));
242 }
243
244 let operands = cols.iter().zip(value.into_product().unwrap()).map(cmp).collect();
246 Self::Log {
247 op: OpLogic::And,
248 operands,
249 }
250 }
251
252 fn from_op_col_bounds(cols: &ColList, bounds: (Bound<AlgebraicValue>, Bound<AlgebraicValue>)) -> Self {
255 let (cmp, value) = match bounds {
256 (Bound::Included(a), Bound::Included(b)) if a == b => (OpCmp::Eq, a),
258 (Bound::Included(value), Bound::Unbounded) => (OpCmp::GtEq, value),
260 (Bound::Excluded(value), Bound::Unbounded) => (OpCmp::Gt, value),
262 (Bound::Unbounded, Bound::Included(value)) => (OpCmp::LtEq, value),
264 (Bound::Unbounded, Bound::Excluded(value)) => (OpCmp::Lt, value),
266 (Bound::Unbounded, Bound::Unbounded) => unreachable!(),
267 (lower_bound, upper_bound) => {
268 let lhs = Self::from_op_col_bounds(cols, (lower_bound, Bound::Unbounded));
269 let rhs = Self::from_op_col_bounds(cols, (Bound::Unbounded, upper_bound));
270 return ColumnOp::and(lhs, rhs);
271 }
272 };
273 ColumnOp::and_cmp(cmp, cols, value)
274 }
275
276 fn as_col_cmp(&self) -> Option<(ColId, OpCmp)> {
278 match self {
279 Self::ColCmpVal { lhs, cmp, rhs: _ } => Some((*lhs, *cmp)),
280 Self::Cmp { lhs, cmp, rhs: _ } => match &**lhs {
281 ColumnOp::Col(col) => Some((*col, *cmp)),
282 _ => None,
283 },
284 _ => None,
285 }
286 }
287
288 fn eval<'a>(&'a self, row: &'a RelValue<'_>) -> Cow<'a, AlgebraicValue> {
290 let into = |b| Cow::Owned(AlgebraicValue::Bool(b));
291
292 match self {
293 Self::Col(col) => row.read_column(col.idx()).unwrap(),
294 Self::Val(val) => Cow::Borrowed(val),
295 Self::ColCmpVal { lhs, cmp, rhs } => into(Self::eval_cmp_col_val(row, *cmp, *lhs, rhs)),
296 Self::Cmp { lhs, cmp, rhs } => into(Self::eval_cmp(row, *cmp, lhs, rhs)),
297 Self::Log { op, operands } => into(Self::eval_log(row, *op, operands)),
298 }
299 }
300
301 pub fn eval_bool(&self, row: &RelValue<'_>) -> bool {
303 match self {
304 Self::Col(col) => *row.read_column(col.idx()).unwrap().as_bool().unwrap(),
305 Self::Val(val) => *val.as_bool().unwrap(),
306 Self::ColCmpVal { lhs, cmp, rhs } => Self::eval_cmp_col_val(row, *cmp, *lhs, rhs),
307 Self::Cmp { lhs, cmp, rhs } => Self::eval_cmp(row, *cmp, lhs, rhs),
308 Self::Log { op, operands } => Self::eval_log(row, *op, operands),
309 }
310 }
311
312 fn eval_op_cmp(cmp: OpCmp, lhs: &AlgebraicValue, rhs: &AlgebraicValue) -> bool {
314 match cmp {
315 OpCmp::Eq => lhs == rhs,
316 OpCmp::NotEq => lhs != rhs,
317 OpCmp::Lt => lhs < rhs,
318 OpCmp::LtEq => lhs <= rhs,
319 OpCmp::Gt => lhs > rhs,
320 OpCmp::GtEq => lhs >= rhs,
321 }
322 }
323
324 fn eval_cmp_col_val(row: &RelValue<'_>, cmp: OpCmp, lhs: ColId, rhs: &AlgebraicValue) -> bool {
326 let lhs = row.read_column(lhs.idx()).unwrap();
327 Self::eval_op_cmp(cmp, &lhs, rhs)
328 }
329
330 fn eval_cmp(row: &RelValue<'_>, cmp: OpCmp, lhs: &Self, rhs: &Self) -> bool {
334 let lhs = lhs.eval(row);
335 let rhs = rhs.eval(row);
336 Self::eval_op_cmp(cmp, &lhs, &rhs)
337 }
338
339 fn eval_log(row: &RelValue<'_>, op: OpLogic, opers: &[ColumnOp]) -> bool {
343 match op {
344 OpLogic::And => opers.iter().all(|o| o.eval_bool(row)),
345 OpLogic::Or => opers.iter().any(|o| o.eval_bool(row)),
346 }
347 }
348}
349
350impl fmt::Display for ColumnOp {
351 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352 match self {
353 Self::Col(col) => write!(f, "{col}"),
354 Self::Val(val) => write!(f, "{}", val.to_satn()),
355 Self::ColCmpVal { lhs, cmp, rhs } => write!(f, "{lhs} {cmp} {}", rhs.to_satn()),
356 Self::Cmp { cmp, lhs, rhs } => write!(f, "{lhs} {cmp} {rhs}"),
357 Self::Log { op, operands } => write!(f, "{}", operands.iter().format((*op).into())),
358 }
359 }
360}
361
362impl From<ColExpr> for ColumnOp {
363 fn from(ce: ColExpr) -> Self {
364 match ce {
365 ColExpr::Col(c) => c.into(),
366 ColExpr::Value(v) => v.into(),
367 }
368 }
369}
370
371impl From<Query> for Option<ColumnOp> {
372 fn from(value: Query) -> Self {
373 match value {
374 Query::IndexScan(op) => Some(ColumnOp::from_op_col_bounds(&op.columns, op.bounds)),
375 Query::Select(op) => Some(op),
376 _ => None,
377 }
378 }
379}
380
381#[derive(Debug, Copy, Clone, PartialEq, Eq, From, Hash)]
392pub struct SourceId(pub usize);
393
394pub trait SourceProvider<'a> {
408 type Source: 'a + IntoIterator<Item = RelValue<'a>>;
410
411 fn take_source(&mut self, id: SourceId) -> Option<Self::Source>;
419}
420
421impl<'a, I: 'a + IntoIterator<Item = RelValue<'a>>, F: FnMut(SourceId) -> Option<I>> SourceProvider<'a> for F {
422 type Source = I;
423 fn take_source(&mut self, id: SourceId) -> Option<Self::Source> {
424 self(id)
425 }
426}
427
428impl<'a, I: 'a + IntoIterator<Item = RelValue<'a>>> SourceProvider<'a> for Option<I> {
429 type Source = I;
430 fn take_source(&mut self, _: SourceId) -> Option<Self::Source> {
431 self.take()
432 }
433}
434
435pub struct NoInMemUsed;
436
437impl<'a> SourceProvider<'a> for NoInMemUsed {
438 type Source = iter::Empty<RelValue<'a>>;
439 fn take_source(&mut self, _: SourceId) -> Option<Self::Source> {
440 None
441 }
442}
443
444#[derive(Debug, PartialEq, Eq, Clone)]
449#[repr(transparent)]
450pub struct SourceSet<T, const N: usize>(
451 ArrayVec<Option<T>, N>,
454);
455
456impl<'a, T: 'a + IntoIterator<Item = RelValue<'a>>, const N: usize> SourceProvider<'a> for SourceSet<T, N> {
457 type Source = T;
458 fn take_source(&mut self, id: SourceId) -> Option<T> {
459 self.take(id)
460 }
461}
462
463impl<T, const N: usize> From<[T; N]> for SourceSet<T, N> {
464 #[inline]
465 fn from(sources: [T; N]) -> Self {
466 Self(sources.map(Some).into())
467 }
468}
469
470impl<T, const N: usize> SourceSet<T, N> {
471 pub fn empty() -> Self {
473 Self(ArrayVec::new())
474 }
475
476 fn next_id(&self) -> SourceId {
478 SourceId(self.0.len())
479 }
480
481 pub fn add(&mut self, table: T) -> SourceId {
484 let source_id = self.next_id();
485 self.0.push(Some(table));
486 source_id
487 }
488
489 pub fn take(&mut self, id: SourceId) -> Option<T> {
494 self.0.get_mut(id.0).map(mem::take).unwrap_or_default()
495 }
496
497 pub fn len(&self) -> usize {
501 self.0.len()
502 }
503
504 pub fn is_empty(&self) -> bool {
508 self.0.is_empty()
509 }
510}
511
512impl<T, const N: usize> std::ops::Index<SourceId> for SourceSet<T, N> {
513 type Output = Option<T>;
514
515 fn index(&self, idx: SourceId) -> &Self::Output {
516 &self.0[idx.0]
517 }
518}
519
520impl<T, const N: usize> std::ops::IndexMut<SourceId> for SourceSet<T, N> {
521 fn index_mut(&mut self, idx: SourceId) -> &mut Self::Output {
522 &mut self.0[idx.0]
523 }
524}
525
526impl<const N: usize> SourceSet<Vec<ProductValue>, N> {
527 pub fn add_mem_table(&mut self, table: MemTable) -> SourceExpr {
530 let id = self.add(table.data);
531 SourceExpr::from_mem_table(table.head, table.table_access, id)
532 }
533}
534
535#[derive(Debug, Clone, Eq, PartialEq, Hash)]
538pub enum SourceExpr {
539 InMemory {
547 source_id: SourceId,
548 header: Arc<Header>,
549 table_type: StTableType,
550 table_access: StAccess,
551 },
552 DbTable(DbTable),
555}
556
557impl SourceExpr {
558 pub fn source_id(&self) -> Option<SourceId> {
563 if let SourceExpr::InMemory { source_id, .. } = self {
564 Some(*source_id)
565 } else {
566 None
567 }
568 }
569
570 pub fn table_name(&self) -> &str {
571 &self.head().table_name
572 }
573
574 pub fn table_type(&self) -> StTableType {
575 match self {
576 SourceExpr::InMemory { table_type, .. } => *table_type,
577 SourceExpr::DbTable(db_table) => db_table.table_type,
578 }
579 }
580
581 pub fn table_access(&self) -> StAccess {
582 match self {
583 SourceExpr::InMemory { table_access, .. } => *table_access,
584 SourceExpr::DbTable(db_table) => db_table.table_access,
585 }
586 }
587
588 pub fn head(&self) -> &Arc<Header> {
589 match self {
590 SourceExpr::InMemory { header, .. } => header,
591 SourceExpr::DbTable(db_table) => &db_table.head,
592 }
593 }
594
595 pub fn is_mem_table(&self) -> bool {
596 matches!(self, SourceExpr::InMemory { .. })
597 }
598
599 pub fn is_db_table(&self) -> bool {
600 matches!(self, SourceExpr::DbTable(_))
601 }
602
603 pub fn from_mem_table(header: Arc<Header>, table_access: StAccess, id: SourceId) -> Self {
604 SourceExpr::InMemory {
605 source_id: id,
606 header,
607 table_type: StTableType::User,
608 table_access,
609 }
610 }
611
612 pub fn table_id(&self) -> Option<TableId> {
613 if let SourceExpr::DbTable(db_table) = self {
614 Some(db_table.table_id)
615 } else {
616 None
617 }
618 }
619
620 pub fn get_db_table(&self) -> Option<&DbTable> {
626 if let SourceExpr::DbTable(db_table) = self {
627 Some(db_table)
628 } else {
629 None
630 }
631 }
632}
633
634impl From<&TableSchema> for SourceExpr {
635 fn from(value: &TableSchema) -> Self {
636 SourceExpr::DbTable(value.into())
637 }
638}
639
640#[derive(Debug, Clone, Eq, PartialEq, Hash)]
644pub struct IndexJoin {
645 pub probe_side: QueryExpr,
646 pub probe_col: ColId,
647 pub index_side: SourceExpr,
648 pub index_select: Option<ColumnOp>,
649 pub index_col: ColId,
650 pub return_index_rows: bool,
653}
654
655impl From<IndexJoin> for QueryExpr {
656 fn from(join: IndexJoin) -> Self {
657 let source: SourceExpr = if join.return_index_rows {
658 join.index_side.clone()
659 } else {
660 join.probe_side.source.clone()
661 };
662 QueryExpr {
663 source,
664 query: vec![Query::IndexJoin(join)],
665 }
666 }
667}
668
669impl IndexJoin {
670 pub fn reorder(self, row_count: impl Fn(TableId, &str) -> i64) -> Self {
674 if self.probe_side.source.is_mem_table() {
676 return self;
677 }
678 if !self
680 .probe_side
681 .source
682 .head()
683 .has_constraint(self.probe_col, Constraints::indexed())
684 {
685 return self;
686 }
687 if !self
689 .probe_side
690 .query
691 .iter()
692 .all(|op| matches!(op, Query::Select(_) | Query::IndexScan(_)))
693 {
694 return self;
695 }
696 match self.index_side.get_db_table() {
697 Some(DbTable { head, table_id, .. }) if row_count(*table_id, &head.table_name) > 500 => self,
703 _ => {
706 let predicate = self
709 .probe_side
710 .query
711 .into_iter()
712 .filter_map(<Query as Into<Option<ColumnOp>>>::into)
713 .reduce(ColumnOp::and);
714 let probe_side = if let Some(predicate) = self.index_select {
716 QueryExpr {
717 source: self.index_side,
718 query: vec![predicate.into()],
719 }
720 } else {
721 self.index_side.into()
722 };
723 IndexJoin {
724 probe_side,
727 probe_col: self.index_col,
729 index_side: self.probe_side.source,
731 index_select: predicate,
733 index_col: self.probe_col,
735 return_index_rows: !self.return_index_rows,
738 }
739 }
740 }
741 }
742
743 pub fn to_inner_join(self) -> QueryExpr {
748 if self.return_index_rows {
749 let (col_lhs, col_rhs) = (self.index_col, self.probe_col);
750 let rhs = self.probe_side;
751
752 let source = self.index_side;
753 let inner_join = Query::JoinInner(JoinExpr::new(rhs, col_lhs, col_rhs, None));
754 let query = if let Some(predicate) = self.index_select {
755 vec![predicate.into(), inner_join]
756 } else {
757 vec![inner_join]
758 };
759 QueryExpr { source, query }
760 } else {
761 let (col_lhs, col_rhs) = (self.probe_col, self.index_col);
762 let mut rhs: QueryExpr = self.index_side.into();
763
764 if let Some(predicate) = self.index_select {
765 rhs.query.push(predicate.into());
766 }
767
768 let source = self.probe_side.source;
769 let inner_join = Query::JoinInner(JoinExpr::new(rhs, col_lhs, col_rhs, None));
770 let query = vec![inner_join];
771 QueryExpr { source, query }
772 }
773 }
774}
775
776#[derive(Debug, Clone, Eq, PartialEq, Hash)]
777pub struct JoinExpr {
778 pub rhs: QueryExpr,
779 pub col_lhs: ColId,
780 pub col_rhs: ColId,
781 pub inner: Option<Arc<Header>>,
786}
787
788impl JoinExpr {
789 pub fn new(rhs: QueryExpr, col_lhs: ColId, col_rhs: ColId, inner: Option<Arc<Header>>) -> Self {
790 Self {
791 rhs,
792 col_lhs,
793 col_rhs,
794 inner,
795 }
796 }
797}
798
799#[derive(Debug, Clone, Copy, Eq, PartialEq)]
800pub enum DbType {
801 Table,
802 Index,
803 Sequence,
804 Constraint,
805}
806
807#[derive(Debug, Clone, Copy, Eq, PartialEq)]
808pub enum Crud {
809 Query,
810 Insert,
811 Update,
812 Delete,
813 Create(DbType),
814 Drop(DbType),
815 Config,
816}
817
818#[derive(Debug, Eq, PartialEq)]
819pub enum CrudExpr {
820 Query(QueryExpr),
821 Insert {
822 table: DbTable,
823 rows: Vec<ProductValue>,
824 },
825 Update {
826 delete: QueryExpr,
827 assignments: IntMap<ColId, ColExpr>,
828 },
829 Delete {
830 query: QueryExpr,
831 },
832 SetVar {
833 name: String,
834 literal: String,
835 },
836 ReadVar {
837 name: String,
838 },
839}
840
841impl CrudExpr {
842 pub fn optimize(self, row_count: &impl Fn(TableId, &str) -> i64) -> Self {
843 match self {
844 CrudExpr::Query(x) => CrudExpr::Query(x.optimize(row_count)),
845 _ => self,
846 }
847 }
848
849 pub fn is_reads<'a>(exprs: impl IntoIterator<Item = &'a CrudExpr>) -> bool {
850 exprs
851 .into_iter()
852 .all(|expr| matches!(expr, CrudExpr::Query(_) | CrudExpr::ReadVar { .. }))
853 }
854}
855
856#[derive(Debug, Clone, Eq, PartialEq, Hash)]
857pub struct IndexScan {
858 pub table: DbTable,
859 pub columns: ColList,
860 pub bounds: (Bound<AlgebraicValue>, Bound<AlgebraicValue>),
861}
862
863impl IndexScan {
864 pub fn is_point(&self) -> bool {
866 match &self.bounds {
867 (Bound::Included(lower), Bound::Included(upper)) => lower == upper,
868 _ => false,
869 }
870 }
871}
872
873#[derive(Debug, Clone, Eq, PartialEq, From, Hash)]
875pub struct ProjectExpr {
876 pub cols: Vec<ColExpr>,
877 pub wildcard_table: Option<TableId>,
880 pub header_after: Arc<Header>,
881}
882
883#[derive(Debug, Clone, Eq, PartialEq, From, Hash)]
885pub enum Query {
886 IndexScan(IndexScan),
888 IndexJoin(IndexJoin),
891 Select(ColumnOp),
895 Project(ProjectExpr),
897 JoinInner(JoinExpr),
901}
902
903impl Query {
904 pub fn walk_sources<E>(&self, on_source: &mut impl FnMut(&SourceExpr) -> Result<(), E>) -> Result<(), E> {
908 match self {
909 Self::Select(..) | Self::Project(..) => Ok(()),
910 Self::IndexScan(scan) => on_source(&SourceExpr::DbTable(scan.table.clone())),
911 Self::IndexJoin(join) => join.probe_side.walk_sources(on_source),
912 Self::JoinInner(join) => join.rhs.walk_sources(on_source),
913 }
914 }
915}
916
917#[derive(Debug, PartialEq, Clone)]
920enum IndexArgument<'a> {
921 Eq {
922 columns: &'a ColList,
923 value: AlgebraicValue,
924 },
925 LowerBound {
926 columns: &'a ColList,
927 value: AlgebraicValue,
928 inclusive: bool,
929 },
930 UpperBound {
931 columns: &'a ColList,
932 value: AlgebraicValue,
933 inclusive: bool,
934 },
935}
936
937#[derive(Debug, PartialEq, Clone)]
938enum IndexColumnOp<'a> {
939 Index(IndexArgument<'a>),
940 Scan(&'a ColumnOp),
941}
942
943fn make_index_arg(cmp: OpCmp, columns: &ColList, value: AlgebraicValue) -> IndexColumnOp<'_> {
944 let arg = match cmp {
945 OpCmp::Eq => IndexArgument::Eq { columns, value },
946 OpCmp::NotEq => unreachable!("No IndexArgument for NotEq, caller should've filtered out"),
947 OpCmp::Lt => IndexArgument::UpperBound {
949 columns,
950 value,
951 inclusive: false,
952 },
953 OpCmp::Gt => IndexArgument::LowerBound {
955 columns,
956 value,
957 inclusive: false,
958 },
959 OpCmp::LtEq => IndexArgument::UpperBound {
961 columns,
962 value,
963 inclusive: true,
964 },
965 OpCmp::GtEq => IndexArgument::LowerBound {
967 columns,
968 value,
969 inclusive: true,
970 },
971 };
972 IndexColumnOp::Index(arg)
973}
974
975#[derive(Debug)]
976struct ColValue<'a> {
977 parent: &'a ColumnOp,
978 col: ColId,
979 cmp: OpCmp,
980 value: &'a AlgebraicValue,
981}
982
983impl<'a> ColValue<'a> {
984 pub fn new(parent: &'a ColumnOp, col: ColId, cmp: OpCmp, value: &'a AlgebraicValue) -> Self {
985 Self {
986 parent,
987 col,
988 cmp,
989 value,
990 }
991 }
992}
993
994type IndexColumnOpSink<'a> = SmallVec<[IndexColumnOp<'a>; 1]>;
995type ColsIndexed = HashSet<(ColId, OpCmp)>;
996
997fn select_best_index<'a>(
1046 cols_indexed: &mut ColsIndexed,
1047 header: &'a Header,
1048 op: &'a ColumnOp,
1049) -> IndexColumnOpSink<'a> {
1050 let mut indices = header
1054 .constraints
1055 .iter()
1056 .filter(|(_, c)| c.has_indexed())
1057 .map(|(cl, _)| cl)
1058 .collect::<SmallVec<[_; 1]>>();
1059 indices.sort_unstable_by_key(|cl| Reverse(cl.len()));
1060
1061 let mut found: IndexColumnOpSink = IndexColumnOpSink::default();
1062
1063 let mut col_map = BTreeMap::<_, SmallVec<[_; 1]>>::new();
1067 extract_cols(op, &mut col_map, &mut found);
1068
1069 for col_list in indices {
1072 if col_map.is_empty() {
1074 break;
1075 }
1076
1077 if let Some(head) = col_list.as_singleton() {
1078 for cmp in [OpCmp::Eq, OpCmp::Lt, OpCmp::LtEq, OpCmp::Gt, OpCmp::GtEq] {
1082 for ColValue { cmp, value, col, .. } in col_map.remove(&(head, cmp)).into_iter().flatten() {
1085 found.push(make_index_arg(cmp, col_list, value.clone()));
1086 cols_indexed.insert((col, cmp));
1087 }
1088 }
1089 } else {
1090 let cmp = OpCmp::Eq;
1098
1099 let mut min_all_cols_num_eq = col_list
1101 .iter()
1102 .map(|col| col_map.get(&(col, cmp)).map_or(0, |fs| fs.len()))
1103 .min()
1104 .unwrap_or_default();
1105
1106 while min_all_cols_num_eq > 0 {
1109 let mut elems = Vec::with_capacity(col_list.len() as usize);
1110 for col in col_list.iter() {
1111 let col_val = pop_multimap(&mut col_map, (col, cmp)).unwrap();
1113 cols_indexed.insert((col_val.col, cmp));
1114 elems.push(col_val.value.clone());
1116 }
1117 let value = AlgebraicValue::product(elems);
1119 found.push(make_index_arg(cmp, col_list, value));
1120 min_all_cols_num_eq -= 1;
1121 }
1122 }
1123 }
1124
1125 found.extend(
1127 col_map
1128 .into_iter()
1129 .flat_map(|(_, fs)| fs)
1130 .map(|f| IndexColumnOp::Scan(f.parent)),
1131 );
1132
1133 found
1134}
1135
1136fn pop_multimap<K: Ord, V, const N: usize>(map: &mut BTreeMap<K, SmallVec<[V; N]>>, key: K) -> Option<V> {
1139 let Entry::Occupied(mut entry) = map.entry(key) else {
1140 return None;
1141 };
1142 let fields = entry.get_mut();
1143 let val = fields.pop();
1144 if fields.is_empty() {
1145 entry.remove();
1146 }
1147 val
1148}
1149
1150fn extract_cols<'a>(
1155 op: &'a ColumnOp,
1156 col_map: &mut BTreeMap<(ColId, OpCmp), SmallVec<[ColValue<'a>; 1]>>,
1157 found: &mut IndexColumnOpSink<'a>,
1158) {
1159 let mut add_field = |parent, op, col, val| {
1160 let fv = ColValue::new(parent, col, op, val);
1161 col_map.entry((col, op)).or_default().push(fv);
1162 };
1163
1164 match op {
1165 ColumnOp::Cmp { cmp, lhs, rhs } => {
1166 if let (ColumnOp::Col(col), ColumnOp::Val(val)) = (&**lhs, &**rhs) {
1167 add_field(op, *cmp, *col, val);
1169 }
1170 }
1171 ColumnOp::ColCmpVal { lhs, cmp, rhs } => add_field(op, *cmp, *lhs, rhs),
1172 ColumnOp::Log {
1173 op: OpLogic::And,
1174 operands,
1175 } => {
1176 for oper in operands.iter() {
1177 extract_cols(oper, col_map, found);
1178 }
1179 }
1180 ColumnOp::Log { op: OpLogic::Or, .. } | ColumnOp::Col(_) | ColumnOp::Val(_) => {
1181 found.push(IndexColumnOp::Scan(op));
1182 }
1183 }
1184}
1185
1186#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1187pub struct QueryExpr {
1194 pub source: SourceExpr,
1195 pub query: Vec<Query>,
1196}
1197
1198impl From<SourceExpr> for QueryExpr {
1199 fn from(source: SourceExpr) -> Self {
1200 QueryExpr { source, query: vec![] }
1201 }
1202}
1203
1204impl QueryExpr {
1205 pub fn new<T: Into<SourceExpr>>(source: T) -> Self {
1206 Self {
1207 source: source.into(),
1208 query: vec![],
1209 }
1210 }
1211
1212 pub fn walk_sources<E>(&self, on_source: &mut impl FnMut(&SourceExpr) -> Result<(), E>) -> Result<(), E> {
1216 on_source(&self.source)?;
1217 self.query.iter().try_for_each(|q| q.walk_sources(on_source))
1218 }
1219
1220 pub fn head(&self) -> &Arc<Header> {
1228 self.query
1229 .iter()
1230 .rev()
1231 .find_map(|op| match op {
1232 Query::Select(_) => None,
1233 Query::IndexScan(scan) => Some(&scan.table.head),
1234 Query::IndexJoin(join) if join.return_index_rows => Some(join.index_side.head()),
1235 Query::IndexJoin(join) => Some(join.probe_side.head()),
1236 Query::Project(proj) => Some(&proj.header_after),
1237 Query::JoinInner(join) => join.inner.as_ref(),
1238 })
1239 .unwrap_or_else(|| self.source.head())
1240 }
1241
1242 pub fn reads_from_table(&self, id: &TableId) -> bool {
1244 self.source.table_id() == Some(*id)
1245 || self.query.iter().any(|q| match q {
1246 Query::Select(_) | Query::Project(..) => false,
1247 Query::IndexScan(scan) => scan.table.table_id == *id,
1248 Query::JoinInner(join) => join.rhs.reads_from_table(id),
1249 Query::IndexJoin(join) => {
1250 join.index_side.table_id() == Some(*id) || join.probe_side.reads_from_table(id)
1251 }
1252 })
1253 }
1254
1255 pub fn with_index_eq(mut self, table: DbTable, columns: ColList, value: AlgebraicValue) -> Self {
1259 let point = |v: AlgebraicValue| (Bound::Included(v.clone()), Bound::Included(v));
1260
1261 let Some(query) = self.query.pop() else {
1263 let bounds = point(value);
1264 self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1265 return self;
1266 };
1267 match query {
1268 Query::JoinInner(JoinExpr {
1270 rhs:
1271 QueryExpr {
1272 source: SourceExpr::DbTable(ref db_table),
1273 ..
1274 },
1275 ..
1276 }) if table.table_id != db_table.table_id => {
1277 self = self.with_index_eq(db_table.clone(), columns, value);
1278 self.query.push(query);
1279 self
1280 }
1281 Query::JoinInner(JoinExpr {
1283 rhs,
1284 col_lhs,
1285 col_rhs,
1286 inner: semi,
1287 }) => {
1288 self.query.push(Query::JoinInner(JoinExpr {
1289 rhs: rhs.with_index_eq(table, columns, value),
1290 col_lhs,
1291 col_rhs,
1292 inner: semi,
1293 }));
1294 self
1295 }
1296 Query::Select(filter) => {
1298 let op = ColumnOp::and_cmp(OpCmp::Eq, &columns, value);
1299 self.query.push(Query::Select(ColumnOp::and(filter, op)));
1300 self
1301 }
1302 query => {
1304 self.query.push(query);
1305 let op = ColumnOp::and_cmp(OpCmp::Eq, &columns, value);
1306 self.query.push(Query::Select(op));
1307 self
1308 }
1309 }
1310 }
1311
1312 pub fn with_index_lower_bound(
1316 mut self,
1317 table: DbTable,
1318 columns: ColList,
1319 value: AlgebraicValue,
1320 inclusive: bool,
1321 ) -> Self {
1322 let Some(query) = self.query.pop() else {
1324 let bounds = (Self::bound(value, inclusive), Bound::Unbounded);
1325 self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1326 return self;
1327 };
1328 match query {
1329 Query::JoinInner(JoinExpr {
1331 rhs:
1332 QueryExpr {
1333 source: SourceExpr::DbTable(ref db_table),
1334 ..
1335 },
1336 ..
1337 }) if table.table_id != db_table.table_id => {
1338 self = self.with_index_lower_bound(table, columns, value, inclusive);
1339 self.query.push(query);
1340 self
1341 }
1342 Query::JoinInner(JoinExpr {
1344 rhs,
1345 col_lhs,
1346 col_rhs,
1347 inner: semi,
1348 }) => {
1349 self.query.push(Query::JoinInner(JoinExpr {
1350 rhs: rhs.with_index_lower_bound(table, columns, value, inclusive),
1351 col_lhs,
1352 col_rhs,
1353 inner: semi,
1354 }));
1355 self
1356 }
1357 Query::IndexScan(IndexScan {
1359 columns: lhs_col_id,
1360 bounds: (Bound::Unbounded, Bound::Included(upper)),
1361 ..
1362 }) if columns == lhs_col_id => {
1363 let bounds = (Self::bound(value, inclusive), Bound::Included(upper));
1364 self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1365 self
1366 }
1367 Query::IndexScan(IndexScan {
1369 columns: lhs_col_id,
1370 bounds: (Bound::Unbounded, Bound::Excluded(upper)),
1371 ..
1372 }) if columns == lhs_col_id => {
1373 let is_never = !inclusive && value == upper;
1383
1384 let bounds = (Self::bound(value, inclusive), Bound::Excluded(upper));
1385 self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1386
1387 if is_never {
1388 log::warn!("Query will select no rows due to equal excluded bounds: {self:?}")
1389 }
1390
1391 self
1392 }
1393 Query::Select(filter) => {
1395 let bounds = (Self::bound(value, inclusive), Bound::Unbounded);
1396 let op = ColumnOp::from_op_col_bounds(&columns, bounds);
1397 self.query.push(Query::Select(ColumnOp::and(filter, op)));
1398 self
1399 }
1400 query => {
1402 self.query.push(query);
1403 let bounds = (Self::bound(value, inclusive), Bound::Unbounded);
1404 let op = ColumnOp::from_op_col_bounds(&columns, bounds);
1405 self.query.push(Query::Select(op));
1406 self
1407 }
1408 }
1409 }
1410
1411 pub fn with_index_upper_bound(
1415 mut self,
1416 table: DbTable,
1417 columns: ColList,
1418 value: AlgebraicValue,
1419 inclusive: bool,
1420 ) -> Self {
1421 let Some(query) = self.query.pop() else {
1423 self.query.push(Query::IndexScan(IndexScan {
1424 table,
1425 columns,
1426 bounds: (Bound::Unbounded, Self::bound(value, inclusive)),
1427 }));
1428 return self;
1429 };
1430 match query {
1431 Query::JoinInner(JoinExpr {
1433 rhs:
1434 QueryExpr {
1435 source: SourceExpr::DbTable(ref db_table),
1436 ..
1437 },
1438 ..
1439 }) if table.table_id != db_table.table_id => {
1440 self = self.with_index_upper_bound(table, columns, value, inclusive);
1441 self.query.push(query);
1442 self
1443 }
1444 Query::JoinInner(JoinExpr {
1446 rhs,
1447 col_lhs,
1448 col_rhs,
1449 inner: semi,
1450 }) => {
1451 self.query.push(Query::JoinInner(JoinExpr {
1452 rhs: rhs.with_index_upper_bound(table, columns, value, inclusive),
1453 col_lhs,
1454 col_rhs,
1455 inner: semi,
1456 }));
1457 self
1458 }
1459 Query::IndexScan(IndexScan {
1461 columns: lhs_col_id,
1462 bounds: (Bound::Included(lower), Bound::Unbounded),
1463 ..
1464 }) if columns == lhs_col_id => {
1465 let bounds = (Bound::Included(lower), Self::bound(value, inclusive));
1466 self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1467 self
1468 }
1469 Query::IndexScan(IndexScan {
1471 columns: lhs_col_id,
1472 bounds: (Bound::Excluded(lower), Bound::Unbounded),
1473 ..
1474 }) if columns == lhs_col_id => {
1475 let is_never = !inclusive && value == lower;
1485
1486 let bounds = (Bound::Excluded(lower), Self::bound(value, inclusive));
1487 self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1488
1489 if is_never {
1490 log::warn!("Query will select no rows due to equal excluded bounds: {self:?}")
1491 }
1492
1493 self
1494 }
1495 Query::Select(filter) => {
1497 let bounds = (Bound::Unbounded, Self::bound(value, inclusive));
1498 let op = ColumnOp::from_op_col_bounds(&columns, bounds);
1499 self.query.push(Query::Select(ColumnOp::and(filter, op)));
1500 self
1501 }
1502 query => {
1504 self.query.push(query);
1505 let bounds = (Bound::Unbounded, Self::bound(value, inclusive));
1506 let op = ColumnOp::from_op_col_bounds(&columns, bounds);
1507 self.query.push(Query::Select(op));
1508 self
1509 }
1510 }
1511 }
1512
1513 pub fn with_select<O>(mut self, op: O) -> Result<Self, RelationError>
1514 where
1515 O: Into<FieldOp>,
1516 {
1517 let op = op.into();
1518 let Some(query) = self.query.pop() else {
1519 return self.add_base_select(op);
1520 };
1521
1522 match (query, op) {
1523 (
1524 Query::JoinInner(JoinExpr {
1525 rhs,
1526 col_lhs,
1527 col_rhs,
1528 inner,
1529 }),
1530 FieldOp::Cmp {
1531 op: OpQuery::Cmp(cmp),
1532 lhs: field,
1533 rhs: value,
1534 },
1535 ) => match (*field, *value) {
1536 (FieldOp::Field(FieldExpr::Name(field)), FieldOp::Field(FieldExpr::Value(value)))
1537 if self.head().column_pos(field).is_some() =>
1539 {
1540 self = self.with_select(FieldOp::cmp(field, cmp, value))?;
1543 self.query.push(Query::JoinInner(JoinExpr { rhs, col_lhs, col_rhs, inner }));
1544 Ok(self)
1545 }
1546 (FieldOp::Field(FieldExpr::Name(field)), FieldOp::Field(FieldExpr::Value(value)))
1547 if rhs.head().column_pos(field).is_some() =>
1549 {
1550 let rhs = rhs.with_select(FieldOp::cmp(field, cmp, value))?;
1553 self.query.push(Query::JoinInner(JoinExpr {
1554 rhs,
1555 col_lhs,
1556 col_rhs,
1557 inner,
1558 }));
1559 Ok(self)
1560 }
1561 (field, value) => {
1562 self.query.push(Query::JoinInner(JoinExpr { rhs, col_lhs, col_rhs, inner, }));
1563
1564 self.check_field_op_logics(&field)?;
1567 self.check_field_op_logics(&value)?;
1568 let col = field.names_to_cols(self.head()).unwrap();
1570 let value = value.names_to_cols(self.head()).unwrap();
1571 self.query.push(Query::Select(ColumnOp::new(OpQuery::Cmp(cmp), col, value)));
1573 Ok(self)
1574 }
1575 },
1576 (Query::Select(lhs), rhs) => {
1578 self.check_field_op(&rhs)?;
1580 let rhs = rhs.names_to_cols(self.head()).unwrap();
1582 self.query.push(Query::Select(ColumnOp::and(lhs, rhs)));
1584 Ok(self)
1585 }
1586 (query, op) => {
1588 self.query.push(query);
1589 self.add_base_select(op)
1590 }
1591 }
1592 }
1593
1594 fn add_base_select(mut self, op: FieldOp) -> Result<Self, RelationError> {
1597 self.check_field_op(&op)?;
1599 let op = op.names_to_cols(self.head()).unwrap();
1601 self.query.push(Query::Select(op));
1603 Ok(self)
1604 }
1605
1606 fn check_field_op(&self, op: &FieldOp) -> Result<(), RelationError> {
1609 use OpQuery::*;
1610 match op {
1611 FieldOp::Cmp { op: Logic(_), lhs, rhs } => {
1613 self.check_field_op(lhs)?;
1614 self.check_field_op(rhs)?;
1615 Ok(())
1616 }
1617 FieldOp::Cmp { op: Cmp(_), lhs, rhs } => {
1623 self.check_field_op_logics(lhs)?;
1624 self.check_field_op_logics(rhs)?;
1625 Ok(())
1626 }
1627 FieldOp::Field(FieldExpr::Value(AlgebraicValue::Bool(_))) => Ok(()),
1628 FieldOp::Field(FieldExpr::Value(v)) => Err(RelationError::NotBoolValue { val: v.clone() }),
1629 FieldOp::Field(FieldExpr::Name(field)) => {
1630 let field = *field;
1631 let head = self.head();
1632 let col_id = head.column_pos_or_err(field)?;
1633 let col_ty = &head.fields[col_id.idx()].algebraic_type;
1634 match col_ty {
1635 &AlgebraicType::Bool => Ok(()),
1636 ty => Err(RelationError::NotBoolType { field, ty: ty.clone() }),
1637 }
1638 }
1639 }
1640 }
1641
1642 fn check_field_op_logics(&self, op: &FieldOp) -> Result<(), RelationError> {
1644 use OpQuery::*;
1645 match op {
1646 FieldOp::Field(_) => Ok(()),
1647 FieldOp::Cmp { op: Cmp(_), lhs, rhs } => {
1648 self.check_field_op_logics(lhs)?;
1649 self.check_field_op_logics(rhs)?;
1650 Ok(())
1651 }
1652 FieldOp::Cmp { op: Logic(_), lhs, rhs } => {
1653 self.check_field_op(lhs)?;
1654 self.check_field_op(rhs)?;
1655 Ok(())
1656 }
1657 }
1658 }
1659
1660 pub fn with_select_cmp<LHS, RHS, O>(self, op: O, lhs: LHS, rhs: RHS) -> Result<Self, RelationError>
1661 where
1662 LHS: Into<FieldExpr>,
1663 RHS: Into<FieldExpr>,
1664 O: Into<OpQuery>,
1665 {
1666 let op = FieldOp::new(op.into(), FieldOp::Field(lhs.into()), FieldOp::Field(rhs.into()));
1667 self.with_select(op)
1668 }
1669
1670 pub fn with_project(
1674 mut self,
1675 fields: Vec<FieldExpr>,
1676 wildcard_table: Option<TableId>,
1677 ) -> Result<Self, RelationError> {
1678 if !fields.is_empty() {
1679 let header_before = self.head();
1680
1681 let mut cols = Vec::with_capacity(fields.len());
1683 for field in fields {
1684 cols.push(field.name_to_col(header_before)?);
1685 }
1686
1687 let header_after = Arc::new(header_before.project(&cols)?);
1690
1691 self.query.push(Query::Project(ProjectExpr {
1693 cols,
1694 wildcard_table,
1695 header_after,
1696 }));
1697 }
1698 Ok(self)
1699 }
1700
1701 pub fn with_join_inner_raw(
1702 mut self,
1703 q_rhs: QueryExpr,
1704 c_lhs: ColId,
1705 c_rhs: ColId,
1706 inner: Option<Arc<Header>>,
1707 ) -> Self {
1708 self.query
1709 .push(Query::JoinInner(JoinExpr::new(q_rhs, c_lhs, c_rhs, inner)));
1710 self
1711 }
1712
1713 pub fn with_join_inner(self, q_rhs: impl Into<QueryExpr>, c_lhs: ColId, c_rhs: ColId, semi: bool) -> Self {
1714 let q_rhs = q_rhs.into();
1715 let inner = (!semi).then(|| Arc::new(self.head().extend(q_rhs.head())));
1716 self.with_join_inner_raw(q_rhs, c_lhs, c_rhs, inner)
1717 }
1718
1719 fn bound(value: AlgebraicValue, inclusive: bool) -> Bound<AlgebraicValue> {
1720 if inclusive {
1721 Bound::Included(value)
1722 } else {
1723 Bound::Excluded(value)
1724 }
1725 }
1726
1727 pub fn try_semi_join(self) -> QueryExpr {
1763 let QueryExpr { source, query } = self;
1764
1765 let Some(source_table_id) = source.table_id() else {
1766 return QueryExpr { source, query };
1768 };
1769
1770 let mut exprs = query.into_iter();
1771 let Some(join_candidate) = exprs.next() else {
1772 return QueryExpr { source, query: vec![] };
1774 };
1775 let Query::JoinInner(join) = join_candidate else {
1776 return QueryExpr {
1778 source,
1779 query: itertools::chain![Some(join_candidate), exprs].collect(),
1780 };
1781 };
1782
1783 let Some(project_candidate) = exprs.next() else {
1784 return QueryExpr {
1786 source,
1787 query: vec![Query::JoinInner(join)],
1788 };
1789 };
1790
1791 let Query::Project(proj) = project_candidate else {
1792 return QueryExpr {
1794 source,
1795 query: itertools::chain![Some(Query::JoinInner(join)), Some(project_candidate), exprs].collect(),
1796 };
1797 };
1798
1799 if proj.wildcard_table != Some(source_table_id) {
1800 return QueryExpr {
1802 source,
1803 query: itertools::chain![Some(Query::JoinInner(join)), Some(Query::Project(proj)), exprs].collect(),
1804 };
1805 };
1806
1807 let semijoin = JoinExpr { inner: None, ..join };
1809
1810 QueryExpr {
1811 source,
1812 query: itertools::chain![Some(Query::JoinInner(semijoin)), exprs].collect(),
1813 }
1814 }
1815
1816 fn try_index_join(self) -> QueryExpr {
1823 let mut query = self;
1824 if query.query.len() != 1 {
1827 return query;
1828 }
1829
1830 if query.source.is_mem_table() {
1833 return query;
1834 }
1835 let source = query.source;
1836 let join = query.query.pop().unwrap();
1837
1838 match join {
1839 Query::JoinInner(join @ JoinExpr { inner: None, .. }) => {
1840 if !join.rhs.query.is_empty() {
1841 if source.head().has_constraint(join.col_lhs, Constraints::indexed()) {
1843 let index_join = IndexJoin {
1844 probe_side: join.rhs,
1845 probe_col: join.col_rhs,
1846 index_side: source.clone(),
1847 index_select: None,
1848 index_col: join.col_lhs,
1849 return_index_rows: true,
1850 };
1851 let query = [Query::IndexJoin(index_join)].into();
1852 return QueryExpr { source, query };
1853 }
1854 }
1855 QueryExpr {
1856 source,
1857 query: vec![Query::JoinInner(join)],
1858 }
1859 }
1860 first => QueryExpr {
1861 source,
1862 query: vec![first],
1863 },
1864 }
1865 }
1866
1867 fn optimize_select(mut q: QueryExpr, op: ColumnOp, tables: &[SourceExpr]) -> QueryExpr {
1869 let mut fields_found = HashSet::default();
1872 for schema in tables {
1873 for op in select_best_index(&mut fields_found, schema.head(), &op) {
1874 if let IndexColumnOp::Scan(op) = &op {
1875 if op.as_col_cmp().is_some_and(|cc| !fields_found.insert(cc)) {
1878 continue;
1879 }
1880 }
1881
1882 match op {
1883 IndexColumnOp::Index(idx) => {
1886 let table = schema
1887 .get_db_table()
1888 .expect("find_sargable_ops(schema, op) implies `schema.is_db_table()`")
1889 .clone();
1890
1891 q = match idx {
1892 IndexArgument::Eq { columns, value } => q.with_index_eq(table, columns.clone(), value),
1893 IndexArgument::LowerBound {
1894 columns,
1895 value,
1896 inclusive,
1897 } => q.with_index_lower_bound(table, columns.clone(), value, inclusive),
1898 IndexArgument::UpperBound {
1899 columns,
1900 value,
1901 inclusive,
1902 } => q.with_index_upper_bound(table, columns.clone(), value, inclusive),
1903 };
1904 }
1905 IndexColumnOp::Scan(rhs) => {
1907 let rhs = rhs.clone();
1908 let op = match q.query.pop() {
1909 Some(Query::Select(lhs)) => ColumnOp::and(lhs, rhs),
1911 None => rhs,
1912 Some(other) => {
1913 q.query.push(other);
1914 rhs
1915 }
1916 };
1917 q.query.push(Query::Select(op));
1918 }
1919 }
1920 }
1921 }
1922
1923 q
1924 }
1925
1926 pub fn optimize(mut self, row_count: &impl Fn(TableId, &str) -> i64) -> Self {
1927 let mut q = Self {
1928 source: self.source.clone(),
1929 query: Vec::with_capacity(self.query.len()),
1930 };
1931
1932 if matches!(&*self.query, [Query::IndexJoin(_)]) {
1933 if let Some(Query::IndexJoin(join)) = self.query.pop() {
1934 q.query.push(Query::IndexJoin(join.reorder(row_count)));
1935 return q;
1936 }
1937 }
1938
1939 for query in self.query {
1940 match query {
1941 Query::Select(op) => {
1942 q = Self::optimize_select(q, op, from_ref(&self.source));
1943 }
1944 Query::JoinInner(join) => {
1945 q = q.with_join_inner_raw(join.rhs.optimize(row_count), join.col_lhs, join.col_rhs, join.inner);
1946 }
1947 _ => q.query.push(query),
1948 };
1949 }
1950
1951 let q = q.try_semi_join();
1953 let q = q.try_index_join();
1954 if matches!(&*q.query, [Query::IndexJoin(_)]) {
1955 return q.optimize(row_count);
1956 }
1957 q
1958 }
1959}
1960
1961impl AuthAccess for Query {
1962 fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> {
1963 if owner == caller {
1964 return Ok(());
1965 }
1966
1967 self.walk_sources(&mut |s| s.check_auth(owner, caller))
1968 }
1969}
1970
1971#[derive(Debug, Eq, PartialEq, From)]
1972pub enum Expr {
1973 #[from]
1974 Value(AlgebraicValue),
1975 Block(Vec<Expr>),
1976 Ident(String),
1977 Crud(Box<CrudExpr>),
1978 Halt(ErrorLang),
1979}
1980
1981impl From<QueryExpr> for Expr {
1982 fn from(x: QueryExpr) -> Self {
1983 Expr::Crud(Box::new(CrudExpr::Query(x)))
1984 }
1985}
1986
1987impl fmt::Display for Query {
1988 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1989 match self {
1990 Query::IndexScan(op) => {
1991 write!(f, "index_scan {:?}", op)
1992 }
1993 Query::IndexJoin(op) => {
1994 write!(f, "index_join {:?}", op)
1995 }
1996 Query::Select(q) => {
1997 write!(f, "select {q}")
1998 }
1999 Query::Project(proj) => {
2000 let q = &proj.cols;
2001 write!(f, "project")?;
2002 if !q.is_empty() {
2003 write!(f, " ")?;
2004 }
2005 for (pos, x) in q.iter().enumerate() {
2006 write!(f, "{x}")?;
2007 if pos + 1 < q.len() {
2008 write!(f, ", ")?;
2009 }
2010 }
2011 Ok(())
2012 }
2013 Query::JoinInner(q) => {
2014 write!(f, "&inner {:?} ON {} = {}", q.rhs, q.col_lhs, q.col_rhs)
2015 }
2016 }
2017 }
2018}
2019
2020impl AuthAccess for SourceExpr {
2021 fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> {
2022 if owner == caller || self.table_access() == StAccess::Public {
2023 return Ok(());
2024 }
2025
2026 Err(AuthError::TablePrivate {
2027 named: self.table_name().to_string(),
2028 })
2029 }
2030}
2031
2032impl AuthAccess for QueryExpr {
2033 fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> {
2034 if owner == caller {
2035 return Ok(());
2036 }
2037 self.walk_sources(&mut |s| s.check_auth(owner, caller))
2038 }
2039}
2040
2041impl AuthAccess for CrudExpr {
2042 fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> {
2043 if owner == caller {
2044 return Ok(());
2045 }
2046 if let CrudExpr::Query(q) = self {
2048 return q.check_auth(owner, caller);
2049 }
2050
2051 Err(AuthError::OwnerRequired)
2053 }
2054}
2055
2056#[derive(Debug, PartialEq)]
2057pub struct Update {
2058 pub table_id: TableId,
2059 pub table_name: Box<str>,
2060 pub inserts: Vec<ProductValue>,
2061 pub deletes: Vec<ProductValue>,
2062}
2063
2064#[derive(Debug, PartialEq)]
2065pub enum Code {
2066 Value(AlgebraicValue),
2067 Table(MemTable),
2068 Halt(ErrorLang),
2069 Block(Vec<Code>),
2070 Crud(CrudExpr),
2071 Pass(Option<Update>),
2072}
2073
2074impl fmt::Display for Code {
2075 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2076 match self {
2077 Code::Value(x) => {
2078 write!(f, "{:?}", &x)
2079 }
2080 Code::Block(_) => write!(f, "Block"),
2081 x => todo!("{:?}", x),
2082 }
2083 }
2084}
2085
2086#[derive(Debug, PartialEq)]
2087pub enum CodeResult {
2088 Value(AlgebraicValue),
2089 Table(MemTable),
2090 Block(Vec<CodeResult>),
2091 Halt(ErrorLang),
2092 Pass(Option<Update>),
2093}
2094
2095impl From<Code> for CodeResult {
2096 fn from(code: Code) -> Self {
2097 match code {
2098 Code::Value(x) => Self::Value(x),
2099 Code::Table(x) => Self::Table(x),
2100 Code::Halt(x) => Self::Halt(x),
2101 Code::Block(x) => {
2102 if x.is_empty() {
2103 Self::Pass(None)
2104 } else {
2105 Self::Block(x.into_iter().map(CodeResult::from).collect())
2106 }
2107 }
2108 Code::Pass(x) => Self::Pass(x),
2109 x => Self::Halt(ErrorLang::new(
2110 ErrorKind::Compiler,
2111 Some(&format!("Invalid result: {x}")),
2112 )),
2113 }
2114 }
2115}
2116
2117#[cfg(test)]
2118mod tests {
2119 use super::*;
2120
2121 use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, relation::Column};
2122 use spacetimedb_sats::{product, AlgebraicType, ProductType};
2123 use spacetimedb_schema::{def::ModuleDef, schema::Schema};
2124 use typed_arena::Arena;
2125
2126 const ALICE: Identity = Identity::from_byte_array([1; 32]);
2127 const BOB: Identity = Identity::from_byte_array([2; 32]);
2128
2129 fn tables() -> [SourceExpr; 2] {
2133 [
2134 SourceExpr::InMemory {
2135 source_id: SourceId(0),
2136 header: Arc::new(Header {
2137 table_id: 42.into(),
2138 table_name: "foo".into(),
2139 fields: vec![],
2140 constraints: Default::default(),
2141 }),
2142 table_type: StTableType::User,
2143 table_access: StAccess::Private,
2144 },
2145 SourceExpr::DbTable(DbTable {
2146 head: Arc::new(Header {
2147 table_id: 42.into(),
2148 table_name: "foo".into(),
2149 fields: vec![],
2150 constraints: [(ColId(42).into(), Constraints::indexed())].into_iter().collect(),
2151 }),
2152 table_id: 42.into(),
2153 table_type: StTableType::User,
2154 table_access: StAccess::Private,
2155 }),
2156 ]
2157 }
2158
2159 fn queries() -> impl IntoIterator<Item = Query> {
2160 let [mem_table, db_table] = tables();
2161 [
2164 Query::IndexScan(IndexScan {
2165 table: db_table.get_db_table().unwrap().clone(),
2166 columns: ColList::new(42.into()),
2167 bounds: (Bound::Included(22.into()), Bound::Unbounded),
2168 }),
2169 Query::IndexJoin(IndexJoin {
2170 probe_side: mem_table.clone().into(),
2171 probe_col: 0.into(),
2172 index_side: SourceExpr::DbTable(DbTable {
2173 head: Arc::new(Header {
2174 table_id: db_table.head().table_id,
2175 table_name: db_table.table_name().into(),
2176 fields: vec![],
2177 constraints: Default::default(),
2178 }),
2179 table_id: db_table.head().table_id,
2180 table_type: StTableType::User,
2181 table_access: StAccess::Public,
2182 }),
2183 index_select: None,
2184 index_col: 22.into(),
2185 return_index_rows: true,
2186 }),
2187 Query::JoinInner(JoinExpr {
2188 col_rhs: 1.into(),
2189 rhs: mem_table.into(),
2190 col_lhs: 1.into(),
2191 inner: None,
2192 }),
2193 ]
2194 }
2195
2196 fn query_exprs() -> impl IntoIterator<Item = QueryExpr> {
2197 tables().map(|table| {
2198 let mut expr = QueryExpr::from(table);
2199 expr.query = queries().into_iter().collect();
2200 expr
2201 })
2202 }
2203
2204 fn assert_owner_private<T: AuthAccess>(auth: &T) {
2205 assert!(auth.check_auth(ALICE, ALICE).is_ok());
2206 assert!(matches!(
2207 auth.check_auth(ALICE, BOB),
2208 Err(AuthError::TablePrivate { .. })
2209 ));
2210 }
2211
2212 fn assert_owner_required<T: AuthAccess>(auth: T) {
2213 assert!(auth.check_auth(ALICE, ALICE).is_ok());
2214 assert!(matches!(auth.check_auth(ALICE, BOB), Err(AuthError::OwnerRequired)));
2215 }
2216
2217 fn mem_table(id: TableId, name: &str, fields: &[(u16, AlgebraicType, bool)]) -> SourceExpr {
2218 let table_access = StAccess::Public;
2219 let head = Header::new(
2220 id,
2221 name.into(),
2222 fields
2223 .iter()
2224 .map(|(col, ty, _)| Column::new(FieldName::new(id, (*col).into()), ty.clone()))
2225 .collect(),
2226 fields
2227 .iter()
2228 .enumerate()
2229 .filter(|(_, (_, _, indexed))| *indexed)
2230 .map(|(i, _)| (ColId::from(i).into(), Constraints::indexed())),
2231 );
2232 SourceExpr::InMemory {
2233 source_id: SourceId(0),
2234 header: Arc::new(head),
2235 table_access,
2236 table_type: StTableType::User,
2237 }
2238 }
2239
2240 #[test]
2241 fn test_index_to_inner_join() {
2242 let index_side = mem_table(
2243 0.into(),
2244 "index",
2245 &[(0, AlgebraicType::U8, false), (1, AlgebraicType::U8, true)],
2246 );
2247 let probe_side = mem_table(
2248 1.into(),
2249 "probe",
2250 &[(0, AlgebraicType::U8, false), (1, AlgebraicType::U8, true)],
2251 );
2252
2253 let index_col = 1.into();
2254 let probe_col = 1.into();
2255 let index_select = ColumnOp::cmp(0, OpCmp::Eq, 0u8);
2256 let join = IndexJoin {
2257 probe_side: probe_side.clone().into(),
2258 probe_col,
2259 index_side: index_side.clone(),
2260 index_select: Some(index_select.clone()),
2261 index_col,
2262 return_index_rows: false,
2263 };
2264
2265 let expr = join.to_inner_join();
2266
2267 assert_eq!(expr.source, probe_side);
2268 assert_eq!(expr.query.len(), 1);
2269
2270 let Query::JoinInner(ref join) = expr.query[0] else {
2271 panic!("expected an inner join, but got {:#?}", expr.query[0]);
2272 };
2273
2274 assert_eq!(join.col_lhs, probe_col);
2275 assert_eq!(join.col_rhs, index_col);
2276 assert_eq!(
2277 join.rhs,
2278 QueryExpr {
2279 source: index_side,
2280 query: vec![index_select.into()]
2281 }
2282 );
2283 assert_eq!(join.inner, None);
2284 }
2285
2286 fn setup_best_index() -> (Header, [ColId; 5], [AlgebraicValue; 5]) {
2287 let table_id = 0.into();
2288
2289 let vals = [1, 2, 3, 4, 5].map(AlgebraicValue::U64);
2290 let col_ids = [0, 1, 2, 3, 4].map(ColId);
2291 let [a, b, c, d, _] = col_ids;
2292 let columns = col_ids.map(|c| Column::new(FieldName::new(table_id, c), AlgebraicType::I8));
2293
2294 let head1 = Header::new(
2295 table_id,
2296 "t1".into(),
2297 columns.to_vec(),
2298 vec![
2299 (a.into(), Constraints::primary_key()),
2301 (b.into(), Constraints::indexed()),
2303 (col_list![b, c], Constraints::unique()),
2305 (col_list![a, b, c, d], Constraints::indexed()),
2307 ],
2308 );
2309
2310 (head1, col_ids, vals)
2311 }
2312
2313 fn make_field_value((cmp, col, value): (OpCmp, ColId, &AlgebraicValue)) -> ColumnOp {
2314 ColumnOp::cmp(col, cmp, value.clone())
2315 }
2316
2317 fn scan_eq<'a>(arena: &'a Arena<ColumnOp>, col: ColId, val: &'a AlgebraicValue) -> IndexColumnOp<'a> {
2318 scan(arena, OpCmp::Eq, col, val)
2319 }
2320
2321 fn scan<'a>(arena: &'a Arena<ColumnOp>, cmp: OpCmp, col: ColId, val: &'a AlgebraicValue) -> IndexColumnOp<'a> {
2322 IndexColumnOp::Scan(arena.alloc(make_field_value((cmp, col, val))))
2323 }
2324
2325 #[test]
2326 fn best_index() {
2327 let (head1, fields, vals) = setup_best_index();
2328 let [col_a, col_b, col_c, col_d, col_e] = fields;
2329 let [val_a, val_b, val_c, val_d, val_e] = vals;
2330
2331 let arena = Arena::new();
2332 let select_best_index = |fields: &[_]| {
2333 let fields = fields
2334 .iter()
2335 .copied()
2336 .map(|(col, val): (ColId, _)| make_field_value((OpCmp::Eq, col, val)))
2337 .reduce(ColumnOp::and)
2338 .unwrap();
2339 select_best_index(&mut <_>::default(), &head1, arena.alloc(fields))
2340 };
2341
2342 let col_list_arena = Arena::new();
2343 let idx_eq = |cols, val| make_index_arg(OpCmp::Eq, col_list_arena.alloc(cols), val);
2344
2345 assert_eq!(
2347 select_best_index(&[(col_d, &val_e)]),
2348 [scan_eq(&arena, col_d, &val_e)].into(),
2349 );
2350
2351 assert_eq!(
2352 select_best_index(&[(col_a, &val_a)]),
2353 [idx_eq(col_a.into(), val_a.clone())].into(),
2354 );
2355
2356 assert_eq!(
2357 select_best_index(&[(col_b, &val_b)]),
2358 [idx_eq(col_b.into(), val_b.clone())].into(),
2359 );
2360
2361 assert_eq!(
2363 select_best_index(&[(col_b, &val_b), (col_c, &val_c)]),
2364 [idx_eq(
2365 col_list![col_b, col_c],
2366 product![val_b.clone(), val_c.clone()].into()
2367 )]
2368 .into(),
2369 );
2370
2371 assert_eq!(
2372 select_best_index(&[(col_c, &val_c), (col_b, &val_b)]),
2373 [idx_eq(
2374 col_list![col_b, col_c],
2375 product![val_b.clone(), val_c.clone()].into()
2376 )]
2377 .into(),
2378 );
2379
2380 assert_eq!(
2382 select_best_index(&[(col_a, &val_a), (col_b, &val_b), (col_c, &val_c), (col_d, &val_d)]),
2383 [idx_eq(
2384 col_list![col_a, col_b, col_c, col_d],
2385 product![val_a.clone(), val_b.clone(), val_c.clone(), val_d.clone()].into(),
2386 )]
2387 .into(),
2388 );
2389
2390 assert_eq!(
2391 select_best_index(&[(col_b, &val_b), (col_a, &val_a), (col_d, &val_d), (col_c, &val_c)]),
2392 [idx_eq(
2393 col_list![col_a, col_b, col_c, col_d],
2394 product![val_a.clone(), val_b.clone(), val_c.clone(), val_d.clone()].into(),
2395 )]
2396 .into()
2397 );
2398
2399 assert_eq!(
2401 select_best_index(&[(col_b, &val_b), (col_a, &val_a), (col_e, &val_e), (col_d, &val_d)]),
2402 [
2403 idx_eq(col_a.into(), val_a.clone()),
2404 idx_eq(col_b.into(), val_b.clone()),
2405 scan_eq(&arena, col_d, &val_d),
2406 scan_eq(&arena, col_e, &val_e),
2407 ]
2408 .into()
2409 );
2410
2411 assert_eq!(
2412 select_best_index(&[(col_b, &val_b), (col_c, &val_c), (col_d, &val_d)]),
2413 [
2414 idx_eq(col_list![col_b, col_c], product![val_b.clone(), val_c.clone()].into(),),
2415 scan_eq(&arena, col_d, &val_d),
2416 ]
2417 .into()
2418 );
2419 }
2420
2421 #[test]
2422 fn best_index_range() {
2423 let arena = Arena::new();
2424
2425 let (head1, cols, vals) = setup_best_index();
2426 let [col_a, col_b, col_c, col_d, _] = cols;
2427 let [val_a, val_b, val_c, val_d, _] = vals;
2428
2429 let select_best_index = |cols: &[_]| {
2430 let fields = cols.iter().map(|x| make_field_value(*x)).reduce(ColumnOp::and).unwrap();
2431 select_best_index(&mut <_>::default(), &head1, arena.alloc(fields))
2432 };
2433
2434 let col_list_arena = Arena::new();
2435 let idx = |cmp, cols: &[ColId], val: &AlgebraicValue| {
2436 let columns = cols.iter().copied().collect::<ColList>();
2437 let columns = col_list_arena.alloc(columns);
2438 make_index_arg(cmp, columns, val.clone())
2439 };
2440
2441 assert_eq!(
2443 select_best_index(&[(OpCmp::Gt, col_a, &val_a), (OpCmp::Lt, col_a, &val_b)]),
2444 [idx(OpCmp::Lt, &[col_a], &val_b), idx(OpCmp::Gt, &[col_a], &val_a)].into()
2445 );
2446
2447 assert_eq!(
2449 select_best_index(&[(OpCmp::Gt, col_d, &val_d), (OpCmp::Lt, col_d, &val_b)]),
2450 [
2451 scan(&arena, OpCmp::Lt, col_d, &val_b),
2452 scan(&arena, OpCmp::Gt, col_d, &val_d)
2453 ]
2454 .into()
2455 );
2456
2457 assert_eq!(
2459 select_best_index(&[(OpCmp::Gt, col_b, &val_b), (OpCmp::Lt, col_c, &val_c)]),
2460 [idx(OpCmp::Gt, &[col_b], &val_b), scan(&arena, OpCmp::Lt, col_c, &val_c)].into()
2461 );
2462
2463 let idx_bc = idx(
2465 OpCmp::Eq,
2466 &[col_b, col_c],
2467 &product![val_b.clone(), val_c.clone()].into(),
2468 );
2469 assert_eq!(
2470 select_best_index(&[
2472 (OpCmp::Eq, col_b, &val_b),
2473 (OpCmp::GtEq, col_a, &val_a),
2474 (OpCmp::Eq, col_c, &val_c),
2475 ]),
2476 [idx_bc.clone(), idx(OpCmp::GtEq, &[col_a], &val_a),].into()
2477 );
2478
2479 assert_eq!(
2481 select_best_index(&[
2482 (OpCmp::Gt, col_b, &val_b),
2483 (OpCmp::Eq, col_a, &val_a),
2484 (OpCmp::Lt, col_c, &val_c),
2485 ]),
2486 [
2487 idx(OpCmp::Eq, &[col_a], &val_a),
2488 idx(OpCmp::Gt, &[col_b], &val_b),
2489 scan(&arena, OpCmp::Lt, col_c, &val_c),
2490 ]
2491 .into()
2492 );
2493
2494 assert_eq!(
2496 select_best_index(&[
2497 (OpCmp::Eq, col_a, &val_a),
2498 (OpCmp::Eq, col_b, &val_b),
2499 (OpCmp::Eq, col_c, &val_c),
2500 (OpCmp::Gt, col_d, &val_d),
2501 ]),
2502 [
2503 idx_bc.clone(),
2504 idx(OpCmp::Eq, &[col_a], &val_a),
2505 scan(&arena, OpCmp::Gt, col_d, &val_d),
2506 ]
2507 .into()
2508 );
2509
2510 assert_eq!(
2512 select_best_index(&[
2513 (OpCmp::Eq, col_b, &val_b),
2514 (OpCmp::Eq, col_c, &val_c),
2515 (OpCmp::Eq, col_b, &val_b),
2516 (OpCmp::Eq, col_c, &val_c),
2517 ]),
2518 [idx_bc.clone(), idx_bc].into()
2519 );
2520 }
2521
2522 #[test]
2523 fn test_auth_table() {
2524 tables().iter().for_each(assert_owner_private)
2525 }
2526
2527 #[test]
2528 fn test_auth_query_code() {
2529 for code in query_exprs() {
2530 assert_owner_private(&code)
2531 }
2532 }
2533
2534 #[test]
2535 fn test_auth_query() {
2536 for query in queries() {
2537 assert_owner_private(&query);
2538 }
2539 }
2540
2541 #[test]
2542 fn test_auth_crud_code_query() {
2543 for query in query_exprs() {
2544 let crud = CrudExpr::Query(query);
2545 assert_owner_private(&crud);
2546 }
2547 }
2548
2549 #[test]
2550 fn test_auth_crud_code_insert() {
2551 for table in tables().into_iter().filter_map(|s| s.get_db_table().cloned()) {
2552 let crud = CrudExpr::Insert { table, rows: vec![] };
2553 assert_owner_required(crud);
2554 }
2555 }
2556
2557 #[test]
2558 fn test_auth_crud_code_update() {
2559 for qc in query_exprs() {
2560 let crud = CrudExpr::Update {
2561 delete: qc,
2562 assignments: Default::default(),
2563 };
2564 assert_owner_required(crud);
2565 }
2566 }
2567
2568 #[test]
2569 fn test_auth_crud_code_delete() {
2570 for query in query_exprs() {
2571 let crud = CrudExpr::Delete { query };
2572 assert_owner_required(crud);
2573 }
2574 }
2575
2576 fn test_def() -> ModuleDef {
2577 let mut builder = RawModuleDefV9Builder::new();
2578 builder.build_table_with_new_type(
2579 "lhs",
2580 ProductType::from([("a", AlgebraicType::I32), ("b", AlgebraicType::String)]),
2581 true,
2582 );
2583 builder.build_table_with_new_type(
2584 "rhs",
2585 ProductType::from([("c", AlgebraicType::I32), ("d", AlgebraicType::I64)]),
2586 true,
2587 );
2588 builder.finish().try_into().expect("test def should be valid")
2589 }
2590
2591 #[test]
2592 fn optimize_inner_join_to_semijoin() {
2594 let def: ModuleDef = test_def();
2595 let lhs = TableSchema::from_module_def(&def, def.table("lhs").unwrap(), (), 0.into());
2596 let rhs = TableSchema::from_module_def(&def, def.table("rhs").unwrap(), (), 1.into());
2597
2598 let lhs_source = SourceExpr::from(&lhs);
2599 let rhs_source = SourceExpr::from(&rhs);
2600
2601 let q = QueryExpr::new(lhs_source.clone())
2602 .with_join_inner(rhs_source.clone(), 0.into(), 0.into(), false)
2603 .with_project(
2604 [0, 1]
2605 .map(|c| FieldExpr::Name(FieldName::new(lhs.table_id, c.into())))
2606 .into(),
2607 Some(TableId::SENTINEL),
2608 )
2609 .unwrap();
2610 let q = q.optimize(&|_, _| 0);
2611
2612 assert_eq!(q.source, lhs_source, "Optimized query should read from lhs");
2613
2614 assert_eq!(
2615 q.query.len(),
2616 1,
2617 "Optimized query should have a single member, a semijoin"
2618 );
2619 match &q.query[0] {
2620 Query::JoinInner(JoinExpr { rhs, inner: semi, .. }) => {
2621 assert_eq!(semi, &None, "Optimized query should be a semijoin");
2622 assert_eq!(rhs.source, rhs_source, "Optimized query should filter with rhs");
2623 assert!(
2624 rhs.query.is_empty(),
2625 "Optimized query should not filter rhs before joining"
2626 );
2627 }
2628 wrong => panic!("Expected an inner join, but found {wrong:?}"),
2629 }
2630 }
2631
2632 #[test]
2633 fn optimize_inner_join_no_project() {
2635 let def: ModuleDef = test_def();
2636 let lhs = TableSchema::from_module_def(&def, def.table("lhs").unwrap(), (), 0.into());
2637 let rhs = TableSchema::from_module_def(&def, def.table("rhs").unwrap(), (), 1.into());
2638
2639 let lhs_source = SourceExpr::from(&lhs);
2640 let rhs_source = SourceExpr::from(&rhs);
2641
2642 let q = QueryExpr::new(lhs_source.clone()).with_join_inner(rhs_source.clone(), 0.into(), 0.into(), false);
2643 let optimized = q.clone().optimize(&|_, _| 0);
2644 assert_eq!(q, optimized);
2645 }
2646
2647 #[test]
2648 fn optimize_inner_join_wrong_project() {
2650 let def: ModuleDef = test_def();
2651 let lhs = TableSchema::from_module_def(&def, def.table("lhs").unwrap(), (), 0.into());
2652 let rhs = TableSchema::from_module_def(&def, def.table("rhs").unwrap(), (), 1.into());
2653
2654 let lhs_source = SourceExpr::from(&lhs);
2655 let rhs_source = SourceExpr::from(&rhs);
2656
2657 let q = QueryExpr::new(lhs_source.clone())
2658 .with_join_inner(rhs_source.clone(), 0.into(), 0.into(), false)
2659 .with_project(
2660 [0, 1]
2661 .map(|c| FieldExpr::Name(FieldName::new(rhs.table_id, c.into())))
2662 .into(),
2663 Some(TableId(1)),
2664 )
2665 .unwrap();
2666 let optimized = q.clone().optimize(&|_, _| 0);
2667 assert_eq!(q, optimized);
2668 }
2669}