1use std::{collections::HashSet, ops::Deref, str::FromStr};
2
3use crate::statement::Statement;
4use anyhow::anyhow;
5use anyhow::bail;
6use anyhow::Context;
7use bigdecimal::BigDecimal;
8use bigdecimal::ToPrimitive;
9use check::{Relvars, TypingResult};
10use errors::{DuplicateName, InvalidLiteral, InvalidOp, InvalidWildcard, UnexpectedType, Unresolved};
11use ethnum::i256;
12use ethnum::u256;
13use expr::AggType;
14use expr::{Expr, FieldProject, ProjectList, ProjectName, RelExpr};
15use spacetimedb_lib::{from_hex_pad, AlgebraicType, AlgebraicValue, ConnectionId, Identity};
16use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type;
17use spacetimedb_schema::schema::ColumnSchema;
18use spacetimedb_sql_parser::ast::{self, BinOp, ProjectElem, SqlExpr, SqlIdent, SqlLiteral};
19
20pub mod check;
21pub mod errors;
22pub mod expr;
23pub mod statement;
24
25pub(crate) fn type_select(input: RelExpr, expr: SqlExpr, vars: &Relvars) -> TypingResult<RelExpr> {
27 Ok(RelExpr::Select(
28 Box::new(input),
29 type_expr(vars, expr, Some(&AlgebraicType::Bool))?,
30 ))
31}
32
33pub(crate) fn type_limit(input: ProjectList, limit: &str) -> TypingResult<ProjectList> {
35 Ok(
36 parse_int(limit, AlgebraicType::U64, BigDecimal::to_u64, AlgebraicValue::U64)
37 .map_err(|_| InvalidLiteral::new(limit.to_owned(), &AlgebraicType::U64))
38 .and_then(|n| {
39 n.into_u64()
40 .map_err(|_| InvalidLiteral::new(limit.to_owned(), &AlgebraicType::U64))
41 })
42 .map(|n| ProjectList::Limit(Box::new(input), n))?,
43 )
44}
45
46pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> TypingResult<ProjectList> {
48 match proj {
49 ast::Project::Star(None) if input.nfields() > 1 => Err(InvalidWildcard::Join.into()),
50 ast::Project::Star(None) => Ok(ProjectList::Name(ProjectName::None(input))),
51 ast::Project::Star(Some(SqlIdent(var))) if input.has_field(&var) => {
52 Ok(ProjectList::Name(ProjectName::Some(input, var)))
53 }
54 ast::Project::Star(Some(SqlIdent(var))) => Err(Unresolved::var(&var).into()),
55 ast::Project::Count(SqlIdent(alias)) => Ok(ProjectList::Agg(input, AggType::Count, alias, AlgebraicType::U64)),
56 ast::Project::Exprs(elems) => {
57 let mut projections = vec![];
58 let mut names = HashSet::new();
59
60 for ProjectElem(expr, SqlIdent(alias)) in elems {
61 if !names.insert(alias.clone()) {
62 return Err(DuplicateName(alias.into_string()).into());
63 }
64
65 if let Expr::Field(p) = type_expr(vars, expr.into(), None)? {
66 projections.push((alias, p));
67 }
68 }
69
70 Ok(ProjectList::List(input, projections))
71 }
72 }
73}
74
75pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>) -> TypingResult<Expr> {
77 match (expr, expected) {
78 (SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(AlgebraicType::Bool)) => Ok(Expr::bool(v)),
79 (SqlExpr::Lit(SqlLiteral::Bool(_)), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
80 (SqlExpr::Lit(SqlLiteral::Str(v)), None | Some(AlgebraicType::String)) => Ok(Expr::str(v)),
81 (SqlExpr::Lit(SqlLiteral::Str(_)), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::String, ty).into()),
82 (SqlExpr::Lit(SqlLiteral::Num(_) | SqlLiteral::Hex(_)), None) => Err(Unresolved::Literal.into()),
83 (SqlExpr::Lit(SqlLiteral::Num(v) | SqlLiteral::Hex(v)), Some(ty)) => Ok(Expr::Value(
84 parse(&v, ty).map_err(|_| InvalidLiteral::new(v.into_string(), ty))?,
85 ty.clone(),
86 )),
87 (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), None) => {
88 let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?;
89 let ColumnSchema { col_pos, col_type, .. } = table_type
90 .get_column_by_name(&field)
91 .ok_or_else(|| Unresolved::var(&field))?;
92 Ok(Expr::Field(FieldProject {
93 table,
94 field: col_pos.idx(),
95 ty: col_type.clone(),
96 }))
97 }
98 (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), Some(ty)) => {
99 let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?;
100 let ColumnSchema { col_pos, col_type, .. } = table_type
101 .as_ref()
102 .get_column_by_name(&field)
103 .ok_or_else(|| Unresolved::var(&field))?;
104 if col_type != ty {
105 return Err(UnexpectedType::new(col_type, ty).into());
106 }
107 Ok(Expr::Field(FieldProject {
108 table,
109 field: col_pos.idx(),
110 ty: col_type.clone(),
111 }))
112 }
113 (SqlExpr::Log(a, b, op), None | Some(AlgebraicType::Bool)) => {
114 let a = type_expr(vars, *a, Some(&AlgebraicType::Bool))?;
115 let b = type_expr(vars, *b, Some(&AlgebraicType::Bool))?;
116 Ok(Expr::LogOp(op, Box::new(a), Box::new(b)))
117 }
118 (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) => match (*a, *b) {
119 (a, b @ SqlExpr::Lit(_)) | (b @ SqlExpr::Lit(_), a) | (a, b) => {
120 let a = type_expr(vars, a, None)?;
121 let b = type_expr(vars, b, Some(a.ty()))?;
122 if !op_supports_type(op, a.ty()) {
123 return Err(InvalidOp::new(op, a.ty()).into());
124 }
125 Ok(Expr::BinOp(op, Box::new(a), Box::new(b)))
126 }
127 },
128 (SqlExpr::Bin(..) | SqlExpr::Log(..), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
129 (SqlExpr::Var(_), _) => unreachable!(),
130 }
131}
132
133fn op_supports_type(_op: BinOp, t: &AlgebraicType) -> bool {
135 t.is_bool()
136 || t.is_integer()
137 || t.is_float()
138 || t.is_string()
139 || t.is_bytes()
140 || t.is_identity()
141 || t.is_connection_id()
142}
143
144fn parse_int<Int, Val, ToInt, ToVal>(
146 literal: &str,
147 ty: AlgebraicType,
148 to_int: ToInt,
149 to_val: ToVal,
150) -> anyhow::Result<AlgebraicValue>
151where
152 Int: Into<Val>,
153 ToInt: FnOnce(&BigDecimal) -> Option<Int>,
154 ToVal: FnOnce(Val) -> AlgebraicValue,
155{
156 BigDecimal::from_str(literal)
159 .ok()
160 .filter(|decimal| decimal.is_integer())
161 .ok_or_else(|| anyhow!("{literal} is not an integer"))
162 .map(|decimal| to_int(&decimal).map(|val| val.into()).map(to_val))
163 .transpose()
164 .ok_or_else(|| anyhow!("{literal} is out of bounds for type {}", fmt_algebraic_type(&ty)))?
165}
166
167fn parse_float<Float, Value, ToFloat, ToValue>(
169 literal: &str,
170 ty: AlgebraicType,
171 to_float: ToFloat,
172 to_value: ToValue,
173) -> anyhow::Result<AlgebraicValue>
174where
175 Float: Into<Value>,
176 ToFloat: FnOnce(&BigDecimal) -> Option<Float>,
177 ToValue: FnOnce(Value) -> AlgebraicValue,
178{
179 BigDecimal::from_str(literal)
180 .ok()
181 .and_then(|decimal| to_float(&decimal))
182 .map(|value| value.into())
183 .map(to_value)
184 .ok_or_else(|| anyhow!("{literal} is not a valid {}", fmt_algebraic_type(&ty)))
185}
186
187pub(crate) fn parse(value: &str, ty: &AlgebraicType) -> anyhow::Result<AlgebraicValue> {
189 let to_bytes = || {
190 from_hex_pad::<Vec<u8>, _>(value)
191 .map(|v| v.into_boxed_slice())
192 .map(AlgebraicValue::Bytes)
193 .with_context(|| "Could not parse hex value")
194 };
195 let to_identity = || {
196 Identity::from_hex(value)
197 .map(AlgebraicValue::from)
198 .with_context(|| "Could not parse identity")
199 };
200 let to_connection_id = || {
201 ConnectionId::from_hex(value)
202 .map(AlgebraicValue::from)
203 .with_context(|| "Could not parse connection id")
204 };
205 let to_i256 = |decimal: &BigDecimal| {
206 i256::from_str_radix(
207 &decimal.to_plain_string(),
209 10,
210 )
211 .ok()
212 };
213 let to_u256 = |decimal: &BigDecimal| {
214 u256::from_str_radix(
215 &decimal.to_plain_string(),
217 10,
218 )
219 .ok()
220 };
221 match ty {
222 AlgebraicType::I8 => parse_int(
223 value,
225 AlgebraicType::I8,
226 BigDecimal::to_i8,
227 AlgebraicValue::I8,
228 ),
229 AlgebraicType::U8 => parse_int(
230 value,
232 AlgebraicType::U8,
233 BigDecimal::to_u8,
234 AlgebraicValue::U8,
235 ),
236 AlgebraicType::I16 => parse_int(
237 value,
239 AlgebraicType::I16,
240 BigDecimal::to_i16,
241 AlgebraicValue::I16,
242 ),
243 AlgebraicType::U16 => parse_int(
244 value,
246 AlgebraicType::U16,
247 BigDecimal::to_u16,
248 AlgebraicValue::U16,
249 ),
250 AlgebraicType::I32 => parse_int(
251 value,
253 AlgebraicType::I32,
254 BigDecimal::to_i32,
255 AlgebraicValue::I32,
256 ),
257 AlgebraicType::U32 => parse_int(
258 value,
260 AlgebraicType::U32,
261 BigDecimal::to_u32,
262 AlgebraicValue::U32,
263 ),
264 AlgebraicType::I64 => parse_int(
265 value,
267 AlgebraicType::I64,
268 BigDecimal::to_i64,
269 AlgebraicValue::I64,
270 ),
271 AlgebraicType::U64 => parse_int(
272 value,
274 AlgebraicType::U64,
275 BigDecimal::to_u64,
276 AlgebraicValue::U64,
277 ),
278 AlgebraicType::F32 => parse_float(
279 value,
281 AlgebraicType::F32,
282 BigDecimal::to_f32,
283 AlgebraicValue::F32,
284 ),
285 AlgebraicType::F64 => parse_float(
286 value,
288 AlgebraicType::F64,
289 BigDecimal::to_f64,
290 AlgebraicValue::F64,
291 ),
292 AlgebraicType::I128 => parse_int(
293 value,
295 AlgebraicType::I128,
296 BigDecimal::to_i128,
297 AlgebraicValue::I128,
298 ),
299 AlgebraicType::U128 => parse_int(
300 value,
302 AlgebraicType::U128,
303 BigDecimal::to_u128,
304 AlgebraicValue::U128,
305 ),
306 AlgebraicType::I256 => parse_int(
307 value,
309 AlgebraicType::I256,
310 to_i256,
311 AlgebraicValue::I256,
312 ),
313 AlgebraicType::U256 => parse_int(
314 value,
316 AlgebraicType::U256,
317 to_u256,
318 AlgebraicValue::U256,
319 ),
320 AlgebraicType::String => Ok(AlgebraicValue::String(value.into())),
321 t if t.is_bytes() => to_bytes(),
322 t if t.is_identity() => to_identity(),
323 t if t.is_connection_id() => to_connection_id(),
324 t => bail!("Literal values for type {} are not supported", fmt_algebraic_type(t)),
325 }
326}
327
328pub enum StatementSource {
330 Subscription,
331 Query,
332}
333
334pub struct StatementCtx<'a> {
338 pub statement: Statement,
339 pub sql: &'a str,
340 pub source: StatementSource,
341}