unistore_sqlite/
db.rs

1//! 嵌入式数据库主门面
2//!
3//! 职责:提供高级 API,整合所有底层功能
4
5use crate::config::SqliteConfig;
6use crate::connection::Connection;
7use crate::error::SqliteError;
8use crate::migration::{MigrationBuilder, MigrationReport, Migrator};
9use crate::query::{DeleteBuilder, InsertBuilder, SelectBuilder, UpdateBuilder};
10use crate::transaction::{with_transaction, Transaction};
11use crate::types::{Param, Row, Rows};
12use std::path::PathBuf;
13
14/// 嵌入式数据库
15///
16/// 提供简化的数据库操作 API
17pub struct EmbeddedDb {
18    conn: Connection,
19    name: String,
20}
21
22impl EmbeddedDb {
23    /// 打开或创建数据库
24    ///
25    /// # Arguments
26    ///
27    /// * `name` - 数据库名称(用于生成文件路径)
28    ///
29    /// # Example
30    ///
31    /// ```ignore
32    /// let db = EmbeddedDb::open("my_app")?;
33    /// ```
34    pub fn open(name: &str) -> Result<Self, SqliteError> {
35        let path = Self::default_path(name);
36
37        // 确保目录存在
38        if let Some(parent) = path.parent() {
39            std::fs::create_dir_all(parent)?;
40        }
41
42        let config = SqliteConfig::default().with_path(path);
43        let conn = Connection::open(config)?;
44
45        Ok(Self {
46            conn,
47            name: name.to_string(),
48        })
49    }
50
51    /// 打开指定路径的数据库
52    pub fn open_path(path: impl Into<PathBuf>) -> Result<Self, SqliteError> {
53        let path = path.into();
54        let name = path
55            .file_stem()
56            .and_then(|s| s.to_str())
57            .unwrap_or("db")
58            .to_string();
59
60        let config = SqliteConfig::default().with_path(path);
61        let conn = Connection::open(config)?;
62
63        Ok(Self { conn, name })
64    }
65
66    /// 打开带配置的数据库
67    pub fn open_with_config(name: &str, config: SqliteConfig) -> Result<Self, SqliteError> {
68        let config = if config.path.is_none() {
69            config.with_path(Self::default_path(name))
70        } else {
71            config
72        };
73
74        let conn = Connection::open(config)?;
75
76        Ok(Self {
77            conn,
78            name: name.to_string(),
79        })
80    }
81
82    /// 创建内存数据库
83    pub fn memory() -> Result<Self, SqliteError> {
84        let conn = Connection::open(SqliteConfig::memory())?;
85        Ok(Self {
86            conn,
87            name: ":memory:".to_string(),
88        })
89    }
90
91    /// 获取默认数据库路径
92    fn default_path(name: &str) -> PathBuf {
93        let data_dir = dirs::data_dir()
94            .unwrap_or_else(|| PathBuf::from("."))
95            .join("unistore")
96            .join("db");
97
98        data_dir.join(format!("{}.db", name))
99    }
100
101    /// 获取数据库名称
102    pub fn name(&self) -> &str {
103        &self.name
104    }
105
106    /// 获取底层连接
107    pub fn connection(&self) -> &Connection {
108        &self.conn
109    }
110
111    /// 执行迁移
112    pub fn migrate<F>(&self, f: F) -> Result<MigrationReport, SqliteError>
113    where
114        F: FnOnce(&mut MigrationBuilder),
115    {
116        let migrator = Migrator::new(&self.conn);
117        migrator.migrate_with(f)
118    }
119
120    /// 获取当前 schema 版本
121    pub fn schema_version(&self) -> Result<u32, SqliteError> {
122        let migrator = Migrator::new(&self.conn);
123        migrator.current_version()
124    }
125
126    // ========== 查询构建器 ==========
127
128    /// 创建 SELECT 查询
129    pub fn select(&self, table: &str) -> SelectBuilder<'_> {
130        SelectBuilder::new(&self.conn, table)
131    }
132
133    /// 创建 INSERT 查询
134    pub fn insert(&self, table: &str) -> InsertBuilder<'_> {
135        InsertBuilder::new(&self.conn, table)
136    }
137
138    /// 创建 UPDATE 查询
139    pub fn update(&self, table: &str) -> UpdateBuilder<'_> {
140        UpdateBuilder::new(&self.conn, table)
141    }
142
143    /// 创建 DELETE 查询
144    pub fn delete(&self, table: &str) -> DeleteBuilder<'_> {
145        DeleteBuilder::new(&self.conn, table)
146    }
147
148    // ========== 便捷方法 ==========
149
150    /// 执行原始 SQL
151    pub fn execute(&self, sql: &str, params: &[Param]) -> Result<usize, SqliteError> {
152        self.conn.execute(sql, params)
153    }
154
155    /// 执行批量 SQL
156    pub fn execute_batch(&self, sql: &str) -> Result<(), SqliteError> {
157        self.conn.execute_batch(sql)
158    }
159
160    /// 查询单行
161    pub fn query_row(&self, sql: &str, params: &[Param]) -> Result<Option<Row>, SqliteError> {
162        self.conn.query_row(sql, params)
163    }
164
165    /// 查询多行
166    pub fn query(&self, sql: &str, params: &[Param]) -> Result<Rows, SqliteError> {
167        self.conn.query(sql, params)
168    }
169
170    /// 获取最后插入的行 ID
171    pub fn last_insert_id(&self) -> Result<i64, SqliteError> {
172        self.conn.last_insert_rowid()
173    }
174
175    // ========== 事务 ==========
176
177    /// 开始事务
178    pub fn begin_transaction(&self) -> Result<Transaction<'_>, SqliteError> {
179        Transaction::begin(&self.conn)
180    }
181
182    /// 在事务中执行操作
183    pub fn with_transaction<F, T>(&self, f: F) -> Result<T, SqliteError>
184    where
185        F: FnOnce(&Transaction<'_>) -> Result<T, SqliteError>,
186    {
187        with_transaction(&self.conn, f)
188    }
189
190    // ========== 工具方法 ==========
191
192    /// 检查表是否存在
193    pub fn table_exists(&self, table: &str) -> Result<bool, SqliteError> {
194        let row = self.conn.query_row(
195            "SELECT 1 FROM sqlite_master WHERE type='table' AND name=?",
196            &[table.into()],
197        )?;
198        Ok(row.is_some())
199    }
200
201    /// 获取表的列信息
202    pub fn table_columns(&self, table: &str) -> Result<Vec<ColumnInfo>, SqliteError> {
203        let rows = self.conn.query(&format!("PRAGMA table_info({})", table), &[])?;
204
205        Ok(rows
206            .into_iter()
207            .map(|row| ColumnInfo {
208                name: row.get_string("name").unwrap_or_default(),
209                type_name: row.get_string("type").unwrap_or_default(),
210                not_null: row.get_bool("notnull").unwrap_or(false),
211                default_value: row.get_string("dflt_value"),
212                is_primary_key: row.get_bool("pk").unwrap_or(false),
213            })
214            .collect())
215    }
216
217    /// 获取数据库大小(字节)
218    pub fn database_size(&self) -> Result<i64, SqliteError> {
219        let row = self
220            .conn
221            .query_row("SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()", &[])?;
222
223        Ok(row.and_then(|r| r.get_i64("size")).unwrap_or(0))
224    }
225
226    /// 执行 VACUUM 优化
227    pub fn vacuum(&self) -> Result<(), SqliteError> {
228        self.conn.execute_batch("VACUUM")
229    }
230
231    /// 关闭数据库
232    pub fn close(self) -> Result<(), SqliteError> {
233        self.conn.close()
234    }
235}
236
237/// 列信息
238#[derive(Debug, Clone)]
239pub struct ColumnInfo {
240    /// 列名
241    pub name: String,
242    /// 类型名
243    pub type_name: String,
244    /// 是否非空
245    pub not_null: bool,
246    /// 默认值
247    pub default_value: Option<String>,
248    /// 是否主键
249    pub is_primary_key: bool,
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_embedded_db_memory() {
258        let db = EmbeddedDb::memory().unwrap();
259        assert_eq!(db.name(), ":memory:");
260    }
261
262    #[test]
263    fn test_crud_operations() {
264        let db = EmbeddedDb::memory().unwrap();
265
266        // 迁移
267        db.migrate(|m| {
268            m.version(1, "创建用户表", |s| {
269                s.create_table("users", |t| t.id().text_not_null("name").integer("age").timestamps())
270            });
271        })
272        .unwrap();
273
274        // 插入
275        let id = db.insert("users").set("name", "Alice").set("age", 30).execute().unwrap();
276        assert_eq!(id, 1);
277
278        // 查询
279        let user = db.select("users").filter_eq("id", 1).fetch_one().unwrap().unwrap();
280        assert_eq!(user.get_str("name"), Some("Alice"));
281        assert_eq!(user.get_i64("age"), Some(30));
282
283        // 更新
284        let affected = db
285            .update("users")
286            .set("age", 31)
287            .filter_eq("id", 1)
288            .execute()
289            .unwrap();
290        assert_eq!(affected, 1);
291
292        // 验证更新
293        let user = db.select("users").filter_eq("id", 1).fetch_one().unwrap().unwrap();
294        assert_eq!(user.get_i64("age"), Some(31));
295
296        // 删除
297        let affected = db.delete("users").filter_eq("id", 1).execute().unwrap();
298        assert_eq!(affected, 1);
299
300        // 验证删除
301        let count = db.select("users").count().unwrap();
302        assert_eq!(count, 0);
303    }
304
305    #[test]
306    fn test_transaction() {
307        let db = EmbeddedDb::memory().unwrap();
308        db.execute_batch("CREATE TABLE test (id INTEGER PRIMARY KEY, value INTEGER)")
309            .unwrap();
310
311        // 成功事务
312        db.with_transaction(|tx| {
313            tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])?;
314            tx.execute("INSERT INTO test (value) VALUES (?)", &[2i32.into()])?;
315            Ok(())
316        })
317        .unwrap();
318
319        let count = db.select("test").count().unwrap();
320        assert_eq!(count, 2);
321
322        // 失败事务(应回滚)
323        let result: Result<(), SqliteError> = db.with_transaction(|tx| {
324            tx.execute("INSERT INTO test (value) VALUES (?)", &[3i32.into()])?;
325            Err(SqliteError::Internal("test".into()))
326        });
327
328        assert!(result.is_err());
329        let count = db.select("test").count().unwrap();
330        assert_eq!(count, 2); // 仍然是 2,事务已回滚
331    }
332
333    #[test]
334    fn test_table_info() {
335        let db = EmbeddedDb::memory().unwrap();
336        db.migrate(|m| {
337            m.version(1, "创建表", |s| {
338                s.create_table("test", |t| t.id().text_not_null("name").integer("value"))
339            });
340        })
341        .unwrap();
342
343        assert!(db.table_exists("test").unwrap());
344        assert!(!db.table_exists("nonexistent").unwrap());
345
346        let columns = db.table_columns("test").unwrap();
347        assert_eq!(columns.len(), 3);
348        assert!(columns.iter().any(|c| c.name == "id" && c.is_primary_key));
349        assert!(columns.iter().any(|c| c.name == "name" && c.not_null));
350    }
351}