Skip to main content

sql_fun_sqlast/sem/
create_table.rs

1mod column_def;
2mod column_name;
3mod table_name;
4mod view_name;
5
6pub use self::{
7    column_def::ColumnDefinition, column_name::ColumnName, table_name::TableName,
8    view_name::ViewName,
9};
10use std::collections::HashMap;
11
12use crate::{
13    sem::{
14        AlterTable, AnalysisError, AnalysisProblem, AstAndContextPair, Constraint, FullName,
15        ImplicitChange, ParseContext, SemAst, alter_table::AlterObjSubCommand,
16    },
17    syn::{CreateStmt, ListOpt, Opt},
18};
19
20#[macro_use]
21#[path = "./macros.rs"]
22mod macros;
23
24impl_from_ref!(TableName);
25impl_from_ref!(ColumnName);
26impl_from_ref!(ViewName);
27
28/// columns in table
29#[derive(Debug, Clone, Default, Eq, Hash, PartialEq, serde::Serialize, serde::Deserialize)]
30pub struct ColumnCollection(Vec<ColumnDefinition>);
31
32impl ColumnCollection {
33    /// add column
34    #[must_use]
35    pub fn new(columns: Vec<ColumnDefinition>) -> Self {
36        Self(columns)
37    }
38
39    /// find mutateble column by name
40    pub fn get_column_mut(&mut self, column_name: &ColumnName) -> Option<&mut ColumnDefinition> {
41        self.0
42            .iter_mut()
43            .find(|c| c.name().as_ref() == Some(column_name))
44    }
45
46    /// get column definition by name
47    #[must_use]
48    pub fn get_column_def(&self, column_name: &ColumnName) -> Option<&ColumnDefinition> {
49        self.0
50            .iter()
51            .find(|c| c.name().as_ref() == Some(column_name))
52    }
53
54    /// test column existing
55    #[must_use]
56    pub fn has_column(&self, column: &ColumnName) -> bool {
57        self.0.iter().any(|c| c.name().as_ref() == Some(column))
58    }
59
60    /// get column definition by index
61    #[must_use]
62    pub fn get_at(&self, column_index: usize) -> Option<&ColumnDefinition> {
63        self.0.get(column_index)
64    }
65
66    /// get column definition with skip count
67    #[must_use]
68    pub fn get_column_def_with_skip_count(
69        &self,
70        column_name: &ColumnName,
71        skip_count: usize,
72    ) -> Option<&ColumnDefinition> {
73        self.0[skip_count..]
74            .iter()
75            .find(|c| c.name().as_ref() == Some(column_name))
76    }
77
78    /// append column definition
79    pub fn push(&mut self, column: &ColumnDefinition) {
80        self.0.push(column.clone());
81    }
82}
83
84/// `create table` SQL statement
85#[derive(Debug)]
86pub struct CreateTable {
87    name: TableName,
88    columns: ColumnCollection,
89    alter_statements: Vec<AlterTable>,
90    owner: Option<String>,
91    cluster_on: Option<String>,
92    constraints: HashMap<String, Constraint>,
93}
94
95impl Clone for CreateTable {
96    fn clone(&self) -> Self {
97        Self {
98            name: self.name.clone(),
99            columns: self.columns.clone(),
100            alter_statements: Default::default(),
101            owner: Default::default(),
102            cluster_on: Default::default(),
103            constraints: Default::default(),
104        }
105    }
106}
107
108impl CreateTable {
109    /// new instance with column definitions
110    #[must_use]
111    pub fn new(name: TableName, columns: Vec<ColumnDefinition>) -> Self {
112        Self {
113            name,
114            columns: ColumnCollection(columns),
115            alter_statements: Default::default(),
116            owner: Default::default(),
117            cluster_on: Default::default(),
118            constraints: Default::default(),
119        }
120    }
121
122    /// get table name
123    #[must_use]
124    pub fn name(&self) -> &TableName {
125        &self.name
126    }
127
128    /// get column definition by column name
129    #[must_use]
130    pub fn get_column_def(&self, column_name: &ColumnName) -> Option<&ColumnDefinition> {
131        self.columns.get_column_def(column_name)
132    }
133
134    /// Look up a column definition by *original* column name, ignoring the first `skip_count` columns.
135    ///
136    /// This is used to implement `PostgreSQL`'s behavior for a table alias with a column alias list:
137    ///
138    /// ```sql
139    /// SELECT ... FROM users AS u(id, name);
140    /// ```
141    ///
142    /// When a column alias list is present, the original names of the first N columns (N = the length
143    /// of the alias list) are not visible through that alias. Only the alias names are visible for
144    /// those columns. Columns after the first N keep their original names.
145    ///
146    /// For example, if `users` has at least two columns and the 1st column is named `user_id`,
147    /// then `u.user_id` is not accessible when `skip_count >= 1` (e.g. `u(id, name)` makes `skip_count = 2`).
148    /// However, columns after the first `skip_count` columns are still accessible by their original names.
149    ///
150    #[must_use]
151    pub fn get_column_def_with_skip_count(
152        &self,
153        column_name: &ColumnName,
154        skip_count: usize,
155    ) -> Option<&ColumnDefinition> {
156        self.columns
157            .get_column_def_with_skip_count(column_name, skip_count)
158    }
159
160    /// get column definition by index
161    #[must_use]
162    pub fn get_column_at(&self, column_index: usize) -> Option<&ColumnDefinition> {
163        self.columns.get_at(column_index)
164    }
165
166    /// check column existing
167    #[must_use]
168    pub fn has_column(&self, column: &ColumnName) -> bool {
169        self.columns.has_column(column)
170    }
171
172    /// apply `alter table` statement
173    pub fn apply_alter(
174        &mut self,
175        alter_table: &AlterTable,
176    ) -> Result<Vec<ImplicitChange>, AnalysisError> {
177        self.alter_statements.push(alter_table.clone());
178
179        for cmd in alter_table.commands() {
180            match cmd {
181                AlterObjSubCommand::ChangeOwner(new_owner) => self.owner = Some(new_owner.clone()),
182                AlterObjSubCommand::SetColumnDefault(column_name, sem_scalar_expr) => {
183                    if let Some(column) = self.columns.get_column_mut(column_name) {
184                        column.set_default(sem_scalar_expr);
185                    }
186                }
187                AlterObjSubCommand::AddConstraint(constraint) => {
188                    self.constraints
189                        .insert(constraint.name().to_string(), constraint.clone());
190                    return constraint.implicit_changes();
191                }
192                AlterObjSubCommand::ClusterOn(index_name) => {
193                    self.cluster_on = Some(index_name.clone())
194                }
195            }
196        }
197
198        Ok(Vec::new())
199    }
200
201    /// get primary key column names
202    #[must_use]
203    pub fn primary_key_columns(&self) -> Vec<&ColumnName> {
204        if let Some(pk) = self
205            .constraints
206            .iter()
207            .find_map(|(_, c)| c.as_primary_key_constraint())
208        {
209            pk.keys().iter().collect()
210        } else {
211            Vec::new()
212        }
213    }
214}
215
216/// analyze [`crate::syn::CreateStmt`]
217#[tracing::instrument(skip(context))]
218pub fn analyze_create_stmt<TParseContext>(
219    mut context: TParseContext,
220    parent_schema: &Option<String>,
221    create: CreateStmt,
222) -> Result<AstAndContextPair<TParseContext>, AnalysisError>
223where
224    TParseContext: ParseContext,
225{
226    let table_name = create.get_relation().unwrap();
227    let table_name = TableName::try_from(table_name)?;
228    let Some(column_defs) = create.get_table_elts().map(|f| f.as_column_def()) else {
229        AnalysisError::raise_unexpected_none("create_stmt.get_table_elts")?
230    };
231    let mut cols: Vec<ColumnDefinition> = Vec::new();
232    for column_def in column_defs {
233        let colname = column_def.get_colname();
234        let colname = ColumnName::new(&colname);
235        let Some(type_name) = column_def.get_type_name().as_inner() else {
236            AnalysisError::raise_unexpected_none("column_def.get_type_name")?
237        };
238        let is_not_null = column_def.get_is_not_null();
239
240        let key = FullName::try_from(type_name)?;
241        let type_ref = if let Some(sem_type) = context.get_type(&key).cloned() {
242            if let Some(type_ref) = sem_type.type_reference().cloned() {
243                Some(type_ref)
244            } else {
245                context.report_problem(AnalysisProblem::unexpected_dynamic_type(&sem_type))?;
246                None
247            }
248        } else {
249            context.report_problem(AnalysisProblem::column_type_not_found(
250                &table_name,
251                &colname,
252                &key,
253            ))?;
254            None
255        };
256
257        cols.push(ColumnDefinition::new(
258            &Some(colname),
259            type_ref.as_ref(),
260            Some(is_not_null),
261        ));
262    }
263    let table_def = CreateTable::new(table_name, cols);
264    let result_context = context.apply_create_table(&table_def)?;
265    Ok(AstAndContextPair::new(
266        SemAst::CreateTable(table_def),
267        result_context,
268    ))
269}