1use std::{collections::HashSet, ops::Deref, str::FromStr};
2
3use crate::statement::Statement;
4use anyhow::anyhow;
5use anyhow::bail;
6use anyhow::Context;
7use bigdecimal::BigDecimal;
8use bigdecimal::ToPrimitive;
9use check::{Relvars, TypingResult};
10use errors::{DuplicateName, InvalidLiteral, InvalidOp, InvalidWildcard, UnexpectedType, Unresolved};
11use ethnum::i256;
12use ethnum::u256;
13use expr::AggType;
14use expr::{Expr, FieldProject, ProjectList, ProjectName, RelExpr};
15use spacetimedb_lib::ser::Serialize;
16use spacetimedb_lib::Timestamp;
17use spacetimedb_lib::{from_hex_pad, AlgebraicType, AlgebraicValue, ConnectionId, Identity};
18use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type;
19use spacetimedb_sats::algebraic_value::ser::ValueSerializer;
20use spacetimedb_schema::schema::ColumnSchema;
21use spacetimedb_sql_parser::ast::{self, BinOp, ProjectElem, SqlExpr, SqlIdent, SqlLiteral};
22
23pub mod check;
24pub mod errors;
25pub mod expr;
26pub mod statement;
27
28pub(crate) fn type_select(input: RelExpr, expr: SqlExpr, vars: &Relvars) -> TypingResult<RelExpr> {
30 Ok(RelExpr::Select(
31 Box::new(input),
32 type_expr(vars, expr, Some(&AlgebraicType::Bool))?,
33 ))
34}
35
36pub(crate) fn type_limit(input: ProjectList, limit: &str) -> TypingResult<ProjectList> {
38 Ok(
39 parse_int(limit, AlgebraicType::U64, BigDecimal::to_u64, AlgebraicValue::U64)
40 .map_err(|_| InvalidLiteral::new(limit.to_owned(), &AlgebraicType::U64))
41 .and_then(|n| {
42 n.into_u64()
43 .map_err(|_| InvalidLiteral::new(limit.to_owned(), &AlgebraicType::U64))
44 })
45 .map(|n| ProjectList::Limit(Box::new(input), n))?,
46 )
47}
48
49pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> TypingResult<ProjectList> {
51 match proj {
52 ast::Project::Star(None) if input.nfields() > 1 => Err(InvalidWildcard::Join.into()),
53 ast::Project::Star(None) => Ok(ProjectList::Name(ProjectName::None(input))),
54 ast::Project::Star(Some(SqlIdent(var))) if input.has_field(&var) => {
55 Ok(ProjectList::Name(ProjectName::Some(input, var)))
56 }
57 ast::Project::Star(Some(SqlIdent(var))) => Err(Unresolved::var(&var).into()),
58 ast::Project::Count(SqlIdent(alias)) => Ok(ProjectList::Agg(input, AggType::Count, alias, AlgebraicType::U64)),
59 ast::Project::Exprs(elems) => {
60 let mut projections = vec![];
61 let mut names = HashSet::new();
62
63 for ProjectElem(expr, SqlIdent(alias)) in elems {
64 if !names.insert(alias.clone()) {
65 return Err(DuplicateName(alias.into_string()).into());
66 }
67
68 if let Expr::Field(p) = type_expr(vars, expr.into(), None)? {
69 projections.push((alias, p));
70 }
71 }
72
73 Ok(ProjectList::List(input, projections))
74 }
75 }
76}
77
78pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>) -> TypingResult<Expr> {
80 match (expr, expected) {
81 (SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(AlgebraicType::Bool)) => Ok(Expr::bool(v)),
82 (SqlExpr::Lit(SqlLiteral::Bool(_)), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
83 (SqlExpr::Lit(SqlLiteral::Str(_) | SqlLiteral::Num(_) | SqlLiteral::Hex(_)), None) => {
84 Err(Unresolved::Literal.into())
85 }
86 (SqlExpr::Lit(SqlLiteral::Str(v) | SqlLiteral::Num(v) | SqlLiteral::Hex(v)), Some(ty)) => Ok(Expr::Value(
87 parse(&v, ty).map_err(|_| InvalidLiteral::new(v.into_string(), ty))?,
88 ty.clone(),
89 )),
90 (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), None) => {
91 let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?;
92 let ColumnSchema { col_pos, col_type, .. } = table_type
93 .get_column_by_name(&field)
94 .ok_or_else(|| Unresolved::var(&field))?;
95 Ok(Expr::Field(FieldProject {
96 table,
97 field: col_pos.idx(),
98 ty: col_type.clone(),
99 }))
100 }
101 (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), Some(ty)) => {
102 let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?;
103 let ColumnSchema { col_pos, col_type, .. } = table_type
104 .as_ref()
105 .get_column_by_name(&field)
106 .ok_or_else(|| Unresolved::var(&field))?;
107 if col_type != ty {
108 return Err(UnexpectedType::new(col_type, ty).into());
109 }
110 Ok(Expr::Field(FieldProject {
111 table,
112 field: col_pos.idx(),
113 ty: col_type.clone(),
114 }))
115 }
116 (SqlExpr::Log(a, b, op), None | Some(AlgebraicType::Bool)) => {
117 let a = type_expr(vars, *a, Some(&AlgebraicType::Bool))?;
118 let b = type_expr(vars, *b, Some(&AlgebraicType::Bool))?;
119 Ok(Expr::LogOp(op, Box::new(a), Box::new(b)))
120 }
121 (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) => match (*a, *b) {
122 (a, b @ SqlExpr::Lit(_)) | (b @ SqlExpr::Lit(_), a) | (a, b) => {
123 let a = type_expr(vars, a, None)?;
124 let b = type_expr(vars, b, Some(a.ty()))?;
125 if !op_supports_type(op, a.ty()) {
126 return Err(InvalidOp::new(op, a.ty()).into());
127 }
128 Ok(Expr::BinOp(op, Box::new(a), Box::new(b)))
129 }
130 },
131 (SqlExpr::Bin(..) | SqlExpr::Log(..), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
132 (SqlExpr::Var(_), _) => unreachable!(),
133 }
134}
135
136fn op_supports_type(_op: BinOp, t: &AlgebraicType) -> bool {
138 t.is_bool()
139 || t.is_integer()
140 || t.is_float()
141 || t.is_string()
142 || t.is_bytes()
143 || t.is_identity()
144 || t.is_connection_id()
145 || t.is_timestamp()
146}
147
148fn parse_int<Int, Val, ToInt, ToVal>(
150 literal: &str,
151 ty: AlgebraicType,
152 to_int: ToInt,
153 to_val: ToVal,
154) -> anyhow::Result<AlgebraicValue>
155where
156 Int: Into<Val>,
157 ToInt: FnOnce(&BigDecimal) -> Option<Int>,
158 ToVal: FnOnce(Val) -> AlgebraicValue,
159{
160 BigDecimal::from_str(literal)
163 .ok()
164 .filter(|decimal| decimal.is_integer())
165 .ok_or_else(|| anyhow!("{literal} is not an integer"))
166 .map(|decimal| to_int(&decimal).map(|val| val.into()).map(to_val))
167 .transpose()
168 .ok_or_else(|| anyhow!("{literal} is out of bounds for type {}", fmt_algebraic_type(&ty)))?
169}
170
171fn parse_float<Float, Value, ToFloat, ToValue>(
173 literal: &str,
174 ty: AlgebraicType,
175 to_float: ToFloat,
176 to_value: ToValue,
177) -> anyhow::Result<AlgebraicValue>
178where
179 Float: Into<Value>,
180 ToFloat: FnOnce(&BigDecimal) -> Option<Float>,
181 ToValue: FnOnce(Value) -> AlgebraicValue,
182{
183 BigDecimal::from_str(literal)
184 .ok()
185 .and_then(|decimal| to_float(&decimal))
186 .map(|value| value.into())
187 .map(to_value)
188 .ok_or_else(|| anyhow!("{literal} is not a valid {}", fmt_algebraic_type(&ty)))
189}
190
191pub(crate) fn parse(value: &str, ty: &AlgebraicType) -> anyhow::Result<AlgebraicValue> {
193 let to_timestamp = || {
194 Timestamp::parse_from_str(value)?
195 .serialize(ValueSerializer)
196 .with_context(|| "Could not parse timestamp")
197 };
198 let to_bytes = || {
199 from_hex_pad::<Vec<u8>, _>(value)
200 .map(|v| v.into_boxed_slice())
201 .map(AlgebraicValue::Bytes)
202 .with_context(|| "Could not parse hex value")
203 };
204 let to_identity = || {
205 Identity::from_hex(value)
206 .map(AlgebraicValue::from)
207 .with_context(|| "Could not parse identity")
208 };
209 let to_connection_id = || {
210 ConnectionId::from_hex(value)
211 .map(AlgebraicValue::from)
212 .with_context(|| "Could not parse connection id")
213 };
214 let to_i256 = |decimal: &BigDecimal| {
215 i256::from_str_radix(
216 &decimal.to_plain_string(),
218 10,
219 )
220 .ok()
221 };
222 let to_u256 = |decimal: &BigDecimal| {
223 u256::from_str_radix(
224 &decimal.to_plain_string(),
226 10,
227 )
228 .ok()
229 };
230 match ty {
231 AlgebraicType::I8 => parse_int(
232 value,
234 AlgebraicType::I8,
235 BigDecimal::to_i8,
236 AlgebraicValue::I8,
237 ),
238 AlgebraicType::U8 => parse_int(
239 value,
241 AlgebraicType::U8,
242 BigDecimal::to_u8,
243 AlgebraicValue::U8,
244 ),
245 AlgebraicType::I16 => parse_int(
246 value,
248 AlgebraicType::I16,
249 BigDecimal::to_i16,
250 AlgebraicValue::I16,
251 ),
252 AlgebraicType::U16 => parse_int(
253 value,
255 AlgebraicType::U16,
256 BigDecimal::to_u16,
257 AlgebraicValue::U16,
258 ),
259 AlgebraicType::I32 => parse_int(
260 value,
262 AlgebraicType::I32,
263 BigDecimal::to_i32,
264 AlgebraicValue::I32,
265 ),
266 AlgebraicType::U32 => parse_int(
267 value,
269 AlgebraicType::U32,
270 BigDecimal::to_u32,
271 AlgebraicValue::U32,
272 ),
273 AlgebraicType::I64 => parse_int(
274 value,
276 AlgebraicType::I64,
277 BigDecimal::to_i64,
278 AlgebraicValue::I64,
279 ),
280 AlgebraicType::U64 => parse_int(
281 value,
283 AlgebraicType::U64,
284 BigDecimal::to_u64,
285 AlgebraicValue::U64,
286 ),
287 AlgebraicType::F32 => parse_float(
288 value,
290 AlgebraicType::F32,
291 BigDecimal::to_f32,
292 AlgebraicValue::F32,
293 ),
294 AlgebraicType::F64 => parse_float(
295 value,
297 AlgebraicType::F64,
298 BigDecimal::to_f64,
299 AlgebraicValue::F64,
300 ),
301 AlgebraicType::I128 => parse_int(
302 value,
304 AlgebraicType::I128,
305 BigDecimal::to_i128,
306 AlgebraicValue::I128,
307 ),
308 AlgebraicType::U128 => parse_int(
309 value,
311 AlgebraicType::U128,
312 BigDecimal::to_u128,
313 AlgebraicValue::U128,
314 ),
315 AlgebraicType::I256 => parse_int(
316 value,
318 AlgebraicType::I256,
319 to_i256,
320 AlgebraicValue::I256,
321 ),
322 AlgebraicType::U256 => parse_int(
323 value,
325 AlgebraicType::U256,
326 to_u256,
327 AlgebraicValue::U256,
328 ),
329 AlgebraicType::String => Ok(AlgebraicValue::String(value.into())),
330 t if t.is_timestamp() => to_timestamp(),
331 t if t.is_bytes() => to_bytes(),
332 t if t.is_identity() => to_identity(),
333 t if t.is_connection_id() => to_connection_id(),
334 t => bail!("Literal values for type {} are not supported", fmt_algebraic_type(t)),
335 }
336}
337
338pub enum StatementSource {
340 Subscription,
341 Query,
342}
343
344pub struct StatementCtx<'a> {
348 pub statement: Statement,
349 pub sql: &'a str,
350 pub source: StatementSource,
351}