spacetimedb/sql/
compiler.rs

1use super::ast::TableSchemaView;
2use super::ast::{compile_to_ast, Column, From, Join, Selection, SqlAst};
3use super::type_check::TypeCheck;
4use crate::db::relational_db::RelationalDB;
5use crate::error::{DBError, PlanError};
6use core::ops::Deref;
7use spacetimedb_data_structures::map::IntMap;
8use spacetimedb_datastore::locking_tx_datastore::state_view::StateView;
9use spacetimedb_lib::identity::AuthCtx;
10use spacetimedb_primitives::ColId;
11use spacetimedb_schema::relation::{self, ColExpr, DbTable, FieldName, Header};
12use spacetimedb_schema::schema::TableSchema;
13use spacetimedb_vm::expr::{CrudExpr, Expr, FieldExpr, QueryExpr, SourceExpr};
14use spacetimedb_vm::operator::OpCmp;
15use std::sync::Arc;
16
17/// DIRTY HACK ALERT: Maximum allowed length, in UTF-8 bytes, of SQL queries.
18/// Any query longer than this will be rejected.
19/// This prevents a stack overflow when compiling queries with deeply-nested `AND` and `OR` conditions.
20const MAX_SQL_LENGTH: usize = 50_000;
21
22/// Compile the `SQL` expression into an `ast`
23pub fn compile_sql<T: TableSchemaView + StateView>(
24    db: &RelationalDB,
25    auth: &AuthCtx,
26    tx: &T,
27    sql_text: &str,
28) -> Result<Vec<CrudExpr>, DBError> {
29    if sql_text.len() > MAX_SQL_LENGTH {
30        return Err(anyhow::anyhow!("SQL query exceeds maximum allowed length: \"{sql_text:.120}...\"").into());
31    }
32    tracing::trace!(sql = sql_text);
33    let ast = compile_to_ast(db, auth, tx, sql_text)?;
34
35    // TODO(perf, bikeshedding): SmallVec?
36    let mut results = Vec::with_capacity(ast.len());
37
38    for sql in ast {
39        results.push(compile_statement(db, sql).map_err(|error| DBError::Plan {
40            sql: sql_text.to_string(),
41            error,
42        })?);
43    }
44
45    Ok(results)
46}
47
48fn expr_for_projection(table: &From, of: Expr) -> Result<FieldExpr, PlanError> {
49    match of {
50        Expr::Ident(x) => table.find_field(&x).map(|(f, _)| FieldExpr::Name(f)),
51        Expr::Value(x) => Ok(FieldExpr::Value(x)),
52        x => unreachable!("Wrong expression in SQL query {:?}", x),
53    }
54}
55
56/// Compiles a `WHERE ...` clause
57fn compile_where(mut q: QueryExpr, filter: Selection) -> Result<QueryExpr, PlanError> {
58    for op in filter.clause.flatten_ands() {
59        q = q.with_select(op)?;
60    }
61    Ok(q)
62}
63
64/// Compiles a `SELECT ...` clause
65fn compile_select(table: From, project: Box<[Column]>, selection: Option<Selection>) -> Result<QueryExpr, PlanError> {
66    let mut not_found = Vec::with_capacity(project.len());
67    let mut col_ids = Vec::new();
68    let mut qualified_wildcards = Vec::new();
69    //Match columns to their tables...
70    for select_item in Vec::from(project) {
71        match select_item {
72            Column::UnnamedExpr(x) => match expr_for_projection(&table, x) {
73                Ok(field) => col_ids.push(field),
74                Err(PlanError::UnknownField { field, tables: _ }) => not_found.push(field),
75                Err(err) => return Err(err),
76            },
77            Column::QualifiedWildcard { table: name } => {
78                if let Some(t) = table.iter_tables().find(|x| *x.table_name == name) {
79                    for c in t.columns().iter() {
80                        col_ids.push(FieldName::new(t.table_id, c.col_pos).into());
81                    }
82                    qualified_wildcards.push(t.table_id);
83                } else {
84                    return Err(PlanError::TableNotFoundQualified { expect: name });
85                }
86            }
87            Column::Wildcard => {}
88        }
89    }
90
91    if !not_found.is_empty() {
92        return Err(PlanError::UnknownFields {
93            fields: not_found,
94            tables: table.table_names(),
95        });
96    }
97
98    let source_expr: SourceExpr = table.root.deref().into();
99    let mut q = QueryExpr::new(source_expr);
100
101    for join in table.joins {
102        match join {
103            Join::Inner { rhs, on } => {
104                let col_lhs = q.head().column_pos_or_err(on.lhs)?;
105                let rhs_source_expr: SourceExpr = rhs.deref().into();
106                let col_rhs = rhs_source_expr.head().column_pos_or_err(on.rhs)?;
107
108                match on.op {
109                    OpCmp::Eq => {}
110                    x => unreachable!("Unsupported operator `{x}` for joins"),
111                }
112                // Always construct inner joins, never semijoins.
113                // The query optimizer can rewrite certain inner joins into semijoins later in the pipeline.
114                // The full pipeline for a query like `SELECT lhs.* FROM lhs JOIN rhs ON lhs.a = rhs.a` is:
115                // - We produce `[JoinInner(semi: false), Project]`.
116                // - Optimizer rewrites to `[JoinInner(semi: true)]`.
117                // - Optimizer rewrites to `[IndexJoin]`.
118                // For incremental queries, this all happens on the original query with `DbTable` sources.
119                // Then, the query is "incrementalized" by replacing the sources with `MemTable`s,
120                // and the `IndexJoin` is rewritten back into a `JoinInner(semi: true)`.
121                q = q.with_join_inner(rhs_source_expr, col_lhs, col_rhs, false);
122            }
123        }
124    }
125
126    if let Some(filter) = selection {
127        q = compile_where(q, filter)?;
128    }
129    // It is important to project at the end.
130    // This is so joins and filters see fields that are not projected.
131    // It is also important to identify a wildcard project of the form `table.*`.
132    // This implies a potential semijoin and additional optimization opportunities.
133    let qualified_wildcard = (qualified_wildcards.len() == 1).then(|| qualified_wildcards[0]);
134    q = q.with_project(col_ids, qualified_wildcard)?;
135
136    Ok(q)
137}
138
139/// Builds the schema description [DbTable] from the [TableSchema] and their list of columns
140fn compile_columns(table: &TableSchema, cols: &[ColId]) -> DbTable {
141    let mut columns = Vec::with_capacity(cols.len());
142    let cols = cols
143        .iter()
144        // TODO: should we error here instead?
145        // When would the user be passing in columns that aren't present?
146        .filter_map(|col| table.get_column(col.idx()))
147        .map(|col| relation::Column::new(FieldName::new(table.table_id, col.col_pos), col.col_type.clone()));
148    columns.extend(cols);
149
150    let header = Header::from(table).project_col_list(&columns.iter().map(|x| x.field.col).collect());
151
152    DbTable::new(Arc::new(header), table.table_id, table.table_type, table.table_access)
153}
154
155/// Compiles a `INSERT ...` clause
156fn compile_insert(table: &TableSchema, cols: &[ColId], values: Box<[Box<[ColExpr]>]>) -> CrudExpr {
157    let table = compile_columns(table, cols);
158
159    let mut rows = Vec::with_capacity(values.len());
160    for x in Vec::from(values) {
161        let mut row = Vec::with_capacity(x.len());
162        for v in Vec::from(x) {
163            match v {
164                ColExpr::Col(x) => {
165                    todo!("Deal with idents in insert?: {}", x)
166                }
167                ColExpr::Value(x) => {
168                    row.push(x);
169                }
170            }
171        }
172        rows.push(row.into())
173    }
174
175    CrudExpr::Insert { table, rows }
176}
177
178/// Compiles a `DELETE ...` clause
179fn compile_delete(table: Arc<TableSchema>, selection: Option<Selection>) -> Result<CrudExpr, PlanError> {
180    let query = QueryExpr::new(&*table);
181    let query = if let Some(filter) = selection {
182        compile_where(query, filter)?
183    } else {
184        query
185    };
186    Ok(CrudExpr::Delete { query })
187}
188
189/// Compiles a `UPDATE ...` clause
190fn compile_update(
191    table: Arc<TableSchema>,
192    assignments: IntMap<ColId, ColExpr>,
193    selection: Option<Selection>,
194) -> Result<CrudExpr, PlanError> {
195    let query = QueryExpr::new(&*table);
196    let delete = if let Some(filter) = selection {
197        compile_where(query, filter)?
198    } else {
199        query
200    };
201
202    Ok(CrudExpr::Update { delete, assignments })
203}
204
205/// Compiles a `SQL` clause
206fn compile_statement(db: &RelationalDB, statement: SqlAst) -> Result<CrudExpr, PlanError> {
207    statement.type_check()?;
208
209    let q = match statement {
210        SqlAst::Select {
211            from,
212            project,
213            selection,
214        } => CrudExpr::Query(compile_select(from, project, selection)?),
215        SqlAst::Insert { table, columns, values } => compile_insert(&table, &columns, values),
216        SqlAst::Update {
217            table,
218            assignments,
219            selection,
220        } => compile_update(table, assignments, selection)?,
221        SqlAst::Delete { table, selection } => compile_delete(table, selection)?,
222        SqlAst::SetVar { name, literal } => CrudExpr::SetVar { name, literal },
223        SqlAst::ReadVar { name } => CrudExpr::ReadVar { name },
224    };
225
226    Ok(q.optimize(&|table_id, table_name| db.row_count(table_id, table_name)))
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use crate::db::relational_db::tests_utils::{begin_mut_tx, begin_tx, insert, with_auto_commit, TestDB};
233    use crate::sql::execute::tests::run_for_testing;
234    use spacetimedb_lib::error::{ResultTest, TestError};
235    use spacetimedb_lib::{ConnectionId, Identity};
236    use spacetimedb_primitives::{col_list, ColList, TableId};
237    use spacetimedb_sats::{product, AlgebraicType, AlgebraicValue, GroundSpacetimeType as _};
238    use spacetimedb_vm::expr::{ColumnOp, IndexJoin, IndexScan, JoinExpr, Query};
239    use std::convert::From;
240    use std::ops::Bound;
241
242    fn assert_index_scan(
243        op: &Query,
244        cols: impl Into<ColList>,
245        low_bound: Bound<AlgebraicValue>,
246        up_bound: Bound<AlgebraicValue>,
247    ) -> TableId {
248        if let Query::IndexScan(IndexScan { table, columns, bounds }) = op {
249            assert_eq!(columns, &cols.into(), "Columns don't match");
250            assert_eq!(bounds.0, low_bound, "Lower bound don't match");
251            assert_eq!(bounds.1, up_bound, "Upper bound don't match");
252            table.table_id
253        } else {
254            panic!("Expected IndexScan, got {op}");
255        }
256    }
257
258    fn assert_one_eq_index_scan(op: &Query, cols: impl Into<ColList>, val: AlgebraicValue) -> TableId {
259        let val = Bound::Included(val);
260        assert_index_scan(op, cols, val.clone(), val)
261    }
262
263    fn assert_select(op: &Query) {
264        assert!(matches!(op, Query::Select(_)));
265    }
266
267    fn compile_sql<T: TableSchemaView + StateView>(
268        db: &RelationalDB,
269        tx: &T,
270        sql: &str,
271    ) -> Result<Vec<CrudExpr>, DBError> {
272        super::compile_sql(db, &AuthCtx::for_testing(), tx, sql)
273    }
274
275    #[test]
276    fn compile_eq() -> ResultTest<()> {
277        let db = TestDB::durable()?;
278
279        // Create table [test] without any indexes
280        let schema = &[("a", AlgebraicType::U64)];
281        let indexes = &[];
282        db.create_table_for_test("test", schema, indexes)?;
283
284        let tx = begin_tx(&db);
285        // Compile query
286        let sql = "select * from test where a = 1";
287        let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
288            panic!("Expected QueryExpr");
289        };
290        assert_eq!(1, query.len());
291        assert_select(&query[0]);
292        Ok(())
293    }
294
295    #[test]
296    fn compile_not_eq() -> ResultTest<()> {
297        let db = TestDB::durable()?;
298
299        // Create table [test] with cols [a, b] and index on [b].
300        db.create_table_for_test(
301            "test",
302            &[("a", AlgebraicType::U64), ("b", AlgebraicType::U64)],
303            &[1.into(), 0.into()],
304        )?;
305
306        let tx = begin_tx(&db);
307        // Should work with any qualified field.
308        let sql = "select * from test where a = 1 and b <> 3";
309        let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
310            panic!("Expected QueryExpr");
311        };
312        assert_eq!(2, query.len());
313        assert_one_eq_index_scan(&query[0], 0, 1u64.into());
314        assert_select(&query[1]);
315        Ok(())
316    }
317
318    #[test]
319    fn compile_index_eq_basic() -> ResultTest<()> {
320        let db = TestDB::durable()?;
321
322        // Create table [test] with index on [a]
323        let schema = &[("a", AlgebraicType::U64)];
324        let indexes = &[0.into()];
325        db.create_table_for_test("test", schema, indexes)?;
326
327        let tx = begin_tx(&db);
328        //Compile query
329        let sql = "select * from test where a = 1";
330        let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
331            panic!("Expected QueryExpr");
332        };
333        assert_eq!(1, query.len());
334        assert_one_eq_index_scan(&query[0], 0, 1u64.into());
335        Ok(())
336    }
337
338    #[test]
339    fn compile_eq_identity_connection_id() -> ResultTest<()> {
340        let db = TestDB::durable()?;
341
342        // Create table [test] without any indexes
343        let schema = &[
344            ("identity", Identity::get_type()),
345            ("identity_mix", Identity::get_type()),
346            ("connection_id", ConnectionId::get_type()),
347        ];
348        let indexes = &[];
349        let table_id = db.create_table_for_test("test", schema, indexes)?;
350
351        let row = product![
352            Identity::__dummy(),
353            Identity::from_hex("93dda09db9a56d8fa6c024d843e805d8262191db3b4ba84c5efcd1ad451fed4e").unwrap(),
354            ConnectionId::ZERO,
355        ];
356
357        with_auto_commit(&db, |tx| {
358            insert(&db, tx, table_id, &row.clone())?;
359            Ok::<(), TestError>(())
360        })?;
361
362        // Check can be used by CRUD ops:
363        let sql = &format!(
364            "INSERT INTO test (identity, identity_mix, connection_id) VALUES ({}, x'91DDA09DB9A56D8FA6C024D843E805D8262191DB3B4BA84C5EFCD1AD451FED4E', {})",
365            Identity::__dummy(),
366            ConnectionId::ZERO,
367        );
368        run_for_testing(&db, sql)?;
369
370        // Compile query, check for both hex formats and it to be case-insensitive...
371        let sql = &format!(
372            "select * from test where identity = {} AND identity_mix = x'93dda09db9a56d8fa6c024d843e805D8262191db3b4bA84c5efcd1ad451fed4e' AND connection_id = x'{}' AND connection_id = {}",
373            Identity::__dummy(),
374            ConnectionId::ZERO,
375            ConnectionId::ZERO,
376        );
377
378        let rows = run_for_testing(&db, sql)?;
379
380        let tx = begin_tx(&db);
381        let CrudExpr::Query(QueryExpr {
382            source: _,
383            query: mut ops,
384        }) = compile_sql(&db, &tx, sql)?.remove(0)
385        else {
386            panic!("Expected QueryExpr");
387        };
388
389        assert_eq!(1, ops.len());
390
391        // Assert no index scan
392        let Query::Select(_) = ops.remove(0) else {
393            panic!("Expected Select");
394        };
395
396        assert_eq!(rows, vec![row]);
397
398        Ok(())
399    }
400
401    #[test]
402    fn compile_eq_and_eq() -> ResultTest<()> {
403        let db = TestDB::durable()?;
404
405        // Create table [test] with index on [b]
406        let schema = &[("a", AlgebraicType::U64), ("b", AlgebraicType::U64)];
407        let indexes = &[1.into()];
408        db.create_table_for_test("test", schema, indexes)?;
409
410        let tx = begin_tx(&db);
411        // Note, order does not matter.
412        // The sargable predicate occurs last, but we can still generate an index scan.
413        let sql = "select * from test where a = 1 and b = 2";
414        let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
415            panic!("Expected QueryExpr");
416        };
417        assert_eq!(2, query.len());
418        assert_one_eq_index_scan(&query[0], 1, 2u64.into());
419        assert_select(&query[1]);
420        Ok(())
421    }
422
423    #[test]
424    fn compile_index_eq_and_eq() -> ResultTest<()> {
425        let db = TestDB::durable()?;
426
427        // Create table [test] with index on [b]
428        let schema = &[("a", AlgebraicType::U64), ("b", AlgebraicType::U64)];
429        let indexes = &[1.into()];
430        db.create_table_for_test("test", schema, indexes)?;
431
432        let tx = begin_tx(&db);
433        // Note, order does not matter.
434        // The sargable predicate occurs first and we can generate an index scan.
435        let sql = "select * from test where b = 2 and a = 1";
436        let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
437            panic!("Expected QueryExpr");
438        };
439        assert_eq!(2, query.len());
440        assert_one_eq_index_scan(&query[0], 1, 2u64.into());
441        assert_select(&query[1]);
442        Ok(())
443    }
444
445    #[test]
446    fn compile_index_multi_eq_and_eq() -> ResultTest<()> {
447        let db = TestDB::durable()?;
448
449        // Create table [test] with index on [b]
450        let schema = &[
451            ("a", AlgebraicType::U64),
452            ("b", AlgebraicType::U64),
453            ("c", AlgebraicType::U64),
454            ("d", AlgebraicType::U64),
455        ];
456        db.create_table_for_test_multi_column("test", schema, col_list![0, 1])?;
457
458        let tx = begin_mut_tx(&db);
459        let sql = "select * from test where b = 2 and a = 1";
460        let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
461            panic!("Expected QueryExpr");
462        };
463        assert_eq!(1, query.len());
464        assert_one_eq_index_scan(&query[0], col_list![0, 1], product![1u64, 2u64].into());
465        Ok(())
466    }
467
468    #[test]
469    fn compile_eq_or_eq() -> ResultTest<()> {
470        let db = TestDB::durable()?;
471
472        // Create table [test] with indexes on [a] and [b]
473        let schema = &[("a", AlgebraicType::U64), ("b", AlgebraicType::U64)];
474        let indexes = &[0.into(), 1.into()];
475        db.create_table_for_test("test", schema, indexes)?;
476
477        let tx = begin_tx(&db);
478        // Compile query
479        let sql = "select * from test where a = 1 or b = 2";
480        let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
481            panic!("Expected QueryExpr");
482        };
483        assert_eq!(1, query.len());
484        // Assert no index scan because OR is not sargable.
485        assert_select(&query[0]);
486        Ok(())
487    }
488
489    #[test]
490    fn compile_index_range_open() -> ResultTest<()> {
491        let db = TestDB::durable()?;
492
493        // Create table [test] with indexes on [b]
494        let schema = &[("a", AlgebraicType::U64), ("b", AlgebraicType::U64)];
495        let indexes = &[1.into()];
496        db.create_table_for_test("test", schema, indexes)?;
497
498        let tx = begin_tx(&db);
499        // Compile query
500        let sql = "select * from test where b > 2";
501        let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
502            panic!("Expected QueryExpr");
503        };
504        assert_eq!(1, query.len());
505        assert_index_scan(&query[0], 1, Bound::Excluded(AlgebraicValue::U64(2)), Bound::Unbounded);
506
507        Ok(())
508    }
509
510    #[test]
511    fn compile_index_range_closed() -> ResultTest<()> {
512        let db = TestDB::durable()?;
513
514        // Create table [test] with indexes on [b]
515        let schema = &[("a", AlgebraicType::U64), ("b", AlgebraicType::U64)];
516        let indexes = &[1.into()];
517        db.create_table_for_test("test", schema, indexes)?;
518
519        let tx = begin_tx(&db);
520        // Compile query
521        let sql = "select * from test where b > 2 and b < 5";
522        let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
523            panic!("Expected QueryExpr");
524        };
525        assert_eq!(1, query.len());
526        assert_index_scan(
527            &query[0],
528            1,
529            Bound::Excluded(AlgebraicValue::U64(2)),
530            Bound::Excluded(AlgebraicValue::U64(5)),
531        );
532
533        Ok(())
534    }
535
536    #[test]
537    fn compile_index_eq_select_range() -> ResultTest<()> {
538        let db = TestDB::durable()?;
539
540        // Create table [test] with indexes on [a] and [b]
541        let schema = &[("a", AlgebraicType::U64), ("b", AlgebraicType::U64)];
542        let indexes = &[0.into(), 1.into()];
543        db.create_table_for_test("test", schema, indexes)?;
544
545        let tx = begin_tx(&db);
546        // Note, order matters - the equality condition occurs first which
547        // means an index scan will be generated rather than the range condition.
548        let sql = "select * from test where a = 3 and b > 2 and b < 5";
549        let CrudExpr::Query(QueryExpr { source: _, query }) = compile_sql(&db, &tx, sql)?.remove(0) else {
550            panic!("Expected QueryExpr");
551        };
552        assert_eq!(2, query.len());
553        assert_one_eq_index_scan(&query[0], 0, 3u64.into());
554        assert_select(&query[1]);
555        Ok(())
556    }
557
558    #[test]
559    fn compile_join_lhs_push_down() -> ResultTest<()> {
560        let db = TestDB::durable()?;
561
562        // Create table [lhs] with index on [a]
563        let schema = &[("a", AlgebraicType::U64), ("b", AlgebraicType::U64)];
564        let indexes = &[0.into()];
565        let lhs_id = db.create_table_for_test("lhs", schema, indexes)?;
566
567        // Create table [rhs] with no indexes
568        let schema = &[("b", AlgebraicType::U64), ("c", AlgebraicType::U64)];
569        let indexes = &[];
570        let rhs_id = db.create_table_for_test("rhs", schema, indexes)?;
571
572        let tx = begin_tx(&db);
573        // Should push sargable equality condition below join
574        let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where lhs.a = 3";
575        let exp = compile_sql(&db, &tx, sql)?.remove(0);
576
577        let CrudExpr::Query(QueryExpr {
578            source: source_lhs,
579            query,
580            ..
581        }) = exp
582        else {
583            panic!("unexpected expression: {exp:#?}");
584        };
585
586        assert_eq!(source_lhs.table_id().unwrap(), lhs_id);
587        assert_eq!(query.len(), 3);
588
589        // First operation in the pipeline should be an index scan
590        let table_id = assert_one_eq_index_scan(&query[0], 0, 3u64.into());
591
592        assert_eq!(table_id, lhs_id);
593
594        // Followed by a join with the rhs table
595        let Query::JoinInner(JoinExpr {
596            ref rhs,
597            col_lhs,
598            col_rhs,
599            inner: Some(ref inner_header),
600        }) = query[1]
601        else {
602            panic!("unexpected operator {:#?}", query[1]);
603        };
604
605        assert_eq!(rhs.source.table_id().unwrap(), rhs_id);
606        assert_eq!(col_lhs, 1.into());
607        assert_eq!(col_rhs, 0.into());
608        assert_eq!(&**inner_header, &source_lhs.head().extend(rhs.source.head()));
609        Ok(())
610    }
611
612    #[test]
613    fn compile_join_lhs_push_down_no_index() -> ResultTest<()> {
614        let db = TestDB::durable()?;
615
616        // Create table [lhs] with no indexes
617        let schema = &[("a", AlgebraicType::U64), ("b", AlgebraicType::U64)];
618        let lhs_id = db.create_table_for_test("lhs", schema, &[])?;
619
620        // Create table [rhs] with no indexes
621        let schema = &[("b", AlgebraicType::U64), ("c", AlgebraicType::U64)];
622        let rhs_id = db.create_table_for_test("rhs", schema, &[])?;
623
624        let tx = begin_tx(&db);
625        // Should push equality condition below join
626        let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where lhs.a = 3";
627        let exp = compile_sql(&db, &tx, sql)?.remove(0);
628
629        let CrudExpr::Query(QueryExpr {
630            source: source_lhs,
631            query,
632            ..
633        }) = exp
634        else {
635            panic!("unexpected expression: {exp:#?}");
636        };
637        assert_eq!(source_lhs.table_id().unwrap(), lhs_id);
638        assert_eq!(query.len(), 3);
639
640        // The first operation in the pipeline should be a selection with `col#0 = 3`
641        let Query::Select(ColumnOp::ColCmpVal {
642            cmp: OpCmp::Eq,
643            lhs: ColId(0),
644            rhs: AlgebraicValue::U64(3),
645        }) = query[0]
646        else {
647            panic!("unexpected operator {:#?}", query[0]);
648        };
649
650        // The join should follow the selection
651        let Query::JoinInner(JoinExpr {
652            ref rhs,
653            col_lhs,
654            col_rhs,
655            inner: Some(ref inner_header),
656        }) = query[1]
657        else {
658            panic!("unexpected operator {:#?}", query[1]);
659        };
660
661        assert_eq!(rhs.source.table_id().unwrap(), rhs_id);
662        assert_eq!(col_lhs, 1.into());
663        assert_eq!(col_rhs, 0.into());
664        assert_eq!(&**inner_header, &source_lhs.head().extend(rhs.source.head()));
665        assert!(rhs.query.is_empty());
666        Ok(())
667    }
668
669    #[test]
670    fn compile_join_rhs_push_down_no_index() -> ResultTest<()> {
671        let db = TestDB::durable()?;
672
673        // Create table [lhs] with no indexes
674        let schema = &[("a", AlgebraicType::U64), ("b", AlgebraicType::U64)];
675        let lhs_id = db.create_table_for_test("lhs", schema, &[])?;
676
677        // Create table [rhs] with no indexes
678        let schema = &[("b", AlgebraicType::U64), ("c", AlgebraicType::U64)];
679        let rhs_id = db.create_table_for_test("rhs", schema, &[])?;
680
681        let tx = begin_tx(&db);
682        // Should push equality condition below join
683        let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c = 3";
684        let exp = compile_sql(&db, &tx, sql)?.remove(0);
685
686        let CrudExpr::Query(QueryExpr {
687            source: source_lhs,
688            query,
689            ..
690        }) = exp
691        else {
692            panic!("unexpected expression: {exp:#?}");
693        };
694
695        assert_eq!(source_lhs.table_id().unwrap(), lhs_id);
696        assert_eq!(query.len(), 1);
697
698        // First and only operation in the pipeline should be a join
699        let Query::JoinInner(JoinExpr {
700            ref rhs,
701            col_lhs,
702            col_rhs,
703            inner: None,
704        }) = query[0]
705        else {
706            panic!("unexpected operator {:#?}", query[0]);
707        };
708
709        assert_eq!(rhs.source.table_id().unwrap(), rhs_id);
710        assert_eq!(col_lhs, 1.into());
711        assert_eq!(col_rhs, 0.into());
712
713        // The selection should be pushed onto the rhs of the join
714        let Query::Select(ColumnOp::ColCmpVal {
715            cmp: OpCmp::Eq,
716            lhs: ColId(1),
717            rhs: AlgebraicValue::U64(3),
718        }) = rhs.query[0]
719        else {
720            panic!("unexpected operator {:#?}", rhs.query[0]);
721        };
722        Ok(())
723    }
724
725    #[test]
726    fn compile_join_lhs_and_rhs_push_down() -> ResultTest<()> {
727        let db = TestDB::durable()?;
728
729        // Create table [lhs] with index on [a]
730        let schema = &[("a", AlgebraicType::U64), ("b", AlgebraicType::U64)];
731        let indexes = &[0.into()];
732        let lhs_id = db.create_table_for_test("lhs", schema, indexes)?;
733
734        // Create table [rhs] with index on [c]
735        let schema = &[("b", AlgebraicType::U64), ("c", AlgebraicType::U64)];
736        let indexes = &[1.into()];
737        let rhs_id = db.create_table_for_test("rhs", schema, indexes)?;
738
739        let tx = begin_tx(&db);
740        // Should push the sargable equality condition into the join's left arg.
741        // Should push the sargable range condition into the join's right arg.
742        let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where lhs.a = 3 and rhs.c < 4";
743        let exp = compile_sql(&db, &tx, sql)?.remove(0);
744
745        let CrudExpr::Query(QueryExpr {
746            source: source_lhs,
747            query,
748            ..
749        }) = exp
750        else {
751            panic!("unexpected result from compilation: {exp:?}");
752        };
753
754        assert_eq!(source_lhs.table_id().unwrap(), lhs_id);
755        assert_eq!(query.len(), 3);
756
757        // First operation in the pipeline should be an index scan
758        let table_id = assert_one_eq_index_scan(&query[0], 0, 3u64.into());
759
760        assert_eq!(table_id, lhs_id);
761
762        // Followed by a join
763        let Query::JoinInner(JoinExpr {
764            ref rhs,
765            col_lhs,
766            col_rhs,
767            inner: Some(ref inner_header),
768        }) = query[1]
769        else {
770            panic!("unexpected operator {:#?}", query[1]);
771        };
772
773        assert_eq!(rhs.source.table_id().unwrap(), rhs_id);
774        assert_eq!(col_lhs, 1.into());
775        assert_eq!(col_rhs, 0.into());
776        assert_eq!(&**inner_header, &source_lhs.head().extend(rhs.source.head()));
777
778        assert_eq!(1, rhs.query.len());
779
780        // The right side of the join should be an index scan
781        let table_id = assert_index_scan(
782            &rhs.query[0],
783            1,
784            Bound::Unbounded,
785            Bound::Excluded(AlgebraicValue::U64(4)),
786        );
787
788        assert_eq!(table_id, rhs_id);
789        Ok(())
790    }
791
792    #[test]
793    fn compile_index_join() -> ResultTest<()> {
794        let db = TestDB::durable()?;
795
796        // Create table [lhs] with index on [b]
797        let schema = &[("a", AlgebraicType::U64), ("b", AlgebraicType::U64)];
798        let indexes = &[1.into()];
799        let lhs_id = db.create_table_for_test("lhs", schema, indexes)?;
800
801        // Create table [rhs] with index on [b, c]
802        let schema = &[
803            ("b", AlgebraicType::U64),
804            ("c", AlgebraicType::U64),
805            ("d", AlgebraicType::U64),
806        ];
807        let indexes = &[0.into(), 1.into()];
808        let rhs_id = db.create_table_for_test("rhs", schema, indexes)?;
809
810        let tx = begin_tx(&db);
811        // Should generate an index join since there is an index on `lhs.b`.
812        // Should push the sargable range condition into the index join's probe side.
813        let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c > 2 and rhs.c < 4 and rhs.d = 3";
814        let exp = compile_sql(&db, &tx, sql)?.remove(0);
815
816        let CrudExpr::Query(QueryExpr {
817            source: SourceExpr::DbTable(DbTable { table_id, .. }),
818            query,
819            ..
820        }) = exp
821        else {
822            panic!("unexpected result from compilation: {exp:?}");
823        };
824
825        assert_eq!(table_id, lhs_id);
826        assert_eq!(query.len(), 1);
827
828        let Query::IndexJoin(IndexJoin {
829            probe_side:
830                QueryExpr {
831                    source: SourceExpr::DbTable(DbTable { table_id, .. }),
832                    query: rhs,
833                },
834            probe_col,
835            index_side: SourceExpr::DbTable(DbTable {
836                table_id: index_table, ..
837            }),
838            index_col,
839            ..
840        }) = &query[0]
841        else {
842            panic!("unexpected operator {:#?}", query[0]);
843        };
844
845        assert_eq!(*table_id, rhs_id);
846        assert_eq!(*index_table, lhs_id);
847        assert_eq!(index_col, &1.into());
848        assert_eq!(*probe_col, 0.into());
849
850        assert_eq!(2, rhs.len());
851
852        // The probe side of the join should be an index scan
853        let table_id = assert_index_scan(
854            &rhs[0],
855            1,
856            Bound::Excluded(AlgebraicValue::U64(2)),
857            Bound::Excluded(AlgebraicValue::U64(4)),
858        );
859
860        assert_eq!(table_id, rhs_id);
861
862        // Followed by a selection
863        let Query::Select(ColumnOp::ColCmpVal {
864            cmp: OpCmp::Eq,
865            lhs: ColId(2),
866            rhs: AlgebraicValue::U64(3),
867        }) = rhs[1]
868        else {
869            panic!("unexpected operator {:#?}", rhs[0]);
870        };
871        Ok(())
872    }
873
874    #[test]
875    fn compile_index_multi_join() -> ResultTest<()> {
876        let db = TestDB::durable()?;
877
878        // Create table [lhs] with index on [b]
879        let schema = &[("a", AlgebraicType::U64), ("b", AlgebraicType::U64)];
880        let indexes = &[1.into()];
881        let lhs_id = db.create_table_for_test("lhs", schema, indexes)?;
882
883        // Create table [rhs] with index on [b, c]
884        let schema = &[
885            ("b", AlgebraicType::U64),
886            ("c", AlgebraicType::U64),
887            ("d", AlgebraicType::U64),
888        ];
889        let indexes = col_list![0, 1];
890        let rhs_id = db.create_table_for_test_multi_column("rhs", schema, indexes)?;
891
892        let tx = begin_tx(&db);
893        // Should generate an index join since there is an index on `lhs.b`.
894        // Should push the sargable range condition into the index join's probe side.
895        let sql = "select lhs.* from lhs join rhs on lhs.b = rhs.b where rhs.c = 2 and rhs.b = 4 and rhs.d = 3";
896        let exp = compile_sql(&db, &tx, sql)?.remove(0);
897
898        let CrudExpr::Query(QueryExpr {
899            source: SourceExpr::DbTable(DbTable { table_id, .. }),
900            query,
901            ..
902        }) = exp
903        else {
904            panic!("unexpected result from compilation: {exp:?}");
905        };
906
907        assert_eq!(table_id, lhs_id);
908        assert_eq!(query.len(), 1);
909
910        let Query::IndexJoin(IndexJoin {
911            probe_side:
912                QueryExpr {
913                    source: SourceExpr::DbTable(DbTable { table_id, .. }),
914                    query: rhs,
915                },
916            probe_col,
917            index_side: SourceExpr::DbTable(DbTable {
918                table_id: index_table, ..
919            }),
920            index_col,
921            ..
922        }) = &query[0]
923        else {
924            panic!("unexpected operator {:#?}", query[0]);
925        };
926
927        assert_eq!(*table_id, rhs_id);
928        assert_eq!(*index_table, lhs_id);
929        assert_eq!(index_col, &1.into());
930        assert_eq!(*probe_col, 0.into());
931
932        assert_eq!(2, rhs.len());
933
934        // The probe side of the join should be an index scan
935        let table_id = assert_one_eq_index_scan(&rhs[0], col_list![0, 1], product![4u64, 2u64].into());
936
937        assert_eq!(table_id, rhs_id);
938
939        // Followed by a selection
940        let Query::Select(ColumnOp::ColCmpVal {
941            cmp: OpCmp::Eq,
942            lhs: ColId(2),
943            rhs: AlgebraicValue::U64(3),
944        }) = rhs[1]
945        else {
946            panic!("unexpected operator {:#?}", rhs[0]);
947        };
948        Ok(())
949    }
950
951    #[test]
952    fn compile_join_with_diff_col_names() -> ResultTest<()> {
953        let db = TestDB::durable()?;
954        db.create_table_for_test("A", &[("x", AlgebraicType::U64)], &[])?;
955        db.create_table_for_test("B", &[("y", AlgebraicType::U64)], &[])?;
956        assert!(compile_sql(&db, &begin_tx(&db), "select B.* from B join A on B.y = A.x").is_ok());
957        Ok(())
958    }
959
960    #[test]
961    fn compile_type_check() -> ResultTest<()> {
962        let db = TestDB::durable()?;
963        db.create_table_for_test("PlayerState", &[("entity_id", AlgebraicType::U64)], &[0.into()])?;
964        db.create_table_for_test("EnemyState", &[("entity_id", AlgebraicType::I8)], &[0.into()])?;
965        db.create_table_for_test("FriendState", &[("entity_id", AlgebraicType::U64)], &[0.into()])?;
966        let sql = "SELECT * FROM PlayerState WHERE entity_id = '161853'";
967
968        // Should fail with type mismatch for selections and joins.
969        //
970        // TODO: Type check other operations deferred for the new query engine.
971
972        assert!(
973            compile_sql(&db, &begin_tx(&db), sql).is_err(),
974            // Err("SqlError: Type Mismatch: `PlayerState.entity_id: U64` != `String(\"161853\"): String`, executing: `SELECT * FROM PlayerState WHERE entity_id = '161853'`".into())
975        );
976
977        // Check we can still compile the query if we remove the type mismatch and have multiple logical operations.
978        let sql = "SELECT * FROM PlayerState WHERE entity_id = 1 AND entity_id = 2 AND entity_id = 3 OR entity_id = 4 OR entity_id = 5";
979
980        assert!(compile_sql(&db, &begin_tx(&db), sql).is_ok());
981
982        // Now verify when we have a type mismatch in the middle of the logical operations.
983        let sql = "SELECT * FROM PlayerState WHERE entity_id = 1 AND entity_id";
984
985        assert!(
986            compile_sql(&db, &begin_tx(&db), sql).is_err(),
987            // Err("SqlError: Type Mismatch: `PlayerState.entity_id: U64 == U64(1): U64` and `PlayerState.entity_id: U64`, both sides must be an `Bool` expression, executing: `SELECT * FROM PlayerState WHERE entity_id = 1 AND entity_id`".into())
988        );
989        // Verify that all operands of `AND` must be `Bool`.
990        let sql = "SELECT * FROM PlayerState WHERE entity_id AND entity_id";
991
992        assert!(
993            compile_sql(&db, &begin_tx(&db), sql).is_err(),
994            // Err("SqlError: Type Mismatch: `PlayerState.entity_id: U64` and `PlayerState.entity_id: U64`, both sides must be an `Bool` expression, executing: `SELECT * FROM PlayerState WHERE entity_id AND entity_id`".into())
995        );
996        Ok(())
997    }
998}