Skip to main content

wp_knowledge/
runtime.rs

1use std::collections::hash_map::DefaultHasher;
2use std::hash::{Hash, Hasher};
3use std::num::NonZeroUsize;
4use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
5use std::sync::{Arc, OnceLock, RwLock};
6use std::time::{Duration, Instant};
7
8use async_trait::async_trait;
9use lru::LruCache;
10use orion_error::{ToStructError, UvsFrom};
11use tokio::task;
12use wp_error::{KnowledgeReason, KnowledgeResult};
13use wp_log::{debug_kdb, warn_kdb};
14use wp_model_core::model::{DataField, DataType, Value};
15
16use crate::loader::ProviderKind;
17use crate::mem::RowData;
18use crate::telemetry::{
19    CacheLayer, CacheOutcome, CacheTelemetryEvent, QueryTelemetryEvent, ReloadOutcome,
20    ReloadTelemetryEvent, telemetry, telemetry_enabled,
21};
22
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct DatasourceId(pub String);
25
26impl DatasourceId {
27    pub fn from_seed(kind: ProviderKind, seed: &str) -> Self {
28        let mut hasher = DefaultHasher::new();
29        seed.hash(&mut hasher);
30        let kind_str = match kind {
31            ProviderKind::SqliteAuthority => "sqlite",
32            ProviderKind::Postgres => "postgres",
33            ProviderKind::Mysql => "mysql",
34        };
35        Self(format!("{kind_str}:{:016x}", hasher.finish()))
36    }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40pub struct Generation(pub u64);
41
42#[derive(Debug, Clone)]
43pub enum QueryMode {
44    Many,
45    FirstRow,
46}
47
48#[derive(Debug, Clone, Copy)]
49pub enum CachePolicy {
50    Bypass,
51    UseGlobal,
52    UseCallScope,
53}
54
55#[derive(Debug, Clone)]
56pub enum QueryValue {
57    Null,
58    Bool(bool),
59    Int(i64),
60    Float(f64),
61    Text(String),
62}
63
64#[derive(Debug, Clone)]
65pub struct QueryParam {
66    pub name: String,
67    pub value: QueryValue,
68}
69
70#[derive(Debug, Clone)]
71pub struct QueryRequest {
72    pub sql: String,
73    pub params: Vec<QueryParam>,
74    pub mode: QueryMode,
75    pub cache_policy: CachePolicy,
76}
77
78impl QueryRequest {
79    pub fn many(
80        sql: impl Into<String>,
81        params: Vec<QueryParam>,
82        cache_policy: CachePolicy,
83    ) -> Self {
84        Self {
85            sql: sql.into(),
86            params,
87            mode: QueryMode::Many,
88            cache_policy,
89        }
90    }
91
92    pub fn first_row(
93        sql: impl Into<String>,
94        params: Vec<QueryParam>,
95        cache_policy: CachePolicy,
96    ) -> Self {
97        Self {
98            sql: sql.into(),
99            params,
100            mode: QueryMode::FirstRow,
101            cache_policy,
102        }
103    }
104}
105
106#[derive(Debug, Clone)]
107pub enum QueryResponse {
108    Rows(Vec<RowData>),
109    Row(RowData),
110}
111
112impl QueryResponse {
113    pub fn into_rows(self) -> Vec<RowData> {
114        match self {
115            QueryResponse::Rows(rows) => rows,
116            QueryResponse::Row(row) => vec![row],
117        }
118    }
119
120    pub fn into_row(self) -> RowData {
121        match self {
122            QueryResponse::Rows(rows) => rows.into_iter().next().unwrap_or_default(),
123            QueryResponse::Row(row) => row,
124        }
125    }
126}
127
128#[async_trait]
129pub trait ProviderExecutor: Send + Sync {
130    fn query(&self, sql: &str) -> KnowledgeResult<Vec<RowData>>;
131    fn query_fields(&self, sql: &str, params: &[DataField]) -> KnowledgeResult<Vec<RowData>>;
132    fn query_row(&self, sql: &str) -> KnowledgeResult<RowData>;
133    fn query_named_fields(&self, sql: &str, params: &[DataField]) -> KnowledgeResult<RowData>;
134
135    async fn query_async(&self, sql: &str) -> KnowledgeResult<Vec<RowData>> {
136        self.query(sql)
137    }
138
139    async fn query_fields_async(
140        &self,
141        sql: &str,
142        params: &[DataField],
143    ) -> KnowledgeResult<Vec<RowData>> {
144        self.query_fields(sql, params)
145    }
146
147    async fn query_row_async(&self, sql: &str) -> KnowledgeResult<RowData> {
148        self.query_row(sql)
149    }
150
151    async fn query_named_fields_async(
152        &self,
153        sql: &str,
154        params: &[DataField],
155    ) -> KnowledgeResult<RowData> {
156        self.query_named_fields(sql, params)
157    }
158}
159
160#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
161pub enum QueryModeTag {
162    Many,
163    FirstRow,
164}
165
166#[derive(Debug, Clone, PartialEq, Eq, Hash)]
167pub struct ResultCacheKey {
168    pub datasource_id: DatasourceId,
169    pub generation: Generation,
170    pub query_hash: u64,
171    pub params_hash: u64,
172    pub mode: QueryModeTag,
173}
174
175pub struct ProviderHandle {
176    pub provider: Arc<dyn ProviderExecutor>,
177    pub datasource_id: DatasourceId,
178    pub generation: Generation,
179    pub kind: ProviderKind,
180}
181
182#[derive(Debug, Clone)]
183pub struct RuntimeSnapshot {
184    pub provider_kind: Option<ProviderKind>,
185    pub datasource_id: Option<DatasourceId>,
186    pub generation: Option<Generation>,
187    pub result_cache_enabled: bool,
188    pub result_cache_len: usize,
189    pub result_cache_capacity: usize,
190    pub result_cache_ttl_ms: u64,
191    pub metadata_cache_len: usize,
192    pub metadata_cache_capacity: usize,
193    pub result_cache_hits: u64,
194    pub result_cache_misses: u64,
195    pub metadata_cache_hits: u64,
196    pub metadata_cache_misses: u64,
197    pub local_cache_hits: u64,
198    pub local_cache_misses: u64,
199    pub reload_successes: u64,
200    pub reload_failures: u64,
201}
202
203#[derive(Debug, Clone)]
204pub struct MetadataCacheScope {
205    pub datasource_id: DatasourceId,
206    pub generation: Generation,
207}
208
209#[derive(Debug, Clone, Copy)]
210pub struct ResultCacheConfig {
211    pub enabled: bool,
212    pub capacity: usize,
213    pub ttl: Duration,
214}
215
216impl Default for ResultCacheConfig {
217    fn default() -> Self {
218        Self {
219            enabled: true,
220            capacity: 1024,
221            ttl: Duration::from_millis(30_000),
222        }
223    }
224}
225
226#[derive(Debug, Clone)]
227struct CachedQueryResponse {
228    response: Arc<QueryResponse>,
229    cached_at: Instant,
230}
231
232pub struct KnowledgeRuntime {
233    provider: RwLock<Option<Arc<ProviderHandle>>>,
234    next_generation: AtomicU64,
235    provider_epoch: AtomicU64,
236    current_generation_value: AtomicU64,
237    result_cache_config: RwLock<ResultCacheConfig>,
238    result_cache_enabled: AtomicBool,
239    result_cache_ttl_ms: AtomicU64,
240    result_cache: RwLock<LruCache<ResultCacheKey, CachedQueryResponse>>,
241    result_cache_hits: AtomicU64,
242    result_cache_misses: AtomicU64,
243    metadata_cache_hits: AtomicU64,
244    metadata_cache_misses: AtomicU64,
245    local_cache_hits: AtomicU64,
246    local_cache_misses: AtomicU64,
247    reload_successes: AtomicU64,
248    reload_failures: AtomicU64,
249}
250
251impl KnowledgeRuntime {
252    pub fn new(result_cache_capacity: usize) -> Self {
253        let config = ResultCacheConfig {
254            capacity: result_cache_capacity.max(1),
255            ..ResultCacheConfig::default()
256        };
257        let capacity = NonZeroUsize::new(config.capacity).expect("non-zero capacity");
258        Self {
259            provider: RwLock::new(None),
260            next_generation: AtomicU64::new(0),
261            provider_epoch: AtomicU64::new(0),
262            current_generation_value: AtomicU64::new(0),
263            result_cache_config: RwLock::new(config),
264            result_cache_enabled: AtomicBool::new(config.enabled),
265            result_cache_ttl_ms: AtomicU64::new(config.ttl.as_millis() as u64),
266            result_cache: RwLock::new(LruCache::new(capacity)),
267            result_cache_hits: AtomicU64::new(0),
268            result_cache_misses: AtomicU64::new(0),
269            metadata_cache_hits: AtomicU64::new(0),
270            metadata_cache_misses: AtomicU64::new(0),
271            local_cache_hits: AtomicU64::new(0),
272            local_cache_misses: AtomicU64::new(0),
273            reload_successes: AtomicU64::new(0),
274            reload_failures: AtomicU64::new(0),
275        }
276    }
277
278    pub fn install_provider<F>(
279        &self,
280        kind: ProviderKind,
281        datasource_id: DatasourceId,
282        build: F,
283    ) -> KnowledgeResult<Generation>
284    where
285        F: FnOnce(Generation) -> KnowledgeResult<Arc<dyn ProviderExecutor>>,
286    {
287        let generation = Generation(self.next_generation.fetch_add(1, Ordering::SeqCst) + 1);
288        let previous = self
289            .provider
290            .read()
291            .ok()
292            .and_then(|guard| guard.as_ref().cloned());
293        debug_kdb!(
294            "[kdb] reload provider start kind={kind:?} datasource_id={} target_generation={} previous_generation={}",
295            datasource_id.0,
296            generation.0,
297            previous
298                .as_ref()
299                .map(|handle| handle.generation.0.to_string())
300                .unwrap_or_else(|| "none".to_string())
301        );
302        let provider = match build(generation) {
303            Ok(provider) => provider,
304            Err(err) => {
305                self.reload_failures.fetch_add(1, Ordering::Relaxed);
306                warn_kdb!(
307                    "[kdb] reload provider failed kind={kind:?} datasource_id={} target_generation={} err={}",
308                    datasource_id.0,
309                    generation.0,
310                    err
311                );
312                if telemetry_enabled() {
313                    telemetry().on_reload(&ReloadTelemetryEvent {
314                        outcome: ReloadOutcome::Failure,
315                        provider_kind: kind.clone(),
316                    });
317                }
318                return Err(err);
319            }
320        };
321        debug_kdb!(
322            "[kdb] install provider kind={kind:?} datasource_id={} generation={}",
323            datasource_id.0,
324            generation.0
325        );
326        let kind_for_handle = kind.clone();
327        let datasource_id_for_handle = datasource_id.clone();
328        let handle = Arc::new(ProviderHandle {
329            provider,
330            datasource_id: datasource_id_for_handle,
331            generation,
332            kind: kind_for_handle,
333        });
334        self.provider_epoch.fetch_add(1, Ordering::AcqRel);
335        {
336            let mut guard = self
337                .provider
338                .write()
339                .expect("runtime provider lock poisoned");
340            *guard = Some(handle);
341        }
342        self.current_generation_value
343            .store(generation.0, Ordering::Release);
344        self.provider_epoch.fetch_add(1, Ordering::Release);
345        self.reload_successes.fetch_add(1, Ordering::Relaxed);
346        if telemetry_enabled() {
347            telemetry().on_reload(&ReloadTelemetryEvent {
348                outcome: ReloadOutcome::Success,
349                provider_kind: kind.clone(),
350            });
351        }
352        debug_kdb!(
353            "[kdb] reload provider success kind={kind:?} datasource_id={} generation={}",
354            datasource_id.0,
355            generation.0
356        );
357        Ok(generation)
358    }
359
360    pub fn configure_result_cache(&self, enabled: bool, capacity: usize, ttl: Duration) {
361        let new_config = ResultCacheConfig {
362            enabled,
363            capacity: capacity.max(1),
364            ttl: ttl.max(Duration::from_millis(1)),
365        };
366        let mut should_reset_cache = false;
367        {
368            let mut guard = self
369                .result_cache_config
370                .write()
371                .expect("runtime result cache config lock poisoned");
372            if guard.capacity != new_config.capacity || (!new_config.enabled && guard.enabled) {
373                should_reset_cache = true;
374            }
375            *guard = new_config;
376        }
377        self.result_cache_enabled
378            .store(new_config.enabled, Ordering::Relaxed);
379        self.result_cache_ttl_ms.store(
380            new_config.ttl.as_millis().min(u128::from(u64::MAX)) as u64,
381            Ordering::Relaxed,
382        );
383
384        if should_reset_cache {
385            let mut cache = self
386                .result_cache
387                .write()
388                .expect("runtime result cache lock poisoned");
389            *cache = LruCache::new(
390                NonZeroUsize::new(new_config.capacity).expect("non-zero result cache capacity"),
391            );
392        }
393    }
394
395    pub fn current_generation(&self) -> Option<Generation> {
396        let epoch_before = self.provider_epoch.load(Ordering::Acquire);
397        if epoch_before % 2 == 1 {
398            return self.current_generation_from_provider();
399        }
400        let generation = self.current_generation_value.load(Ordering::Acquire);
401        let epoch_after = self.provider_epoch.load(Ordering::Acquire);
402        if epoch_before != epoch_after {
403            return self.current_generation_from_provider();
404        }
405        match generation {
406            0 => None,
407            generation => Some(Generation(generation)),
408        }
409    }
410
411    pub fn snapshot(&self) -> RuntimeSnapshot {
412        let provider = self
413            .provider
414            .read()
415            .ok()
416            .and_then(|guard| guard.as_ref().cloned());
417        let result_cache_config = self
418            .result_cache_config
419            .read()
420            .map(|guard| *guard)
421            .unwrap_or_default();
422        let (result_cache_len, result_cache_capacity) = self
423            .result_cache
424            .read()
425            .map(|cache| (cache.len(), cache.cap().get()))
426            .unwrap_or((0, 0));
427        let (metadata_cache_len, metadata_cache_capacity) =
428            crate::mem::query_util::column_metadata_cache_snapshot();
429        RuntimeSnapshot {
430            provider_kind: provider.as_ref().map(|handle| handle.kind.clone()),
431            datasource_id: provider.as_ref().map(|handle| handle.datasource_id.clone()),
432            generation: provider.as_ref().map(|handle| handle.generation),
433            result_cache_enabled: result_cache_config.enabled,
434            result_cache_len,
435            result_cache_capacity,
436            result_cache_ttl_ms: result_cache_config.ttl.as_millis() as u64,
437            metadata_cache_len,
438            metadata_cache_capacity,
439            result_cache_hits: self.result_cache_hits.load(Ordering::Relaxed),
440            result_cache_misses: self.result_cache_misses.load(Ordering::Relaxed),
441            metadata_cache_hits: self.metadata_cache_hits.load(Ordering::Relaxed),
442            metadata_cache_misses: self.metadata_cache_misses.load(Ordering::Relaxed),
443            local_cache_hits: self.local_cache_hits.load(Ordering::Relaxed),
444            local_cache_misses: self.local_cache_misses.load(Ordering::Relaxed),
445            reload_successes: self.reload_successes.load(Ordering::Relaxed),
446            reload_failures: self.reload_failures.load(Ordering::Relaxed),
447        }
448    }
449
450    pub fn current_metadata_scope(&self) -> MetadataCacheScope {
451        self.provider
452            .read()
453            .ok()
454            .and_then(|guard| guard.as_ref().cloned())
455            .map(|handle| MetadataCacheScope {
456                datasource_id: handle.datasource_id.clone(),
457                generation: handle.generation,
458            })
459            .unwrap_or_else(|| MetadataCacheScope {
460                datasource_id: DatasourceId("sqlite:standalone".to_string()),
461                generation: Generation(0),
462            })
463    }
464
465    pub fn current_provider_kind(&self) -> Option<ProviderKind> {
466        self.provider
467            .read()
468            .ok()
469            .and_then(|guard| guard.as_ref().map(|handle| handle.kind.clone()))
470    }
471
472    pub fn record_result_cache_hit(&self) {
473        self.result_cache_hits.fetch_add(1, Ordering::Relaxed);
474    }
475
476    pub fn record_result_cache_miss(&self) {
477        self.result_cache_misses.fetch_add(1, Ordering::Relaxed);
478    }
479
480    pub fn record_metadata_cache_hit(&self) {
481        self.metadata_cache_hits.fetch_add(1, Ordering::Relaxed);
482    }
483
484    pub fn record_metadata_cache_miss(&self) {
485        self.metadata_cache_misses.fetch_add(1, Ordering::Relaxed);
486    }
487
488    pub fn record_local_cache_hit(&self) {
489        self.local_cache_hits.fetch_add(1, Ordering::Relaxed);
490    }
491
492    pub fn record_local_cache_miss(&self) {
493        self.local_cache_misses.fetch_add(1, Ordering::Relaxed);
494    }
495
496    pub fn execute(&self, req: &QueryRequest) -> KnowledgeResult<QueryResponse> {
497        let handle = self.current_handle()?;
498        self.execute_with_handle(&handle, req)
499    }
500
501    fn execute_with_handle(
502        &self,
503        handle: &Arc<ProviderHandle>,
504        req: &QueryRequest,
505    ) -> KnowledgeResult<QueryResponse> {
506        let use_global_cache =
507            matches!(req.cache_policy, CachePolicy::UseGlobal) && self.result_cache_enabled();
508        if use_global_cache && let Some(hit) = self.fetch_result_cache(handle, req) {
509            self.record_result_cache_hit();
510            if telemetry_enabled() {
511                telemetry().on_cache(&CacheTelemetryEvent {
512                    layer: CacheLayer::Result,
513                    outcome: CacheOutcome::Hit,
514                    provider_kind: Some(handle.kind.clone()),
515                });
516            }
517            debug_kdb!(
518                "[kdb] global result cache hit kind={:?} generation={}",
519                handle.kind,
520                handle.generation.0
521            );
522            return Ok(hit);
523        }
524        if use_global_cache {
525            self.record_result_cache_miss();
526            if telemetry_enabled() {
527                telemetry().on_cache(&CacheTelemetryEvent {
528                    layer: CacheLayer::Result,
529                    outcome: CacheOutcome::Miss,
530                    provider_kind: Some(handle.kind.clone()),
531                });
532            }
533            debug_kdb!(
534                "[kdb] global result cache miss kind={:?} generation={}",
535                handle.kind,
536                handle.generation.0
537            );
538        }
539
540        let params = params_to_fields(&req.params);
541        let mode_tag = query_mode_tag(&req.mode);
542        let started = Instant::now();
543        debug_kdb!(
544            "[kdb] execute query kind={:?} generation={} mode={:?} cache_policy={:?}",
545            handle.kind,
546            handle.generation.0,
547            req.mode,
548            req.cache_policy
549        );
550        let response = match match req.mode {
551            QueryMode::Many => {
552                if params.is_empty() {
553                    handle.provider.query(&req.sql).map(QueryResponse::Rows)
554                } else {
555                    handle
556                        .provider
557                        .query_fields(&req.sql, &params)
558                        .map(QueryResponse::Rows)
559                }
560            }
561            QueryMode::FirstRow => {
562                if params.is_empty() {
563                    handle.provider.query_row(&req.sql).map(QueryResponse::Row)
564                } else {
565                    handle
566                        .provider
567                        .query_named_fields(&req.sql, &params)
568                        .map(QueryResponse::Row)
569                }
570            }
571        } {
572            Ok(response) => {
573                if telemetry_enabled() {
574                    telemetry().on_query(&QueryTelemetryEvent {
575                        provider_kind: handle.kind.clone(),
576                        mode: mode_tag,
577                        success: true,
578                        elapsed: started.elapsed(),
579                    });
580                }
581                response
582            }
583            Err(err) => {
584                if telemetry_enabled() {
585                    telemetry().on_query(&QueryTelemetryEvent {
586                        provider_kind: handle.kind.clone(),
587                        mode: mode_tag,
588                        success: false,
589                        elapsed: started.elapsed(),
590                    });
591                }
592                return Err(err);
593            }
594        };
595
596        if use_global_cache {
597            self.save_result_cache(handle, req, response.clone());
598            debug_kdb!(
599                "[kdb] global result cache store kind={:?} generation={}",
600                handle.kind,
601                handle.generation.0
602            );
603        }
604
605        Ok(response)
606    }
607
608    pub fn execute_first_row_fields(
609        &self,
610        sql: &str,
611        params: &[DataField],
612        cache_policy: CachePolicy,
613    ) -> KnowledgeResult<RowData> {
614        let handle = self.current_handle()?;
615        self.execute_first_row_fields_with_handle(&handle, sql, params, cache_policy)
616    }
617
618    fn execute_first_row_fields_with_handle(
619        &self,
620        handle: &Arc<ProviderHandle>,
621        sql: &str,
622        params: &[DataField],
623        cache_policy: CachePolicy,
624    ) -> KnowledgeResult<RowData> {
625        let use_global_cache =
626            matches!(cache_policy, CachePolicy::UseGlobal) && self.result_cache_enabled();
627        if use_global_cache
628            && let Some(hit) = self.fetch_result_cache_by_key(result_cache_key_fields(
629                handle,
630                sql,
631                params,
632                QueryModeTag::FirstRow,
633            ))
634        {
635            self.record_result_cache_hit();
636            if telemetry_enabled() {
637                telemetry().on_cache(&CacheTelemetryEvent {
638                    layer: CacheLayer::Result,
639                    outcome: CacheOutcome::Hit,
640                    provider_kind: Some(handle.kind.clone()),
641                });
642            }
643            return Ok(hit.into_row());
644        }
645        if use_global_cache {
646            self.record_result_cache_miss();
647            if telemetry_enabled() {
648                telemetry().on_cache(&CacheTelemetryEvent {
649                    layer: CacheLayer::Result,
650                    outcome: CacheOutcome::Miss,
651                    provider_kind: Some(handle.kind.clone()),
652                });
653            }
654        }
655
656        let started = Instant::now();
657        let row = if params.is_empty() {
658            handle.provider.query_row(sql)
659        } else {
660            handle.provider.query_named_fields(sql, params)
661        };
662        let row = match row {
663            Ok(row) => {
664                if telemetry_enabled() {
665                    telemetry().on_query(&QueryTelemetryEvent {
666                        provider_kind: handle.kind.clone(),
667                        mode: QueryModeTag::FirstRow,
668                        success: true,
669                        elapsed: started.elapsed(),
670                    });
671                }
672                row
673            }
674            Err(err) => {
675                if telemetry_enabled() {
676                    telemetry().on_query(&QueryTelemetryEvent {
677                        provider_kind: handle.kind.clone(),
678                        mode: QueryModeTag::FirstRow,
679                        success: false,
680                        elapsed: started.elapsed(),
681                    });
682                }
683                return Err(err);
684            }
685        };
686
687        if use_global_cache {
688            self.save_result_cache_by_key(
689                result_cache_key_fields(handle, sql, params, QueryModeTag::FirstRow),
690                QueryResponse::Row(row.clone()),
691            );
692        }
693
694        Ok(row)
695    }
696
697    pub async fn execute_async(&self, req: &QueryRequest) -> KnowledgeResult<QueryResponse> {
698        let handle = self.current_handle()?;
699        if matches!(handle.kind, ProviderKind::SqliteAuthority) {
700            let handle = handle.clone();
701            let req = req.clone();
702            return task::spawn_blocking(move || runtime().execute_with_handle(&handle, &req))
703                .await
704                .map_err(|err| {
705                    KnowledgeReason::from_logic()
706                        .to_err()
707                        .with_detail(format!("knowledge async sqlite query join failed: {err}"))
708                })?;
709        }
710        let use_global_cache =
711            matches!(req.cache_policy, CachePolicy::UseGlobal) && self.result_cache_enabled();
712        if use_global_cache && let Some(hit) = self.fetch_result_cache(&handle, req) {
713            self.record_result_cache_hit();
714            if telemetry_enabled() {
715                telemetry().on_cache(&CacheTelemetryEvent {
716                    layer: CacheLayer::Result,
717                    outcome: CacheOutcome::Hit,
718                    provider_kind: Some(handle.kind.clone()),
719                });
720            }
721            return Ok(hit);
722        }
723        if use_global_cache {
724            self.record_result_cache_miss();
725            if telemetry_enabled() {
726                telemetry().on_cache(&CacheTelemetryEvent {
727                    layer: CacheLayer::Result,
728                    outcome: CacheOutcome::Miss,
729                    provider_kind: Some(handle.kind.clone()),
730                });
731            }
732        }
733
734        let params = params_to_fields(&req.params);
735        let mode_tag = query_mode_tag(&req.mode);
736        let started = Instant::now();
737        let response = match req.mode {
738            QueryMode::Many => {
739                if params.is_empty() {
740                    handle
741                        .provider
742                        .query_async(&req.sql)
743                        .await
744                        .map(QueryResponse::Rows)
745                } else {
746                    handle
747                        .provider
748                        .query_fields_async(&req.sql, &params)
749                        .await
750                        .map(QueryResponse::Rows)
751                }
752            }
753            QueryMode::FirstRow => {
754                if params.is_empty() {
755                    handle
756                        .provider
757                        .query_row_async(&req.sql)
758                        .await
759                        .map(QueryResponse::Row)
760                } else {
761                    handle
762                        .provider
763                        .query_named_fields_async(&req.sql, &params)
764                        .await
765                        .map(QueryResponse::Row)
766                }
767            }
768        };
769        let response = match response {
770            Ok(response) => {
771                if telemetry_enabled() {
772                    telemetry().on_query(&QueryTelemetryEvent {
773                        provider_kind: handle.kind.clone(),
774                        mode: mode_tag,
775                        success: true,
776                        elapsed: started.elapsed(),
777                    });
778                }
779                response
780            }
781            Err(err) => {
782                if telemetry_enabled() {
783                    telemetry().on_query(&QueryTelemetryEvent {
784                        provider_kind: handle.kind.clone(),
785                        mode: mode_tag,
786                        success: false,
787                        elapsed: started.elapsed(),
788                    });
789                }
790                return Err(err);
791            }
792        };
793
794        if use_global_cache {
795            self.save_result_cache(&handle, req, response.clone());
796        }
797
798        Ok(response)
799    }
800
801    pub async fn execute_first_row_fields_async(
802        &self,
803        sql: &str,
804        params: &[DataField],
805        cache_policy: CachePolicy,
806    ) -> KnowledgeResult<RowData> {
807        let handle = self.current_handle()?;
808        if matches!(handle.kind, ProviderKind::SqliteAuthority) {
809            let handle = handle.clone();
810            let sql = sql.to_string();
811            let params = params.to_vec();
812            return task::spawn_blocking(move || {
813                runtime().execute_first_row_fields_with_handle(&handle, &sql, &params, cache_policy)
814            })
815            .await
816            .map_err(|err| {
817                KnowledgeReason::from_logic().to_err().with_detail(format!(
818                    "knowledge async sqlite first-row query join failed: {err}"
819                ))
820            })?;
821        }
822        let use_global_cache =
823            matches!(cache_policy, CachePolicy::UseGlobal) && self.result_cache_enabled();
824        if use_global_cache
825            && let Some(hit) = self.fetch_result_cache_by_key(result_cache_key_fields(
826                &handle,
827                sql,
828                params,
829                QueryModeTag::FirstRow,
830            ))
831        {
832            self.record_result_cache_hit();
833            if telemetry_enabled() {
834                telemetry().on_cache(&CacheTelemetryEvent {
835                    layer: CacheLayer::Result,
836                    outcome: CacheOutcome::Hit,
837                    provider_kind: Some(handle.kind.clone()),
838                });
839            }
840            return Ok(hit.into_row());
841        }
842        if use_global_cache {
843            self.record_result_cache_miss();
844            if telemetry_enabled() {
845                telemetry().on_cache(&CacheTelemetryEvent {
846                    layer: CacheLayer::Result,
847                    outcome: CacheOutcome::Miss,
848                    provider_kind: Some(handle.kind.clone()),
849                });
850            }
851        }
852
853        let started = Instant::now();
854        let row = if params.is_empty() {
855            handle.provider.query_row_async(sql).await
856        } else {
857            handle.provider.query_named_fields_async(sql, params).await
858        };
859        let row = match row {
860            Ok(row) => {
861                if telemetry_enabled() {
862                    telemetry().on_query(&QueryTelemetryEvent {
863                        provider_kind: handle.kind.clone(),
864                        mode: QueryModeTag::FirstRow,
865                        success: true,
866                        elapsed: started.elapsed(),
867                    });
868                }
869                row
870            }
871            Err(err) => {
872                if telemetry_enabled() {
873                    telemetry().on_query(&QueryTelemetryEvent {
874                        provider_kind: handle.kind.clone(),
875                        mode: QueryModeTag::FirstRow,
876                        success: false,
877                        elapsed: started.elapsed(),
878                    });
879                }
880                return Err(err);
881            }
882        };
883
884        if use_global_cache {
885            self.save_result_cache_by_key(
886                result_cache_key_fields(&handle, sql, params, QueryModeTag::FirstRow),
887                QueryResponse::Row(row.clone()),
888            );
889        }
890
891        Ok(row)
892    }
893
894    fn current_handle(&self) -> KnowledgeResult<Arc<ProviderHandle>> {
895        self.provider
896            .read()
897            .expect("runtime provider lock poisoned")
898            .clone()
899            .ok_or_else(|| {
900                KnowledgeReason::from_logic()
901                    .to_err()
902                    .with_detail("knowledge provider not initialized")
903            })
904    }
905
906    fn current_generation_from_provider(&self) -> Option<Generation> {
907        self.provider
908            .read()
909            .ok()
910            .and_then(|guard| guard.as_ref().map(|handle| handle.generation))
911    }
912
913    fn fetch_result_cache(
914        &self,
915        handle: &ProviderHandle,
916        req: &QueryRequest,
917    ) -> Option<QueryResponse> {
918        self.fetch_result_cache_by_key(result_cache_key(handle, req))
919    }
920
921    fn fetch_result_cache_by_key(&self, key: ResultCacheKey) -> Option<QueryResponse> {
922        if !self.result_cache_enabled() {
923            return None;
924        }
925        let cached = self
926            .result_cache
927            .read()
928            .ok()
929            .and_then(|cache| cache.peek(&key).cloned())?;
930        if cached.cached_at.elapsed() > self.result_cache_ttl() {
931            if let Ok(mut cache) = self.result_cache.write() {
932                let _ = cache.pop(&key);
933            }
934            return None;
935        }
936        Some((*cached.response).clone())
937    }
938
939    fn save_result_cache(
940        &self,
941        handle: &ProviderHandle,
942        req: &QueryRequest,
943        response: QueryResponse,
944    ) {
945        self.save_result_cache_by_key(result_cache_key(handle, req), response);
946    }
947
948    fn save_result_cache_by_key(&self, key: ResultCacheKey, response: QueryResponse) {
949        if let Ok(mut cache) = self.result_cache.write() {
950            cache.put(
951                key,
952                CachedQueryResponse {
953                    response: Arc::new(response),
954                    cached_at: Instant::now(),
955                },
956            );
957        }
958    }
959
960    #[inline]
961    fn result_cache_enabled(&self) -> bool {
962        self.result_cache_enabled.load(Ordering::Relaxed)
963    }
964
965    #[inline]
966    fn result_cache_ttl(&self) -> Duration {
967        Duration::from_millis(self.result_cache_ttl_ms.load(Ordering::Relaxed))
968    }
969}
970
971pub fn runtime() -> &'static KnowledgeRuntime {
972    static RUNTIME: OnceLock<KnowledgeRuntime> = OnceLock::new();
973    RUNTIME.get_or_init(|| KnowledgeRuntime::new(1024))
974}
975
976#[cfg(test)]
977pub(crate) struct RuntimeTestGuard(tokio::sync::Mutex<()>);
978
979#[cfg(test)]
980impl RuntimeTestGuard {
981    pub(crate) fn lock(&self) -> Result<tokio::sync::MutexGuard<'_, ()>, std::convert::Infallible> {
982        Ok(self.0.blocking_lock())
983    }
984
985    pub(crate) async fn lock_async(&self) -> tokio::sync::MutexGuard<'_, ()> {
986        self.0.lock().await
987    }
988}
989
990#[cfg(test)]
991pub(crate) fn runtime_test_guard() -> &'static RuntimeTestGuard {
992    static GUARD: OnceLock<RuntimeTestGuard> = OnceLock::new();
993    GUARD.get_or_init(|| RuntimeTestGuard(tokio::sync::Mutex::new(())))
994}
995
996fn result_cache_key(handle: &ProviderHandle, req: &QueryRequest) -> ResultCacheKey {
997    ResultCacheKey {
998        datasource_id: handle.datasource_id.clone(),
999        generation: handle.generation,
1000        query_hash: stable_hash(&req.sql),
1001        params_hash: stable_params_hash(&req.params),
1002        mode: match req.mode {
1003            QueryMode::Many => QueryModeTag::Many,
1004            QueryMode::FirstRow => QueryModeTag::FirstRow,
1005        },
1006    }
1007}
1008
1009fn result_cache_key_fields(
1010    handle: &ProviderHandle,
1011    sql: &str,
1012    params: &[DataField],
1013    mode: QueryModeTag,
1014) -> ResultCacheKey {
1015    ResultCacheKey {
1016        datasource_id: handle.datasource_id.clone(),
1017        generation: handle.generation,
1018        query_hash: stable_hash(sql),
1019        params_hash: stable_field_params_hash(params),
1020        mode,
1021    }
1022}
1023
1024fn query_mode_tag(mode: &QueryMode) -> QueryModeTag {
1025    match mode {
1026        QueryMode::Many => QueryModeTag::Many,
1027        QueryMode::FirstRow => QueryModeTag::FirstRow,
1028    }
1029}
1030
1031fn stable_hash(value: &str) -> u64 {
1032    let mut hasher = DefaultHasher::new();
1033    value.hash(&mut hasher);
1034    hasher.finish()
1035}
1036
1037fn stable_params_hash(params: &[QueryParam]) -> u64 {
1038    let mut hasher = DefaultHasher::new();
1039    for param in params {
1040        param.name.hash(&mut hasher);
1041        match &param.value {
1042            QueryValue::Null => 0u8.hash(&mut hasher),
1043            QueryValue::Bool(value) => {
1044                1u8.hash(&mut hasher);
1045                value.hash(&mut hasher);
1046            }
1047            QueryValue::Int(value) => {
1048                2u8.hash(&mut hasher);
1049                value.hash(&mut hasher);
1050            }
1051            QueryValue::Float(value) => {
1052                3u8.hash(&mut hasher);
1053                value.to_bits().hash(&mut hasher);
1054            }
1055            QueryValue::Text(value) => {
1056                4u8.hash(&mut hasher);
1057                value.hash(&mut hasher);
1058            }
1059        }
1060    }
1061    hasher.finish()
1062}
1063
1064fn stable_field_params_hash(params: &[DataField]) -> u64 {
1065    let mut hasher = DefaultHasher::new();
1066    for field in params {
1067        field.get_name().hash(&mut hasher);
1068        match field.get_value() {
1069            Value::Null | Value::Ignore(_) => 0u8.hash(&mut hasher),
1070            Value::Bool(value) => {
1071                1u8.hash(&mut hasher);
1072                value.hash(&mut hasher);
1073            }
1074            Value::Digit(value) => {
1075                2u8.hash(&mut hasher);
1076                value.hash(&mut hasher);
1077            }
1078            Value::Float(value) => {
1079                3u8.hash(&mut hasher);
1080                value.to_bits().hash(&mut hasher);
1081            }
1082            Value::Chars(value) => {
1083                4u8.hash(&mut hasher);
1084                value.hash(&mut hasher);
1085            }
1086            Value::Symbol(value) => {
1087                5u8.hash(&mut hasher);
1088                value.hash(&mut hasher);
1089            }
1090            Value::Time(value) => {
1091                6u8.hash(&mut hasher);
1092                value.hash(&mut hasher);
1093            }
1094            Value::Hex(value) => {
1095                7u8.hash(&mut hasher);
1096                value.to_string().hash(&mut hasher);
1097            }
1098            Value::IpNet(value) => {
1099                8u8.hash(&mut hasher);
1100                value.to_string().hash(&mut hasher);
1101            }
1102            Value::IpAddr(value) => {
1103                9u8.hash(&mut hasher);
1104                value.hash(&mut hasher);
1105            }
1106            Value::Obj(value) => {
1107                10u8.hash(&mut hasher);
1108                format!("{:?}", value).hash(&mut hasher);
1109            }
1110            Value::Array(value) => {
1111                11u8.hash(&mut hasher);
1112                format!("{:?}", value).hash(&mut hasher);
1113            }
1114            Value::Domain(value) => {
1115                12u8.hash(&mut hasher);
1116                value.0.hash(&mut hasher);
1117            }
1118            Value::Url(value) => {
1119                13u8.hash(&mut hasher);
1120                value.0.hash(&mut hasher);
1121            }
1122            Value::Email(value) => {
1123                14u8.hash(&mut hasher);
1124                value.0.hash(&mut hasher);
1125            }
1126            Value::IdCard(value) => {
1127                15u8.hash(&mut hasher);
1128                value.0.hash(&mut hasher);
1129            }
1130            Value::MobilePhone(value) => {
1131                16u8.hash(&mut hasher);
1132                value.0.hash(&mut hasher);
1133            }
1134        }
1135    }
1136    hasher.finish()
1137}
1138
1139pub fn fields_to_params(params: &[DataField]) -> Vec<QueryParam> {
1140    params
1141        .iter()
1142        .map(|field| {
1143            let value = match field.get_value() {
1144                Value::Null | Value::Ignore(_) => QueryValue::Null,
1145                Value::Bool(value) => QueryValue::Bool(*value),
1146                Value::Digit(value) => QueryValue::Int(*value),
1147                Value::Float(value) => QueryValue::Float(*value),
1148                Value::Chars(value) => QueryValue::Text(value.to_string()),
1149                Value::Symbol(value) => QueryValue::Text(value.to_string()),
1150                Value::Time(value) => QueryValue::Text(value.to_string()),
1151                Value::Hex(value) => QueryValue::Text(value.to_string()),
1152                Value::IpNet(value) => QueryValue::Text(value.to_string()),
1153                Value::IpAddr(value) => QueryValue::Text(value.to_string()),
1154                Value::Obj(value) => QueryValue::Text(format!("{:?}", value)),
1155                Value::Array(value) => QueryValue::Text(format!("{:?}", value)),
1156                Value::Domain(value) => QueryValue::Text(value.0.to_string()),
1157                Value::Url(value) => QueryValue::Text(value.0.to_string()),
1158                Value::Email(value) => QueryValue::Text(value.0.to_string()),
1159                Value::IdCard(value) => QueryValue::Text(value.0.to_string()),
1160                Value::MobilePhone(value) => QueryValue::Text(value.0.to_string()),
1161            };
1162            QueryParam {
1163                name: field.get_name().to_string(),
1164                value,
1165            }
1166        })
1167        .collect()
1168}
1169
1170pub fn params_to_fields(params: &[QueryParam]) -> Vec<DataField> {
1171    params
1172        .iter()
1173        .map(|param| match &param.value {
1174            QueryValue::Null => {
1175                DataField::new(DataType::default(), param.name.clone(), Value::Null)
1176            }
1177            QueryValue::Bool(value) => {
1178                DataField::new(DataType::default(), param.name.clone(), Value::Bool(*value))
1179            }
1180            QueryValue::Int(value) => DataField::from_digit(param.name.clone(), *value),
1181            QueryValue::Float(value) => DataField::from_float(param.name.clone(), *value),
1182            QueryValue::Text(value) => DataField::from_chars(param.name.clone(), value.clone()),
1183        })
1184        .collect()
1185}
1186
1187#[cfg(test)]
1188mod tests {
1189    use super::*;
1190    use async_trait::async_trait;
1191    use std::sync::Arc;
1192    use wp_model_core::model::Value;
1193
1194    struct TestProvider {
1195        value: &'static str,
1196    }
1197
1198    #[async_trait]
1199    impl ProviderExecutor for TestProvider {
1200        fn query(&self, _sql: &str) -> KnowledgeResult<Vec<RowData>> {
1201            Ok(vec![vec![DataField::from_chars("value", self.value)]])
1202        }
1203
1204        fn query_fields(&self, _sql: &str, _params: &[DataField]) -> KnowledgeResult<Vec<RowData>> {
1205            self.query("")
1206        }
1207
1208        fn query_row(&self, _sql: &str) -> KnowledgeResult<RowData> {
1209            Ok(vec![DataField::from_chars("value", self.value)])
1210        }
1211
1212        fn query_named_fields(
1213            &self,
1214            _sql: &str,
1215            _params: &[DataField],
1216        ) -> KnowledgeResult<RowData> {
1217            self.query_row("")
1218        }
1219    }
1220
1221    #[test]
1222    fn query_param_hash_is_stable() {
1223        let params = vec![
1224            QueryParam {
1225                name: ":id".to_string(),
1226                value: QueryValue::Int(7),
1227            },
1228            QueryParam {
1229                name: ":name".to_string(),
1230                value: QueryValue::Text("abc".to_string()),
1231            },
1232        ];
1233        assert_eq!(stable_params_hash(&params), stable_params_hash(&params));
1234    }
1235
1236    #[test]
1237    fn fields_to_params_preserves_raw_chars_value() {
1238        let fields = [DataField::from_chars(
1239            ":name".to_string(),
1240            "令狐冲".to_string(),
1241        )];
1242        let params = fields_to_params(&fields);
1243        assert_eq!(params.len(), 1);
1244        match &params[0].value {
1245            QueryValue::Text(value) => assert_eq!(value, "令狐冲"),
1246            other => panic!("unexpected param value: {other:?}"),
1247        }
1248        let roundtrip = params_to_fields(&params);
1249        assert!(matches!(roundtrip[0].get_value(), Value::Chars(_)));
1250    }
1251
1252    #[tokio::test(flavor = "current_thread")]
1253    async fn sqlite_async_bridge_keeps_captured_handle_after_reload() {
1254        let _guard = runtime_test_guard().lock_async().await;
1255        runtime()
1256            .install_provider(
1257                ProviderKind::SqliteAuthority,
1258                DatasourceId("sqlite:old".to_string()),
1259                |_generation| Ok(Arc::new(TestProvider { value: "old" })),
1260            )
1261            .expect("install old provider");
1262        let old_handle = runtime().current_handle().expect("current old handle");
1263
1264        runtime()
1265            .install_provider(
1266                ProviderKind::SqliteAuthority,
1267                DatasourceId("sqlite:new".to_string()),
1268                |_generation| Ok(Arc::new(TestProvider { value: "new" })),
1269            )
1270            .expect("install new provider");
1271
1272        let req = QueryRequest::first_row("SELECT value", Vec::new(), CachePolicy::Bypass);
1273        let row = task::spawn_blocking(move || runtime().execute_with_handle(&old_handle, &req))
1274            .await
1275            .expect("join sqlite bridge")
1276            .expect("execute old handle")
1277            .into_row();
1278        assert_eq!(row[0].to_string(), "chars(old)");
1279    }
1280}