Skip to main content

wp_knowledge/mem/
query_util.rs

1use orion_error::ErrorOwe;
2use rusqlite::Params;
3use std::collections::hash_map::DefaultHasher;
4use std::future::Future;
5use std::hash::{Hash, Hasher};
6use std::num::NonZeroUsize;
7use std::sync::{Arc, RwLock};
8use wp_error::KnowledgeResult;
9use wp_log::debug_kdb;
10use wp_model_core::model::{self, DataField};
11
12use lazy_static::lazy_static;
13use lru::LruCache;
14
15use crate::loader::ProviderKind;
16use crate::mem::RowData;
17use crate::runtime::{DatasourceId, Generation, MetadataCacheScope, runtime};
18use crate::telemetry::{
19    CacheLayer, CacheOutcome, CacheTelemetryEvent, telemetry, telemetry_enabled,
20};
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23pub struct MetadataCacheKey {
24    pub datasource_id: DatasourceId,
25    pub generation: Generation,
26    pub query_hash: u64,
27}
28
29lazy_static! {
30    /// Global column metadata cache keyed by datasource/generation/query hash.
31    pub static ref COLNAME_CACHE: RwLock<LruCache<MetadataCacheKey, Arc<Vec<String>>>> =
32        RwLock::new(LruCache::new(
33            NonZeroUsize::new(512).expect("non-zero metadata cache size")
34        ));
35}
36
37pub fn column_metadata_cache_snapshot() -> (usize, usize) {
38    COLNAME_CACHE
39        .read()
40        .map(|cache| (cache.len(), cache.cap().get()))
41        .unwrap_or((0, 0))
42}
43
44fn stable_hash(value: &str) -> u64 {
45    let mut hasher = DefaultHasher::new();
46    value.hash(&mut hasher);
47    hasher.finish()
48}
49
50#[cfg(test)]
51pub(crate) fn metadata_cache_key_for_current_scope(sql: &str) -> MetadataCacheKey {
52    let scope = runtime().current_metadata_scope();
53    metadata_cache_key_for_scope(&scope, sql)
54}
55
56pub(crate) fn metadata_cache_key_for_scope(
57    scope: &MetadataCacheScope,
58    sql: &str,
59) -> MetadataCacheKey {
60    MetadataCacheKey {
61        datasource_id: scope.datasource_id.clone(),
62        generation: scope.generation,
63        query_hash: stable_hash(sql),
64    }
65}
66
67pub(crate) fn metadata_cache_get_or_try_init<F>(sql: &str, load: F) -> KnowledgeResult<Vec<String>>
68where
69    F: FnOnce() -> KnowledgeResult<Option<Vec<String>>>,
70{
71    let scope = runtime().current_metadata_scope();
72    let provider_kind = runtime().current_provider_kind();
73    metadata_cache_get_or_try_init_for_scope(&scope, provider_kind, sql, load)
74}
75
76pub(crate) fn metadata_cache_get_or_try_init_for_scope<F>(
77    scope: &MetadataCacheScope,
78    provider_kind: Option<ProviderKind>,
79    sql: &str,
80    load: F,
81) -> KnowledgeResult<Vec<String>>
82where
83    F: FnOnce() -> KnowledgeResult<Option<Vec<String>>>,
84{
85    let cache_key = metadata_cache_key_for_scope(scope, sql);
86    if let Some(names) = COLNAME_CACHE
87        .read()
88        .ok()
89        .and_then(|m| m.peek(&cache_key).cloned())
90    {
91        runtime().record_metadata_cache_hit();
92        if telemetry_enabled() {
93            telemetry().on_cache(&CacheTelemetryEvent {
94                layer: CacheLayer::Metadata,
95                outcome: CacheOutcome::Hit,
96                provider_kind: provider_kind.clone(),
97            });
98        }
99        debug_kdb!(
100            "[kdb] metadata cache hit datasource_id={} generation={}",
101            cache_key.datasource_id.0,
102            cache_key.generation.0
103        );
104        return Ok((*names).clone());
105    }
106
107    runtime().record_metadata_cache_miss();
108    if telemetry_enabled() {
109        telemetry().on_cache(&CacheTelemetryEvent {
110            layer: CacheLayer::Metadata,
111            outcome: CacheOutcome::Miss,
112            provider_kind,
113        });
114    }
115    debug_kdb!(
116        "[kdb] metadata cache miss datasource_id={} generation={}",
117        cache_key.datasource_id.0,
118        cache_key.generation.0
119    );
120
121    let Some(names) = load()? else {
122        return Ok(Vec::new());
123    };
124    if let Ok(mut m) = COLNAME_CACHE.write() {
125        m.put(cache_key, Arc::new(names.clone()));
126    }
127    Ok(names)
128}
129
130pub(crate) async fn metadata_cache_get_or_try_init_async_for_scope<F, Fut>(
131    scope: &MetadataCacheScope,
132    provider_kind: Option<ProviderKind>,
133    sql: &str,
134    load: F,
135) -> KnowledgeResult<Vec<String>>
136where
137    F: FnOnce() -> Fut,
138    Fut: Future<Output = KnowledgeResult<Option<Vec<String>>>>,
139{
140    let cache_key = metadata_cache_key_for_scope(scope, sql);
141    if let Some(names) = COLNAME_CACHE
142        .read()
143        .ok()
144        .and_then(|m| m.peek(&cache_key).cloned())
145    {
146        runtime().record_metadata_cache_hit();
147        if telemetry_enabled() {
148            telemetry().on_cache(&CacheTelemetryEvent {
149                layer: CacheLayer::Metadata,
150                outcome: CacheOutcome::Hit,
151                provider_kind: provider_kind.clone(),
152            });
153        }
154        return Ok((*names).clone());
155    }
156
157    runtime().record_metadata_cache_miss();
158    if telemetry_enabled() {
159        telemetry().on_cache(&CacheTelemetryEvent {
160            layer: CacheLayer::Metadata,
161            outcome: CacheOutcome::Miss,
162            provider_kind,
163        });
164    }
165
166    let Some(names) = load().await? else {
167        return Ok(Vec::new());
168    };
169    if let Ok(mut m) = COLNAME_CACHE.write() {
170        m.put(cache_key, Arc::new(names.clone()));
171    }
172    Ok(names)
173}
174
175/// 将一行数据映射为 RowData
176fn map_row(row: &rusqlite::Row<'_>, col_names: &[String]) -> KnowledgeResult<RowData> {
177    let mut result = Vec::with_capacity(col_names.len());
178    for (i, col_name) in col_names.iter().enumerate() {
179        let value = row.get_ref(i).owe_rule()?;
180        let field = match value {
181            rusqlite::types::ValueRef::Null => {
182                DataField::new(model::DataType::default(), col_name, model::Value::Null)
183            }
184            rusqlite::types::ValueRef::Integer(v) => DataField::from_digit(col_name, v),
185            rusqlite::types::ValueRef::Real(v) => DataField::from_float(col_name, v),
186            rusqlite::types::ValueRef::Text(v) => {
187                DataField::from_chars(col_name, String::from_utf8(v.to_vec()).owe_rule()?)
188            }
189            rusqlite::types::ValueRef::Blob(v) => {
190                DataField::from_chars(col_name, String::from_utf8_lossy(v).to_string())
191            }
192        };
193        result.push(field);
194    }
195    Ok(result)
196}
197
198/// 从 statement 获取列名(普通版,带 debug 日志)
199fn extract_col_names(stmt: &rusqlite::Statement<'_>) -> Vec<String> {
200    let col_cnt = stmt.column_count();
201    debug_kdb!("[memdb] col_cnt={}", col_cnt);
202    let mut col_names = Vec::with_capacity(col_cnt);
203    for i in 0..col_cnt {
204        let name = stmt.column_name(i).unwrap_or("").to_string();
205        debug_kdb!("[memdb] col[{}] name='{}'", i, name);
206        col_names.push(name);
207    }
208    col_names
209}
210
211/// 从 statement 获取列名(cached 版,使用全局缓存)
212fn extract_col_names_cached(
213    stmt: &rusqlite::Statement<'_>,
214    sql: &str,
215) -> KnowledgeResult<Vec<String>> {
216    metadata_cache_get_or_try_init(sql, || {
217        let col_cnt = stmt.column_count();
218        let mut names = Vec::with_capacity(col_cnt);
219        for i in 0..col_cnt {
220            names.push(stmt.column_name(i).owe_rule()?.to_string());
221        }
222        Ok(Some(names))
223    })
224}
225
226fn extract_col_names_cached_with_scope(
227    stmt: &rusqlite::Statement<'_>,
228    scope: &MetadataCacheScope,
229    provider_kind: Option<ProviderKind>,
230    sql: &str,
231) -> KnowledgeResult<Vec<String>> {
232    metadata_cache_get_or_try_init_for_scope(scope, provider_kind, sql, || {
233        let col_cnt = stmt.column_count();
234        let mut names = Vec::with_capacity(col_cnt);
235        for i in 0..col_cnt {
236            names.push(stmt.column_name(i).owe_rule()?.to_string());
237        }
238        Ok(Some(names))
239    })
240}
241
242pub fn query<P: Params>(
243    conn: &rusqlite::Connection,
244    sql: &str,
245    params: P,
246) -> KnowledgeResult<Vec<RowData>> {
247    let mut stmt = conn.prepare_cached(sql).owe_rule()?;
248    let col_names = extract_col_names(&stmt);
249    let mut rows = stmt.query(params).owe_rule()?;
250    let mut all_result = Vec::new();
251    while let Some(row) = rows.next().owe_rule()? {
252        all_result.push(map_row(row, &col_names)?);
253    }
254    Ok(all_result)
255}
256
257/// Query first row and map columns into RowData with column names preserved.
258pub fn query_first_row<P: Params>(
259    conn: &rusqlite::Connection,
260    sql: &str,
261    params: P,
262) -> KnowledgeResult<RowData> {
263    let mut stmt = conn.prepare_cached(sql).owe_rule()?;
264    let col_names = extract_col_names(&stmt);
265    let mut rows = stmt.query(params).owe_rule()?;
266    if let Some(row) = rows.next().owe_rule()? {
267        map_row(row, &col_names)
268    } else {
269        debug_kdb!("[memdb] no row for sql");
270        Ok(Vec::new())
271    }
272}
273
274pub fn query_cached<P: Params>(
275    conn: &rusqlite::Connection,
276    sql: &str,
277    params: P,
278) -> KnowledgeResult<Vec<RowData>> {
279    let mut stmt = conn.prepare_cached(sql).owe_rule()?;
280    // Column names cache (per SQL)
281    let col_names = extract_col_names_cached(&stmt, sql)?;
282    let mut rows = stmt.query(params).owe_rule()?;
283    let mut all_result = Vec::new();
284    while let Some(row) = rows.next().owe_rule()? {
285        all_result.push(map_row(row, &col_names)?);
286    }
287    Ok(all_result)
288}
289
290pub fn query_cached_with_scope<P: Params>(
291    conn: &rusqlite::Connection,
292    scope: &MetadataCacheScope,
293    provider_kind: Option<ProviderKind>,
294    sql: &str,
295    params: P,
296) -> KnowledgeResult<Vec<RowData>> {
297    let mut stmt = conn.prepare_cached(sql).owe_rule()?;
298    let col_names = extract_col_names_cached_with_scope(&stmt, scope, provider_kind, sql)?;
299    let mut rows = stmt.query(params).owe_rule()?;
300    let mut all_result = Vec::new();
301    while let Some(row) = rows.next().owe_rule()? {
302        all_result.push(map_row(row, &col_names)?);
303    }
304    Ok(all_result)
305}
306
307/// Same as `query_first_row` but with a shared column-names cache to reduce metadata lookups.
308pub fn query_first_row_cached<P: Params>(
309    conn: &rusqlite::Connection,
310    sql: &str,
311    params: P,
312) -> KnowledgeResult<RowData> {
313    let mut stmt = conn.prepare_cached(sql).owe_rule()?;
314    let col_names = extract_col_names_cached(&stmt, sql)?;
315    let mut rows = stmt.query(params).owe_rule()?;
316    if let Some(row) = rows.next().owe_rule()? {
317        map_row(row, &col_names)
318    } else {
319        Ok(Vec::new())
320    }
321}
322
323pub fn query_first_row_cached_with_scope<P: Params>(
324    conn: &rusqlite::Connection,
325    scope: &MetadataCacheScope,
326    provider_kind: Option<ProviderKind>,
327    sql: &str,
328    params: P,
329) -> KnowledgeResult<RowData> {
330    let mut stmt = conn.prepare_cached(sql).owe_rule()?;
331    let col_names = extract_col_names_cached_with_scope(&stmt, scope, provider_kind, sql)?;
332    let mut rows = stmt.query(params).owe_rule()?;
333    if let Some(row) = rows.next().owe_rule()? {
334        map_row(row, &col_names)
335    } else {
336        Ok(Vec::new())
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use rusqlite::Connection;
344
345    fn setup_test_db() -> Connection {
346        let conn = Connection::open_in_memory().unwrap();
347        conn.execute(
348            "CREATE TABLE test (id INTEGER, name TEXT, score REAL, data BLOB, empty)",
349            [],
350        )
351        .unwrap();
352        conn
353    }
354
355    #[test]
356    fn test_query_returns_all_rows() {
357        let conn = setup_test_db();
358        let rows = query(&conn, "SELECT * FROM test", []).unwrap();
359        assert!(rows.is_empty());
360        conn.execute("INSERT INTO test (id, name) VALUES (1, 'alice')", [])
361            .unwrap();
362        conn.execute("INSERT INTO test (id, name) VALUES (2, 'bob')", [])
363            .unwrap();
364        conn.execute("INSERT INTO test (id, name) VALUES (3, 'charlie')", [])
365            .unwrap();
366
367        let rows = query(&conn, "SELECT id, name FROM test ORDER BY id", []).unwrap();
368        assert_eq!(rows.len(), 3);
369    }
370
371    #[test]
372    fn test_query_first_row_returns_single_row() {
373        let conn = setup_test_db();
374        let row = query_first_row(&conn, "SELECT * FROM test", []).unwrap();
375        assert!(row.is_empty());
376        conn.execute("INSERT INTO test (id, name) VALUES (1, 'first')", [])
377            .unwrap();
378        conn.execute("INSERT INTO test (id, name) VALUES (2, 'second')", [])
379            .unwrap();
380
381        let row = query_first_row(&conn, "SELECT id, name FROM test ORDER BY id", []).unwrap();
382        assert_eq!(row.len(), 2);
383        assert_eq!(row[0].to_string(), "digit(1)");
384        assert_eq!(row[1].to_string(), "chars(first)");
385    }
386
387    #[test]
388    fn test_map_row_handles_all_types() {
389        let conn = setup_test_db();
390        conn.execute(
391            "INSERT INTO test (id, name, score, data, empty) VALUES (42, 'hello', 3.14, X'414243', NULL)",
392            [],
393        )
394        .unwrap();
395
396        let row =
397            query_first_row(&conn, "SELECT id, name, score, data, empty FROM test", []).unwrap();
398        assert_eq!(row.len(), 5);
399    }
400
401    #[test]
402    fn test_extract_col_names_preserves_aliases() {
403        let conn = setup_test_db();
404        conn.execute("INSERT INTO test (id, name) VALUES (1, 'x')", [])
405            .unwrap();
406
407        let row = query_first_row(
408            &conn,
409            "SELECT id AS user_id, name AS user_name FROM test",
410            [],
411        )
412        .unwrap();
413        assert_eq!(row[0].get_name(), "user_id");
414        assert_eq!(row[1].get_name(), "user_name");
415    }
416
417    #[test]
418    fn test_query_cached_uses_cache() {
419        let _guard = crate::runtime::runtime_test_guard()
420            .lock()
421            .expect("runtime test guard");
422        let conn = setup_test_db();
423        conn.execute("INSERT INTO test (id) VALUES (1)", [])
424            .unwrap();
425
426        let sql = "SELECT id FROM test WHERE id = 1";
427        // 第一次查询,填充缓存
428        let _ = query_cached(&conn, sql, []).unwrap();
429        // 第二次查询,应命中缓存
430        let rows = query_cached(&conn, sql, []).unwrap();
431        assert_eq!(rows.len(), 1);
432
433        // 验证缓存已填充
434        let cache = COLNAME_CACHE.read().unwrap();
435        assert!(cache.contains(&metadata_cache_key_for_current_scope(sql)));
436    }
437
438    #[test]
439    fn test_query_with_params() {
440        let conn = setup_test_db();
441        conn.execute("INSERT INTO test (id, name) VALUES (1, 'alice')", [])
442            .unwrap();
443        conn.execute("INSERT INTO test (id, name) VALUES (2, 'bob')", [])
444            .unwrap();
445
446        let rows = query(&conn, "SELECT name FROM test WHERE id = ?1", [2]).unwrap();
447        assert_eq!(rows.len(), 1);
448        assert_eq!(rows[0][0].to_string(), "chars(bob)");
449    }
450
451    #[test]
452    fn test_metadata_cache_key_for_scope_is_explicit() {
453        let sql = "SELECT id FROM test";
454        let scope_a = MetadataCacheScope {
455            datasource_id: DatasourceId("postgres:aaaa".to_string()),
456            generation: Generation(1),
457        };
458        let scope_b = MetadataCacheScope {
459            datasource_id: DatasourceId("postgres:bbbb".to_string()),
460            generation: Generation(2),
461        };
462        let key_a = metadata_cache_key_for_scope(&scope_a, sql);
463        let key_b = metadata_cache_key_for_scope(&scope_b, sql);
464        assert_ne!(key_a, key_b);
465    }
466}