1use crate::errors::{ErrorKind, ErrorLang};
2use crate::operator::{OpCmp, OpLogic, OpQuery};
3use crate::relation::{MemTable, RelValue};
4use arrayvec::ArrayVec;
5use core::slice::from_ref;
6use derive_more::From;
7use itertools::Itertools;
8use smallvec::SmallVec;
9use spacetimedb_data_structures::map::{HashSet, IntMap};
10use spacetimedb_lib::db::auth::{StAccess, StTableType};
11use spacetimedb_lib::Identity;
12use spacetimedb_primitives::*;
13use spacetimedb_sats::satn::Satn;
14use spacetimedb_sats::{AlgebraicType, AlgebraicValue, ProductValue};
15use spacetimedb_schema::def::error::{AuthError, RelationError};
16use spacetimedb_schema::relation::{ColExpr, DbTable, FieldName, Header};
17use spacetimedb_schema::schema::TableSchema;
18use std::borrow::Cow;
19use std::cmp::Reverse;
20use std::collections::btree_map::Entry;
21use std::collections::BTreeMap;
22use std::ops::Bound;
23use std::sync::Arc;
24use std::{fmt, iter, mem};
25
26pub trait AuthAccess {
28 fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError>;
29}
30
31#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, From)]
32pub enum FieldExpr {
33 Name(FieldName),
34 Value(AlgebraicValue),
35}
36
37impl FieldExpr {
38 pub fn strip_table(self) -> ColExpr {
39 match self {
40 Self::Name(field) => ColExpr::Col(field.col),
41 Self::Value(value) => ColExpr::Value(value),
42 }
43 }
44
45 pub fn name_to_col(self, head: &Header) -> Result<ColExpr, RelationError> {
46 match self {
47 Self::Value(val) => Ok(ColExpr::Value(val)),
48 Self::Name(field) => head.column_pos_or_err(field).map(ColExpr::Col),
49 }
50 }
51}
52
53impl fmt::Display for FieldExpr {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 match self {
56 FieldExpr::Name(x) => write!(f, "{x}"),
57 FieldExpr::Value(x) => write!(f, "{}", x.to_satn()),
58 }
59 }
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, From)]
63pub enum FieldOp {
64 #[from]
65 Field(FieldExpr),
66 Cmp {
67 op: OpQuery,
68 lhs: Box<FieldOp>,
69 rhs: Box<FieldOp>,
70 },
71}
72
73type FieldOpFlat = SmallVec<[FieldOp; 1]>;
74
75impl FieldOp {
76 pub fn new(op: OpQuery, lhs: Self, rhs: Self) -> Self {
77 Self::Cmp {
78 op,
79 lhs: Box::new(lhs),
80 rhs: Box::new(rhs),
81 }
82 }
83
84 pub fn cmp(field: impl Into<FieldName>, op: OpCmp, value: impl Into<AlgebraicValue>) -> Self {
85 Self::new(
86 OpQuery::Cmp(op),
87 Self::Field(FieldExpr::Name(field.into())),
88 Self::Field(FieldExpr::Value(value.into())),
89 )
90 }
91
92 pub fn names_to_cols(self, head: &Header) -> Result<ColumnOp, RelationError> {
93 match self {
94 Self::Field(field) => field.name_to_col(head).map(ColumnOp::from),
95 Self::Cmp { op, lhs, rhs } => {
96 let lhs = lhs.names_to_cols(head)?;
97 let rhs = rhs.names_to_cols(head)?;
98 Ok(ColumnOp::new(op, lhs, rhs))
99 }
100 }
101 }
102
103 pub fn flatten_ands(self) -> FieldOpFlat {
111 fn fill_vec(buf: &mut FieldOpFlat, op: FieldOp) {
112 match op {
113 FieldOp::Cmp {
114 op: OpQuery::Logic(OpLogic::And),
115 lhs,
116 rhs,
117 } => {
118 fill_vec(buf, *lhs);
119 fill_vec(buf, *rhs);
120 }
121 op => buf.push(op),
122 }
123 }
124 let mut buf = SmallVec::new();
125 fill_vec(&mut buf, self);
126 buf
127 }
128}
129
130impl fmt::Display for FieldOp {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 match self {
133 Self::Field(x) => {
134 write!(f, "{x}")
135 }
136 Self::Cmp { op, lhs, rhs } => {
137 write!(f, "{lhs} {op} {rhs}")
138 }
139 }
140 }
141}
142
143#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, From)]
144pub enum ColumnOp {
145 #[from]
147 Col(ColId),
148 #[from]
150 Val(AlgebraicValue),
151 ColCmpVal {
154 lhs: ColId,
155 cmp: OpCmp,
156 rhs: AlgebraicValue,
157 },
158 Cmp {
160 lhs: Box<ColumnOp>,
161 cmp: OpCmp,
162 rhs: Box<ColumnOp>,
163 },
164 Log { op: OpLogic, operands: Box<[ColumnOp]> },
168}
169
170impl ColumnOp {
171 pub fn new(op: OpQuery, lhs: Self, rhs: Self) -> Self {
172 match op {
173 OpQuery::Cmp(cmp) => match (lhs, rhs) {
174 (ColumnOp::Col(lhs), ColumnOp::Val(rhs)) => Self::cmp(lhs, cmp, rhs),
175 (lhs, rhs) => Self::Cmp {
176 lhs: Box::new(lhs),
177 cmp,
178 rhs: Box::new(rhs),
179 },
180 },
181 OpQuery::Logic(op) => Self::Log {
182 op,
183 operands: [lhs, rhs].into(),
184 },
185 }
186 }
187
188 pub fn cmp(col: impl Into<ColId>, cmp: OpCmp, val: impl Into<AlgebraicValue>) -> Self {
189 let lhs = col.into();
190 let rhs = val.into();
191 Self::ColCmpVal { lhs, cmp, rhs }
192 }
193
194 fn and(lhs: Self, rhs: Self) -> Self {
196 let ands = |operands| {
197 let op = OpLogic::And;
198 Self::Log { op, operands }
199 };
200
201 match (lhs, rhs) {
202 (
204 Self::Log {
205 op: OpLogic::And,
206 operands: lhs,
207 },
208 Self::Log {
209 op: OpLogic::And,
210 operands: rhs,
211 },
212 ) => {
213 let mut operands = Vec::from(lhs);
214 operands.append(&mut Vec::from(rhs));
215 ands(operands.into())
216 }
217 (
219 Self::Log {
220 op: OpLogic::And,
221 operands: lhs,
222 },
223 rhs,
224 ) => {
225 let mut operands = Vec::from(lhs);
226 operands.push(rhs);
227 ands(operands.into())
228 }
229 (lhs, rhs) => ands([lhs, rhs].into()),
231 }
232 }
233
234 fn and_cmp(op: OpCmp, cols: &ColList, value: AlgebraicValue) -> Self {
236 let cmp = |(col, value): (ColId, _)| Self::cmp(col, op, value);
237
238 if let Some(head) = cols.as_singleton() {
240 return cmp((head, value));
241 }
242
243 let operands = cols.iter().zip(value.into_product().unwrap()).map(cmp).collect();
245 Self::Log {
246 op: OpLogic::And,
247 operands,
248 }
249 }
250
251 fn from_op_col_bounds(cols: &ColList, bounds: (Bound<AlgebraicValue>, Bound<AlgebraicValue>)) -> Self {
254 let (cmp, value) = match bounds {
255 (Bound::Included(a), Bound::Included(b)) if a == b => (OpCmp::Eq, a),
257 (Bound::Included(value), Bound::Unbounded) => (OpCmp::GtEq, value),
259 (Bound::Excluded(value), Bound::Unbounded) => (OpCmp::Gt, value),
261 (Bound::Unbounded, Bound::Included(value)) => (OpCmp::LtEq, value),
263 (Bound::Unbounded, Bound::Excluded(value)) => (OpCmp::Lt, value),
265 (Bound::Unbounded, Bound::Unbounded) => unreachable!(),
266 (lower_bound, upper_bound) => {
267 let lhs = Self::from_op_col_bounds(cols, (lower_bound, Bound::Unbounded));
268 let rhs = Self::from_op_col_bounds(cols, (Bound::Unbounded, upper_bound));
269 return ColumnOp::and(lhs, rhs);
270 }
271 };
272 ColumnOp::and_cmp(cmp, cols, value)
273 }
274
275 fn as_col_cmp(&self) -> Option<(ColId, OpCmp)> {
277 match self {
278 Self::ColCmpVal { lhs, cmp, rhs: _ } => Some((*lhs, *cmp)),
279 Self::Cmp { lhs, cmp, rhs: _ } => match &**lhs {
280 ColumnOp::Col(col) => Some((*col, *cmp)),
281 _ => None,
282 },
283 _ => None,
284 }
285 }
286
287 fn eval<'a>(&'a self, row: &'a RelValue<'_>) -> Cow<'a, AlgebraicValue> {
289 let into = |b| Cow::Owned(AlgebraicValue::Bool(b));
290
291 match self {
292 Self::Col(col) => row.read_column(col.idx()).unwrap(),
293 Self::Val(val) => Cow::Borrowed(val),
294 Self::ColCmpVal { lhs, cmp, rhs } => into(Self::eval_cmp_col_val(row, *cmp, *lhs, rhs)),
295 Self::Cmp { lhs, cmp, rhs } => into(Self::eval_cmp(row, *cmp, lhs, rhs)),
296 Self::Log { op, operands } => into(Self::eval_log(row, *op, operands)),
297 }
298 }
299
300 pub fn eval_bool(&self, row: &RelValue<'_>) -> bool {
302 match self {
303 Self::Col(col) => *row.read_column(col.idx()).unwrap().as_bool().unwrap(),
304 Self::Val(val) => *val.as_bool().unwrap(),
305 Self::ColCmpVal { lhs, cmp, rhs } => Self::eval_cmp_col_val(row, *cmp, *lhs, rhs),
306 Self::Cmp { lhs, cmp, rhs } => Self::eval_cmp(row, *cmp, lhs, rhs),
307 Self::Log { op, operands } => Self::eval_log(row, *op, operands),
308 }
309 }
310
311 fn eval_op_cmp(cmp: OpCmp, lhs: &AlgebraicValue, rhs: &AlgebraicValue) -> bool {
313 match cmp {
314 OpCmp::Eq => lhs == rhs,
315 OpCmp::NotEq => lhs != rhs,
316 OpCmp::Lt => lhs < rhs,
317 OpCmp::LtEq => lhs <= rhs,
318 OpCmp::Gt => lhs > rhs,
319 OpCmp::GtEq => lhs >= rhs,
320 }
321 }
322
323 fn eval_cmp_col_val(row: &RelValue<'_>, cmp: OpCmp, lhs: ColId, rhs: &AlgebraicValue) -> bool {
325 let lhs = row.read_column(lhs.idx()).unwrap();
326 Self::eval_op_cmp(cmp, &lhs, rhs)
327 }
328
329 fn eval_cmp(row: &RelValue<'_>, cmp: OpCmp, lhs: &Self, rhs: &Self) -> bool {
333 let lhs = lhs.eval(row);
334 let rhs = rhs.eval(row);
335 Self::eval_op_cmp(cmp, &lhs, &rhs)
336 }
337
338 fn eval_log(row: &RelValue<'_>, op: OpLogic, opers: &[ColumnOp]) -> bool {
342 match op {
343 OpLogic::And => opers.iter().all(|o| o.eval_bool(row)),
344 OpLogic::Or => opers.iter().any(|o| o.eval_bool(row)),
345 }
346 }
347}
348
349impl fmt::Display for ColumnOp {
350 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351 match self {
352 Self::Col(col) => write!(f, "{col}"),
353 Self::Val(val) => write!(f, "{}", val.to_satn()),
354 Self::ColCmpVal { lhs, cmp, rhs } => write!(f, "{lhs} {cmp} {}", rhs.to_satn()),
355 Self::Cmp { cmp, lhs, rhs } => write!(f, "{lhs} {cmp} {rhs}"),
356 Self::Log { op, operands } => write!(f, "{}", operands.iter().format((*op).into())),
357 }
358 }
359}
360
361impl From<ColExpr> for ColumnOp {
362 fn from(ce: ColExpr) -> Self {
363 match ce {
364 ColExpr::Col(c) => c.into(),
365 ColExpr::Value(v) => v.into(),
366 }
367 }
368}
369
370impl From<Query> for Option<ColumnOp> {
371 fn from(value: Query) -> Self {
372 match value {
373 Query::IndexScan(op) => Some(ColumnOp::from_op_col_bounds(&op.columns, op.bounds)),
374 Query::Select(op) => Some(op),
375 _ => None,
376 }
377 }
378}
379
380#[derive(Debug, Copy, Clone, PartialEq, Eq, From, Hash)]
391pub struct SourceId(pub usize);
392
393pub trait SourceProvider<'a> {
407 type Source: 'a + IntoIterator<Item = RelValue<'a>>;
409
410 fn take_source(&mut self, id: SourceId) -> Option<Self::Source>;
418}
419
420impl<'a, I: 'a + IntoIterator<Item = RelValue<'a>>, F: FnMut(SourceId) -> Option<I>> SourceProvider<'a> for F {
421 type Source = I;
422 fn take_source(&mut self, id: SourceId) -> Option<Self::Source> {
423 self(id)
424 }
425}
426
427impl<'a, I: 'a + IntoIterator<Item = RelValue<'a>>> SourceProvider<'a> for Option<I> {
428 type Source = I;
429 fn take_source(&mut self, _: SourceId) -> Option<Self::Source> {
430 self.take()
431 }
432}
433
434pub struct NoInMemUsed;
435
436impl<'a> SourceProvider<'a> for NoInMemUsed {
437 type Source = iter::Empty<RelValue<'a>>;
438 fn take_source(&mut self, _: SourceId) -> Option<Self::Source> {
439 None
440 }
441}
442
443#[derive(Debug, PartialEq, Eq, Clone)]
448#[repr(transparent)]
449pub struct SourceSet<T, const N: usize>(
450 ArrayVec<Option<T>, N>,
453);
454
455impl<'a, T: 'a + IntoIterator<Item = RelValue<'a>>, const N: usize> SourceProvider<'a> for SourceSet<T, N> {
456 type Source = T;
457 fn take_source(&mut self, id: SourceId) -> Option<T> {
458 self.take(id)
459 }
460}
461
462impl<T, const N: usize> From<[T; N]> for SourceSet<T, N> {
463 #[inline]
464 fn from(sources: [T; N]) -> Self {
465 Self(sources.map(Some).into())
466 }
467}
468
469impl<T, const N: usize> SourceSet<T, N> {
470 pub fn empty() -> Self {
472 Self(ArrayVec::new())
473 }
474
475 fn next_id(&self) -> SourceId {
477 SourceId(self.0.len())
478 }
479
480 pub fn add(&mut self, table: T) -> SourceId {
483 let source_id = self.next_id();
484 self.0.push(Some(table));
485 source_id
486 }
487
488 pub fn take(&mut self, id: SourceId) -> Option<T> {
493 self.0.get_mut(id.0).map(mem::take).unwrap_or_default()
494 }
495
496 pub fn len(&self) -> usize {
500 self.0.len()
501 }
502
503 pub fn is_empty(&self) -> bool {
507 self.0.is_empty()
508 }
509}
510
511impl<T, const N: usize> std::ops::Index<SourceId> for SourceSet<T, N> {
512 type Output = Option<T>;
513
514 fn index(&self, idx: SourceId) -> &Self::Output {
515 &self.0[idx.0]
516 }
517}
518
519impl<T, const N: usize> std::ops::IndexMut<SourceId> for SourceSet<T, N> {
520 fn index_mut(&mut self, idx: SourceId) -> &mut Self::Output {
521 &mut self.0[idx.0]
522 }
523}
524
525impl<const N: usize> SourceSet<Vec<ProductValue>, N> {
526 pub fn add_mem_table(&mut self, table: MemTable) -> SourceExpr {
529 let id = self.add(table.data);
530 SourceExpr::from_mem_table(table.head, table.table_access, id)
531 }
532}
533
534#[derive(Debug, Clone, Eq, PartialEq, Hash)]
537pub enum SourceExpr {
538 InMemory {
546 source_id: SourceId,
547 header: Arc<Header>,
548 table_type: StTableType,
549 table_access: StAccess,
550 },
551 DbTable(DbTable),
554}
555
556impl SourceExpr {
557 pub fn source_id(&self) -> Option<SourceId> {
562 if let SourceExpr::InMemory { source_id, .. } = self {
563 Some(*source_id)
564 } else {
565 None
566 }
567 }
568
569 pub fn table_name(&self) -> &str {
570 &self.head().table_name
571 }
572
573 pub fn table_type(&self) -> StTableType {
574 match self {
575 SourceExpr::InMemory { table_type, .. } => *table_type,
576 SourceExpr::DbTable(db_table) => db_table.table_type,
577 }
578 }
579
580 pub fn table_access(&self) -> StAccess {
581 match self {
582 SourceExpr::InMemory { table_access, .. } => *table_access,
583 SourceExpr::DbTable(db_table) => db_table.table_access,
584 }
585 }
586
587 pub fn head(&self) -> &Arc<Header> {
588 match self {
589 SourceExpr::InMemory { header, .. } => header,
590 SourceExpr::DbTable(db_table) => &db_table.head,
591 }
592 }
593
594 pub fn is_mem_table(&self) -> bool {
595 matches!(self, SourceExpr::InMemory { .. })
596 }
597
598 pub fn is_db_table(&self) -> bool {
599 matches!(self, SourceExpr::DbTable(_))
600 }
601
602 pub fn from_mem_table(header: Arc<Header>, table_access: StAccess, id: SourceId) -> Self {
603 SourceExpr::InMemory {
604 source_id: id,
605 header,
606 table_type: StTableType::User,
607 table_access,
608 }
609 }
610
611 pub fn table_id(&self) -> Option<TableId> {
612 if let SourceExpr::DbTable(db_table) = self {
613 Some(db_table.table_id)
614 } else {
615 None
616 }
617 }
618
619 pub fn get_db_table(&self) -> Option<&DbTable> {
625 if let SourceExpr::DbTable(db_table) = self {
626 Some(db_table)
627 } else {
628 None
629 }
630 }
631}
632
633impl From<&TableSchema> for SourceExpr {
634 fn from(value: &TableSchema) -> Self {
635 SourceExpr::DbTable(value.into())
636 }
637}
638
639#[derive(Debug, Clone, Eq, PartialEq, Hash)]
643pub struct IndexJoin {
644 pub probe_side: QueryExpr,
645 pub probe_col: ColId,
646 pub index_side: SourceExpr,
647 pub index_select: Option<ColumnOp>,
648 pub index_col: ColId,
649 pub return_index_rows: bool,
652}
653
654impl From<IndexJoin> for QueryExpr {
655 fn from(join: IndexJoin) -> Self {
656 let source: SourceExpr = if join.return_index_rows {
657 join.index_side.clone()
658 } else {
659 join.probe_side.source.clone()
660 };
661 QueryExpr {
662 source,
663 query: vec![Query::IndexJoin(join)],
664 }
665 }
666}
667
668impl IndexJoin {
669 pub fn reorder(self, row_count: impl Fn(TableId, &str) -> i64) -> Self {
673 if self.probe_side.source.is_mem_table() {
675 return self;
676 }
677 if !self
679 .probe_side
680 .source
681 .head()
682 .has_constraint(self.probe_col, Constraints::indexed())
683 {
684 return self;
685 }
686 if !self
688 .probe_side
689 .query
690 .iter()
691 .all(|op| matches!(op, Query::Select(_) | Query::IndexScan(_)))
692 {
693 return self;
694 }
695 match self.index_side.get_db_table() {
696 Some(DbTable { head, table_id, .. }) if row_count(*table_id, &head.table_name) > 500 => self,
702 _ => {
705 let predicate = self
708 .probe_side
709 .query
710 .into_iter()
711 .filter_map(<Query as Into<Option<ColumnOp>>>::into)
712 .reduce(ColumnOp::and);
713 let probe_side = if let Some(predicate) = self.index_select {
715 QueryExpr {
716 source: self.index_side,
717 query: vec![predicate.into()],
718 }
719 } else {
720 self.index_side.into()
721 };
722 IndexJoin {
723 probe_side,
726 probe_col: self.index_col,
728 index_side: self.probe_side.source,
730 index_select: predicate,
732 index_col: self.probe_col,
734 return_index_rows: !self.return_index_rows,
737 }
738 }
739 }
740 }
741
742 pub fn to_inner_join(self) -> QueryExpr {
747 if self.return_index_rows {
748 let (col_lhs, col_rhs) = (self.index_col, self.probe_col);
749 let rhs = self.probe_side;
750
751 let source = self.index_side;
752 let inner_join = Query::JoinInner(JoinExpr::new(rhs, col_lhs, col_rhs, None));
753 let query = if let Some(predicate) = self.index_select {
754 vec![predicate.into(), inner_join]
755 } else {
756 vec![inner_join]
757 };
758 QueryExpr { source, query }
759 } else {
760 let (col_lhs, col_rhs) = (self.probe_col, self.index_col);
761 let mut rhs: QueryExpr = self.index_side.into();
762
763 if let Some(predicate) = self.index_select {
764 rhs.query.push(predicate.into());
765 }
766
767 let source = self.probe_side.source;
768 let inner_join = Query::JoinInner(JoinExpr::new(rhs, col_lhs, col_rhs, None));
769 let query = vec![inner_join];
770 QueryExpr { source, query }
771 }
772 }
773}
774
775#[derive(Debug, Clone, Eq, PartialEq, Hash)]
776pub struct JoinExpr {
777 pub rhs: QueryExpr,
778 pub col_lhs: ColId,
779 pub col_rhs: ColId,
780 pub inner: Option<Arc<Header>>,
785}
786
787impl JoinExpr {
788 pub fn new(rhs: QueryExpr, col_lhs: ColId, col_rhs: ColId, inner: Option<Arc<Header>>) -> Self {
789 Self {
790 rhs,
791 col_lhs,
792 col_rhs,
793 inner,
794 }
795 }
796}
797
798#[derive(Debug, Clone, Copy, Eq, PartialEq)]
799pub enum DbType {
800 Table,
801 Index,
802 Sequence,
803 Constraint,
804}
805
806#[derive(Debug, Clone, Copy, Eq, PartialEq)]
807pub enum Crud {
808 Query,
809 Insert,
810 Update,
811 Delete,
812 Create(DbType),
813 Drop(DbType),
814 Config,
815}
816
817#[derive(Debug, Eq, PartialEq)]
818pub enum CrudExpr {
819 Query(QueryExpr),
820 Insert {
821 table: DbTable,
822 rows: Vec<ProductValue>,
823 },
824 Update {
825 delete: QueryExpr,
826 assignments: IntMap<ColId, ColExpr>,
827 },
828 Delete {
829 query: QueryExpr,
830 },
831 SetVar {
832 name: String,
833 literal: String,
834 },
835 ReadVar {
836 name: String,
837 },
838}
839
840impl CrudExpr {
841 pub fn optimize(self, row_count: &impl Fn(TableId, &str) -> i64) -> Self {
842 match self {
843 CrudExpr::Query(x) => CrudExpr::Query(x.optimize(row_count)),
844 _ => self,
845 }
846 }
847
848 pub fn is_reads<'a>(exprs: impl IntoIterator<Item = &'a CrudExpr>) -> bool {
849 exprs
850 .into_iter()
851 .all(|expr| matches!(expr, CrudExpr::Query(_) | CrudExpr::ReadVar { .. }))
852 }
853}
854
855#[derive(Debug, Clone, Eq, PartialEq, Hash)]
856pub struct IndexScan {
857 pub table: DbTable,
858 pub columns: ColList,
859 pub bounds: (Bound<AlgebraicValue>, Bound<AlgebraicValue>),
860}
861
862impl IndexScan {
863 pub fn is_point(&self) -> bool {
865 match &self.bounds {
866 (Bound::Included(lower), Bound::Included(upper)) => lower == upper,
867 _ => false,
868 }
869 }
870}
871
872#[derive(Debug, Clone, Eq, PartialEq, From, Hash)]
874pub struct ProjectExpr {
875 pub cols: Vec<ColExpr>,
876 pub wildcard_table: Option<TableId>,
879 pub header_after: Arc<Header>,
880}
881
882#[derive(Debug, Clone, Eq, PartialEq, From, Hash)]
884pub enum Query {
885 IndexScan(IndexScan),
887 IndexJoin(IndexJoin),
890 Select(ColumnOp),
894 Project(ProjectExpr),
896 JoinInner(JoinExpr),
900}
901
902impl Query {
903 pub fn walk_sources<E>(&self, on_source: &mut impl FnMut(&SourceExpr) -> Result<(), E>) -> Result<(), E> {
907 match self {
908 Self::Select(..) | Self::Project(..) => Ok(()),
909 Self::IndexScan(scan) => on_source(&SourceExpr::DbTable(scan.table.clone())),
910 Self::IndexJoin(join) => join.probe_side.walk_sources(on_source),
911 Self::JoinInner(join) => join.rhs.walk_sources(on_source),
912 }
913 }
914}
915
916#[derive(Debug, PartialEq, Clone)]
919enum IndexArgument<'a> {
920 Eq {
921 columns: &'a ColList,
922 value: AlgebraicValue,
923 },
924 LowerBound {
925 columns: &'a ColList,
926 value: AlgebraicValue,
927 inclusive: bool,
928 },
929 UpperBound {
930 columns: &'a ColList,
931 value: AlgebraicValue,
932 inclusive: bool,
933 },
934}
935
936#[derive(Debug, PartialEq, Clone)]
937enum IndexColumnOp<'a> {
938 Index(IndexArgument<'a>),
939 Scan(&'a ColumnOp),
940}
941
942fn make_index_arg(cmp: OpCmp, columns: &ColList, value: AlgebraicValue) -> IndexColumnOp<'_> {
943 let arg = match cmp {
944 OpCmp::Eq => IndexArgument::Eq { columns, value },
945 OpCmp::NotEq => unreachable!("No IndexArgument for NotEq, caller should've filtered out"),
946 OpCmp::Lt => IndexArgument::UpperBound {
948 columns,
949 value,
950 inclusive: false,
951 },
952 OpCmp::Gt => IndexArgument::LowerBound {
954 columns,
955 value,
956 inclusive: false,
957 },
958 OpCmp::LtEq => IndexArgument::UpperBound {
960 columns,
961 value,
962 inclusive: true,
963 },
964 OpCmp::GtEq => IndexArgument::LowerBound {
966 columns,
967 value,
968 inclusive: true,
969 },
970 };
971 IndexColumnOp::Index(arg)
972}
973
974#[derive(Debug)]
975struct ColValue<'a> {
976 parent: &'a ColumnOp,
977 col: ColId,
978 cmp: OpCmp,
979 value: &'a AlgebraicValue,
980}
981
982impl<'a> ColValue<'a> {
983 pub fn new(parent: &'a ColumnOp, col: ColId, cmp: OpCmp, value: &'a AlgebraicValue) -> Self {
984 Self {
985 parent,
986 col,
987 cmp,
988 value,
989 }
990 }
991}
992
993type IndexColumnOpSink<'a> = SmallVec<[IndexColumnOp<'a>; 1]>;
994type ColsIndexed = HashSet<(ColId, OpCmp)>;
995
996fn select_best_index<'a>(
1045 cols_indexed: &mut ColsIndexed,
1046 header: &'a Header,
1047 op: &'a ColumnOp,
1048) -> IndexColumnOpSink<'a> {
1049 let mut indices = header
1053 .constraints
1054 .iter()
1055 .filter(|(_, c)| c.has_indexed())
1056 .map(|(cl, _)| cl)
1057 .collect::<SmallVec<[_; 1]>>();
1058 indices.sort_unstable_by_key(|cl| Reverse(cl.len()));
1059
1060 let mut found: IndexColumnOpSink = IndexColumnOpSink::default();
1061
1062 let mut col_map = BTreeMap::<_, SmallVec<[_; 1]>>::new();
1066 extract_cols(op, &mut col_map, &mut found);
1067
1068 for col_list in indices {
1071 if col_map.is_empty() {
1073 break;
1074 }
1075
1076 if let Some(head) = col_list.as_singleton() {
1077 for cmp in [OpCmp::Eq, OpCmp::Lt, OpCmp::LtEq, OpCmp::Gt, OpCmp::GtEq] {
1081 for ColValue { cmp, value, col, .. } in col_map.remove(&(head, cmp)).into_iter().flatten() {
1084 found.push(make_index_arg(cmp, col_list, value.clone()));
1085 cols_indexed.insert((col, cmp));
1086 }
1087 }
1088 } else {
1089 let cmp = OpCmp::Eq;
1097
1098 let mut min_all_cols_num_eq = col_list
1100 .iter()
1101 .map(|col| col_map.get(&(col, cmp)).map_or(0, |fs| fs.len()))
1102 .min()
1103 .unwrap_or_default();
1104
1105 while min_all_cols_num_eq > 0 {
1108 let mut elems = Vec::with_capacity(col_list.len() as usize);
1109 for col in col_list.iter() {
1110 let col_val = pop_multimap(&mut col_map, (col, cmp)).unwrap();
1112 cols_indexed.insert((col_val.col, cmp));
1113 elems.push(col_val.value.clone());
1115 }
1116 let value = AlgebraicValue::product(elems);
1118 found.push(make_index_arg(cmp, col_list, value));
1119 min_all_cols_num_eq -= 1;
1120 }
1121 }
1122 }
1123
1124 found.extend(
1126 col_map
1127 .into_iter()
1128 .flat_map(|(_, fs)| fs)
1129 .map(|f| IndexColumnOp::Scan(f.parent)),
1130 );
1131
1132 found
1133}
1134
1135fn pop_multimap<K: Ord, V, const N: usize>(map: &mut BTreeMap<K, SmallVec<[V; N]>>, key: K) -> Option<V> {
1138 let Entry::Occupied(mut entry) = map.entry(key) else {
1139 return None;
1140 };
1141 let fields = entry.get_mut();
1142 let val = fields.pop();
1143 if fields.is_empty() {
1144 entry.remove();
1145 }
1146 val
1147}
1148
1149fn extract_cols<'a>(
1154 op: &'a ColumnOp,
1155 col_map: &mut BTreeMap<(ColId, OpCmp), SmallVec<[ColValue<'a>; 1]>>,
1156 found: &mut IndexColumnOpSink<'a>,
1157) {
1158 let mut add_field = |parent, op, col, val| {
1159 let fv = ColValue::new(parent, col, op, val);
1160 col_map.entry((col, op)).or_default().push(fv);
1161 };
1162
1163 match op {
1164 ColumnOp::Cmp { cmp, lhs, rhs } => {
1165 if let (ColumnOp::Col(col), ColumnOp::Val(val)) = (&**lhs, &**rhs) {
1166 add_field(op, *cmp, *col, val);
1168 }
1169 }
1170 ColumnOp::ColCmpVal { lhs, cmp, rhs } => add_field(op, *cmp, *lhs, rhs),
1171 ColumnOp::Log {
1172 op: OpLogic::And,
1173 operands,
1174 } => {
1175 for oper in operands.iter() {
1176 extract_cols(oper, col_map, found);
1177 }
1178 }
1179 ColumnOp::Log { op: OpLogic::Or, .. } | ColumnOp::Col(_) | ColumnOp::Val(_) => {
1180 found.push(IndexColumnOp::Scan(op));
1181 }
1182 }
1183}
1184
1185#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1186pub struct QueryExpr {
1193 pub source: SourceExpr,
1194 pub query: Vec<Query>,
1195}
1196
1197impl From<SourceExpr> for QueryExpr {
1198 fn from(source: SourceExpr) -> Self {
1199 QueryExpr { source, query: vec![] }
1200 }
1201}
1202
1203impl QueryExpr {
1204 pub fn new<T: Into<SourceExpr>>(source: T) -> Self {
1205 Self {
1206 source: source.into(),
1207 query: vec![],
1208 }
1209 }
1210
1211 pub fn walk_sources<E>(&self, on_source: &mut impl FnMut(&SourceExpr) -> Result<(), E>) -> Result<(), E> {
1215 on_source(&self.source)?;
1216 self.query.iter().try_for_each(|q| q.walk_sources(on_source))
1217 }
1218
1219 pub fn head(&self) -> &Arc<Header> {
1227 self.query
1228 .iter()
1229 .rev()
1230 .find_map(|op| match op {
1231 Query::Select(_) => None,
1232 Query::IndexScan(scan) => Some(&scan.table.head),
1233 Query::IndexJoin(join) if join.return_index_rows => Some(join.index_side.head()),
1234 Query::IndexJoin(join) => Some(join.probe_side.head()),
1235 Query::Project(proj) => Some(&proj.header_after),
1236 Query::JoinInner(join) => join.inner.as_ref(),
1237 })
1238 .unwrap_or_else(|| self.source.head())
1239 }
1240
1241 pub fn reads_from_table(&self, id: &TableId) -> bool {
1243 self.source.table_id() == Some(*id)
1244 || self.query.iter().any(|q| match q {
1245 Query::Select(_) | Query::Project(..) => false,
1246 Query::IndexScan(scan) => scan.table.table_id == *id,
1247 Query::JoinInner(join) => join.rhs.reads_from_table(id),
1248 Query::IndexJoin(join) => {
1249 join.index_side.table_id() == Some(*id) || join.probe_side.reads_from_table(id)
1250 }
1251 })
1252 }
1253
1254 pub fn with_index_eq(mut self, table: DbTable, columns: ColList, value: AlgebraicValue) -> Self {
1258 let point = |v: AlgebraicValue| (Bound::Included(v.clone()), Bound::Included(v));
1259
1260 let Some(query) = self.query.pop() else {
1262 let bounds = point(value);
1263 self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1264 return self;
1265 };
1266 match query {
1267 Query::JoinInner(JoinExpr {
1269 rhs:
1270 QueryExpr {
1271 source: SourceExpr::DbTable(ref db_table),
1272 ..
1273 },
1274 ..
1275 }) if table.table_id != db_table.table_id => {
1276 self = self.with_index_eq(db_table.clone(), columns, value);
1277 self.query.push(query);
1278 self
1279 }
1280 Query::JoinInner(JoinExpr {
1282 rhs,
1283 col_lhs,
1284 col_rhs,
1285 inner: semi,
1286 }) => {
1287 self.query.push(Query::JoinInner(JoinExpr {
1288 rhs: rhs.with_index_eq(table, columns, value),
1289 col_lhs,
1290 col_rhs,
1291 inner: semi,
1292 }));
1293 self
1294 }
1295 Query::Select(filter) => {
1297 let op = ColumnOp::and_cmp(OpCmp::Eq, &columns, value);
1298 self.query.push(Query::Select(ColumnOp::and(filter, op)));
1299 self
1300 }
1301 query => {
1303 self.query.push(query);
1304 let op = ColumnOp::and_cmp(OpCmp::Eq, &columns, value);
1305 self.query.push(Query::Select(op));
1306 self
1307 }
1308 }
1309 }
1310
1311 pub fn with_index_lower_bound(
1315 mut self,
1316 table: DbTable,
1317 columns: ColList,
1318 value: AlgebraicValue,
1319 inclusive: bool,
1320 ) -> Self {
1321 let Some(query) = self.query.pop() else {
1323 let bounds = (Self::bound(value, inclusive), Bound::Unbounded);
1324 self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1325 return self;
1326 };
1327 match query {
1328 Query::JoinInner(JoinExpr {
1330 rhs:
1331 QueryExpr {
1332 source: SourceExpr::DbTable(ref db_table),
1333 ..
1334 },
1335 ..
1336 }) if table.table_id != db_table.table_id => {
1337 self = self.with_index_lower_bound(table, columns, value, inclusive);
1338 self.query.push(query);
1339 self
1340 }
1341 Query::JoinInner(JoinExpr {
1343 rhs,
1344 col_lhs,
1345 col_rhs,
1346 inner: semi,
1347 }) => {
1348 self.query.push(Query::JoinInner(JoinExpr {
1349 rhs: rhs.with_index_lower_bound(table, columns, value, inclusive),
1350 col_lhs,
1351 col_rhs,
1352 inner: semi,
1353 }));
1354 self
1355 }
1356 Query::IndexScan(IndexScan {
1358 columns: lhs_col_id,
1359 bounds: (Bound::Unbounded, Bound::Included(upper)),
1360 ..
1361 }) if columns == lhs_col_id => {
1362 let bounds = (Self::bound(value, inclusive), Bound::Included(upper));
1363 self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1364 self
1365 }
1366 Query::IndexScan(IndexScan {
1368 columns: lhs_col_id,
1369 bounds: (Bound::Unbounded, Bound::Excluded(upper)),
1370 ..
1371 }) if columns == lhs_col_id => {
1372 let is_never = !inclusive && value == upper;
1382
1383 let bounds = (Self::bound(value, inclusive), Bound::Excluded(upper));
1384 self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1385
1386 if is_never {
1387 log::warn!("Query will select no rows due to equal excluded bounds: {self:?}")
1388 }
1389
1390 self
1391 }
1392 Query::Select(filter) => {
1394 let bounds = (Self::bound(value, inclusive), Bound::Unbounded);
1395 let op = ColumnOp::from_op_col_bounds(&columns, bounds);
1396 self.query.push(Query::Select(ColumnOp::and(filter, op)));
1397 self
1398 }
1399 query => {
1401 self.query.push(query);
1402 let bounds = (Self::bound(value, inclusive), Bound::Unbounded);
1403 let op = ColumnOp::from_op_col_bounds(&columns, bounds);
1404 self.query.push(Query::Select(op));
1405 self
1406 }
1407 }
1408 }
1409
1410 pub fn with_index_upper_bound(
1414 mut self,
1415 table: DbTable,
1416 columns: ColList,
1417 value: AlgebraicValue,
1418 inclusive: bool,
1419 ) -> Self {
1420 let Some(query) = self.query.pop() else {
1422 self.query.push(Query::IndexScan(IndexScan {
1423 table,
1424 columns,
1425 bounds: (Bound::Unbounded, Self::bound(value, inclusive)),
1426 }));
1427 return self;
1428 };
1429 match query {
1430 Query::JoinInner(JoinExpr {
1432 rhs:
1433 QueryExpr {
1434 source: SourceExpr::DbTable(ref db_table),
1435 ..
1436 },
1437 ..
1438 }) if table.table_id != db_table.table_id => {
1439 self = self.with_index_upper_bound(table, columns, value, inclusive);
1440 self.query.push(query);
1441 self
1442 }
1443 Query::JoinInner(JoinExpr {
1445 rhs,
1446 col_lhs,
1447 col_rhs,
1448 inner: semi,
1449 }) => {
1450 self.query.push(Query::JoinInner(JoinExpr {
1451 rhs: rhs.with_index_upper_bound(table, columns, value, inclusive),
1452 col_lhs,
1453 col_rhs,
1454 inner: semi,
1455 }));
1456 self
1457 }
1458 Query::IndexScan(IndexScan {
1460 columns: lhs_col_id,
1461 bounds: (Bound::Included(lower), Bound::Unbounded),
1462 ..
1463 }) if columns == lhs_col_id => {
1464 let bounds = (Bound::Included(lower), Self::bound(value, inclusive));
1465 self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1466 self
1467 }
1468 Query::IndexScan(IndexScan {
1470 columns: lhs_col_id,
1471 bounds: (Bound::Excluded(lower), Bound::Unbounded),
1472 ..
1473 }) if columns == lhs_col_id => {
1474 let is_never = !inclusive && value == lower;
1484
1485 let bounds = (Bound::Excluded(lower), Self::bound(value, inclusive));
1486 self.query.push(Query::IndexScan(IndexScan { table, columns, bounds }));
1487
1488 if is_never {
1489 log::warn!("Query will select no rows due to equal excluded bounds: {self:?}")
1490 }
1491
1492 self
1493 }
1494 Query::Select(filter) => {
1496 let bounds = (Bound::Unbounded, Self::bound(value, inclusive));
1497 let op = ColumnOp::from_op_col_bounds(&columns, bounds);
1498 self.query.push(Query::Select(ColumnOp::and(filter, op)));
1499 self
1500 }
1501 query => {
1503 self.query.push(query);
1504 let bounds = (Bound::Unbounded, Self::bound(value, inclusive));
1505 let op = ColumnOp::from_op_col_bounds(&columns, bounds);
1506 self.query.push(Query::Select(op));
1507 self
1508 }
1509 }
1510 }
1511
1512 pub fn with_select<O>(mut self, op: O) -> Result<Self, RelationError>
1513 where
1514 O: Into<FieldOp>,
1515 {
1516 let op = op.into();
1517 let Some(query) = self.query.pop() else {
1518 return self.add_base_select(op);
1519 };
1520
1521 match (query, op) {
1522 (
1523 Query::JoinInner(JoinExpr {
1524 rhs,
1525 col_lhs,
1526 col_rhs,
1527 inner,
1528 }),
1529 FieldOp::Cmp {
1530 op: OpQuery::Cmp(cmp),
1531 lhs: field,
1532 rhs: value,
1533 },
1534 ) => match (*field, *value) {
1535 (FieldOp::Field(FieldExpr::Name(field)), FieldOp::Field(FieldExpr::Value(value)))
1536 if self.head().column_pos(field).is_some() =>
1538 {
1539 self = self.with_select(FieldOp::cmp(field, cmp, value))?;
1542 self.query.push(Query::JoinInner(JoinExpr { rhs, col_lhs, col_rhs, inner }));
1543 Ok(self)
1544 }
1545 (FieldOp::Field(FieldExpr::Name(field)), FieldOp::Field(FieldExpr::Value(value)))
1546 if rhs.head().column_pos(field).is_some() =>
1548 {
1549 let rhs = rhs.with_select(FieldOp::cmp(field, cmp, value))?;
1552 self.query.push(Query::JoinInner(JoinExpr {
1553 rhs,
1554 col_lhs,
1555 col_rhs,
1556 inner,
1557 }));
1558 Ok(self)
1559 }
1560 (field, value) => {
1561 self.query.push(Query::JoinInner(JoinExpr { rhs, col_lhs, col_rhs, inner, }));
1562
1563 self.check_field_op_logics(&field)?;
1566 self.check_field_op_logics(&value)?;
1567 let col = field.names_to_cols(self.head()).unwrap();
1569 let value = value.names_to_cols(self.head()).unwrap();
1570 self.query.push(Query::Select(ColumnOp::new(OpQuery::Cmp(cmp), col, value)));
1572 Ok(self)
1573 }
1574 },
1575 (Query::Select(lhs), rhs) => {
1577 self.check_field_op(&rhs)?;
1579 let rhs = rhs.names_to_cols(self.head()).unwrap();
1581 self.query.push(Query::Select(ColumnOp::and(lhs, rhs)));
1583 Ok(self)
1584 }
1585 (query, op) => {
1587 self.query.push(query);
1588 self.add_base_select(op)
1589 }
1590 }
1591 }
1592
1593 fn add_base_select(mut self, op: FieldOp) -> Result<Self, RelationError> {
1596 self.check_field_op(&op)?;
1598 let op = op.names_to_cols(self.head()).unwrap();
1600 self.query.push(Query::Select(op));
1602 Ok(self)
1603 }
1604
1605 fn check_field_op(&self, op: &FieldOp) -> Result<(), RelationError> {
1608 use OpQuery::*;
1609 match op {
1610 FieldOp::Cmp { op: Logic(_), lhs, rhs } => {
1612 self.check_field_op(lhs)?;
1613 self.check_field_op(rhs)?;
1614 Ok(())
1615 }
1616 FieldOp::Cmp { op: Cmp(_), lhs, rhs } => {
1622 self.check_field_op_logics(lhs)?;
1623 self.check_field_op_logics(rhs)?;
1624 Ok(())
1625 }
1626 FieldOp::Field(FieldExpr::Value(AlgebraicValue::Bool(_))) => Ok(()),
1627 FieldOp::Field(FieldExpr::Value(v)) => Err(RelationError::NotBoolValue { val: v.clone() }),
1628 FieldOp::Field(FieldExpr::Name(field)) => {
1629 let field = *field;
1630 let head = self.head();
1631 let col_id = head.column_pos_or_err(field)?;
1632 let col_ty = &head.fields[col_id.idx()].algebraic_type;
1633 match col_ty {
1634 &AlgebraicType::Bool => Ok(()),
1635 ty => Err(RelationError::NotBoolType { field, ty: ty.clone() }),
1636 }
1637 }
1638 }
1639 }
1640
1641 fn check_field_op_logics(&self, op: &FieldOp) -> Result<(), RelationError> {
1643 use OpQuery::*;
1644 match op {
1645 FieldOp::Field(_) => Ok(()),
1646 FieldOp::Cmp { op: Cmp(_), lhs, rhs } => {
1647 self.check_field_op_logics(lhs)?;
1648 self.check_field_op_logics(rhs)?;
1649 Ok(())
1650 }
1651 FieldOp::Cmp { op: Logic(_), lhs, rhs } => {
1652 self.check_field_op(lhs)?;
1653 self.check_field_op(rhs)?;
1654 Ok(())
1655 }
1656 }
1657 }
1658
1659 pub fn with_select_cmp<LHS, RHS, O>(self, op: O, lhs: LHS, rhs: RHS) -> Result<Self, RelationError>
1660 where
1661 LHS: Into<FieldExpr>,
1662 RHS: Into<FieldExpr>,
1663 O: Into<OpQuery>,
1664 {
1665 let op = FieldOp::new(op.into(), FieldOp::Field(lhs.into()), FieldOp::Field(rhs.into()));
1666 self.with_select(op)
1667 }
1668
1669 pub fn with_project(
1673 mut self,
1674 fields: Vec<FieldExpr>,
1675 wildcard_table: Option<TableId>,
1676 ) -> Result<Self, RelationError> {
1677 if !fields.is_empty() {
1678 let header_before = self.head();
1679
1680 let mut cols = Vec::with_capacity(fields.len());
1682 for field in fields {
1683 cols.push(field.name_to_col(header_before)?);
1684 }
1685
1686 let header_after = Arc::new(header_before.project(&cols)?);
1689
1690 self.query.push(Query::Project(ProjectExpr {
1692 cols,
1693 wildcard_table,
1694 header_after,
1695 }));
1696 }
1697 Ok(self)
1698 }
1699
1700 pub fn with_join_inner_raw(
1701 mut self,
1702 q_rhs: QueryExpr,
1703 c_lhs: ColId,
1704 c_rhs: ColId,
1705 inner: Option<Arc<Header>>,
1706 ) -> Self {
1707 self.query
1708 .push(Query::JoinInner(JoinExpr::new(q_rhs, c_lhs, c_rhs, inner)));
1709 self
1710 }
1711
1712 pub fn with_join_inner(self, q_rhs: impl Into<QueryExpr>, c_lhs: ColId, c_rhs: ColId, semi: bool) -> Self {
1713 let q_rhs = q_rhs.into();
1714 let inner = (!semi).then(|| Arc::new(self.head().extend(q_rhs.head())));
1715 self.with_join_inner_raw(q_rhs, c_lhs, c_rhs, inner)
1716 }
1717
1718 fn bound(value: AlgebraicValue, inclusive: bool) -> Bound<AlgebraicValue> {
1719 if inclusive {
1720 Bound::Included(value)
1721 } else {
1722 Bound::Excluded(value)
1723 }
1724 }
1725
1726 pub fn try_semi_join(self) -> QueryExpr {
1762 let QueryExpr { source, query } = self;
1763
1764 let Some(source_table_id) = source.table_id() else {
1765 return QueryExpr { source, query };
1767 };
1768
1769 let mut exprs = query.into_iter();
1770 let Some(join_candidate) = exprs.next() else {
1771 return QueryExpr { source, query: vec![] };
1773 };
1774 let Query::JoinInner(join) = join_candidate else {
1775 return QueryExpr {
1777 source,
1778 query: itertools::chain![Some(join_candidate), exprs].collect(),
1779 };
1780 };
1781
1782 let Some(project_candidate) = exprs.next() else {
1783 return QueryExpr {
1785 source,
1786 query: vec![Query::JoinInner(join)],
1787 };
1788 };
1789
1790 let Query::Project(proj) = project_candidate else {
1791 return QueryExpr {
1793 source,
1794 query: itertools::chain![Some(Query::JoinInner(join)), Some(project_candidate), exprs].collect(),
1795 };
1796 };
1797
1798 if proj.wildcard_table != Some(source_table_id) {
1799 return QueryExpr {
1801 source,
1802 query: itertools::chain![Some(Query::JoinInner(join)), Some(Query::Project(proj)), exprs].collect(),
1803 };
1804 };
1805
1806 let semijoin = JoinExpr { inner: None, ..join };
1808
1809 QueryExpr {
1810 source,
1811 query: itertools::chain![Some(Query::JoinInner(semijoin)), exprs].collect(),
1812 }
1813 }
1814
1815 fn try_index_join(self) -> QueryExpr {
1822 let mut query = self;
1823 if query.query.len() != 1 {
1826 return query;
1827 }
1828
1829 if query.source.is_mem_table() {
1832 return query;
1833 }
1834 let source = query.source;
1835 let join = query.query.pop().unwrap();
1836
1837 match join {
1838 Query::JoinInner(join @ JoinExpr { inner: None, .. }) => {
1839 if !join.rhs.query.is_empty() {
1840 if source.head().has_constraint(join.col_lhs, Constraints::indexed()) {
1842 let index_join = IndexJoin {
1843 probe_side: join.rhs,
1844 probe_col: join.col_rhs,
1845 index_side: source.clone(),
1846 index_select: None,
1847 index_col: join.col_lhs,
1848 return_index_rows: true,
1849 };
1850 let query = [Query::IndexJoin(index_join)].into();
1851 return QueryExpr { source, query };
1852 }
1853 }
1854 QueryExpr {
1855 source,
1856 query: vec![Query::JoinInner(join)],
1857 }
1858 }
1859 first => QueryExpr {
1860 source,
1861 query: vec![first],
1862 },
1863 }
1864 }
1865
1866 fn optimize_select(mut q: QueryExpr, op: ColumnOp, tables: &[SourceExpr]) -> QueryExpr {
1868 let mut fields_found = HashSet::default();
1871 for schema in tables {
1872 for op in select_best_index(&mut fields_found, schema.head(), &op) {
1873 if let IndexColumnOp::Scan(op) = &op {
1874 if op.as_col_cmp().is_some_and(|cc| !fields_found.insert(cc)) {
1877 continue;
1878 }
1879 }
1880
1881 match op {
1882 IndexColumnOp::Index(idx) => {
1885 let table = schema
1886 .get_db_table()
1887 .expect("find_sargable_ops(schema, op) implies `schema.is_db_table()`")
1888 .clone();
1889
1890 q = match idx {
1891 IndexArgument::Eq { columns, value } => q.with_index_eq(table, columns.clone(), value),
1892 IndexArgument::LowerBound {
1893 columns,
1894 value,
1895 inclusive,
1896 } => q.with_index_lower_bound(table, columns.clone(), value, inclusive),
1897 IndexArgument::UpperBound {
1898 columns,
1899 value,
1900 inclusive,
1901 } => q.with_index_upper_bound(table, columns.clone(), value, inclusive),
1902 };
1903 }
1904 IndexColumnOp::Scan(rhs) => {
1906 let rhs = rhs.clone();
1907 let op = match q.query.pop() {
1908 Some(Query::Select(lhs)) => ColumnOp::and(lhs, rhs),
1910 None => rhs,
1911 Some(other) => {
1912 q.query.push(other);
1913 rhs
1914 }
1915 };
1916 q.query.push(Query::Select(op));
1917 }
1918 }
1919 }
1920 }
1921
1922 q
1923 }
1924
1925 pub fn optimize(mut self, row_count: &impl Fn(TableId, &str) -> i64) -> Self {
1926 let mut q = Self {
1927 source: self.source.clone(),
1928 query: Vec::with_capacity(self.query.len()),
1929 };
1930
1931 if matches!(&*self.query, [Query::IndexJoin(_)]) {
1932 if let Some(Query::IndexJoin(join)) = self.query.pop() {
1933 q.query.push(Query::IndexJoin(join.reorder(row_count)));
1934 return q;
1935 }
1936 }
1937
1938 for query in self.query {
1939 match query {
1940 Query::Select(op) => {
1941 q = Self::optimize_select(q, op, from_ref(&self.source));
1942 }
1943 Query::JoinInner(join) => {
1944 q = q.with_join_inner_raw(join.rhs.optimize(row_count), join.col_lhs, join.col_rhs, join.inner);
1945 }
1946 _ => q.query.push(query),
1947 };
1948 }
1949
1950 let q = q.try_semi_join();
1952 let q = q.try_index_join();
1953 if matches!(&*q.query, [Query::IndexJoin(_)]) {
1954 return q.optimize(row_count);
1955 }
1956 q
1957 }
1958}
1959
1960impl AuthAccess for Query {
1961 fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> {
1962 if owner == caller {
1963 return Ok(());
1964 }
1965
1966 self.walk_sources(&mut |s| s.check_auth(owner, caller))
1967 }
1968}
1969
1970#[derive(Debug, Eq, PartialEq, From)]
1971pub enum Expr {
1972 #[from]
1973 Value(AlgebraicValue),
1974 Block(Vec<Expr>),
1975 Ident(String),
1976 Crud(Box<CrudExpr>),
1977 Halt(ErrorLang),
1978}
1979
1980impl From<QueryExpr> for Expr {
1981 fn from(x: QueryExpr) -> Self {
1982 Expr::Crud(Box::new(CrudExpr::Query(x)))
1983 }
1984}
1985
1986impl fmt::Display for Query {
1987 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1988 match self {
1989 Query::IndexScan(op) => {
1990 write!(f, "index_scan {op:?}")
1991 }
1992 Query::IndexJoin(op) => {
1993 write!(f, "index_join {op:?}")
1994 }
1995 Query::Select(q) => {
1996 write!(f, "select {q}")
1997 }
1998 Query::Project(proj) => {
1999 let q = &proj.cols;
2000 write!(f, "project")?;
2001 if !q.is_empty() {
2002 write!(f, " ")?;
2003 }
2004 for (pos, x) in q.iter().enumerate() {
2005 write!(f, "{x}")?;
2006 if pos + 1 < q.len() {
2007 write!(f, ", ")?;
2008 }
2009 }
2010 Ok(())
2011 }
2012 Query::JoinInner(q) => {
2013 write!(f, "&inner {:?} ON {} = {}", q.rhs, q.col_lhs, q.col_rhs)
2014 }
2015 }
2016 }
2017}
2018
2019impl AuthAccess for SourceExpr {
2020 fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> {
2021 if owner == caller || self.table_access() == StAccess::Public {
2022 return Ok(());
2023 }
2024
2025 Err(AuthError::TablePrivate {
2026 named: self.table_name().to_string(),
2027 })
2028 }
2029}
2030
2031impl AuthAccess for QueryExpr {
2032 fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> {
2033 if owner == caller {
2034 return Ok(());
2035 }
2036 self.walk_sources(&mut |s| s.check_auth(owner, caller))
2037 }
2038}
2039
2040impl AuthAccess for CrudExpr {
2041 fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> {
2042 if owner == caller {
2043 return Ok(());
2044 }
2045 if let CrudExpr::Query(q) = self {
2047 return q.check_auth(owner, caller);
2048 }
2049
2050 Err(AuthError::OwnerRequired)
2052 }
2053}
2054
2055#[derive(Debug, PartialEq)]
2056pub struct Update {
2057 pub table_id: TableId,
2058 pub table_name: Box<str>,
2059 pub inserts: Vec<ProductValue>,
2060 pub deletes: Vec<ProductValue>,
2061}
2062
2063#[derive(Debug, PartialEq)]
2064pub enum Code {
2065 Value(AlgebraicValue),
2066 Table(MemTable),
2067 Halt(ErrorLang),
2068 Block(Vec<Code>),
2069 Crud(CrudExpr),
2070 Pass(Option<Update>),
2071}
2072
2073impl fmt::Display for Code {
2074 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2075 match self {
2076 Code::Value(x) => {
2077 write!(f, "{:?}", &x)
2078 }
2079 Code::Block(_) => write!(f, "Block"),
2080 x => todo!("{:?}", x),
2081 }
2082 }
2083}
2084
2085#[derive(Debug, PartialEq)]
2086pub enum CodeResult {
2087 Value(AlgebraicValue),
2088 Table(MemTable),
2089 Block(Vec<CodeResult>),
2090 Halt(ErrorLang),
2091 Pass(Option<Update>),
2092}
2093
2094impl From<Code> for CodeResult {
2095 fn from(code: Code) -> Self {
2096 match code {
2097 Code::Value(x) => Self::Value(x),
2098 Code::Table(x) => Self::Table(x),
2099 Code::Halt(x) => Self::Halt(x),
2100 Code::Block(x) => {
2101 if x.is_empty() {
2102 Self::Pass(None)
2103 } else {
2104 Self::Block(x.into_iter().map(CodeResult::from).collect())
2105 }
2106 }
2107 Code::Pass(x) => Self::Pass(x),
2108 x => Self::Halt(ErrorLang::new(
2109 ErrorKind::Compiler,
2110 Some(&format!("Invalid result: {x}")),
2111 )),
2112 }
2113 }
2114}
2115
2116#[cfg(test)]
2117mod tests {
2118 use super::*;
2119
2120 use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9Builder;
2121 use spacetimedb_sats::{product, AlgebraicType, ProductType};
2122 use spacetimedb_schema::{def::ModuleDef, relation::Column, schema::Schema};
2123 use typed_arena::Arena;
2124
2125 const ALICE: Identity = Identity::from_byte_array([1; 32]);
2126 const BOB: Identity = Identity::from_byte_array([2; 32]);
2127
2128 fn tables() -> [SourceExpr; 2] {
2132 [
2133 SourceExpr::InMemory {
2134 source_id: SourceId(0),
2135 header: Arc::new(Header {
2136 table_id: 42.into(),
2137 table_name: "foo".into(),
2138 fields: vec![],
2139 constraints: Default::default(),
2140 }),
2141 table_type: StTableType::User,
2142 table_access: StAccess::Private,
2143 },
2144 SourceExpr::DbTable(DbTable {
2145 head: Arc::new(Header {
2146 table_id: 42.into(),
2147 table_name: "foo".into(),
2148 fields: vec![],
2149 constraints: [(ColId(42).into(), Constraints::indexed())].into_iter().collect(),
2150 }),
2151 table_id: 42.into(),
2152 table_type: StTableType::User,
2153 table_access: StAccess::Private,
2154 }),
2155 ]
2156 }
2157
2158 fn queries() -> impl IntoIterator<Item = Query> {
2159 let [mem_table, db_table] = tables();
2160 [
2163 Query::IndexScan(IndexScan {
2164 table: db_table.get_db_table().unwrap().clone(),
2165 columns: ColList::new(42.into()),
2166 bounds: (Bound::Included(22.into()), Bound::Unbounded),
2167 }),
2168 Query::IndexJoin(IndexJoin {
2169 probe_side: mem_table.clone().into(),
2170 probe_col: 0.into(),
2171 index_side: SourceExpr::DbTable(DbTable {
2172 head: Arc::new(Header {
2173 table_id: db_table.head().table_id,
2174 table_name: db_table.table_name().into(),
2175 fields: vec![],
2176 constraints: Default::default(),
2177 }),
2178 table_id: db_table.head().table_id,
2179 table_type: StTableType::User,
2180 table_access: StAccess::Public,
2181 }),
2182 index_select: None,
2183 index_col: 22.into(),
2184 return_index_rows: true,
2185 }),
2186 Query::JoinInner(JoinExpr {
2187 col_rhs: 1.into(),
2188 rhs: mem_table.into(),
2189 col_lhs: 1.into(),
2190 inner: None,
2191 }),
2192 ]
2193 }
2194
2195 fn query_exprs() -> impl IntoIterator<Item = QueryExpr> {
2196 tables().map(|table| {
2197 let mut expr = QueryExpr::from(table);
2198 expr.query = queries().into_iter().collect();
2199 expr
2200 })
2201 }
2202
2203 fn assert_owner_private<T: AuthAccess>(auth: &T) {
2204 assert!(auth.check_auth(ALICE, ALICE).is_ok());
2205 assert!(matches!(
2206 auth.check_auth(ALICE, BOB),
2207 Err(AuthError::TablePrivate { .. })
2208 ));
2209 }
2210
2211 fn assert_owner_required<T: AuthAccess>(auth: T) {
2212 assert!(auth.check_auth(ALICE, ALICE).is_ok());
2213 assert!(matches!(auth.check_auth(ALICE, BOB), Err(AuthError::OwnerRequired)));
2214 }
2215
2216 fn mem_table(id: TableId, name: &str, fields: &[(u16, AlgebraicType, bool)]) -> SourceExpr {
2217 let table_access = StAccess::Public;
2218 let head = Header::new(
2219 id,
2220 name.into(),
2221 fields
2222 .iter()
2223 .map(|(col, ty, _)| Column::new(FieldName::new(id, (*col).into()), ty.clone()))
2224 .collect(),
2225 fields
2226 .iter()
2227 .enumerate()
2228 .filter(|(_, (_, _, indexed))| *indexed)
2229 .map(|(i, _)| (ColId::from(i).into(), Constraints::indexed())),
2230 );
2231 SourceExpr::InMemory {
2232 source_id: SourceId(0),
2233 header: Arc::new(head),
2234 table_access,
2235 table_type: StTableType::User,
2236 }
2237 }
2238
2239 #[test]
2240 fn test_index_to_inner_join() {
2241 let index_side = mem_table(
2242 0.into(),
2243 "index",
2244 &[(0, AlgebraicType::U8, false), (1, AlgebraicType::U8, true)],
2245 );
2246 let probe_side = mem_table(
2247 1.into(),
2248 "probe",
2249 &[(0, AlgebraicType::U8, false), (1, AlgebraicType::U8, true)],
2250 );
2251
2252 let index_col = 1.into();
2253 let probe_col = 1.into();
2254 let index_select = ColumnOp::cmp(0, OpCmp::Eq, 0u8);
2255 let join = IndexJoin {
2256 probe_side: probe_side.clone().into(),
2257 probe_col,
2258 index_side: index_side.clone(),
2259 index_select: Some(index_select.clone()),
2260 index_col,
2261 return_index_rows: false,
2262 };
2263
2264 let expr = join.to_inner_join();
2265
2266 assert_eq!(expr.source, probe_side);
2267 assert_eq!(expr.query.len(), 1);
2268
2269 let Query::JoinInner(ref join) = expr.query[0] else {
2270 panic!("expected an inner join, but got {:#?}", expr.query[0]);
2271 };
2272
2273 assert_eq!(join.col_lhs, probe_col);
2274 assert_eq!(join.col_rhs, index_col);
2275 assert_eq!(
2276 join.rhs,
2277 QueryExpr {
2278 source: index_side,
2279 query: vec![index_select.into()]
2280 }
2281 );
2282 assert_eq!(join.inner, None);
2283 }
2284
2285 fn setup_best_index() -> (Header, [ColId; 5], [AlgebraicValue; 5]) {
2286 let table_id = 0.into();
2287
2288 let vals = [1, 2, 3, 4, 5].map(AlgebraicValue::U64);
2289 let col_ids = [0, 1, 2, 3, 4].map(ColId);
2290 let [a, b, c, d, _] = col_ids;
2291 let columns = col_ids.map(|c| Column::new(FieldName::new(table_id, c), AlgebraicType::I8));
2292
2293 let head1 = Header::new(
2294 table_id,
2295 "t1".into(),
2296 columns.to_vec(),
2297 vec![
2298 (a.into(), Constraints::primary_key()),
2300 (b.into(), Constraints::indexed()),
2302 (col_list![b, c], Constraints::unique()),
2304 (col_list![a, b, c, d], Constraints::indexed()),
2306 ],
2307 );
2308
2309 (head1, col_ids, vals)
2310 }
2311
2312 fn make_field_value((cmp, col, value): (OpCmp, ColId, &AlgebraicValue)) -> ColumnOp {
2313 ColumnOp::cmp(col, cmp, value.clone())
2314 }
2315
2316 fn scan_eq<'a>(arena: &'a Arena<ColumnOp>, col: ColId, val: &'a AlgebraicValue) -> IndexColumnOp<'a> {
2317 scan(arena, OpCmp::Eq, col, val)
2318 }
2319
2320 fn scan<'a>(arena: &'a Arena<ColumnOp>, cmp: OpCmp, col: ColId, val: &'a AlgebraicValue) -> IndexColumnOp<'a> {
2321 IndexColumnOp::Scan(arena.alloc(make_field_value((cmp, col, val))))
2322 }
2323
2324 #[test]
2325 fn best_index() {
2326 let (head1, fields, vals) = setup_best_index();
2327 let [col_a, col_b, col_c, col_d, col_e] = fields;
2328 let [val_a, val_b, val_c, val_d, val_e] = vals;
2329
2330 let arena = Arena::new();
2331 let select_best_index = |fields: &[_]| {
2332 let fields = fields
2333 .iter()
2334 .copied()
2335 .map(|(col, val): (ColId, _)| make_field_value((OpCmp::Eq, col, val)))
2336 .reduce(ColumnOp::and)
2337 .unwrap();
2338 select_best_index(&mut <_>::default(), &head1, arena.alloc(fields))
2339 };
2340
2341 let col_list_arena = Arena::new();
2342 let idx_eq = |cols, val| make_index_arg(OpCmp::Eq, col_list_arena.alloc(cols), val);
2343
2344 assert_eq!(
2346 select_best_index(&[(col_d, &val_e)]),
2347 [scan_eq(&arena, col_d, &val_e)].into(),
2348 );
2349
2350 assert_eq!(
2351 select_best_index(&[(col_a, &val_a)]),
2352 [idx_eq(col_a.into(), val_a.clone())].into(),
2353 );
2354
2355 assert_eq!(
2356 select_best_index(&[(col_b, &val_b)]),
2357 [idx_eq(col_b.into(), val_b.clone())].into(),
2358 );
2359
2360 assert_eq!(
2362 select_best_index(&[(col_b, &val_b), (col_c, &val_c)]),
2363 [idx_eq(
2364 col_list![col_b, col_c],
2365 product![val_b.clone(), val_c.clone()].into()
2366 )]
2367 .into(),
2368 );
2369
2370 assert_eq!(
2371 select_best_index(&[(col_c, &val_c), (col_b, &val_b)]),
2372 [idx_eq(
2373 col_list![col_b, col_c],
2374 product![val_b.clone(), val_c.clone()].into()
2375 )]
2376 .into(),
2377 );
2378
2379 assert_eq!(
2381 select_best_index(&[(col_a, &val_a), (col_b, &val_b), (col_c, &val_c), (col_d, &val_d)]),
2382 [idx_eq(
2383 col_list![col_a, col_b, col_c, col_d],
2384 product![val_a.clone(), val_b.clone(), val_c.clone(), val_d.clone()].into(),
2385 )]
2386 .into(),
2387 );
2388
2389 assert_eq!(
2390 select_best_index(&[(col_b, &val_b), (col_a, &val_a), (col_d, &val_d), (col_c, &val_c)]),
2391 [idx_eq(
2392 col_list![col_a, col_b, col_c, col_d],
2393 product![val_a.clone(), val_b.clone(), val_c.clone(), val_d.clone()].into(),
2394 )]
2395 .into()
2396 );
2397
2398 assert_eq!(
2400 select_best_index(&[(col_b, &val_b), (col_a, &val_a), (col_e, &val_e), (col_d, &val_d)]),
2401 [
2402 idx_eq(col_a.into(), val_a.clone()),
2403 idx_eq(col_b.into(), val_b.clone()),
2404 scan_eq(&arena, col_d, &val_d),
2405 scan_eq(&arena, col_e, &val_e),
2406 ]
2407 .into()
2408 );
2409
2410 assert_eq!(
2411 select_best_index(&[(col_b, &val_b), (col_c, &val_c), (col_d, &val_d)]),
2412 [
2413 idx_eq(col_list![col_b, col_c], product![val_b.clone(), val_c.clone()].into(),),
2414 scan_eq(&arena, col_d, &val_d),
2415 ]
2416 .into()
2417 );
2418 }
2419
2420 #[test]
2421 fn best_index_range() {
2422 let arena = Arena::new();
2423
2424 let (head1, cols, vals) = setup_best_index();
2425 let [col_a, col_b, col_c, col_d, _] = cols;
2426 let [val_a, val_b, val_c, val_d, _] = vals;
2427
2428 let select_best_index = |cols: &[_]| {
2429 let fields = cols.iter().map(|x| make_field_value(*x)).reduce(ColumnOp::and).unwrap();
2430 select_best_index(&mut <_>::default(), &head1, arena.alloc(fields))
2431 };
2432
2433 let col_list_arena = Arena::new();
2434 let idx = |cmp, cols: &[ColId], val: &AlgebraicValue| {
2435 let columns = cols.iter().copied().collect::<ColList>();
2436 let columns = col_list_arena.alloc(columns);
2437 make_index_arg(cmp, columns, val.clone())
2438 };
2439
2440 assert_eq!(
2442 select_best_index(&[(OpCmp::Gt, col_a, &val_a), (OpCmp::Lt, col_a, &val_b)]),
2443 [idx(OpCmp::Lt, &[col_a], &val_b), idx(OpCmp::Gt, &[col_a], &val_a)].into()
2444 );
2445
2446 assert_eq!(
2448 select_best_index(&[(OpCmp::Gt, col_d, &val_d), (OpCmp::Lt, col_d, &val_b)]),
2449 [
2450 scan(&arena, OpCmp::Lt, col_d, &val_b),
2451 scan(&arena, OpCmp::Gt, col_d, &val_d)
2452 ]
2453 .into()
2454 );
2455
2456 assert_eq!(
2458 select_best_index(&[(OpCmp::Gt, col_b, &val_b), (OpCmp::Lt, col_c, &val_c)]),
2459 [idx(OpCmp::Gt, &[col_b], &val_b), scan(&arena, OpCmp::Lt, col_c, &val_c)].into()
2460 );
2461
2462 let idx_bc = idx(
2464 OpCmp::Eq,
2465 &[col_b, col_c],
2466 &product![val_b.clone(), val_c.clone()].into(),
2467 );
2468 assert_eq!(
2469 select_best_index(&[
2471 (OpCmp::Eq, col_b, &val_b),
2472 (OpCmp::GtEq, col_a, &val_a),
2473 (OpCmp::Eq, col_c, &val_c),
2474 ]),
2475 [idx_bc.clone(), idx(OpCmp::GtEq, &[col_a], &val_a),].into()
2476 );
2477
2478 assert_eq!(
2480 select_best_index(&[
2481 (OpCmp::Gt, col_b, &val_b),
2482 (OpCmp::Eq, col_a, &val_a),
2483 (OpCmp::Lt, col_c, &val_c),
2484 ]),
2485 [
2486 idx(OpCmp::Eq, &[col_a], &val_a),
2487 idx(OpCmp::Gt, &[col_b], &val_b),
2488 scan(&arena, OpCmp::Lt, col_c, &val_c),
2489 ]
2490 .into()
2491 );
2492
2493 assert_eq!(
2495 select_best_index(&[
2496 (OpCmp::Eq, col_a, &val_a),
2497 (OpCmp::Eq, col_b, &val_b),
2498 (OpCmp::Eq, col_c, &val_c),
2499 (OpCmp::Gt, col_d, &val_d),
2500 ]),
2501 [
2502 idx_bc.clone(),
2503 idx(OpCmp::Eq, &[col_a], &val_a),
2504 scan(&arena, OpCmp::Gt, col_d, &val_d),
2505 ]
2506 .into()
2507 );
2508
2509 assert_eq!(
2511 select_best_index(&[
2512 (OpCmp::Eq, col_b, &val_b),
2513 (OpCmp::Eq, col_c, &val_c),
2514 (OpCmp::Eq, col_b, &val_b),
2515 (OpCmp::Eq, col_c, &val_c),
2516 ]),
2517 [idx_bc.clone(), idx_bc].into()
2518 );
2519 }
2520
2521 #[test]
2522 fn test_auth_table() {
2523 tables().iter().for_each(assert_owner_private)
2524 }
2525
2526 #[test]
2527 fn test_auth_query_code() {
2528 for code in query_exprs() {
2529 assert_owner_private(&code)
2530 }
2531 }
2532
2533 #[test]
2534 fn test_auth_query() {
2535 for query in queries() {
2536 assert_owner_private(&query);
2537 }
2538 }
2539
2540 #[test]
2541 fn test_auth_crud_code_query() {
2542 for query in query_exprs() {
2543 let crud = CrudExpr::Query(query);
2544 assert_owner_private(&crud);
2545 }
2546 }
2547
2548 #[test]
2549 fn test_auth_crud_code_insert() {
2550 for table in tables().into_iter().filter_map(|s| s.get_db_table().cloned()) {
2551 let crud = CrudExpr::Insert { table, rows: vec![] };
2552 assert_owner_required(crud);
2553 }
2554 }
2555
2556 #[test]
2557 fn test_auth_crud_code_update() {
2558 for qc in query_exprs() {
2559 let crud = CrudExpr::Update {
2560 delete: qc,
2561 assignments: Default::default(),
2562 };
2563 assert_owner_required(crud);
2564 }
2565 }
2566
2567 #[test]
2568 fn test_auth_crud_code_delete() {
2569 for query in query_exprs() {
2570 let crud = CrudExpr::Delete { query };
2571 assert_owner_required(crud);
2572 }
2573 }
2574
2575 fn test_def() -> ModuleDef {
2576 let mut builder = RawModuleDefV9Builder::new();
2577 builder.build_table_with_new_type(
2578 "lhs",
2579 ProductType::from([("a", AlgebraicType::I32), ("b", AlgebraicType::String)]),
2580 true,
2581 );
2582 builder.build_table_with_new_type(
2583 "rhs",
2584 ProductType::from([("c", AlgebraicType::I32), ("d", AlgebraicType::I64)]),
2585 true,
2586 );
2587 builder.finish().try_into().expect("test def should be valid")
2588 }
2589
2590 #[test]
2591 fn optimize_inner_join_to_semijoin() {
2593 let def: ModuleDef = test_def();
2594 let lhs = TableSchema::from_module_def(&def, def.table("lhs").unwrap(), (), 0.into());
2595 let rhs = TableSchema::from_module_def(&def, def.table("rhs").unwrap(), (), 1.into());
2596
2597 let lhs_source = SourceExpr::from(&lhs);
2598 let rhs_source = SourceExpr::from(&rhs);
2599
2600 let q = QueryExpr::new(lhs_source.clone())
2601 .with_join_inner(rhs_source.clone(), 0.into(), 0.into(), false)
2602 .with_project(
2603 [0, 1]
2604 .map(|c| FieldExpr::Name(FieldName::new(lhs.table_id, c.into())))
2605 .into(),
2606 Some(TableId::SENTINEL),
2607 )
2608 .unwrap();
2609 let q = q.optimize(&|_, _| 0);
2610
2611 assert_eq!(q.source, lhs_source, "Optimized query should read from lhs");
2612
2613 assert_eq!(
2614 q.query.len(),
2615 1,
2616 "Optimized query should have a single member, a semijoin"
2617 );
2618 match &q.query[0] {
2619 Query::JoinInner(JoinExpr { rhs, inner: semi, .. }) => {
2620 assert_eq!(semi, &None, "Optimized query should be a semijoin");
2621 assert_eq!(rhs.source, rhs_source, "Optimized query should filter with rhs");
2622 assert!(
2623 rhs.query.is_empty(),
2624 "Optimized query should not filter rhs before joining"
2625 );
2626 }
2627 wrong => panic!("Expected an inner join, but found {wrong:?}"),
2628 }
2629 }
2630
2631 #[test]
2632 fn optimize_inner_join_no_project() {
2634 let def: ModuleDef = test_def();
2635 let lhs = TableSchema::from_module_def(&def, def.table("lhs").unwrap(), (), 0.into());
2636 let rhs = TableSchema::from_module_def(&def, def.table("rhs").unwrap(), (), 1.into());
2637
2638 let lhs_source = SourceExpr::from(&lhs);
2639 let rhs_source = SourceExpr::from(&rhs);
2640
2641 let q = QueryExpr::new(lhs_source.clone()).with_join_inner(rhs_source.clone(), 0.into(), 0.into(), false);
2642 let optimized = q.clone().optimize(&|_, _| 0);
2643 assert_eq!(q, optimized);
2644 }
2645
2646 #[test]
2647 fn optimize_inner_join_wrong_project() {
2649 let def: ModuleDef = test_def();
2650 let lhs = TableSchema::from_module_def(&def, def.table("lhs").unwrap(), (), 0.into());
2651 let rhs = TableSchema::from_module_def(&def, def.table("rhs").unwrap(), (), 1.into());
2652
2653 let lhs_source = SourceExpr::from(&lhs);
2654 let rhs_source = SourceExpr::from(&rhs);
2655
2656 let q = QueryExpr::new(lhs_source.clone())
2657 .with_join_inner(rhs_source.clone(), 0.into(), 0.into(), false)
2658 .with_project(
2659 [0, 1]
2660 .map(|c| FieldExpr::Name(FieldName::new(rhs.table_id, c.into())))
2661 .into(),
2662 Some(TableId(1)),
2663 )
2664 .unwrap();
2665 let optimized = q.clone().optimize(&|_, _| 0);
2666 assert_eq!(q, optimized);
2667 }
2668}