Skip to main content

wp_knowledge/mem/
query_util.rs

1use orion_error::ErrorOwe;
2use rusqlite::Params;
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5use wp_error::KnowledgeResult;
6use wp_log::debug_kdb;
7use wp_model_core::model::{self, DataField};
8
9use lazy_static::lazy_static;
10
11use crate::mem::RowData;
12
13lazy_static! {
14    /// Global column-name cache keyed by raw SQL text, shared by MemDB queries.
15    pub static ref COLNAME_CACHE: RwLock<HashMap<String, Arc<Vec<String>>>> =
16        RwLock::new(HashMap::new());
17}
18
19/// 将一行数据映射为 RowData
20fn map_row(row: &rusqlite::Row<'_>, col_names: &[String]) -> KnowledgeResult<RowData> {
21    let mut result = Vec::with_capacity(col_names.len());
22    for (i, col_name) in col_names.iter().enumerate() {
23        let value = row.get_ref(i).owe_rule()?;
24        let field = match value {
25            rusqlite::types::ValueRef::Null => {
26                DataField::new(model::DataType::default(), col_name, model::Value::Null)
27            }
28            rusqlite::types::ValueRef::Integer(v) => DataField::from_digit(col_name, v),
29            rusqlite::types::ValueRef::Real(v) => DataField::from_float(col_name, v),
30            rusqlite::types::ValueRef::Text(v) => {
31                DataField::from_chars(col_name, String::from_utf8(v.to_vec()).owe_rule()?)
32            }
33            rusqlite::types::ValueRef::Blob(v) => {
34                DataField::from_chars(col_name, String::from_utf8_lossy(v).to_string())
35            }
36        };
37        result.push(field);
38    }
39    Ok(result)
40}
41
42/// 从 statement 获取列名(普通版,带 debug 日志)
43fn extract_col_names(stmt: &rusqlite::Statement<'_>) -> Vec<String> {
44    let col_cnt = stmt.column_count();
45    debug_kdb!("[memdb] col_cnt={}", col_cnt);
46    let mut col_names = Vec::with_capacity(col_cnt);
47    for i in 0..col_cnt {
48        let name = stmt.column_name(i).unwrap_or("").to_string();
49        debug_kdb!("[memdb] col[{}] name='{}'", i, name);
50        col_names.push(name);
51    }
52    col_names
53}
54
55/// 从 statement 获取列名(cached 版,使用全局缓存)
56fn extract_col_names_cached(
57    stmt: &rusqlite::Statement<'_>,
58    sql: &str,
59) -> KnowledgeResult<Vec<String>> {
60    if let Some(names) = COLNAME_CACHE.read().ok().and_then(|m| m.get(sql).cloned()) {
61        return Ok((*names).clone());
62    }
63    let col_cnt = stmt.column_count();
64    let mut names = Vec::with_capacity(col_cnt);
65    for i in 0..col_cnt {
66        names.push(stmt.column_name(i).owe_rule()?.to_string());
67    }
68    if let Ok(mut m) = COLNAME_CACHE.write() {
69        m.insert(sql.to_string(), Arc::new(names.clone()));
70    }
71    Ok(names)
72}
73
74pub fn query<P: Params>(
75    conn: &rusqlite::Connection,
76    sql: &str,
77    params: P,
78) -> KnowledgeResult<Vec<RowData>> {
79    let mut stmt = conn.prepare_cached(sql).owe_rule()?;
80    let col_names = extract_col_names(&stmt);
81    let mut rows = stmt.query(params).owe_rule()?;
82    let mut all_result = Vec::new();
83    while let Some(row) = rows.next().owe_rule()? {
84        all_result.push(map_row(row, &col_names)?);
85    }
86    Ok(all_result)
87}
88
89/// Query first row and map columns into RowData with column names preserved.
90pub fn query_first_row<P: Params>(
91    conn: &rusqlite::Connection,
92    sql: &str,
93    params: P,
94) -> KnowledgeResult<RowData> {
95    let mut stmt = conn.prepare_cached(sql).owe_rule()?;
96    let col_names = extract_col_names(&stmt);
97    let mut rows = stmt.query(params).owe_rule()?;
98    if let Some(row) = rows.next().owe_rule()? {
99        map_row(row, &col_names)
100    } else {
101        debug_kdb!("[memdb] no row for sql");
102        Ok(Vec::new())
103    }
104}
105
106pub fn query_cached<P: Params>(
107    conn: &rusqlite::Connection,
108    sql: &str,
109    params: P,
110) -> KnowledgeResult<Vec<RowData>> {
111    let mut stmt = conn.prepare_cached(sql).owe_rule()?;
112    // Column names cache (per SQL)
113    let col_names = extract_col_names_cached(&stmt, sql)?;
114    let mut rows = stmt.query(params).owe_rule()?;
115    let mut all_result = Vec::new();
116    while let Some(row) = rows.next().owe_rule()? {
117        all_result.push(map_row(row, &col_names)?);
118    }
119    Ok(all_result)
120}
121
122/// Same as `query_first_row` but with a shared column-names cache to reduce metadata lookups.
123pub fn query_first_row_cached<P: Params>(
124    conn: &rusqlite::Connection,
125    sql: &str,
126    params: P,
127) -> KnowledgeResult<RowData> {
128    let mut stmt = conn.prepare_cached(sql).owe_rule()?;
129    let col_names = extract_col_names_cached(&stmt, sql)?;
130    let mut rows = stmt.query(params).owe_rule()?;
131    if let Some(row) = rows.next().owe_rule()? {
132        map_row(row, &col_names)
133    } else {
134        Ok(Vec::new())
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use rusqlite::Connection;
142
143    fn setup_test_db() -> Connection {
144        let conn = Connection::open_in_memory().unwrap();
145        conn.execute(
146            "CREATE TABLE test (id INTEGER, name TEXT, score REAL, data BLOB, empty)",
147            [],
148        )
149        .unwrap();
150        conn
151    }
152
153    #[test]
154    fn test_query_returns_all_rows() {
155        let conn = setup_test_db();
156        let rows = query(&conn, "SELECT * FROM test", []).unwrap();
157        assert!(rows.is_empty());
158        conn.execute("INSERT INTO test (id, name) VALUES (1, 'alice')", [])
159            .unwrap();
160        conn.execute("INSERT INTO test (id, name) VALUES (2, 'bob')", [])
161            .unwrap();
162        conn.execute("INSERT INTO test (id, name) VALUES (3, 'charlie')", [])
163            .unwrap();
164
165        let rows = query(&conn, "SELECT id, name FROM test ORDER BY id", []).unwrap();
166        assert_eq!(rows.len(), 3);
167    }
168
169    #[test]
170    fn test_query_first_row_returns_single_row() {
171        let conn = setup_test_db();
172        let row = query_first_row(&conn, "SELECT * FROM test", []).unwrap();
173        assert!(row.is_empty());
174        conn.execute("INSERT INTO test (id, name) VALUES (1, 'first')", [])
175            .unwrap();
176        conn.execute("INSERT INTO test (id, name) VALUES (2, 'second')", [])
177            .unwrap();
178
179        let row = query_first_row(&conn, "SELECT id, name FROM test ORDER BY id", []).unwrap();
180        assert_eq!(row.len(), 2);
181        assert_eq!(row[0].to_string(), "digit(1)");
182        assert_eq!(row[1].to_string(), "chars(first)");
183    }
184
185    #[test]
186    fn test_map_row_handles_all_types() {
187        let conn = setup_test_db();
188        conn.execute(
189            "INSERT INTO test (id, name, score, data, empty) VALUES (42, 'hello', 3.14, X'414243', NULL)",
190            [],
191        )
192        .unwrap();
193
194        let row =
195            query_first_row(&conn, "SELECT id, name, score, data, empty FROM test", []).unwrap();
196        assert_eq!(row.len(), 5);
197    }
198
199    #[test]
200    fn test_extract_col_names_preserves_aliases() {
201        let conn = setup_test_db();
202        conn.execute("INSERT INTO test (id, name) VALUES (1, 'x')", [])
203            .unwrap();
204
205        let row = query_first_row(
206            &conn,
207            "SELECT id AS user_id, name AS user_name FROM test",
208            [],
209        )
210        .unwrap();
211        assert_eq!(row[0].get_name(), "user_id");
212        assert_eq!(row[1].get_name(), "user_name");
213    }
214
215    #[test]
216    fn test_query_cached_uses_cache() {
217        let conn = setup_test_db();
218        conn.execute("INSERT INTO test (id) VALUES (1)", [])
219            .unwrap();
220
221        let sql = "SELECT id FROM test WHERE id = 1";
222        // 第一次查询,填充缓存
223        let _ = query_cached(&conn, sql, []).unwrap();
224        // 第二次查询,应命中缓存
225        let rows = query_cached(&conn, sql, []).unwrap();
226        assert_eq!(rows.len(), 1);
227
228        // 验证缓存已填充
229        let cache = COLNAME_CACHE.read().unwrap();
230        assert!(cache.contains_key(sql));
231    }
232
233    #[test]
234    fn test_query_with_params() {
235        let conn = setup_test_db();
236        conn.execute("INSERT INTO test (id, name) VALUES (1, 'alice')", [])
237            .unwrap();
238        conn.execute("INSERT INTO test (id, name) VALUES (2, 'bob')", [])
239            .unwrap();
240
241        let rows = query(&conn, "SELECT name FROM test WHERE id = ?1", [2]).unwrap();
242        assert_eq!(rows.len(), 1);
243        assert_eq!(rows[0][0].to_string(), "chars(bob)");
244    }
245}