1use std::path::Path;
4use std::sync::Mutex;
5
6use rusqlite::Connection;
7use serde_json::Value;
8use uuid::Uuid;
9use worldinterface_core::id::{FlowRunId, NodeId};
10
11use crate::error::ContextStoreError;
12use crate::store::ContextStore;
13
14pub struct SqliteContextStore {
20 conn: Mutex<Connection>,
21}
22
23impl SqliteContextStore {
24 pub fn open(path: impl AsRef<Path>) -> Result<Self, ContextStoreError> {
29 let conn =
30 Connection::open(path).map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
31 Self::initialize(conn)
32 }
33
34 pub fn in_memory() -> Result<Self, ContextStoreError> {
36 let conn = Connection::open_in_memory()
37 .map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
38 Self::initialize(conn)
39 }
40
41 fn initialize(conn: Connection) -> Result<Self, ContextStoreError> {
42 conn.pragma_update(None, "journal_mode", "WAL")
43 .map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
44 conn.pragma_update(None, "busy_timeout", 5000)
45 .map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
46 conn.pragma_update(None, "synchronous", "FULL")
50 .map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
51
52 conn.execute_batch(
53 "CREATE TABLE IF NOT EXISTS outputs (
54 flow_run_id TEXT NOT NULL,
55 node_id TEXT NOT NULL,
56 value BLOB NOT NULL,
57 written_at TEXT NOT NULL,
58 PRIMARY KEY (flow_run_id, node_id)
59 ) STRICT;
60
61 CREATE TABLE IF NOT EXISTS globals (
62 key TEXT NOT NULL PRIMARY KEY,
63 value BLOB NOT NULL,
64 written_at TEXT NOT NULL
65 ) STRICT;",
66 )
67 .map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
68
69 Ok(Self { conn: Mutex::new(conn) })
70 }
71}
72
73impl ContextStore for SqliteContextStore {
74 fn put(
75 &self,
76 flow_run_id: FlowRunId,
77 node_id: NodeId,
78 value: &Value,
79 ) -> Result<(), ContextStoreError> {
80 let conn = self.conn.lock().map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
81 let value_bytes = serde_json::to_vec(value)?;
82 let now = chrono::Utc::now().to_rfc3339();
83
84 match conn.execute(
85 "INSERT INTO outputs (flow_run_id, node_id, value, written_at) VALUES (?1, ?2, ?3, ?4)",
86 rusqlite::params![flow_run_id.to_string(), node_id.to_string(), value_bytes, now,],
87 ) {
88 Ok(_) => Ok(()),
89 Err(rusqlite::Error::SqliteFailure(err, _))
90 if err.code == rusqlite::ErrorCode::ConstraintViolation =>
91 {
92 Err(ContextStoreError::AlreadyExists { flow_run_id, node_id })
93 }
94 Err(e) => Err(ContextStoreError::StorageError(e.to_string())),
95 }
96 }
97
98 fn get(
99 &self,
100 flow_run_id: FlowRunId,
101 node_id: NodeId,
102 ) -> Result<Option<Value>, ContextStoreError> {
103 let conn = self.conn.lock().map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
104
105 let mut stmt = conn
106 .prepare("SELECT value FROM outputs WHERE flow_run_id = ?1 AND node_id = ?2")
107 .map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
108
109 let mut rows = stmt
110 .query(rusqlite::params![flow_run_id.to_string(), node_id.to_string(),])
111 .map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
112
113 match rows.next().map_err(|e| ContextStoreError::StorageError(e.to_string()))? {
114 Some(row) => {
115 let bytes: Vec<u8> =
116 row.get(0).map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
117 let value: Value = serde_json::from_slice(&bytes)
118 .map_err(|e| ContextStoreError::DeserializationFailed { source: e })?;
119 Ok(Some(value))
120 }
121 None => Ok(None),
122 }
123 }
124
125 fn list_keys(&self, flow_run_id: FlowRunId) -> Result<Vec<NodeId>, ContextStoreError> {
126 let conn = self.conn.lock().map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
127
128 let mut stmt = conn
129 .prepare("SELECT node_id FROM outputs WHERE flow_run_id = ?1")
130 .map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
131
132 let rows = stmt
133 .query_map(rusqlite::params![flow_run_id.to_string()], |row| {
134 let node_id_str: String = row.get(0)?;
135 Ok(node_id_str)
136 })
137 .map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
138
139 let mut keys = Vec::new();
140 for row in rows {
141 let node_id_str = row.map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
142 let uuid = Uuid::parse_str(&node_id_str)
143 .map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
144 keys.push(NodeId::from(uuid));
145 }
146 Ok(keys)
147 }
148
149 fn put_global(&self, key: &str, value: &Value) -> Result<(), ContextStoreError> {
150 let conn = self.conn.lock().map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
151 let value_bytes = serde_json::to_vec(value)?;
152 let now = chrono::Utc::now().to_rfc3339();
153
154 match conn.execute(
155 "INSERT INTO globals (key, value, written_at) VALUES (?1, ?2, ?3)",
156 rusqlite::params![key, value_bytes, now],
157 ) {
158 Ok(_) => Ok(()),
159 Err(rusqlite::Error::SqliteFailure(err, _))
160 if err.code == rusqlite::ErrorCode::ConstraintViolation =>
161 {
162 Err(ContextStoreError::GlobalAlreadyExists { key: key.to_string() })
163 }
164 Err(e) => Err(ContextStoreError::StorageError(e.to_string())),
165 }
166 }
167
168 fn upsert_global(&self, key: &str, value: &Value) -> Result<(), ContextStoreError> {
169 let conn = self.conn.lock().map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
170 let value_bytes = serde_json::to_vec(value)?;
171 let now = chrono::Utc::now().to_rfc3339();
172
173 conn.execute(
174 "INSERT INTO globals (key, value, written_at) VALUES (?1, ?2, ?3)
175 ON CONFLICT(key) DO UPDATE SET value = excluded.value, written_at = \
176 excluded.written_at",
177 rusqlite::params![key, value_bytes, now],
178 )
179 .map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
180
181 Ok(())
182 }
183
184 fn get_global(&self, key: &str) -> Result<Option<Value>, ContextStoreError> {
185 let conn = self.conn.lock().map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
186
187 let mut stmt = conn
188 .prepare("SELECT value FROM globals WHERE key = ?1")
189 .map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
190
191 let mut rows = stmt
192 .query(rusqlite::params![key])
193 .map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
194
195 match rows.next().map_err(|e| ContextStoreError::StorageError(e.to_string()))? {
196 Some(row) => {
197 let bytes: Vec<u8> =
198 row.get(0).map_err(|e| ContextStoreError::StorageError(e.to_string()))?;
199 let value: Value = serde_json::from_slice(&bytes)
200 .map_err(|e| ContextStoreError::DeserializationFailed { source: e })?;
201 Ok(Some(value))
202 }
203 None => Ok(None),
204 }
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use std::collections::HashSet;
211
212 use serde_json::json;
213 use worldinterface_core::id::{FlowRunId, NodeId};
214
215 use super::*;
216 use crate::store::ContextStore;
217
218 #[test]
221 fn put_and_get_roundtrip() {
222 let store = SqliteContextStore::in_memory().unwrap();
223 let fr = FlowRunId::new();
224 let n = NodeId::new();
225 let val = json!({"key": "value", "num": 42});
226
227 store.put(fr, n, &val).unwrap();
228 let got = store.get(fr, n).unwrap().unwrap();
229 assert_eq!(val, got);
230 }
231
232 #[test]
233 fn put_returns_ok_on_first_write() {
234 let store = SqliteContextStore::in_memory().unwrap();
235 let fr = FlowRunId::new();
236 let n = NodeId::new();
237 assert!(store.put(fr, n, &json!("hello")).is_ok());
238 }
239
240 #[test]
241 fn put_rejects_duplicate() {
242 let store = SqliteContextStore::in_memory().unwrap();
243 let fr = FlowRunId::new();
244 let n = NodeId::new();
245
246 store.put(fr, n, &json!(1)).unwrap();
247 let err = store.put(fr, n, &json!(2)).unwrap_err();
248 assert!(
249 matches!(err, ContextStoreError::AlreadyExists { .. }),
250 "expected AlreadyExists, got: {err:?}"
251 );
252 }
253
254 #[test]
255 fn put_duplicate_preserves_original() {
256 let store = SqliteContextStore::in_memory().unwrap();
257 let fr = FlowRunId::new();
258 let n = NodeId::new();
259 let original = json!({"original": true});
260 let replacement = json!({"original": false});
261
262 store.put(fr, n, &original).unwrap();
263 let _ = store.put(fr, n, &replacement); let got = store.get(fr, n).unwrap().unwrap();
265 assert_eq!(original, got);
266 }
267
268 #[test]
269 fn put_different_nodes_same_flow() {
270 let store = SqliteContextStore::in_memory().unwrap();
271 let fr = FlowRunId::new();
272 let n1 = NodeId::new();
273 let n2 = NodeId::new();
274
275 store.put(fr, n1, &json!("a")).unwrap();
276 store.put(fr, n2, &json!("b")).unwrap();
277 assert_eq!(store.get(fr, n1).unwrap().unwrap(), json!("a"));
278 assert_eq!(store.get(fr, n2).unwrap().unwrap(), json!("b"));
279 }
280
281 #[test]
282 fn put_same_node_different_flows() {
283 let store = SqliteContextStore::in_memory().unwrap();
284 let fr1 = FlowRunId::new();
285 let fr2 = FlowRunId::new();
286 let n = NodeId::new();
287
288 store.put(fr1, n, &json!(1)).unwrap();
289 store.put(fr2, n, &json!(2)).unwrap();
290 assert_eq!(store.get(fr1, n).unwrap().unwrap(), json!(1));
291 assert_eq!(store.get(fr2, n).unwrap().unwrap(), json!(2));
292 }
293
294 #[test]
295 fn put_complex_json_value() {
296 let store = SqliteContextStore::in_memory().unwrap();
297 let fr = FlowRunId::new();
298 let n = NodeId::new();
299 let val = json!({
300 "nested": {"deep": {"value": [1, 2, 3]}},
301 "null_field": null,
302 "float": 1.234,
303 "bool": true,
304 "empty_array": [],
305 "empty_object": {}
306 });
307
308 store.put(fr, n, &val).unwrap();
309 assert_eq!(store.get(fr, n).unwrap().unwrap(), val);
310 }
311
312 #[test]
313 fn put_null_value() {
314 let store = SqliteContextStore::in_memory().unwrap();
315 let fr = FlowRunId::new();
316 let n = NodeId::new();
317
318 store.put(fr, n, &Value::Null).unwrap();
319 assert_eq!(store.get(fr, n).unwrap().unwrap(), Value::Null);
320 }
321
322 #[test]
325 fn get_nonexistent_returns_none() {
326 let store = SqliteContextStore::in_memory().unwrap();
327 let fr = FlowRunId::new();
328 let n = NodeId::new();
329 assert_eq!(store.get(fr, n).unwrap(), None);
330 }
331
332 #[test]
333 fn get_after_put_returns_some() {
334 let store = SqliteContextStore::in_memory().unwrap();
335 let fr = FlowRunId::new();
336 let n = NodeId::new();
337 let val = json!("test");
338
339 store.put(fr, n, &val).unwrap();
340 assert!(store.get(fr, n).unwrap().is_some());
341 }
342
343 #[test]
344 fn list_keys_empty_flow() {
345 let store = SqliteContextStore::in_memory().unwrap();
346 let fr = FlowRunId::new();
347 assert!(store.list_keys(fr).unwrap().is_empty());
348 }
349
350 #[test]
351 fn list_keys_returns_written_nodes() {
352 let store = SqliteContextStore::in_memory().unwrap();
353 let fr = FlowRunId::new();
354 let n1 = NodeId::new();
355 let n2 = NodeId::new();
356 let n3 = NodeId::new();
357
358 store.put(fr, n1, &json!(1)).unwrap();
359 store.put(fr, n2, &json!(2)).unwrap();
360 store.put(fr, n3, &json!(3)).unwrap();
361
362 let keys: HashSet<NodeId> = store.list_keys(fr).unwrap().into_iter().collect();
363 assert_eq!(keys.len(), 3);
364 assert!(keys.contains(&n1));
365 assert!(keys.contains(&n2));
366 assert!(keys.contains(&n3));
367 }
368
369 #[test]
370 fn list_keys_scoped_to_flow() {
371 let store = SqliteContextStore::in_memory().unwrap();
372 let fr1 = FlowRunId::new();
373 let fr2 = FlowRunId::new();
374 let n1 = NodeId::new();
375 let n2 = NodeId::new();
376
377 store.put(fr1, n1, &json!(1)).unwrap();
378 store.put(fr2, n2, &json!(2)).unwrap();
379
380 let keys1: Vec<NodeId> = store.list_keys(fr1).unwrap();
381 assert_eq!(keys1.len(), 1);
382 assert_eq!(keys1[0], n1);
383 }
384
385 #[test]
388 fn global_put_and_get_roundtrip() {
389 let store = SqliteContextStore::in_memory().unwrap();
390 let val = json!({"global": true});
391
392 store.put_global("my_key", &val).unwrap();
393 assert_eq!(store.get_global("my_key").unwrap().unwrap(), val);
394 }
395
396 #[test]
397 fn global_put_rejects_duplicate() {
398 let store = SqliteContextStore::in_memory().unwrap();
399
400 store.put_global("key", &json!(1)).unwrap();
401 let err = store.put_global("key", &json!(2)).unwrap_err();
402 assert!(
403 matches!(err, ContextStoreError::GlobalAlreadyExists { .. }),
404 "expected GlobalAlreadyExists, got: {err:?}"
405 );
406 }
407
408 #[test]
409 fn global_get_nonexistent_returns_none() {
410 let store = SqliteContextStore::in_memory().unwrap();
411 assert_eq!(store.get_global("nonexistent").unwrap(), None);
412 }
413
414 #[test]
415 fn globals_independent_of_outputs() {
416 let store = SqliteContextStore::in_memory().unwrap();
417 let fr = FlowRunId::new();
418 let n = NodeId::new();
419
420 store.put(fr, n, &json!("output")).unwrap();
421 store.put_global("global_key", &json!("global")).unwrap();
422
423 assert_eq!(store.get(fr, n).unwrap().unwrap(), json!("output"));
425 assert_eq!(store.get_global("global_key").unwrap().unwrap(), json!("global"));
426 }
427
428 #[test]
431 fn upsert_global_creates_new_entry() {
432 let store = SqliteContextStore::in_memory().unwrap();
433 let val = json!({"version": 1});
434
435 store.upsert_global("my_key", &val).unwrap();
436 assert_eq!(store.get_global("my_key").unwrap().unwrap(), val);
437 }
438
439 #[test]
440 fn upsert_global_overwrites_existing_entry() {
441 let store = SqliteContextStore::in_memory().unwrap();
442 let val1 = json!({"version": 1});
443 let val2 = json!({"version": 2});
444
445 store.upsert_global("my_key", &val1).unwrap();
446 store.upsert_global("my_key", &val2).unwrap();
447 assert_eq!(store.get_global("my_key").unwrap().unwrap(), val2);
448 }
449
450 #[test]
451 fn upsert_global_and_get_global_round_trip() {
452 let store = SqliteContextStore::in_memory().unwrap();
453
454 store.upsert_global("key1", &json!("first")).unwrap();
456 assert_eq!(store.get_global("key1").unwrap().unwrap(), json!("first"));
457
458 store.upsert_global("key1", &json!("second")).unwrap();
460 assert_eq!(store.get_global("key1").unwrap().unwrap(), json!("second"));
461
462 let err = store.put_global("key1", &json!("third")).unwrap_err();
464 assert!(matches!(err, ContextStoreError::GlobalAlreadyExists { .. }));
465
466 assert_eq!(store.get_global("key1").unwrap().unwrap(), json!("second"));
468 }
469
470 #[test]
473 fn open_creates_db_file() {
474 let dir = tempfile::tempdir().unwrap();
475 let path = dir.path().join("test.db");
476 assert!(!path.exists());
477
478 let _store = SqliteContextStore::open(&path).unwrap();
479 assert!(path.exists());
480 }
481
482 #[test]
483 fn open_creates_tables() {
484 let dir = tempfile::tempdir().unwrap();
485 let path = dir.path().join("test.db");
486 let _store = SqliteContextStore::open(&path).unwrap();
487
488 let conn = Connection::open(&path).unwrap();
490 let tables: Vec<String> = conn
491 .prepare("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
492 .unwrap()
493 .query_map([], |row| row.get(0))
494 .unwrap()
495 .collect::<Result<Vec<_>, _>>()
496 .unwrap();
497
498 assert!(tables.contains(&"globals".to_string()));
499 assert!(tables.contains(&"outputs".to_string()));
500 }
501
502 #[test]
503 fn open_is_idempotent() {
504 let dir = tempfile::tempdir().unwrap();
505 let path = dir.path().join("test.db");
506
507 let store1 = SqliteContextStore::open(&path).unwrap();
508 let fr = FlowRunId::new();
509 let n = NodeId::new();
510 store1.put(fr, n, &json!("data")).unwrap();
511 drop(store1);
512
513 let store2 = SqliteContextStore::open(&path).unwrap();
515 assert_eq!(store2.get(fr, n).unwrap().unwrap(), json!("data"));
516 }
517
518 #[test]
519 fn wal_mode_enabled() {
520 let dir = tempfile::tempdir().unwrap();
521 let path = dir.path().join("test.db");
522 let store = SqliteContextStore::open(&path).unwrap();
523
524 let conn = store.conn.lock().unwrap();
525 let mode: String = conn.pragma_query_value(None, "journal_mode", |row| row.get(0)).unwrap();
526 assert_eq!(mode, "wal");
527 }
528
529 #[test]
530 fn in_memory_is_isolated() {
531 let store1 = SqliteContextStore::in_memory().unwrap();
532 let store2 = SqliteContextStore::in_memory().unwrap();
533 let fr = FlowRunId::new();
534 let n = NodeId::new();
535
536 store1.put(fr, n, &json!("only in store1")).unwrap();
537 assert_eq!(store2.get(fr, n).unwrap(), None);
538 }
539
540 #[test]
543 fn data_survives_reopen() {
544 let dir = tempfile::tempdir().unwrap();
545 let path = dir.path().join("test.db");
546 let fr = FlowRunId::new();
547 let n = NodeId::new();
548 let val = json!({"survived": true});
549
550 {
551 let store = SqliteContextStore::open(&path).unwrap();
552 store.put(fr, n, &val).unwrap();
553 }
555
556 let store = SqliteContextStore::open(&path).unwrap();
557 assert_eq!(store.get(fr, n).unwrap().unwrap(), val);
558 }
559
560 #[test]
561 fn crash_after_write_before_complete() {
562 let dir = tempfile::tempdir().unwrap();
563 let path = dir.path().join("test.db");
564 let fr = FlowRunId::new();
565 let n = NodeId::new();
566 let val = json!("written_but_not_completed");
567
568 {
569 let store = SqliteContextStore::open(&path).unwrap();
570 store.put(fr, n, &val).unwrap();
571 }
573
574 let store = SqliteContextStore::open(&path).unwrap();
576 assert_eq!(store.get(fr, n).unwrap().unwrap(), val);
577
578 let err = store.put(fr, n, &json!("retry")).unwrap_err();
580 assert!(matches!(err, ContextStoreError::AlreadyExists { .. }));
581 }
582
583 #[test]
584 fn crash_before_write() {
585 let dir = tempfile::tempdir().unwrap();
586 let path = dir.path().join("test.db");
587 let fr = FlowRunId::new();
588 let n = NodeId::new();
589
590 {
591 let _store = SqliteContextStore::open(&path).unwrap();
592 }
594
595 let store = SqliteContextStore::open(&path).unwrap();
596 assert_eq!(store.get(fr, n).unwrap(), None);
597 }
598
599 #[test]
600 fn multiple_flows_survive_reopen() {
601 let dir = tempfile::tempdir().unwrap();
602 let path = dir.path().join("test.db");
603 let fr1 = FlowRunId::new();
604 let fr2 = FlowRunId::new();
605 let n1 = NodeId::new();
606 let n2 = NodeId::new();
607
608 {
609 let store = SqliteContextStore::open(&path).unwrap();
610 store.put(fr1, n1, &json!("flow1")).unwrap();
611 store.put(fr2, n2, &json!("flow2")).unwrap();
612 }
613
614 let store = SqliteContextStore::open(&path).unwrap();
615 assert_eq!(store.get(fr1, n1).unwrap().unwrap(), json!("flow1"));
616 assert_eq!(store.get(fr2, n2).unwrap().unwrap(), json!("flow2"));
617 }
618
619 #[test]
622 fn empty_string_value() {
623 let store = SqliteContextStore::in_memory().unwrap();
624 let fr = FlowRunId::new();
625 let n = NodeId::new();
626 let val = json!("");
627
628 store.put(fr, n, &val).unwrap();
629 assert_eq!(store.get(fr, n).unwrap().unwrap(), val);
630 }
631
632 #[test]
633 fn large_value() {
634 let store = SqliteContextStore::in_memory().unwrap();
635 let fr = FlowRunId::new();
636 let n = NodeId::new();
637 let big_string = "x".repeat(1_000_000);
639 let val = json!({"data": big_string});
640
641 store.put(fr, n, &val).unwrap();
642 assert_eq!(store.get(fr, n).unwrap().unwrap(), val);
643 }
644
645 #[test]
646 fn special_characters_in_global_key() {
647 let store = SqliteContextStore::in_memory().unwrap();
648 let keys = [
649 "key with spaces",
650 "unicode: 你好世界 🌍",
651 "slashes/and\\backslashes",
652 "quotes\"and'apostrophes",
653 "",
654 ];
655
656 for (i, key) in keys.iter().enumerate() {
657 let val = json!(i);
658 store.put_global(key, &val).unwrap();
659 assert_eq!(store.get_global(key).unwrap().unwrap(), val);
660 }
661 }
662
663 #[test]
664 fn concurrent_reads_after_write() {
665 use std::sync::Arc;
666 use std::thread;
667
668 let store = Arc::new(SqliteContextStore::in_memory().unwrap());
669 let fr = FlowRunId::new();
670 let n = NodeId::new();
671 let val = json!({"concurrent": true});
672
673 store.put(fr, n, &val).unwrap();
674
675 let handles: Vec<_> = (0..8)
676 .map(|_| {
677 let store = Arc::clone(&store);
678 let expected = val.clone();
679 thread::spawn(move || {
680 let got = store.get(fr, n).unwrap().unwrap();
681 assert_eq!(got, expected);
682 })
683 })
684 .collect();
685
686 for h in handles {
687 h.join().unwrap();
688 }
689 }
690}