spacetimedb/sql/
execute.rs

1use std::time::Duration;
2
3use super::ast::SchemaViewer;
4use crate::db::datastore::locking_tx_datastore::state_view::StateView;
5use crate::db::datastore::traits::IsolationLevel;
6use crate::db::relational_db::{RelationalDB, Tx};
7use crate::energy::EnergyQuanta;
8use crate::error::DBError;
9use crate::estimation::estimate_rows_scanned;
10use crate::execution_context::Workload;
11use crate::host::module_host::{DatabaseTableUpdate, DatabaseUpdate, EventStatus, ModuleEvent, ModuleFunctionCall};
12use crate::host::ArgsTuple;
13use crate::subscription::module_subscription_actor::{ModuleSubscriptions, WriteConflict};
14use crate::subscription::tx::DeltaTx;
15use crate::util::slow::SlowQueryLogger;
16use crate::vm::{check_row_limit, DbProgram, TxMode};
17use anyhow::anyhow;
18use spacetimedb_expr::statement::Statement;
19use spacetimedb_lib::identity::AuthCtx;
20use spacetimedb_lib::metrics::ExecutionMetrics;
21use spacetimedb_lib::relation::FieldName;
22use spacetimedb_lib::Timestamp;
23use spacetimedb_lib::{AlgebraicType, ProductType, ProductValue};
24use spacetimedb_query::{compile_sql_stmt, execute_dml_stmt, execute_select_stmt};
25use spacetimedb_vm::eval::run_ast;
26use spacetimedb_vm::expr::{CodeResult, CrudExpr, Expr};
27use spacetimedb_vm::relation::MemTable;
28
29pub struct StmtResult {
30    pub schema: ProductType,
31    pub rows: Vec<ProductValue>,
32}
33
34// TODO(cloutiertyler): we could do this the swift parsing way in which
35// we always generate a plan, but it may contain errors
36
37pub(crate) fn collect_result(
38    result: &mut Vec<MemTable>,
39    updates: &mut Vec<DatabaseTableUpdate>,
40    r: CodeResult,
41) -> Result<(), DBError> {
42    match r {
43        CodeResult::Value(_) => {}
44        CodeResult::Table(x) => result.push(x),
45        CodeResult::Block(lines) => {
46            for x in lines {
47                collect_result(result, updates, x)?;
48            }
49        }
50        CodeResult::Halt(err) => return Err(DBError::VmUser(err)),
51        CodeResult::Pass(x) => match x {
52            None => {}
53            Some(update) => {
54                updates.push(DatabaseTableUpdate {
55                    table_name: update.table_name,
56                    table_id: update.table_id,
57                    inserts: update.inserts.into(),
58                    deletes: update.deletes.into(),
59                });
60            }
61        },
62    }
63
64    Ok(())
65}
66
67fn execute(
68    p: &mut DbProgram<'_, '_>,
69    ast: Vec<CrudExpr>,
70    sql: &str,
71    updates: &mut Vec<DatabaseTableUpdate>,
72) -> Result<Vec<MemTable>, DBError> {
73    let slow_query_threshold = if let TxMode::Tx(tx) = p.tx {
74        p.db.query_limit(tx)?.map(Duration::from_millis)
75    } else {
76        None
77    };
78    let _slow_query_logger = SlowQueryLogger::new(sql, slow_query_threshold, p.tx.ctx().workload()).log_guard();
79    let mut result = Vec::with_capacity(ast.len());
80    let query = Expr::Block(ast.into_iter().map(|x| Expr::Crud(Box::new(x))).collect());
81    // SQL queries can never reference `MemTable`s, so pass an empty `SourceSet`.
82    collect_result(&mut result, updates, run_ast(p, query, [].into()).into())?;
83    Ok(result)
84}
85
86/// Run the compiled `SQL` expression inside the `vm` created by [DbProgram]
87///
88/// Evaluates `ast` and accordingly triggers mutable or read tx to execute
89///
90/// Also, in case the execution takes more than x, log it as `slow query`
91pub fn execute_sql(
92    db: &RelationalDB,
93    sql: &str,
94    ast: Vec<CrudExpr>,
95    auth: AuthCtx,
96    subs: Option<&ModuleSubscriptions>,
97) -> Result<Vec<MemTable>, DBError> {
98    if CrudExpr::is_reads(&ast) {
99        let mut updates = Vec::new();
100        db.with_read_only(Workload::Sql, |tx| {
101            execute(
102                &mut DbProgram::new(db, &mut TxMode::Tx(tx), auth),
103                ast,
104                sql,
105                &mut updates,
106            )
107        })
108    } else if subs.is_none() {
109        let mut updates = Vec::new();
110        db.with_auto_commit(Workload::Sql, |mut_tx| {
111            execute(
112                &mut DbProgram::new(db, &mut mut_tx.into(), auth),
113                ast,
114                sql,
115                &mut updates,
116            )
117        })
118    } else {
119        let mut tx = db.begin_mut_tx(IsolationLevel::Serializable, Workload::Sql);
120        let mut updates = Vec::with_capacity(ast.len());
121        let res = execute(
122            &mut DbProgram::new(db, &mut (&mut tx).into(), auth),
123            ast,
124            sql,
125            &mut updates,
126        );
127        if res.is_ok() && !updates.is_empty() {
128            let event = ModuleEvent {
129                timestamp: Timestamp::now(),
130                caller_identity: auth.caller,
131                caller_connection_id: None,
132                function_call: ModuleFunctionCall {
133                    reducer: String::new(),
134                    reducer_id: u32::MAX.into(),
135                    args: ArgsTuple::default(),
136                },
137                status: EventStatus::Committed(DatabaseUpdate { tables: updates }),
138                energy_quanta_used: EnergyQuanta::ZERO,
139                host_execution_duration: Duration::ZERO,
140                request_id: None,
141                timer: None,
142            };
143            match subs.unwrap().commit_and_broadcast_event(None, event, tx).unwrap() {
144                Ok(_) => res,
145                Err(WriteConflict) => todo!("See module_host_actor::call_reducer_with_tx"),
146            }
147        } else {
148            db.finish_tx(tx, res)
149        }
150    }
151}
152
153/// Like [`execute_sql`], but for providing your own `tx`.
154///
155/// Returns None if you pass a mutable query with an immutable tx.
156pub fn execute_sql_tx<'a>(
157    db: &RelationalDB,
158    tx: impl Into<TxMode<'a>>,
159    sql: &str,
160    ast: Vec<CrudExpr>,
161    auth: AuthCtx,
162) -> Result<Option<Vec<MemTable>>, DBError> {
163    let mut tx = tx.into();
164
165    if matches!(tx, TxMode::Tx(_)) && !CrudExpr::is_reads(&ast) {
166        return Ok(None);
167    }
168
169    let mut updates = Vec::new(); // No subscription updates in this path, because it requires owning the tx.
170    execute(&mut DbProgram::new(db, &mut tx, auth), ast, sql, &mut updates).map(Some)
171}
172
173pub struct SqlResult {
174    pub rows: Vec<ProductValue>,
175    /// These metrics will be reported via `report_tx_metrics`.
176    /// They should not be reported separately to avoid double counting.
177    pub metrics: ExecutionMetrics,
178}
179
180/// Run the `SQL` string using the `auth` credentials
181pub fn run(
182    db: &RelationalDB,
183    sql_text: &str,
184    auth: AuthCtx,
185    subs: Option<&ModuleSubscriptions>,
186    head: &mut Vec<(Box<str>, AlgebraicType)>,
187) -> Result<SqlResult, DBError> {
188    // We parse the sql statement in a mutable transaction.
189    // If it turns out to be a query, we downgrade the tx.
190    let (tx, stmt) = db.with_auto_rollback(db.begin_mut_tx(IsolationLevel::Serializable, Workload::Sql), |tx| {
191        compile_sql_stmt(sql_text, &SchemaViewer::new(tx, &auth), &auth)
192    })?;
193
194    let mut metrics = ExecutionMetrics::default();
195
196    match stmt {
197        Statement::Select(stmt) => {
198            // Up to this point, the tx has been read-only,
199            // and hence there are no deltas to process.
200            let (tx_data, tx_metrics_mut, tx) = tx.commit_downgrade(Workload::Sql);
201
202            // Release the tx on drop, so that we record metrics.
203            let mut tx = scopeguard::guard(tx, |tx| {
204                let (tx_metrics_downgrade, reducer) = db.release_tx(tx);
205                db.report_tx_metricses(&reducer, Some(&tx_data), Some(&tx_metrics_mut), &tx_metrics_downgrade);
206            });
207
208            // Compute the header for the result set
209            stmt.for_each_return_field(|col_name, col_type| {
210                head.push((col_name.into(), col_type.clone()));
211            });
212
213            // Evaluate the query
214            let rows = execute_select_stmt(stmt, &DeltaTx::from(&*tx), &mut metrics, |plan| {
215                check_row_limit(
216                    &[&plan],
217                    db,
218                    &tx,
219                    |plan, tx| plan.plan_iter().map(|plan| estimate_rows_scanned(tx, plan)).sum(),
220                    &auth,
221                )?;
222                Ok(plan)
223            })?;
224
225            // Update transaction metrics
226            tx.metrics.merge(metrics);
227
228            Ok(SqlResult {
229                rows,
230                metrics: tx.metrics,
231            })
232        }
233        Statement::DML(stmt) => {
234            // An extra layer of auth is required for DML
235            if auth.caller != auth.owner {
236                return Err(anyhow!("Only owners are authorized to run SQL DML statements").into());
237            }
238
239            // Evaluate the mutation
240            let (mut tx, _) = db.with_auto_rollback(tx, |tx| execute_dml_stmt(stmt, tx, &mut metrics))?;
241
242            // Update transaction metrics
243            tx.metrics.merge(metrics);
244
245            // Commit the tx if there are no deltas to process
246            if subs.is_none() {
247                let metrics = tx.metrics;
248                return db.commit_tx(tx).map(|tx_opt| {
249                    if let Some((tx_data, tx_metrics, reducer)) = tx_opt {
250                        db.report(&reducer, &tx_metrics, Some(&tx_data));
251                    }
252                    SqlResult { rows: vec![], metrics }
253                });
254            }
255
256            // Otherwise downgrade the tx and process the deltas.
257            // Note, we get the delta by downgrading the tx.
258            // Hence we just pass a default `DatabaseUpdate` here.
259            // It will ultimately be replaced with the correct one.
260            match subs
261                .unwrap()
262                .commit_and_broadcast_event(
263                    None,
264                    ModuleEvent {
265                        timestamp: Timestamp::now(),
266                        caller_identity: auth.caller,
267                        caller_connection_id: None,
268                        function_call: ModuleFunctionCall {
269                            reducer: String::new(),
270                            reducer_id: u32::MAX.into(),
271                            args: ArgsTuple::default(),
272                        },
273                        status: EventStatus::Committed(DatabaseUpdate::default()),
274                        energy_quanta_used: EnergyQuanta::ZERO,
275                        host_execution_duration: Duration::ZERO,
276                        request_id: None,
277                        timer: None,
278                    },
279                    tx,
280                )
281                .unwrap()
282            {
283                Err(WriteConflict) => {
284                    todo!("See module_host_actor::call_reducer_with_tx")
285                }
286                Ok(_) => Ok(SqlResult { rows: vec![], metrics }),
287            }
288        }
289    }
290}
291
292/// Translates a `FieldName` to the field's name.
293pub fn translate_col(tx: &Tx, field: FieldName) -> Option<Box<str>> {
294    Some(
295        tx.get_schema(field.table)?
296            .get_column(field.col.idx())?
297            .col_name
298            .clone(),
299    )
300}
301
302#[cfg(test)]
303pub(crate) mod tests {
304    use std::sync::Arc;
305
306    use super::*;
307    use crate::db::datastore::system_tables::{
308        StRowLevelSecurityRow, StTableFields, ST_ROW_LEVEL_SECURITY_ID, ST_TABLE_ID, ST_TABLE_NAME,
309    };
310    use crate::db::relational_db::tests_utils::{begin_tx, insert, with_auto_commit, TestDB};
311    use crate::vm::tests::create_table_with_rows;
312    use itertools::Itertools;
313    use pretty_assertions::assert_eq;
314    use spacetimedb_lib::bsatn::ToBsatn;
315    use spacetimedb_lib::db::auth::{StAccess, StTableType};
316    use spacetimedb_lib::error::{ResultTest, TestError};
317    use spacetimedb_lib::relation::Header;
318    use spacetimedb_lib::{AlgebraicValue, Identity};
319    use spacetimedb_primitives::{col_list, ColId, TableId};
320    use spacetimedb_sats::{product, AlgebraicType, ArrayValue, ProductType};
321    use spacetimedb_vm::eval::test_helpers::create_game_data;
322
323    pub(crate) fn execute_for_testing(
324        db: &RelationalDB,
325        sql_text: &str,
326        q: Vec<CrudExpr>,
327    ) -> Result<Vec<MemTable>, DBError> {
328        let (subs, _runtime) = ModuleSubscriptions::for_test_new_runtime(Arc::new(db.clone()));
329        execute_sql(db, sql_text, q, AuthCtx::for_testing(), Some(&subs))
330    }
331
332    /// Short-cut for simplify test execution
333    pub(crate) fn run_for_testing(db: &RelationalDB, sql_text: &str) -> Result<Vec<ProductValue>, DBError> {
334        let (subs, _runtime) = ModuleSubscriptions::for_test_new_runtime(Arc::new(db.clone()));
335        run(db, sql_text, AuthCtx::for_testing(), Some(&subs), &mut vec![]).map(|x| x.rows)
336    }
337
338    fn create_data(total_rows: u64) -> ResultTest<(TestDB, MemTable)> {
339        let stdb = TestDB::durable()?;
340
341        let rows: Vec<_> = (1..=total_rows)
342            .map(|i| product!(i, format!("health{i}").into_boxed_str()))
343            .collect();
344        let head = ProductType::from([("inventory_id", AlgebraicType::U64), ("name", AlgebraicType::String)]);
345
346        let schema = with_auto_commit(&stdb, |tx| {
347            create_table_with_rows(&stdb, tx, "inventory", head.clone(), &rows, StAccess::Public)
348        })?;
349        let header = Header::from(&*schema).into();
350
351        Ok((stdb, MemTable::new(header, schema.table_access, rows)))
352    }
353
354    fn create_identity_table(table_name: &str) -> ResultTest<(TestDB, MemTable)> {
355        let stdb = TestDB::durable()?;
356        let head = ProductType::from([("identity", AlgebraicType::identity())]);
357        let rows = vec![product!(Identity::ZERO), product!(Identity::ONE)];
358
359        let schema = with_auto_commit(&stdb, |tx| {
360            create_table_with_rows(&stdb, tx, table_name, head.clone(), &rows, StAccess::Public)
361        })?;
362        let header = Header::from(&*schema).into();
363
364        Ok((stdb, MemTable::new(header, schema.table_access, rows)))
365    }
366
367    #[test]
368    fn test_select_star() -> ResultTest<()> {
369        let (db, input) = create_data(1)?;
370
371        let result = run_for_testing(&db, "SELECT * FROM inventory")?;
372
373        assert_eq!(result, input.data, "Inventory");
374        Ok(())
375    }
376
377    #[test]
378    fn test_limit() -> ResultTest<()> {
379        let (db, _) = create_data(5)?;
380
381        let result = run_for_testing(&db, "SELECT * FROM inventory limit 2")?;
382
383        let (_, input) = create_data(2)?;
384
385        assert_eq!(result, input.data, "Inventory");
386        Ok(())
387    }
388
389    #[test]
390    fn test_count() -> ResultTest<()> {
391        let (db, _) = create_data(5)?;
392
393        let sql = "SELECT count(*) as n FROM inventory";
394        let result = run_for_testing(&db, sql)?;
395        assert_eq!(result, vec![product![5u64]], "Inventory");
396
397        let sql = "SELECT count(*) as n FROM inventory limit 2";
398        let result = run_for_testing(&db, sql)?;
399        assert_eq!(result, vec![product![5u64]], "Inventory");
400
401        let sql = "SELECT count(*) as n FROM inventory WHERE inventory_id = 4 or inventory_id = 5";
402        let result = run_for_testing(&db, sql)?;
403        assert_eq!(result, vec![product![2u64]], "Inventory");
404        Ok(())
405    }
406
407    /// Test the evaluation of SELECT, UPDATE, and DELETE parameterized with `:sender`
408    #[test]
409    fn test_sender_param() -> ResultTest<()> {
410        let (db, _) = create_identity_table("user")?;
411
412        const SELECT_ALL: &str = "SELECT * FROM user";
413
414        let sql = "SELECT * FROM user WHERE identity = :sender";
415        let result = run_for_testing(&db, sql)?;
416        assert_eq!(result, vec![product![Identity::ZERO]]);
417
418        let sql = "DELETE FROM user WHERE identity = :sender";
419        run_for_testing(&db, sql)?;
420        let result = run_for_testing(&db, SELECT_ALL)?;
421        assert_eq!(result, vec![product![Identity::ONE]]);
422
423        let zero = "0".repeat(64);
424        let one = "0".repeat(63) + "1";
425
426        let sql = format!("UPDATE user SET identity = 0x{zero}");
427        run_for_testing(&db, &sql)?;
428        let sql = format!("UPDATE user SET identity = 0x{one} WHERE identity = :sender");
429        run_for_testing(&db, &sql)?;
430        let result = run_for_testing(&db, SELECT_ALL)?;
431        assert_eq!(result, vec![product![Identity::ONE]]);
432
433        Ok(())
434    }
435
436    /// Create an [Identity] from a [u8]
437    fn identity_from_u8(v: u8) -> Identity {
438        Identity::from_byte_array([v; 32])
439    }
440
441    /// Insert rules into the RLS system table
442    fn insert_rls_rules(
443        db: &RelationalDB,
444        table_ids: impl IntoIterator<Item = TableId>,
445        rules: impl IntoIterator<Item = &'static str>,
446    ) -> anyhow::Result<()> {
447        with_auto_commit(db, |tx| {
448            for (table_id, sql) in table_ids.into_iter().zip(rules) {
449                db.insert(
450                    tx,
451                    ST_ROW_LEVEL_SECURITY_ID,
452                    &ProductValue::from(StRowLevelSecurityRow {
453                        table_id,
454                        sql: sql.into(),
455                    })
456                    .to_bsatn_vec()?,
457                )?;
458            }
459            Ok(())
460        })
461    }
462
463    /// Insert product values into a table
464    fn insert_rows(
465        db: &RelationalDB,
466        table_id: TableId,
467        rows: impl IntoIterator<Item = ProductValue>,
468    ) -> anyhow::Result<()> {
469        with_auto_commit(db, |tx| {
470            for row in rows.into_iter() {
471                db.insert(tx, table_id, &row.to_bsatn_vec()?)?;
472            }
473            Ok(())
474        })
475    }
476
477    /// Assert this query returns the expected rows for this user
478    fn assert_query_results(
479        db: &RelationalDB,
480        sql: &str,
481        auth: &AuthCtx,
482        expected: impl IntoIterator<Item = ProductValue>,
483    ) {
484        assert_eq!(
485            run(db, sql, *auth, None, &mut vec![])
486                .unwrap()
487                .rows
488                .into_iter()
489                .sorted()
490                .dedup()
491                .collect::<Vec<_>>(),
492            expected.into_iter().sorted().dedup().collect::<Vec<_>>()
493        );
494    }
495
496    /// Test a query that uses a multi-column index
497    #[test]
498    fn test_multi_column_index() -> anyhow::Result<()> {
499        let db = TestDB::in_memory()?;
500
501        let schema = [
502            ("a", AlgebraicType::U64),
503            ("b", AlgebraicType::U64),
504            ("c", AlgebraicType::U64),
505        ];
506
507        let table_id = db.create_table_for_test_multi_column("t", &schema, [1, 2].into())?;
508
509        insert_rows(
510            &db,
511            table_id,
512            vec![
513                product![0_u64, 1_u64, 2_u64],
514                product![1_u64, 2_u64, 1_u64],
515                product![2_u64, 2_u64, 2_u64],
516            ],
517        )?;
518
519        assert_query_results(
520            &db,
521            "select * from t where c = 1 and b = 2",
522            &AuthCtx::for_testing(),
523            [product![1_u64, 2_u64, 1_u64]],
524        );
525
526        Ok(())
527    }
528
529    /// Test querying a table with RLS rules
530    #[test]
531    fn test_rls_rules() -> anyhow::Result<()> {
532        let db = TestDB::in_memory()?;
533
534        let id_for_a = identity_from_u8(1);
535        let id_for_b = identity_from_u8(2);
536
537        let users_schema = [("identity", AlgebraicType::identity())];
538        let sales_schema = [
539            ("order_id", AlgebraicType::U64),
540            ("customer", AlgebraicType::identity()),
541        ];
542
543        let users_table_id = db.create_table_for_test("users", &users_schema, &[])?;
544        let sales_table_id = db.create_table_for_test("sales", &sales_schema, &[])?;
545
546        insert_rows(&db, users_table_id, vec![product![id_for_a], product![id_for_b]])?;
547        insert_rows(
548            &db,
549            sales_table_id,
550            vec![
551                product![1u64, id_for_a],
552                product![2u64, id_for_b],
553                product![3u64, id_for_a],
554                product![4u64, id_for_b],
555            ],
556        )?;
557
558        insert_rls_rules(
559            &db,
560            [users_table_id, sales_table_id],
561            [
562                "select * from users where identity = :sender",
563                "select s.* from users u join sales s on u.identity = s.customer",
564            ],
565        )?;
566
567        let auth_for_a = AuthCtx::new(Identity::ZERO, id_for_a);
568        let auth_for_b = AuthCtx::new(Identity::ZERO, id_for_b);
569
570        assert_query_results(
571            &db,
572            // Should only return the identity for sender "a"
573            "select * from users",
574            &auth_for_a,
575            [product![id_for_a]],
576        );
577        assert_query_results(
578            &db,
579            // Should only return the identity for sender "b"
580            "select * from users",
581            &auth_for_b,
582            [product![id_for_b]],
583        );
584        assert_query_results(
585            &db,
586            // Should only return the orders for sender "a"
587            "select * from users where identity = :sender",
588            &auth_for_a,
589            [product![id_for_a]],
590        );
591        assert_query_results(
592            &db,
593            // Should only return the orders for sender "b"
594            "select * from users where identity = :sender",
595            &auth_for_b,
596            [product![id_for_b]],
597        );
598        assert_query_results(
599            &db,
600            // Should only return the orders for sender "a"
601            &format!("select * from users where identity = 0x{}", id_for_a.to_hex()),
602            &auth_for_a,
603            [product![id_for_a]],
604        );
605        assert_query_results(
606            &db,
607            // Should only return the orders for sender "b"
608            &format!("select * from users where identity = 0x{}", id_for_b.to_hex()),
609            &auth_for_b,
610            [product![id_for_b]],
611        );
612        assert_query_results(
613            &db,
614            // Should only return the orders for sender "a"
615            &format!(
616                "select * from users where identity = :sender and identity = 0x{}",
617                id_for_a.to_hex()
618            ),
619            &auth_for_a,
620            [product![id_for_a]],
621        );
622        assert_query_results(
623            &db,
624            // Should only return the orders for sender "b"
625            &format!(
626                "select * from users where identity = :sender and identity = 0x{}",
627                id_for_b.to_hex()
628            ),
629            &auth_for_b,
630            [product![id_for_b]],
631        );
632        assert_query_results(
633            &db,
634            // Should only return the orders for sender "a"
635            &format!(
636                "select * from users where identity = :sender or identity = 0x{}",
637                id_for_b.to_hex()
638            ),
639            &auth_for_a,
640            [product![id_for_a]],
641        );
642        assert_query_results(
643            &db,
644            // Should only return the orders for sender "b"
645            &format!(
646                "select * from users where identity = :sender or identity = 0x{}",
647                id_for_a.to_hex()
648            ),
649            &auth_for_b,
650            [product![id_for_b]],
651        );
652        assert_query_results(
653            &db,
654            // Should not return any rows.
655            // Querying as sender "a", but filtering on sender "b".
656            &format!("select * from users where identity = 0x{}", id_for_b.to_hex()),
657            &auth_for_a,
658            [],
659        );
660        assert_query_results(
661            &db,
662            // Should not return any rows.
663            // Querying as sender "b", but filtering on sender "a".
664            &format!("select * from users where identity = 0x{}", id_for_a.to_hex()),
665            &auth_for_b,
666            [],
667        );
668        assert_query_results(
669            &db,
670            // Should not return any rows.
671            // Querying as sender "a", but filtering on sender "b".
672            &format!(
673                "select * from users where identity = :sender and identity = 0x{}",
674                id_for_b.to_hex()
675            ),
676            &auth_for_a,
677            [],
678        );
679        assert_query_results(
680            &db,
681            // Should not return any rows.
682            // Querying as sender "b", but filtering on sender "a".
683            &format!(
684                "select * from users where identity = :sender and identity = 0x{}",
685                id_for_a.to_hex()
686            ),
687            &auth_for_b,
688            [],
689        );
690        assert_query_results(
691            &db,
692            // Should only return the orders for sender "a"
693            "select * from sales",
694            &auth_for_a,
695            [product![1u64, id_for_a], product![3u64, id_for_a]],
696        );
697        assert_query_results(
698            &db,
699            // Should only return the orders for sender "b"
700            "select * from sales",
701            &auth_for_b,
702            [product![2u64, id_for_b], product![4u64, id_for_b]],
703        );
704        assert_query_results(
705            &db,
706            // Should only return the orders for sender "a"
707            "select s.* from users u join sales s on u.identity = s.customer",
708            &auth_for_a,
709            [product![1u64, id_for_a], product![3u64, id_for_a]],
710        );
711        assert_query_results(
712            &db,
713            // Should only return the orders for sender "b"
714            "select s.* from users u join sales s on u.identity = s.customer",
715            &auth_for_b,
716            [product![2u64, id_for_b], product![4u64, id_for_b]],
717        );
718        assert_query_results(
719            &db,
720            // Should only return the orders for sender "a"
721            "select s.* from users u join sales s on u.identity = s.customer where u.identity = :sender",
722            &auth_for_a,
723            [product![1u64, id_for_a], product![3u64, id_for_a]],
724        );
725        assert_query_results(
726            &db,
727            // Should only return the orders for sender "b"
728            "select s.* from users u join sales s on u.identity = s.customer where u.identity = :sender",
729            &auth_for_b,
730            [product![2u64, id_for_b], product![4u64, id_for_b]],
731        );
732
733        Ok(())
734    }
735
736    /// Test querying tables with multiple levels of RLS rules
737    #[test]
738    fn test_nested_rls_rules() -> anyhow::Result<()> {
739        let db = TestDB::in_memory()?;
740
741        let id_for_a = identity_from_u8(1);
742        let id_for_b = identity_from_u8(2);
743        let id_for_c = identity_from_u8(3);
744
745        let users_schema = [("identity", AlgebraicType::identity())];
746        let sales_schema = [
747            ("order_id", AlgebraicType::U64),
748            ("product_id", AlgebraicType::U64),
749            ("customer", AlgebraicType::identity()),
750        ];
751
752        let users_table_id = db.create_table_for_test("users", &users_schema, &[0.into()])?;
753        let admin_table_id = db.create_table_for_test("admins", &users_schema, &[0.into()])?;
754        let sales_table_id = db.create_table_for_test("sales", &sales_schema, &[0.into()])?;
755
756        insert_rows(&db, admin_table_id, [product![id_for_c]])?;
757        insert_rows(
758            &db,
759            users_table_id,
760            [product![id_for_a], product![id_for_b], product![id_for_c]],
761        )?;
762        insert_rows(
763            &db,
764            sales_table_id,
765            [product![1u64, 1u64, id_for_a], product![2u64, 2u64, id_for_b]],
766        )?;
767
768        insert_rls_rules(
769            &db,
770            [admin_table_id, users_table_id, users_table_id, sales_table_id],
771            [
772                "select * from admins where identity = :sender",
773                "select * from users where identity = :sender",
774                "select users.* from admins join users",
775                "select s.* from users u join sales s on u.identity = s.customer",
776            ],
777        )?;
778
779        let auth_for_a = AuthCtx::new(Identity::ZERO, id_for_a);
780        let auth_for_b = AuthCtx::new(Identity::ZERO, id_for_b);
781        let auth_for_c = AuthCtx::new(Identity::ZERO, id_for_c);
782
783        assert_query_results(
784            &db,
785            "select * from admins",
786            &auth_for_a,
787            // Identity "a" is not an admin
788            [],
789        );
790        assert_query_results(
791            &db,
792            "select * from admins",
793            &auth_for_b,
794            // Identity "b" is not an admin
795            [],
796        );
797        assert_query_results(
798            &db,
799            "select * from admins",
800            &auth_for_c,
801            // Identity "c" is an admin
802            [product![id_for_c]],
803        );
804
805        assert_query_results(
806            &db,
807            "select * from users",
808            &auth_for_a,
809            // Identity "a" can only see its own user
810            vec![product![id_for_a]],
811        );
812        assert_query_results(
813            &db,
814            "select * from users",
815            &auth_for_b,
816            // Identity "b" can only see its own user
817            vec![product![id_for_b]],
818        );
819        assert_query_results(
820            &db,
821            "select * from users",
822            &auth_for_c,
823            // Identity "c" is an admin so it can see everyone's users
824            [product![id_for_a], product![id_for_b], product![id_for_c]],
825        );
826
827        assert_query_results(
828            &db,
829            "select * from sales",
830            &auth_for_a,
831            // Identity "a" can only see its own orders
832            [product![1u64, 1u64, id_for_a]],
833        );
834        assert_query_results(
835            &db,
836            "select * from sales",
837            &auth_for_b,
838            // Identity "b" can only see its own orders
839            [product![2u64, 2u64, id_for_b]],
840        );
841        assert_query_results(
842            &db,
843            "select * from sales",
844            &auth_for_c,
845            // Identity "c" is an admin so it can see everyone's orders
846            [product![1u64, 1u64, id_for_a], product![2u64, 2u64, id_for_b]],
847        );
848
849        Ok(())
850    }
851
852    /// Test projecting columns from both tables in join
853    #[test]
854    fn test_project_join() -> anyhow::Result<()> {
855        let db = TestDB::in_memory()?;
856
857        let t_schema = [("id", AlgebraicType::U8), ("x", AlgebraicType::U8)];
858        let s_schema = [("id", AlgebraicType::U8), ("y", AlgebraicType::U8)];
859
860        let t_id = db.create_table_for_test("t", &t_schema, &[0.into()])?;
861        let s_id = db.create_table_for_test("s", &s_schema, &[0.into()])?;
862
863        insert_rows(&db, t_id, [product![1_u8, 2_u8]])?;
864        insert_rows(&db, s_id, [product![1_u8, 3_u8]])?;
865
866        let id = identity_from_u8(1);
867        let auth = AuthCtx::new(Identity::ZERO, id);
868
869        assert_query_results(
870            &db,
871            "select t.x, s.y from t join s on t.id = s.id",
872            &auth,
873            [product![2_u8, 3_u8]],
874        );
875
876        Ok(())
877    }
878
879    #[test]
880    fn test_select_star_table() -> ResultTest<()> {
881        let (db, input) = create_data(1)?;
882
883        let result = run_for_testing(&db, "SELECT inventory.* FROM inventory")?;
884
885        assert_eq!(result, input.data, "Inventory");
886
887        let result = run_for_testing(
888            &db,
889            "SELECT inventory.inventory_id FROM inventory WHERE inventory.inventory_id = 1",
890        )?;
891
892        assert_eq!(result, vec![product!(1u64)], "Inventory");
893
894        Ok(())
895    }
896
897    #[test]
898    fn test_select_catalog() -> ResultTest<()> {
899        let (db, _) = create_data(1)?;
900
901        let tx = begin_tx(&db);
902        let _ = db.release_tx(tx);
903
904        let result = run_for_testing(
905            &db,
906            &format!("SELECT * FROM {} WHERE table_id = {}", ST_TABLE_NAME, ST_TABLE_ID),
907        )?;
908
909        let pk_col_id: ColId = StTableFields::TableId.into();
910        let row = product![
911            ST_TABLE_ID,
912            ST_TABLE_NAME,
913            StTableType::System.as_str(),
914            StAccess::Public.as_str(),
915            Some(AlgebraicValue::Array(ArrayValue::U16(vec![pk_col_id.0].into()))),
916        ];
917
918        assert_eq!(result, vec![row], "st_table");
919        Ok(())
920    }
921
922    #[test]
923    fn test_select_column() -> ResultTest<()> {
924        let (db, _) = create_data(1)?;
925
926        let result = run_for_testing(&db, "SELECT inventory_id FROM inventory")?;
927
928        let row = product![1u64];
929
930        assert_eq!(result, vec![row], "Inventory");
931        Ok(())
932    }
933
934    #[test]
935    fn test_where() -> ResultTest<()> {
936        let (db, _) = create_data(1)?;
937
938        let result = run_for_testing(&db, "SELECT inventory_id FROM inventory WHERE inventory_id = 1")?;
939
940        let row = product![1u64];
941
942        assert_eq!(result, vec![row], "Inventory");
943        Ok(())
944    }
945
946    #[test]
947    fn test_or() -> ResultTest<()> {
948        let (db, _) = create_data(2)?;
949
950        let mut result = run_for_testing(
951            &db,
952            "SELECT inventory_id FROM inventory WHERE inventory_id = 1 OR inventory_id = 2",
953        )?;
954
955        result.sort();
956
957        assert_eq!(result, vec![product![1u64], product![2u64]], "Inventory");
958        Ok(())
959    }
960
961    #[test]
962    fn test_nested() -> ResultTest<()> {
963        let (db, _) = create_data(2)?;
964
965        let mut result = run_for_testing(
966            &db,
967            "SELECT inventory_id FROM inventory WHERE (inventory_id = 1 OR inventory_id = 2 AND (true))",
968        )?;
969
970        result.sort();
971
972        assert_eq!(result, vec![product![1u64], product![2u64]], "Inventory");
973        Ok(())
974    }
975
976    #[test]
977    fn test_inner_join() -> ResultTest<()> {
978        let data = create_game_data();
979
980        let db = TestDB::durable()?;
981
982        with_auto_commit::<_, TestError>(&db, |tx| {
983            let i = create_table_with_rows(&db, tx, "Inventory", data.inv_ty, &data.inv.data, StAccess::Public)?;
984            let p = create_table_with_rows(&db, tx, "Player", data.player_ty, &data.player.data, StAccess::Public)?;
985            create_table_with_rows(
986                &db,
987                tx,
988                "Location",
989                data.location_ty,
990                &data.location.data,
991                StAccess::Public,
992            )?;
993            Ok((p, i))
994        })?;
995
996        let result = run_for_testing(
997            &db,
998            "SELECT
999        Player.*
1000            FROM
1001        Player
1002        JOIN Location
1003        ON Location.entity_id = Player.entity_id
1004        WHERE Location.x > 0 AND Location.x <= 32 AND Location.z > 0 AND Location.z <= 32",
1005        )?;
1006
1007        let row1 = product!(100u64, 1u64);
1008
1009        assert_eq!(result, vec![row1], "Player JOIN Location");
1010
1011        let result = run_for_testing(
1012            &db,
1013            "SELECT
1014        Inventory.*
1015            FROM
1016        Inventory
1017        JOIN Player
1018        ON Inventory.inventory_id = Player.inventory_id
1019        JOIN Location
1020        ON Player.entity_id = Location.entity_id
1021        WHERE Location.x > 0 AND Location.x <= 32 AND Location.z > 0 AND Location.z <= 32",
1022        )?;
1023
1024        let row1 = product!(1u64, "health");
1025
1026        assert_eq!(result, vec![row1], "Inventory JOIN Player JOIN Location");
1027        Ok(())
1028    }
1029
1030    #[test]
1031    fn test_insert() -> ResultTest<()> {
1032        let (db, mut input) = create_data(1)?;
1033
1034        let result = run_for_testing(&db, "INSERT INTO inventory (inventory_id, name) VALUES (2, 'test')")?;
1035
1036        assert_eq!(result.len(), 0, "Return results");
1037
1038        let mut result = run_for_testing(&db, "SELECT * FROM inventory")?;
1039
1040        input.data.push(product![2u64, "test"]);
1041        input.data.sort();
1042        result.sort();
1043
1044        assert_eq!(result, input.data, "Inventory");
1045
1046        Ok(())
1047    }
1048
1049    #[test]
1050    fn test_delete() -> ResultTest<()> {
1051        let (db, _input) = create_data(1)?;
1052
1053        run_for_testing(&db, "INSERT INTO inventory (inventory_id, name) VALUES (2, 't2')")?;
1054        run_for_testing(&db, "INSERT INTO inventory (inventory_id, name) VALUES (3, 't3')")?;
1055
1056        let result = run_for_testing(&db, "SELECT * FROM inventory")?;
1057        assert_eq!(result.len(), 3, "Not return results");
1058
1059        run_for_testing(&db, "DELETE FROM inventory WHERE inventory.inventory_id = 3")?;
1060
1061        let result = run_for_testing(&db, "SELECT * FROM inventory")?;
1062        assert_eq!(result.len(), 2, "Not delete correct row?");
1063
1064        run_for_testing(&db, "DELETE FROM inventory")?;
1065
1066        let result = run_for_testing(&db, "SELECT * FROM inventory")?;
1067        assert_eq!(result.len(), 0, "Not delete all rows");
1068
1069        Ok(())
1070    }
1071
1072    #[test]
1073    fn test_update() -> ResultTest<()> {
1074        let (db, input) = create_data(1)?;
1075
1076        run_for_testing(&db, "INSERT INTO inventory (inventory_id, name) VALUES (2, 't2')")?;
1077        run_for_testing(&db, "INSERT INTO inventory (inventory_id, name) VALUES (3, 't3')")?;
1078
1079        run_for_testing(&db, "UPDATE inventory SET name = 'c2' WHERE inventory_id = 2")?;
1080
1081        let result = run_for_testing(&db, "SELECT * FROM inventory WHERE inventory_id = 2")?;
1082
1083        let mut change = input;
1084        change.data.clear();
1085        change.data.push(product![2u64, "c2"]);
1086
1087        assert_eq!(result, change.data, "Update Inventory 2");
1088
1089        run_for_testing(&db, "UPDATE inventory SET name = 'c3'")?;
1090
1091        let result = run_for_testing(&db, "SELECT * FROM inventory")?;
1092
1093        let updated: Vec<_> = result
1094            .into_iter()
1095            .map(|x| x.field_as_str(1, None).unwrap().to_string())
1096            .collect();
1097        assert_eq!(vec!["c3"; 3], updated);
1098
1099        Ok(())
1100    }
1101
1102    #[test]
1103    fn test_multi_column() -> ResultTest<()> {
1104        let (db, _input) = create_data(1)?;
1105
1106        // Create table [test] with index on [a, b]
1107        let schema = &[
1108            ("a", AlgebraicType::I32),
1109            ("b", AlgebraicType::I32),
1110            ("c", AlgebraicType::I32),
1111            ("d", AlgebraicType::I32),
1112        ];
1113        let table_id = db.create_table_for_test_multi_column("test", schema, col_list![0, 1])?;
1114        with_auto_commit(&db, |tx| insert(&db, tx, table_id, &product![1, 1, 1, 1]).map(drop))?;
1115
1116        let result = run_for_testing(&db, "select * from test where b = 1 and a = 1")?;
1117
1118        assert_eq!(result, vec![product![1, 1, 1, 1]]);
1119
1120        Ok(())
1121    }
1122
1123    #[test]
1124    fn test_large_query_no_panic() -> ResultTest<()> {
1125        let db = TestDB::durable()?;
1126
1127        let _table_id = db
1128            .create_table_for_test_multi_column(
1129                "test",
1130                &[("x", AlgebraicType::I32), ("y", AlgebraicType::I32)],
1131                col_list![0, 1],
1132            )
1133            .unwrap();
1134
1135        let mut query = "select * from test where ".to_string();
1136        for x in 0..1_000 {
1137            for y in 0..1_000 {
1138                let fragment = format!("((x = {x}) and y = {y}) or");
1139                query.push_str(&fragment);
1140            }
1141        }
1142        query.push_str("((x = 1000) and (y = 1000))");
1143
1144        assert!(run_for_testing(&db, &query).is_err());
1145        Ok(())
1146    }
1147
1148    #[test]
1149    fn test_impossible_bounds_no_panic() -> ResultTest<()> {
1150        let db = TestDB::durable()?;
1151
1152        let table_id = db
1153            .create_table_for_test("test", &[("x", AlgebraicType::I32)], &[ColId(0)])
1154            .unwrap();
1155
1156        with_auto_commit(&db, |tx| {
1157            for i in 0..1000i32 {
1158                insert(&db, tx, table_id, &product!(i)).unwrap();
1159            }
1160            Ok::<(), DBError>(())
1161        })
1162        .unwrap();
1163
1164        let result = run_for_testing(&db, "select * from test where x > 5 and x < 5").unwrap();
1165        assert!(result.is_empty());
1166
1167        let result = run_for_testing(&db, "select * from test where x >= 5 and x < 4").unwrap();
1168        assert!(result.is_empty(), "Expected no rows but found {:#?}", result);
1169
1170        let result = run_for_testing(&db, "select * from test where x > 5 and x <= 4").unwrap();
1171        assert!(result.is_empty());
1172        Ok(())
1173    }
1174
1175    #[test]
1176    fn test_multi_column_two_ranges() -> ResultTest<()> {
1177        let db = TestDB::durable()?;
1178
1179        // Create table [test] with index on [a, b]
1180        let schema = &[("a", AlgebraicType::U8), ("b", AlgebraicType::U8)];
1181        let table_id = db.create_table_for_test_multi_column("test", schema, col_list![0, 1])?;
1182        let row = product![4u8, 8u8];
1183        with_auto_commit(&db, |tx| insert(&db, tx, table_id, &row.clone()).map(drop))?;
1184
1185        let result = run_for_testing(&db, "select * from test where a >= 3 and a <= 5 and b >= 3 and b <= 5")?;
1186
1187        assert!(result.is_empty());
1188
1189        Ok(())
1190    }
1191
1192    #[test]
1193    fn test_row_limit() -> ResultTest<()> {
1194        let db = TestDB::durable()?;
1195
1196        let table_id = db.create_table_for_test("T", &[("a", AlgebraicType::U8)], &[])?;
1197        with_auto_commit(&db, |tx| -> Result<_, DBError> {
1198            for i in 0..5u8 {
1199                insert(&db, tx, table_id, &product!(i))?;
1200            }
1201            Ok(())
1202        })?;
1203
1204        let server = Identity::from_claims("issuer", "server");
1205        let client = Identity::from_claims("issuer", "client");
1206
1207        let internal_auth = AuthCtx::new(server, server);
1208        let external_auth = AuthCtx::new(server, client);
1209
1210        let run = |db, sql, auth, subs| run(db, sql, auth, subs, &mut vec![]);
1211
1212        // No row limit, both queries pass.
1213        assert!(run(&db, "SELECT * FROM T", internal_auth, None).is_ok());
1214        assert!(run(&db, "SELECT * FROM T", external_auth, None).is_ok());
1215
1216        // Set row limit.
1217        assert!(run(&db, "SET row_limit = 4", internal_auth, None).is_ok());
1218
1219        // External query fails.
1220        assert!(run(&db, "SELECT * FROM T", internal_auth, None).is_ok());
1221        assert!(run(&db, "SELECT * FROM T", external_auth, None).is_err());
1222
1223        // Increase row limit.
1224        assert!(run(&db, "DELETE FROM st_var WHERE name = 'row_limit'", internal_auth, None).is_ok());
1225        assert!(run(&db, "SET row_limit = 5", internal_auth, None).is_ok());
1226
1227        // Both queries pass.
1228        assert!(run(&db, "SELECT * FROM T", internal_auth, None).is_ok());
1229        assert!(run(&db, "SELECT * FROM T", external_auth, None).is_ok());
1230
1231        Ok(())
1232    }
1233
1234    // Verify we don't return rows on DML
1235    #[test]
1236    fn test_row_dml() -> ResultTest<()> {
1237        let db = TestDB::durable()?;
1238
1239        let table_id = db.create_table_for_test("T", &[("a", AlgebraicType::U8)], &[])?;
1240        with_auto_commit(&db, |tx| -> Result<_, DBError> {
1241            for i in 0..4u8 {
1242                insert(&db, tx, table_id, &product!(i))?;
1243            }
1244            Ok(())
1245        })?;
1246
1247        let server = Identity::from_claims("issuer", "server");
1248
1249        let internal_auth = AuthCtx::new(server, server);
1250
1251        let run = |db, sql, auth, subs| run(db, sql, auth, subs, &mut vec![]);
1252
1253        let check = |db, sql, auth, metrics: ExecutionMetrics| {
1254            let result = run(db, sql, auth, None)?;
1255            assert_eq!(result.rows, vec![]);
1256            assert_eq!(result.metrics.rows_inserted, metrics.rows_inserted);
1257            assert_eq!(result.metrics.rows_deleted, metrics.rows_deleted);
1258            assert_eq!(result.metrics.rows_updated, metrics.rows_updated);
1259
1260            Ok::<(), DBError>(())
1261        };
1262
1263        let ins = ExecutionMetrics {
1264            rows_inserted: 1,
1265            ..ExecutionMetrics::default()
1266        };
1267        let upd = ExecutionMetrics {
1268            rows_updated: 5,
1269            ..ExecutionMetrics::default()
1270        };
1271        let del = ExecutionMetrics {
1272            rows_deleted: 1,
1273            ..ExecutionMetrics::default()
1274        };
1275
1276        check(&db, "INSERT INTO T (a) VALUES (5)", internal_auth, ins)?;
1277        check(&db, "UPDATE T SET a = 2", internal_auth, upd)?;
1278        assert_eq!(
1279            run(&db, "SELECT * FROM T", internal_auth, None)?.rows,
1280            vec![product!(2u8)]
1281        );
1282        check(&db, "DELETE FROM T", internal_auth, del)?;
1283
1284        Ok(())
1285    }
1286}