unistore_sqlite/
connection.rs

1//! SQLite 连接管理
2//!
3//! 职责:封装 rusqlite 连接,提供底层 SQL 执行能力
4
5use crate::config::SqliteConfig;
6use crate::error::SqliteError;
7use crate::types::{Param, Row, Rows, SqlValue};
8use parking_lot::Mutex;
9use rusqlite::Connection as RawConnection;
10use std::sync::Arc;
11
12/// 连接状态
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum ConnectionState {
15    /// 已打开
16    Open,
17    /// 已关闭
18    Closed,
19}
20
21/// SQLite 连接包装器
22///
23/// 线程安全的连接封装,提供执行和查询方法
24pub struct Connection {
25    /// 底层连接(Mutex 保护)
26    inner: Arc<Mutex<Option<RawConnection>>>,
27    /// 配置
28    config: SqliteConfig,
29    /// 状态
30    state: Arc<Mutex<ConnectionState>>,
31}
32
33impl Connection {
34    /// 打开数据库连接
35    pub fn open(config: SqliteConfig) -> Result<Self, SqliteError> {
36        let path = config.path_string();
37
38        // 打开连接
39        let conn = if config.read_only {
40            RawConnection::open_with_flags(
41                &path,
42                rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY | rusqlite::OpenFlags::SQLITE_OPEN_URI,
43            )
44        } else if config.create_if_missing {
45            RawConnection::open(&path)
46        } else {
47            RawConnection::open_with_flags(
48                &path,
49                rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE | rusqlite::OpenFlags::SQLITE_OPEN_URI,
50            )
51        }
52        .map_err(|e| SqliteError::OpenFailed(e.to_string()))?;
53
54        // 应用 PRAGMA 配置
55        for pragma in config.to_pragmas() {
56            conn.execute_batch(&pragma)
57                .map_err(|e| SqliteError::OpenFailed(format!("PRAGMA failed: {}", e)))?;
58        }
59
60        Ok(Self {
61            inner: Arc::new(Mutex::new(Some(conn))),
62            config,
63            state: Arc::new(Mutex::new(ConnectionState::Open)),
64        })
65    }
66
67    /// 打开内存数据库
68    pub fn open_in_memory() -> Result<Self, SqliteError> {
69        Self::open(SqliteConfig::memory())
70    }
71
72    /// 获取连接状态
73    pub fn state(&self) -> ConnectionState {
74        *self.state.lock()
75    }
76
77    /// 判断连接是否打开
78    pub fn is_open(&self) -> bool {
79        self.state() == ConnectionState::Open
80    }
81
82    /// 获取配置引用
83    pub fn config(&self) -> &SqliteConfig {
84        &self.config
85    }
86
87    /// 执行 SQL(无返回结果)
88    pub fn execute(&self, sql: &str, params: &[Param]) -> Result<usize, SqliteError> {
89        let guard = self.inner.lock();
90        let conn = guard.as_ref().ok_or(SqliteError::DatabaseClosed)?;
91
92        let params_refs: Vec<&dyn rusqlite::ToSql> =
93            params.iter().map(|p| p as &dyn rusqlite::ToSql).collect();
94
95        conn.execute(sql, params_refs.as_slice())
96            .map_err(SqliteError::from)
97    }
98
99    /// 执行多条 SQL 语句
100    pub fn execute_batch(&self, sql: &str) -> Result<(), SqliteError> {
101        let guard = self.inner.lock();
102        let conn = guard.as_ref().ok_or(SqliteError::DatabaseClosed)?;
103
104        conn.execute_batch(sql).map_err(SqliteError::from)
105    }
106
107    /// 查询单行
108    pub fn query_row(&self, sql: &str, params: &[Param]) -> Result<Option<Row>, SqliteError> {
109        let guard = self.inner.lock();
110        let conn = guard.as_ref().ok_or(SqliteError::DatabaseClosed)?;
111
112        let params_refs: Vec<&dyn rusqlite::ToSql> =
113            params.iter().map(|p| p as &dyn rusqlite::ToSql).collect();
114
115        let mut stmt = conn.prepare(sql).map_err(SqliteError::from)?;
116        let columns: Vec<String> = stmt.column_names().iter().map(|s: &&str| s.to_string()).collect();
117
118        let result = stmt.query_row(params_refs.as_slice(), |row: &rusqlite::Row| {
119            let mut r = Row::new();
120            for (i, col) in columns.iter().enumerate() {
121                let value: SqlValue = row.get(i)?;
122                r.push(col.to_string(), value);
123            }
124            Ok(r)
125        });
126
127        match result {
128            Ok(row) => Ok(Some(row)),
129            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
130            Err(e) => Err(SqliteError::from(e)),
131        }
132    }
133
134    /// 查询多行
135    pub fn query(&self, sql: &str, params: &[Param]) -> Result<Rows, SqliteError> {
136        let guard = self.inner.lock();
137        let conn = guard.as_ref().ok_or(SqliteError::DatabaseClosed)?;
138
139        let params_refs: Vec<&dyn rusqlite::ToSql> =
140            params.iter().map(|p| p as &dyn rusqlite::ToSql).collect();
141
142        let mut stmt = conn.prepare(sql).map_err(SqliteError::from)?;
143        let columns: Vec<String> = stmt.column_names().iter().map(|s: &&str| s.to_string()).collect();
144
145        let rows = stmt
146            .query_map(params_refs.as_slice(), |row: &rusqlite::Row| {
147                let mut r = Row::new();
148                for (i, col) in columns.iter().enumerate() {
149                    let value: SqlValue = row.get(i)?;
150                    r.push(col.to_string(), value);
151                }
152                Ok(r)
153            })
154            .map_err(SqliteError::from)?;
155
156        let mut result: Vec<Row> = Vec::new();
157        for row in rows {
158            result.push(row.map_err(SqliteError::from)?);
159        }
160
161        Ok(result)
162    }
163
164    /// 获取最后插入的行 ID
165    pub fn last_insert_rowid(&self) -> Result<i64, SqliteError> {
166        let guard = self.inner.lock();
167        let conn = guard.as_ref().ok_or(SqliteError::DatabaseClosed)?;
168        Ok(conn.last_insert_rowid())
169    }
170
171    /// 获取上次操作影响的行数
172    pub fn changes(&self) -> Result<usize, SqliteError> {
173        let guard = self.inner.lock();
174        let conn = guard.as_ref().ok_or(SqliteError::DatabaseClosed)?;
175        Ok(conn.changes() as usize)
176    }
177
178    /// 关闭连接
179    pub fn close(&self) -> Result<(), SqliteError> {
180        let mut guard = self.inner.lock();
181        let mut state = self.state.lock();
182
183        if let Some(conn) = guard.take() {
184            // 尝试关闭,忽略错误(连接会自动清理)
185            drop(conn);
186        }
187
188        *state = ConnectionState::Closed;
189        Ok(())
190    }
191
192    /// 访问底层连接(高级用法)
193    ///
194    /// # Safety
195    ///
196    /// 调用者需要确保不会破坏连接状态
197    pub fn with_raw<F, R>(&self, f: F) -> Result<R, SqliteError>
198    where
199        F: FnOnce(&RawConnection) -> Result<R, SqliteError>,
200    {
201        let guard = self.inner.lock();
202        let conn = guard.as_ref().ok_or(SqliteError::DatabaseClosed)?;
203        f(conn)
204    }
205}
206
207impl Clone for Connection {
208    fn clone(&self) -> Self {
209        Self {
210            inner: self.inner.clone(),
211            config: self.config.clone(),
212            state: self.state.clone(),
213        }
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn test_open_memory() {
223        let conn = Connection::open_in_memory().unwrap();
224        assert!(conn.is_open());
225    }
226
227    #[test]
228    fn test_execute_and_query() {
229        let conn = Connection::open_in_memory().unwrap();
230
231        // 创建表
232        conn.execute_batch("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)")
233            .unwrap();
234
235        // 插入数据
236        conn.execute("INSERT INTO test (name) VALUES (?)", &["Alice".into()])
237            .unwrap();
238        conn.execute("INSERT INTO test (name) VALUES (?)", &["Bob".into()])
239            .unwrap();
240
241        let id = conn.last_insert_rowid().unwrap();
242        assert_eq!(id, 2);
243
244        // 查询
245        let rows = conn.query("SELECT * FROM test", &[]).unwrap();
246        assert_eq!(rows.len(), 2);
247        assert_eq!(rows[0].get_str("name"), Some("Alice"));
248        assert_eq!(rows[1].get_str("name"), Some("Bob"));
249    }
250
251    #[test]
252    fn test_query_row() {
253        let conn = Connection::open_in_memory().unwrap();
254        conn.execute_batch("CREATE TABLE test (id INTEGER PRIMARY KEY, value INTEGER)")
255            .unwrap();
256        conn.execute("INSERT INTO test (value) VALUES (?)", &[42i32.into()])
257            .unwrap();
258
259        let row = conn
260            .query_row("SELECT * FROM test WHERE id = ?", &[1i32.into()])
261            .unwrap();
262        assert!(row.is_some());
263        assert_eq!(row.unwrap().get_i64("value"), Some(42));
264
265        let row = conn
266            .query_row("SELECT * FROM test WHERE id = ?", &[999i32.into()])
267            .unwrap();
268        assert!(row.is_none());
269    }
270
271    #[test]
272    fn test_close() {
273        let conn = Connection::open_in_memory().unwrap();
274        assert!(conn.is_open());
275
276        conn.close().unwrap();
277        assert!(!conn.is_open());
278
279        // 关闭后操作应该失败
280        let result = conn.execute("SELECT 1", &[]);
281        assert!(result.is_err());
282    }
283}