spacetimedb_sql_parser/ast/
sql.rs

1use crate::parser::{errors::SqlUnsupported, SqlParseResult};
2
3use super::{Project, SqlExpr, SqlFrom, SqlIdent, SqlLiteral};
4
5/// The AST for the SQL DML and query language
6#[derive(Debug)]
7pub enum SqlAst {
8    /// SELECT ...
9    Select(SqlSelect),
10    /// INSERT INTO ...
11    Insert(SqlInsert),
12    /// UPDATE ...
13    Update(SqlUpdate),
14    /// DELETE FROM ...
15    Delete(SqlDelete),
16    /// SET var TO ...
17    Set(SqlSet),
18    /// SHOW var
19    Show(SqlShow),
20}
21
22impl SqlAst {
23    pub fn qualify_vars(self) -> Self {
24        match self {
25            Self::Select(select) => Self::Select(select.qualify_vars()),
26            Self::Update(SqlUpdate {
27                table: with,
28                assignments,
29                filter,
30            }) => Self::Update(SqlUpdate {
31                table: with.clone(),
32                filter: filter.map(|expr| expr.qualify_vars(with)),
33                assignments,
34            }),
35            Self::Delete(SqlDelete { table: with, filter }) => Self::Delete(SqlDelete {
36                table: with.clone(),
37                filter: filter.map(|expr| expr.qualify_vars(with)),
38            }),
39            _ => self,
40        }
41    }
42
43    pub fn find_unqualified_vars(self) -> SqlParseResult<Self> {
44        match self {
45            Self::Select(select) => select.find_unqualified_vars().map(Self::Select),
46            _ => Ok(self),
47        }
48    }
49}
50
51/// A SELECT statement in the SQL query language
52#[derive(Debug)]
53pub struct SqlSelect {
54    pub project: Project,
55    pub from: SqlFrom,
56    pub filter: Option<SqlExpr>,
57    pub limit: Option<Box<str>>,
58}
59
60impl SqlSelect {
61    pub fn qualify_vars(self) -> Self {
62        match &self.from {
63            SqlFrom::Expr(_, alias) => Self {
64                project: self.project.qualify_vars(alias.clone()),
65                filter: self.filter.map(|expr| expr.qualify_vars(alias.clone())),
66                ..self
67            },
68            SqlFrom::Join(..) => self,
69        }
70    }
71
72    pub fn find_unqualified_vars(self) -> SqlParseResult<Self> {
73        if self.from.has_unqualified_vars() {
74            return Err(SqlUnsupported::UnqualifiedNames.into());
75        }
76        if self.project.has_unqualified_vars() {
77            return Err(SqlUnsupported::UnqualifiedNames.into());
78        }
79        Ok(self)
80    }
81}
82
83/// INSERT INTO table cols VALUES literals
84#[derive(Debug)]
85pub struct SqlInsert {
86    pub table: SqlIdent,
87    pub fields: Vec<SqlIdent>,
88    pub values: SqlValues,
89}
90
91/// VALUES literals
92#[derive(Debug)]
93pub struct SqlValues(pub Vec<Vec<SqlLiteral>>);
94
95/// UPDATE table SET cols [ WHERE predicate ]
96#[derive(Debug)]
97pub struct SqlUpdate {
98    pub table: SqlIdent,
99    pub assignments: Vec<SqlSet>,
100    pub filter: Option<SqlExpr>,
101}
102
103/// DELETE FROM table [ WHERE predicate ]
104#[derive(Debug)]
105pub struct SqlDelete {
106    pub table: SqlIdent,
107    pub filter: Option<SqlExpr>,
108}
109
110/// SET var '=' literal
111#[derive(Debug)]
112pub struct SqlSet(pub SqlIdent, pub SqlLiteral);
113
114/// SHOW var
115#[derive(Debug)]
116pub struct SqlShow(pub SqlIdent);