spacetimedb_sql_parser/ast/
mod.rs

1use std::fmt::{Display, Formatter};
2
3use spacetimedb_lib::Identity;
4use sqlparser::ast::Ident;
5
6pub mod sql;
7pub mod sub;
8
9/// The FROM clause is either a relvar or a JOIN
10#[derive(Debug)]
11pub enum SqlFrom {
12    Expr(SqlIdent, SqlIdent),
13    Join(SqlIdent, SqlIdent, Vec<SqlJoin>),
14}
15
16impl SqlFrom {
17    pub fn has_unqualified_vars(&self) -> bool {
18        match self {
19            Self::Join(_, _, joins) => joins.iter().any(|join| join.has_unqualified_vars()),
20            _ => false,
21        }
22    }
23}
24
25/// An inner join in a FROM clause
26#[derive(Debug)]
27pub struct SqlJoin {
28    pub var: SqlIdent,
29    pub alias: SqlIdent,
30    pub on: Option<SqlExpr>,
31}
32
33impl SqlJoin {
34    pub fn has_unqualified_vars(&self) -> bool {
35        self.on.as_ref().is_some_and(|expr| expr.has_unqualified_vars())
36    }
37}
38
39/// A projection expression in a SELECT clause
40#[derive(Debug)]
41pub struct ProjectElem(pub ProjectExpr, pub SqlIdent);
42
43impl ProjectElem {
44    pub fn qualify_vars(self, with: SqlIdent) -> Self {
45        let Self(expr, alias) = self;
46        Self(expr.qualify_vars(with), alias)
47    }
48}
49
50/// A column projection in a SELECT clause
51#[derive(Debug)]
52pub enum ProjectExpr {
53    Var(SqlIdent),
54    Field(SqlIdent, SqlIdent),
55}
56
57impl From<ProjectExpr> for SqlExpr {
58    fn from(value: ProjectExpr) -> Self {
59        match value {
60            ProjectExpr::Var(name) => Self::Var(name),
61            ProjectExpr::Field(table, field) => Self::Field(table, field),
62        }
63    }
64}
65
66impl ProjectExpr {
67    pub fn qualify_vars(self, with: SqlIdent) -> Self {
68        match self {
69            Self::Var(name) => Self::Field(with, name),
70            Self::Field(_, _) => self,
71        }
72    }
73}
74
75/// A SQL SELECT clause
76#[derive(Debug)]
77pub enum Project {
78    /// SELECT *
79    /// SELECT a.*
80    Star(Option<SqlIdent>),
81    /// SELECT a, b
82    Exprs(Vec<ProjectElem>),
83    /// SELECT COUNT(*)
84    Count(SqlIdent),
85}
86
87impl Project {
88    pub fn qualify_vars(self, with: SqlIdent) -> Self {
89        match self {
90            Self::Star(..) | Self::Count(..) => self,
91            Self::Exprs(elems) => Self::Exprs(elems.into_iter().map(|elem| elem.qualify_vars(with.clone())).collect()),
92        }
93    }
94
95    pub fn has_unqualified_vars(&self) -> bool {
96        match self {
97            Self::Exprs(exprs) => exprs
98                .iter()
99                .any(|ProjectElem(expr, _)| matches!(expr, ProjectExpr::Var(_))),
100            _ => false,
101        }
102    }
103}
104
105/// A scalar SQL expression
106#[derive(Debug)]
107pub enum SqlExpr {
108    /// A constant expression
109    Lit(SqlLiteral),
110    /// Unqualified column ref
111    Var(SqlIdent),
112    /// A parameter prefixed with `:`
113    Param(Parameter),
114    /// Qualified column ref
115    Field(SqlIdent, SqlIdent),
116    /// A binary infix expression
117    Bin(Box<SqlExpr>, Box<SqlExpr>, BinOp),
118    /// A binary logic expression
119    Log(Box<SqlExpr>, Box<SqlExpr>, LogOp),
120}
121
122impl SqlExpr {
123    pub fn qualify_vars(self, with: SqlIdent) -> Self {
124        match self {
125            Self::Var(name) => Self::Field(with, name),
126            Self::Lit(..) | Self::Field(..) | Self::Param(..) => self,
127            Self::Bin(a, b, op) => Self::Bin(
128                Box::new(a.qualify_vars(with.clone())),
129                Box::new(b.qualify_vars(with)),
130                op,
131            ),
132            Self::Log(a, b, op) => Self::Log(
133                Box::new(a.qualify_vars(with.clone())),
134                Box::new(b.qualify_vars(with)),
135                op,
136            ),
137        }
138    }
139
140    pub fn has_unqualified_vars(&self) -> bool {
141        match self {
142            Self::Var(_) => true,
143            Self::Bin(a, b, _) | Self::Log(a, b, _) => a.has_unqualified_vars() || b.has_unqualified_vars(),
144            _ => false,
145        }
146    }
147
148    /// Is this AST parameterized?
149    /// We need to know in order to hash subscription queries correctly.
150    pub fn has_parameter(&self) -> bool {
151        match self {
152            Self::Lit(_) | Self::Var(_) | Self::Field(..) => false,
153            Self::Param(Parameter::Sender) => true,
154            Self::Bin(a, b, _) | Self::Log(a, b, _) => a.has_parameter() || b.has_parameter(),
155        }
156    }
157
158    /// Replace the `:sender` parameter with the [Identity] it represents
159    pub fn resolve_sender(self, sender_identity: Identity) -> Self {
160        match self {
161            Self::Lit(_) | Self::Var(_) | Self::Field(..) => self,
162            Self::Param(Parameter::Sender) => {
163                Self::Lit(SqlLiteral::Hex(String::from(sender_identity.to_hex()).into_boxed_str()))
164            }
165
166            Self::Bin(a, b, op) => Self::Bin(
167                Box::new(a.resolve_sender(sender_identity)),
168                Box::new(b.resolve_sender(sender_identity)),
169                op,
170            ),
171            Self::Log(a, b, op) => Self::Log(
172                Box::new(a.resolve_sender(sender_identity)),
173                Box::new(b.resolve_sender(sender_identity)),
174                op,
175            ),
176        }
177    }
178}
179
180/// A named parameter prefixed with `:`
181#[derive(Debug)]
182pub enum Parameter {
183    /// :sender
184    Sender,
185}
186
187/// A SQL identifier or named reference.
188/// Currently case sensitive.
189#[derive(Debug, Clone)]
190pub struct SqlIdent(pub Box<str>);
191
192/// Case insensitivity should be implemented here if at all
193impl From<Ident> for SqlIdent {
194    fn from(Ident { value, .. }: Ident) -> Self {
195        SqlIdent(value.into_boxed_str())
196    }
197}
198
199/// A SQL constant expression
200#[derive(Debug)]
201pub enum SqlLiteral {
202    /// A boolean constant
203    Bool(bool),
204    /// A hex value like 0xFF or x'FF'
205    Hex(Box<str>),
206    /// An integer or float value
207    Num(Box<str>),
208    /// A string value
209    Str(Box<str>),
210}
211
212/// Binary infix operators
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
214pub enum BinOp {
215    Eq,
216    Ne,
217    Lt,
218    Gt,
219    Lte,
220    Gte,
221}
222
223impl Display for BinOp {
224    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
225        match self {
226            Self::Eq => write!(f, "="),
227            Self::Ne => write!(f, "<>"),
228            Self::Lt => write!(f, "<"),
229            Self::Gt => write!(f, ">"),
230            Self::Lte => write!(f, "<="),
231            Self::Gte => write!(f, ">="),
232        }
233    }
234}
235
236#[derive(Debug, Clone, Copy, PartialEq, Eq)]
237pub enum LogOp {
238    And,
239    Or,
240}
241
242impl Display for LogOp {
243    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
244        match self {
245            Self::And => write!(f, "AND"),
246            Self::Or => write!(f, "OR"),
247        }
248    }
249}