unistore_sqlite/
migration.rs

1//! 数据库迁移系统
2//!
3//! 职责:管理数据库 schema 版本和迁移
4
5use crate::connection::Connection;
6use crate::error::SqliteError;
7
8/// 迁移信息
9#[derive(Debug, Clone)]
10pub struct Migration {
11    /// 版本号
12    pub version: u32,
13    /// 迁移描述
14    pub description: String,
15    /// SQL 语句
16    pub sql: String,
17}
18
19impl Migration {
20    /// 创建新迁移
21    pub fn new(version: u32, description: impl Into<String>, sql: impl Into<String>) -> Self {
22        Self {
23            version,
24            description: description.into(),
25            sql: sql.into(),
26        }
27    }
28}
29
30/// 迁移报告
31#[derive(Debug, Default)]
32pub struct MigrationReport {
33    /// 已应用的迁移数量
34    pub applied: usize,
35    /// 跳过的迁移数量(已存在)
36    pub skipped: usize,
37    /// 当前数据库版本
38    pub current_version: u32,
39}
40
41/// 表构建器
42pub struct TableBuilder {
43    name: String,
44    columns: Vec<String>,
45}
46
47impl TableBuilder {
48    /// 创建新的表构建器
49    pub fn new(name: impl Into<String>) -> Self {
50        Self {
51            name: name.into(),
52            columns: Vec::new(),
53        }
54    }
55
56    /// 添加自增主键 id
57    pub fn id(mut self) -> Self {
58        self.columns
59            .push("id INTEGER PRIMARY KEY AUTOINCREMENT".to_string());
60        self
61    }
62
63    /// 添加整数列
64    pub fn integer(mut self, name: &str) -> Self {
65        self.columns.push(format!("{} INTEGER", name));
66        self
67    }
68
69    /// 添加非空整数列
70    pub fn integer_not_null(mut self, name: &str) -> Self {
71        self.columns.push(format!("{} INTEGER NOT NULL", name));
72        self
73    }
74
75    /// 添加带默认值的整数列
76    pub fn integer_default(mut self, name: &str, default: i64) -> Self {
77        self.columns
78            .push(format!("{} INTEGER DEFAULT {}", name, default));
79        self
80    }
81
82    /// 添加文本列
83    pub fn text(mut self, name: &str) -> Self {
84        self.columns.push(format!("{} TEXT", name));
85        self
86    }
87
88    /// 添加非空文本列
89    pub fn text_not_null(mut self, name: &str) -> Self {
90        self.columns.push(format!("{} TEXT NOT NULL", name));
91        self
92    }
93
94    /// 添加带默认值的文本列
95    pub fn text_default(mut self, name: &str, default: &str) -> Self {
96        self.columns
97            .push(format!("{} TEXT DEFAULT '{}'", name, default));
98        self
99    }
100
101    /// 添加实数列
102    pub fn real(mut self, name: &str) -> Self {
103        self.columns.push(format!("{} REAL", name));
104        self
105    }
106
107    /// 添加非空实数列
108    pub fn real_not_null(mut self, name: &str) -> Self {
109        self.columns.push(format!("{} REAL NOT NULL", name));
110        self
111    }
112
113    /// 添加二进制列
114    pub fn blob(mut self, name: &str) -> Self {
115        self.columns.push(format!("{} BLOB", name));
116        self
117    }
118
119    /// 添加布尔列(存储为 INTEGER)
120    pub fn boolean(mut self, name: &str) -> Self {
121        self.columns.push(format!("{} INTEGER", name));
122        self
123    }
124
125    /// 添加布尔列带默认值
126    pub fn boolean_default(mut self, name: &str, default: bool) -> Self {
127        self.columns
128            .push(format!("{} INTEGER DEFAULT {}", name, if default { 1 } else { 0 }));
129        self
130    }
131
132    /// 添加 created_at 时间戳列
133    pub fn created_at(mut self) -> Self {
134        self.columns
135            .push("created_at TEXT DEFAULT (datetime('now'))".to_string());
136        self
137    }
138
139    /// 添加 updated_at 时间戳列
140    pub fn updated_at(mut self) -> Self {
141        self.columns
142            .push("updated_at TEXT DEFAULT (datetime('now'))".to_string());
143        self
144    }
145
146    /// 添加时间戳列(created_at + updated_at)
147    pub fn timestamps(self) -> Self {
148        self.created_at().updated_at()
149    }
150
151    /// 添加外键
152    pub fn foreign_key(mut self, column: &str, ref_table: &str, ref_column: &str) -> Self {
153        self.columns.push(format!(
154            "FOREIGN KEY ({}) REFERENCES {}({})",
155            column, ref_table, ref_column
156        ));
157        self
158    }
159
160    /// 添加唯一约束
161    pub fn unique(mut self, columns: &[&str]) -> Self {
162        self.columns
163            .push(format!("UNIQUE ({})", columns.join(", ")));
164        self
165    }
166
167    /// 添加自定义列定义
168    pub fn column(mut self, definition: &str) -> Self {
169        self.columns.push(definition.to_string());
170        self
171    }
172
173    /// 构建 CREATE TABLE 语句
174    pub fn build(&self) -> String {
175        format!(
176            "CREATE TABLE IF NOT EXISTS {} (\n  {}\n)",
177            self.name,
178            self.columns.join(",\n  ")
179        )
180    }
181}
182
183/// Schema 构建器
184pub struct SchemaBuilder {
185    statements: Vec<String>,
186}
187
188impl SchemaBuilder {
189    /// 创建新的 schema 构建器
190    pub fn new() -> Self {
191        Self {
192            statements: Vec::new(),
193        }
194    }
195
196    /// 创建表
197    pub fn create_table<F>(&mut self, name: &str, f: F) -> Result<(), SqliteError>
198    where
199        F: FnOnce(TableBuilder) -> TableBuilder,
200    {
201        let builder = TableBuilder::new(name);
202        let builder = f(builder);
203        self.statements.push(builder.build());
204        Ok(())
205    }
206
207    /// 创建索引
208    pub fn create_index(&mut self, name: &str, table: &str, columns: &[&str]) -> Result<(), SqliteError> {
209        self.statements.push(format!(
210            "CREATE INDEX IF NOT EXISTS {} ON {} ({})",
211            name,
212            table,
213            columns.join(", ")
214        ));
215        Ok(())
216    }
217
218    /// 创建唯一索引
219    pub fn create_unique_index(
220        &mut self,
221        name: &str,
222        table: &str,
223        columns: &[&str],
224    ) -> Result<(), SqliteError> {
225        self.statements.push(format!(
226            "CREATE UNIQUE INDEX IF NOT EXISTS {} ON {} ({})",
227            name,
228            table,
229            columns.join(", ")
230        ));
231        Ok(())
232    }
233
234    /// 删除表
235    pub fn drop_table(&mut self, name: &str) -> Result<(), SqliteError> {
236        self.statements.push(format!("DROP TABLE IF EXISTS {}", name));
237        Ok(())
238    }
239
240    /// 删除索引
241    pub fn drop_index(&mut self, name: &str) -> Result<(), SqliteError> {
242        self.statements.push(format!("DROP INDEX IF EXISTS {}", name));
243        Ok(())
244    }
245
246    /// 添加原始 SQL
247    pub fn raw(&mut self, sql: &str) -> Result<(), SqliteError> {
248        self.statements.push(sql.to_string());
249        Ok(())
250    }
251
252    /// 构建所有语句
253    pub fn build(&self) -> String {
254        self.statements.join(";\n")
255    }
256}
257
258impl Default for SchemaBuilder {
259    fn default() -> Self {
260        Self::new()
261    }
262}
263
264/// 迁移构建器
265pub struct MigrationBuilder {
266    migrations: Vec<Migration>,
267}
268
269impl MigrationBuilder {
270    /// 创建新的迁移构建器
271    pub fn new() -> Self {
272        Self {
273            migrations: Vec::new(),
274        }
275    }
276
277    /// 添加版本迁移
278    pub fn version<F>(&mut self, version: u32, description: &str, f: F)
279    where
280        F: FnOnce(&mut SchemaBuilder) -> Result<(), SqliteError>,
281    {
282        let mut schema = SchemaBuilder::new();
283        if f(&mut schema).is_ok() {
284            self.migrations.push(Migration::new(version, description, schema.build()));
285        }
286    }
287
288    /// 获取所有迁移
289    pub fn build(self) -> Vec<Migration> {
290        self.migrations
291    }
292}
293
294impl Default for MigrationBuilder {
295    fn default() -> Self {
296        Self::new()
297    }
298}
299
300/// 迁移执行器
301pub struct Migrator<'a> {
302    conn: &'a Connection,
303}
304
305impl<'a> Migrator<'a> {
306    /// 创建新的迁移执行器
307    pub fn new(conn: &'a Connection) -> Self {
308        Self { conn }
309    }
310
311    /// 确保迁移表存在
312    fn ensure_migration_table(&self) -> Result<(), SqliteError> {
313        self.conn.execute_batch(
314            "CREATE TABLE IF NOT EXISTS _migrations (
315                version INTEGER PRIMARY KEY,
316                description TEXT NOT NULL,
317                applied_at TEXT DEFAULT (datetime('now'))
318            )",
319        )
320    }
321
322    /// 获取当前版本
323    pub fn current_version(&self) -> Result<u32, SqliteError> {
324        self.ensure_migration_table()?;
325
326        let row = self
327            .conn
328            .query_row("SELECT MAX(version) as v FROM _migrations", &[])?;
329
330        Ok(row.and_then(|r| r.get_i64("v")).unwrap_or(0) as u32)
331    }
332
333    /// 检查迁移是否已应用
334    pub fn is_applied(&self, version: u32) -> Result<bool, SqliteError> {
335        self.ensure_migration_table()?;
336
337        let row = self.conn.query_row(
338            "SELECT 1 FROM _migrations WHERE version = ?",
339            &[version.into()],
340        )?;
341
342        Ok(row.is_some())
343    }
344
345    /// 执行单个迁移
346    fn apply_migration(&self, migration: &Migration) -> Result<(), SqliteError> {
347        // 执行迁移 SQL
348        self.conn
349            .execute_batch(&migration.sql)
350            .map_err(|e| SqliteError::MigrationFailed(format!("v{}: {}", migration.version, e)))?;
351
352        // 记录迁移
353        self.conn.execute(
354            "INSERT INTO _migrations (version, description) VALUES (?, ?)",
355            &[migration.version.into(), migration.description.clone().into()],
356        )?;
357
358        Ok(())
359    }
360
361    /// 执行所有待执行的迁移
362    pub fn migrate(&self, migrations: &[Migration]) -> Result<MigrationReport, SqliteError> {
363        self.ensure_migration_table()?;
364
365        let mut report = MigrationReport::default();
366
367        // 按版本排序
368        let mut sorted: Vec<_> = migrations.iter().collect();
369        sorted.sort_by_key(|m| m.version);
370
371        for migration in sorted {
372            if self.is_applied(migration.version)? {
373                report.skipped += 1;
374            } else {
375                self.apply_migration(migration)?;
376                report.applied += 1;
377            }
378        }
379
380        report.current_version = self.current_version()?;
381
382        Ok(report)
383    }
384
385    /// 使用构建器执行迁移
386    pub fn migrate_with<F>(&self, f: F) -> Result<MigrationReport, SqliteError>
387    where
388        F: FnOnce(&mut MigrationBuilder),
389    {
390        let mut builder = MigrationBuilder::new();
391        f(&mut builder);
392        self.migrate(&builder.build())
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399
400    #[test]
401    fn test_table_builder() {
402        let sql = TableBuilder::new("users")
403            .id()
404            .text_not_null("name")
405            .integer("age")
406            .timestamps()
407            .build();
408
409        assert!(sql.contains("CREATE TABLE"));
410        assert!(sql.contains("id INTEGER PRIMARY KEY"));
411        assert!(sql.contains("name TEXT NOT NULL"));
412        assert!(sql.contains("created_at"));
413    }
414
415    #[test]
416    fn test_migration() {
417        let conn = Connection::open_in_memory().unwrap();
418        let migrator = Migrator::new(&conn);
419
420        let report = migrator
421            .migrate_with(|m| {
422                m.version(1, "创建用户表", |s| {
423                    s.create_table("users", |t| t.id().text_not_null("name").timestamps())
424                });
425                m.version(2, "添加索引", |s| s.create_index("idx_users_name", "users", &["name"]));
426            })
427            .unwrap();
428
429        assert_eq!(report.applied, 2);
430        assert_eq!(report.current_version, 2);
431
432        // 再次执行应该跳过
433        let report2 = migrator
434            .migrate_with(|m| {
435                m.version(1, "创建用户表", |s| {
436                    s.create_table("users", |t| t.id().text_not_null("name").timestamps())
437                });
438            })
439            .unwrap();
440
441        assert_eq!(report2.applied, 0);
442        assert_eq!(report2.skipped, 1);
443    }
444
445    #[test]
446    fn test_schema_builder() {
447        let mut schema = SchemaBuilder::new();
448        schema
449            .create_table("posts", |t| {
450                t.id()
451                    .text_not_null("title")
452                    .text("content")
453                    .integer_not_null("user_id")
454                    .foreign_key("user_id", "users", "id")
455            })
456            .unwrap();
457
458        schema
459            .create_index("idx_posts_user", "posts", &["user_id"])
460            .unwrap();
461
462        let sql = schema.build();
463        assert!(sql.contains("CREATE TABLE"));
464        assert!(sql.contains("FOREIGN KEY"));
465        assert!(sql.contains("CREATE INDEX"));
466    }
467}