spacetimedb_expr/
lib.rs

1use std::{collections::HashSet, ops::Deref, str::FromStr};
2
3use crate::statement::Statement;
4use anyhow::anyhow;
5use anyhow::bail;
6use anyhow::Context;
7use bigdecimal::BigDecimal;
8use bigdecimal::ToPrimitive;
9use check::{Relvars, TypingResult};
10use errors::{DuplicateName, InvalidLiteral, InvalidOp, InvalidWildcard, UnexpectedType, Unresolved};
11use ethnum::i256;
12use ethnum::u256;
13use expr::AggType;
14use expr::{Expr, FieldProject, ProjectList, ProjectName, RelExpr};
15use spacetimedb_lib::{from_hex_pad, AlgebraicType, AlgebraicValue, ConnectionId, Identity};
16use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type;
17use spacetimedb_schema::schema::ColumnSchema;
18use spacetimedb_sql_parser::ast::{self, BinOp, ProjectElem, SqlExpr, SqlIdent, SqlLiteral};
19
20pub mod check;
21pub mod errors;
22pub mod expr;
23pub mod statement;
24
25/// Type check and lower a [SqlExpr]
26pub(crate) fn type_select(input: RelExpr, expr: SqlExpr, vars: &Relvars) -> TypingResult<RelExpr> {
27    Ok(RelExpr::Select(
28        Box::new(input),
29        type_expr(vars, expr, Some(&AlgebraicType::Bool))?,
30    ))
31}
32
33/// Type check a LIMIT clause
34pub(crate) fn type_limit(input: ProjectList, limit: &str) -> TypingResult<ProjectList> {
35    Ok(
36        parse_int(limit, AlgebraicType::U64, BigDecimal::to_u64, AlgebraicValue::U64)
37            .map_err(|_| InvalidLiteral::new(limit.to_owned(), &AlgebraicType::U64))
38            .and_then(|n| {
39                n.into_u64()
40                    .map_err(|_| InvalidLiteral::new(limit.to_owned(), &AlgebraicType::U64))
41            })
42            .map(|n| ProjectList::Limit(Box::new(input), n))?,
43    )
44}
45
46/// Type check and lower a [ast::Project]
47pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> TypingResult<ProjectList> {
48    match proj {
49        ast::Project::Star(None) if input.nfields() > 1 => Err(InvalidWildcard::Join.into()),
50        ast::Project::Star(None) => Ok(ProjectList::Name(ProjectName::None(input))),
51        ast::Project::Star(Some(SqlIdent(var))) if input.has_field(&var) => {
52            Ok(ProjectList::Name(ProjectName::Some(input, var)))
53        }
54        ast::Project::Star(Some(SqlIdent(var))) => Err(Unresolved::var(&var).into()),
55        ast::Project::Count(SqlIdent(alias)) => Ok(ProjectList::Agg(input, AggType::Count, alias, AlgebraicType::U64)),
56        ast::Project::Exprs(elems) => {
57            let mut projections = vec![];
58            let mut names = HashSet::new();
59
60            for ProjectElem(expr, SqlIdent(alias)) in elems {
61                if !names.insert(alias.clone()) {
62                    return Err(DuplicateName(alias.into_string()).into());
63                }
64
65                if let Expr::Field(p) = type_expr(vars, expr.into(), None)? {
66                    projections.push((alias, p));
67                }
68            }
69
70            Ok(ProjectList::List(input, projections))
71        }
72    }
73}
74
75/// Type check and lower a [SqlExpr] into a logical [Expr].
76pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>) -> TypingResult<Expr> {
77    match (expr, expected) {
78        (SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(AlgebraicType::Bool)) => Ok(Expr::bool(v)),
79        (SqlExpr::Lit(SqlLiteral::Bool(_)), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
80        (SqlExpr::Lit(SqlLiteral::Str(v)), None | Some(AlgebraicType::String)) => Ok(Expr::str(v)),
81        (SqlExpr::Lit(SqlLiteral::Str(_)), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::String, ty).into()),
82        (SqlExpr::Lit(SqlLiteral::Num(_) | SqlLiteral::Hex(_)), None) => Err(Unresolved::Literal.into()),
83        (SqlExpr::Lit(SqlLiteral::Num(v) | SqlLiteral::Hex(v)), Some(ty)) => Ok(Expr::Value(
84            parse(&v, ty).map_err(|_| InvalidLiteral::new(v.into_string(), ty))?,
85            ty.clone(),
86        )),
87        (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), None) => {
88            let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?;
89            let ColumnSchema { col_pos, col_type, .. } = table_type
90                .get_column_by_name(&field)
91                .ok_or_else(|| Unresolved::var(&field))?;
92            Ok(Expr::Field(FieldProject {
93                table,
94                field: col_pos.idx(),
95                ty: col_type.clone(),
96            }))
97        }
98        (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), Some(ty)) => {
99            let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?;
100            let ColumnSchema { col_pos, col_type, .. } = table_type
101                .as_ref()
102                .get_column_by_name(&field)
103                .ok_or_else(|| Unresolved::var(&field))?;
104            if col_type != ty {
105                return Err(UnexpectedType::new(col_type, ty).into());
106            }
107            Ok(Expr::Field(FieldProject {
108                table,
109                field: col_pos.idx(),
110                ty: col_type.clone(),
111            }))
112        }
113        (SqlExpr::Log(a, b, op), None | Some(AlgebraicType::Bool)) => {
114            let a = type_expr(vars, *a, Some(&AlgebraicType::Bool))?;
115            let b = type_expr(vars, *b, Some(&AlgebraicType::Bool))?;
116            Ok(Expr::LogOp(op, Box::new(a), Box::new(b)))
117        }
118        (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) => match (*a, *b) {
119            (a, b @ SqlExpr::Lit(_)) | (b @ SqlExpr::Lit(_), a) | (a, b) => {
120                let a = type_expr(vars, a, None)?;
121                let b = type_expr(vars, b, Some(a.ty()))?;
122                if !op_supports_type(op, a.ty()) {
123                    return Err(InvalidOp::new(op, a.ty()).into());
124                }
125                Ok(Expr::BinOp(op, Box::new(a), Box::new(b)))
126            }
127        },
128        (SqlExpr::Bin(..) | SqlExpr::Log(..), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
129        (SqlExpr::Var(_), _) => unreachable!(),
130    }
131}
132
133/// Is this type compatible with this binary operator?
134fn op_supports_type(_op: BinOp, t: &AlgebraicType) -> bool {
135    t.is_bool()
136        || t.is_integer()
137        || t.is_float()
138        || t.is_string()
139        || t.is_bytes()
140        || t.is_identity()
141        || t.is_connection_id()
142}
143
144/// Parse an integer literal into an [AlgebraicValue]
145fn parse_int<Int, Val, ToInt, ToVal>(
146    literal: &str,
147    ty: AlgebraicType,
148    to_int: ToInt,
149    to_val: ToVal,
150) -> anyhow::Result<AlgebraicValue>
151where
152    Int: Into<Val>,
153    ToInt: FnOnce(&BigDecimal) -> Option<Int>,
154    ToVal: FnOnce(Val) -> AlgebraicValue,
155{
156    // Why are we using an arbitrary precision type?
157    // For scientific notation as well as i256 and u256.
158    BigDecimal::from_str(literal)
159        .ok()
160        .filter(|decimal| decimal.is_integer())
161        .ok_or_else(|| anyhow!("{literal} is not an integer"))
162        .map(|decimal| to_int(&decimal).map(|val| val.into()).map(to_val))
163        .transpose()
164        .ok_or_else(|| anyhow!("{literal} is out of bounds for type {}", fmt_algebraic_type(&ty)))?
165}
166
167/// Parse a floating point literal into an [AlgebraicValue]
168fn parse_float<Float, Value, ToFloat, ToValue>(
169    literal: &str,
170    ty: AlgebraicType,
171    to_float: ToFloat,
172    to_value: ToValue,
173) -> anyhow::Result<AlgebraicValue>
174where
175    Float: Into<Value>,
176    ToFloat: FnOnce(&BigDecimal) -> Option<Float>,
177    ToValue: FnOnce(Value) -> AlgebraicValue,
178{
179    BigDecimal::from_str(literal)
180        .ok()
181        .and_then(|decimal| to_float(&decimal))
182        .map(|value| value.into())
183        .map(to_value)
184        .ok_or_else(|| anyhow!("{literal} is not a valid {}", fmt_algebraic_type(&ty)))
185}
186
187/// Parses a source text literal as a particular type
188pub(crate) fn parse(value: &str, ty: &AlgebraicType) -> anyhow::Result<AlgebraicValue> {
189    let to_bytes = || {
190        from_hex_pad::<Vec<u8>, _>(value)
191            .map(|v| v.into_boxed_slice())
192            .map(AlgebraicValue::Bytes)
193            .with_context(|| "Could not parse hex value")
194    };
195    let to_identity = || {
196        Identity::from_hex(value)
197            .map(AlgebraicValue::from)
198            .with_context(|| "Could not parse identity")
199    };
200    let to_connection_id = || {
201        ConnectionId::from_hex(value)
202            .map(AlgebraicValue::from)
203            .with_context(|| "Could not parse connection id")
204    };
205    let to_i256 = |decimal: &BigDecimal| {
206        i256::from_str_radix(
207            // Convert to decimal notation
208            &decimal.to_plain_string(),
209            10,
210        )
211        .ok()
212    };
213    let to_u256 = |decimal: &BigDecimal| {
214        u256::from_str_radix(
215            // Convert to decimal notation
216            &decimal.to_plain_string(),
217            10,
218        )
219        .ok()
220    };
221    match ty {
222        AlgebraicType::I8 => parse_int(
223            // Parse literal as I8
224            value,
225            AlgebraicType::I8,
226            BigDecimal::to_i8,
227            AlgebraicValue::I8,
228        ),
229        AlgebraicType::U8 => parse_int(
230            // Parse literal as U8
231            value,
232            AlgebraicType::U8,
233            BigDecimal::to_u8,
234            AlgebraicValue::U8,
235        ),
236        AlgebraicType::I16 => parse_int(
237            // Parse literal as I16
238            value,
239            AlgebraicType::I16,
240            BigDecimal::to_i16,
241            AlgebraicValue::I16,
242        ),
243        AlgebraicType::U16 => parse_int(
244            // Parse literal as U16
245            value,
246            AlgebraicType::U16,
247            BigDecimal::to_u16,
248            AlgebraicValue::U16,
249        ),
250        AlgebraicType::I32 => parse_int(
251            // Parse literal as I32
252            value,
253            AlgebraicType::I32,
254            BigDecimal::to_i32,
255            AlgebraicValue::I32,
256        ),
257        AlgebraicType::U32 => parse_int(
258            // Parse literal as U32
259            value,
260            AlgebraicType::U32,
261            BigDecimal::to_u32,
262            AlgebraicValue::U32,
263        ),
264        AlgebraicType::I64 => parse_int(
265            // Parse literal as I64
266            value,
267            AlgebraicType::I64,
268            BigDecimal::to_i64,
269            AlgebraicValue::I64,
270        ),
271        AlgebraicType::U64 => parse_int(
272            // Parse literal as U64
273            value,
274            AlgebraicType::U64,
275            BigDecimal::to_u64,
276            AlgebraicValue::U64,
277        ),
278        AlgebraicType::F32 => parse_float(
279            // Parse literal as F32
280            value,
281            AlgebraicType::F32,
282            BigDecimal::to_f32,
283            AlgebraicValue::F32,
284        ),
285        AlgebraicType::F64 => parse_float(
286            // Parse literal as F64
287            value,
288            AlgebraicType::F64,
289            BigDecimal::to_f64,
290            AlgebraicValue::F64,
291        ),
292        AlgebraicType::I128 => parse_int(
293            // Parse literal as I128
294            value,
295            AlgebraicType::I128,
296            BigDecimal::to_i128,
297            AlgebraicValue::I128,
298        ),
299        AlgebraicType::U128 => parse_int(
300            // Parse literal as U128
301            value,
302            AlgebraicType::U128,
303            BigDecimal::to_u128,
304            AlgebraicValue::U128,
305        ),
306        AlgebraicType::I256 => parse_int(
307            // Parse literal as I256
308            value,
309            AlgebraicType::I256,
310            to_i256,
311            AlgebraicValue::I256,
312        ),
313        AlgebraicType::U256 => parse_int(
314            // Parse literal as U256
315            value,
316            AlgebraicType::U256,
317            to_u256,
318            AlgebraicValue::U256,
319        ),
320        AlgebraicType::String => Ok(AlgebraicValue::String(value.into())),
321        t if t.is_bytes() => to_bytes(),
322        t if t.is_identity() => to_identity(),
323        t if t.is_connection_id() => to_connection_id(),
324        t => bail!("Literal values for type {} are not supported", fmt_algebraic_type(t)),
325    }
326}
327
328/// The source of a statement
329pub enum StatementSource {
330    Subscription,
331    Query,
332}
333
334/// A statement context.
335///
336/// This is a wrapper around a statement, its source, and the original SQL text.
337pub struct StatementCtx<'a> {
338    pub statement: Statement,
339    pub sql: &'a str,
340    pub source: StatementSource,
341}