sql_migration_sim/
lib.rs

1//! This library is meant to parse multiple related SQL migration files, and calculate the final
2//! schema that results from running them in order.
3//!
4//! ## Example
5//!
6//! ```
7//! use sql_migration_sim::{Schema, Error, ast::DataType};
8//!
9//! let mut schema = Schema::new();
10//!
11//! let create_statement = r##"CREATE TABLE ships (
12//!    id BIGINT PRIMARY KEY,
13//!    name TEXT NOT NULL,
14//!    mast_count INT not null
15//! );"##;
16//!
17//! let alter = r##"
18//!     ALTER TABLE ships ALTER COLUMN mast_count DROP NOT NULL;
19//!     ALTER TABLE ships ADD COLUMN has_motor BOOLEAN NOT NULL;
20//!     "##;
21//!
22//! schema.apply_sql(create_statement)?;
23//! schema.apply_sql(alter)?;
24//!
25//!
26//! let result = schema.tables.get("ships").unwrap();
27//! assert_eq!(result.columns.len(), 4);
28//! assert_eq!(result.columns[0].name(), "id");
29//! assert!(matches!(result.columns[0].data_type, DataType::BigInt(_)));
30//! assert_eq!(result.columns[0].not_null(), true);
31//! assert_eq!(result.columns[1].name(), "name");
32//! assert_eq!(result.columns[1].not_null(), true);
33//! assert_eq!(result.columns[2].name(), "mast_count");
34//! assert_eq!(result.columns[2].not_null(), false);
35//! assert_eq!(result.columns[3].name(), "has_motor");
36//! assert_eq!(result.columns[3].not_null(), true);
37//!
38//! # Ok::<(), Error>(())
39//!
40//! ```
41//!
42
43#![warn(missing_docs)]
44use std::{
45    borrow::Cow,
46    collections::HashMap,
47    ops::{Deref, DerefMut},
48    path::Path,
49};
50
51#[cfg(feature = "serde")]
52use serde::{Deserialize, Serialize};
53use sqlparser::ast::{
54    AlterColumnOperation, AlterIndexOperation, AlterTableOperation, ColumnDef, ColumnOption,
55    ColumnOptionDef, CreateFunctionBody, DataType, Ident, ObjectName, ObjectType,
56    OperateFunctionArg, Statement, TableConstraint,
57};
58pub use sqlparser::{ast, dialect};
59
60/// A column in a database table
61#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
62#[derive(Debug, Clone)]
63pub struct Column(pub ColumnDef);
64
65impl Deref for Column {
66    type Target = ColumnDef;
67
68    fn deref(&self) -> &Self::Target {
69        &self.0
70    }
71}
72
73impl DerefMut for Column {
74    fn deref_mut(&mut self) -> &mut Self::Target {
75        &mut self.0
76    }
77}
78
79impl Column {
80    /// The name of the column
81    pub fn name(&self) -> &str {
82        self.name.value.as_str()
83    }
84
85    /// Whether the column is nullable or not
86    pub fn not_null(&self) -> bool {
87        self.options
88            .iter()
89            .find_map(|o| match o.option {
90                ColumnOption::Null => Some(false),
91                ColumnOption::NotNull => Some(true),
92                ColumnOption::Unique { is_primary, .. } => is_primary.then_some(true),
93                _ => None,
94            })
95            .unwrap_or(false)
96    }
97
98    /// Returns the default value of the column
99    pub fn default_value(&self) -> Option<&ast::Expr> {
100        self.options.iter().find_map(|o| match &o.option {
101            ColumnOption::Default(expr) => Some(expr),
102            _ => None,
103        })
104    }
105}
106
107/// A function in the database
108#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
109#[derive(Debug, Clone, PartialEq, Eq)]
110pub struct Function {
111    /// The name of the function
112    pub name: ObjectName,
113    /// The arguments of the function
114    pub args: Option<Vec<OperateFunctionArg>>,
115    /// The return type of the function
116    pub return_type: Option<DataType>,
117    /// The options and body of the function
118    pub params: CreateFunctionBody,
119}
120
121/// A table in the database
122#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
123#[derive(Debug, Clone)]
124pub struct Table {
125    /// The name of the table
126    pub name: ObjectName,
127    /// The columns in the table
128    pub columns: Vec<Column>,
129    /// Constraints on this table
130    pub constraints: Vec<TableConstraint>,
131}
132
133impl Table {
134    /// The name of the table
135    pub fn name(&self) -> String {
136        self.name.to_string()
137    }
138
139    /// The PostgreSQL schema of the table, if set.
140    pub fn schema(&self) -> Option<&str> {
141        if self.name.0.len() == 1 {
142            None
143        } else {
144            Some(self.name.0[0].value.as_str())
145        }
146    }
147}
148
149/// A view in the database
150#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
151#[derive(Debug, Clone)]
152pub struct View {
153    /// The name of the view
154    pub name: ObjectName,
155    /// The columns in the view
156    pub columns: Vec<String>,
157}
158
159impl View {
160    /// The name of the view
161    pub fn name(&self) -> String {
162        self.name.to_string()
163    }
164}
165
166/// Errors that can occur while parsing SQL and updating the schema
167#[derive(Debug, thiserror::Error)]
168pub enum Error {
169    /// Encountered an ALTER TABLE statement on a nonexistent table.
170    #[error("Attempted to alter a table {0} that does not exist")]
171    AlteredMissingTable(String),
172    /// Encountered an ALTER COLUMN statement on a nonexistent column.
173    #[error("Attempted to alter a column {0} that does not exist in table {1}")]
174    AlteredMissingColumn(String, String),
175    /// Attempted to create a table that already exists
176    #[error("Attempted to create table {0} that already exists")]
177    TableAlreadyExists(String),
178    /// Attempted to create a function that already exists
179    #[error("Attempted to create function {0} that already exists")]
180    FunctionAlreadyExists(String),
181    /// Attempted to create a column that already exists
182    #[error("Attempted to create column {0} that already exists in table {1}")]
183    ColumnAlreadyExists(String, String),
184    /// Attempted to rename an index that doesn't exist
185    #[error("Attempted to rename index {0} that does not exist")]
186    RenameMissingIndex(String),
187    /// The SQL parser encountered an error
188    #[error("SQL Parse Error {0}")]
189    Parse(#[from] sqlparser::parser::ParserError),
190    /// Error reading a file
191    #[error("Failed to read file {filename}")]
192    File {
193        /// The underlying error
194        #[source]
195        source: std::io::Error,
196        /// The name of the file on which the error occurred
197        filename: String,
198    },
199}
200
201/// The database schema, built from parsing one or more SQL statements.
202#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
203#[derive(Debug)]
204pub struct Schema {
205    #[cfg_attr(feature = "serde", serde(skip, default = "default_dialect"))]
206    /// The SQL dialect used for pasing
207    pub dialect: Box<dyn dialect::Dialect>,
208    /// The tables in the schema
209    pub tables: HashMap<String, Table>,
210    /// The views in the schema
211    pub views: HashMap<String, View>,
212    /// The created indices. The key is the index name and the value is the table the index is on.
213    pub indices: HashMap<String, String>,
214    /// Functions in the schema
215    pub functions: HashMap<String, Function>,
216    /// References to the schema objects, in the order they were created.
217    pub creation_order: Vec<ObjectNameAndType>,
218}
219
220#[cfg(feature = "serde")]
221fn default_dialect() -> Box<dyn dialect::Dialect> {
222    Box::new(dialect::GenericDialect {})
223}
224
225/// An object and its type
226#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
227#[derive(Debug, Clone, PartialEq, Eq)]
228pub struct ObjectNameAndType {
229    /// The name of the object
230    pub name: String,
231    /// The type of object this is
232    pub object_type: SchemaObjectType,
233}
234
235/// The type of an object in the [Schema]
236#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
237#[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
238#[derive(Debug, Clone, Copy, PartialEq, Eq)]
239pub enum SchemaObjectType {
240    /// A SQL table
241    Table,
242    /// A view
243    View,
244    /// An index
245    Index,
246    /// A Function
247    Function,
248}
249
250impl std::fmt::Display for SchemaObjectType {
251    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
252        match self {
253            SchemaObjectType::Table => write!(f, "table"),
254            SchemaObjectType::View => write!(f, "view"),
255            SchemaObjectType::Index => write!(f, "index"),
256            SchemaObjectType::Function => write!(f, "function"),
257        }
258    }
259}
260
261impl Schema {
262    /// Create a new [Schema] that parses with a generic SQL dialect
263    pub fn new() -> Self {
264        Self::new_with_dialect(sqlparser::dialect::GenericDialect {})
265    }
266
267    /// Create a new [Schema] that parses with the given SQL dialect
268    pub fn new_with_dialect<D: dialect::Dialect>(dialect: D) -> Self {
269        let dialect = Box::new(dialect);
270        Self {
271            tables: HashMap::new(),
272            views: HashMap::new(),
273            indices: HashMap::new(),
274            functions: HashMap::new(),
275            creation_order: Vec::new(),
276            dialect,
277        }
278    }
279
280    fn create_table(&mut self, name: ObjectName, mut columns: Vec<ColumnDef>) -> Result<(), Error> {
281        let name_str = normalized_name(&name).to_string();
282        if self.tables.contains_key(&name_str) {
283            return Err(Error::TableAlreadyExists(name_str));
284        }
285
286        for c in &mut columns {
287            for option in &mut c.options {
288                match &mut option.option {
289                    ColumnOption::ForeignKey { foreign_table, .. } => {
290                        let table = normalized_name(&foreign_table);
291                        if table.as_ref() != foreign_table {
292                            *foreign_table = table.into_owned();
293                        }
294                    }
295                    _ => {}
296                }
297            }
298        }
299
300        self.tables.insert(
301            name_str.clone(),
302            Table {
303                name,
304                columns: columns.into_iter().map(Column).collect(),
305                constraints: Vec::new(),
306            },
307        );
308
309        self.creation_order.push(ObjectNameAndType {
310            name: name_str,
311            object_type: SchemaObjectType::Table,
312        });
313
314        Ok(())
315    }
316
317    fn create_view(
318        &mut self,
319        name: ObjectName,
320        or_replace: bool,
321        columns: Vec<String>,
322    ) -> Result<(), Error> {
323        let name_str = normalized_name(&name).to_string();
324        if !or_replace && self.views.contains_key(&name_str) {
325            return Err(Error::TableAlreadyExists(name_str));
326        }
327
328        self.views.insert(name_str.clone(), View { name, columns });
329
330        self.creation_order.push(ObjectNameAndType {
331            name: name_str,
332            object_type: SchemaObjectType::View,
333        });
334
335        Ok(())
336    }
337
338    fn create_function(
339        &mut self,
340        name: ObjectName,
341        or_replace: bool,
342        args: Option<Vec<OperateFunctionArg>>,
343        return_type: Option<DataType>,
344        params: CreateFunctionBody,
345    ) -> Result<(), Error> {
346        let name_str = normalized_name(&name).to_string();
347        if !or_replace && self.functions.contains_key(&name_str) {
348            return Err(Error::TableAlreadyExists(name_str));
349        }
350
351        self.functions.insert(
352            name_str.clone(),
353            Function {
354                name,
355                args,
356                return_type,
357                params,
358            },
359        );
360
361        self.creation_order.push(ObjectNameAndType {
362            name: name_str,
363            object_type: SchemaObjectType::Function,
364        });
365
366        Ok(())
367    }
368
369    fn handle_alter_table(
370        &mut self,
371        name: &str,
372        name_ident: &ObjectName,
373        operation: AlterTableOperation,
374    ) -> Result<(), Error> {
375        match operation {
376            AlterTableOperation::AddColumn {
377                if_not_exists,
378                column_def,
379                ..
380            } => {
381                let table = self
382                    .tables
383                    .get_mut(name)
384                    .ok_or_else(|| Error::AlteredMissingTable(name.to_string()))?;
385
386                let existing_column = table.columns.iter().find(|c| c.name == column_def.name);
387
388                if existing_column.is_none() {
389                    table.columns.push(Column(column_def));
390                } else if !if_not_exists {
391                    return Err(Error::ColumnAlreadyExists(
392                        column_def.name.value,
393                        name.to_string(),
394                    ));
395                }
396            }
397
398            AlterTableOperation::DropColumn { column_name, .. } => {
399                let table = self
400                    .tables
401                    .get_mut(name)
402                    .ok_or_else(|| Error::AlteredMissingTable(name.to_string()))?;
403                table.columns.retain(|c| c.name != column_name);
404            }
405
406            AlterTableOperation::RenameColumn {
407                old_column_name,
408                new_column_name,
409            } => {
410                let table = self
411                    .tables
412                    .get_mut(name)
413                    .ok_or_else(|| Error::AlteredMissingTable(name.to_string()))?;
414
415                let column = table
416                    .columns
417                    .iter_mut()
418                    .find(|c| c.name == old_column_name)
419                    .ok_or_else(|| {
420                        Error::AlteredMissingColumn(old_column_name.value.clone(), name.to_string())
421                    })?;
422                column.name = new_column_name;
423            }
424
425            AlterTableOperation::RenameTable {
426                table_name: new_table_name,
427            } => {
428                let mut table = self
429                    .tables
430                    .remove(name)
431                    .ok_or_else(|| Error::AlteredMissingTable(name.to_string()))?;
432
433                let (schema, _) = object_schema_and_name(&name_ident);
434                let (_, new_table_name) = object_schema_and_name(&new_table_name);
435                let new_table_name = name_with_schema(schema.cloned(), new_table_name.clone());
436
437                let new_name_str = new_table_name.to_string();
438                table.name = new_table_name;
439
440                self.tables.insert(new_name_str.clone(), table);
441                // Update the name in creation_order to match
442                if let Some(i) = self
443                    .creation_order
444                    .iter_mut()
445                    .find(|o| o.name == name && o.object_type == SchemaObjectType::Table)
446                {
447                    i.name = new_name_str;
448                }
449            }
450
451            AlterTableOperation::AlterColumn { column_name, op } => {
452                let table = self
453                    .tables
454                    .get_mut(name)
455                    .ok_or_else(|| Error::AlteredMissingTable(name.to_string()))?;
456
457                let column = table
458                    .columns
459                    .iter_mut()
460                    .find(|c| c.name == column_name)
461                    .ok_or_else(|| {
462                        Error::AlteredMissingColumn(
463                            table.name.to_string(),
464                            column_name.value.clone(),
465                        )
466                    })?;
467
468                match op {
469                    AlterColumnOperation::SetNotNull => {
470                        if column
471                            .options
472                            .iter()
473                            .find(|o| o.option == ColumnOption::NotNull)
474                            .is_none()
475                        {
476                            column.options.push(ColumnOptionDef {
477                                name: None,
478                                option: ColumnOption::NotNull,
479                            });
480                        }
481
482                        column.options.retain(|o| o.option != ColumnOption::Null);
483                    }
484                    AlterColumnOperation::DropNotNull => {
485                        column.options.retain(|o| o.option != ColumnOption::NotNull);
486                    }
487                    AlterColumnOperation::SetDefault { value } => {
488                        if let Some(default_option) = column
489                            .options
490                            .iter_mut()
491                            .find(|o| matches!(o.option, ColumnOption::Default(_)))
492                        {
493                            default_option.option = ColumnOption::Default(value);
494                        } else {
495                            column.options.push(ColumnOptionDef {
496                                name: None,
497                                option: ColumnOption::Default(value),
498                            })
499                        }
500                    }
501                    AlterColumnOperation::DropDefault => {
502                        column
503                            .options
504                            .retain(|o| !matches!(o.option, ColumnOption::Default(_)));
505                    }
506
507                    AlterColumnOperation::SetDataType { data_type, .. } => {
508                        column.data_type = data_type
509                    }
510                    _ => {}
511                }
512            }
513
514            AlterTableOperation::AddConstraint(mut c) => {
515                let table = self
516                    .tables
517                    .get_mut(name)
518                    .ok_or_else(|| Error::AlteredMissingTable(name.to_string()))?;
519
520                match &mut c {
521                    TableConstraint::ForeignKey { foreign_table, .. } => {
522                        let table = normalized_name(&foreign_table);
523                        if table.as_ref() != foreign_table {
524                            *foreign_table = table.into_owned();
525                        }
526                    }
527                    _ => {}
528                }
529
530                table.constraints.push(c);
531            }
532
533            AlterTableOperation::DropConstraint {
534                name: constraint_name,
535                ..
536            } => {
537                let table = self
538                    .tables
539                    .get_mut(name)
540                    .ok_or_else(|| Error::AlteredMissingTable(name.to_string()))?;
541
542                table.constraints.retain(|c| {
543                    let name = table_constraint_name(c);
544                    name.as_ref().map(|n| n != &constraint_name).unwrap_or(true)
545                });
546            }
547
548            _ => {}
549        }
550
551        Ok(())
552    }
553
554    /// Apply a parsed statement to the schema
555    pub fn apply_statement(&mut self, statement: Statement) -> Result<(), Error> {
556        match statement {
557            Statement::CreateTable { name, columns, .. } => {
558                self.create_table(name, columns)?;
559            }
560            Statement::AlterTable {
561                name: name_ident,
562                operations,
563                ..
564            } => {
565                let name = normalized_name(&name_ident).to_string();
566                for operation in operations {
567                    self.handle_alter_table(&name, &name_ident, operation)?;
568                }
569            }
570            Statement::CreateView {
571                name,
572                columns,
573                or_replace,
574                ..
575            } => {
576                self.create_view(
577                    name,
578                    or_replace,
579                    columns.into_iter().map(|c| c.name.value).collect(),
580                )?;
581            }
582
583            Statement::CreateFunction {
584                or_replace,
585                temporary,
586                name,
587                args,
588                return_type,
589                params,
590            } => {
591                if !temporary {
592                    self.create_function(name, or_replace, args, return_type, params)?;
593                }
594            }
595
596            Statement::Drop {
597                object_type, names, ..
598            } => {
599                for name in names {
600                    let name = name.to_string();
601                    match object_type {
602                        ObjectType::Table => {
603                            self.tables.remove(&name);
604                            self.creation_order.retain(|c| {
605                                c.object_type != SchemaObjectType::Table || c.name != name
606                            });
607                        }
608                        ObjectType::View => {
609                            self.views.remove(&name);
610                            self.creation_order.retain(|c| {
611                                c.object_type != SchemaObjectType::View || c.name != name
612                            });
613                        }
614                        ObjectType::Index => {
615                            self.indices.remove(&name);
616                            self.creation_order.retain(|c| {
617                                c.object_type != SchemaObjectType::Index || c.name != name
618                            });
619                        }
620                        _ => {}
621                    }
622                }
623            }
624
625            Statement::CreateIndex {
626                name, table_name, ..
627            } => {
628                // For now we ignore indexes without names.
629                if let Some(name) = name {
630                    let (schema, _) = object_schema_and_name(&table_name);
631                    let (_, name) = object_schema_and_name(&name);
632                    let full_name = name_with_schema(schema.cloned(), name.clone());
633                    self.indices
634                        .insert(full_name.to_string(), table_name.to_string());
635                    self.creation_order.push(ObjectNameAndType {
636                        name: full_name.to_string(),
637                        object_type: SchemaObjectType::Index,
638                    });
639                }
640            }
641
642            Statement::AlterIndex { name, operation } => {
643                match operation {
644                    AlterIndexOperation::RenameIndex { index_name } => {
645                        let Some(table_name) = self.indices.remove(&name.to_string()) else {
646                            return Err(Error::RenameMissingIndex(name.to_string()));
647                        };
648
649                        let (schema, _) = object_schema_and_name(&name);
650                        let (_, index_name) = object_schema_and_name(&index_name);
651                        let new_name = name_with_schema(schema.cloned(), index_name.clone());
652                        let new_name = new_name.to_string();
653                        let old_name = name.to_string();
654                        self.indices.insert(new_name.clone(), table_name);
655
656                        // Update the name in creation_order to match
657                        if let Some(i) = self.creation_order.iter_mut().find(|o| {
658                            o.name == old_name && o.object_type == SchemaObjectType::Index
659                        }) {
660                            i.name = new_name;
661                        }
662                    }
663                }
664            }
665            _ => {}
666        }
667
668        Ok(())
669    }
670
671    /// Parse some SQL into a list of statements
672    pub fn parse_sql(&self, sql: &str) -> Result<Vec<Statement>, Error> {
673        sqlparser::parser::Parser::new(self.dialect.as_ref())
674            .try_with_sql(sql)?
675            .parse_statements()
676            .map_err(Error::from)
677    }
678
679    /// Apply one or more SQL statements to the schema
680    pub fn apply_sql(&mut self, sql: &str) -> Result<(), Error> {
681        self.parse_sql(sql)?
682            .into_iter()
683            .try_for_each(|statement| self.apply_statement(statement))
684    }
685
686    /// Read a SQL file and apply its contents to the schema
687    pub fn apply_file(&mut self, filename: &Path) -> Result<(), Error> {
688        let contents = std::fs::read_to_string(filename).map_err(|e| Error::File {
689            source: e,
690            filename: filename.display().to_string(),
691        })?;
692
693        self.apply_sql(&contents)
694    }
695}
696
697/// Apply a schema to a name
698pub fn name_with_schema(schema: Option<Ident>, name: Ident) -> ObjectName {
699    if let Some(schema) = schema {
700        ObjectName(vec![schema, name])
701    } else {
702        ObjectName(vec![name])
703    }
704}
705
706/// Extract a name into the name and an optional schema.
707pub fn object_schema_and_name(name: &ObjectName) -> (Option<&Ident>, &Ident) {
708    if name.0.len() == 2 {
709        (Some(&name.0[0]).filter(|s| s.value != "public"), &name.0[1])
710    } else {
711        (None, &name.0[0])
712    }
713}
714
715/// Name, buf if the schema is "public" then remove it.
716pub fn normalized_name(name: &ObjectName) -> Cow<'_, ObjectName> {
717    if name.0.len() == 2 && name.0[0].value == "public" {
718        Cow::Owned(ObjectName(vec![name.0[1].clone()]))
719    } else {
720        Cow::Borrowed(name)
721    }
722}
723
724/// Given an index name and the table it's on calculate the name with schema.
725pub fn index_full_name(index_name: &ObjectName, table_name: &ObjectName) -> ObjectName {
726    if index_name.0.len() > 1 {
727        return index_name.clone();
728    }
729
730    let (schema, _) = object_schema_and_name(&table_name);
731    return name_with_schema(schema.cloned(), index_name.0[0].clone());
732}
733
734/// Get the name of a table constraint
735pub fn table_constraint_name(constraint: &TableConstraint) -> &Option<Ident> {
736    match constraint {
737        TableConstraint::Unique { name, .. } => name,
738        TableConstraint::PrimaryKey { name, .. } => name,
739        TableConstraint::ForeignKey { name, .. } => name,
740        TableConstraint::Check { name, .. } => name,
741        TableConstraint::Index { name, .. } => name,
742        TableConstraint::FulltextOrSpatial { .. } => &None,
743    }
744}
745
746#[cfg(test)]
747mod test {
748    use sqlparser::{ast::DataType, dialect};
749
750    use super::*;
751
752    const CREATE: &str = r##"
753    CREATE TABLE ships (
754        id BIGINT PRIMARY KEY,
755        name TEXT NOT NULL,
756        mast_count INT not null
757    );"##;
758
759    const CREATE_WITH_SCHEMA: &str = r##"
760    CREATE TABLE sch.ships (
761        id BIGINT PRIMARY KEY,
762        name TEXT NOT NULL,
763        mast_count INT not null
764    );"##;
765
766    #[test]
767    fn rename_table() {
768        let mut schema = Schema::new();
769        schema.apply_sql(CREATE).unwrap();
770        schema
771            .apply_sql("ALTER TABLE ships RENAME TO ships_2;")
772            .unwrap();
773
774        assert!(schema.tables.contains_key("ships_2"));
775        assert!(!schema.tables.contains_key("ships"));
776        assert_eq!(
777            schema.creation_order,
778            vec![ObjectNameAndType {
779                name: "ships_2".to_string(),
780                object_type: SchemaObjectType::Table
781            }]
782        );
783    }
784
785    #[test]
786    fn rename_table_with_schema() {
787        let mut schema = Schema::new();
788        schema.apply_sql(CREATE_WITH_SCHEMA).unwrap();
789        schema
790            .apply_sql("ALTER TABLE sch.ships RENAME TO ships_2;")
791            .unwrap();
792
793        assert!(schema.tables.contains_key("sch.ships_2"));
794        assert!(!schema.tables.contains_key("sch.ships"));
795        assert_eq!(
796            schema.creation_order,
797            vec![ObjectNameAndType {
798                name: "sch.ships_2".to_string(),
799                object_type: SchemaObjectType::Table
800            }]
801        );
802    }
803
804    #[test]
805    fn drop_table() {
806        let mut schema = Schema::new();
807        schema.apply_sql(CREATE).unwrap();
808        schema.apply_sql(CREATE_WITH_SCHEMA).unwrap();
809        schema.apply_sql("DROP TABLE ships").unwrap();
810
811        assert!(!schema.tables.contains_key("ships"));
812        assert!(schema.tables.contains_key("sch.ships"));
813        assert_eq!(
814            schema.creation_order,
815            vec![ObjectNameAndType {
816                name: "sch.ships".to_string(),
817                object_type: SchemaObjectType::Table
818            }]
819        );
820    }
821
822    #[test]
823    fn create_index() {
824        let mut schema = Schema::new();
825        schema
826            .apply_sql(
827                "
828            CREATE INDEX idx_name ON ships(name);
829            CREATE INDEX idx_name_2 ON sch.ships(name);
830        ",
831            )
832            .unwrap();
833
834        assert_eq!(schema.indices.get("idx_name").unwrap(), "ships");
835        assert_eq!(schema.indices.get("sch.idx_name_2").unwrap(), "sch.ships");
836        assert_eq!(
837            schema.creation_order,
838            vec![
839                ObjectNameAndType {
840                    name: "idx_name".to_string(),
841                    object_type: SchemaObjectType::Index
842                },
843                ObjectNameAndType {
844                    name: "sch.idx_name_2".to_string(),
845                    object_type: SchemaObjectType::Index
846                },
847            ]
848        );
849    }
850
851    #[test]
852    fn drop_index() {
853        let mut schema = Schema::new();
854        schema.apply_sql(CREATE).unwrap();
855        schema
856            .apply_sql("CREATE INDEX idx_name ON sch.ships(name);")
857            .unwrap();
858
859        schema.apply_sql("DROP INDEX sch.idx_name;").unwrap();
860
861        assert!(schema.indices.is_empty());
862        assert_eq!(
863            schema.creation_order,
864            vec![ObjectNameAndType {
865                name: "ships".to_string(),
866                object_type: SchemaObjectType::Table
867            }]
868        );
869    }
870
871    #[test]
872    fn add_column() {
873        let mut schema = Schema::new();
874        schema.apply_sql(CREATE).unwrap();
875        schema
876            .apply_sql("ALTER TABLE ships ADD COLUMN has_motor BOOLEAN NOT NULL;")
877            .unwrap();
878        assert!(schema.tables["ships"].columns[3].name() == "has_motor");
879    }
880
881    #[test]
882    fn drop_column() {
883        let mut schema = Schema::new();
884        schema.apply_sql(CREATE).unwrap();
885        schema
886            .apply_sql("ALTER TABLE ships DROP COLUMN name;")
887            .unwrap();
888        assert!(schema.tables["ships"].columns.len() == 2);
889        assert!(schema.tables["ships"]
890            .columns
891            .iter()
892            .find(|c| c.name() == "name")
893            .is_none());
894    }
895
896    #[test]
897    fn rename_column() {
898        let mut schema = Schema::new();
899        schema.apply_sql(CREATE).unwrap();
900        schema
901            .apply_sql("ALTER TABLE ships RENAME COLUMN mast_count TO mast_count_2;")
902            .unwrap();
903        assert!(schema.tables["ships"].columns[2].name() == "mast_count_2");
904    }
905
906    #[test]
907    fn alter_column_change_nullable() {
908        let mut schema = Schema::new_with_dialect(dialect::PostgreSqlDialect {});
909        schema.apply_sql(CREATE).unwrap();
910        schema
911            .apply_sql("ALTER TABLE ships ALTER COLUMN mast_count DROP NOT NULL")
912            .unwrap();
913        assert!(!schema.tables["ships"].columns[2].not_null());
914
915        schema
916            .apply_sql("ALTER TABLE ships ALTER COLUMN mast_count SET NOT NULL")
917            .unwrap();
918        assert!(schema.tables["ships"].columns[2].not_null());
919    }
920
921    #[test]
922    fn alter_column_default() {
923        let mut schema = Schema::new_with_dialect(dialect::PostgreSqlDialect {});
924        schema.apply_sql(CREATE).unwrap();
925        schema
926            .apply_sql("ALTER TABLE ships ALTER COLUMN mast_count SET DEFAULT 2")
927            .unwrap();
928        assert_eq!(
929            schema.tables["ships"].columns[2]
930                .default_value()
931                .unwrap()
932                .to_string(),
933            "2"
934        );
935
936        schema
937            .apply_sql("ALTER TABLE ships ALTER COLUMN mast_count DROP DEFAULT")
938            .unwrap();
939        assert!(schema.tables["ships"].columns[2].default_value().is_none());
940    }
941
942    #[test]
943    fn alter_column_data_type() {
944        let mut schema = Schema::new_with_dialect(dialect::PostgreSqlDialect {});
945        schema.apply_sql(CREATE).unwrap();
946        schema
947            .apply_sql(
948                "ALTER TABLE ships ALTER COLUMN mast_count TYPE JSON USING(mast_count::json);",
949            )
950            .unwrap();
951        println!("{:?}", schema.tables["ships"].columns[2]);
952        assert!(schema.tables["ships"].columns[2].data_type == DataType::JSON);
953    }
954}