spacetimedb_sql_parser/ast/
sql.rsuse crate::parser::{errors::SqlUnsupported, SqlParseResult};
use super::{Project, SqlExpr, SqlFrom, SqlIdent, SqlLiteral};
#[derive(Debug)]
pub enum SqlAst {
Select(SqlSelect),
Insert(SqlInsert),
Update(SqlUpdate),
Delete(SqlDelete),
Set(SqlSet),
Show(SqlShow),
}
impl SqlAst {
pub fn qualify_vars(self) -> Self {
match self {
Self::Select(select) => Self::Select(select.qualify_vars()),
Self::Update(SqlUpdate {
table: with,
assignments,
filter,
}) => Self::Update(SqlUpdate {
table: with.clone(),
filter: filter.map(|expr| expr.qualify_vars(with)),
assignments,
}),
Self::Delete(SqlDelete { table: with, filter }) => Self::Delete(SqlDelete {
table: with.clone(),
filter: filter.map(|expr| expr.qualify_vars(with)),
}),
_ => self,
}
}
pub fn find_unqualified_vars(self) -> SqlParseResult<Self> {
match self {
Self::Select(select) => select.find_unqualified_vars().map(Self::Select),
_ => Ok(self),
}
}
}
#[derive(Debug)]
pub struct SqlSelect {
pub project: Project,
pub from: SqlFrom,
pub filter: Option<SqlExpr>,
}
impl SqlSelect {
pub fn qualify_vars(self) -> Self {
match &self.from {
SqlFrom::Expr(_, alias) => Self {
project: self.project.qualify_vars(alias.clone()),
filter: self.filter.map(|expr| expr.qualify_vars(alias.clone())),
from: self.from,
},
SqlFrom::Join(..) => self,
}
}
pub fn find_unqualified_vars(self) -> SqlParseResult<Self> {
if self.from.has_unqualified_vars() {
return Err(SqlUnsupported::UnqualifiedNames.into());
}
if self.project.has_unqualified_vars() {
return Err(SqlUnsupported::UnqualifiedNames.into());
}
Ok(self)
}
}
#[derive(Debug)]
pub struct SqlInsert {
pub table: SqlIdent,
pub fields: Vec<SqlIdent>,
pub values: SqlValues,
}
#[derive(Debug)]
pub struct SqlValues(pub Vec<Vec<SqlLiteral>>);
#[derive(Debug)]
pub struct SqlUpdate {
pub table: SqlIdent,
pub assignments: Vec<SqlSet>,
pub filter: Option<SqlExpr>,
}
#[derive(Debug)]
pub struct SqlDelete {
pub table: SqlIdent,
pub filter: Option<SqlExpr>,
}
#[derive(Debug)]
pub struct SqlSet(pub SqlIdent, pub SqlLiteral);
#[derive(Debug)]
pub struct SqlShow(pub SqlIdent);