use std::collections::HashSet;
use check::TypingResult;
use errors::{DuplicateName, InvalidLiteral, InvalidWildcard, UnexpectedType, Unresolved};
use expr::{Expr, Let, RelExpr};
use spacetimedb_lib::{from_hex_pad, Address, AlgebraicType, AlgebraicValue, Identity};
use spacetimedb_sql_parser::ast::{self, ProjectElem, ProjectExpr, SqlExpr, SqlIdent, SqlLiteral};
use ty::{Symbol, TyCtx, TyEnv, TyId, Type, TypeWithCtx};
pub mod check;
pub mod errors;
pub mod expr;
pub mod statement;
pub mod ty;
#[macro_export]
macro_rules! static_assert_size {
($ty:ty, $size:expr) => {
const _: [(); $size] = [(); ::core::mem::size_of::<$ty>()];
};
}
pub(crate) fn type_select(
ctx: &mut TyCtx,
input: RelExpr,
alias: Option<Symbol>,
expr: SqlExpr,
) -> TypingResult<RelExpr> {
let mut vars = Vec::new();
let mut tenv = TyEnv::default();
if let Some(name) = alias {
tenv.add(name, input.ty_id());
vars.push((name, Expr::Input(input.ty_id())));
}
for (i, name, ty) in input.ty(ctx)?.expect_relation()?.iter() {
tenv.add(name, ty);
vars.push((name, Expr::Field(Box::new(Expr::Input(input.ty_id())), i, ty)));
}
let expr = type_expr(ctx, &tenv, expr, Some(TyId::BOOL))?;
Ok(RelExpr::select(
input,
Let {
vars,
exprs: vec![expr],
},
))
}
pub(crate) fn type_proj(
ctx: &mut TyCtx,
input: RelExpr,
alias: Option<Symbol>,
proj: ast::Project,
) -> TypingResult<RelExpr> {
match proj {
ast::Project::Star(None) => {
for (_, _, ty) in input.ty(ctx)?.expect_relation()?.iter() {
ctx.try_resolve(ty)?
.expect_scalar()
.map_err(|_| InvalidWildcard::Join)?;
}
Ok(input)
}
ast::Project::Star(Some(SqlIdent(var))) => {
let name = ctx.get_symbol(&var).ok_or_else(|| Unresolved::var(&var))?;
match alias {
Some(alias) if alias == name => {
let input_type = input.ty_id();
ctx.try_resolve(input_type)?
.expect_relation()
.map_err(|_| InvalidWildcard::Scalar)?;
Ok(RelExpr::project(
input,
Let {
vars: vec![],
exprs: vec![Expr::Input(input_type)],
},
))
}
Some(_) | None => {
let (i, ty) = input
.ty(ctx)?
.expect_relation()?
.find(name)
.ok_or_else(|| Unresolved::var(&var))?;
ctx.try_resolve(ty)?
.expect_relation()
.map_err(|_| InvalidWildcard::Scalar)?;
let input_type = input.ty_id();
Ok(RelExpr::project(
input,
Let {
vars: vec![],
exprs: vec![Expr::Field(Box::new(Expr::Input(input_type)), i, ty)],
},
))
}
}
}
ast::Project::Exprs(elems) => {
let mut vars = Vec::new();
let mut tenv = TyEnv::default();
if let Some(name) = alias {
tenv.add(name, input.ty_id());
vars.push((name, Expr::Input(input.ty_id())));
}
for (i, name, ty) in input.ty(ctx)?.expect_relation()?.iter() {
tenv.add(name, ty);
vars.push((name, Expr::Field(Box::new(Expr::Input(input.ty_id())), i, ty)));
}
let mut field_exprs = Vec::new();
let mut field_types = Vec::new();
let mut names = HashSet::new();
for elem in elems {
match elem {
ProjectElem(ProjectExpr::Var(SqlIdent(field)), None) => {
let name = ctx.gen_symbol(&field);
if !names.insert(name) {
return Err(DuplicateName(field.into_string()).into());
}
let expr = type_expr(ctx, &tenv, SqlExpr::Var(SqlIdent(field)), None)?;
field_types.push((name, expr.ty_id()));
field_exprs.push((name, expr));
}
ProjectElem(ProjectExpr::Var(field), Some(SqlIdent(alias))) => {
let name = ctx.gen_symbol(&alias);
if !names.insert(name) {
return Err(DuplicateName(alias.into_string()).into());
}
let expr = type_expr(ctx, &tenv, SqlExpr::Var(field), None)?;
field_types.push((name, expr.ty_id()));
field_exprs.push((name, expr));
}
ProjectElem(ProjectExpr::Field(table, SqlIdent(field)), None) => {
let name = ctx.gen_symbol(&field);
if !names.insert(name) {
return Err(DuplicateName(field.into_string()).into());
}
let expr = type_expr(ctx, &tenv, SqlExpr::Field(table, SqlIdent(field)), None)?;
field_types.push((name, expr.ty_id()));
field_exprs.push((name, expr));
}
ProjectElem(ProjectExpr::Field(table, field), Some(SqlIdent(alias))) => {
let name = ctx.gen_symbol(&alias);
if !names.insert(name) {
return Err(DuplicateName(alias.into_string()).into());
}
let expr = type_expr(ctx, &tenv, SqlExpr::Field(table, field), None)?;
field_types.push((name, expr.ty_id()));
field_exprs.push((name, expr));
}
}
}
let id = ctx.add_row_type(field_types);
Ok(RelExpr::project(
input,
Let {
vars,
exprs: vec![Expr::Row(field_exprs.into_boxed_slice(), id)],
},
))
}
}
}
pub(crate) fn type_expr(ctx: &TyCtx, vars: &TyEnv, expr: SqlExpr, expected: Option<TyId>) -> TypingResult<Expr> {
match (expr, expected) {
(SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(TyId::BOOL)) => Ok(Expr::bool(v)),
(SqlExpr::Lit(SqlLiteral::Bool(_)), Some(id)) => {
let expected = ctx.bool();
let inferred = ctx.try_resolve(id)?;
Err(UnexpectedType::new(&expected, &inferred).into())
}
(SqlExpr::Lit(SqlLiteral::Str(v)), None | Some(TyId::STR)) => Ok(Expr::str(v)),
(SqlExpr::Lit(SqlLiteral::Str(_)), Some(id)) => {
let expected = ctx.str();
let inferred = ctx.try_resolve(id)?;
Err(UnexpectedType::new(&expected, &inferred).into())
}
(SqlExpr::Lit(SqlLiteral::Num(_) | SqlLiteral::Hex(_)), None) => Err(Unresolved::Literal.into()),
(SqlExpr::Lit(SqlLiteral::Num(v) | SqlLiteral::Hex(v)), Some(id)) => {
let t = ctx.try_resolve(id)?;
let v = parse(v.into_string(), t)?;
Ok(Expr::Lit(v, id))
}
(SqlExpr::Var(SqlIdent(var)), None) => {
let var_name = ctx.get_symbol(&var).ok_or_else(|| Unresolved::var(&var))?;
let var_type = vars.find(var_name).ok_or_else(|| Unresolved::var(&var))?;
Ok(Expr::Var(var_name, var_type))
}
(SqlExpr::Var(SqlIdent(var)), Some(id)) => {
let var_name = ctx.get_symbol(&var).ok_or_else(|| Unresolved::var(&var))?;
let var_type = vars.find(var_name).ok_or_else(|| Unresolved::var(&var))?;
assert_eq_types(ctx, var_type, id)?;
Ok(Expr::Var(var_name, var_type))
}
(SqlExpr::Field(SqlIdent(table), SqlIdent(field)), None) => {
let table_name = ctx.get_symbol(&table).ok_or_else(|| Unresolved::var(&table))?;
let field_name = ctx.get_symbol(&field).ok_or_else(|| Unresolved::var(&field))?;
let table_type = vars.find(table_name).ok_or_else(|| Unresolved::var(&table))?;
let (i, field_type) = ctx
.try_resolve(table_type)?
.expect_relation()?
.find(field_name)
.ok_or_else(|| Unresolved::field(&table, &field))?;
Ok(Expr::Field(Box::new(Expr::Var(table_name, table_type)), i, field_type))
}
(SqlExpr::Field(SqlIdent(table), SqlIdent(field)), Some(id)) => {
let table_name = ctx.get_symbol(&table).ok_or_else(|| Unresolved::var(&table))?;
let field_name = ctx.get_symbol(&field).ok_or_else(|| Unresolved::var(&field))?;
let table_type = vars.find(table_name).ok_or_else(|| Unresolved::var(&table))?;
let (i, field_type) = ctx
.try_resolve(table_type)?
.expect_relation()?
.find(field_name)
.ok_or_else(|| Unresolved::field(&table, &field))?;
assert_eq_types(ctx, field_type, id)?;
Ok(Expr::Field(Box::new(Expr::Var(table_name, table_type)), i, field_type))
}
(SqlExpr::Bin(a, b, op), None | Some(TyId::BOOL)) => match (*a, *b) {
(a, b @ SqlExpr::Lit(_)) | (b @ SqlExpr::Lit(_), a) | (a, b) => {
let a = type_expr(ctx, vars, a, None)?;
let b = type_expr(ctx, vars, b, Some(a.ty_id()))?;
a.ty(ctx)?.expect_op(op)?;
Ok(Expr::Bin(op, Box::new(a), Box::new(b)))
}
},
(SqlExpr::Bin(..), Some(id)) => {
let expected = ctx.bool();
let inferred = ctx.try_resolve(id)?;
Err(UnexpectedType::new(&expected, &inferred).into())
}
}
}
pub(crate) fn assert_eq_types(ctx: &TyCtx, a: TyId, b: TyId) -> TypingResult<()> {
if !ctx.eq(a, b)? {
return Err(UnexpectedType::new(&ctx.try_resolve(a)?, &ctx.try_resolve(b)?).into());
}
Ok(())
}
pub(crate) fn parse(value: String, ty: TypeWithCtx) -> Result<AlgebraicValue, InvalidLiteral> {
match &*ty {
Type::Alg(AlgebraicType::I8) => value
.parse::<i8>()
.map(AlgebraicValue::I8)
.map_err(|_| InvalidLiteral::new(value, &ty)),
Type::Alg(AlgebraicType::U8) => value
.parse::<u8>()
.map(AlgebraicValue::U8)
.map_err(|_| InvalidLiteral::new(value, &ty)),
Type::Alg(AlgebraicType::I16) => value
.parse::<i16>()
.map(AlgebraicValue::I16)
.map_err(|_| InvalidLiteral::new(value, &ty)),
Type::Alg(AlgebraicType::U16) => value
.parse::<u16>()
.map(AlgebraicValue::U16)
.map_err(|_| InvalidLiteral::new(value, &ty)),
Type::Alg(AlgebraicType::I32) => value
.parse::<i32>()
.map(AlgebraicValue::I32)
.map_err(|_| InvalidLiteral::new(value, &ty)),
Type::Alg(AlgebraicType::U32) => value
.parse::<u32>()
.map(AlgebraicValue::U32)
.map_err(|_| InvalidLiteral::new(value, &ty)),
Type::Alg(AlgebraicType::I64) => value
.parse::<i64>()
.map(AlgebraicValue::I64)
.map_err(|_| InvalidLiteral::new(value, &ty)),
Type::Alg(AlgebraicType::U64) => value
.parse::<u64>()
.map(AlgebraicValue::U64)
.map_err(|_| InvalidLiteral::new(value, &ty)),
Type::Alg(AlgebraicType::F32) => value
.parse::<f32>()
.map(|value| AlgebraicValue::F32(value.into()))
.map_err(|_| InvalidLiteral::new(value, &ty)),
Type::Alg(AlgebraicType::F64) => value
.parse::<f64>()
.map(|value| AlgebraicValue::F64(value.into()))
.map_err(|_| InvalidLiteral::new(value, &ty)),
Type::Alg(AlgebraicType::I128) => value
.parse::<i128>()
.map(|value| AlgebraicValue::I128(value.into()))
.map_err(|_| InvalidLiteral::new(value, &ty)),
Type::Alg(AlgebraicType::U128) => value
.parse::<u128>()
.map(|value| AlgebraicValue::U128(value.into()))
.map_err(|_| InvalidLiteral::new(value, &ty)),
Type::Alg(t) if t.is_bytes() => from_hex_pad::<Vec<u8>, _>(&value)
.map(|value| AlgebraicValue::Bytes(value.into_boxed_slice()))
.map_err(|_| InvalidLiteral::new(value, &ty)),
Type::Alg(t) if t.is_identity() => Identity::from_hex(&value)
.map(AlgebraicValue::from)
.map_err(|_| InvalidLiteral::new(value, &ty)),
Type::Alg(t) if t.is_address() => Address::from_hex(&value)
.map(AlgebraicValue::from)
.map_err(|_| InvalidLiteral::new(value, &ty)),
_ => Err(InvalidLiteral::new(value, &ty)),
}
}