Skip to main content

scud/attractor/
context.rs

1//! Thread-safe key-value execution context for pipeline runs.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8/// Thread-safe key-value context shared across pipeline execution.
9///
10/// Stores arbitrary JSON values keyed by string names. Supports
11/// isolated clones for parallel branches and atomic update application.
12#[derive(Debug, Clone)]
13pub struct Context {
14    inner: Arc<RwLock<HashMap<String, serde_json::Value>>>,
15}
16
17impl Context {
18    /// Create an empty context.
19    pub fn new() -> Self {
20        Self {
21            inner: Arc::new(RwLock::new(HashMap::new())),
22        }
23    }
24
25    /// Create a context with initial values.
26    pub fn with_values(values: HashMap<String, serde_json::Value>) -> Self {
27        Self {
28            inner: Arc::new(RwLock::new(values)),
29        }
30    }
31
32    /// Set a value in the context.
33    pub async fn set(&self, key: impl Into<String>, value: serde_json::Value) {
34        self.inner.write().await.insert(key.into(), value);
35    }
36
37    /// Get a value from the context.
38    pub async fn get(&self, key: &str) -> Option<serde_json::Value> {
39        self.inner.read().await.get(key).cloned()
40    }
41
42    /// Get a string value from the context.
43    pub async fn get_str(&self, key: &str) -> Option<String> {
44        self.get(key)
45            .await
46            .and_then(|v| v.as_str().map(String::from))
47    }
48
49    /// Take a snapshot of the current context state.
50    pub async fn snapshot(&self) -> HashMap<String, serde_json::Value> {
51        self.inner.read().await.clone()
52    }
53
54    /// Create an isolated clone for parallel branches.
55    ///
56    /// Changes to the clone do not affect the original.
57    pub async fn clone_isolated(&self) -> Self {
58        Self::with_values(self.snapshot().await)
59    }
60
61    /// Apply a batch of updates atomically.
62    pub async fn apply_updates(&self, updates: &HashMap<String, serde_json::Value>) {
63        let mut inner = self.inner.write().await;
64        for (key, value) in updates {
65            inner.insert(key.clone(), value.clone());
66        }
67    }
68
69    /// Check if the context contains a key.
70    pub async fn contains_key(&self, key: &str) -> bool {
71        self.inner.read().await.contains_key(key)
72    }
73
74    /// Get the number of entries.
75    pub async fn len(&self) -> usize {
76        self.inner.read().await.len()
77    }
78
79    /// Check if the context is empty.
80    pub async fn is_empty(&self) -> bool {
81        self.inner.read().await.is_empty()
82    }
83}
84
85impl Default for Context {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91/// Serializable snapshot of a context for checkpointing.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ContextSnapshot {
94    pub values: HashMap<String, serde_json::Value>,
95}
96
97impl From<HashMap<String, serde_json::Value>> for ContextSnapshot {
98    fn from(values: HashMap<String, serde_json::Value>) -> Self {
99        Self { values }
100    }
101}
102
103impl ContextSnapshot {
104    /// Restore a Context from this snapshot.
105    pub fn restore(&self) -> Context {
106        Context::with_values(self.values.clone())
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[tokio::test]
115    async fn test_set_and_get() {
116        let ctx = Context::new();
117        ctx.set("name", serde_json::json!("Alice")).await;
118        assert_eq!(
119            ctx.get("name").await,
120            Some(serde_json::json!("Alice"))
121        );
122    }
123
124    #[tokio::test]
125    async fn test_get_str() {
126        let ctx = Context::new();
127        ctx.set("greeting", serde_json::json!("hello")).await;
128        assert_eq!(ctx.get_str("greeting").await, Some("hello".to_string()));
129
130        ctx.set("number", serde_json::json!(42)).await;
131        assert_eq!(ctx.get_str("number").await, None);
132    }
133
134    #[tokio::test]
135    async fn test_snapshot() {
136        let ctx = Context::new();
137        ctx.set("a", serde_json::json!(1)).await;
138        ctx.set("b", serde_json::json!(2)).await;
139        let snap = ctx.snapshot().await;
140        assert_eq!(snap.len(), 2);
141    }
142
143    #[tokio::test]
144    async fn test_clone_isolated() {
145        let ctx = Context::new();
146        ctx.set("shared", serde_json::json!("original")).await;
147
148        let clone = ctx.clone_isolated().await;
149        clone.set("shared", serde_json::json!("modified")).await;
150
151        // Original should be unchanged
152        assert_eq!(
153            ctx.get_str("shared").await,
154            Some("original".to_string())
155        );
156        assert_eq!(
157            clone.get_str("shared").await,
158            Some("modified".to_string())
159        );
160    }
161
162    #[tokio::test]
163    async fn test_apply_updates() {
164        let ctx = Context::new();
165        let mut updates = HashMap::new();
166        updates.insert("x".into(), serde_json::json!(10));
167        updates.insert("y".into(), serde_json::json!(20));
168        ctx.apply_updates(&updates).await;
169        assert_eq!(ctx.len().await, 2);
170    }
171
172    #[tokio::test]
173    async fn test_snapshot_roundtrip() {
174        let ctx = Context::new();
175        ctx.set("key", serde_json::json!("value")).await;
176        let snap = ContextSnapshot::from(ctx.snapshot().await);
177        let json = serde_json::to_string(&snap).unwrap();
178        let restored_snap: ContextSnapshot = serde_json::from_str(&json).unwrap();
179        let restored = restored_snap.restore();
180        assert_eq!(
181            restored.get_str("key").await,
182            Some("value".to_string())
183        );
184    }
185}