Skip to main content

wp_knowledge/mem/
query_util.rs

1use orion_error::ErrorOwe;
2use rusqlite::Params;
3use std::collections::hash_map::DefaultHasher;
4use std::hash::{Hash, Hasher};
5use std::num::NonZeroUsize;
6use std::sync::{Arc, RwLock};
7use wp_error::KnowledgeResult;
8use wp_log::debug_kdb;
9use wp_model_core::model::{self, DataField};
10
11use lazy_static::lazy_static;
12use lru::LruCache;
13
14use crate::mem::RowData;
15use crate::runtime::{DatasourceId, Generation, runtime};
16use crate::telemetry::{CacheLayer, CacheOutcome, CacheTelemetryEvent, telemetry};
17
18#[derive(Debug, Clone, PartialEq, Eq, Hash)]
19pub struct MetadataCacheKey {
20    pub datasource_id: DatasourceId,
21    pub generation: Generation,
22    pub query_hash: u64,
23}
24
25lazy_static! {
26    /// Global column metadata cache keyed by datasource/generation/query hash.
27    pub static ref COLNAME_CACHE: RwLock<LruCache<MetadataCacheKey, Arc<Vec<String>>>> =
28        RwLock::new(LruCache::new(
29            NonZeroUsize::new(512).expect("non-zero metadata cache size")
30        ));
31}
32
33pub fn column_metadata_cache_snapshot() -> (usize, usize) {
34    COLNAME_CACHE
35        .read()
36        .map(|cache| (cache.len(), cache.cap().get()))
37        .unwrap_or((0, 0))
38}
39
40fn stable_hash(value: &str) -> u64 {
41    let mut hasher = DefaultHasher::new();
42    value.hash(&mut hasher);
43    hasher.finish()
44}
45
46pub(crate) fn metadata_cache_key_for_current_scope(sql: &str) -> MetadataCacheKey {
47    let scope = runtime().current_metadata_scope();
48    MetadataCacheKey {
49        datasource_id: scope.datasource_id,
50        generation: scope.generation,
51        query_hash: stable_hash(sql),
52    }
53}
54
55pub(crate) fn metadata_cache_get_or_try_init<F>(sql: &str, load: F) -> KnowledgeResult<Vec<String>>
56where
57    F: FnOnce() -> KnowledgeResult<Option<Vec<String>>>,
58{
59    let cache_key = metadata_cache_key_for_current_scope(sql);
60    if let Some(names) = COLNAME_CACHE
61        .read()
62        .ok()
63        .and_then(|m| m.peek(&cache_key).cloned())
64    {
65        runtime().record_metadata_cache_hit();
66        telemetry().on_cache(&CacheTelemetryEvent {
67            layer: CacheLayer::Metadata,
68            outcome: CacheOutcome::Hit,
69            provider_kind: runtime().current_provider_kind(),
70        });
71        debug_kdb!(
72            "[kdb] metadata cache hit datasource_id={} generation={}",
73            cache_key.datasource_id.0,
74            cache_key.generation.0
75        );
76        return Ok((*names).clone());
77    }
78
79    runtime().record_metadata_cache_miss();
80    telemetry().on_cache(&CacheTelemetryEvent {
81        layer: CacheLayer::Metadata,
82        outcome: CacheOutcome::Miss,
83        provider_kind: runtime().current_provider_kind(),
84    });
85    debug_kdb!(
86        "[kdb] metadata cache miss datasource_id={} generation={}",
87        cache_key.datasource_id.0,
88        cache_key.generation.0
89    );
90
91    let Some(names) = load()? else {
92        return Ok(Vec::new());
93    };
94    if let Ok(mut m) = COLNAME_CACHE.write() {
95        m.put(cache_key, Arc::new(names.clone()));
96    }
97    Ok(names)
98}
99
100/// 将一行数据映射为 RowData
101fn map_row(row: &rusqlite::Row<'_>, col_names: &[String]) -> KnowledgeResult<RowData> {
102    let mut result = Vec::with_capacity(col_names.len());
103    for (i, col_name) in col_names.iter().enumerate() {
104        let value = row.get_ref(i).owe_rule()?;
105        let field = match value {
106            rusqlite::types::ValueRef::Null => {
107                DataField::new(model::DataType::default(), col_name, model::Value::Null)
108            }
109            rusqlite::types::ValueRef::Integer(v) => DataField::from_digit(col_name, v),
110            rusqlite::types::ValueRef::Real(v) => DataField::from_float(col_name, v),
111            rusqlite::types::ValueRef::Text(v) => {
112                DataField::from_chars(col_name, String::from_utf8(v.to_vec()).owe_rule()?)
113            }
114            rusqlite::types::ValueRef::Blob(v) => {
115                DataField::from_chars(col_name, String::from_utf8_lossy(v).to_string())
116            }
117        };
118        result.push(field);
119    }
120    Ok(result)
121}
122
123/// 从 statement 获取列名(普通版,带 debug 日志)
124fn extract_col_names(stmt: &rusqlite::Statement<'_>) -> Vec<String> {
125    let col_cnt = stmt.column_count();
126    debug_kdb!("[memdb] col_cnt={}", col_cnt);
127    let mut col_names = Vec::with_capacity(col_cnt);
128    for i in 0..col_cnt {
129        let name = stmt.column_name(i).unwrap_or("").to_string();
130        debug_kdb!("[memdb] col[{}] name='{}'", i, name);
131        col_names.push(name);
132    }
133    col_names
134}
135
136/// 从 statement 获取列名(cached 版,使用全局缓存)
137fn extract_col_names_cached(
138    stmt: &rusqlite::Statement<'_>,
139    sql: &str,
140) -> KnowledgeResult<Vec<String>> {
141    metadata_cache_get_or_try_init(sql, || {
142        let col_cnt = stmt.column_count();
143        let mut names = Vec::with_capacity(col_cnt);
144        for i in 0..col_cnt {
145            names.push(stmt.column_name(i).owe_rule()?.to_string());
146        }
147        Ok(Some(names))
148    })
149}
150
151pub fn query<P: Params>(
152    conn: &rusqlite::Connection,
153    sql: &str,
154    params: P,
155) -> KnowledgeResult<Vec<RowData>> {
156    let mut stmt = conn.prepare_cached(sql).owe_rule()?;
157    let col_names = extract_col_names(&stmt);
158    let mut rows = stmt.query(params).owe_rule()?;
159    let mut all_result = Vec::new();
160    while let Some(row) = rows.next().owe_rule()? {
161        all_result.push(map_row(row, &col_names)?);
162    }
163    Ok(all_result)
164}
165
166/// Query first row and map columns into RowData with column names preserved.
167pub fn query_first_row<P: Params>(
168    conn: &rusqlite::Connection,
169    sql: &str,
170    params: P,
171) -> KnowledgeResult<RowData> {
172    let mut stmt = conn.prepare_cached(sql).owe_rule()?;
173    let col_names = extract_col_names(&stmt);
174    let mut rows = stmt.query(params).owe_rule()?;
175    if let Some(row) = rows.next().owe_rule()? {
176        map_row(row, &col_names)
177    } else {
178        debug_kdb!("[memdb] no row for sql");
179        Ok(Vec::new())
180    }
181}
182
183pub fn query_cached<P: Params>(
184    conn: &rusqlite::Connection,
185    sql: &str,
186    params: P,
187) -> KnowledgeResult<Vec<RowData>> {
188    let mut stmt = conn.prepare_cached(sql).owe_rule()?;
189    // Column names cache (per SQL)
190    let col_names = extract_col_names_cached(&stmt, sql)?;
191    let mut rows = stmt.query(params).owe_rule()?;
192    let mut all_result = Vec::new();
193    while let Some(row) = rows.next().owe_rule()? {
194        all_result.push(map_row(row, &col_names)?);
195    }
196    Ok(all_result)
197}
198
199/// Same as `query_first_row` but with a shared column-names cache to reduce metadata lookups.
200pub fn query_first_row_cached<P: Params>(
201    conn: &rusqlite::Connection,
202    sql: &str,
203    params: P,
204) -> KnowledgeResult<RowData> {
205    let mut stmt = conn.prepare_cached(sql).owe_rule()?;
206    let col_names = extract_col_names_cached(&stmt, sql)?;
207    let mut rows = stmt.query(params).owe_rule()?;
208    if let Some(row) = rows.next().owe_rule()? {
209        map_row(row, &col_names)
210    } else {
211        Ok(Vec::new())
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use rusqlite::Connection;
219
220    fn setup_test_db() -> Connection {
221        let conn = Connection::open_in_memory().unwrap();
222        conn.execute(
223            "CREATE TABLE test (id INTEGER, name TEXT, score REAL, data BLOB, empty)",
224            [],
225        )
226        .unwrap();
227        conn
228    }
229
230    #[test]
231    fn test_query_returns_all_rows() {
232        let conn = setup_test_db();
233        let rows = query(&conn, "SELECT * FROM test", []).unwrap();
234        assert!(rows.is_empty());
235        conn.execute("INSERT INTO test (id, name) VALUES (1, 'alice')", [])
236            .unwrap();
237        conn.execute("INSERT INTO test (id, name) VALUES (2, 'bob')", [])
238            .unwrap();
239        conn.execute("INSERT INTO test (id, name) VALUES (3, 'charlie')", [])
240            .unwrap();
241
242        let rows = query(&conn, "SELECT id, name FROM test ORDER BY id", []).unwrap();
243        assert_eq!(rows.len(), 3);
244    }
245
246    #[test]
247    fn test_query_first_row_returns_single_row() {
248        let conn = setup_test_db();
249        let row = query_first_row(&conn, "SELECT * FROM test", []).unwrap();
250        assert!(row.is_empty());
251        conn.execute("INSERT INTO test (id, name) VALUES (1, 'first')", [])
252            .unwrap();
253        conn.execute("INSERT INTO test (id, name) VALUES (2, 'second')", [])
254            .unwrap();
255
256        let row = query_first_row(&conn, "SELECT id, name FROM test ORDER BY id", []).unwrap();
257        assert_eq!(row.len(), 2);
258        assert_eq!(row[0].to_string(), "digit(1)");
259        assert_eq!(row[1].to_string(), "chars(first)");
260    }
261
262    #[test]
263    fn test_map_row_handles_all_types() {
264        let conn = setup_test_db();
265        conn.execute(
266            "INSERT INTO test (id, name, score, data, empty) VALUES (42, 'hello', 3.14, X'414243', NULL)",
267            [],
268        )
269        .unwrap();
270
271        let row =
272            query_first_row(&conn, "SELECT id, name, score, data, empty FROM test", []).unwrap();
273        assert_eq!(row.len(), 5);
274    }
275
276    #[test]
277    fn test_extract_col_names_preserves_aliases() {
278        let conn = setup_test_db();
279        conn.execute("INSERT INTO test (id, name) VALUES (1, 'x')", [])
280            .unwrap();
281
282        let row = query_first_row(
283            &conn,
284            "SELECT id AS user_id, name AS user_name FROM test",
285            [],
286        )
287        .unwrap();
288        assert_eq!(row[0].get_name(), "user_id");
289        assert_eq!(row[1].get_name(), "user_name");
290    }
291
292    #[test]
293    fn test_query_cached_uses_cache() {
294        let _guard = crate::runtime::runtime_test_guard()
295            .lock()
296            .expect("runtime test guard");
297        let conn = setup_test_db();
298        conn.execute("INSERT INTO test (id) VALUES (1)", [])
299            .unwrap();
300
301        let sql = "SELECT id FROM test WHERE id = 1";
302        // 第一次查询,填充缓存
303        let _ = query_cached(&conn, sql, []).unwrap();
304        // 第二次查询,应命中缓存
305        let rows = query_cached(&conn, sql, []).unwrap();
306        assert_eq!(rows.len(), 1);
307
308        // 验证缓存已填充
309        let cache = COLNAME_CACHE.read().unwrap();
310        assert!(cache.contains(&metadata_cache_key_for_current_scope(sql)));
311    }
312
313    #[test]
314    fn test_query_with_params() {
315        let conn = setup_test_db();
316        conn.execute("INSERT INTO test (id, name) VALUES (1, 'alice')", [])
317            .unwrap();
318        conn.execute("INSERT INTO test (id, name) VALUES (2, 'bob')", [])
319            .unwrap();
320
321        let rows = query(&conn, "SELECT name FROM test WHERE id = ?1", [2]).unwrap();
322        assert_eq!(rows.len(), 1);
323        assert_eq!(rows[0][0].to_string(), "chars(bob)");
324    }
325}