sql_fun_sqlast/sem/
create_table.rs1mod column_def;
2mod column_name;
3mod table_name;
4mod view_name;
5
6pub use self::{
7 column_def::ColumnDefinition, column_name::ColumnName, table_name::TableName,
8 view_name::ViewName,
9};
10use std::collections::HashMap;
11
12use crate::{
13 sem::{
14 AlterTable, AnalysisError, AnalysisProblem, AstAndContextPair, Constraint, FullName,
15 ImplicitChange, ParseContext, SemAst, alter_table::AlterObjSubCommand,
16 },
17 syn::{CreateStmt, ListOpt, Opt},
18};
19
20#[macro_use]
21#[path = "./macros.rs"]
22mod macros;
23
24impl_from_ref!(TableName);
25impl_from_ref!(ColumnName);
26impl_from_ref!(ViewName);
27
28#[derive(Debug, Clone, Default, Eq, Hash, PartialEq, serde::Serialize, serde::Deserialize)]
30pub struct ColumnCollection(Vec<ColumnDefinition>);
31
32impl ColumnCollection {
33 #[must_use]
35 pub fn new(columns: Vec<ColumnDefinition>) -> Self {
36 Self(columns)
37 }
38
39 pub fn get_column_mut(&mut self, column_name: &ColumnName) -> Option<&mut ColumnDefinition> {
41 self.0
42 .iter_mut()
43 .find(|c| c.name().as_ref() == Some(column_name))
44 }
45
46 #[must_use]
48 pub fn get_column_def(&self, column_name: &ColumnName) -> Option<&ColumnDefinition> {
49 self.0
50 .iter()
51 .find(|c| c.name().as_ref() == Some(column_name))
52 }
53
54 #[must_use]
56 pub fn has_column(&self, column: &ColumnName) -> bool {
57 self.0.iter().any(|c| c.name().as_ref() == Some(column))
58 }
59
60 #[must_use]
62 pub fn get_at(&self, column_index: usize) -> Option<&ColumnDefinition> {
63 self.0.get(column_index)
64 }
65
66 #[must_use]
68 pub fn get_column_def_with_skip_count(
69 &self,
70 column_name: &ColumnName,
71 skip_count: usize,
72 ) -> Option<&ColumnDefinition> {
73 self.0[skip_count..]
74 .iter()
75 .find(|c| c.name().as_ref() == Some(column_name))
76 }
77
78 pub fn push(&mut self, column: &ColumnDefinition) {
80 self.0.push(column.clone());
81 }
82}
83
84#[derive(Debug)]
86pub struct CreateTable {
87 name: TableName,
88 columns: ColumnCollection,
89 alter_statements: Vec<AlterTable>,
90 owner: Option<String>,
91 cluster_on: Option<String>,
92 constraints: HashMap<String, Constraint>,
93}
94
95impl Clone for CreateTable {
96 fn clone(&self) -> Self {
97 Self {
98 name: self.name.clone(),
99 columns: self.columns.clone(),
100 alter_statements: Default::default(),
101 owner: Default::default(),
102 cluster_on: Default::default(),
103 constraints: Default::default(),
104 }
105 }
106}
107
108impl CreateTable {
109 #[must_use]
111 pub fn new(name: TableName, columns: Vec<ColumnDefinition>) -> Self {
112 Self {
113 name,
114 columns: ColumnCollection(columns),
115 alter_statements: Default::default(),
116 owner: Default::default(),
117 cluster_on: Default::default(),
118 constraints: Default::default(),
119 }
120 }
121
122 #[must_use]
124 pub fn name(&self) -> &TableName {
125 &self.name
126 }
127
128 #[must_use]
130 pub fn get_column_def(&self, column_name: &ColumnName) -> Option<&ColumnDefinition> {
131 self.columns.get_column_def(column_name)
132 }
133
134 #[must_use]
151 pub fn get_column_def_with_skip_count(
152 &self,
153 column_name: &ColumnName,
154 skip_count: usize,
155 ) -> Option<&ColumnDefinition> {
156 self.columns
157 .get_column_def_with_skip_count(column_name, skip_count)
158 }
159
160 #[must_use]
162 pub fn get_column_at(&self, column_index: usize) -> Option<&ColumnDefinition> {
163 self.columns.get_at(column_index)
164 }
165
166 #[must_use]
168 pub fn has_column(&self, column: &ColumnName) -> bool {
169 self.columns.has_column(column)
170 }
171
172 pub fn apply_alter(
174 &mut self,
175 alter_table: &AlterTable,
176 ) -> Result<Vec<ImplicitChange>, AnalysisError> {
177 self.alter_statements.push(alter_table.clone());
178
179 for cmd in alter_table.commands() {
180 match cmd {
181 AlterObjSubCommand::ChangeOwner(new_owner) => self.owner = Some(new_owner.clone()),
182 AlterObjSubCommand::SetColumnDefault(column_name, sem_scalar_expr) => {
183 if let Some(column) = self.columns.get_column_mut(column_name) {
184 column.set_default(sem_scalar_expr);
185 }
186 }
187 AlterObjSubCommand::AddConstraint(constraint) => {
188 self.constraints
189 .insert(constraint.name().to_string(), constraint.clone());
190 return constraint.implicit_changes();
191 }
192 AlterObjSubCommand::ClusterOn(index_name) => {
193 self.cluster_on = Some(index_name.clone())
194 }
195 }
196 }
197
198 Ok(Vec::new())
199 }
200
201 #[must_use]
203 pub fn primary_key_columns(&self) -> Vec<&ColumnName> {
204 if let Some(pk) = self
205 .constraints
206 .iter()
207 .find_map(|(_, c)| c.as_primary_key_constraint())
208 {
209 pk.keys().iter().collect()
210 } else {
211 Vec::new()
212 }
213 }
214}
215
216#[tracing::instrument(skip(context))]
218pub fn analyze_create_stmt<TParseContext>(
219 mut context: TParseContext,
220 parent_schema: &Option<String>,
221 create: CreateStmt,
222) -> Result<AstAndContextPair<TParseContext>, AnalysisError>
223where
224 TParseContext: ParseContext,
225{
226 let table_name = create.get_relation().unwrap();
227 let table_name = TableName::try_from(table_name)?;
228 let Some(column_defs) = create.get_table_elts().map(|f| f.as_column_def()) else {
229 AnalysisError::raise_unexpected_none("create_stmt.get_table_elts")?
230 };
231 let mut cols: Vec<ColumnDefinition> = Vec::new();
232 for column_def in column_defs {
233 let colname = column_def.get_colname();
234 let colname = ColumnName::new(&colname);
235 let Some(type_name) = column_def.get_type_name().as_inner() else {
236 AnalysisError::raise_unexpected_none("column_def.get_type_name")?
237 };
238 let is_not_null = column_def.get_is_not_null();
239
240 let key = FullName::try_from(type_name)?;
241 let type_ref = if let Some(sem_type) = context.get_type(&key).cloned() {
242 if let Some(type_ref) = sem_type.type_reference().cloned() {
243 Some(type_ref)
244 } else {
245 context.report_problem(AnalysisProblem::unexpected_dynamic_type(&sem_type))?;
246 None
247 }
248 } else {
249 context.report_problem(AnalysisProblem::column_type_not_found(
250 &table_name,
251 &colname,
252 &key,
253 ))?;
254 None
255 };
256
257 cols.push(ColumnDefinition::new(
258 &Some(colname),
259 type_ref.as_ref(),
260 Some(is_not_null),
261 ));
262 }
263 let table_def = CreateTable::new(table_name, cols);
264 let result_context = context.apply_create_table(&table_def)?;
265 Ok(AstAndContextPair::new(
266 SemAst::CreateTable(table_def),
267 result_context,
268 ))
269}