spacetimedb_sql_parser/ast/
sql.rs

1use spacetimedb_lib::Identity;
2
3use crate::parser::{errors::SqlUnsupported, SqlParseResult};
4
5use super::{Project, SqlExpr, SqlFrom, SqlIdent, SqlLiteral};
6
7/// The AST for the SQL DML and query language
8#[derive(Debug)]
9pub enum SqlAst {
10    /// SELECT ...
11    Select(SqlSelect),
12    /// INSERT INTO ...
13    Insert(SqlInsert),
14    /// UPDATE ...
15    Update(SqlUpdate),
16    /// DELETE FROM ...
17    Delete(SqlDelete),
18    /// SET var TO ...
19    Set(SqlSet),
20    /// SHOW var
21    Show(SqlShow),
22}
23
24impl SqlAst {
25    pub fn qualify_vars(self) -> Self {
26        match self {
27            Self::Select(select) => Self::Select(select.qualify_vars()),
28            Self::Update(SqlUpdate {
29                table: with,
30                assignments,
31                filter,
32            }) => Self::Update(SqlUpdate {
33                table: with.clone(),
34                filter: filter.map(|expr| expr.qualify_vars(with)),
35                assignments,
36            }),
37            Self::Delete(SqlDelete { table: with, filter }) => Self::Delete(SqlDelete {
38                table: with.clone(),
39                filter: filter.map(|expr| expr.qualify_vars(with)),
40            }),
41            _ => self,
42        }
43    }
44
45    pub fn find_unqualified_vars(self) -> SqlParseResult<Self> {
46        match self {
47            Self::Select(select) => select.find_unqualified_vars().map(Self::Select),
48            _ => Ok(self),
49        }
50    }
51
52    /// Replace the `:sender` parameter with the [Identity] it represents
53    pub fn resolve_sender(self, sender_identity: Identity) -> Self {
54        match self {
55            Self::Select(select) => Self::Select(select.resolve_sender(sender_identity)),
56            Self::Update(update) => Self::Update(update.resolve_sender(sender_identity)),
57            Self::Delete(delete) => Self::Delete(delete.resolve_sender(sender_identity)),
58            _ => self,
59        }
60    }
61}
62
63/// A SELECT statement in the SQL query language
64#[derive(Debug)]
65pub struct SqlSelect {
66    pub project: Project,
67    pub from: SqlFrom,
68    pub filter: Option<SqlExpr>,
69    pub limit: Option<Box<str>>,
70}
71
72impl SqlSelect {
73    pub fn qualify_vars(self) -> Self {
74        match &self.from {
75            SqlFrom::Expr(_, alias) => Self {
76                project: self.project.qualify_vars(alias.clone()),
77                filter: self.filter.map(|expr| expr.qualify_vars(alias.clone())),
78                ..self
79            },
80            SqlFrom::Join(..) => self,
81        }
82    }
83
84    pub fn find_unqualified_vars(self) -> SqlParseResult<Self> {
85        if self.from.has_unqualified_vars() {
86            return Err(SqlUnsupported::UnqualifiedNames.into());
87        }
88        if self.project.has_unqualified_vars() {
89            return Err(SqlUnsupported::UnqualifiedNames.into());
90        }
91        if let Some(expr) = &self.filter {
92            if expr.has_unqualified_vars() {
93                return Err(SqlUnsupported::UnqualifiedNames.into());
94            }
95        }
96        Ok(self)
97    }
98
99    /// Replace the `:sender` parameter with the [Identity] it represents
100    pub fn resolve_sender(self, sender_identity: Identity) -> Self {
101        Self {
102            filter: self.filter.map(|expr| expr.resolve_sender(sender_identity)),
103            ..self
104        }
105    }
106}
107
108/// INSERT INTO table cols VALUES literals
109#[derive(Debug)]
110pub struct SqlInsert {
111    pub table: SqlIdent,
112    pub fields: Vec<SqlIdent>,
113    pub values: SqlValues,
114}
115
116/// VALUES literals
117#[derive(Debug)]
118pub struct SqlValues(pub Vec<Vec<SqlLiteral>>);
119
120/// UPDATE table SET cols [ WHERE predicate ]
121#[derive(Debug)]
122pub struct SqlUpdate {
123    pub table: SqlIdent,
124    pub assignments: Vec<SqlSet>,
125    pub filter: Option<SqlExpr>,
126}
127
128impl SqlUpdate {
129    /// Replace the `:sender` parameter with the [Identity] it represents
130    fn resolve_sender(self, sender_identity: Identity) -> Self {
131        Self {
132            filter: self.filter.map(|expr| expr.resolve_sender(sender_identity)),
133            ..self
134        }
135    }
136}
137
138/// DELETE FROM table [ WHERE predicate ]
139#[derive(Debug)]
140pub struct SqlDelete {
141    pub table: SqlIdent,
142    pub filter: Option<SqlExpr>,
143}
144
145impl SqlDelete {
146    /// Replace the `:sender` parameter with the [Identity] it represents
147    fn resolve_sender(self, sender_identity: Identity) -> Self {
148        Self {
149            filter: self.filter.map(|expr| expr.resolve_sender(sender_identity)),
150            ..self
151        }
152    }
153}
154
155/// SET var '=' literal
156#[derive(Debug)]
157pub struct SqlSet(pub SqlIdent, pub SqlLiteral);
158
159/// SHOW var
160#[derive(Debug)]
161pub struct SqlShow(pub SqlIdent);