1use std::collections::HashMap;
2use std::ops::{Deref, DerefMut};
3use std::sync::Arc;
4
5use crate::expr::{Expr, ProjectList, ProjectName, Relvar};
6use crate::{expr::LeftDeepJoin, statement::Statement};
7use spacetimedb_lib::AlgebraicType;
8use spacetimedb_primitives::TableId;
9use spacetimedb_schema::schema::TableSchema;
10use spacetimedb_sql_parser::ast::BinOp;
11use spacetimedb_sql_parser::{
12 ast::{sub::SqlSelect, SqlFrom, SqlIdent, SqlJoin},
13 parser::sub::parse_subscription,
14};
15
16use super::{
17 errors::{DuplicateName, TypingError, Unresolved, Unsupported},
18 expr::RelExpr,
19 type_expr, type_proj, type_select, StatementCtx, StatementSource,
20};
21
22pub type TypingResult<T> = core::result::Result<T, TypingError>;
24
25pub trait SchemaView {
27 fn table_id(&self, name: &str) -> Option<TableId>;
28 fn schema_for_table(&self, table_id: TableId) -> Option<Arc<TableSchema>>;
29
30 fn schema(&self, name: &str) -> Option<Arc<TableSchema>> {
31 self.table_id(name).and_then(|table_id| self.schema_for_table(table_id))
32 }
33}
34
35#[derive(Default)]
36pub struct Relvars(HashMap<Box<str>, Arc<TableSchema>>);
37
38impl Deref for Relvars {
39 type Target = HashMap<Box<str>, Arc<TableSchema>>;
40 fn deref(&self) -> &Self::Target {
41 &self.0
42 }
43}
44
45impl DerefMut for Relvars {
46 fn deref_mut(&mut self) -> &mut Self::Target {
47 &mut self.0
48 }
49}
50
51pub trait TypeChecker {
52 type Ast;
53 type Set;
54
55 fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList>;
56
57 fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList>;
58
59 fn type_from(from: SqlFrom, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<RelExpr> {
60 match from {
61 SqlFrom::Expr(SqlIdent(name), SqlIdent(alias)) => {
62 let schema = Self::type_relvar(tx, &name)?;
63 vars.insert(alias.clone(), schema.clone());
64 Ok(RelExpr::RelVar(Relvar {
65 schema,
66 alias,
67 delta: None,
68 }))
69 }
70 SqlFrom::Join(SqlIdent(name), SqlIdent(alias), joins) => {
71 let schema = Self::type_relvar(tx, &name)?;
72 vars.insert(alias.clone(), schema.clone());
73 let mut join = RelExpr::RelVar(Relvar {
74 schema,
75 alias,
76 delta: None,
77 });
78
79 for SqlJoin {
80 var: SqlIdent(name),
81 alias: SqlIdent(alias),
82 on,
83 } in joins
84 {
85 if vars.contains_key(&alias) {
87 return Err(DuplicateName(alias.into_string()).into());
88 }
89
90 let lhs = Box::new(join);
91 let rhs = Relvar {
92 schema: Self::type_relvar(tx, &name)?,
93 alias,
94 delta: None,
95 };
96
97 vars.insert(rhs.alias.clone(), rhs.schema.clone());
98
99 if let Some(on) = on {
100 if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool))? {
101 if let (Expr::Field(a), Expr::Field(b)) = (*a, *b) {
102 join = RelExpr::EqJoin(LeftDeepJoin { lhs, rhs }, a, b);
103 continue;
104 }
105 }
106 unreachable!("Unreachability guaranteed by parser")
107 }
108
109 join = RelExpr::LeftDeepJoin(LeftDeepJoin { lhs, rhs });
110 }
111
112 Ok(join)
113 }
114 }
115 }
116
117 fn type_relvar(tx: &impl SchemaView, name: &str) -> TypingResult<Arc<TableSchema>> {
118 tx.schema(name)
119 .ok_or_else(|| Unresolved::table(name))
120 .map_err(TypingError::from)
121 }
122}
123
124struct SubChecker;
126
127impl TypeChecker for SubChecker {
128 type Ast = SqlSelect;
129 type Set = SqlSelect;
130
131 fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList> {
132 Self::type_set(ast, &mut Relvars::default(), tx)
133 }
134
135 fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList> {
136 match ast {
137 SqlSelect {
138 project,
139 from,
140 filter: None,
141 } => {
142 let input = Self::type_from(from, vars, tx)?;
143 type_proj(input, project, vars)
144 }
145 SqlSelect {
146 project,
147 from,
148 filter: Some(expr),
149 } => {
150 let input = Self::type_from(from, vars, tx)?;
151 type_proj(type_select(input, expr, vars)?, project, vars)
152 }
153 }
154 }
155}
156
157pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult<ProjectName> {
159 expect_table_type(SubChecker::type_ast(parse_subscription(sql)?, tx)?)
160}
161
162pub fn type_subscription(ast: SqlSelect, tx: &impl SchemaView) -> TypingResult<ProjectName> {
164 expect_table_type(SubChecker::type_ast(ast, tx)?)
165}
166
167pub fn compile_sql_sub<'a>(sql: &'a str, tx: &impl SchemaView) -> TypingResult<StatementCtx<'a>> {
169 Ok(StatementCtx {
170 statement: Statement::Select(ProjectList::Name(parse_and_type_sub(sql, tx)?)),
171 sql,
172 source: StatementSource::Subscription,
173 })
174}
175
176fn expect_table_type(expr: ProjectList) -> TypingResult<ProjectName> {
178 match expr {
179 ProjectList::Name(proj) => Ok(proj),
180 ProjectList::Limit(input, _) => expect_table_type(*input),
181 ProjectList::List(..) | ProjectList::Agg(..) => Err(Unsupported::ReturnType.into()),
182 }
183}
184
185pub mod test_utils {
186 use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, ProductType};
187 use spacetimedb_primitives::TableId;
188 use spacetimedb_schema::{
189 def::ModuleDef,
190 schema::{Schema, TableSchema},
191 };
192 use std::sync::Arc;
193
194 use super::SchemaView;
195
196 pub fn build_module_def(types: Vec<(&str, ProductType)>) -> ModuleDef {
197 let mut builder = RawModuleDefV9Builder::new();
198 for (name, ty) in types {
199 builder.build_table_with_new_type(name, ty, true);
200 }
201 builder.finish().try_into().expect("failed to generate module def")
202 }
203
204 pub struct SchemaViewer(pub ModuleDef);
205
206 impl SchemaView for SchemaViewer {
207 fn table_id(&self, name: &str) -> Option<TableId> {
208 match name {
209 "t" => Some(TableId(0)),
210 "s" => Some(TableId(1)),
211 _ => None,
212 }
213 }
214
215 fn schema_for_table(&self, table_id: TableId) -> Option<Arc<TableSchema>> {
216 match table_id.idx() {
217 0 => Some((TableId(0), "t")),
218 1 => Some((TableId(1), "s")),
219 _ => None,
220 }
221 .and_then(|(table_id, name)| {
222 self.0
223 .table(name)
224 .map(|def| Arc::new(TableSchema::from_module_def(&self.0, def, (), table_id)))
225 })
226 }
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use crate::check::test_utils::{build_module_def, SchemaViewer};
233 use spacetimedb_lib::{AlgebraicType, ProductType};
234 use spacetimedb_schema::def::ModuleDef;
235
236 use super::parse_and_type_sub;
237
238 fn module_def() -> ModuleDef {
239 build_module_def(vec![
240 (
241 "t",
242 ProductType::from([
243 ("ts", AlgebraicType::timestamp()),
244 ("i8", AlgebraicType::I8),
245 ("u8", AlgebraicType::U8),
246 ("i16", AlgebraicType::I16),
247 ("u16", AlgebraicType::U16),
248 ("i32", AlgebraicType::I32),
249 ("u32", AlgebraicType::U32),
250 ("i64", AlgebraicType::I64),
251 ("u64", AlgebraicType::U64),
252 ("int", AlgebraicType::U32),
253 ("f32", AlgebraicType::F32),
254 ("f64", AlgebraicType::F64),
255 ("i128", AlgebraicType::I128),
256 ("u128", AlgebraicType::U128),
257 ("i256", AlgebraicType::I256),
258 ("u256", AlgebraicType::U256),
259 ("str", AlgebraicType::String),
260 ("arr", AlgebraicType::array(AlgebraicType::String)),
261 ]),
262 ),
263 (
264 "s",
265 ProductType::from([
266 ("id", AlgebraicType::identity()),
267 ("u32", AlgebraicType::U32),
268 ("arr", AlgebraicType::array(AlgebraicType::String)),
269 ("bytes", AlgebraicType::bytes()),
270 ]),
271 ),
272 ])
273 }
274
275 #[test]
276 fn valid_literals() {
277 let tx = SchemaViewer(module_def());
278
279 struct TestCase {
280 sql: &'static str,
281 msg: &'static str,
282 }
283
284 for TestCase { sql, msg } in [
285 TestCase {
286 sql: "select * from t where i32 = -1",
287 msg: "Leading `-`",
288 },
289 TestCase {
290 sql: "select * from t where u32 = +1",
291 msg: "Leading `+`",
292 },
293 TestCase {
294 sql: "select * from t where u32 = 1e3",
295 msg: "Scientific notation",
296 },
297 TestCase {
298 sql: "select * from t where u32 = 1E3",
299 msg: "Case insensitive scientific notation",
300 },
301 TestCase {
302 sql: "select * from t where f32 = 1e3",
303 msg: "Integers can parse as floats",
304 },
305 TestCase {
306 sql: "select * from t where f32 = 1e-3",
307 msg: "Negative exponent",
308 },
309 TestCase {
310 sql: "select * from t where f32 = 0.1",
311 msg: "Standard decimal notation",
312 },
313 TestCase {
314 sql: "select * from t where f32 = .1",
315 msg: "Leading `.`",
316 },
317 TestCase {
318 sql: "select * from t where f32 = 1e40",
319 msg: "Infinity",
320 },
321 TestCase {
322 sql: "select * from t where u256 = 1e40",
323 msg: "u256",
324 },
325 TestCase {
326 sql: "select * from t where ts = '2025-02-10T15:45:30Z'",
327 msg: "timestamp",
328 },
329 TestCase {
330 sql: "select * from t where ts = '2025-02-10T15:45:30.123Z'",
331 msg: "timestamp ms",
332 },
333 TestCase {
334 sql: "select * from t where ts = '2025-02-10T15:45:30.123456789Z'",
335 msg: "timestamp ns",
336 },
337 TestCase {
338 sql: "select * from t where ts = '2025-02-10 15:45:30+02:00'",
339 msg: "timestamp with timezone",
340 },
341 TestCase {
342 sql: "select * from t where ts = '2025-02-10 15:45:30.123+02:00'",
343 msg: "timestamp ms with timezone",
344 },
345 ] {
346 let result = parse_and_type_sub(sql, &tx);
347 assert!(result.is_ok(), "name: {}, error: {}", msg, result.unwrap_err());
348 }
349 }
350
351 #[test]
352 fn valid_literals_for_type() {
353 let tx = SchemaViewer(module_def());
354
355 for ty in [
356 "i8", "u8", "i16", "u16", "i32", "u32", "i64", "u64", "f32", "f64", "i128", "u128", "i256", "u256",
357 ] {
358 let sql = format!("select * from t where {ty} = 127");
359 let result = parse_and_type_sub(&sql, &tx);
360 assert!(result.is_ok(), "Faild to parse {ty}: {}", result.unwrap_err());
361 }
362 }
363
364 #[test]
365 fn invalid_literals() {
366 let tx = SchemaViewer(module_def());
367
368 struct TestCase {
369 sql: &'static str,
370 msg: &'static str,
371 }
372
373 for TestCase { sql, msg } in [
374 TestCase {
375 sql: "select * from t where u8 = -1",
376 msg: "Negative integer for unsigned column",
377 },
378 TestCase {
379 sql: "select * from t where u8 = 1e3",
380 msg: "Out of bounds",
381 },
382 TestCase {
383 sql: "select * from t where u8 = 0.1",
384 msg: "Float as integer",
385 },
386 TestCase {
387 sql: "select * from t where u32 = 1e-3",
388 msg: "Float as integer",
389 },
390 TestCase {
391 sql: "select * from t where i32 = 1e-3",
392 msg: "Float as integer",
393 },
394 ] {
395 let result = parse_and_type_sub(sql, &tx);
396 assert!(result.is_err(), "{msg}");
397 }
398 }
399
400 #[test]
401 fn valid() {
402 let tx = SchemaViewer(module_def());
403
404 struct TestCase {
405 sql: &'static str,
406 msg: &'static str,
407 }
408
409 for TestCase { sql, msg } in [
410 TestCase {
411 sql: "select * from t",
412 msg: "Can select * on any table",
413 },
414 TestCase {
415 sql: "select * from t where true",
416 msg: "Boolean literals are valid in WHERE clause",
417 },
418 TestCase {
419 sql: "select * from t where t.u32 = 1",
420 msg: "Can qualify column references with table name",
421 },
422 TestCase {
423 sql: "select * from t where u32 = 1",
424 msg: "Can leave columns unqualified when unambiguous",
425 },
426 TestCase {
427 sql: "select * from t where t.u32 = 1 or t.str = ''",
428 msg: "Type OR with qualified column references",
429 },
430 TestCase {
431 sql: "select * from s where s.bytes = 0xABCD or bytes = X'ABCD'",
432 msg: "Type OR with mixed qualified and unqualified column references",
433 },
434 TestCase {
435 sql: "select * from s as r where r.bytes = 0xABCD or bytes = X'ABCD'",
436 msg: "Type OR with table alias",
437 },
438 TestCase {
439 sql: "select t.* from t join s",
440 msg: "Type cross join + projection",
441 },
442 TestCase {
443 sql: "select t.* from t join s join s as r where t.u32 = s.u32 and s.u32 = r.u32",
444 msg: "Type self join + projection",
445 },
446 TestCase {
447 sql: "select t.* from t join s on t.u32 = s.u32 where t.f32 = 0.1",
448 msg: "Type inner join + projection",
449 },
450 ] {
451 let result = parse_and_type_sub(sql, &tx);
452 assert!(result.is_ok(), "{msg}");
453 }
454 }
455
456 #[test]
457 fn invalid() {
458 let tx = SchemaViewer(module_def());
459
460 struct TestCase {
461 sql: &'static str,
462 msg: &'static str,
463 }
464
465 for TestCase { sql, msg } in [
466 TestCase {
467 sql: "select * from r",
468 msg: "Table r does not exist",
469 },
470 TestCase {
471 sql: "select * from t where t.a = 1",
472 msg: "Field a does not exist on table t",
473 },
474 TestCase {
475 sql: "select * from t as r where r.a = 1",
476 msg: "Field a does not exist on table t",
477 },
478 TestCase {
479 sql: "select * from t where u32 = 'str'",
480 msg: "Field u32 is not a string",
481 },
482 TestCase {
483 sql: "select * from t where t.u32 = 1.3",
484 msg: "Field u32 is not a float",
485 },
486 TestCase {
487 sql: "select * from t as r where t.u32 = 5",
488 msg: "t is not in scope after alias",
489 },
490 TestCase {
491 sql: "select u32 from t",
492 msg: "Subscriptions must be typed to a single table",
493 },
494 TestCase {
495 sql: "select * from t join s",
496 msg: "Subscriptions must be typed to a single table",
497 },
498 TestCase {
499 sql: "select t.* from t join t",
500 msg: "Self join requires aliases",
501 },
502 TestCase {
503 sql: "select t.* from t join s on t.arr = s.arr",
504 msg: "Product values are not comparable",
505 },
506 TestCase {
507 sql: "select t.* from t join s on t.u32 = r.u32 join s as r",
508 msg: "Alias r is not in scope when it is referenced",
509 },
510 TestCase {
511 sql: "select * from t limit 5",
512 msg: "Subscriptions do not support limit",
513 },
514 ] {
515 let result = parse_and_type_sub(sql, &tx);
516 assert!(result.is_err(), "{msg}");
517 }
518 }
519}