spacetimedb_sql_parser/ast/
sql.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
use crate::parser::{errors::SqlUnsupported, SqlParseResult};

use super::{Project, SqlExpr, SqlFrom, SqlIdent, SqlLiteral};

/// The AST for the SQL DML and query language
#[derive(Debug)]
pub enum SqlAst {
    /// SELECT ...
    Select(SqlSelect),
    /// INSERT INTO ...
    Insert(SqlInsert),
    /// UPDATE ...
    Update(SqlUpdate),
    /// DELETE FROM ...
    Delete(SqlDelete),
    /// SET var TO ...
    Set(SqlSet),
    /// SHOW var
    Show(SqlShow),
}

impl SqlAst {
    pub fn qualify_vars(self) -> Self {
        match self {
            Self::Select(select) => Self::Select(select.qualify_vars()),
            Self::Update(SqlUpdate {
                table: with,
                assignments,
                filter,
            }) => Self::Update(SqlUpdate {
                table: with.clone(),
                filter: filter.map(|expr| expr.qualify_vars(with)),
                assignments,
            }),
            Self::Delete(SqlDelete { table: with, filter }) => Self::Delete(SqlDelete {
                table: with.clone(),
                filter: filter.map(|expr| expr.qualify_vars(with)),
            }),
            _ => self,
        }
    }

    pub fn find_unqualified_vars(self) -> SqlParseResult<Self> {
        match self {
            Self::Select(select) => select.find_unqualified_vars().map(Self::Select),
            _ => Ok(self),
        }
    }
}

/// A SELECT statement in the SQL query language
#[derive(Debug)]
pub struct SqlSelect {
    pub project: Project,
    pub from: SqlFrom,
    pub filter: Option<SqlExpr>,
}

impl SqlSelect {
    pub fn qualify_vars(self) -> Self {
        match &self.from {
            SqlFrom::Expr(_, alias) => Self {
                project: self.project.qualify_vars(alias.clone()),
                filter: self.filter.map(|expr| expr.qualify_vars(alias.clone())),
                from: self.from,
            },
            SqlFrom::Join(..) => self,
        }
    }

    pub fn find_unqualified_vars(self) -> SqlParseResult<Self> {
        if self.from.has_unqualified_vars() {
            return Err(SqlUnsupported::UnqualifiedNames.into());
        }
        if self.project.has_unqualified_vars() {
            return Err(SqlUnsupported::UnqualifiedNames.into());
        }
        Ok(self)
    }
}

/// INSERT INTO table cols VALUES literals
#[derive(Debug)]
pub struct SqlInsert {
    pub table: SqlIdent,
    pub fields: Vec<SqlIdent>,
    pub values: SqlValues,
}

/// VALUES literals
#[derive(Debug)]
pub struct SqlValues(pub Vec<Vec<SqlLiteral>>);

/// UPDATE table SET cols [ WHERE predicate ]
#[derive(Debug)]
pub struct SqlUpdate {
    pub table: SqlIdent,
    pub assignments: Vec<SqlSet>,
    pub filter: Option<SqlExpr>,
}

/// DELETE FROM table [ WHERE predicate ]
#[derive(Debug)]
pub struct SqlDelete {
    pub table: SqlIdent,
    pub filter: Option<SqlExpr>,
}

/// SET var '=' literal
#[derive(Debug)]
pub struct SqlSet(pub SqlIdent, pub SqlLiteral);

/// SHOW var
#[derive(Debug)]
pub struct SqlShow(pub SqlIdent);