spacetimedb_expr/
check.rs

1use std::collections::HashMap;
2use std::ops::{Deref, DerefMut};
3use std::sync::Arc;
4
5use crate::expr::LeftDeepJoin;
6use crate::expr::{Expr, ProjectList, ProjectName, Relvar};
7use spacetimedb_lib::identity::AuthCtx;
8use spacetimedb_lib::AlgebraicType;
9use spacetimedb_primitives::TableId;
10use spacetimedb_schema::schema::TableSchema;
11use spacetimedb_sql_parser::ast::BinOp;
12use spacetimedb_sql_parser::{
13    ast::{sub::SqlSelect, SqlFrom, SqlIdent, SqlJoin},
14    parser::sub::parse_subscription,
15};
16
17use super::{
18    errors::{DuplicateName, TypingError, Unresolved, Unsupported},
19    expr::RelExpr,
20    type_expr, type_proj, type_select,
21};
22
23/// The result of type checking and name resolution
24pub type TypingResult<T> = core::result::Result<T, TypingError>;
25
26/// A view of the database schema
27pub trait SchemaView {
28    fn table_id(&self, name: &str) -> Option<TableId>;
29    fn schema_for_table(&self, table_id: TableId) -> Option<Arc<TableSchema>>;
30    fn rls_rules_for_table(&self, table_id: TableId) -> anyhow::Result<Vec<Box<str>>>;
31
32    fn schema(&self, name: &str) -> Option<Arc<TableSchema>> {
33        self.table_id(name).and_then(|table_id| self.schema_for_table(table_id))
34    }
35}
36
37#[derive(Default)]
38pub struct Relvars(HashMap<Box<str>, Arc<TableSchema>>);
39
40impl Deref for Relvars {
41    type Target = HashMap<Box<str>, Arc<TableSchema>>;
42    fn deref(&self) -> &Self::Target {
43        &self.0
44    }
45}
46
47impl DerefMut for Relvars {
48    fn deref_mut(&mut self) -> &mut Self::Target {
49        &mut self.0
50    }
51}
52
53pub trait TypeChecker {
54    type Ast;
55    type Set;
56
57    fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList>;
58
59    fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList>;
60
61    fn type_from(from: SqlFrom, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<RelExpr> {
62        match from {
63            SqlFrom::Expr(SqlIdent(name), SqlIdent(alias)) => {
64                let schema = Self::type_relvar(tx, &name)?;
65                vars.insert(alias.clone(), schema.clone());
66                Ok(RelExpr::RelVar(Relvar {
67                    schema,
68                    alias,
69                    delta: None,
70                }))
71            }
72            SqlFrom::Join(SqlIdent(name), SqlIdent(alias), joins) => {
73                let schema = Self::type_relvar(tx, &name)?;
74                vars.insert(alias.clone(), schema.clone());
75                let mut join = RelExpr::RelVar(Relvar {
76                    schema,
77                    alias,
78                    delta: None,
79                });
80
81                for SqlJoin {
82                    var: SqlIdent(name),
83                    alias: SqlIdent(alias),
84                    on,
85                } in joins
86                {
87                    // Check for duplicate aliases
88                    if vars.contains_key(&alias) {
89                        return Err(DuplicateName(alias.into_string()).into());
90                    }
91
92                    let lhs = Box::new(join);
93                    let rhs = Relvar {
94                        schema: Self::type_relvar(tx, &name)?,
95                        alias,
96                        delta: None,
97                    };
98
99                    vars.insert(rhs.alias.clone(), rhs.schema.clone());
100
101                    if let Some(on) = on {
102                        if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool))? {
103                            if let (Expr::Field(a), Expr::Field(b)) = (*a, *b) {
104                                join = RelExpr::EqJoin(LeftDeepJoin { lhs, rhs }, a, b);
105                                continue;
106                            }
107                        }
108                        unreachable!("Unreachability guaranteed by parser")
109                    }
110
111                    join = RelExpr::LeftDeepJoin(LeftDeepJoin { lhs, rhs });
112                }
113
114                Ok(join)
115            }
116        }
117    }
118
119    fn type_relvar(tx: &impl SchemaView, name: &str) -> TypingResult<Arc<TableSchema>> {
120        tx.schema(name)
121            .ok_or_else(|| Unresolved::table(name))
122            .map_err(TypingError::from)
123    }
124}
125
126/// Type checker for subscriptions
127struct SubChecker;
128
129impl TypeChecker for SubChecker {
130    type Ast = SqlSelect;
131    type Set = SqlSelect;
132
133    fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList> {
134        Self::type_set(ast, &mut Relvars::default(), tx)
135    }
136
137    fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList> {
138        match ast {
139            SqlSelect {
140                project,
141                from,
142                filter: None,
143            } => {
144                let input = Self::type_from(from, vars, tx)?;
145                type_proj(input, project, vars)
146            }
147            SqlSelect {
148                project,
149                from,
150                filter: Some(expr),
151            } => {
152                let input = Self::type_from(from, vars, tx)?;
153                type_proj(type_select(input, expr, vars)?, project, vars)
154            }
155        }
156    }
157}
158
159/// Parse and type check a subscription query
160pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult<(ProjectName, bool)> {
161    let ast = parse_subscription(sql)?;
162    let has_param = ast.has_parameter();
163    let ast = ast.resolve_sender(auth.caller);
164    expect_table_type(SubChecker::type_ast(ast, tx)?).map(|plan| (plan, has_param))
165}
166
167/// Returns an error if the input type is not a table type or relvar
168fn expect_table_type(expr: ProjectList) -> TypingResult<ProjectName> {
169    match expr {
170        // Note, this is called before we do any RLS resolution.
171        // Hence this length should always be 1.
172        ProjectList::Name(mut proj) if proj.len() == 1 => Ok(proj.pop().unwrap()),
173        ProjectList::Limit(input, _) => expect_table_type(*input),
174        ProjectList::Name(..) | ProjectList::List(..) | ProjectList::Agg(..) => Err(Unsupported::ReturnType.into()),
175    }
176}
177
178pub mod test_utils {
179    use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, ProductType};
180    use spacetimedb_primitives::TableId;
181    use spacetimedb_schema::{
182        def::ModuleDef,
183        schema::{Schema, TableSchema},
184    };
185    use std::sync::Arc;
186
187    use super::SchemaView;
188
189    pub fn build_module_def(types: Vec<(&str, ProductType)>) -> ModuleDef {
190        let mut builder = RawModuleDefV9Builder::new();
191        for (name, ty) in types {
192            builder.build_table_with_new_type(name, ty, true);
193        }
194        builder.finish().try_into().expect("failed to generate module def")
195    }
196
197    pub struct SchemaViewer(pub ModuleDef);
198
199    impl SchemaView for SchemaViewer {
200        fn table_id(&self, name: &str) -> Option<TableId> {
201            match name {
202                "t" => Some(TableId(0)),
203                "s" => Some(TableId(1)),
204                _ => None,
205            }
206        }
207
208        fn schema_for_table(&self, table_id: TableId) -> Option<Arc<TableSchema>> {
209            match table_id.idx() {
210                0 => Some((TableId(0), "t")),
211                1 => Some((TableId(1), "s")),
212                _ => None,
213            }
214            .and_then(|(table_id, name)| {
215                self.0
216                    .table(name)
217                    .map(|def| Arc::new(TableSchema::from_module_def(&self.0, def, (), table_id)))
218            })
219        }
220
221        fn rls_rules_for_table(&self, _: TableId) -> anyhow::Result<Vec<Box<str>>> {
222            Ok(vec![])
223        }
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use crate::{
230        check::test_utils::{build_module_def, SchemaViewer},
231        expr::ProjectName,
232    };
233    use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, ProductType};
234    use spacetimedb_schema::def::ModuleDef;
235
236    use super::{SchemaView, TypingResult};
237
238    fn module_def() -> ModuleDef {
239        build_module_def(vec![
240            (
241                "t",
242                ProductType::from([
243                    ("ts", AlgebraicType::timestamp()),
244                    ("i8", AlgebraicType::I8),
245                    ("u8", AlgebraicType::U8),
246                    ("i16", AlgebraicType::I16),
247                    ("u16", AlgebraicType::U16),
248                    ("i32", AlgebraicType::I32),
249                    ("u32", AlgebraicType::U32),
250                    ("i64", AlgebraicType::I64),
251                    ("u64", AlgebraicType::U64),
252                    ("int", AlgebraicType::U32),
253                    ("f32", AlgebraicType::F32),
254                    ("f64", AlgebraicType::F64),
255                    ("i128", AlgebraicType::I128),
256                    ("u128", AlgebraicType::U128),
257                    ("i256", AlgebraicType::I256),
258                    ("u256", AlgebraicType::U256),
259                    ("str", AlgebraicType::String),
260                    ("arr", AlgebraicType::array(AlgebraicType::String)),
261                ]),
262            ),
263            (
264                "s",
265                ProductType::from([
266                    ("id", AlgebraicType::identity()),
267                    ("u32", AlgebraicType::U32),
268                    ("arr", AlgebraicType::array(AlgebraicType::String)),
269                    ("bytes", AlgebraicType::bytes()),
270                ]),
271            ),
272        ])
273    }
274
275    /// A wrapper around [super::parse_and_type_sub] that takes a dummy [AuthCtx]
276    fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult<ProjectName> {
277        super::parse_and_type_sub(sql, tx, &AuthCtx::for_testing()).map(|(plan, _)| plan)
278    }
279
280    #[test]
281    fn valid_literals() {
282        let tx = SchemaViewer(module_def());
283
284        struct TestCase {
285            sql: &'static str,
286            msg: &'static str,
287        }
288
289        for TestCase { sql, msg } in [
290            TestCase {
291                sql: "select * from t where i32 = -1",
292                msg: "Leading `-`",
293            },
294            TestCase {
295                sql: "select * from t where u32 = +1",
296                msg: "Leading `+`",
297            },
298            TestCase {
299                sql: "select * from t where u32 = 1e3",
300                msg: "Scientific notation",
301            },
302            TestCase {
303                sql: "select * from t where u32 = 1E3",
304                msg: "Case insensitive scientific notation",
305            },
306            TestCase {
307                sql: "select * from t where f32 = 1e3",
308                msg: "Integers can parse as floats",
309            },
310            TestCase {
311                sql: "select * from t where f32 = 1e-3",
312                msg: "Negative exponent",
313            },
314            TestCase {
315                sql: "select * from t where f32 = 0.1",
316                msg: "Standard decimal notation",
317            },
318            TestCase {
319                sql: "select * from t where f32 = .1",
320                msg: "Leading `.`",
321            },
322            TestCase {
323                sql: "select * from t where f32 = 1e40",
324                msg: "Infinity",
325            },
326            TestCase {
327                sql: "select * from t where u256 = 1e40",
328                msg: "u256",
329            },
330            TestCase {
331                sql: "select * from t where ts = '2025-02-10T15:45:30Z'",
332                msg: "timestamp",
333            },
334            TestCase {
335                sql: "select * from t where ts = '2025-02-10T15:45:30.123Z'",
336                msg: "timestamp ms",
337            },
338            TestCase {
339                sql: "select * from t where ts = '2025-02-10T15:45:30.123456789Z'",
340                msg: "timestamp ns",
341            },
342            TestCase {
343                sql: "select * from t where ts = '2025-02-10 15:45:30+02:00'",
344                msg: "timestamp with timezone",
345            },
346            TestCase {
347                sql: "select * from t where ts = '2025-02-10 15:45:30.123+02:00'",
348                msg: "timestamp ms with timezone",
349            },
350        ] {
351            let result = parse_and_type_sub(sql, &tx);
352            assert!(result.is_ok(), "name: {}, error: {}", msg, result.unwrap_err());
353        }
354    }
355
356    #[test]
357    fn valid_literals_for_type() {
358        let tx = SchemaViewer(module_def());
359
360        for ty in [
361            "i8", "u8", "i16", "u16", "i32", "u32", "i64", "u64", "f32", "f64", "i128", "u128", "i256", "u256",
362        ] {
363            let sql = format!("select * from t where {ty} = 127");
364            let result = parse_and_type_sub(&sql, &tx);
365            assert!(result.is_ok(), "Failed to parse {ty}: {}", result.unwrap_err());
366        }
367    }
368
369    #[test]
370    fn invalid_literals() {
371        let tx = SchemaViewer(module_def());
372
373        struct TestCase {
374            sql: &'static str,
375            msg: &'static str,
376        }
377
378        for TestCase { sql, msg } in [
379            TestCase {
380                sql: "select * from t where u8 = -1",
381                msg: "Negative integer for unsigned column",
382            },
383            TestCase {
384                sql: "select * from t where u8 = 1e3",
385                msg: "Out of bounds",
386            },
387            TestCase {
388                sql: "select * from t where u8 = 0.1",
389                msg: "Float as integer",
390            },
391            TestCase {
392                sql: "select * from t where u32 = 1e-3",
393                msg: "Float as integer",
394            },
395            TestCase {
396                sql: "select * from t where i32 = 1e-3",
397                msg: "Float as integer",
398            },
399        ] {
400            let result = parse_and_type_sub(sql, &tx);
401            assert!(result.is_err(), "{msg}");
402        }
403    }
404
405    #[test]
406    fn valid() {
407        let tx = SchemaViewer(module_def());
408
409        struct TestCase {
410            sql: &'static str,
411            msg: &'static str,
412        }
413
414        for TestCase { sql, msg } in [
415            TestCase {
416                sql: "select * from t",
417                msg: "Can select * on any table",
418            },
419            TestCase {
420                sql: "select * from t where true",
421                msg: "Boolean literals are valid in WHERE clause",
422            },
423            TestCase {
424                sql: "select * from t where t.u32 = 1",
425                msg: "Can qualify column references with table name",
426            },
427            TestCase {
428                sql: "select * from t where u32 = 1",
429                msg: "Can leave columns unqualified when unambiguous",
430            },
431            TestCase {
432                sql: "select * from s where id = :sender",
433                msg: "Can use :sender as an Identity",
434            },
435            TestCase {
436                sql: "select * from s where bytes = :sender",
437                msg: "Can use :sender as a byte array",
438            },
439            TestCase {
440                sql: "select * from t where t.u32 = 1 or t.str = ''",
441                msg: "Type OR with qualified column references",
442            },
443            TestCase {
444                sql: "select * from s where s.bytes = 0xABCD or bytes = X'ABCD'",
445                msg: "Type OR with mixed qualified and unqualified column references",
446            },
447            TestCase {
448                sql: "select * from s as r where r.bytes = 0xABCD or bytes = X'ABCD'",
449                msg: "Type OR with table alias",
450            },
451            TestCase {
452                sql: "select t.* from t join s",
453                msg: "Type cross join + projection",
454            },
455            TestCase {
456                sql: "select t.* from t join s join s as r where t.u32 = s.u32 and s.u32 = r.u32",
457                msg: "Type self join + projection",
458            },
459            TestCase {
460                sql: "select t.* from t join s on t.u32 = s.u32 where t.f32 = 0.1",
461                msg: "Type inner join + projection",
462            },
463        ] {
464            let result = parse_and_type_sub(sql, &tx);
465            assert!(result.is_ok(), "{msg}");
466        }
467    }
468
469    #[test]
470    fn invalid() {
471        let tx = SchemaViewer(module_def());
472
473        struct TestCase {
474            sql: &'static str,
475            msg: &'static str,
476        }
477
478        for TestCase { sql, msg } in [
479            TestCase {
480                sql: "select * from r",
481                msg: "Table r does not exist",
482            },
483            TestCase {
484                sql: "select * from t where arr = :sender",
485                msg: "The :sender param is an identity",
486            },
487            TestCase {
488                sql: "select * from t where t.a = 1",
489                msg: "Field a does not exist on table t",
490            },
491            TestCase {
492                sql: "select * from t as r where r.a = 1",
493                msg: "Field a does not exist on table t",
494            },
495            TestCase {
496                sql: "select * from t where u32 = 'str'",
497                msg: "Field u32 is not a string",
498            },
499            TestCase {
500                sql: "select * from t where t.u32 = 1.3",
501                msg: "Field u32 is not a float",
502            },
503            TestCase {
504                sql: "select * from t as r where t.u32 = 5",
505                msg: "t is not in scope after alias",
506            },
507            TestCase {
508                sql: "select u32 from t",
509                msg: "Subscriptions must be typed to a single table",
510            },
511            TestCase {
512                sql: "select * from t join s",
513                msg: "Subscriptions must be typed to a single table",
514            },
515            TestCase {
516                sql: "select t.* from t join t",
517                msg: "Self join requires aliases",
518            },
519            TestCase {
520                sql: "select t.* from t join s on t.arr = s.arr",
521                msg: "Product values are not comparable",
522            },
523            TestCase {
524                sql: "select t.* from t join s on t.u32 = r.u32 join s as r",
525                msg: "Alias r is not in scope when it is referenced",
526            },
527            TestCase {
528                sql: "select * from t limit 5",
529                msg: "Subscriptions do not support limit",
530            },
531            TestCase {
532                sql: "select t.* from t join s on t.u32 = s.u32 where bytes = 0xABCD",
533                msg: "Columns must be qualified in join expressions",
534            },
535        ] {
536            let result = parse_and_type_sub(sql, &tx);
537            assert!(result.is_err(), "{msg}");
538        }
539    }
540}