1use base64::Engine as _;
2use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::sync::{Arc, Mutex, OnceLock};
6
7#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
10#[serde(rename_all = "snake_case")]
11pub enum LuaRuntimeDatabaseProviderMode {
12 #[default]
15 DynamicLibrary,
16 HostCallback,
19 SpaceController,
22}
23
24#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
27#[serde(rename_all = "snake_case")]
28pub enum LuaRuntimeDatabaseCallbackMode {
29 #[default]
32 Standard,
33 Json,
36}
37
38#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
41#[serde(rename_all = "snake_case")]
42pub enum RuntimeDatabaseKind {
43 Sqlite,
46 LanceDb,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
54pub struct RuntimeDatabaseBindingContext {
55 pub space_label: String,
58 pub skill_id: String,
61 pub binding_tag: String,
64 pub root_name: String,
67 pub space_root: String,
70 pub skill_dir: String,
73 pub skill_dir_name: String,
76 pub database_kind: RuntimeDatabaseKind,
79 pub default_database_path: String,
82}
83
84impl RuntimeDatabaseBindingContext {
85 pub fn new(
88 space_label: impl Into<String>,
89 skill_id: impl Into<String>,
90 root_name: impl Into<String>,
91 space_root: impl Into<String>,
92 skill_dir: impl Into<String>,
93 skill_dir_name: impl Into<String>,
94 database_kind: RuntimeDatabaseKind,
95 default_database_path: impl Into<String>,
96 ) -> Self {
97 let space_label = space_label.into();
98 let skill_id = skill_id.into();
99 Self {
100 binding_tag: format!("{}-{}", space_label, skill_id),
101 space_label,
102 skill_id,
103 root_name: root_name.into(),
104 space_root: space_root.into(),
105 skill_dir: skill_dir.into(),
106 skill_dir_name: skill_dir_name.into(),
107 database_kind,
108 default_database_path: default_database_path.into(),
109 }
110 }
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
116#[serde(rename_all = "snake_case")]
117pub enum RuntimeSqliteProviderAction {
118 ExecuteScript,
121 ExecuteBatch,
124 QueryJson,
127 QueryStream,
130 QueryStreamWaitMetrics,
133 QueryStreamChunk,
136 QueryStreamClose,
139 TokenizeText,
142 UpsertCustomWord,
145 RemoveCustomWord,
148 ListCustomWords,
151 EnsureFtsIndex,
154 RebuildFtsIndex,
157 UpsertFtsDocument,
160 DeleteFtsDocument,
163 SearchFts,
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
171#[serde(rename_all = "snake_case")]
172pub enum RuntimeLanceDbProviderAction {
173 CreateTable,
176 VectorUpsert,
179 VectorSearch,
182 Delete,
185 DropTable,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
193pub struct RuntimeSqliteProviderRequest {
194 pub action: RuntimeSqliteProviderAction,
197 pub binding: RuntimeDatabaseBindingContext,
200 pub input: Value,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
208pub struct RuntimeLanceDbProviderRequest {
209 pub action: RuntimeLanceDbProviderAction,
212 pub binding: RuntimeDatabaseBindingContext,
215 pub input: Value,
218}
219
220pub type RuntimeSqliteProviderCallback =
223 Arc<dyn Fn(&RuntimeSqliteProviderRequest) -> Result<Value, String> + Send + Sync>;
224
225pub type RuntimeLanceDbProviderCallback = Arc<
228 dyn Fn(&RuntimeLanceDbProviderRequest) -> Result<RuntimeLanceDbProviderResult, String>
229 + Send
230 + Sync,
231>;
232
233pub type RuntimeSqliteProviderJsonCallback =
236 Arc<dyn Fn(&str) -> Result<String, String> + Send + Sync>;
237
238pub type RuntimeLanceDbProviderJsonCallback =
241 Arc<dyn Fn(&str) -> Result<String, String> + Send + Sync>;
242
243#[derive(Clone, Default)]
246pub(crate) struct RuntimeDatabaseProviderCallbacks {
247 sqlite_standard: Option<RuntimeSqliteProviderCallback>,
250 lancedb_standard: Option<RuntimeLanceDbProviderCallback>,
253 sqlite_json: Option<RuntimeSqliteProviderJsonCallback>,
256 lancedb_json: Option<RuntimeLanceDbProviderJsonCallback>,
259}
260
261impl RuntimeDatabaseProviderCallbacks {
262 pub(crate) fn capture_process_defaults() -> Result<Self, String> {
265 Ok(Self {
266 sqlite_standard: take_optional_callback(sqlite_provider_callback_registry())?,
267 lancedb_standard: take_optional_callback(lancedb_provider_callback_registry())?,
268 sqlite_json: take_optional_callback(sqlite_provider_json_callback_registry())?,
269 lancedb_json: take_optional_callback(lancedb_provider_json_callback_registry())?,
270 })
271 }
272
273 pub(crate) fn has_sqlite_provider_callback_for_mode(
276 &self,
277 callback_mode: LuaRuntimeDatabaseCallbackMode,
278 ) -> bool {
279 match callback_mode {
280 LuaRuntimeDatabaseCallbackMode::Standard => self.sqlite_standard.is_some(),
281 LuaRuntimeDatabaseCallbackMode::Json => self.sqlite_json.is_some(),
282 }
283 }
284
285 pub(crate) fn has_lancedb_provider_callback_for_mode(
288 &self,
289 callback_mode: LuaRuntimeDatabaseCallbackMode,
290 ) -> bool {
291 match callback_mode {
292 LuaRuntimeDatabaseCallbackMode::Standard => self.lancedb_standard.is_some(),
293 LuaRuntimeDatabaseCallbackMode::Json => self.lancedb_json.is_some(),
294 }
295 }
296
297 pub(crate) fn dispatch_sqlite_provider_request(
300 &self,
301 request: &RuntimeSqliteProviderRequest,
302 callback_mode: LuaRuntimeDatabaseCallbackMode,
303 ) -> Result<Value, String> {
304 match callback_mode {
305 LuaRuntimeDatabaseCallbackMode::Standard => {
306 let callback = self.sqlite_standard.clone().ok_or_else(|| {
307 "SQLite host-callback mode requires one registered standard callback"
308 .to_string()
309 })?;
310 callback(request)
311 }
312 LuaRuntimeDatabaseCallbackMode::Json => {
313 let callback = self.sqlite_json.clone().ok_or_else(|| {
314 "SQLite host-callback JSON mode requires one registered JSON callback"
315 .to_string()
316 })?;
317 let request_json = serde_json::to_string(request).map_err(|error| {
318 format!("failed to encode sqlite provider request: {}", error)
319 })?;
320 let response_json = callback(&request_json)?;
321 serde_json::from_str::<Value>(&response_json).map_err(|error| {
322 format!("failed to parse sqlite provider response json: {}", error)
323 })
324 }
325 }
326 }
327
328 pub(crate) fn dispatch_lancedb_provider_request(
331 &self,
332 request: &RuntimeLanceDbProviderRequest,
333 callback_mode: LuaRuntimeDatabaseCallbackMode,
334 ) -> Result<RuntimeLanceDbProviderResult, String> {
335 match callback_mode {
336 LuaRuntimeDatabaseCallbackMode::Standard => {
337 let callback = self.lancedb_standard.clone().ok_or_else(|| {
338 "LanceDB host-callback mode requires one registered standard callback"
339 .to_string()
340 })?;
341 callback(request)
342 }
343 LuaRuntimeDatabaseCallbackMode::Json => {
344 let callback = self.lancedb_json.clone().ok_or_else(|| {
345 "LanceDB host-callback JSON mode requires one registered JSON callback"
346 .to_string()
347 })?;
348 let request_json = serde_json::to_string(request).map_err(|error| {
349 format!("failed to encode lancedb provider request: {}", error)
350 })?;
351 let response_json = callback(&request_json)?;
352 let value: Value = serde_json::from_str(&response_json).map_err(|error| {
353 format!("failed to parse lancedb provider response json: {}", error)
354 })?;
355 let meta = value
356 .get("meta")
357 .cloned()
358 .unwrap_or_else(|| Value::Object(Default::default()));
359 let bytes = value
360 .get("data_base64")
361 .and_then(Value::as_str)
362 .map(|text| {
363 BASE64_STANDARD.decode(text.as_bytes()).map_err(|error| {
364 format!("failed to decode lancedb provider data_base64: {}", error)
365 })
366 })
367 .transpose()?
368 .unwrap_or_default();
369 Ok(RuntimeLanceDbProviderResult::binary(meta, bytes))
370 }
371 }
372 }
373}
374
375#[derive(Debug, Clone, PartialEq)]
378pub struct RuntimeLanceDbProviderResult {
379 pub meta: Value,
382 pub bytes: Vec<u8>,
385}
386
387impl RuntimeLanceDbProviderResult {
388 pub fn json(meta: Value) -> Self {
391 Self {
392 meta,
393 bytes: Vec::new(),
394 }
395 }
396
397 pub fn binary(meta: Value, bytes: Vec<u8>) -> Self {
400 Self { meta, bytes }
401 }
402}
403
404pub fn set_sqlite_provider_callback(callback: Option<RuntimeSqliteProviderCallback>) {
407 let registry = sqlite_provider_callback_registry();
408 let mut guard = registry.lock().unwrap();
409 *guard = callback;
410}
411
412pub fn set_lancedb_provider_callback(callback: Option<RuntimeLanceDbProviderCallback>) {
415 let registry = lancedb_provider_callback_registry();
416 let mut guard = registry.lock().unwrap();
417 *guard = callback;
418}
419
420pub fn set_sqlite_provider_json_callback(callback: Option<RuntimeSqliteProviderJsonCallback>) {
423 let registry = sqlite_provider_json_callback_registry();
424 let mut guard = registry.lock().unwrap();
425 *guard = callback;
426}
427
428pub fn set_lancedb_provider_json_callback(callback: Option<RuntimeLanceDbProviderJsonCallback>) {
431 let registry = lancedb_provider_json_callback_registry();
432 let mut guard = registry.lock().unwrap();
433 *guard = callback;
434}
435
436fn take_optional_callback<T: Clone>(
439 registry: &'static Mutex<Option<T>>,
440) -> Result<Option<T>, String> {
441 let guard = registry
442 .lock()
443 .map_err(|_| "Database provider callback registry lock poisoned".to_string())?;
444 Ok(guard.clone())
445}
446
447fn sqlite_provider_callback_registry() -> &'static Mutex<Option<RuntimeSqliteProviderCallback>> {
450 static REGISTRY: OnceLock<Mutex<Option<RuntimeSqliteProviderCallback>>> = OnceLock::new();
451 REGISTRY.get_or_init(|| Mutex::new(None))
452}
453
454fn lancedb_provider_callback_registry() -> &'static Mutex<Option<RuntimeLanceDbProviderCallback>> {
457 static REGISTRY: OnceLock<Mutex<Option<RuntimeLanceDbProviderCallback>>> = OnceLock::new();
458 REGISTRY.get_or_init(|| Mutex::new(None))
459}
460
461fn sqlite_provider_json_callback_registry()
464-> &'static Mutex<Option<RuntimeSqliteProviderJsonCallback>> {
465 static REGISTRY: OnceLock<Mutex<Option<RuntimeSqliteProviderJsonCallback>>> = OnceLock::new();
466 REGISTRY.get_or_init(|| Mutex::new(None))
467}
468
469fn lancedb_provider_json_callback_registry()
472-> &'static Mutex<Option<RuntimeLanceDbProviderJsonCallback>> {
473 static REGISTRY: OnceLock<Mutex<Option<RuntimeLanceDbProviderJsonCallback>>> = OnceLock::new();
474 REGISTRY.get_or_init(|| Mutex::new(None))
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use serde_json::json;
481 use std::sync::{Mutex, OnceLock};
482
483 fn database_callback_test_lock() -> &'static Mutex<()> {
486 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
487 LOCK.get_or_init(|| Mutex::new(()))
488 }
489
490 struct ProcessCallbackRestoreGuard {
493 snapshot: RuntimeDatabaseProviderCallbacks,
494 }
495
496 impl ProcessCallbackRestoreGuard {
497 fn capture() -> Self {
500 Self {
501 snapshot: RuntimeDatabaseProviderCallbacks::capture_process_defaults()
502 .expect("capture callback snapshot"),
503 }
504 }
505 }
506
507 impl Drop for ProcessCallbackRestoreGuard {
508 fn drop(&mut self) {
509 set_sqlite_provider_callback(self.snapshot.sqlite_standard.clone());
510 set_lancedb_provider_callback(self.snapshot.lancedb_standard.clone());
511 set_sqlite_provider_json_callback(self.snapshot.sqlite_json.clone());
512 set_lancedb_provider_json_callback(self.snapshot.lancedb_json.clone());
513 }
514 }
515
516 fn sample_binding_context(database_kind: RuntimeDatabaseKind) -> RuntimeDatabaseBindingContext {
519 RuntimeDatabaseBindingContext::new(
520 "ROOT",
521 "test-skill",
522 "ROOT",
523 "D:/runtime-test-root/__database",
524 "D:/runtime-test-root/skills/test-skill",
525 "test-skill",
526 database_kind,
527 "D:/runtime-test-root/__database/default.db",
528 )
529 }
530
531 #[test]
534 fn captured_callback_snapshots_stay_engine_scoped() {
535 let _serial_guard = database_callback_test_lock()
536 .lock()
537 .expect("lock callback test guard");
538 let _restore_guard = ProcessCallbackRestoreGuard::capture();
539
540 set_sqlite_provider_callback(Some(Arc::new(|_| {
541 Ok(json!({ "source": "sqlite-standard-a" }))
542 })));
543 set_sqlite_provider_json_callback(Some(Arc::new(|_| {
544 Ok("{\"source\":\"sqlite-json-a\"}".to_string())
545 })));
546 set_lancedb_provider_callback(Some(Arc::new(|_| {
547 Ok(RuntimeLanceDbProviderResult::json(
548 json!({ "source": "lancedb-standard-a" }),
549 ))
550 })));
551 set_lancedb_provider_json_callback(Some(Arc::new(|_| {
552 Ok("{\"meta\":{\"source\":\"lancedb-json-a\"}}".to_string())
553 })));
554 let snapshot_a = RuntimeDatabaseProviderCallbacks::capture_process_defaults()
555 .expect("capture callback snapshot A");
556
557 set_sqlite_provider_callback(Some(Arc::new(|_| {
558 Ok(json!({ "source": "sqlite-standard-b" }))
559 })));
560 set_sqlite_provider_json_callback(Some(Arc::new(|_| {
561 Ok("{\"source\":\"sqlite-json-b\"}".to_string())
562 })));
563 set_lancedb_provider_callback(Some(Arc::new(|_| {
564 Ok(RuntimeLanceDbProviderResult::json(
565 json!({ "source": "lancedb-standard-b" }),
566 ))
567 })));
568 set_lancedb_provider_json_callback(Some(Arc::new(|_| {
569 Ok("{\"meta\":{\"source\":\"lancedb-json-b\"}}".to_string())
570 })));
571 let snapshot_b = RuntimeDatabaseProviderCallbacks::capture_process_defaults()
572 .expect("capture callback snapshot B");
573
574 let sqlite_request = RuntimeSqliteProviderRequest {
575 action: RuntimeSqliteProviderAction::QueryJson,
576 binding: sample_binding_context(RuntimeDatabaseKind::Sqlite),
577 input: json!({ "sql": "select 1" }),
578 };
579 let lancedb_request = RuntimeLanceDbProviderRequest {
580 action: RuntimeLanceDbProviderAction::VectorSearch,
581 binding: sample_binding_context(RuntimeDatabaseKind::LanceDb),
582 input: json!({ "table": "demo" }),
583 };
584
585 assert_eq!(
586 snapshot_a
587 .dispatch_sqlite_provider_request(
588 &sqlite_request,
589 LuaRuntimeDatabaseCallbackMode::Standard,
590 )
591 .expect("dispatch sqlite standard A"),
592 json!({ "source": "sqlite-standard-a" })
593 );
594 assert_eq!(
595 snapshot_a
596 .dispatch_sqlite_provider_request(
597 &sqlite_request,
598 LuaRuntimeDatabaseCallbackMode::Json,
599 )
600 .expect("dispatch sqlite json A"),
601 json!({ "source": "sqlite-json-a" })
602 );
603 assert_eq!(
604 snapshot_b
605 .dispatch_sqlite_provider_request(
606 &sqlite_request,
607 LuaRuntimeDatabaseCallbackMode::Standard,
608 )
609 .expect("dispatch sqlite standard B"),
610 json!({ "source": "sqlite-standard-b" })
611 );
612 assert_eq!(
613 snapshot_b
614 .dispatch_sqlite_provider_request(
615 &sqlite_request,
616 LuaRuntimeDatabaseCallbackMode::Json,
617 )
618 .expect("dispatch sqlite json B"),
619 json!({ "source": "sqlite-json-b" })
620 );
621
622 assert_eq!(
623 snapshot_a
624 .dispatch_lancedb_provider_request(
625 &lancedb_request,
626 LuaRuntimeDatabaseCallbackMode::Standard,
627 )
628 .expect("dispatch lancedb standard A")
629 .meta,
630 json!({ "source": "lancedb-standard-a" })
631 );
632 assert_eq!(
633 snapshot_a
634 .dispatch_lancedb_provider_request(
635 &lancedb_request,
636 LuaRuntimeDatabaseCallbackMode::Json,
637 )
638 .expect("dispatch lancedb json A")
639 .meta,
640 json!({ "source": "lancedb-json-a" })
641 );
642 assert_eq!(
643 snapshot_b
644 .dispatch_lancedb_provider_request(
645 &lancedb_request,
646 LuaRuntimeDatabaseCallbackMode::Standard,
647 )
648 .expect("dispatch lancedb standard B")
649 .meta,
650 json!({ "source": "lancedb-standard-b" })
651 );
652 assert_eq!(
653 snapshot_b
654 .dispatch_lancedb_provider_request(
655 &lancedb_request,
656 LuaRuntimeDatabaseCallbackMode::Json,
657 )
658 .expect("dispatch lancedb json B")
659 .meta,
660 json!({ "source": "lancedb-json-b" })
661 );
662 }
663}