1use std::collections::HashMap;
2use std::ops::{Deref, DerefMut};
3use std::sync::Arc;
4
5use crate::expr::LeftDeepJoin;
6use crate::expr::{Expr, ProjectList, ProjectName, Relvar};
7use spacetimedb_lib::identity::AuthCtx;
8use spacetimedb_lib::AlgebraicType;
9use spacetimedb_primitives::TableId;
10use spacetimedb_schema::schema::TableSchema;
11use spacetimedb_sql_parser::ast::BinOp;
12use spacetimedb_sql_parser::{
13 ast::{sub::SqlSelect, SqlFrom, SqlIdent, SqlJoin},
14 parser::sub::parse_subscription,
15};
16
17use super::{
18 errors::{DuplicateName, TypingError, Unresolved, Unsupported},
19 expr::RelExpr,
20 type_expr, type_proj, type_select,
21};
22
23pub type TypingResult<T> = core::result::Result<T, TypingError>;
25
26pub trait SchemaView {
28 fn table_id(&self, name: &str) -> Option<TableId>;
29 fn schema_for_table(&self, table_id: TableId) -> Option<Arc<TableSchema>>;
30 fn rls_rules_for_table(&self, table_id: TableId) -> anyhow::Result<Vec<Box<str>>>;
31
32 fn schema(&self, name: &str) -> Option<Arc<TableSchema>> {
33 self.table_id(name).and_then(|table_id| self.schema_for_table(table_id))
34 }
35}
36
37#[derive(Default)]
38pub struct Relvars(HashMap<Box<str>, Arc<TableSchema>>);
39
40impl Deref for Relvars {
41 type Target = HashMap<Box<str>, Arc<TableSchema>>;
42 fn deref(&self) -> &Self::Target {
43 &self.0
44 }
45}
46
47impl DerefMut for Relvars {
48 fn deref_mut(&mut self) -> &mut Self::Target {
49 &mut self.0
50 }
51}
52
53pub trait TypeChecker {
54 type Ast;
55 type Set;
56
57 fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList>;
58
59 fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList>;
60
61 fn type_from(from: SqlFrom, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<RelExpr> {
62 match from {
63 SqlFrom::Expr(SqlIdent(name), SqlIdent(alias)) => {
64 let schema = Self::type_relvar(tx, &name)?;
65 vars.insert(alias.clone(), schema.clone());
66 Ok(RelExpr::RelVar(Relvar {
67 schema,
68 alias,
69 delta: None,
70 }))
71 }
72 SqlFrom::Join(SqlIdent(name), SqlIdent(alias), joins) => {
73 let schema = Self::type_relvar(tx, &name)?;
74 vars.insert(alias.clone(), schema.clone());
75 let mut join = RelExpr::RelVar(Relvar {
76 schema,
77 alias,
78 delta: None,
79 });
80
81 for SqlJoin {
82 var: SqlIdent(name),
83 alias: SqlIdent(alias),
84 on,
85 } in joins
86 {
87 if vars.contains_key(&alias) {
89 return Err(DuplicateName(alias.into_string()).into());
90 }
91
92 let lhs = Box::new(join);
93 let rhs = Relvar {
94 schema: Self::type_relvar(tx, &name)?,
95 alias,
96 delta: None,
97 };
98
99 vars.insert(rhs.alias.clone(), rhs.schema.clone());
100
101 if let Some(on) = on {
102 if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool))? {
103 if let (Expr::Field(a), Expr::Field(b)) = (*a, *b) {
104 join = RelExpr::EqJoin(LeftDeepJoin { lhs, rhs }, a, b);
105 continue;
106 }
107 }
108 unreachable!("Unreachability guaranteed by parser")
109 }
110
111 join = RelExpr::LeftDeepJoin(LeftDeepJoin { lhs, rhs });
112 }
113
114 Ok(join)
115 }
116 }
117 }
118
119 fn type_relvar(tx: &impl SchemaView, name: &str) -> TypingResult<Arc<TableSchema>> {
120 tx.schema(name)
121 .ok_or_else(|| Unresolved::table(name))
122 .map_err(TypingError::from)
123 }
124}
125
126struct SubChecker;
128
129impl TypeChecker for SubChecker {
130 type Ast = SqlSelect;
131 type Set = SqlSelect;
132
133 fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList> {
134 Self::type_set(ast, &mut Relvars::default(), tx)
135 }
136
137 fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList> {
138 match ast {
139 SqlSelect {
140 project,
141 from,
142 filter: None,
143 } => {
144 let input = Self::type_from(from, vars, tx)?;
145 type_proj(input, project, vars)
146 }
147 SqlSelect {
148 project,
149 from,
150 filter: Some(expr),
151 } => {
152 let input = Self::type_from(from, vars, tx)?;
153 type_proj(type_select(input, expr, vars)?, project, vars)
154 }
155 }
156 }
157}
158
159pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult<(ProjectName, bool)> {
161 let ast = parse_subscription(sql)?;
162 let has_param = ast.has_parameter();
163 let ast = ast.resolve_sender(auth.caller);
164 expect_table_type(SubChecker::type_ast(ast, tx)?).map(|plan| (plan, has_param))
165}
166
167fn expect_table_type(expr: ProjectList) -> TypingResult<ProjectName> {
169 match expr {
170 ProjectList::Name(mut proj) if proj.len() == 1 => Ok(proj.pop().unwrap()),
173 ProjectList::Limit(input, _) => expect_table_type(*input),
174 ProjectList::Name(..) | ProjectList::List(..) | ProjectList::Agg(..) => Err(Unsupported::ReturnType.into()),
175 }
176}
177
178pub mod test_utils {
179 use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, ProductType};
180 use spacetimedb_primitives::TableId;
181 use spacetimedb_schema::{
182 def::ModuleDef,
183 schema::{Schema, TableSchema},
184 };
185 use std::sync::Arc;
186
187 use super::SchemaView;
188
189 pub fn build_module_def(types: Vec<(&str, ProductType)>) -> ModuleDef {
190 let mut builder = RawModuleDefV9Builder::new();
191 for (name, ty) in types {
192 builder.build_table_with_new_type(name, ty, true);
193 }
194 builder.finish().try_into().expect("failed to generate module def")
195 }
196
197 pub struct SchemaViewer(pub ModuleDef);
198
199 impl SchemaView for SchemaViewer {
200 fn table_id(&self, name: &str) -> Option<TableId> {
201 match name {
202 "t" => Some(TableId(0)),
203 "s" => Some(TableId(1)),
204 _ => None,
205 }
206 }
207
208 fn schema_for_table(&self, table_id: TableId) -> Option<Arc<TableSchema>> {
209 match table_id.idx() {
210 0 => Some((TableId(0), "t")),
211 1 => Some((TableId(1), "s")),
212 _ => None,
213 }
214 .and_then(|(table_id, name)| {
215 self.0
216 .table(name)
217 .map(|def| Arc::new(TableSchema::from_module_def(&self.0, def, (), table_id)))
218 })
219 }
220
221 fn rls_rules_for_table(&self, _: TableId) -> anyhow::Result<Vec<Box<str>>> {
222 Ok(vec![])
223 }
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use crate::{
230 check::test_utils::{build_module_def, SchemaViewer},
231 expr::ProjectName,
232 };
233 use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, ProductType};
234 use spacetimedb_schema::def::ModuleDef;
235
236 use super::{SchemaView, TypingResult};
237
238 fn module_def() -> ModuleDef {
239 build_module_def(vec![
240 (
241 "t",
242 ProductType::from([
243 ("ts", AlgebraicType::timestamp()),
244 ("i8", AlgebraicType::I8),
245 ("u8", AlgebraicType::U8),
246 ("i16", AlgebraicType::I16),
247 ("u16", AlgebraicType::U16),
248 ("i32", AlgebraicType::I32),
249 ("u32", AlgebraicType::U32),
250 ("i64", AlgebraicType::I64),
251 ("u64", AlgebraicType::U64),
252 ("int", AlgebraicType::U32),
253 ("f32", AlgebraicType::F32),
254 ("f64", AlgebraicType::F64),
255 ("i128", AlgebraicType::I128),
256 ("u128", AlgebraicType::U128),
257 ("i256", AlgebraicType::I256),
258 ("u256", AlgebraicType::U256),
259 ("str", AlgebraicType::String),
260 ("arr", AlgebraicType::array(AlgebraicType::String)),
261 ]),
262 ),
263 (
264 "s",
265 ProductType::from([
266 ("id", AlgebraicType::identity()),
267 ("u32", AlgebraicType::U32),
268 ("arr", AlgebraicType::array(AlgebraicType::String)),
269 ("bytes", AlgebraicType::bytes()),
270 ]),
271 ),
272 ])
273 }
274
275 fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult<ProjectName> {
277 super::parse_and_type_sub(sql, tx, &AuthCtx::for_testing()).map(|(plan, _)| plan)
278 }
279
280 #[test]
281 fn valid_literals() {
282 let tx = SchemaViewer(module_def());
283
284 struct TestCase {
285 sql: &'static str,
286 msg: &'static str,
287 }
288
289 for TestCase { sql, msg } in [
290 TestCase {
291 sql: "select * from t where i32 = -1",
292 msg: "Leading `-`",
293 },
294 TestCase {
295 sql: "select * from t where u32 = +1",
296 msg: "Leading `+`",
297 },
298 TestCase {
299 sql: "select * from t where u32 = 1e3",
300 msg: "Scientific notation",
301 },
302 TestCase {
303 sql: "select * from t where u32 = 1E3",
304 msg: "Case insensitive scientific notation",
305 },
306 TestCase {
307 sql: "select * from t where f32 = 1e3",
308 msg: "Integers can parse as floats",
309 },
310 TestCase {
311 sql: "select * from t where f32 = 1e-3",
312 msg: "Negative exponent",
313 },
314 TestCase {
315 sql: "select * from t where f32 = 0.1",
316 msg: "Standard decimal notation",
317 },
318 TestCase {
319 sql: "select * from t where f32 = .1",
320 msg: "Leading `.`",
321 },
322 TestCase {
323 sql: "select * from t where f32 = 1e40",
324 msg: "Infinity",
325 },
326 TestCase {
327 sql: "select * from t where u256 = 1e40",
328 msg: "u256",
329 },
330 TestCase {
331 sql: "select * from t where ts = '2025-02-10T15:45:30Z'",
332 msg: "timestamp",
333 },
334 TestCase {
335 sql: "select * from t where ts = '2025-02-10T15:45:30.123Z'",
336 msg: "timestamp ms",
337 },
338 TestCase {
339 sql: "select * from t where ts = '2025-02-10T15:45:30.123456789Z'",
340 msg: "timestamp ns",
341 },
342 TestCase {
343 sql: "select * from t where ts = '2025-02-10 15:45:30+02:00'",
344 msg: "timestamp with timezone",
345 },
346 TestCase {
347 sql: "select * from t where ts = '2025-02-10 15:45:30.123+02:00'",
348 msg: "timestamp ms with timezone",
349 },
350 ] {
351 let result = parse_and_type_sub(sql, &tx);
352 assert!(result.is_ok(), "name: {}, error: {}", msg, result.unwrap_err());
353 }
354 }
355
356 #[test]
357 fn valid_literals_for_type() {
358 let tx = SchemaViewer(module_def());
359
360 for ty in [
361 "i8", "u8", "i16", "u16", "i32", "u32", "i64", "u64", "f32", "f64", "i128", "u128", "i256", "u256",
362 ] {
363 let sql = format!("select * from t where {ty} = 127");
364 let result = parse_and_type_sub(&sql, &tx);
365 assert!(result.is_ok(), "Failed to parse {ty}: {}", result.unwrap_err());
366 }
367 }
368
369 #[test]
370 fn invalid_literals() {
371 let tx = SchemaViewer(module_def());
372
373 struct TestCase {
374 sql: &'static str,
375 msg: &'static str,
376 }
377
378 for TestCase { sql, msg } in [
379 TestCase {
380 sql: "select * from t where u8 = -1",
381 msg: "Negative integer for unsigned column",
382 },
383 TestCase {
384 sql: "select * from t where u8 = 1e3",
385 msg: "Out of bounds",
386 },
387 TestCase {
388 sql: "select * from t where u8 = 0.1",
389 msg: "Float as integer",
390 },
391 TestCase {
392 sql: "select * from t where u32 = 1e-3",
393 msg: "Float as integer",
394 },
395 TestCase {
396 sql: "select * from t where i32 = 1e-3",
397 msg: "Float as integer",
398 },
399 ] {
400 let result = parse_and_type_sub(sql, &tx);
401 assert!(result.is_err(), "{msg}");
402 }
403 }
404
405 #[test]
406 fn valid() {
407 let tx = SchemaViewer(module_def());
408
409 struct TestCase {
410 sql: &'static str,
411 msg: &'static str,
412 }
413
414 for TestCase { sql, msg } in [
415 TestCase {
416 sql: "select * from t",
417 msg: "Can select * on any table",
418 },
419 TestCase {
420 sql: "select * from t where true",
421 msg: "Boolean literals are valid in WHERE clause",
422 },
423 TestCase {
424 sql: "select * from t where t.u32 = 1",
425 msg: "Can qualify column references with table name",
426 },
427 TestCase {
428 sql: "select * from t where u32 = 1",
429 msg: "Can leave columns unqualified when unambiguous",
430 },
431 TestCase {
432 sql: "select * from s where id = :sender",
433 msg: "Can use :sender as an Identity",
434 },
435 TestCase {
436 sql: "select * from s where bytes = :sender",
437 msg: "Can use :sender as a byte array",
438 },
439 TestCase {
440 sql: "select * from t where t.u32 = 1 or t.str = ''",
441 msg: "Type OR with qualified column references",
442 },
443 TestCase {
444 sql: "select * from s where s.bytes = 0xABCD or bytes = X'ABCD'",
445 msg: "Type OR with mixed qualified and unqualified column references",
446 },
447 TestCase {
448 sql: "select * from s as r where r.bytes = 0xABCD or bytes = X'ABCD'",
449 msg: "Type OR with table alias",
450 },
451 TestCase {
452 sql: "select t.* from t join s",
453 msg: "Type cross join + projection",
454 },
455 TestCase {
456 sql: "select t.* from t join s join s as r where t.u32 = s.u32 and s.u32 = r.u32",
457 msg: "Type self join + projection",
458 },
459 TestCase {
460 sql: "select t.* from t join s on t.u32 = s.u32 where t.f32 = 0.1",
461 msg: "Type inner join + projection",
462 },
463 ] {
464 let result = parse_and_type_sub(sql, &tx);
465 assert!(result.is_ok(), "{msg}");
466 }
467 }
468
469 #[test]
470 fn invalid() {
471 let tx = SchemaViewer(module_def());
472
473 struct TestCase {
474 sql: &'static str,
475 msg: &'static str,
476 }
477
478 for TestCase { sql, msg } in [
479 TestCase {
480 sql: "select * from r",
481 msg: "Table r does not exist",
482 },
483 TestCase {
484 sql: "select * from t where arr = :sender",
485 msg: "The :sender param is an identity",
486 },
487 TestCase {
488 sql: "select * from t where t.a = 1",
489 msg: "Field a does not exist on table t",
490 },
491 TestCase {
492 sql: "select * from t as r where r.a = 1",
493 msg: "Field a does not exist on table t",
494 },
495 TestCase {
496 sql: "select * from t where u32 = 'str'",
497 msg: "Field u32 is not a string",
498 },
499 TestCase {
500 sql: "select * from t where t.u32 = 1.3",
501 msg: "Field u32 is not a float",
502 },
503 TestCase {
504 sql: "select * from t as r where t.u32 = 5",
505 msg: "t is not in scope after alias",
506 },
507 TestCase {
508 sql: "select u32 from t",
509 msg: "Subscriptions must be typed to a single table",
510 },
511 TestCase {
512 sql: "select * from t join s",
513 msg: "Subscriptions must be typed to a single table",
514 },
515 TestCase {
516 sql: "select t.* from t join t",
517 msg: "Self join requires aliases",
518 },
519 TestCase {
520 sql: "select t.* from t join s on t.arr = s.arr",
521 msg: "Product values are not comparable",
522 },
523 TestCase {
524 sql: "select t.* from t join s on t.u32 = r.u32 join s as r",
525 msg: "Alias r is not in scope when it is referenced",
526 },
527 TestCase {
528 sql: "select * from t limit 5",
529 msg: "Subscriptions do not support limit",
530 },
531 TestCase {
532 sql: "select t.* from t join s on t.u32 = s.u32 where bytes = 0xABCD",
533 msg: "Columns must be qualified in join expressions",
534 },
535 ] {
536 let result = parse_and_type_sub(sql, &tx);
537 assert!(result.is_err(), "{msg}");
538 }
539 }
540}