1use std::collections::HashMap;
2use std::ops::{Deref, DerefMut};
3use std::sync::Arc;
4
5use crate::expr::{Expr, ProjectList, ProjectName, Relvar};
6use crate::{expr::LeftDeepJoin, statement::Statement};
7use spacetimedb_lib::AlgebraicType;
8use spacetimedb_primitives::TableId;
9use spacetimedb_schema::schema::TableSchema;
10use spacetimedb_sql_parser::ast::BinOp;
11use spacetimedb_sql_parser::{
12 ast::{sub::SqlSelect, SqlFrom, SqlIdent, SqlJoin},
13 parser::sub::parse_subscription,
14};
15
16use super::{
17 errors::{DuplicateName, TypingError, Unresolved, Unsupported},
18 expr::RelExpr,
19 type_expr, type_proj, type_select, StatementCtx, StatementSource,
20};
21
22pub type TypingResult<T> = core::result::Result<T, TypingError>;
24
25pub trait SchemaView {
27 fn table_id(&self, name: &str) -> Option<TableId>;
28 fn schema_for_table(&self, table_id: TableId) -> Option<Arc<TableSchema>>;
29
30 fn schema(&self, name: &str) -> Option<Arc<TableSchema>> {
31 self.table_id(name).and_then(|table_id| self.schema_for_table(table_id))
32 }
33}
34
35#[derive(Default)]
36pub struct Relvars(HashMap<Box<str>, Arc<TableSchema>>);
37
38impl Deref for Relvars {
39 type Target = HashMap<Box<str>, Arc<TableSchema>>;
40 fn deref(&self) -> &Self::Target {
41 &self.0
42 }
43}
44
45impl DerefMut for Relvars {
46 fn deref_mut(&mut self) -> &mut Self::Target {
47 &mut self.0
48 }
49}
50
51pub trait TypeChecker {
52 type Ast;
53 type Set;
54
55 fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList>;
56
57 fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList>;
58
59 fn type_from(from: SqlFrom, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<RelExpr> {
60 match from {
61 SqlFrom::Expr(SqlIdent(name), SqlIdent(alias)) => {
62 let schema = Self::type_relvar(tx, &name)?;
63 vars.insert(alias.clone(), schema.clone());
64 Ok(RelExpr::RelVar(Relvar {
65 schema,
66 alias,
67 delta: None,
68 }))
69 }
70 SqlFrom::Join(SqlIdent(name), SqlIdent(alias), joins) => {
71 let schema = Self::type_relvar(tx, &name)?;
72 vars.insert(alias.clone(), schema.clone());
73 let mut join = RelExpr::RelVar(Relvar {
74 schema,
75 alias,
76 delta: None,
77 });
78
79 for SqlJoin {
80 var: SqlIdent(name),
81 alias: SqlIdent(alias),
82 on,
83 } in joins
84 {
85 if vars.contains_key(&alias) {
87 return Err(DuplicateName(alias.into_string()).into());
88 }
89
90 let lhs = Box::new(join);
91 let rhs = Relvar {
92 schema: Self::type_relvar(tx, &name)?,
93 alias,
94 delta: None,
95 };
96
97 vars.insert(rhs.alias.clone(), rhs.schema.clone());
98
99 if let Some(on) = on {
100 if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool))? {
101 if let (Expr::Field(a), Expr::Field(b)) = (*a, *b) {
102 join = RelExpr::EqJoin(LeftDeepJoin { lhs, rhs }, a, b);
103 continue;
104 }
105 }
106 unreachable!("Unreachability guaranteed by parser")
107 }
108
109 join = RelExpr::LeftDeepJoin(LeftDeepJoin { lhs, rhs });
110 }
111
112 Ok(join)
113 }
114 }
115 }
116
117 fn type_relvar(tx: &impl SchemaView, name: &str) -> TypingResult<Arc<TableSchema>> {
118 tx.schema(name)
119 .ok_or_else(|| Unresolved::table(name))
120 .map_err(TypingError::from)
121 }
122}
123
124struct SubChecker;
126
127impl TypeChecker for SubChecker {
128 type Ast = SqlSelect;
129 type Set = SqlSelect;
130
131 fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList> {
132 Self::type_set(ast, &mut Relvars::default(), tx)
133 }
134
135 fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList> {
136 match ast {
137 SqlSelect {
138 project,
139 from,
140 filter: None,
141 } => {
142 let input = Self::type_from(from, vars, tx)?;
143 type_proj(input, project, vars)
144 }
145 SqlSelect {
146 project,
147 from,
148 filter: Some(expr),
149 } => {
150 let input = Self::type_from(from, vars, tx)?;
151 type_proj(type_select(input, expr, vars)?, project, vars)
152 }
153 }
154 }
155}
156
157pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult<ProjectName> {
159 expect_table_type(SubChecker::type_ast(parse_subscription(sql)?, tx)?)
160}
161
162pub fn type_subscription(ast: SqlSelect, tx: &impl SchemaView) -> TypingResult<ProjectName> {
164 expect_table_type(SubChecker::type_ast(ast, tx)?)
165}
166
167pub fn compile_sql_sub<'a>(sql: &'a str, tx: &impl SchemaView) -> TypingResult<StatementCtx<'a>> {
169 Ok(StatementCtx {
170 statement: Statement::Select(ProjectList::Name(parse_and_type_sub(sql, tx)?)),
171 sql,
172 source: StatementSource::Subscription,
173 })
174}
175
176fn expect_table_type(expr: ProjectList) -> TypingResult<ProjectName> {
178 match expr {
179 ProjectList::Name(proj) => Ok(proj),
180 ProjectList::Limit(input, _) => expect_table_type(*input),
181 ProjectList::List(..) | ProjectList::Agg(..) => Err(Unsupported::ReturnType.into()),
182 }
183}
184
185pub mod test_utils {
186 use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, ProductType};
187 use spacetimedb_primitives::TableId;
188 use spacetimedb_schema::{
189 def::ModuleDef,
190 schema::{Schema, TableSchema},
191 };
192 use std::sync::Arc;
193
194 use super::SchemaView;
195
196 pub fn build_module_def(types: Vec<(&str, ProductType)>) -> ModuleDef {
197 let mut builder = RawModuleDefV9Builder::new();
198 for (name, ty) in types {
199 builder.build_table_with_new_type(name, ty, true);
200 }
201 builder.finish().try_into().expect("failed to generate module def")
202 }
203
204 pub struct SchemaViewer(pub ModuleDef);
205
206 impl SchemaView for SchemaViewer {
207 fn table_id(&self, name: &str) -> Option<TableId> {
208 match name {
209 "t" => Some(TableId(0)),
210 "s" => Some(TableId(1)),
211 _ => None,
212 }
213 }
214
215 fn schema_for_table(&self, table_id: TableId) -> Option<Arc<TableSchema>> {
216 match table_id.idx() {
217 0 => Some((TableId(0), "t")),
218 1 => Some((TableId(1), "s")),
219 _ => None,
220 }
221 .and_then(|(table_id, name)| {
222 self.0
223 .table(name)
224 .map(|def| Arc::new(TableSchema::from_module_def(&self.0, def, (), table_id)))
225 })
226 }
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use crate::check::test_utils::{build_module_def, SchemaViewer};
233 use spacetimedb_lib::{AlgebraicType, ProductType};
234 use spacetimedb_schema::def::ModuleDef;
235
236 use super::parse_and_type_sub;
237
238 fn module_def() -> ModuleDef {
239 build_module_def(vec![
240 (
241 "t",
242 ProductType::from([
243 ("i8", AlgebraicType::I8),
244 ("u8", AlgebraicType::U8),
245 ("i16", AlgebraicType::I16),
246 ("u16", AlgebraicType::U16),
247 ("i32", AlgebraicType::I32),
248 ("u32", AlgebraicType::U32),
249 ("i64", AlgebraicType::I64),
250 ("u64", AlgebraicType::U64),
251 ("int", AlgebraicType::U32),
252 ("f32", AlgebraicType::F32),
253 ("f64", AlgebraicType::F64),
254 ("i128", AlgebraicType::I128),
255 ("u128", AlgebraicType::U128),
256 ("i256", AlgebraicType::I256),
257 ("u256", AlgebraicType::U256),
258 ("str", AlgebraicType::String),
259 ("arr", AlgebraicType::array(AlgebraicType::String)),
260 ]),
261 ),
262 (
263 "s",
264 ProductType::from([
265 ("id", AlgebraicType::identity()),
266 ("u32", AlgebraicType::U32),
267 ("arr", AlgebraicType::array(AlgebraicType::String)),
268 ("bytes", AlgebraicType::bytes()),
269 ]),
270 ),
271 ])
272 }
273
274 #[test]
275 fn valid_literals() {
276 let tx = SchemaViewer(module_def());
277
278 struct TestCase {
279 sql: &'static str,
280 msg: &'static str,
281 }
282
283 for TestCase { sql, msg } in [
284 TestCase {
285 sql: "select * from t where i32 = -1",
286 msg: "Leading `-`",
287 },
288 TestCase {
289 sql: "select * from t where u32 = +1",
290 msg: "Leading `+`",
291 },
292 TestCase {
293 sql: "select * from t where u32 = 1e3",
294 msg: "Scientific notation",
295 },
296 TestCase {
297 sql: "select * from t where u32 = 1E3",
298 msg: "Case insensitive scientific notation",
299 },
300 TestCase {
301 sql: "select * from t where f32 = 1e3",
302 msg: "Integers can parse as floats",
303 },
304 TestCase {
305 sql: "select * from t where f32 = 1e-3",
306 msg: "Negative exponent",
307 },
308 TestCase {
309 sql: "select * from t where f32 = 0.1",
310 msg: "Standard decimal notation",
311 },
312 TestCase {
313 sql: "select * from t where f32 = .1",
314 msg: "Leading `.`",
315 },
316 TestCase {
317 sql: "select * from t where f32 = 1e40",
318 msg: "Infinity",
319 },
320 TestCase {
321 sql: "select * from t where u256 = 1e40",
322 msg: "u256",
323 },
324 ] {
325 let result = parse_and_type_sub(sql, &tx);
326 assert!(result.is_ok(), "{msg}");
327 }
328 }
329
330 #[test]
331 fn valid_literals_for_type() {
332 let tx = SchemaViewer(module_def());
333
334 for ty in [
335 "i8", "u8", "i16", "u16", "i32", "u32", "i64", "u64", "f32", "f64", "i128", "u128", "i256", "u256",
336 ] {
337 let sql = format!("select * from t where {ty} = 127");
338 let result = parse_and_type_sub(&sql, &tx);
339 assert!(result.is_ok(), "Faild to parse {ty}: {}", result.unwrap_err());
340 }
341 }
342
343 #[test]
344 fn invalid_literals() {
345 let tx = SchemaViewer(module_def());
346
347 struct TestCase {
348 sql: &'static str,
349 msg: &'static str,
350 }
351
352 for TestCase { sql, msg } in [
353 TestCase {
354 sql: "select * from t where u8 = -1",
355 msg: "Negative integer for unsigned column",
356 },
357 TestCase {
358 sql: "select * from t where u8 = 1e3",
359 msg: "Out of bounds",
360 },
361 TestCase {
362 sql: "select * from t where u8 = 0.1",
363 msg: "Float as integer",
364 },
365 TestCase {
366 sql: "select * from t where u32 = 1e-3",
367 msg: "Float as integer",
368 },
369 TestCase {
370 sql: "select * from t where i32 = 1e-3",
371 msg: "Float as integer",
372 },
373 ] {
374 let result = parse_and_type_sub(sql, &tx);
375 assert!(result.is_err(), "{msg}");
376 }
377 }
378
379 #[test]
380 fn valid() {
381 let tx = SchemaViewer(module_def());
382
383 struct TestCase {
384 sql: &'static str,
385 msg: &'static str,
386 }
387
388 for TestCase { sql, msg } in [
389 TestCase {
390 sql: "select * from t",
391 msg: "Can select * on any table",
392 },
393 TestCase {
394 sql: "select * from t where true",
395 msg: "Boolean literals are valid in WHERE clause",
396 },
397 TestCase {
398 sql: "select * from t where t.u32 = 1",
399 msg: "Can qualify column references with table name",
400 },
401 TestCase {
402 sql: "select * from t where u32 = 1",
403 msg: "Can leave columns unqualified when unambiguous",
404 },
405 TestCase {
406 sql: "select * from t where t.u32 = 1 or t.str = ''",
407 msg: "Type OR with qualified column references",
408 },
409 TestCase {
410 sql: "select * from s where s.bytes = 0xABCD or bytes = X'ABCD'",
411 msg: "Type OR with mixed qualified and unqualified column references",
412 },
413 TestCase {
414 sql: "select * from s as r where r.bytes = 0xABCD or bytes = X'ABCD'",
415 msg: "Type OR with table alias",
416 },
417 TestCase {
418 sql: "select t.* from t join s",
419 msg: "Type cross join + projection",
420 },
421 TestCase {
422 sql: "select t.* from t join s join s as r where t.u32 = s.u32 and s.u32 = r.u32",
423 msg: "Type self join + projection",
424 },
425 TestCase {
426 sql: "select t.* from t join s on t.u32 = s.u32 where t.f32 = 0.1",
427 msg: "Type inner join + projection",
428 },
429 ] {
430 let result = parse_and_type_sub(sql, &tx);
431 assert!(result.is_ok(), "{msg}");
432 }
433 }
434
435 #[test]
436 fn invalid() {
437 let tx = SchemaViewer(module_def());
438
439 struct TestCase {
440 sql: &'static str,
441 msg: &'static str,
442 }
443
444 for TestCase { sql, msg } in [
445 TestCase {
446 sql: "select * from r",
447 msg: "Table r does not exist",
448 },
449 TestCase {
450 sql: "select * from t where t.a = 1",
451 msg: "Field a does not exist on table t",
452 },
453 TestCase {
454 sql: "select * from t as r where r.a = 1",
455 msg: "Field a does not exist on table t",
456 },
457 TestCase {
458 sql: "select * from t where u32 = 'str'",
459 msg: "Field u32 is not a string",
460 },
461 TestCase {
462 sql: "select * from t where t.u32 = 1.3",
463 msg: "Field u32 is not a float",
464 },
465 TestCase {
466 sql: "select * from t as r where t.u32 = 5",
467 msg: "t is not in scope after alias",
468 },
469 TestCase {
470 sql: "select u32 from t",
471 msg: "Subscriptions must be typed to a single table",
472 },
473 TestCase {
474 sql: "select * from t join s",
475 msg: "Subscriptions must be typed to a single table",
476 },
477 TestCase {
478 sql: "select t.* from t join t",
479 msg: "Self join requires aliases",
480 },
481 TestCase {
482 sql: "select t.* from t join s on t.arr = s.arr",
483 msg: "Product values are not comparable",
484 },
485 TestCase {
486 sql: "select t.* from t join s on t.u32 = r.u32 join s as r",
487 msg: "Alias r is not in scope when it is referenced",
488 },
489 TestCase {
490 sql: "select * from t limit 5",
491 msg: "Subscriptions do not support limit",
492 },
493 ] {
494 let result = parse_and_type_sub(sql, &tx);
495 assert!(result.is_err(), "{msg}");
496 }
497 }
498}