Skip to main content

wp_knowledge/mem/
query_util.rs

1use crate::error::{KnowReason, KnowledgeResult};
2use orion_error::conversion::SourceRawErr;
3use rusqlite::Params;
4use std::collections::hash_map::DefaultHasher;
5use std::future::Future;
6use std::hash::{Hash, Hasher};
7use std::num::NonZeroUsize;
8use std::sync::{Arc, RwLock};
9use wp_log::debug_kdb;
10use wp_model_core::model::{self, DataField};
11
12use lazy_static::lazy_static;
13use lru::LruCache;
14
15use crate::loader::ProviderKind;
16use crate::mem::RowData;
17use crate::runtime::{DatasourceId, Generation, MetadataCacheScope, runtime};
18use crate::telemetry::{
19    CacheLayer, CacheOutcome, CacheTelemetryEvent, telemetry, telemetry_enabled,
20};
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23pub struct MetadataCacheKey {
24    pub datasource_id: DatasourceId,
25    pub generation: Generation,
26    pub query_hash: u64,
27}
28
29lazy_static! {
30    /// Global column metadata cache keyed by datasource/generation/query hash.
31    pub static ref COLNAME_CACHE: RwLock<LruCache<MetadataCacheKey, Arc<Vec<String>>>> =
32        RwLock::new(LruCache::new(
33            NonZeroUsize::new(512).expect("non-zero metadata cache size")
34        ));
35}
36
37pub fn column_metadata_cache_snapshot() -> (usize, usize) {
38    COLNAME_CACHE
39        .read()
40        .map(|cache| (cache.len(), cache.cap().get()))
41        .unwrap_or((0, 0))
42}
43
44fn stable_hash(value: &str) -> u64 {
45    let mut hasher = DefaultHasher::new();
46    value.hash(&mut hasher);
47    hasher.finish()
48}
49
50#[cfg(test)]
51pub(crate) fn metadata_cache_key_for_current_scope(sql: &str) -> MetadataCacheKey {
52    let scope = runtime().current_metadata_scope();
53    metadata_cache_key_for_scope(&scope, sql)
54}
55
56pub(crate) fn metadata_cache_key_for_scope(
57    scope: &MetadataCacheScope,
58    sql: &str,
59) -> MetadataCacheKey {
60    MetadataCacheKey {
61        datasource_id: scope.datasource_id.clone(),
62        generation: scope.generation,
63        query_hash: stable_hash(sql),
64    }
65}
66
67pub(crate) fn metadata_cache_get_or_try_init<F>(sql: &str, load: F) -> KnowledgeResult<Vec<String>>
68where
69    F: FnOnce() -> KnowledgeResult<Option<Vec<String>>>,
70{
71    let scope = runtime().current_metadata_scope();
72    let provider_kind = runtime().current_provider_kind();
73    metadata_cache_get_or_try_init_for_scope(&scope, provider_kind, sql, load)
74}
75
76pub(crate) fn metadata_cache_get_or_try_init_for_scope<F>(
77    scope: &MetadataCacheScope,
78    provider_kind: Option<ProviderKind>,
79    sql: &str,
80    load: F,
81) -> KnowledgeResult<Vec<String>>
82where
83    F: FnOnce() -> KnowledgeResult<Option<Vec<String>>>,
84{
85    let cache_key = metadata_cache_key_for_scope(scope, sql);
86    if let Some(names) = COLNAME_CACHE
87        .read()
88        .ok()
89        .and_then(|m| m.peek(&cache_key).cloned())
90    {
91        runtime().record_metadata_cache_hit();
92        if telemetry_enabled() {
93            telemetry().on_cache(&CacheTelemetryEvent {
94                layer: CacheLayer::Metadata,
95                outcome: CacheOutcome::Hit,
96                provider_kind: provider_kind.clone(),
97            });
98        }
99        debug_kdb!(
100            "[kdb] metadata cache hit datasource_id={} generation={}",
101            cache_key.datasource_id.0,
102            cache_key.generation.0
103        );
104        return Ok((*names).clone());
105    }
106
107    runtime().record_metadata_cache_miss();
108    if telemetry_enabled() {
109        telemetry().on_cache(&CacheTelemetryEvent {
110            layer: CacheLayer::Metadata,
111            outcome: CacheOutcome::Miss,
112            provider_kind,
113        });
114    }
115    debug_kdb!(
116        "[kdb] metadata cache miss datasource_id={} generation={}",
117        cache_key.datasource_id.0,
118        cache_key.generation.0
119    );
120
121    let Some(names) = load()? else {
122        return Ok(Vec::new());
123    };
124    if let Ok(mut m) = COLNAME_CACHE.write() {
125        m.put(cache_key, Arc::new(names.clone()));
126    }
127    Ok(names)
128}
129
130pub(crate) async fn metadata_cache_get_or_try_init_async_for_scope<F, Fut>(
131    scope: &MetadataCacheScope,
132    provider_kind: Option<ProviderKind>,
133    sql: &str,
134    load: F,
135) -> KnowledgeResult<Vec<String>>
136where
137    F: FnOnce() -> Fut,
138    Fut: Future<Output = KnowledgeResult<Option<Vec<String>>>>,
139{
140    metadata_cache_get_or_try_init_async_for_scope_typed(scope, provider_kind, sql, load).await
141}
142
143pub(crate) async fn metadata_cache_get_or_try_init_async_for_scope_typed<F, Fut, E>(
144    scope: &MetadataCacheScope,
145    provider_kind: Option<ProviderKind>,
146    sql: &str,
147    load: F,
148) -> Result<Vec<String>, E>
149where
150    F: FnOnce() -> Fut,
151    Fut: Future<Output = Result<Option<Vec<String>>, E>>,
152{
153    let cache_key = metadata_cache_key_for_scope(scope, sql);
154    if let Some(names) = COLNAME_CACHE
155        .read()
156        .ok()
157        .and_then(|m| m.peek(&cache_key).cloned())
158    {
159        runtime().record_metadata_cache_hit();
160        if telemetry_enabled() {
161            telemetry().on_cache(&CacheTelemetryEvent {
162                layer: CacheLayer::Metadata,
163                outcome: CacheOutcome::Hit,
164                provider_kind: provider_kind.clone(),
165            });
166        }
167        return Ok((*names).clone());
168    }
169
170    runtime().record_metadata_cache_miss();
171    if telemetry_enabled() {
172        telemetry().on_cache(&CacheTelemetryEvent {
173            layer: CacheLayer::Metadata,
174            outcome: CacheOutcome::Miss,
175            provider_kind,
176        });
177    }
178
179    let Some(names) = load().await? else {
180        return Ok(Vec::new());
181    };
182    if let Ok(mut m) = COLNAME_CACHE.write() {
183        m.put(cache_key, Arc::new(names.clone()));
184    }
185    Ok(names)
186}
187
188/// 将一行数据映射为 RowData
189fn map_row(row: &rusqlite::Row<'_>, col_names: &[String]) -> KnowledgeResult<RowData> {
190    let mut result = Vec::with_capacity(col_names.len());
191    for (i, col_name) in col_names.iter().enumerate() {
192        let value = row
193            .get_ref(i)
194            .source_raw_err(KnowReason::from_rule(), "source error")?;
195        let field = match value {
196            rusqlite::types::ValueRef::Null => {
197                DataField::new(model::DataType::default(), col_name, model::Value::Null)
198            }
199            rusqlite::types::ValueRef::Integer(v) => DataField::from_digit(col_name, v),
200            rusqlite::types::ValueRef::Real(v) => DataField::from_float(col_name, v),
201            rusqlite::types::ValueRef::Text(v) => DataField::from_chars(
202                col_name,
203                String::from_utf8(v.to_vec())
204                    .source_raw_err(KnowReason::from_rule(), "source error")?,
205            ),
206            rusqlite::types::ValueRef::Blob(v) => {
207                DataField::from_chars(col_name, String::from_utf8_lossy(v).to_string())
208            }
209        };
210        result.push(field);
211    }
212    Ok(result)
213}
214
215/// 从 statement 获取列名(普通版,带 debug 日志)
216fn extract_col_names(stmt: &rusqlite::Statement<'_>) -> Vec<String> {
217    let col_cnt = stmt.column_count();
218    debug_kdb!("[memdb] col_cnt={}", col_cnt);
219    let mut col_names = Vec::with_capacity(col_cnt);
220    for i in 0..col_cnt {
221        let name = stmt.column_name(i).unwrap_or("").to_string();
222        debug_kdb!("[memdb] col[{}] name='{}'", i, name);
223        col_names.push(name);
224    }
225    col_names
226}
227
228/// 从 statement 获取列名(cached 版,使用全局缓存)
229fn extract_col_names_cached(
230    stmt: &rusqlite::Statement<'_>,
231    sql: &str,
232) -> KnowledgeResult<Vec<String>> {
233    metadata_cache_get_or_try_init(sql, || {
234        let col_cnt = stmt.column_count();
235        let mut names = Vec::with_capacity(col_cnt);
236        for i in 0..col_cnt {
237            names.push(
238                stmt.column_name(i)
239                    .source_raw_err(KnowReason::from_rule(), "source error")?
240                    .to_string(),
241            );
242        }
243        Ok(Some(names))
244    })
245}
246
247fn extract_col_names_cached_with_scope(
248    stmt: &rusqlite::Statement<'_>,
249    scope: &MetadataCacheScope,
250    provider_kind: Option<ProviderKind>,
251    sql: &str,
252) -> KnowledgeResult<Vec<String>> {
253    metadata_cache_get_or_try_init_for_scope(scope, provider_kind, sql, || {
254        let col_cnt = stmt.column_count();
255        let mut names = Vec::with_capacity(col_cnt);
256        for i in 0..col_cnt {
257            names.push(
258                stmt.column_name(i)
259                    .source_raw_err(KnowReason::from_rule(), "source error")?
260                    .to_string(),
261            );
262        }
263        Ok(Some(names))
264    })
265}
266
267pub fn query<P: Params>(
268    conn: &rusqlite::Connection,
269    sql: &str,
270    params: P,
271) -> KnowledgeResult<Vec<RowData>> {
272    let mut stmt = conn
273        .prepare_cached(sql)
274        .source_raw_err(KnowReason::from_rule(), "source error")?;
275    let col_names = extract_col_names(&stmt);
276    let mut rows = stmt
277        .query(params)
278        .source_raw_err(KnowReason::from_rule(), "source error")?;
279    let mut all_result = Vec::new();
280    while let Some(row) = rows
281        .next()
282        .source_raw_err(KnowReason::from_rule(), "source error")?
283    {
284        all_result.push(map_row(row, &col_names)?);
285    }
286    Ok(all_result)
287}
288
289/// Query first row and map columns into RowData with column names preserved.
290pub fn query_first_row<P: Params>(
291    conn: &rusqlite::Connection,
292    sql: &str,
293    params: P,
294) -> KnowledgeResult<RowData> {
295    let mut stmt = conn
296        .prepare_cached(sql)
297        .source_raw_err(KnowReason::from_rule(), "source error")?;
298    let col_names = extract_col_names(&stmt);
299    let mut rows = stmt
300        .query(params)
301        .source_raw_err(KnowReason::from_rule(), "source error")?;
302    if let Some(row) = rows
303        .next()
304        .source_raw_err(KnowReason::from_rule(), "source error")?
305    {
306        map_row(row, &col_names)
307    } else {
308        debug_kdb!("[memdb] no row for sql");
309        Ok(Vec::new())
310    }
311}
312
313pub fn query_cached<P: Params>(
314    conn: &rusqlite::Connection,
315    sql: &str,
316    params: P,
317) -> KnowledgeResult<Vec<RowData>> {
318    let mut stmt = conn
319        .prepare_cached(sql)
320        .source_raw_err(KnowReason::from_rule(), "source error")?;
321    // Column names cache (per SQL)
322    let col_names = extract_col_names_cached(&stmt, sql)?;
323    let mut rows = stmt
324        .query(params)
325        .source_raw_err(KnowReason::from_rule(), "source error")?;
326    let mut all_result = Vec::new();
327    while let Some(row) = rows
328        .next()
329        .source_raw_err(KnowReason::from_rule(), "source error")?
330    {
331        all_result.push(map_row(row, &col_names)?);
332    }
333    Ok(all_result)
334}
335
336pub fn query_cached_with_scope<P: Params>(
337    conn: &rusqlite::Connection,
338    scope: &MetadataCacheScope,
339    provider_kind: Option<ProviderKind>,
340    sql: &str,
341    params: P,
342) -> KnowledgeResult<Vec<RowData>> {
343    let mut stmt = conn
344        .prepare_cached(sql)
345        .source_raw_err(KnowReason::from_rule(), "source error")?;
346    let col_names = extract_col_names_cached_with_scope(&stmt, scope, provider_kind, sql)?;
347    let mut rows = stmt
348        .query(params)
349        .source_raw_err(KnowReason::from_rule(), "source error")?;
350    let mut all_result = Vec::new();
351    while let Some(row) = rows
352        .next()
353        .source_raw_err(KnowReason::from_rule(), "source error")?
354    {
355        all_result.push(map_row(row, &col_names)?);
356    }
357    Ok(all_result)
358}
359
360/// Same as `query_first_row` but with a shared column-names cache to reduce metadata lookups.
361pub fn query_first_row_cached<P: Params>(
362    conn: &rusqlite::Connection,
363    sql: &str,
364    params: P,
365) -> KnowledgeResult<RowData> {
366    let mut stmt = conn
367        .prepare_cached(sql)
368        .source_raw_err(KnowReason::from_rule(), "source error")?;
369    let col_names = extract_col_names_cached(&stmt, sql)?;
370    let mut rows = stmt
371        .query(params)
372        .source_raw_err(KnowReason::from_rule(), "source error")?;
373    if let Some(row) = rows
374        .next()
375        .source_raw_err(KnowReason::from_rule(), "source error")?
376    {
377        map_row(row, &col_names)
378    } else {
379        Ok(Vec::new())
380    }
381}
382
383pub fn query_first_row_cached_with_scope<P: Params>(
384    conn: &rusqlite::Connection,
385    scope: &MetadataCacheScope,
386    provider_kind: Option<ProviderKind>,
387    sql: &str,
388    params: P,
389) -> KnowledgeResult<RowData> {
390    let mut stmt = conn
391        .prepare_cached(sql)
392        .source_raw_err(KnowReason::from_rule(), "source error")?;
393    let col_names = extract_col_names_cached_with_scope(&stmt, scope, provider_kind, sql)?;
394    let mut rows = stmt
395        .query(params)
396        .source_raw_err(KnowReason::from_rule(), "source error")?;
397    if let Some(row) = rows
398        .next()
399        .source_raw_err(KnowReason::from_rule(), "source error")?
400    {
401        map_row(row, &col_names)
402    } else {
403        Ok(Vec::new())
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use rusqlite::Connection;
411
412    fn setup_test_db() -> Connection {
413        let conn = Connection::open_in_memory().unwrap();
414        conn.execute(
415            "CREATE TABLE test (id INTEGER, name TEXT, score REAL, data BLOB, empty)",
416            [],
417        )
418        .unwrap();
419        conn
420    }
421
422    #[test]
423    fn test_query_returns_all_rows() {
424        let conn = setup_test_db();
425        let rows = query(&conn, "SELECT * FROM test", []).unwrap();
426        assert!(rows.is_empty());
427        conn.execute("INSERT INTO test (id, name) VALUES (1, 'alice')", [])
428            .unwrap();
429        conn.execute("INSERT INTO test (id, name) VALUES (2, 'bob')", [])
430            .unwrap();
431        conn.execute("INSERT INTO test (id, name) VALUES (3, 'charlie')", [])
432            .unwrap();
433
434        let rows = query(&conn, "SELECT id, name FROM test ORDER BY id", []).unwrap();
435        assert_eq!(rows.len(), 3);
436    }
437
438    #[test]
439    fn test_query_first_row_returns_single_row() {
440        let conn = setup_test_db();
441        let row = query_first_row(&conn, "SELECT * FROM test", []).unwrap();
442        assert!(row.is_empty());
443        conn.execute("INSERT INTO test (id, name) VALUES (1, 'first')", [])
444            .unwrap();
445        conn.execute("INSERT INTO test (id, name) VALUES (2, 'second')", [])
446            .unwrap();
447
448        let row = query_first_row(&conn, "SELECT id, name FROM test ORDER BY id", []).unwrap();
449        assert_eq!(row.len(), 2);
450        assert_eq!(row[0].to_string(), "digit(1)");
451        assert_eq!(row[1].to_string(), "chars(first)");
452    }
453
454    #[test]
455    fn test_map_row_handles_all_types() {
456        let conn = setup_test_db();
457        conn.execute(
458            "INSERT INTO test (id, name, score, data, empty) VALUES (42, 'hello', 3.14, X'414243', NULL)",
459            [],
460        )
461        .unwrap();
462
463        let row =
464            query_first_row(&conn, "SELECT id, name, score, data, empty FROM test", []).unwrap();
465        assert_eq!(row.len(), 5);
466    }
467
468    #[test]
469    fn test_extract_col_names_preserves_aliases() {
470        let conn = setup_test_db();
471        conn.execute("INSERT INTO test (id, name) VALUES (1, 'x')", [])
472            .unwrap();
473
474        let row = query_first_row(
475            &conn,
476            "SELECT id AS user_id, name AS user_name FROM test",
477            [],
478        )
479        .unwrap();
480        assert_eq!(row[0].get_name(), "user_id");
481        assert_eq!(row[1].get_name(), "user_name");
482    }
483
484    #[test]
485    fn test_query_cached_uses_cache() {
486        let _guard = crate::runtime::runtime_test_guard()
487            .lock()
488            .expect("runtime test guard");
489        let conn = setup_test_db();
490        conn.execute("INSERT INTO test (id) VALUES (1)", [])
491            .unwrap();
492
493        let sql = "SELECT id FROM test WHERE id = 1";
494        // 第一次查询,填充缓存
495        let _ = query_cached(&conn, sql, []).unwrap();
496        // 第二次查询,应命中缓存
497        let rows = query_cached(&conn, sql, []).unwrap();
498        assert_eq!(rows.len(), 1);
499
500        // 验证缓存已填充
501        let cache = COLNAME_CACHE.read().unwrap();
502        assert!(cache.contains(&metadata_cache_key_for_current_scope(sql)));
503    }
504
505    #[test]
506    fn test_query_with_params() {
507        let conn = setup_test_db();
508        conn.execute("INSERT INTO test (id, name) VALUES (1, 'alice')", [])
509            .unwrap();
510        conn.execute("INSERT INTO test (id, name) VALUES (2, 'bob')", [])
511            .unwrap();
512
513        let rows = query(&conn, "SELECT name FROM test WHERE id = ?1", [2]).unwrap();
514        assert_eq!(rows.len(), 1);
515        assert_eq!(rows[0][0].to_string(), "chars(bob)");
516    }
517
518    #[test]
519    fn test_metadata_cache_key_for_scope_is_explicit() {
520        let sql = "SELECT id FROM test";
521        let scope_a = MetadataCacheScope {
522            datasource_id: DatasourceId("postgres:aaaa".to_string()),
523            generation: Generation(1),
524        };
525        let scope_b = MetadataCacheScope {
526            datasource_id: DatasourceId("postgres:bbbb".to_string()),
527            generation: Generation(2),
528        };
529        let key_a = metadata_cache_key_for_scope(&scope_a, sql);
530        let key_b = metadata_cache_key_for_scope(&scope_b, sql);
531        assert_ne!(key_a, key_b);
532    }
533}