spacetimedb/sql/
execute.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use super::ast::SchemaViewer;
5use crate::db::relational_db::{RelationalDB, Tx};
6use crate::energy::EnergyQuanta;
7use crate::error::DBError;
8use crate::estimation::estimate_rows_scanned;
9use crate::host::module_host::{DatabaseTableUpdate, DatabaseUpdate, EventStatus, ModuleEvent, ModuleFunctionCall};
10use crate::host::ArgsTuple;
11use crate::subscription::module_subscription_actor::{ModuleSubscriptions, WriteConflict};
12use crate::subscription::tx::DeltaTx;
13use crate::util::slow::SlowQueryLogger;
14use crate::vm::{check_row_limit, DbProgram, TxMode};
15use anyhow::anyhow;
16use spacetimedb_datastore::execution_context::Workload;
17use spacetimedb_datastore::locking_tx_datastore::state_view::StateView;
18use spacetimedb_datastore::traits::IsolationLevel;
19use spacetimedb_expr::statement::Statement;
20use spacetimedb_lib::identity::AuthCtx;
21use spacetimedb_lib::metrics::ExecutionMetrics;
22use spacetimedb_lib::Timestamp;
23use spacetimedb_lib::{AlgebraicType, ProductType, ProductValue};
24use spacetimedb_query::{compile_sql_stmt, execute_dml_stmt, execute_select_stmt};
25use spacetimedb_schema::relation::FieldName;
26use spacetimedb_vm::eval::run_ast;
27use spacetimedb_vm::expr::{CodeResult, CrudExpr, Expr};
28use spacetimedb_vm::relation::MemTable;
29
30pub struct StmtResult {
31    pub schema: ProductType,
32    pub rows: Vec<ProductValue>,
33}
34
35// TODO(cloutiertyler): we could do this the swift parsing way in which
36// we always generate a plan, but it may contain errors
37
38pub(crate) fn collect_result(
39    result: &mut Vec<MemTable>,
40    updates: &mut Vec<DatabaseTableUpdate>,
41    r: CodeResult,
42) -> Result<(), DBError> {
43    match r {
44        CodeResult::Value(_) => {}
45        CodeResult::Table(x) => result.push(x),
46        CodeResult::Block(lines) => {
47            for x in lines {
48                collect_result(result, updates, x)?;
49            }
50        }
51        CodeResult::Halt(err) => return Err(DBError::VmUser(err)),
52        CodeResult::Pass(x) => match x {
53            None => {}
54            Some(update) => {
55                updates.push(DatabaseTableUpdate {
56                    table_name: update.table_name,
57                    table_id: update.table_id,
58                    inserts: update.inserts.into(),
59                    deletes: update.deletes.into(),
60                });
61            }
62        },
63    }
64
65    Ok(())
66}
67
68fn execute(
69    p: &mut DbProgram<'_, '_>,
70    ast: Vec<CrudExpr>,
71    sql: &str,
72    updates: &mut Vec<DatabaseTableUpdate>,
73) -> Result<Vec<MemTable>, DBError> {
74    let slow_query_threshold = if let TxMode::Tx(tx) = p.tx {
75        p.db.query_limit(tx)?.map(Duration::from_millis)
76    } else {
77        None
78    };
79    let _slow_query_logger = SlowQueryLogger::new(sql, slow_query_threshold, p.tx.ctx().workload()).log_guard();
80    let mut result = Vec::with_capacity(ast.len());
81    let query = Expr::Block(ast.into_iter().map(|x| Expr::Crud(Box::new(x))).collect());
82    // SQL queries can never reference `MemTable`s, so pass an empty `SourceSet`.
83    collect_result(&mut result, updates, run_ast(p, query, [].into()).into())?;
84    Ok(result)
85}
86
87/// Run the compiled `SQL` expression inside the `vm` created by [DbProgram]
88///
89/// Evaluates `ast` and accordingly triggers mutable or read tx to execute
90///
91/// Also, in case the execution takes more than x, log it as `slow query`
92pub fn execute_sql(
93    db: &RelationalDB,
94    sql: &str,
95    ast: Vec<CrudExpr>,
96    auth: AuthCtx,
97    subs: Option<&ModuleSubscriptions>,
98) -> Result<Vec<MemTable>, DBError> {
99    if CrudExpr::is_reads(&ast) {
100        let mut updates = Vec::new();
101        db.with_read_only(Workload::Sql, |tx| {
102            execute(
103                &mut DbProgram::new(db, &mut TxMode::Tx(tx), auth),
104                ast,
105                sql,
106                &mut updates,
107            )
108        })
109    } else if subs.is_none() {
110        let mut updates = Vec::new();
111        db.with_auto_commit(Workload::Sql, |mut_tx| {
112            execute(
113                &mut DbProgram::new(db, &mut mut_tx.into(), auth),
114                ast,
115                sql,
116                &mut updates,
117            )
118        })
119    } else {
120        let mut tx = db.begin_mut_tx(IsolationLevel::Serializable, Workload::Sql);
121        let mut updates = Vec::with_capacity(ast.len());
122        let res = execute(
123            &mut DbProgram::new(db, &mut (&mut tx).into(), auth),
124            ast,
125            sql,
126            &mut updates,
127        );
128        if res.is_ok() && !updates.is_empty() {
129            let event = ModuleEvent {
130                timestamp: Timestamp::now(),
131                caller_identity: auth.caller,
132                caller_connection_id: None,
133                function_call: ModuleFunctionCall {
134                    reducer: String::new(),
135                    reducer_id: u32::MAX.into(),
136                    args: ArgsTuple::default(),
137                },
138                status: EventStatus::Committed(DatabaseUpdate { tables: updates }),
139                energy_quanta_used: EnergyQuanta::ZERO,
140                host_execution_duration: Duration::ZERO,
141                request_id: None,
142                timer: None,
143            };
144            match subs.unwrap().commit_and_broadcast_event(None, event, tx).unwrap() {
145                Ok(_) => res,
146                Err(WriteConflict) => todo!("See module_host_actor::call_reducer_with_tx"),
147            }
148        } else {
149            db.finish_tx(tx, res)
150        }
151    }
152}
153
154/// Like [`execute_sql`], but for providing your own `tx`.
155///
156/// Returns None if you pass a mutable query with an immutable tx.
157pub fn execute_sql_tx<'a>(
158    db: &RelationalDB,
159    tx: impl Into<TxMode<'a>>,
160    sql: &str,
161    ast: Vec<CrudExpr>,
162    auth: AuthCtx,
163) -> Result<Option<Vec<MemTable>>, DBError> {
164    let mut tx = tx.into();
165
166    if matches!(tx, TxMode::Tx(_)) && !CrudExpr::is_reads(&ast) {
167        return Ok(None);
168    }
169
170    let mut updates = Vec::new(); // No subscription updates in this path, because it requires owning the tx.
171    execute(&mut DbProgram::new(db, &mut tx, auth), ast, sql, &mut updates).map(Some)
172}
173
174pub struct SqlResult {
175    pub rows: Vec<ProductValue>,
176    /// These metrics will be reported via `report_tx_metrics`.
177    /// They should not be reported separately to avoid double counting.
178    pub metrics: ExecutionMetrics,
179}
180
181/// Run the `SQL` string using the `auth` credentials
182pub fn run(
183    db: &RelationalDB,
184    sql_text: &str,
185    auth: AuthCtx,
186    subs: Option<&ModuleSubscriptions>,
187    head: &mut Vec<(Box<str>, AlgebraicType)>,
188) -> Result<SqlResult, DBError> {
189    // We parse the sql statement in a mutable transaction.
190    // If it turns out to be a query, we downgrade the tx.
191    let (tx, stmt) = db.with_auto_rollback(db.begin_mut_tx(IsolationLevel::Serializable, Workload::Sql), |tx| {
192        compile_sql_stmt(sql_text, &SchemaViewer::new(tx, &auth), &auth)
193    })?;
194
195    let mut metrics = ExecutionMetrics::default();
196
197    match stmt {
198        Statement::Select(stmt) => {
199            // Up to this point, the tx has been read-only,
200            // and hence there are no deltas to process.
201            let (tx_data, tx_metrics_mut, tx) = tx.commit_downgrade(Workload::Sql);
202
203            // Release the tx on drop, so that we record metrics.
204            let mut tx = scopeguard::guard(tx, |tx| {
205                let (tx_metrics_downgrade, reducer) = db.release_tx(tx);
206                db.report_tx_metrics(
207                    reducer,
208                    Some(Arc::new(tx_data)),
209                    Some(tx_metrics_mut),
210                    Some(tx_metrics_downgrade),
211                );
212            });
213
214            // Compute the header for the result set
215            stmt.for_each_return_field(|col_name, col_type| {
216                head.push((col_name.into(), col_type.clone()));
217            });
218
219            // Evaluate the query
220            let rows = execute_select_stmt(stmt, &DeltaTx::from(&*tx), &mut metrics, |plan| {
221                check_row_limit(
222                    &[&plan],
223                    db,
224                    &tx,
225                    |plan, tx| plan.plan_iter().map(|plan| estimate_rows_scanned(tx, plan)).sum(),
226                    &auth,
227                )?;
228                Ok(plan)
229            })?;
230
231            // Update transaction metrics
232            tx.metrics.merge(metrics);
233
234            Ok(SqlResult {
235                rows,
236                metrics: tx.metrics,
237            })
238        }
239        Statement::DML(stmt) => {
240            // An extra layer of auth is required for DML
241            if auth.caller != auth.owner {
242                return Err(anyhow!("Only owners are authorized to run SQL DML statements").into());
243            }
244
245            // Evaluate the mutation
246            let (mut tx, _) = db.with_auto_rollback(tx, |tx| execute_dml_stmt(stmt, tx, &mut metrics))?;
247
248            // Update transaction metrics
249            tx.metrics.merge(metrics);
250
251            // Commit the tx if there are no deltas to process
252            if subs.is_none() {
253                let metrics = tx.metrics;
254                return db.commit_tx(tx).map(|tx_opt| {
255                    if let Some((tx_data, tx_metrics, reducer)) = tx_opt {
256                        db.report_mut_tx_metrics(reducer, tx_metrics, Some(tx_data));
257                    }
258                    SqlResult { rows: vec![], metrics }
259                });
260            }
261
262            // Otherwise downgrade the tx and process the deltas.
263            // Note, we get the delta by downgrading the tx.
264            // Hence we just pass a default `DatabaseUpdate` here.
265            // It will ultimately be replaced with the correct one.
266            match subs
267                .unwrap()
268                .commit_and_broadcast_event(
269                    None,
270                    ModuleEvent {
271                        timestamp: Timestamp::now(),
272                        caller_identity: auth.caller,
273                        caller_connection_id: None,
274                        function_call: ModuleFunctionCall {
275                            reducer: String::new(),
276                            reducer_id: u32::MAX.into(),
277                            args: ArgsTuple::default(),
278                        },
279                        status: EventStatus::Committed(DatabaseUpdate::default()),
280                        energy_quanta_used: EnergyQuanta::ZERO,
281                        host_execution_duration: Duration::ZERO,
282                        request_id: None,
283                        timer: None,
284                    },
285                    tx,
286                )
287                .unwrap()
288            {
289                Err(WriteConflict) => {
290                    todo!("See module_host_actor::call_reducer_with_tx")
291                }
292                Ok(_) => Ok(SqlResult { rows: vec![], metrics }),
293            }
294        }
295    }
296}
297
298/// Translates a `FieldName` to the field's name.
299pub fn translate_col(tx: &Tx, field: FieldName) -> Option<Box<str>> {
300    Some(
301        tx.get_schema(field.table)?
302            .get_column(field.col.idx())?
303            .col_name
304            .clone(),
305    )
306}
307
308#[cfg(test)]
309pub(crate) mod tests {
310    use std::sync::Arc;
311
312    use super::*;
313    use crate::db::relational_db::tests_utils::{begin_tx, insert, with_auto_commit, TestDB};
314    use crate::vm::tests::create_table_with_rows;
315    use itertools::Itertools;
316    use pretty_assertions::assert_eq;
317    use spacetimedb_datastore::system_tables::{
318        StRowLevelSecurityRow, StTableFields, ST_ROW_LEVEL_SECURITY_ID, ST_TABLE_ID, ST_TABLE_NAME,
319    };
320    use spacetimedb_lib::bsatn::ToBsatn;
321    use spacetimedb_lib::db::auth::{StAccess, StTableType};
322    use spacetimedb_lib::error::{ResultTest, TestError};
323    use spacetimedb_lib::{AlgebraicValue, Identity};
324    use spacetimedb_primitives::{col_list, ColId, TableId};
325    use spacetimedb_sats::{product, AlgebraicType, ArrayValue, ProductType};
326    use spacetimedb_schema::relation::Header;
327    use spacetimedb_vm::eval::test_helpers::create_game_data;
328
329    pub(crate) fn execute_for_testing(
330        db: &RelationalDB,
331        sql_text: &str,
332        q: Vec<CrudExpr>,
333    ) -> Result<Vec<MemTable>, DBError> {
334        let (subs, _runtime) = ModuleSubscriptions::for_test_new_runtime(Arc::new(db.clone()));
335        execute_sql(db, sql_text, q, AuthCtx::for_testing(), Some(&subs))
336    }
337
338    /// Short-cut for simplify test execution
339    pub(crate) fn run_for_testing(db: &RelationalDB, sql_text: &str) -> Result<Vec<ProductValue>, DBError> {
340        let (subs, _runtime) = ModuleSubscriptions::for_test_new_runtime(Arc::new(db.clone()));
341        run(db, sql_text, AuthCtx::for_testing(), Some(&subs), &mut vec![]).map(|x| x.rows)
342    }
343
344    fn create_data(total_rows: u64) -> ResultTest<(TestDB, MemTable)> {
345        let stdb = TestDB::durable()?;
346
347        let rows: Vec<_> = (1..=total_rows)
348            .map(|i| product!(i, format!("health{i}").into_boxed_str()))
349            .collect();
350        let head = ProductType::from([("inventory_id", AlgebraicType::U64), ("name", AlgebraicType::String)]);
351
352        let schema = with_auto_commit(&stdb, |tx| {
353            create_table_with_rows(&stdb, tx, "inventory", head.clone(), &rows, StAccess::Public)
354        })?;
355        let header = Header::from(&*schema).into();
356
357        Ok((stdb, MemTable::new(header, schema.table_access, rows)))
358    }
359
360    fn create_identity_table(table_name: &str) -> ResultTest<(TestDB, MemTable)> {
361        let stdb = TestDB::durable()?;
362        let head = ProductType::from([("identity", AlgebraicType::identity())]);
363        let rows = vec![product!(Identity::ZERO), product!(Identity::ONE)];
364
365        let schema = with_auto_commit(&stdb, |tx| {
366            create_table_with_rows(&stdb, tx, table_name, head.clone(), &rows, StAccess::Public)
367        })?;
368        let header = Header::from(&*schema).into();
369
370        Ok((stdb, MemTable::new(header, schema.table_access, rows)))
371    }
372
373    #[test]
374    fn test_select_star() -> ResultTest<()> {
375        let (db, input) = create_data(1)?;
376
377        let result = run_for_testing(&db, "SELECT * FROM inventory")?;
378
379        assert_eq!(result, input.data, "Inventory");
380        Ok(())
381    }
382
383    #[test]
384    fn test_limit() -> ResultTest<()> {
385        let (db, _) = create_data(5)?;
386
387        let result = run_for_testing(&db, "SELECT * FROM inventory limit 2")?;
388
389        let (_, input) = create_data(2)?;
390
391        assert_eq!(result, input.data, "Inventory");
392        Ok(())
393    }
394
395    #[test]
396    fn test_count() -> ResultTest<()> {
397        let (db, _) = create_data(5)?;
398
399        let sql = "SELECT count(*) as n FROM inventory";
400        let result = run_for_testing(&db, sql)?;
401        assert_eq!(result, vec![product![5u64]], "Inventory");
402
403        let sql = "SELECT count(*) as n FROM inventory limit 2";
404        let result = run_for_testing(&db, sql)?;
405        assert_eq!(result, vec![product![5u64]], "Inventory");
406
407        let sql = "SELECT count(*) as n FROM inventory WHERE inventory_id = 4 or inventory_id = 5";
408        let result = run_for_testing(&db, sql)?;
409        assert_eq!(result, vec![product![2u64]], "Inventory");
410        Ok(())
411    }
412
413    /// Test the evaluation of SELECT, UPDATE, and DELETE parameterized with `:sender`
414    #[test]
415    fn test_sender_param() -> ResultTest<()> {
416        let (db, _) = create_identity_table("user")?;
417
418        const SELECT_ALL: &str = "SELECT * FROM user";
419
420        let sql = "SELECT * FROM user WHERE identity = :sender";
421        let result = run_for_testing(&db, sql)?;
422        assert_eq!(result, vec![product![Identity::ZERO]]);
423
424        let sql = "DELETE FROM user WHERE identity = :sender";
425        run_for_testing(&db, sql)?;
426        let result = run_for_testing(&db, SELECT_ALL)?;
427        assert_eq!(result, vec![product![Identity::ONE]]);
428
429        let zero = "0".repeat(64);
430        let one = "0".repeat(63) + "1";
431
432        let sql = format!("UPDATE user SET identity = 0x{zero}");
433        run_for_testing(&db, &sql)?;
434        let sql = format!("UPDATE user SET identity = 0x{one} WHERE identity = :sender");
435        run_for_testing(&db, &sql)?;
436        let result = run_for_testing(&db, SELECT_ALL)?;
437        assert_eq!(result, vec![product![Identity::ONE]]);
438
439        Ok(())
440    }
441
442    /// Create an [Identity] from a [u8]
443    fn identity_from_u8(v: u8) -> Identity {
444        Identity::from_byte_array([v; 32])
445    }
446
447    /// Insert rules into the RLS system table
448    fn insert_rls_rules(
449        db: &RelationalDB,
450        table_ids: impl IntoIterator<Item = TableId>,
451        rules: impl IntoIterator<Item = &'static str>,
452    ) -> anyhow::Result<()> {
453        with_auto_commit(db, |tx| {
454            for (table_id, sql) in table_ids.into_iter().zip(rules) {
455                db.insert(
456                    tx,
457                    ST_ROW_LEVEL_SECURITY_ID,
458                    &ProductValue::from(StRowLevelSecurityRow {
459                        table_id,
460                        sql: sql.into(),
461                    })
462                    .to_bsatn_vec()?,
463                )?;
464            }
465            Ok(())
466        })
467    }
468
469    /// Insert product values into a table
470    fn insert_rows(
471        db: &RelationalDB,
472        table_id: TableId,
473        rows: impl IntoIterator<Item = ProductValue>,
474    ) -> anyhow::Result<()> {
475        with_auto_commit(db, |tx| {
476            for row in rows.into_iter() {
477                db.insert(tx, table_id, &row.to_bsatn_vec()?)?;
478            }
479            Ok(())
480        })
481    }
482
483    /// Assert this query returns the expected rows for this user
484    fn assert_query_results(
485        db: &RelationalDB,
486        sql: &str,
487        auth: &AuthCtx,
488        expected: impl IntoIterator<Item = ProductValue>,
489    ) {
490        assert_eq!(
491            run(db, sql, *auth, None, &mut vec![])
492                .unwrap()
493                .rows
494                .into_iter()
495                .sorted()
496                .dedup()
497                .collect::<Vec<_>>(),
498            expected.into_iter().sorted().dedup().collect::<Vec<_>>()
499        );
500    }
501
502    /// Test a query that uses a multi-column index
503    #[test]
504    fn test_multi_column_index() -> anyhow::Result<()> {
505        let db = TestDB::in_memory()?;
506
507        let schema = [
508            ("a", AlgebraicType::U64),
509            ("b", AlgebraicType::U64),
510            ("c", AlgebraicType::U64),
511        ];
512
513        let table_id = db.create_table_for_test_multi_column("t", &schema, [1, 2].into())?;
514
515        insert_rows(
516            &db,
517            table_id,
518            vec![
519                product![0_u64, 1_u64, 2_u64],
520                product![1_u64, 2_u64, 1_u64],
521                product![2_u64, 2_u64, 2_u64],
522            ],
523        )?;
524
525        assert_query_results(
526            &db,
527            "select * from t where c = 1 and b = 2",
528            &AuthCtx::for_testing(),
529            [product![1_u64, 2_u64, 1_u64]],
530        );
531
532        Ok(())
533    }
534
535    /// Test querying a table with RLS rules
536    #[test]
537    fn test_rls_rules() -> anyhow::Result<()> {
538        let db = TestDB::in_memory()?;
539
540        let id_for_a = identity_from_u8(1);
541        let id_for_b = identity_from_u8(2);
542
543        let users_schema = [("identity", AlgebraicType::identity())];
544        let sales_schema = [
545            ("order_id", AlgebraicType::U64),
546            ("customer", AlgebraicType::identity()),
547        ];
548
549        let users_table_id = db.create_table_for_test("users", &users_schema, &[])?;
550        let sales_table_id = db.create_table_for_test("sales", &sales_schema, &[])?;
551
552        insert_rows(&db, users_table_id, vec![product![id_for_a], product![id_for_b]])?;
553        insert_rows(
554            &db,
555            sales_table_id,
556            vec![
557                product![1u64, id_for_a],
558                product![2u64, id_for_b],
559                product![3u64, id_for_a],
560                product![4u64, id_for_b],
561            ],
562        )?;
563
564        insert_rls_rules(
565            &db,
566            [users_table_id, sales_table_id],
567            [
568                "select * from users where identity = :sender",
569                "select s.* from users u join sales s on u.identity = s.customer",
570            ],
571        )?;
572
573        let auth_for_a = AuthCtx::new(Identity::ZERO, id_for_a);
574        let auth_for_b = AuthCtx::new(Identity::ZERO, id_for_b);
575
576        assert_query_results(
577            &db,
578            // Should only return the identity for sender "a"
579            "select * from users",
580            &auth_for_a,
581            [product![id_for_a]],
582        );
583        assert_query_results(
584            &db,
585            // Should only return the identity for sender "b"
586            "select * from users",
587            &auth_for_b,
588            [product![id_for_b]],
589        );
590        assert_query_results(
591            &db,
592            // Should only return the orders for sender "a"
593            "select * from users where identity = :sender",
594            &auth_for_a,
595            [product![id_for_a]],
596        );
597        assert_query_results(
598            &db,
599            // Should only return the orders for sender "b"
600            "select * from users where identity = :sender",
601            &auth_for_b,
602            [product![id_for_b]],
603        );
604        assert_query_results(
605            &db,
606            // Should only return the orders for sender "a"
607            &format!("select * from users where identity = 0x{}", id_for_a.to_hex()),
608            &auth_for_a,
609            [product![id_for_a]],
610        );
611        assert_query_results(
612            &db,
613            // Should only return the orders for sender "b"
614            &format!("select * from users where identity = 0x{}", id_for_b.to_hex()),
615            &auth_for_b,
616            [product![id_for_b]],
617        );
618        assert_query_results(
619            &db,
620            // Should only return the orders for sender "a"
621            &format!(
622                "select * from users where identity = :sender and identity = 0x{}",
623                id_for_a.to_hex()
624            ),
625            &auth_for_a,
626            [product![id_for_a]],
627        );
628        assert_query_results(
629            &db,
630            // Should only return the orders for sender "b"
631            &format!(
632                "select * from users where identity = :sender and identity = 0x{}",
633                id_for_b.to_hex()
634            ),
635            &auth_for_b,
636            [product![id_for_b]],
637        );
638        assert_query_results(
639            &db,
640            // Should only return the orders for sender "a"
641            &format!(
642                "select * from users where identity = :sender or identity = 0x{}",
643                id_for_b.to_hex()
644            ),
645            &auth_for_a,
646            [product![id_for_a]],
647        );
648        assert_query_results(
649            &db,
650            // Should only return the orders for sender "b"
651            &format!(
652                "select * from users where identity = :sender or identity = 0x{}",
653                id_for_a.to_hex()
654            ),
655            &auth_for_b,
656            [product![id_for_b]],
657        );
658        assert_query_results(
659            &db,
660            // Should not return any rows.
661            // Querying as sender "a", but filtering on sender "b".
662            &format!("select * from users where identity = 0x{}", id_for_b.to_hex()),
663            &auth_for_a,
664            [],
665        );
666        assert_query_results(
667            &db,
668            // Should not return any rows.
669            // Querying as sender "b", but filtering on sender "a".
670            &format!("select * from users where identity = 0x{}", id_for_a.to_hex()),
671            &auth_for_b,
672            [],
673        );
674        assert_query_results(
675            &db,
676            // Should not return any rows.
677            // Querying as sender "a", but filtering on sender "b".
678            &format!(
679                "select * from users where identity = :sender and identity = 0x{}",
680                id_for_b.to_hex()
681            ),
682            &auth_for_a,
683            [],
684        );
685        assert_query_results(
686            &db,
687            // Should not return any rows.
688            // Querying as sender "b", but filtering on sender "a".
689            &format!(
690                "select * from users where identity = :sender and identity = 0x{}",
691                id_for_a.to_hex()
692            ),
693            &auth_for_b,
694            [],
695        );
696        assert_query_results(
697            &db,
698            // Should only return the orders for sender "a"
699            "select * from sales",
700            &auth_for_a,
701            [product![1u64, id_for_a], product![3u64, id_for_a]],
702        );
703        assert_query_results(
704            &db,
705            // Should only return the orders for sender "b"
706            "select * from sales",
707            &auth_for_b,
708            [product![2u64, id_for_b], product![4u64, id_for_b]],
709        );
710        assert_query_results(
711            &db,
712            // Should only return the orders for sender "a"
713            "select s.* from users u join sales s on u.identity = s.customer",
714            &auth_for_a,
715            [product![1u64, id_for_a], product![3u64, id_for_a]],
716        );
717        assert_query_results(
718            &db,
719            // Should only return the orders for sender "b"
720            "select s.* from users u join sales s on u.identity = s.customer",
721            &auth_for_b,
722            [product![2u64, id_for_b], product![4u64, id_for_b]],
723        );
724        assert_query_results(
725            &db,
726            // Should only return the orders for sender "a"
727            "select s.* from users u join sales s on u.identity = s.customer where u.identity = :sender",
728            &auth_for_a,
729            [product![1u64, id_for_a], product![3u64, id_for_a]],
730        );
731        assert_query_results(
732            &db,
733            // Should only return the orders for sender "b"
734            "select s.* from users u join sales s on u.identity = s.customer where u.identity = :sender",
735            &auth_for_b,
736            [product![2u64, id_for_b], product![4u64, id_for_b]],
737        );
738
739        Ok(())
740    }
741
742    /// Test querying tables with multiple levels of RLS rules
743    #[test]
744    fn test_nested_rls_rules() -> anyhow::Result<()> {
745        let db = TestDB::in_memory()?;
746
747        let id_for_a = identity_from_u8(1);
748        let id_for_b = identity_from_u8(2);
749        let id_for_c = identity_from_u8(3);
750
751        let users_schema = [("identity", AlgebraicType::identity())];
752        let sales_schema = [
753            ("order_id", AlgebraicType::U64),
754            ("product_id", AlgebraicType::U64),
755            ("customer", AlgebraicType::identity()),
756        ];
757
758        let users_table_id = db.create_table_for_test("users", &users_schema, &[0.into()])?;
759        let admin_table_id = db.create_table_for_test("admins", &users_schema, &[0.into()])?;
760        let sales_table_id = db.create_table_for_test("sales", &sales_schema, &[0.into()])?;
761
762        insert_rows(&db, admin_table_id, [product![id_for_c]])?;
763        insert_rows(
764            &db,
765            users_table_id,
766            [product![id_for_a], product![id_for_b], product![id_for_c]],
767        )?;
768        insert_rows(
769            &db,
770            sales_table_id,
771            [product![1u64, 1u64, id_for_a], product![2u64, 2u64, id_for_b]],
772        )?;
773
774        insert_rls_rules(
775            &db,
776            [admin_table_id, users_table_id, users_table_id, sales_table_id],
777            [
778                "select * from admins where identity = :sender",
779                "select * from users where identity = :sender",
780                "select users.* from admins join users",
781                "select s.* from users u join sales s on u.identity = s.customer",
782            ],
783        )?;
784
785        let auth_for_a = AuthCtx::new(Identity::ZERO, id_for_a);
786        let auth_for_b = AuthCtx::new(Identity::ZERO, id_for_b);
787        let auth_for_c = AuthCtx::new(Identity::ZERO, id_for_c);
788
789        assert_query_results(
790            &db,
791            "select * from admins",
792            &auth_for_a,
793            // Identity "a" is not an admin
794            [],
795        );
796        assert_query_results(
797            &db,
798            "select * from admins",
799            &auth_for_b,
800            // Identity "b" is not an admin
801            [],
802        );
803        assert_query_results(
804            &db,
805            "select * from admins",
806            &auth_for_c,
807            // Identity "c" is an admin
808            [product![id_for_c]],
809        );
810
811        assert_query_results(
812            &db,
813            "select * from users",
814            &auth_for_a,
815            // Identity "a" can only see its own user
816            vec![product![id_for_a]],
817        );
818        assert_query_results(
819            &db,
820            "select * from users",
821            &auth_for_b,
822            // Identity "b" can only see its own user
823            vec![product![id_for_b]],
824        );
825        assert_query_results(
826            &db,
827            "select * from users",
828            &auth_for_c,
829            // Identity "c" is an admin so it can see everyone's users
830            [product![id_for_a], product![id_for_b], product![id_for_c]],
831        );
832
833        assert_query_results(
834            &db,
835            "select * from sales",
836            &auth_for_a,
837            // Identity "a" can only see its own orders
838            [product![1u64, 1u64, id_for_a]],
839        );
840        assert_query_results(
841            &db,
842            "select * from sales",
843            &auth_for_b,
844            // Identity "b" can only see its own orders
845            [product![2u64, 2u64, id_for_b]],
846        );
847        assert_query_results(
848            &db,
849            "select * from sales",
850            &auth_for_c,
851            // Identity "c" is an admin so it can see everyone's orders
852            [product![1u64, 1u64, id_for_a], product![2u64, 2u64, id_for_b]],
853        );
854
855        Ok(())
856    }
857
858    /// Test projecting columns from both tables in join
859    #[test]
860    fn test_project_join() -> anyhow::Result<()> {
861        let db = TestDB::in_memory()?;
862
863        let t_schema = [("id", AlgebraicType::U8), ("x", AlgebraicType::U8)];
864        let s_schema = [("id", AlgebraicType::U8), ("y", AlgebraicType::U8)];
865
866        let t_id = db.create_table_for_test("t", &t_schema, &[0.into()])?;
867        let s_id = db.create_table_for_test("s", &s_schema, &[0.into()])?;
868
869        insert_rows(&db, t_id, [product![1_u8, 2_u8]])?;
870        insert_rows(&db, s_id, [product![1_u8, 3_u8]])?;
871
872        let id = identity_from_u8(1);
873        let auth = AuthCtx::new(Identity::ZERO, id);
874
875        assert_query_results(
876            &db,
877            "select t.x, s.y from t join s on t.id = s.id",
878            &auth,
879            [product![2_u8, 3_u8]],
880        );
881
882        Ok(())
883    }
884
885    #[test]
886    fn test_select_star_table() -> ResultTest<()> {
887        let (db, input) = create_data(1)?;
888
889        let result = run_for_testing(&db, "SELECT inventory.* FROM inventory")?;
890
891        assert_eq!(result, input.data, "Inventory");
892
893        let result = run_for_testing(
894            &db,
895            "SELECT inventory.inventory_id FROM inventory WHERE inventory.inventory_id = 1",
896        )?;
897
898        assert_eq!(result, vec![product!(1u64)], "Inventory");
899
900        Ok(())
901    }
902
903    #[test]
904    fn test_select_catalog() -> ResultTest<()> {
905        let (db, _) = create_data(1)?;
906
907        let tx = begin_tx(&db);
908        let _ = db.release_tx(tx);
909
910        let result = run_for_testing(
911            &db,
912            &format!("SELECT * FROM {ST_TABLE_NAME} WHERE table_id = {ST_TABLE_ID}"),
913        )?;
914
915        let pk_col_id: ColId = StTableFields::TableId.into();
916        let row = product![
917            ST_TABLE_ID,
918            ST_TABLE_NAME,
919            StTableType::System.as_str(),
920            StAccess::Public.as_str(),
921            Some(AlgebraicValue::Array(ArrayValue::U16(vec![pk_col_id.0].into()))),
922        ];
923
924        assert_eq!(result, vec![row], "st_table");
925        Ok(())
926    }
927
928    #[test]
929    fn test_select_column() -> ResultTest<()> {
930        let (db, _) = create_data(1)?;
931
932        let result = run_for_testing(&db, "SELECT inventory_id FROM inventory")?;
933
934        let row = product![1u64];
935
936        assert_eq!(result, vec![row], "Inventory");
937        Ok(())
938    }
939
940    #[test]
941    fn test_where() -> ResultTest<()> {
942        let (db, _) = create_data(1)?;
943
944        let result = run_for_testing(&db, "SELECT inventory_id FROM inventory WHERE inventory_id = 1")?;
945
946        let row = product![1u64];
947
948        assert_eq!(result, vec![row], "Inventory");
949        Ok(())
950    }
951
952    #[test]
953    fn test_or() -> ResultTest<()> {
954        let (db, _) = create_data(2)?;
955
956        let mut result = run_for_testing(
957            &db,
958            "SELECT inventory_id FROM inventory WHERE inventory_id = 1 OR inventory_id = 2",
959        )?;
960
961        result.sort();
962
963        assert_eq!(result, vec![product![1u64], product![2u64]], "Inventory");
964        Ok(())
965    }
966
967    #[test]
968    fn test_nested() -> ResultTest<()> {
969        let (db, _) = create_data(2)?;
970
971        let mut result = run_for_testing(
972            &db,
973            "SELECT inventory_id FROM inventory WHERE (inventory_id = 1 OR inventory_id = 2 AND (true))",
974        )?;
975
976        result.sort();
977
978        assert_eq!(result, vec![product![1u64], product![2u64]], "Inventory");
979        Ok(())
980    }
981
982    #[test]
983    fn test_inner_join() -> ResultTest<()> {
984        let data = create_game_data();
985
986        let db = TestDB::durable()?;
987
988        with_auto_commit::<_, TestError>(&db, |tx| {
989            let i = create_table_with_rows(&db, tx, "Inventory", data.inv_ty, &data.inv.data, StAccess::Public)?;
990            let p = create_table_with_rows(&db, tx, "Player", data.player_ty, &data.player.data, StAccess::Public)?;
991            create_table_with_rows(
992                &db,
993                tx,
994                "Location",
995                data.location_ty,
996                &data.location.data,
997                StAccess::Public,
998            )?;
999            Ok((p, i))
1000        })?;
1001
1002        let result = run_for_testing(
1003            &db,
1004            "SELECT
1005        Player.*
1006            FROM
1007        Player
1008        JOIN Location
1009        ON Location.entity_id = Player.entity_id
1010        WHERE Location.x > 0 AND Location.x <= 32 AND Location.z > 0 AND Location.z <= 32",
1011        )?;
1012
1013        let row1 = product!(100u64, 1u64);
1014
1015        assert_eq!(result, vec![row1], "Player JOIN Location");
1016
1017        let result = run_for_testing(
1018            &db,
1019            "SELECT
1020        Inventory.*
1021            FROM
1022        Inventory
1023        JOIN Player
1024        ON Inventory.inventory_id = Player.inventory_id
1025        JOIN Location
1026        ON Player.entity_id = Location.entity_id
1027        WHERE Location.x > 0 AND Location.x <= 32 AND Location.z > 0 AND Location.z <= 32",
1028        )?;
1029
1030        let row1 = product!(1u64, "health");
1031
1032        assert_eq!(result, vec![row1], "Inventory JOIN Player JOIN Location");
1033        Ok(())
1034    }
1035
1036    #[test]
1037    fn test_insert() -> ResultTest<()> {
1038        let (db, mut input) = create_data(1)?;
1039
1040        let result = run_for_testing(&db, "INSERT INTO inventory (inventory_id, name) VALUES (2, 'test')")?;
1041
1042        assert_eq!(result.len(), 0, "Return results");
1043
1044        let mut result = run_for_testing(&db, "SELECT * FROM inventory")?;
1045
1046        input.data.push(product![2u64, "test"]);
1047        input.data.sort();
1048        result.sort();
1049
1050        assert_eq!(result, input.data, "Inventory");
1051
1052        Ok(())
1053    }
1054
1055    #[test]
1056    fn test_delete() -> ResultTest<()> {
1057        let (db, _input) = create_data(1)?;
1058
1059        run_for_testing(&db, "INSERT INTO inventory (inventory_id, name) VALUES (2, 't2')")?;
1060        run_for_testing(&db, "INSERT INTO inventory (inventory_id, name) VALUES (3, 't3')")?;
1061
1062        let result = run_for_testing(&db, "SELECT * FROM inventory")?;
1063        assert_eq!(result.len(), 3, "Not return results");
1064
1065        run_for_testing(&db, "DELETE FROM inventory WHERE inventory.inventory_id = 3")?;
1066
1067        let result = run_for_testing(&db, "SELECT * FROM inventory")?;
1068        assert_eq!(result.len(), 2, "Not delete correct row?");
1069
1070        run_for_testing(&db, "DELETE FROM inventory")?;
1071
1072        let result = run_for_testing(&db, "SELECT * FROM inventory")?;
1073        assert_eq!(result.len(), 0, "Not delete all rows");
1074
1075        Ok(())
1076    }
1077
1078    #[test]
1079    fn test_update() -> ResultTest<()> {
1080        let (db, input) = create_data(1)?;
1081
1082        run_for_testing(&db, "INSERT INTO inventory (inventory_id, name) VALUES (2, 't2')")?;
1083        run_for_testing(&db, "INSERT INTO inventory (inventory_id, name) VALUES (3, 't3')")?;
1084
1085        run_for_testing(&db, "UPDATE inventory SET name = 'c2' WHERE inventory_id = 2")?;
1086
1087        let result = run_for_testing(&db, "SELECT * FROM inventory WHERE inventory_id = 2")?;
1088
1089        let mut change = input;
1090        change.data.clear();
1091        change.data.push(product![2u64, "c2"]);
1092
1093        assert_eq!(result, change.data, "Update Inventory 2");
1094
1095        run_for_testing(&db, "UPDATE inventory SET name = 'c3'")?;
1096
1097        let result = run_for_testing(&db, "SELECT * FROM inventory")?;
1098
1099        let updated: Vec<_> = result
1100            .into_iter()
1101            .map(|x| x.field_as_str(1, None).unwrap().to_string())
1102            .collect();
1103        assert_eq!(vec!["c3"; 3], updated);
1104
1105        Ok(())
1106    }
1107
1108    #[test]
1109    fn test_multi_column() -> ResultTest<()> {
1110        let (db, _input) = create_data(1)?;
1111
1112        // Create table [test] with index on [a, b]
1113        let schema = &[
1114            ("a", AlgebraicType::I32),
1115            ("b", AlgebraicType::I32),
1116            ("c", AlgebraicType::I32),
1117            ("d", AlgebraicType::I32),
1118        ];
1119        let table_id = db.create_table_for_test_multi_column("test", schema, col_list![0, 1])?;
1120        with_auto_commit(&db, |tx| insert(&db, tx, table_id, &product![1, 1, 1, 1]).map(drop))?;
1121
1122        let result = run_for_testing(&db, "select * from test where b = 1 and a = 1")?;
1123
1124        assert_eq!(result, vec![product![1, 1, 1, 1]]);
1125
1126        Ok(())
1127    }
1128
1129    /// Test we are protected against stack overflows when:
1130    /// 1. The query is too large (too many characters)
1131    /// 2. The AST is too deep
1132    ///
1133    /// Exercise the limit [`recursion::MAX_RECURSION_EXPR`]
1134    #[test]
1135    fn test_large_query_no_panic() -> ResultTest<()> {
1136        let db = TestDB::durable()?;
1137
1138        let _table_id = db
1139            .create_table_for_test_multi_column(
1140                "test",
1141                &[("x", AlgebraicType::I32), ("y", AlgebraicType::I32)],
1142                col_list![0, 1],
1143            )
1144            .unwrap();
1145
1146        let build_query = |total| {
1147            let mut sql = "select * from test where ".to_string();
1148            for x in 1..total {
1149                let fragment = format!("x = {x} or ");
1150                sql.push_str(&fragment.repeat((total - 1) as usize));
1151            }
1152            sql.push_str("(y = 0)");
1153            sql
1154        };
1155        let run = |db: &RelationalDB, sep: char, sql_text: &str| {
1156            run_for_testing(db, sql_text).map_err(|e| e.to_string().split(sep).next().unwrap_or_default().to_string())
1157        };
1158        let sql = build_query(1_000);
1159        assert_eq!(
1160            run(&db, ':', &sql),
1161            Err("SQL query exceeds maximum allowed length".to_string())
1162        );
1163
1164        let sql = build_query(41); // This causes stack overflow without the limit
1165        assert_eq!(run(&db, ',', &sql), Err("Recursion limit exceeded".to_string()));
1166
1167        let sql = build_query(40); // The max we can with the current limit
1168        assert!(run(&db, ',', &sql).is_ok(), "Expected query to run without panic");
1169
1170        // Check no overflow with lot of joins
1171        let mut sql = "SELECT test.* FROM test ".to_string();
1172        // We could push up to 700 joins without overflow as long we don't have any conditions,
1173        // but here execution become too slow.
1174        // TODO: Move this test to the `Plan`
1175        for i in 0..200 {
1176            sql.push_str(&format!("JOIN test AS m{i} ON test.x = m{i}.y "));
1177        }
1178
1179        assert!(
1180            run(&db, ',', &sql).is_ok(),
1181            "Query with many joins and conditions should not overflow"
1182        );
1183        Ok(())
1184    }
1185
1186    #[test]
1187    fn test_impossible_bounds_no_panic() -> ResultTest<()> {
1188        let db = TestDB::durable()?;
1189
1190        let table_id = db
1191            .create_table_for_test("test", &[("x", AlgebraicType::I32)], &[ColId(0)])
1192            .unwrap();
1193
1194        with_auto_commit(&db, |tx| {
1195            for i in 0..1000i32 {
1196                insert(&db, tx, table_id, &product!(i)).unwrap();
1197            }
1198            Ok::<(), DBError>(())
1199        })
1200        .unwrap();
1201
1202        let result = run_for_testing(&db, "select * from test where x > 5 and x < 5").unwrap();
1203        assert!(result.is_empty());
1204
1205        let result = run_for_testing(&db, "select * from test where x >= 5 and x < 4").unwrap();
1206        assert!(result.is_empty(), "Expected no rows but found {result:#?}");
1207
1208        let result = run_for_testing(&db, "select * from test where x > 5 and x <= 4").unwrap();
1209        assert!(result.is_empty());
1210        Ok(())
1211    }
1212
1213    #[test]
1214    fn test_multi_column_two_ranges() -> ResultTest<()> {
1215        let db = TestDB::durable()?;
1216
1217        // Create table [test] with index on [a, b]
1218        let schema = &[("a", AlgebraicType::U8), ("b", AlgebraicType::U8)];
1219        let table_id = db.create_table_for_test_multi_column("test", schema, col_list![0, 1])?;
1220        let row = product![4u8, 8u8];
1221        with_auto_commit(&db, |tx| insert(&db, tx, table_id, &row.clone()).map(drop))?;
1222
1223        let result = run_for_testing(&db, "select * from test where a >= 3 and a <= 5 and b >= 3 and b <= 5")?;
1224
1225        assert!(result.is_empty());
1226
1227        Ok(())
1228    }
1229
1230    #[test]
1231    fn test_row_limit() -> ResultTest<()> {
1232        let db = TestDB::durable()?;
1233
1234        let table_id = db.create_table_for_test("T", &[("a", AlgebraicType::U8)], &[])?;
1235        with_auto_commit(&db, |tx| -> Result<_, DBError> {
1236            for i in 0..5u8 {
1237                insert(&db, tx, table_id, &product!(i))?;
1238            }
1239            Ok(())
1240        })?;
1241
1242        let server = Identity::from_claims("issuer", "server");
1243        let client = Identity::from_claims("issuer", "client");
1244
1245        let internal_auth = AuthCtx::new(server, server);
1246        let external_auth = AuthCtx::new(server, client);
1247
1248        let run = |db, sql, auth, subs| run(db, sql, auth, subs, &mut vec![]);
1249
1250        // No row limit, both queries pass.
1251        assert!(run(&db, "SELECT * FROM T", internal_auth, None).is_ok());
1252        assert!(run(&db, "SELECT * FROM T", external_auth, None).is_ok());
1253
1254        // Set row limit.
1255        assert!(run(&db, "SET row_limit = 4", internal_auth, None).is_ok());
1256
1257        // External query fails.
1258        assert!(run(&db, "SELECT * FROM T", internal_auth, None).is_ok());
1259        assert!(run(&db, "SELECT * FROM T", external_auth, None).is_err());
1260
1261        // Increase row limit.
1262        assert!(run(&db, "DELETE FROM st_var WHERE name = 'row_limit'", internal_auth, None).is_ok());
1263        assert!(run(&db, "SET row_limit = 5", internal_auth, None).is_ok());
1264
1265        // Both queries pass.
1266        assert!(run(&db, "SELECT * FROM T", internal_auth, None).is_ok());
1267        assert!(run(&db, "SELECT * FROM T", external_auth, None).is_ok());
1268
1269        Ok(())
1270    }
1271
1272    // Verify we don't return rows on DML
1273    #[test]
1274    fn test_row_dml() -> ResultTest<()> {
1275        let db = TestDB::durable()?;
1276
1277        let table_id = db.create_table_for_test("T", &[("a", AlgebraicType::U8)], &[])?;
1278        with_auto_commit(&db, |tx| -> Result<_, DBError> {
1279            for i in 0..4u8 {
1280                insert(&db, tx, table_id, &product!(i))?;
1281            }
1282            Ok(())
1283        })?;
1284
1285        let server = Identity::from_claims("issuer", "server");
1286
1287        let internal_auth = AuthCtx::new(server, server);
1288
1289        let run = |db, sql, auth, subs| run(db, sql, auth, subs, &mut vec![]);
1290
1291        let check = |db, sql, auth, metrics: ExecutionMetrics| {
1292            let result = run(db, sql, auth, None)?;
1293            assert_eq!(result.rows, vec![]);
1294            assert_eq!(result.metrics.rows_inserted, metrics.rows_inserted);
1295            assert_eq!(result.metrics.rows_deleted, metrics.rows_deleted);
1296            assert_eq!(result.metrics.rows_updated, metrics.rows_updated);
1297
1298            Ok::<(), DBError>(())
1299        };
1300
1301        let ins = ExecutionMetrics {
1302            rows_inserted: 1,
1303            ..ExecutionMetrics::default()
1304        };
1305        let upd = ExecutionMetrics {
1306            rows_updated: 5,
1307            ..ExecutionMetrics::default()
1308        };
1309        let del = ExecutionMetrics {
1310            rows_deleted: 1,
1311            ..ExecutionMetrics::default()
1312        };
1313
1314        check(&db, "INSERT INTO T (a) VALUES (5)", internal_auth, ins)?;
1315        check(&db, "UPDATE T SET a = 2", internal_auth, upd)?;
1316        assert_eq!(
1317            run(&db, "SELECT * FROM T", internal_auth, None)?.rows,
1318            vec![product!(2u8)]
1319        );
1320        check(&db, "DELETE FROM T", internal_auth, del)?;
1321
1322        Ok(())
1323    }
1324}