1use orion_error::ErrorOwe;
2use rusqlite::Params;
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5use wp_error::KnowledgeResult;
6use wp_log::debug_kdb;
7use wp_model_core::model::{self, DataField};
8
9use lazy_static::lazy_static;
10
11use crate::mem::RowData;
12
13lazy_static! {
14 pub static ref COLNAME_CACHE: RwLock<HashMap<String, Arc<Vec<String>>>> =
16 RwLock::new(HashMap::new());
17}
18
19fn map_row(row: &rusqlite::Row<'_>, col_names: &[String]) -> KnowledgeResult<RowData> {
21 let mut result = Vec::with_capacity(col_names.len());
22 for (i, col_name) in col_names.iter().enumerate() {
23 let value = row.get_ref(i).owe_rule()?;
24 let field = match value {
25 rusqlite::types::ValueRef::Null => {
26 DataField::new(model::DataType::default(), col_name, model::Value::Null)
27 }
28 rusqlite::types::ValueRef::Integer(v) => DataField::from_digit(col_name, v),
29 rusqlite::types::ValueRef::Real(v) => DataField::from_float(col_name, v),
30 rusqlite::types::ValueRef::Text(v) => {
31 DataField::from_chars(col_name, String::from_utf8(v.to_vec()).owe_rule()?)
32 }
33 rusqlite::types::ValueRef::Blob(v) => {
34 DataField::from_chars(col_name, String::from_utf8_lossy(v).to_string())
35 }
36 };
37 result.push(field);
38 }
39 Ok(result)
40}
41
42fn extract_col_names(stmt: &rusqlite::Statement<'_>) -> Vec<String> {
44 let col_cnt = stmt.column_count();
45 debug_kdb!("[memdb] col_cnt={}", col_cnt);
46 let mut col_names = Vec::with_capacity(col_cnt);
47 for i in 0..col_cnt {
48 let name = stmt.column_name(i).unwrap_or("").to_string();
49 debug_kdb!("[memdb] col[{}] name='{}'", i, name);
50 col_names.push(name);
51 }
52 col_names
53}
54
55fn extract_col_names_cached(
57 stmt: &rusqlite::Statement<'_>,
58 sql: &str,
59) -> KnowledgeResult<Vec<String>> {
60 if let Some(names) = COLNAME_CACHE.read().ok().and_then(|m| m.get(sql).cloned()) {
61 return Ok((*names).clone());
62 }
63 let col_cnt = stmt.column_count();
64 let mut names = Vec::with_capacity(col_cnt);
65 for i in 0..col_cnt {
66 names.push(stmt.column_name(i).owe_rule()?.to_string());
67 }
68 if let Ok(mut m) = COLNAME_CACHE.write() {
69 m.insert(sql.to_string(), Arc::new(names.clone()));
70 }
71 Ok(names)
72}
73
74pub fn query<P: Params>(
75 conn: &rusqlite::Connection,
76 sql: &str,
77 params: P,
78) -> KnowledgeResult<Vec<RowData>> {
79 let mut stmt = conn.prepare_cached(sql).owe_rule()?;
80 let col_names = extract_col_names(&stmt);
81 let mut rows = stmt.query(params).owe_rule()?;
82 let mut all_result = Vec::new();
83 while let Some(row) = rows.next().owe_rule()? {
84 all_result.push(map_row(row, &col_names)?);
85 }
86 Ok(all_result)
87}
88
89pub fn query_first_row<P: Params>(
91 conn: &rusqlite::Connection,
92 sql: &str,
93 params: P,
94) -> KnowledgeResult<RowData> {
95 let mut stmt = conn.prepare_cached(sql).owe_rule()?;
96 let col_names = extract_col_names(&stmt);
97 let mut rows = stmt.query(params).owe_rule()?;
98 if let Some(row) = rows.next().owe_rule()? {
99 map_row(row, &col_names)
100 } else {
101 debug_kdb!("[memdb] no row for sql");
102 Ok(Vec::new())
103 }
104}
105
106pub fn query_cached<P: Params>(
107 conn: &rusqlite::Connection,
108 sql: &str,
109 params: P,
110) -> KnowledgeResult<Vec<RowData>> {
111 let mut stmt = conn.prepare_cached(sql).owe_rule()?;
112 let col_names = extract_col_names_cached(&stmt, sql)?;
114 let mut rows = stmt.query(params).owe_rule()?;
115 let mut all_result = Vec::new();
116 while let Some(row) = rows.next().owe_rule()? {
117 all_result.push(map_row(row, &col_names)?);
118 }
119 Ok(all_result)
120}
121
122pub fn query_first_row_cached<P: Params>(
124 conn: &rusqlite::Connection,
125 sql: &str,
126 params: P,
127) -> KnowledgeResult<RowData> {
128 let mut stmt = conn.prepare_cached(sql).owe_rule()?;
129 let col_names = extract_col_names_cached(&stmt, sql)?;
130 let mut rows = stmt.query(params).owe_rule()?;
131 if let Some(row) = rows.next().owe_rule()? {
132 map_row(row, &col_names)
133 } else {
134 Ok(Vec::new())
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use rusqlite::Connection;
142
143 fn setup_test_db() -> Connection {
144 let conn = Connection::open_in_memory().unwrap();
145 conn.execute(
146 "CREATE TABLE test (id INTEGER, name TEXT, score REAL, data BLOB, empty)",
147 [],
148 )
149 .unwrap();
150 conn
151 }
152
153 #[test]
154 fn test_query_returns_all_rows() {
155 let conn = setup_test_db();
156 let rows = query(&conn, "SELECT * FROM test", []).unwrap();
157 assert!(rows.is_empty());
158 conn.execute("INSERT INTO test (id, name) VALUES (1, 'alice')", [])
159 .unwrap();
160 conn.execute("INSERT INTO test (id, name) VALUES (2, 'bob')", [])
161 .unwrap();
162 conn.execute("INSERT INTO test (id, name) VALUES (3, 'charlie')", [])
163 .unwrap();
164
165 let rows = query(&conn, "SELECT id, name FROM test ORDER BY id", []).unwrap();
166 assert_eq!(rows.len(), 3);
167 }
168
169 #[test]
170 fn test_query_first_row_returns_single_row() {
171 let conn = setup_test_db();
172 let row = query_first_row(&conn, "SELECT * FROM test", []).unwrap();
173 assert!(row.is_empty());
174 conn.execute("INSERT INTO test (id, name) VALUES (1, 'first')", [])
175 .unwrap();
176 conn.execute("INSERT INTO test (id, name) VALUES (2, 'second')", [])
177 .unwrap();
178
179 let row = query_first_row(&conn, "SELECT id, name FROM test ORDER BY id", []).unwrap();
180 assert_eq!(row.len(), 2);
181 assert_eq!(row[0].to_string(), "digit(1)");
182 assert_eq!(row[1].to_string(), "chars(first)");
183 }
184
185 #[test]
186 fn test_map_row_handles_all_types() {
187 let conn = setup_test_db();
188 conn.execute(
189 "INSERT INTO test (id, name, score, data, empty) VALUES (42, 'hello', 3.14, X'414243', NULL)",
190 [],
191 )
192 .unwrap();
193
194 let row =
195 query_first_row(&conn, "SELECT id, name, score, data, empty FROM test", []).unwrap();
196 assert_eq!(row.len(), 5);
197 }
198
199 #[test]
200 fn test_extract_col_names_preserves_aliases() {
201 let conn = setup_test_db();
202 conn.execute("INSERT INTO test (id, name) VALUES (1, 'x')", [])
203 .unwrap();
204
205 let row = query_first_row(
206 &conn,
207 "SELECT id AS user_id, name AS user_name FROM test",
208 [],
209 )
210 .unwrap();
211 assert_eq!(row[0].get_name(), "user_id");
212 assert_eq!(row[1].get_name(), "user_name");
213 }
214
215 #[test]
216 fn test_query_cached_uses_cache() {
217 let conn = setup_test_db();
218 conn.execute("INSERT INTO test (id) VALUES (1)", [])
219 .unwrap();
220
221 let sql = "SELECT id FROM test WHERE id = 1";
222 let _ = query_cached(&conn, sql, []).unwrap();
224 let rows = query_cached(&conn, sql, []).unwrap();
226 assert_eq!(rows.len(), 1);
227
228 let cache = COLNAME_CACHE.read().unwrap();
230 assert!(cache.contains_key(sql));
231 }
232
233 #[test]
234 fn test_query_with_params() {
235 let conn = setup_test_db();
236 conn.execute("INSERT INTO test (id, name) VALUES (1, 'alice')", [])
237 .unwrap();
238 conn.execute("INSERT INTO test (id, name) VALUES (2, 'bob')", [])
239 .unwrap();
240
241 let rows = query(&conn, "SELECT name FROM test WHERE id = ?1", [2]).unwrap();
242 assert_eq!(rows.len(), 1);
243 assert_eq!(rows[0][0].to_string(), "chars(bob)");
244 }
245}