1use std::collections::hash_map::DefaultHasher;
2use std::hash::{Hash, Hasher};
3use std::num::NonZeroUsize;
4use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
5use std::sync::{Arc, OnceLock, RwLock};
6use std::time::{Duration, Instant};
7
8use async_trait::async_trait;
9use lru::LruCache;
10use orion_error::{ToStructError, UvsFrom};
11use tokio::task;
12use wp_error::{KnowledgeReason, KnowledgeResult};
13use wp_log::{debug_kdb, warn_kdb};
14use wp_model_core::model::{DataField, DataType, Value};
15
16use crate::loader::ProviderKind;
17use crate::mem::RowData;
18use crate::telemetry::{
19 CacheLayer, CacheOutcome, CacheTelemetryEvent, QueryTelemetryEvent, ReloadOutcome,
20 ReloadTelemetryEvent, telemetry, telemetry_enabled,
21};
22
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct DatasourceId(pub String);
25
26impl DatasourceId {
27 pub fn from_seed(kind: ProviderKind, seed: &str) -> Self {
28 let mut hasher = DefaultHasher::new();
29 seed.hash(&mut hasher);
30 let kind_str = match kind {
31 ProviderKind::SqliteAuthority => "sqlite",
32 ProviderKind::Postgres => "postgres",
33 ProviderKind::Mysql => "mysql",
34 };
35 Self(format!("{kind_str}:{:016x}", hasher.finish()))
36 }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40pub struct Generation(pub u64);
41
42#[derive(Debug, Clone)]
43pub enum QueryMode {
44 Many,
45 FirstRow,
46}
47
48#[derive(Debug, Clone, Copy)]
49pub enum CachePolicy {
50 Bypass,
51 UseGlobal,
52 UseCallScope,
53}
54
55#[derive(Debug, Clone)]
56pub enum QueryValue {
57 Null,
58 Bool(bool),
59 Int(i64),
60 Float(f64),
61 Text(String),
62}
63
64#[derive(Debug, Clone)]
65pub struct QueryParam {
66 pub name: String,
67 pub value: QueryValue,
68}
69
70#[derive(Debug, Clone)]
71pub struct QueryRequest {
72 pub sql: String,
73 pub params: Vec<QueryParam>,
74 pub mode: QueryMode,
75 pub cache_policy: CachePolicy,
76}
77
78impl QueryRequest {
79 pub fn many(
80 sql: impl Into<String>,
81 params: Vec<QueryParam>,
82 cache_policy: CachePolicy,
83 ) -> Self {
84 Self {
85 sql: sql.into(),
86 params,
87 mode: QueryMode::Many,
88 cache_policy,
89 }
90 }
91
92 pub fn first_row(
93 sql: impl Into<String>,
94 params: Vec<QueryParam>,
95 cache_policy: CachePolicy,
96 ) -> Self {
97 Self {
98 sql: sql.into(),
99 params,
100 mode: QueryMode::FirstRow,
101 cache_policy,
102 }
103 }
104}
105
106#[derive(Debug, Clone)]
107pub enum QueryResponse {
108 Rows(Vec<RowData>),
109 Row(RowData),
110}
111
112impl QueryResponse {
113 pub fn into_rows(self) -> Vec<RowData> {
114 match self {
115 QueryResponse::Rows(rows) => rows,
116 QueryResponse::Row(row) => vec![row],
117 }
118 }
119
120 pub fn into_row(self) -> RowData {
121 match self {
122 QueryResponse::Rows(rows) => rows.into_iter().next().unwrap_or_default(),
123 QueryResponse::Row(row) => row,
124 }
125 }
126}
127
128#[async_trait]
129pub trait ProviderExecutor: Send + Sync {
130 fn query(&self, sql: &str) -> KnowledgeResult<Vec<RowData>>;
131 fn query_fields(&self, sql: &str, params: &[DataField]) -> KnowledgeResult<Vec<RowData>>;
132 fn query_row(&self, sql: &str) -> KnowledgeResult<RowData>;
133 fn query_named_fields(&self, sql: &str, params: &[DataField]) -> KnowledgeResult<RowData>;
134
135 async fn query_async(&self, sql: &str) -> KnowledgeResult<Vec<RowData>> {
136 self.query(sql)
137 }
138
139 async fn query_fields_async(
140 &self,
141 sql: &str,
142 params: &[DataField],
143 ) -> KnowledgeResult<Vec<RowData>> {
144 self.query_fields(sql, params)
145 }
146
147 async fn query_row_async(&self, sql: &str) -> KnowledgeResult<RowData> {
148 self.query_row(sql)
149 }
150
151 async fn query_named_fields_async(
152 &self,
153 sql: &str,
154 params: &[DataField],
155 ) -> KnowledgeResult<RowData> {
156 self.query_named_fields(sql, params)
157 }
158}
159
160#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
161pub enum QueryModeTag {
162 Many,
163 FirstRow,
164}
165
166#[derive(Debug, Clone, PartialEq, Eq, Hash)]
167pub struct ResultCacheKey {
168 pub datasource_id: DatasourceId,
169 pub generation: Generation,
170 pub query_hash: u64,
171 pub params_hash: u64,
172 pub mode: QueryModeTag,
173}
174
175pub struct ProviderHandle {
176 pub provider: Arc<dyn ProviderExecutor>,
177 pub datasource_id: DatasourceId,
178 pub generation: Generation,
179 pub kind: ProviderKind,
180}
181
182#[derive(Debug, Clone)]
183pub struct RuntimeSnapshot {
184 pub provider_kind: Option<ProviderKind>,
185 pub datasource_id: Option<DatasourceId>,
186 pub generation: Option<Generation>,
187 pub result_cache_enabled: bool,
188 pub result_cache_len: usize,
189 pub result_cache_capacity: usize,
190 pub result_cache_ttl_ms: u64,
191 pub metadata_cache_len: usize,
192 pub metadata_cache_capacity: usize,
193 pub result_cache_hits: u64,
194 pub result_cache_misses: u64,
195 pub metadata_cache_hits: u64,
196 pub metadata_cache_misses: u64,
197 pub local_cache_hits: u64,
198 pub local_cache_misses: u64,
199 pub reload_successes: u64,
200 pub reload_failures: u64,
201}
202
203#[derive(Debug, Clone)]
204pub struct MetadataCacheScope {
205 pub datasource_id: DatasourceId,
206 pub generation: Generation,
207}
208
209#[derive(Debug, Clone, Copy)]
210pub struct ResultCacheConfig {
211 pub enabled: bool,
212 pub capacity: usize,
213 pub ttl: Duration,
214}
215
216impl Default for ResultCacheConfig {
217 fn default() -> Self {
218 Self {
219 enabled: true,
220 capacity: 1024,
221 ttl: Duration::from_millis(30_000),
222 }
223 }
224}
225
226#[derive(Debug, Clone)]
227struct CachedQueryResponse {
228 response: Arc<QueryResponse>,
229 cached_at: Instant,
230}
231
232pub struct KnowledgeRuntime {
233 provider: RwLock<Option<Arc<ProviderHandle>>>,
234 next_generation: AtomicU64,
235 provider_epoch: AtomicU64,
236 current_generation_value: AtomicU64,
237 result_cache_config: RwLock<ResultCacheConfig>,
238 result_cache_enabled: AtomicBool,
239 result_cache_ttl_ms: AtomicU64,
240 result_cache: RwLock<LruCache<ResultCacheKey, CachedQueryResponse>>,
241 result_cache_hits: AtomicU64,
242 result_cache_misses: AtomicU64,
243 metadata_cache_hits: AtomicU64,
244 metadata_cache_misses: AtomicU64,
245 local_cache_hits: AtomicU64,
246 local_cache_misses: AtomicU64,
247 reload_successes: AtomicU64,
248 reload_failures: AtomicU64,
249}
250
251impl KnowledgeRuntime {
252 pub fn new(result_cache_capacity: usize) -> Self {
253 let config = ResultCacheConfig {
254 capacity: result_cache_capacity.max(1),
255 ..ResultCacheConfig::default()
256 };
257 let capacity = NonZeroUsize::new(config.capacity).expect("non-zero capacity");
258 Self {
259 provider: RwLock::new(None),
260 next_generation: AtomicU64::new(0),
261 provider_epoch: AtomicU64::new(0),
262 current_generation_value: AtomicU64::new(0),
263 result_cache_config: RwLock::new(config),
264 result_cache_enabled: AtomicBool::new(config.enabled),
265 result_cache_ttl_ms: AtomicU64::new(config.ttl.as_millis() as u64),
266 result_cache: RwLock::new(LruCache::new(capacity)),
267 result_cache_hits: AtomicU64::new(0),
268 result_cache_misses: AtomicU64::new(0),
269 metadata_cache_hits: AtomicU64::new(0),
270 metadata_cache_misses: AtomicU64::new(0),
271 local_cache_hits: AtomicU64::new(0),
272 local_cache_misses: AtomicU64::new(0),
273 reload_successes: AtomicU64::new(0),
274 reload_failures: AtomicU64::new(0),
275 }
276 }
277
278 pub fn install_provider<F>(
279 &self,
280 kind: ProviderKind,
281 datasource_id: DatasourceId,
282 build: F,
283 ) -> KnowledgeResult<Generation>
284 where
285 F: FnOnce(Generation) -> KnowledgeResult<Arc<dyn ProviderExecutor>>,
286 {
287 let generation = Generation(self.next_generation.fetch_add(1, Ordering::SeqCst) + 1);
288 let previous = self
289 .provider
290 .read()
291 .ok()
292 .and_then(|guard| guard.as_ref().cloned());
293 debug_kdb!(
294 "[kdb] reload provider start kind={kind:?} datasource_id={} target_generation={} previous_generation={}",
295 datasource_id.0,
296 generation.0,
297 previous
298 .as_ref()
299 .map(|handle| handle.generation.0.to_string())
300 .unwrap_or_else(|| "none".to_string())
301 );
302 let provider = match build(generation) {
303 Ok(provider) => provider,
304 Err(err) => {
305 self.reload_failures.fetch_add(1, Ordering::Relaxed);
306 warn_kdb!(
307 "[kdb] reload provider failed kind={kind:?} datasource_id={} target_generation={} err={}",
308 datasource_id.0,
309 generation.0,
310 err
311 );
312 if telemetry_enabled() {
313 telemetry().on_reload(&ReloadTelemetryEvent {
314 outcome: ReloadOutcome::Failure,
315 provider_kind: kind.clone(),
316 });
317 }
318 return Err(err);
319 }
320 };
321 debug_kdb!(
322 "[kdb] install provider kind={kind:?} datasource_id={} generation={}",
323 datasource_id.0,
324 generation.0
325 );
326 let kind_for_handle = kind.clone();
327 let datasource_id_for_handle = datasource_id.clone();
328 let handle = Arc::new(ProviderHandle {
329 provider,
330 datasource_id: datasource_id_for_handle,
331 generation,
332 kind: kind_for_handle,
333 });
334 self.provider_epoch.fetch_add(1, Ordering::AcqRel);
335 {
336 let mut guard = self
337 .provider
338 .write()
339 .expect("runtime provider lock poisoned");
340 *guard = Some(handle);
341 }
342 self.current_generation_value
343 .store(generation.0, Ordering::Release);
344 self.provider_epoch.fetch_add(1, Ordering::Release);
345 self.reload_successes.fetch_add(1, Ordering::Relaxed);
346 if telemetry_enabled() {
347 telemetry().on_reload(&ReloadTelemetryEvent {
348 outcome: ReloadOutcome::Success,
349 provider_kind: kind.clone(),
350 });
351 }
352 debug_kdb!(
353 "[kdb] reload provider success kind={kind:?} datasource_id={} generation={}",
354 datasource_id.0,
355 generation.0
356 );
357 Ok(generation)
358 }
359
360 pub fn configure_result_cache(&self, enabled: bool, capacity: usize, ttl: Duration) {
361 let new_config = ResultCacheConfig {
362 enabled,
363 capacity: capacity.max(1),
364 ttl: ttl.max(Duration::from_millis(1)),
365 };
366 let mut should_reset_cache = false;
367 {
368 let mut guard = self
369 .result_cache_config
370 .write()
371 .expect("runtime result cache config lock poisoned");
372 if guard.capacity != new_config.capacity || (!new_config.enabled && guard.enabled) {
373 should_reset_cache = true;
374 }
375 *guard = new_config;
376 }
377 self.result_cache_enabled
378 .store(new_config.enabled, Ordering::Relaxed);
379 self.result_cache_ttl_ms.store(
380 new_config.ttl.as_millis().min(u128::from(u64::MAX)) as u64,
381 Ordering::Relaxed,
382 );
383
384 if should_reset_cache {
385 let mut cache = self
386 .result_cache
387 .write()
388 .expect("runtime result cache lock poisoned");
389 *cache = LruCache::new(
390 NonZeroUsize::new(new_config.capacity).expect("non-zero result cache capacity"),
391 );
392 }
393 }
394
395 pub fn current_generation(&self) -> Option<Generation> {
396 let epoch_before = self.provider_epoch.load(Ordering::Acquire);
397 if epoch_before % 2 == 1 {
398 return self.current_generation_from_provider();
399 }
400 let generation = self.current_generation_value.load(Ordering::Acquire);
401 let epoch_after = self.provider_epoch.load(Ordering::Acquire);
402 if epoch_before != epoch_after {
403 return self.current_generation_from_provider();
404 }
405 match generation {
406 0 => None,
407 generation => Some(Generation(generation)),
408 }
409 }
410
411 pub fn snapshot(&self) -> RuntimeSnapshot {
412 let provider = self
413 .provider
414 .read()
415 .ok()
416 .and_then(|guard| guard.as_ref().cloned());
417 let result_cache_config = self
418 .result_cache_config
419 .read()
420 .map(|guard| *guard)
421 .unwrap_or_default();
422 let (result_cache_len, result_cache_capacity) = self
423 .result_cache
424 .read()
425 .map(|cache| (cache.len(), cache.cap().get()))
426 .unwrap_or((0, 0));
427 let (metadata_cache_len, metadata_cache_capacity) =
428 crate::mem::query_util::column_metadata_cache_snapshot();
429 RuntimeSnapshot {
430 provider_kind: provider.as_ref().map(|handle| handle.kind.clone()),
431 datasource_id: provider.as_ref().map(|handle| handle.datasource_id.clone()),
432 generation: provider.as_ref().map(|handle| handle.generation),
433 result_cache_enabled: result_cache_config.enabled,
434 result_cache_len,
435 result_cache_capacity,
436 result_cache_ttl_ms: result_cache_config.ttl.as_millis() as u64,
437 metadata_cache_len,
438 metadata_cache_capacity,
439 result_cache_hits: self.result_cache_hits.load(Ordering::Relaxed),
440 result_cache_misses: self.result_cache_misses.load(Ordering::Relaxed),
441 metadata_cache_hits: self.metadata_cache_hits.load(Ordering::Relaxed),
442 metadata_cache_misses: self.metadata_cache_misses.load(Ordering::Relaxed),
443 local_cache_hits: self.local_cache_hits.load(Ordering::Relaxed),
444 local_cache_misses: self.local_cache_misses.load(Ordering::Relaxed),
445 reload_successes: self.reload_successes.load(Ordering::Relaxed),
446 reload_failures: self.reload_failures.load(Ordering::Relaxed),
447 }
448 }
449
450 pub fn current_metadata_scope(&self) -> MetadataCacheScope {
451 self.provider
452 .read()
453 .ok()
454 .and_then(|guard| guard.as_ref().cloned())
455 .map(|handle| MetadataCacheScope {
456 datasource_id: handle.datasource_id.clone(),
457 generation: handle.generation,
458 })
459 .unwrap_or_else(|| MetadataCacheScope {
460 datasource_id: DatasourceId("sqlite:standalone".to_string()),
461 generation: Generation(0),
462 })
463 }
464
465 pub fn current_provider_kind(&self) -> Option<ProviderKind> {
466 self.provider
467 .read()
468 .ok()
469 .and_then(|guard| guard.as_ref().map(|handle| handle.kind.clone()))
470 }
471
472 pub fn record_result_cache_hit(&self) {
473 self.result_cache_hits.fetch_add(1, Ordering::Relaxed);
474 }
475
476 pub fn record_result_cache_miss(&self) {
477 self.result_cache_misses.fetch_add(1, Ordering::Relaxed);
478 }
479
480 pub fn record_metadata_cache_hit(&self) {
481 self.metadata_cache_hits.fetch_add(1, Ordering::Relaxed);
482 }
483
484 pub fn record_metadata_cache_miss(&self) {
485 self.metadata_cache_misses.fetch_add(1, Ordering::Relaxed);
486 }
487
488 pub fn record_local_cache_hit(&self) {
489 self.local_cache_hits.fetch_add(1, Ordering::Relaxed);
490 }
491
492 pub fn record_local_cache_miss(&self) {
493 self.local_cache_misses.fetch_add(1, Ordering::Relaxed);
494 }
495
496 pub fn execute(&self, req: &QueryRequest) -> KnowledgeResult<QueryResponse> {
497 let handle = self.current_handle()?;
498 self.execute_with_handle(&handle, req)
499 }
500
501 fn execute_with_handle(
502 &self,
503 handle: &Arc<ProviderHandle>,
504 req: &QueryRequest,
505 ) -> KnowledgeResult<QueryResponse> {
506 let use_global_cache =
507 matches!(req.cache_policy, CachePolicy::UseGlobal) && self.result_cache_enabled();
508 if use_global_cache && let Some(hit) = self.fetch_result_cache(handle, req) {
509 self.record_result_cache_hit();
510 if telemetry_enabled() {
511 telemetry().on_cache(&CacheTelemetryEvent {
512 layer: CacheLayer::Result,
513 outcome: CacheOutcome::Hit,
514 provider_kind: Some(handle.kind.clone()),
515 });
516 }
517 debug_kdb!(
518 "[kdb] global result cache hit kind={:?} generation={}",
519 handle.kind,
520 handle.generation.0
521 );
522 return Ok(hit);
523 }
524 if use_global_cache {
525 self.record_result_cache_miss();
526 if telemetry_enabled() {
527 telemetry().on_cache(&CacheTelemetryEvent {
528 layer: CacheLayer::Result,
529 outcome: CacheOutcome::Miss,
530 provider_kind: Some(handle.kind.clone()),
531 });
532 }
533 debug_kdb!(
534 "[kdb] global result cache miss kind={:?} generation={}",
535 handle.kind,
536 handle.generation.0
537 );
538 }
539
540 let params = params_to_fields(&req.params);
541 let mode_tag = query_mode_tag(&req.mode);
542 let started = Instant::now();
543 debug_kdb!(
544 "[kdb] execute query kind={:?} generation={} mode={:?} cache_policy={:?}",
545 handle.kind,
546 handle.generation.0,
547 req.mode,
548 req.cache_policy
549 );
550 let response = match match req.mode {
551 QueryMode::Many => {
552 if params.is_empty() {
553 handle.provider.query(&req.sql).map(QueryResponse::Rows)
554 } else {
555 handle
556 .provider
557 .query_fields(&req.sql, ¶ms)
558 .map(QueryResponse::Rows)
559 }
560 }
561 QueryMode::FirstRow => {
562 if params.is_empty() {
563 handle.provider.query_row(&req.sql).map(QueryResponse::Row)
564 } else {
565 handle
566 .provider
567 .query_named_fields(&req.sql, ¶ms)
568 .map(QueryResponse::Row)
569 }
570 }
571 } {
572 Ok(response) => {
573 if telemetry_enabled() {
574 telemetry().on_query(&QueryTelemetryEvent {
575 provider_kind: handle.kind.clone(),
576 mode: mode_tag,
577 success: true,
578 elapsed: started.elapsed(),
579 });
580 }
581 response
582 }
583 Err(err) => {
584 if telemetry_enabled() {
585 telemetry().on_query(&QueryTelemetryEvent {
586 provider_kind: handle.kind.clone(),
587 mode: mode_tag,
588 success: false,
589 elapsed: started.elapsed(),
590 });
591 }
592 return Err(err);
593 }
594 };
595
596 if use_global_cache {
597 self.save_result_cache(handle, req, response.clone());
598 debug_kdb!(
599 "[kdb] global result cache store kind={:?} generation={}",
600 handle.kind,
601 handle.generation.0
602 );
603 }
604
605 Ok(response)
606 }
607
608 pub fn execute_first_row_fields(
609 &self,
610 sql: &str,
611 params: &[DataField],
612 cache_policy: CachePolicy,
613 ) -> KnowledgeResult<RowData> {
614 let handle = self.current_handle()?;
615 self.execute_first_row_fields_with_handle(&handle, sql, params, cache_policy)
616 }
617
618 fn execute_first_row_fields_with_handle(
619 &self,
620 handle: &Arc<ProviderHandle>,
621 sql: &str,
622 params: &[DataField],
623 cache_policy: CachePolicy,
624 ) -> KnowledgeResult<RowData> {
625 let use_global_cache =
626 matches!(cache_policy, CachePolicy::UseGlobal) && self.result_cache_enabled();
627 if use_global_cache
628 && let Some(hit) = self.fetch_result_cache_by_key(result_cache_key_fields(
629 handle,
630 sql,
631 params,
632 QueryModeTag::FirstRow,
633 ))
634 {
635 self.record_result_cache_hit();
636 if telemetry_enabled() {
637 telemetry().on_cache(&CacheTelemetryEvent {
638 layer: CacheLayer::Result,
639 outcome: CacheOutcome::Hit,
640 provider_kind: Some(handle.kind.clone()),
641 });
642 }
643 return Ok(hit.into_row());
644 }
645 if use_global_cache {
646 self.record_result_cache_miss();
647 if telemetry_enabled() {
648 telemetry().on_cache(&CacheTelemetryEvent {
649 layer: CacheLayer::Result,
650 outcome: CacheOutcome::Miss,
651 provider_kind: Some(handle.kind.clone()),
652 });
653 }
654 }
655
656 let started = Instant::now();
657 let row = if params.is_empty() {
658 handle.provider.query_row(sql)
659 } else {
660 handle.provider.query_named_fields(sql, params)
661 };
662 let row = match row {
663 Ok(row) => {
664 if telemetry_enabled() {
665 telemetry().on_query(&QueryTelemetryEvent {
666 provider_kind: handle.kind.clone(),
667 mode: QueryModeTag::FirstRow,
668 success: true,
669 elapsed: started.elapsed(),
670 });
671 }
672 row
673 }
674 Err(err) => {
675 if telemetry_enabled() {
676 telemetry().on_query(&QueryTelemetryEvent {
677 provider_kind: handle.kind.clone(),
678 mode: QueryModeTag::FirstRow,
679 success: false,
680 elapsed: started.elapsed(),
681 });
682 }
683 return Err(err);
684 }
685 };
686
687 if use_global_cache {
688 self.save_result_cache_by_key(
689 result_cache_key_fields(handle, sql, params, QueryModeTag::FirstRow),
690 QueryResponse::Row(row.clone()),
691 );
692 }
693
694 Ok(row)
695 }
696
697 pub async fn execute_async(&self, req: &QueryRequest) -> KnowledgeResult<QueryResponse> {
698 let handle = self.current_handle()?;
699 if matches!(handle.kind, ProviderKind::SqliteAuthority) {
700 let handle = handle.clone();
701 let req = req.clone();
702 return task::spawn_blocking(move || runtime().execute_with_handle(&handle, &req))
703 .await
704 .map_err(|err| {
705 KnowledgeReason::from_logic()
706 .to_err()
707 .with_detail(format!("knowledge async sqlite query join failed: {err}"))
708 })?;
709 }
710 let use_global_cache =
711 matches!(req.cache_policy, CachePolicy::UseGlobal) && self.result_cache_enabled();
712 if use_global_cache && let Some(hit) = self.fetch_result_cache(&handle, req) {
713 self.record_result_cache_hit();
714 if telemetry_enabled() {
715 telemetry().on_cache(&CacheTelemetryEvent {
716 layer: CacheLayer::Result,
717 outcome: CacheOutcome::Hit,
718 provider_kind: Some(handle.kind.clone()),
719 });
720 }
721 return Ok(hit);
722 }
723 if use_global_cache {
724 self.record_result_cache_miss();
725 if telemetry_enabled() {
726 telemetry().on_cache(&CacheTelemetryEvent {
727 layer: CacheLayer::Result,
728 outcome: CacheOutcome::Miss,
729 provider_kind: Some(handle.kind.clone()),
730 });
731 }
732 }
733
734 let params = params_to_fields(&req.params);
735 let mode_tag = query_mode_tag(&req.mode);
736 let started = Instant::now();
737 let response = match req.mode {
738 QueryMode::Many => {
739 if params.is_empty() {
740 handle
741 .provider
742 .query_async(&req.sql)
743 .await
744 .map(QueryResponse::Rows)
745 } else {
746 handle
747 .provider
748 .query_fields_async(&req.sql, ¶ms)
749 .await
750 .map(QueryResponse::Rows)
751 }
752 }
753 QueryMode::FirstRow => {
754 if params.is_empty() {
755 handle
756 .provider
757 .query_row_async(&req.sql)
758 .await
759 .map(QueryResponse::Row)
760 } else {
761 handle
762 .provider
763 .query_named_fields_async(&req.sql, ¶ms)
764 .await
765 .map(QueryResponse::Row)
766 }
767 }
768 };
769 let response = match response {
770 Ok(response) => {
771 if telemetry_enabled() {
772 telemetry().on_query(&QueryTelemetryEvent {
773 provider_kind: handle.kind.clone(),
774 mode: mode_tag,
775 success: true,
776 elapsed: started.elapsed(),
777 });
778 }
779 response
780 }
781 Err(err) => {
782 if telemetry_enabled() {
783 telemetry().on_query(&QueryTelemetryEvent {
784 provider_kind: handle.kind.clone(),
785 mode: mode_tag,
786 success: false,
787 elapsed: started.elapsed(),
788 });
789 }
790 return Err(err);
791 }
792 };
793
794 if use_global_cache {
795 self.save_result_cache(&handle, req, response.clone());
796 }
797
798 Ok(response)
799 }
800
801 pub async fn execute_first_row_fields_async(
802 &self,
803 sql: &str,
804 params: &[DataField],
805 cache_policy: CachePolicy,
806 ) -> KnowledgeResult<RowData> {
807 let handle = self.current_handle()?;
808 if matches!(handle.kind, ProviderKind::SqliteAuthority) {
809 let handle = handle.clone();
810 let sql = sql.to_string();
811 let params = params.to_vec();
812 return task::spawn_blocking(move || {
813 runtime().execute_first_row_fields_with_handle(&handle, &sql, ¶ms, cache_policy)
814 })
815 .await
816 .map_err(|err| {
817 KnowledgeReason::from_logic().to_err().with_detail(format!(
818 "knowledge async sqlite first-row query join failed: {err}"
819 ))
820 })?;
821 }
822 let use_global_cache =
823 matches!(cache_policy, CachePolicy::UseGlobal) && self.result_cache_enabled();
824 if use_global_cache
825 && let Some(hit) = self.fetch_result_cache_by_key(result_cache_key_fields(
826 &handle,
827 sql,
828 params,
829 QueryModeTag::FirstRow,
830 ))
831 {
832 self.record_result_cache_hit();
833 if telemetry_enabled() {
834 telemetry().on_cache(&CacheTelemetryEvent {
835 layer: CacheLayer::Result,
836 outcome: CacheOutcome::Hit,
837 provider_kind: Some(handle.kind.clone()),
838 });
839 }
840 return Ok(hit.into_row());
841 }
842 if use_global_cache {
843 self.record_result_cache_miss();
844 if telemetry_enabled() {
845 telemetry().on_cache(&CacheTelemetryEvent {
846 layer: CacheLayer::Result,
847 outcome: CacheOutcome::Miss,
848 provider_kind: Some(handle.kind.clone()),
849 });
850 }
851 }
852
853 let started = Instant::now();
854 let row = if params.is_empty() {
855 handle.provider.query_row_async(sql).await
856 } else {
857 handle.provider.query_named_fields_async(sql, params).await
858 };
859 let row = match row {
860 Ok(row) => {
861 if telemetry_enabled() {
862 telemetry().on_query(&QueryTelemetryEvent {
863 provider_kind: handle.kind.clone(),
864 mode: QueryModeTag::FirstRow,
865 success: true,
866 elapsed: started.elapsed(),
867 });
868 }
869 row
870 }
871 Err(err) => {
872 if telemetry_enabled() {
873 telemetry().on_query(&QueryTelemetryEvent {
874 provider_kind: handle.kind.clone(),
875 mode: QueryModeTag::FirstRow,
876 success: false,
877 elapsed: started.elapsed(),
878 });
879 }
880 return Err(err);
881 }
882 };
883
884 if use_global_cache {
885 self.save_result_cache_by_key(
886 result_cache_key_fields(&handle, sql, params, QueryModeTag::FirstRow),
887 QueryResponse::Row(row.clone()),
888 );
889 }
890
891 Ok(row)
892 }
893
894 fn current_handle(&self) -> KnowledgeResult<Arc<ProviderHandle>> {
895 self.provider
896 .read()
897 .expect("runtime provider lock poisoned")
898 .clone()
899 .ok_or_else(|| {
900 KnowledgeReason::from_logic()
901 .to_err()
902 .with_detail("knowledge provider not initialized")
903 })
904 }
905
906 fn current_generation_from_provider(&self) -> Option<Generation> {
907 self.provider
908 .read()
909 .ok()
910 .and_then(|guard| guard.as_ref().map(|handle| handle.generation))
911 }
912
913 fn fetch_result_cache(
914 &self,
915 handle: &ProviderHandle,
916 req: &QueryRequest,
917 ) -> Option<QueryResponse> {
918 self.fetch_result_cache_by_key(result_cache_key(handle, req))
919 }
920
921 fn fetch_result_cache_by_key(&self, key: ResultCacheKey) -> Option<QueryResponse> {
922 if !self.result_cache_enabled() {
923 return None;
924 }
925 let cached = self
926 .result_cache
927 .read()
928 .ok()
929 .and_then(|cache| cache.peek(&key).cloned())?;
930 if cached.cached_at.elapsed() > self.result_cache_ttl() {
931 if let Ok(mut cache) = self.result_cache.write() {
932 let _ = cache.pop(&key);
933 }
934 return None;
935 }
936 Some((*cached.response).clone())
937 }
938
939 fn save_result_cache(
940 &self,
941 handle: &ProviderHandle,
942 req: &QueryRequest,
943 response: QueryResponse,
944 ) {
945 self.save_result_cache_by_key(result_cache_key(handle, req), response);
946 }
947
948 fn save_result_cache_by_key(&self, key: ResultCacheKey, response: QueryResponse) {
949 if let Ok(mut cache) = self.result_cache.write() {
950 cache.put(
951 key,
952 CachedQueryResponse {
953 response: Arc::new(response),
954 cached_at: Instant::now(),
955 },
956 );
957 }
958 }
959
960 #[inline]
961 fn result_cache_enabled(&self) -> bool {
962 self.result_cache_enabled.load(Ordering::Relaxed)
963 }
964
965 #[inline]
966 fn result_cache_ttl(&self) -> Duration {
967 Duration::from_millis(self.result_cache_ttl_ms.load(Ordering::Relaxed))
968 }
969}
970
971pub fn runtime() -> &'static KnowledgeRuntime {
972 static RUNTIME: OnceLock<KnowledgeRuntime> = OnceLock::new();
973 RUNTIME.get_or_init(|| KnowledgeRuntime::new(1024))
974}
975
976#[cfg(test)]
977pub(crate) struct RuntimeTestGuard(tokio::sync::Mutex<()>);
978
979#[cfg(test)]
980impl RuntimeTestGuard {
981 pub(crate) fn lock(&self) -> Result<tokio::sync::MutexGuard<'_, ()>, std::convert::Infallible> {
982 Ok(self.0.blocking_lock())
983 }
984
985 pub(crate) async fn lock_async(&self) -> tokio::sync::MutexGuard<'_, ()> {
986 self.0.lock().await
987 }
988}
989
990#[cfg(test)]
991pub(crate) fn runtime_test_guard() -> &'static RuntimeTestGuard {
992 static GUARD: OnceLock<RuntimeTestGuard> = OnceLock::new();
993 GUARD.get_or_init(|| RuntimeTestGuard(tokio::sync::Mutex::new(())))
994}
995
996fn result_cache_key(handle: &ProviderHandle, req: &QueryRequest) -> ResultCacheKey {
997 ResultCacheKey {
998 datasource_id: handle.datasource_id.clone(),
999 generation: handle.generation,
1000 query_hash: stable_hash(&req.sql),
1001 params_hash: stable_params_hash(&req.params),
1002 mode: match req.mode {
1003 QueryMode::Many => QueryModeTag::Many,
1004 QueryMode::FirstRow => QueryModeTag::FirstRow,
1005 },
1006 }
1007}
1008
1009fn result_cache_key_fields(
1010 handle: &ProviderHandle,
1011 sql: &str,
1012 params: &[DataField],
1013 mode: QueryModeTag,
1014) -> ResultCacheKey {
1015 ResultCacheKey {
1016 datasource_id: handle.datasource_id.clone(),
1017 generation: handle.generation,
1018 query_hash: stable_hash(sql),
1019 params_hash: stable_field_params_hash(params),
1020 mode,
1021 }
1022}
1023
1024fn query_mode_tag(mode: &QueryMode) -> QueryModeTag {
1025 match mode {
1026 QueryMode::Many => QueryModeTag::Many,
1027 QueryMode::FirstRow => QueryModeTag::FirstRow,
1028 }
1029}
1030
1031fn stable_hash(value: &str) -> u64 {
1032 let mut hasher = DefaultHasher::new();
1033 value.hash(&mut hasher);
1034 hasher.finish()
1035}
1036
1037fn stable_params_hash(params: &[QueryParam]) -> u64 {
1038 let mut hasher = DefaultHasher::new();
1039 for param in params {
1040 param.name.hash(&mut hasher);
1041 match ¶m.value {
1042 QueryValue::Null => 0u8.hash(&mut hasher),
1043 QueryValue::Bool(value) => {
1044 1u8.hash(&mut hasher);
1045 value.hash(&mut hasher);
1046 }
1047 QueryValue::Int(value) => {
1048 2u8.hash(&mut hasher);
1049 value.hash(&mut hasher);
1050 }
1051 QueryValue::Float(value) => {
1052 3u8.hash(&mut hasher);
1053 value.to_bits().hash(&mut hasher);
1054 }
1055 QueryValue::Text(value) => {
1056 4u8.hash(&mut hasher);
1057 value.hash(&mut hasher);
1058 }
1059 }
1060 }
1061 hasher.finish()
1062}
1063
1064fn stable_field_params_hash(params: &[DataField]) -> u64 {
1065 let mut hasher = DefaultHasher::new();
1066 for field in params {
1067 field.get_name().hash(&mut hasher);
1068 match field.get_value() {
1069 Value::Null | Value::Ignore(_) => 0u8.hash(&mut hasher),
1070 Value::Bool(value) => {
1071 1u8.hash(&mut hasher);
1072 value.hash(&mut hasher);
1073 }
1074 Value::Digit(value) => {
1075 2u8.hash(&mut hasher);
1076 value.hash(&mut hasher);
1077 }
1078 Value::Float(value) => {
1079 3u8.hash(&mut hasher);
1080 value.to_bits().hash(&mut hasher);
1081 }
1082 Value::Chars(value) => {
1083 4u8.hash(&mut hasher);
1084 value.hash(&mut hasher);
1085 }
1086 Value::Symbol(value) => {
1087 5u8.hash(&mut hasher);
1088 value.hash(&mut hasher);
1089 }
1090 Value::Time(value) => {
1091 6u8.hash(&mut hasher);
1092 value.hash(&mut hasher);
1093 }
1094 Value::Hex(value) => {
1095 7u8.hash(&mut hasher);
1096 value.to_string().hash(&mut hasher);
1097 }
1098 Value::IpNet(value) => {
1099 8u8.hash(&mut hasher);
1100 value.to_string().hash(&mut hasher);
1101 }
1102 Value::IpAddr(value) => {
1103 9u8.hash(&mut hasher);
1104 value.hash(&mut hasher);
1105 }
1106 Value::Obj(value) => {
1107 10u8.hash(&mut hasher);
1108 format!("{:?}", value).hash(&mut hasher);
1109 }
1110 Value::Array(value) => {
1111 11u8.hash(&mut hasher);
1112 format!("{:?}", value).hash(&mut hasher);
1113 }
1114 Value::Domain(value) => {
1115 12u8.hash(&mut hasher);
1116 value.0.hash(&mut hasher);
1117 }
1118 Value::Url(value) => {
1119 13u8.hash(&mut hasher);
1120 value.0.hash(&mut hasher);
1121 }
1122 Value::Email(value) => {
1123 14u8.hash(&mut hasher);
1124 value.0.hash(&mut hasher);
1125 }
1126 Value::IdCard(value) => {
1127 15u8.hash(&mut hasher);
1128 value.0.hash(&mut hasher);
1129 }
1130 Value::MobilePhone(value) => {
1131 16u8.hash(&mut hasher);
1132 value.0.hash(&mut hasher);
1133 }
1134 }
1135 }
1136 hasher.finish()
1137}
1138
1139pub fn fields_to_params(params: &[DataField]) -> Vec<QueryParam> {
1140 params
1141 .iter()
1142 .map(|field| {
1143 let value = match field.get_value() {
1144 Value::Null | Value::Ignore(_) => QueryValue::Null,
1145 Value::Bool(value) => QueryValue::Bool(*value),
1146 Value::Digit(value) => QueryValue::Int(*value),
1147 Value::Float(value) => QueryValue::Float(*value),
1148 Value::Chars(value) => QueryValue::Text(value.to_string()),
1149 Value::Symbol(value) => QueryValue::Text(value.to_string()),
1150 Value::Time(value) => QueryValue::Text(value.to_string()),
1151 Value::Hex(value) => QueryValue::Text(value.to_string()),
1152 Value::IpNet(value) => QueryValue::Text(value.to_string()),
1153 Value::IpAddr(value) => QueryValue::Text(value.to_string()),
1154 Value::Obj(value) => QueryValue::Text(format!("{:?}", value)),
1155 Value::Array(value) => QueryValue::Text(format!("{:?}", value)),
1156 Value::Domain(value) => QueryValue::Text(value.0.to_string()),
1157 Value::Url(value) => QueryValue::Text(value.0.to_string()),
1158 Value::Email(value) => QueryValue::Text(value.0.to_string()),
1159 Value::IdCard(value) => QueryValue::Text(value.0.to_string()),
1160 Value::MobilePhone(value) => QueryValue::Text(value.0.to_string()),
1161 };
1162 QueryParam {
1163 name: field.get_name().to_string(),
1164 value,
1165 }
1166 })
1167 .collect()
1168}
1169
1170pub fn params_to_fields(params: &[QueryParam]) -> Vec<DataField> {
1171 params
1172 .iter()
1173 .map(|param| match ¶m.value {
1174 QueryValue::Null => {
1175 DataField::new(DataType::default(), param.name.clone(), Value::Null)
1176 }
1177 QueryValue::Bool(value) => {
1178 DataField::new(DataType::default(), param.name.clone(), Value::Bool(*value))
1179 }
1180 QueryValue::Int(value) => DataField::from_digit(param.name.clone(), *value),
1181 QueryValue::Float(value) => DataField::from_float(param.name.clone(), *value),
1182 QueryValue::Text(value) => DataField::from_chars(param.name.clone(), value.clone()),
1183 })
1184 .collect()
1185}
1186
1187#[cfg(test)]
1188mod tests {
1189 use super::*;
1190 use async_trait::async_trait;
1191 use std::sync::Arc;
1192 use wp_model_core::model::Value;
1193
1194 struct TestProvider {
1195 value: &'static str,
1196 }
1197
1198 #[async_trait]
1199 impl ProviderExecutor for TestProvider {
1200 fn query(&self, _sql: &str) -> KnowledgeResult<Vec<RowData>> {
1201 Ok(vec![vec![DataField::from_chars("value", self.value)]])
1202 }
1203
1204 fn query_fields(&self, _sql: &str, _params: &[DataField]) -> KnowledgeResult<Vec<RowData>> {
1205 self.query("")
1206 }
1207
1208 fn query_row(&self, _sql: &str) -> KnowledgeResult<RowData> {
1209 Ok(vec![DataField::from_chars("value", self.value)])
1210 }
1211
1212 fn query_named_fields(
1213 &self,
1214 _sql: &str,
1215 _params: &[DataField],
1216 ) -> KnowledgeResult<RowData> {
1217 self.query_row("")
1218 }
1219 }
1220
1221 #[test]
1222 fn query_param_hash_is_stable() {
1223 let params = vec![
1224 QueryParam {
1225 name: ":id".to_string(),
1226 value: QueryValue::Int(7),
1227 },
1228 QueryParam {
1229 name: ":name".to_string(),
1230 value: QueryValue::Text("abc".to_string()),
1231 },
1232 ];
1233 assert_eq!(stable_params_hash(¶ms), stable_params_hash(¶ms));
1234 }
1235
1236 #[test]
1237 fn fields_to_params_preserves_raw_chars_value() {
1238 let fields = [DataField::from_chars(
1239 ":name".to_string(),
1240 "令狐冲".to_string(),
1241 )];
1242 let params = fields_to_params(&fields);
1243 assert_eq!(params.len(), 1);
1244 match ¶ms[0].value {
1245 QueryValue::Text(value) => assert_eq!(value, "令狐冲"),
1246 other => panic!("unexpected param value: {other:?}"),
1247 }
1248 let roundtrip = params_to_fields(¶ms);
1249 assert!(matches!(roundtrip[0].get_value(), Value::Chars(_)));
1250 }
1251
1252 #[tokio::test(flavor = "current_thread")]
1253 async fn sqlite_async_bridge_keeps_captured_handle_after_reload() {
1254 let _guard = runtime_test_guard().lock_async().await;
1255 runtime()
1256 .install_provider(
1257 ProviderKind::SqliteAuthority,
1258 DatasourceId("sqlite:old".to_string()),
1259 |_generation| Ok(Arc::new(TestProvider { value: "old" })),
1260 )
1261 .expect("install old provider");
1262 let old_handle = runtime().current_handle().expect("current old handle");
1263
1264 runtime()
1265 .install_provider(
1266 ProviderKind::SqliteAuthority,
1267 DatasourceId("sqlite:new".to_string()),
1268 |_generation| Ok(Arc::new(TestProvider { value: "new" })),
1269 )
1270 .expect("install new provider");
1271
1272 let req = QueryRequest::first_row("SELECT value", Vec::new(), CachePolicy::Bypass);
1273 let row = task::spawn_blocking(move || runtime().execute_with_handle(&old_handle, &req))
1274 .await
1275 .expect("join sqlite bridge")
1276 .expect("execute old handle")
1277 .into_row();
1278 assert_eq!(row[0].to_string(), "chars(old)");
1279 }
1280}