use std::sync::Arc;
use spacetimedb_schema::schema::{ColumnSchema, TableSchema};
use spacetimedb_sql_parser::{
ast::{
self,
sub::{SqlAst, SqlSelect},
SqlFrom, SqlIdent, SqlJoin,
},
parser::sub::parse_subscription,
};
use crate::ty::TyId;
use super::{
assert_eq_types,
errors::{DuplicateName, TypingError, Unresolved, Unsupported},
expr::{Expr, Let, RelExpr},
ty::{Symbol, TyCtx, TyEnv},
type_expr, type_proj, type_select,
};
pub type TypingResult<T> = core::result::Result<T, TypingError>;
pub trait SchemaView {
fn schema(&self, name: &str) -> Option<Arc<TableSchema>>;
}
pub trait TypeChecker {
type Ast;
type Set;
fn type_ast(ctx: &mut TyCtx, ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<RelExpr>;
fn type_set(ctx: &mut TyCtx, ast: Self::Set, tx: &impl SchemaView) -> TypingResult<RelExpr>;
fn type_from(
ctx: &mut TyCtx,
from: SqlFrom<Self::Ast>,
tx: &impl SchemaView,
) -> TypingResult<(RelExpr, Option<Symbol>)> {
match from {
SqlFrom::Expr(expr, None) => Self::type_rel(ctx, expr, tx),
SqlFrom::Expr(expr, Some(SqlIdent(alias))) => {
let (expr, _) = Self::type_rel(ctx, expr, tx)?;
let symbol = ctx.gen_symbol(alias);
Ok((expr, Some(symbol)))
}
SqlFrom::Join(r, SqlIdent(alias), joins) => {
let mut env = TyEnv::default();
let mut inputs = Vec::new();
let mut exprs = Vec::new();
let mut types = Vec::new();
let input = Self::type_rel(ctx, r, tx)?.0;
let ty = input.ty_id();
let name = ctx.gen_symbol(alias);
env.add(name, ty);
inputs.push(input);
types.push((name, ty));
for SqlJoin {
expr,
alias: SqlIdent(alias),
on,
} in joins
{
let input = Self::type_rel(ctx, expr, tx)?.0;
let ty = input.ty_id();
let name = ctx.gen_symbol(&alias);
if env.add(name, ty).is_some() {
return Err(DuplicateName(alias.into_string()).into());
}
inputs.push(input);
types.push((name, ty));
if let Some(on) = on {
exprs.push(type_expr(ctx, &env, on, Some(TyId::BOOL))?);
}
}
let ty = ctx.add_row_type(types.clone());
let input = RelExpr::Join(inputs.into(), ty);
let vars = types
.into_iter()
.enumerate()
.map(|(i, (name, ty))| (name, Expr::Field(Box::new(Expr::Input(input.ty_id())), i, ty)))
.collect();
Ok((RelExpr::select(input, Let { vars, exprs }), None))
}
}
}
fn type_rel(
ctx: &mut TyCtx,
expr: ast::RelExpr<Self::Ast>,
tx: &impl SchemaView,
) -> TypingResult<(RelExpr, Option<Symbol>)> {
match expr {
ast::RelExpr::Var(SqlIdent(var)) => {
let schema = tx
.schema(&var)
.ok_or_else(|| Unresolved::table(&var))
.map_err(TypingError::from)?;
let mut types = Vec::new();
for ColumnSchema { col_name, col_type, .. } in schema.columns() {
let id = ctx.add_algebraic_type(col_type);
let name = ctx.gen_symbol(col_name);
types.push((name, id));
}
let id = ctx.add_var_type(schema.table_id, types);
let symbol = ctx.gen_symbol(var);
Ok((RelExpr::RelVar(schema, id), Some(symbol)))
}
ast::RelExpr::Ast(ast) => Ok((Self::type_ast(ctx, *ast, tx)?, None)),
}
}
}
struct SubChecker;
impl TypeChecker for SubChecker {
type Ast = SqlAst;
type Set = SqlAst;
fn type_ast(ctx: &mut TyCtx, ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<RelExpr> {
Self::type_set(ctx, ast, tx)
}
fn type_set(ctx: &mut TyCtx, ast: Self::Set, tx: &impl SchemaView) -> TypingResult<RelExpr> {
match ast {
SqlAst::Union(a, b) => {
let a = Self::type_ast(ctx, *a, tx)?;
let b = Self::type_ast(ctx, *b, tx)?;
assert_eq_types(ctx, a.ty_id(), b.ty_id())?;
Ok(RelExpr::Union(Box::new(a), Box::new(b)))
}
SqlAst::Minus(a, b) => {
let a = Self::type_ast(ctx, *a, tx)?;
let b = Self::type_ast(ctx, *b, tx)?;
assert_eq_types(ctx, a.ty_id(), b.ty_id())?;
Ok(RelExpr::Minus(Box::new(a), Box::new(b)))
}
SqlAst::Select(SqlSelect {
project,
from,
filter: None,
}) => {
let (input, alias) = Self::type_from(ctx, from, tx)?;
type_proj(ctx, input, alias, project)
}
SqlAst::Select(SqlSelect {
project,
from,
filter: Some(expr),
}) => {
let (from, alias) = Self::type_from(ctx, from, tx)?;
let input = type_select(ctx, from, alias, expr)?;
type_proj(ctx, input, alias, project)
}
}
}
}
pub fn parse_and_type_sub(ctx: &mut TyCtx, sql: &str, tx: &impl SchemaView) -> TypingResult<RelExpr> {
let expr = SubChecker::type_ast(ctx, parse_subscription(sql)?, tx)?;
expect_table_type(ctx, expr)
}
fn expect_table_type(ctx: &TyCtx, expr: RelExpr) -> TypingResult<RelExpr> {
let _ = expr.ty(ctx)?.expect_relvar().map_err(|_| Unsupported::ReturnType)?;
Ok(expr)
}
#[cfg(test)]
mod tests {
use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, AlgebraicType, ProductType};
use spacetimedb_primitives::TableId;
use spacetimedb_schema::{
def::ModuleDef,
schema::{Schema, TableSchema},
};
use std::sync::Arc;
use crate::ty::TyCtx;
use super::{parse_and_type_sub, SchemaView};
fn module_def() -> ModuleDef {
let mut builder = RawModuleDefV9Builder::new();
builder.build_table_with_new_type(
"t",
ProductType::from([
("u32", AlgebraicType::U32),
("f32", AlgebraicType::F32),
("str", AlgebraicType::String),
("arr", AlgebraicType::array(AlgebraicType::String)),
]),
true,
);
builder.build_table_with_new_type(
"s",
ProductType::from([
("id", AlgebraicType::identity()),
("u32", AlgebraicType::U32),
("arr", AlgebraicType::array(AlgebraicType::String)),
("bytes", AlgebraicType::bytes()),
]),
true,
);
builder.finish().try_into().expect("failed to generate module def")
}
struct SchemaViewer(ModuleDef);
impl SchemaView for SchemaViewer {
fn schema(&self, name: &str) -> Option<Arc<TableSchema>> {
self.0.table(name).map(|def| {
Arc::new(TableSchema::from_module_def(
&self.0,
def,
(),
TableId(if *def.name == *"t" { 0 } else { 1 }),
))
})
}
}
#[test]
fn valid() {
let tx = SchemaViewer(module_def());
for sql in [
"select * from t",
"select * from t where true",
"select * from t where t.u32 = 1",
"select * from t where u32 = 1",
"select * from t where t.u32 = 1 or t.str = ''",
"select * from s where s.bytes = 0xABCD or bytes = X'ABCD'",
"select * from s as r where r.bytes = 0xABCD or bytes = X'ABCD'",
"select * from (select t.* from t join s)",
"select * from (select t.* from t join s join s as r where t.u32 = s.u32 and s.u32 = r.u32)",
"select * from (select t.* from t join s on t.u32 = s.u32 where t.f32 = 0.1)",
"select * from (select t.* from t join (select s.u32 from s) s on t.u32 = s.u32)",
"select * from (select t.* from t join (select u32 as a from s) s on t.u32 = s.a)",
"select * from (select * from t union all select * from t)",
] {
let result = parse_and_type_sub(&mut TyCtx::default(), sql, &tx);
assert!(result.is_ok());
}
}
#[test]
fn invalid() {
let tx = SchemaViewer(module_def());
for sql in [
"select * from r",
"select * from t where t.a = 1",
"select * from t as r where r.a = 1",
"select * from t where u32 = 'str'",
"select * from t where t.u32 = 1.3",
"select * from t as r where t.u32 = 5",
"select u32 from t",
"select * from t join s",
"select * from (select t.* from t join t)",
"select * from (select t.* from t join s on t.arr = s.arr)",
"select * from (select s.* from t join (select u32 from s) s on t.u32 = s.u32)",
"select * from (select t.* from t join (select u32 as a from s) s on t.u32 = s.u32)",
"select * from (select t.* from t join (select u32 from s) s on s.bytes = 0xABCD)",
"select * from (select t.* from t join s on t.u32 = r.u32 join s as r)",
"select * from (select * from t union all select * from s)",
] {
let result = parse_and_type_sub(&mut TyCtx::default(), sql, &tx);
assert!(result.is_err());
}
}
}