unistore_sqlite/
query.rs

1//! 查询构建器
2//!
3//! 职责:提供流畅的 SQL 查询构建 API
4
5use crate::connection::Connection;
6use crate::error::SqliteError;
7use crate::types::{Param, Row, Rows};
8
9/// SELECT 查询构建器
10pub struct SelectBuilder<'a> {
11    conn: &'a Connection,
12    table: String,
13    columns: Vec<String>,
14    where_clause: Option<String>,
15    where_params: Vec<Param>,
16    order_by: Option<String>,
17    limit: Option<usize>,
18    offset: Option<usize>,
19}
20
21impl<'a> SelectBuilder<'a> {
22    /// 创建新的 SELECT 构建器
23    pub fn new(conn: &'a Connection, table: impl Into<String>) -> Self {
24        Self {
25            conn,
26            table: table.into(),
27            columns: vec!["*".to_string()],
28            where_clause: None,
29            where_params: Vec::new(),
30            order_by: None,
31            limit: None,
32            offset: None,
33        }
34    }
35
36    /// 指定要查询的列
37    pub fn columns(mut self, cols: &[&str]) -> Self {
38        self.columns = cols.iter().map(|s| s.to_string()).collect();
39        self
40    }
41
42    /// 添加 WHERE 条件
43    pub fn filter(mut self, condition: &str, params: impl IntoIterator<Item = Param>) -> Self {
44        self.where_clause = Some(condition.to_string());
45        self.where_params = params.into_iter().collect();
46        self
47    }
48
49    /// 添加 WHERE 条件(单参数便捷方法)
50    pub fn filter_eq(self, column: &str, value: impl Into<Param>) -> Self {
51        self.filter(&format!("{} = ?", column), [value.into()])
52    }
53
54    /// 添加 ORDER BY
55    pub fn order_by(mut self, column: &str, desc: bool) -> Self {
56        let direction = if desc { "DESC" } else { "ASC" };
57        self.order_by = Some(format!("{} {}", column, direction));
58        self
59    }
60
61    /// 设置 LIMIT
62    pub fn limit(mut self, limit: usize) -> Self {
63        self.limit = Some(limit);
64        self
65    }
66
67    /// 设置 OFFSET
68    pub fn offset(mut self, offset: usize) -> Self {
69        self.offset = Some(offset);
70        self
71    }
72
73    /// 构建 SQL 语句
74    pub fn build(&self) -> (String, Vec<Param>) {
75        let mut sql = format!("SELECT {} FROM {}", self.columns.join(", "), self.table);
76
77        let params = self.where_params.clone();
78
79        if let Some(ref where_clause) = self.where_clause {
80            sql.push_str(" WHERE ");
81            sql.push_str(where_clause);
82        }
83
84        if let Some(ref order) = self.order_by {
85            sql.push_str(" ORDER BY ");
86            sql.push_str(order);
87        }
88
89        if let Some(limit) = self.limit {
90            sql.push_str(&format!(" LIMIT {}", limit));
91        }
92
93        if let Some(offset) = self.offset {
94            sql.push_str(&format!(" OFFSET {}", offset));
95        }
96
97        (sql, params)
98    }
99
100    /// 执行查询,返回所有行
101    pub fn fetch_all(self) -> Result<Rows, SqliteError> {
102        let (sql, params) = self.build();
103        self.conn.query(&sql, &params)
104    }
105
106    /// 执行查询,返回第一行
107    pub fn fetch_one(self) -> Result<Option<Row>, SqliteError> {
108        let (sql, params) = self.build();
109        self.conn.query_row(&sql, &params)
110    }
111
112    /// 执行查询,返回行数
113    pub fn count(self) -> Result<i64, SqliteError> {
114        let sql = format!(
115            "SELECT COUNT(*) as cnt FROM {}{}",
116            self.table,
117            self.where_clause
118                .as_ref()
119                .map(|w| format!(" WHERE {}", w))
120                .unwrap_or_default()
121        );
122        let row = self.conn.query_row(&sql, &self.where_params)?;
123        Ok(row.and_then(|r| r.get_i64("cnt")).unwrap_or(0))
124    }
125}
126
127/// INSERT 查询构建器
128pub struct InsertBuilder<'a> {
129    conn: &'a Connection,
130    table: String,
131    columns: Vec<String>,
132    values: Vec<Param>,
133}
134
135impl<'a> InsertBuilder<'a> {
136    /// 创建新的 INSERT 构建器
137    pub fn new(conn: &'a Connection, table: impl Into<String>) -> Self {
138        Self {
139            conn,
140            table: table.into(),
141            columns: Vec::new(),
142            values: Vec::new(),
143        }
144    }
145
146    /// 设置列值
147    pub fn set(mut self, column: &str, value: impl Into<Param>) -> Self {
148        self.columns.push(column.to_string());
149        self.values.push(value.into());
150        self
151    }
152
153    /// 构建 SQL 语句
154    pub fn build(&self) -> (String, Vec<Param>) {
155        let placeholders = vec!["?"; self.columns.len()].join(", ");
156        let sql = format!(
157            "INSERT INTO {} ({}) VALUES ({})",
158            self.table,
159            self.columns.join(", "),
160            placeholders
161        );
162        (sql, self.values.clone())
163    }
164
165    /// 执行插入,返回最后插入的行 ID
166    pub fn execute(self) -> Result<i64, SqliteError> {
167        let (sql, params) = self.build();
168        self.conn.execute(&sql, &params)?;
169        self.conn.last_insert_rowid()
170    }
171}
172
173/// UPDATE 查询构建器
174pub struct UpdateBuilder<'a> {
175    conn: &'a Connection,
176    table: String,
177    sets: Vec<(String, Param)>,
178    where_clause: Option<String>,
179    where_params: Vec<Param>,
180}
181
182impl<'a> UpdateBuilder<'a> {
183    /// 创建新的 UPDATE 构建器
184    pub fn new(conn: &'a Connection, table: impl Into<String>) -> Self {
185        Self {
186            conn,
187            table: table.into(),
188            sets: Vec::new(),
189            where_clause: None,
190            where_params: Vec::new(),
191        }
192    }
193
194    /// 设置列值
195    pub fn set(mut self, column: &str, value: impl Into<Param>) -> Self {
196        self.sets.push((column.to_string(), value.into()));
197        self
198    }
199
200    /// 添加 WHERE 条件
201    pub fn filter(mut self, condition: &str, params: impl IntoIterator<Item = Param>) -> Self {
202        self.where_clause = Some(condition.to_string());
203        self.where_params = params.into_iter().collect();
204        self
205    }
206
207    /// 添加 WHERE 条件(单参数便捷方法)
208    pub fn filter_eq(self, column: &str, value: impl Into<Param>) -> Self {
209        self.filter(&format!("{} = ?", column), [value.into()])
210    }
211
212    /// 构建 SQL 语句
213    pub fn build(&self) -> (String, Vec<Param>) {
214        let set_clause: Vec<String> = self.sets.iter().map(|(col, _)| format!("{} = ?", col)).collect();
215
216        let mut sql = format!("UPDATE {} SET {}", self.table, set_clause.join(", "));
217
218        let mut params: Vec<Param> = self.sets.iter().map(|(_, v)| v.clone()).collect();
219
220        if let Some(ref where_clause) = self.where_clause {
221            sql.push_str(" WHERE ");
222            sql.push_str(where_clause);
223            params.extend(self.where_params.clone());
224        }
225
226        (sql, params)
227    }
228
229    /// 执行更新,返回影响的行数
230    pub fn execute(self) -> Result<usize, SqliteError> {
231        let (sql, params) = self.build();
232        self.conn.execute(&sql, &params)
233    }
234}
235
236/// DELETE 查询构建器
237pub struct DeleteBuilder<'a> {
238    conn: &'a Connection,
239    table: String,
240    where_clause: Option<String>,
241    where_params: Vec<Param>,
242}
243
244impl<'a> DeleteBuilder<'a> {
245    /// 创建新的 DELETE 构建器
246    pub fn new(conn: &'a Connection, table: impl Into<String>) -> Self {
247        Self {
248            conn,
249            table: table.into(),
250            where_clause: None,
251            where_params: Vec::new(),
252        }
253    }
254
255    /// 添加 WHERE 条件
256    pub fn filter(mut self, condition: &str, params: impl IntoIterator<Item = Param>) -> Self {
257        self.where_clause = Some(condition.to_string());
258        self.where_params = params.into_iter().collect();
259        self
260    }
261
262    /// 添加 WHERE 条件(单参数便捷方法)
263    pub fn filter_eq(self, column: &str, value: impl Into<Param>) -> Self {
264        self.filter(&format!("{} = ?", column), [value.into()])
265    }
266
267    /// 构建 SQL 语句
268    pub fn build(&self) -> (String, Vec<Param>) {
269        let mut sql = format!("DELETE FROM {}", self.table);
270
271        if let Some(ref where_clause) = self.where_clause {
272            sql.push_str(" WHERE ");
273            sql.push_str(where_clause);
274        }
275
276        (sql, self.where_params.clone())
277    }
278
279    /// 执行删除,返回影响的行数
280    pub fn execute(self) -> Result<usize, SqliteError> {
281        let (sql, params) = self.build();
282        self.conn.execute(&sql, &params)
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    fn setup_test_db() -> Connection {
291        let conn = Connection::open_in_memory().unwrap();
292        conn.execute_batch(
293            "CREATE TABLE users (
294                id INTEGER PRIMARY KEY,
295                name TEXT NOT NULL,
296                age INTEGER
297            )",
298        )
299        .unwrap();
300        conn
301    }
302
303    #[test]
304    fn test_insert() {
305        let conn = setup_test_db();
306
307        let id = InsertBuilder::new(&conn, "users")
308            .set("name", "Alice")
309            .set("age", 30)
310            .execute()
311            .unwrap();
312
313        assert_eq!(id, 1);
314    }
315
316    #[test]
317    fn test_select() {
318        let conn = setup_test_db();
319
320        InsertBuilder::new(&conn, "users")
321            .set("name", "Alice")
322            .set("age", 30)
323            .execute()
324            .unwrap();
325
326        InsertBuilder::new(&conn, "users")
327            .set("name", "Bob")
328            .set("age", 25)
329            .execute()
330            .unwrap();
331
332        // 查询所有
333        let rows = SelectBuilder::new(&conn, "users").fetch_all().unwrap();
334        assert_eq!(rows.len(), 2);
335
336        // 条件查询
337        let rows = SelectBuilder::new(&conn, "users")
338            .filter_eq("name", "Alice")
339            .fetch_all()
340            .unwrap();
341        assert_eq!(rows.len(), 1);
342
343        // 排序和限制
344        let rows = SelectBuilder::new(&conn, "users")
345            .order_by("age", false)
346            .limit(1)
347            .fetch_all()
348            .unwrap();
349        assert_eq!(rows.len(), 1);
350        assert_eq!(rows[0].get_str("name"), Some("Bob"));
351    }
352
353    #[test]
354    fn test_update() {
355        let conn = setup_test_db();
356
357        InsertBuilder::new(&conn, "users")
358            .set("name", "Alice")
359            .set("age", 30)
360            .execute()
361            .unwrap();
362
363        let affected = UpdateBuilder::new(&conn, "users")
364            .set("age", 31)
365            .filter_eq("name", "Alice")
366            .execute()
367            .unwrap();
368
369        assert_eq!(affected, 1);
370
371        let row = SelectBuilder::new(&conn, "users")
372            .filter_eq("name", "Alice")
373            .fetch_one()
374            .unwrap()
375            .unwrap();
376
377        assert_eq!(row.get_i64("age"), Some(31));
378    }
379
380    #[test]
381    fn test_delete() {
382        let conn = setup_test_db();
383
384        InsertBuilder::new(&conn, "users")
385            .set("name", "Alice")
386            .set("age", 30)
387            .execute()
388            .unwrap();
389
390        InsertBuilder::new(&conn, "users")
391            .set("name", "Bob")
392            .set("age", 25)
393            .execute()
394            .unwrap();
395
396        let affected = DeleteBuilder::new(&conn, "users")
397            .filter_eq("name", "Alice")
398            .execute()
399            .unwrap();
400
401        assert_eq!(affected, 1);
402
403        let count = SelectBuilder::new(&conn, "users").count().unwrap();
404        assert_eq!(count, 1);
405    }
406
407    #[test]
408    fn test_count() {
409        let conn = setup_test_db();
410
411        for i in 0..5 {
412            InsertBuilder::new(&conn, "users")
413                .set("name", format!("User{}", i))
414                .set("age", 20 + i)
415                .execute()
416                .unwrap();
417        }
418
419        let count = SelectBuilder::new(&conn, "users").count().unwrap();
420        assert_eq!(count, 5);
421
422        let count = SelectBuilder::new(&conn, "users")
423            .filter("age >= ?", [22i32.into()])
424            .count()
425            .unwrap();
426        assert_eq!(count, 3);
427    }
428}