Skip to main content

sea_orm_migration/
manager.rs

1use super::{IntoSchemaManagerConnection, SchemaManagerConnection};
2use sea_orm::sea_query::{
3    ForeignKeyCreateStatement, ForeignKeyDropStatement, IndexCreateStatement, IndexDropStatement,
4    SelectStatement, TableAlterStatement, TableCreateStatement, TableDropStatement,
5    TableRenameStatement, TableTruncateStatement,
6    extension::postgres::{TypeAlterStatement, TypeCreateStatement, TypeDropStatement},
7};
8use sea_orm::{ConnectionTrait, DbBackend, DbErr, StatementBuilder, TransactionTrait};
9#[allow(unused_imports)]
10use sea_schema::probe::SchemaProbe;
11
12/// Helper struct for writing migration scripts in migration file
13pub struct SchemaManager<'c> {
14    conn: SchemaManagerConnection<'c>,
15}
16
17impl<'c> SchemaManager<'c> {
18    pub fn new<T>(conn: T) -> Self
19    where
20        T: IntoSchemaManagerConnection<'c>,
21    {
22        Self {
23            conn: conn.into_database_executor(),
24        }
25    }
26
27    pub async fn execute<S>(&self, stmt: S) -> Result<(), DbErr>
28    where
29        S: StatementBuilder,
30    {
31        self.conn.execute(&stmt).await.map(|_| ())
32    }
33
34    #[doc(hidden)]
35    pub async fn exec_stmt<S>(&self, stmt: S) -> Result<(), DbErr>
36    where
37        S: StatementBuilder,
38    {
39        self.conn.execute(&stmt).await.map(|_| ())
40    }
41
42    pub fn get_database_backend(&self) -> DbBackend {
43        self.conn.get_database_backend()
44    }
45
46    pub fn get_connection(&self) -> &SchemaManagerConnection<'c> {
47        &self.conn
48    }
49}
50
51/// Transaction Control
52impl SchemaManager<'_> {
53    /// Begin a new transaction, returning an owned `SchemaManager` backed by it.
54    ///
55    /// Useful in migrations with `use_transaction() -> Some(false)` for manual
56    /// transaction management (e.g., separating DDL and DML into distinct transactions).
57    pub async fn begin(&self) -> Result<SchemaManager<'static>, DbErr> {
58        let txn = self.conn.begin().await?;
59        Ok(SchemaManager {
60            conn: SchemaManagerConnection::OwnedTransaction(txn),
61        })
62    }
63
64    /// Commit the owned transaction. Only valid on a `SchemaManager` created by [`begin()`](Self::begin).
65    pub async fn commit(self) -> Result<(), DbErr> {
66        match self.conn {
67            SchemaManagerConnection::OwnedTransaction(txn) => txn.commit().await,
68            _ => Err(DbErr::Custom(
69                "Cannot commit: SchemaManager does not own a transaction".into(),
70            )),
71        }
72    }
73}
74
75/// Schema Creation
76impl SchemaManager<'_> {
77    pub async fn create_table(&self, stmt: TableCreateStatement) -> Result<(), DbErr> {
78        self.execute(stmt).await
79    }
80
81    pub async fn create_index(&self, stmt: IndexCreateStatement) -> Result<(), DbErr> {
82        self.execute(stmt).await
83    }
84
85    pub async fn create_foreign_key(&self, stmt: ForeignKeyCreateStatement) -> Result<(), DbErr> {
86        self.execute(stmt).await
87    }
88
89    pub async fn create_type(&self, stmt: TypeCreateStatement) -> Result<(), DbErr> {
90        self.execute(stmt).await
91    }
92}
93
94/// Schema Mutation
95impl SchemaManager<'_> {
96    pub async fn alter_table(&self, stmt: TableAlterStatement) -> Result<(), DbErr> {
97        self.execute(stmt).await
98    }
99
100    pub async fn drop_table(&self, stmt: TableDropStatement) -> Result<(), DbErr> {
101        self.execute(stmt).await
102    }
103
104    pub async fn rename_table(&self, stmt: TableRenameStatement) -> Result<(), DbErr> {
105        self.execute(stmt).await
106    }
107
108    pub async fn truncate_table(&self, stmt: TableTruncateStatement) -> Result<(), DbErr> {
109        self.execute(stmt).await
110    }
111
112    pub async fn drop_index(&self, stmt: IndexDropStatement) -> Result<(), DbErr> {
113        self.execute(stmt).await
114    }
115
116    pub async fn drop_foreign_key(&self, stmt: ForeignKeyDropStatement) -> Result<(), DbErr> {
117        self.execute(stmt).await
118    }
119
120    pub async fn alter_type(&self, stmt: TypeAlterStatement) -> Result<(), DbErr> {
121        self.execute(stmt).await
122    }
123
124    pub async fn drop_type(&self, stmt: TypeDropStatement) -> Result<(), DbErr> {
125        self.execute(stmt).await
126    }
127}
128
129/// Schema Inspection.
130impl SchemaManager<'_> {
131    pub async fn has_table<T>(&self, table: T) -> Result<bool, DbErr>
132    where
133        T: AsRef<str>,
134    {
135        has_table(&self.conn, table).await
136    }
137
138    pub async fn has_column<T, C>(&self, _table: T, _column: C) -> Result<bool, DbErr>
139    where
140        T: AsRef<str>,
141        C: AsRef<str>,
142    {
143        let _stmt: SelectStatement = match self.conn.get_database_backend() {
144            #[cfg(feature = "sqlx-mysql")]
145            DbBackend::MySql => sea_schema::mysql::MySql.has_column(_table, _column),
146            #[cfg(feature = "sqlx-postgres")]
147            DbBackend::Postgres => sea_schema::postgres::Postgres.has_column(_table, _column),
148            #[cfg(feature = "sqlx-sqlite")]
149            DbBackend::Sqlite => sea_schema::sqlite::Sqlite.has_column(_table, _column),
150            #[allow(unreachable_patterns)]
151            other => {
152                return Err(DbErr::BackendNotSupported {
153                    db: other.as_str(),
154                    ctx: "has_column",
155                });
156            }
157        };
158
159        #[allow(unreachable_code)]
160        let res = self
161            .conn
162            .query_one(&_stmt)
163            .await?
164            .ok_or_else(|| DbErr::Custom("Failed to check column exists".to_owned()))?;
165
166        res.try_get("", "has_column")
167    }
168
169    pub async fn has_index<T, I>(&self, _table: T, _index: I) -> Result<bool, DbErr>
170    where
171        T: AsRef<str>,
172        I: AsRef<str>,
173    {
174        let _stmt: SelectStatement = match self.conn.get_database_backend() {
175            #[cfg(feature = "sqlx-mysql")]
176            DbBackend::MySql => sea_schema::mysql::MySql.has_index(_table, _index),
177            #[cfg(feature = "sqlx-postgres")]
178            DbBackend::Postgres => sea_schema::postgres::Postgres.has_index(_table, _index),
179            #[cfg(feature = "sqlx-sqlite")]
180            DbBackend::Sqlite => sea_schema::sqlite::Sqlite.has_index(_table, _index),
181            #[allow(unreachable_patterns)]
182            other => {
183                return Err(DbErr::BackendNotSupported {
184                    db: other.as_str(),
185                    ctx: "has_index",
186                });
187            }
188        };
189
190        #[allow(unreachable_code)]
191        let res = self
192            .conn
193            .query_one(&_stmt)
194            .await?
195            .ok_or_else(|| DbErr::Custom("Failed to check index exists".to_owned()))?;
196
197        res.try_get("", "has_index")
198    }
199}
200
201pub(crate) async fn has_table<C, T>(conn: &C, _table: T) -> Result<bool, DbErr>
202where
203    C: ConnectionTrait,
204    T: AsRef<str>,
205{
206    let _stmt: SelectStatement = match conn.get_database_backend() {
207        #[cfg(feature = "sqlx-mysql")]
208        DbBackend::MySql => sea_schema::mysql::MySql.has_table(_table),
209        #[cfg(feature = "sqlx-postgres")]
210        DbBackend::Postgres => sea_schema::postgres::Postgres.has_table(_table),
211        #[cfg(feature = "sqlx-sqlite")]
212        DbBackend::Sqlite => sea_schema::sqlite::Sqlite.has_table(_table),
213        #[allow(unreachable_patterns)]
214        other => {
215            return Err(DbErr::BackendNotSupported {
216                db: other.as_str(),
217                ctx: "has_table",
218            });
219        }
220    };
221
222    #[allow(unreachable_code)]
223    let res = conn
224        .query_one(&_stmt)
225        .await?
226        .ok_or_else(|| DbErr::Custom("Failed to check table exists".to_owned()))?;
227
228    res.try_get("", "has_table")
229}