spacetimedb_expr/
check.rs

1use std::collections::HashMap;
2use std::ops::{Deref, DerefMut};
3use std::sync::Arc;
4
5use crate::expr::{Expr, ProjectList, ProjectName, Relvar};
6use crate::{expr::LeftDeepJoin, statement::Statement};
7use spacetimedb_lib::AlgebraicType;
8use spacetimedb_primitives::TableId;
9use spacetimedb_schema::schema::TableSchema;
10use spacetimedb_sql_parser::ast::BinOp;
11use spacetimedb_sql_parser::{
12    ast::{sub::SqlSelect, SqlFrom, SqlIdent, SqlJoin},
13    parser::sub::parse_subscription,
14};
15
16use super::{
17    errors::{DuplicateName, TypingError, Unresolved, Unsupported},
18    expr::RelExpr,
19    type_expr, type_proj, type_select, StatementCtx, StatementSource,
20};
21
22/// The result of type checking and name resolution
23pub type TypingResult<T> = core::result::Result<T, TypingError>;
24
25/// A view of the database schema
26pub trait SchemaView {
27    fn table_id(&self, name: &str) -> Option<TableId>;
28    fn schema_for_table(&self, table_id: TableId) -> Option<Arc<TableSchema>>;
29
30    fn schema(&self, name: &str) -> Option<Arc<TableSchema>> {
31        self.table_id(name).and_then(|table_id| self.schema_for_table(table_id))
32    }
33}
34
35#[derive(Default)]
36pub struct Relvars(HashMap<Box<str>, Arc<TableSchema>>);
37
38impl Deref for Relvars {
39    type Target = HashMap<Box<str>, Arc<TableSchema>>;
40    fn deref(&self) -> &Self::Target {
41        &self.0
42    }
43}
44
45impl DerefMut for Relvars {
46    fn deref_mut(&mut self) -> &mut Self::Target {
47        &mut self.0
48    }
49}
50
51pub trait TypeChecker {
52    type Ast;
53    type Set;
54
55    fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList>;
56
57    fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList>;
58
59    fn type_from(from: SqlFrom, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<RelExpr> {
60        match from {
61            SqlFrom::Expr(SqlIdent(name), SqlIdent(alias)) => {
62                let schema = Self::type_relvar(tx, &name)?;
63                vars.insert(alias.clone(), schema.clone());
64                Ok(RelExpr::RelVar(Relvar {
65                    schema,
66                    alias,
67                    delta: None,
68                }))
69            }
70            SqlFrom::Join(SqlIdent(name), SqlIdent(alias), joins) => {
71                let schema = Self::type_relvar(tx, &name)?;
72                vars.insert(alias.clone(), schema.clone());
73                let mut join = RelExpr::RelVar(Relvar {
74                    schema,
75                    alias,
76                    delta: None,
77                });
78
79                for SqlJoin {
80                    var: SqlIdent(name),
81                    alias: SqlIdent(alias),
82                    on,
83                } in joins
84                {
85                    // Check for duplicate aliases
86                    if vars.contains_key(&alias) {
87                        return Err(DuplicateName(alias.into_string()).into());
88                    }
89
90                    let lhs = Box::new(join);
91                    let rhs = Relvar {
92                        schema: Self::type_relvar(tx, &name)?,
93                        alias,
94                        delta: None,
95                    };
96
97                    vars.insert(rhs.alias.clone(), rhs.schema.clone());
98
99                    if let Some(on) = on {
100                        if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool))? {
101                            if let (Expr::Field(a), Expr::Field(b)) = (*a, *b) {
102                                join = RelExpr::EqJoin(LeftDeepJoin { lhs, rhs }, a, b);
103                                continue;
104                            }
105                        }
106                        unreachable!("Unreachability guaranteed by parser")
107                    }
108
109                    join = RelExpr::LeftDeepJoin(LeftDeepJoin { lhs, rhs });
110                }
111
112                Ok(join)
113            }
114        }
115    }
116
117    fn type_relvar(tx: &impl SchemaView, name: &str) -> TypingResult<Arc<TableSchema>> {
118        tx.schema(name)
119            .ok_or_else(|| Unresolved::table(name))
120            .map_err(TypingError::from)
121    }
122}
123
124/// Type checker for subscriptions
125struct SubChecker;
126
127impl TypeChecker for SubChecker {
128    type Ast = SqlSelect;
129    type Set = SqlSelect;
130
131    fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList> {
132        Self::type_set(ast, &mut Relvars::default(), tx)
133    }
134
135    fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList> {
136        match ast {
137            SqlSelect {
138                project,
139                from,
140                filter: None,
141            } => {
142                let input = Self::type_from(from, vars, tx)?;
143                type_proj(input, project, vars)
144            }
145            SqlSelect {
146                project,
147                from,
148                filter: Some(expr),
149            } => {
150                let input = Self::type_from(from, vars, tx)?;
151                type_proj(type_select(input, expr, vars)?, project, vars)
152            }
153        }
154    }
155}
156
157/// Parse and type check a subscription query
158pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult<ProjectName> {
159    expect_table_type(SubChecker::type_ast(parse_subscription(sql)?, tx)?)
160}
161
162/// Type check a subscription query
163pub fn type_subscription(ast: SqlSelect, tx: &impl SchemaView) -> TypingResult<ProjectName> {
164    expect_table_type(SubChecker::type_ast(ast, tx)?)
165}
166
167/// Parse and type check a *subscription* query into a `StatementCtx`
168pub fn compile_sql_sub<'a>(sql: &'a str, tx: &impl SchemaView) -> TypingResult<StatementCtx<'a>> {
169    Ok(StatementCtx {
170        statement: Statement::Select(ProjectList::Name(parse_and_type_sub(sql, tx)?)),
171        sql,
172        source: StatementSource::Subscription,
173    })
174}
175
176/// Returns an error if the input type is not a table type or relvar
177fn expect_table_type(expr: ProjectList) -> TypingResult<ProjectName> {
178    match expr {
179        ProjectList::Name(proj) => Ok(proj),
180        ProjectList::Limit(input, _) => expect_table_type(*input),
181        ProjectList::List(..) | ProjectList::Agg(..) => Err(Unsupported::ReturnType.into()),
182    }
183}
184
185pub mod test_utils {
186    use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, ProductType};
187    use spacetimedb_primitives::TableId;
188    use spacetimedb_schema::{
189        def::ModuleDef,
190        schema::{Schema, TableSchema},
191    };
192    use std::sync::Arc;
193
194    use super::SchemaView;
195
196    pub fn build_module_def(types: Vec<(&str, ProductType)>) -> ModuleDef {
197        let mut builder = RawModuleDefV9Builder::new();
198        for (name, ty) in types {
199            builder.build_table_with_new_type(name, ty, true);
200        }
201        builder.finish().try_into().expect("failed to generate module def")
202    }
203
204    pub struct SchemaViewer(pub ModuleDef);
205
206    impl SchemaView for SchemaViewer {
207        fn table_id(&self, name: &str) -> Option<TableId> {
208            match name {
209                "t" => Some(TableId(0)),
210                "s" => Some(TableId(1)),
211                _ => None,
212            }
213        }
214
215        fn schema_for_table(&self, table_id: TableId) -> Option<Arc<TableSchema>> {
216            match table_id.idx() {
217                0 => Some((TableId(0), "t")),
218                1 => Some((TableId(1), "s")),
219                _ => None,
220            }
221            .and_then(|(table_id, name)| {
222                self.0
223                    .table(name)
224                    .map(|def| Arc::new(TableSchema::from_module_def(&self.0, def, (), table_id)))
225            })
226        }
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use crate::check::test_utils::{build_module_def, SchemaViewer};
233    use spacetimedb_lib::{AlgebraicType, ProductType};
234    use spacetimedb_schema::def::ModuleDef;
235
236    use super::parse_and_type_sub;
237
238    fn module_def() -> ModuleDef {
239        build_module_def(vec![
240            (
241                "t",
242                ProductType::from([
243                    ("i8", AlgebraicType::I8),
244                    ("u8", AlgebraicType::U8),
245                    ("i16", AlgebraicType::I16),
246                    ("u16", AlgebraicType::U16),
247                    ("i32", AlgebraicType::I32),
248                    ("u32", AlgebraicType::U32),
249                    ("i64", AlgebraicType::I64),
250                    ("u64", AlgebraicType::U64),
251                    ("int", AlgebraicType::U32),
252                    ("f32", AlgebraicType::F32),
253                    ("f64", AlgebraicType::F64),
254                    ("i128", AlgebraicType::I128),
255                    ("u128", AlgebraicType::U128),
256                    ("i256", AlgebraicType::I256),
257                    ("u256", AlgebraicType::U256),
258                    ("str", AlgebraicType::String),
259                    ("arr", AlgebraicType::array(AlgebraicType::String)),
260                ]),
261            ),
262            (
263                "s",
264                ProductType::from([
265                    ("id", AlgebraicType::identity()),
266                    ("u32", AlgebraicType::U32),
267                    ("arr", AlgebraicType::array(AlgebraicType::String)),
268                    ("bytes", AlgebraicType::bytes()),
269                ]),
270            ),
271        ])
272    }
273
274    #[test]
275    fn valid_literals() {
276        let tx = SchemaViewer(module_def());
277
278        struct TestCase {
279            sql: &'static str,
280            msg: &'static str,
281        }
282
283        for TestCase { sql, msg } in [
284            TestCase {
285                sql: "select * from t where i32 = -1",
286                msg: "Leading `-`",
287            },
288            TestCase {
289                sql: "select * from t where u32 = +1",
290                msg: "Leading `+`",
291            },
292            TestCase {
293                sql: "select * from t where u32 = 1e3",
294                msg: "Scientific notation",
295            },
296            TestCase {
297                sql: "select * from t where u32 = 1E3",
298                msg: "Case insensitive scientific notation",
299            },
300            TestCase {
301                sql: "select * from t where f32 = 1e3",
302                msg: "Integers can parse as floats",
303            },
304            TestCase {
305                sql: "select * from t where f32 = 1e-3",
306                msg: "Negative exponent",
307            },
308            TestCase {
309                sql: "select * from t where f32 = 0.1",
310                msg: "Standard decimal notation",
311            },
312            TestCase {
313                sql: "select * from t where f32 = .1",
314                msg: "Leading `.`",
315            },
316            TestCase {
317                sql: "select * from t where f32 = 1e40",
318                msg: "Infinity",
319            },
320            TestCase {
321                sql: "select * from t where u256 = 1e40",
322                msg: "u256",
323            },
324        ] {
325            let result = parse_and_type_sub(sql, &tx);
326            assert!(result.is_ok(), "{msg}");
327        }
328    }
329
330    #[test]
331    fn valid_literals_for_type() {
332        let tx = SchemaViewer(module_def());
333
334        for ty in [
335            "i8", "u8", "i16", "u16", "i32", "u32", "i64", "u64", "f32", "f64", "i128", "u128", "i256", "u256",
336        ] {
337            let sql = format!("select * from t where {ty} = 127");
338            let result = parse_and_type_sub(&sql, &tx);
339            assert!(result.is_ok(), "Faild to parse {ty}: {}", result.unwrap_err());
340        }
341    }
342
343    #[test]
344    fn invalid_literals() {
345        let tx = SchemaViewer(module_def());
346
347        struct TestCase {
348            sql: &'static str,
349            msg: &'static str,
350        }
351
352        for TestCase { sql, msg } in [
353            TestCase {
354                sql: "select * from t where u8 = -1",
355                msg: "Negative integer for unsigned column",
356            },
357            TestCase {
358                sql: "select * from t where u8 = 1e3",
359                msg: "Out of bounds",
360            },
361            TestCase {
362                sql: "select * from t where u8 = 0.1",
363                msg: "Float as integer",
364            },
365            TestCase {
366                sql: "select * from t where u32 = 1e-3",
367                msg: "Float as integer",
368            },
369            TestCase {
370                sql: "select * from t where i32 = 1e-3",
371                msg: "Float as integer",
372            },
373        ] {
374            let result = parse_and_type_sub(sql, &tx);
375            assert!(result.is_err(), "{msg}");
376        }
377    }
378
379    #[test]
380    fn valid() {
381        let tx = SchemaViewer(module_def());
382
383        struct TestCase {
384            sql: &'static str,
385            msg: &'static str,
386        }
387
388        for TestCase { sql, msg } in [
389            TestCase {
390                sql: "select * from t",
391                msg: "Can select * on any table",
392            },
393            TestCase {
394                sql: "select * from t where true",
395                msg: "Boolean literals are valid in WHERE clause",
396            },
397            TestCase {
398                sql: "select * from t where t.u32 = 1",
399                msg: "Can qualify column references with table name",
400            },
401            TestCase {
402                sql: "select * from t where u32 = 1",
403                msg: "Can leave columns unqualified when unambiguous",
404            },
405            TestCase {
406                sql: "select * from t where t.u32 = 1 or t.str = ''",
407                msg: "Type OR with qualified column references",
408            },
409            TestCase {
410                sql: "select * from s where s.bytes = 0xABCD or bytes = X'ABCD'",
411                msg: "Type OR with mixed qualified and unqualified column references",
412            },
413            TestCase {
414                sql: "select * from s as r where r.bytes = 0xABCD or bytes = X'ABCD'",
415                msg: "Type OR with table alias",
416            },
417            TestCase {
418                sql: "select t.* from t join s",
419                msg: "Type cross join + projection",
420            },
421            TestCase {
422                sql: "select t.* from t join s join s as r where t.u32 = s.u32 and s.u32 = r.u32",
423                msg: "Type self join + projection",
424            },
425            TestCase {
426                sql: "select t.* from t join s on t.u32 = s.u32 where t.f32 = 0.1",
427                msg: "Type inner join + projection",
428            },
429        ] {
430            let result = parse_and_type_sub(sql, &tx);
431            assert!(result.is_ok(), "{msg}");
432        }
433    }
434
435    #[test]
436    fn invalid() {
437        let tx = SchemaViewer(module_def());
438
439        struct TestCase {
440            sql: &'static str,
441            msg: &'static str,
442        }
443
444        for TestCase { sql, msg } in [
445            TestCase {
446                sql: "select * from r",
447                msg: "Table r does not exist",
448            },
449            TestCase {
450                sql: "select * from t where t.a = 1",
451                msg: "Field a does not exist on table t",
452            },
453            TestCase {
454                sql: "select * from t as r where r.a = 1",
455                msg: "Field a does not exist on table t",
456            },
457            TestCase {
458                sql: "select * from t where u32 = 'str'",
459                msg: "Field u32 is not a string",
460            },
461            TestCase {
462                sql: "select * from t where t.u32 = 1.3",
463                msg: "Field u32 is not a float",
464            },
465            TestCase {
466                sql: "select * from t as r where t.u32 = 5",
467                msg: "t is not in scope after alias",
468            },
469            TestCase {
470                sql: "select u32 from t",
471                msg: "Subscriptions must be typed to a single table",
472            },
473            TestCase {
474                sql: "select * from t join s",
475                msg: "Subscriptions must be typed to a single table",
476            },
477            TestCase {
478                sql: "select t.* from t join t",
479                msg: "Self join requires aliases",
480            },
481            TestCase {
482                sql: "select t.* from t join s on t.arr = s.arr",
483                msg: "Product values are not comparable",
484            },
485            TestCase {
486                sql: "select t.* from t join s on t.u32 = r.u32 join s as r",
487                msg: "Alias r is not in scope when it is referenced",
488            },
489            TestCase {
490                sql: "select * from t limit 5",
491                msg: "Subscriptions do not support limit",
492            },
493        ] {
494            let result = parse_and_type_sub(sql, &tx);
495            assert!(result.is_err(), "{msg}");
496        }
497    }
498}