Skip to main content

swink_agent/
state.rs

1//! Session key-value state store with delta tracking.
2//!
3//! Provides [`SessionState`] for per-session structured data that tools can
4//! read/write during execution, and [`StateDelta`] for tracking mutations
5//! since the last flush. State is shared via `Arc<RwLock<SessionState>>`.
6#![forbid(unsafe_code)]
7
8use std::collections::HashMap;
9
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12
13// ─── StateDelta ─────────────────────────────────────────────────────────────
14
15/// Record of mutations since the last flush.
16///
17/// `Some(value)` = set/update, `None` = removed.
18#[derive(Debug, Default, Clone, Serialize, Deserialize)]
19pub struct StateDelta {
20    /// Map of changed keys. `Some(v)` means the key was set to `v`;
21    /// `None` means the key was removed.
22    pub changes: HashMap<String, Option<Value>>,
23}
24
25impl StateDelta {
26    /// True if no changes recorded.
27    pub fn is_empty(&self) -> bool {
28        self.changes.is_empty()
29    }
30
31    /// Number of changed keys.
32    pub fn len(&self) -> usize {
33        self.changes.len()
34    }
35}
36
37// ─── SessionState ───────────────────────────────────────────────────────────
38
39/// Key-value store with change tracking for session-attached structured data.
40///
41/// Tools receive an `Arc<RwLock<SessionState>>` during execution and can
42/// read/write arbitrary typed values. Changes are tracked in a [`StateDelta`]
43/// that is flushed at the end of each turn.
44#[derive(Debug, Default, Clone, Serialize, Deserialize)]
45pub struct SessionState {
46    data: HashMap<String, Value>,
47    #[serde(skip)]
48    delta: StateDelta,
49}
50
51impl SessionState {
52    /// Create a new empty session state.
53    pub fn new() -> Self {
54        Self::default()
55    }
56
57    /// Create session state pre-populated with the given data.
58    ///
59    /// Pre-seeded data does NOT appear in the delta (baseline semantics).
60    pub fn with_data(data: HashMap<String, Value>) -> Self {
61        Self {
62            data,
63            delta: StateDelta::default(),
64        }
65    }
66
67    /// Get a typed value by key. Returns `None` if key is missing or
68    /// deserialization fails.
69    pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
70        self.data
71            .get(key)
72            .and_then(|v| serde_json::from_value(v.clone()).ok())
73    }
74
75    /// Get the raw JSON value by key without deserialization.
76    pub fn get_raw(&self, key: &str) -> Option<&Value> {
77        self.data.get(key)
78    }
79
80    /// Set a typed value. Serializes to `Value` and records in delta.
81    ///
82    /// Returns an error if the value cannot be serialized to JSON.
83    pub fn set<T: Serialize>(&mut self, key: &str, value: T) -> Result<(), serde_json::Error> {
84        let val = serde_json::to_value(value)?;
85        self.data.insert(key.to_string(), val.clone());
86        self.delta.changes.insert(key.to_string(), Some(val));
87        Ok(())
88    }
89
90    /// Remove a key. Records removal in delta. No-op if key absent.
91    pub fn remove(&mut self, key: &str) {
92        if self.data.remove(key).is_some() {
93            self.delta.changes.insert(key.to_string(), None);
94        }
95    }
96
97    /// Check if a key exists.
98    pub fn contains(&self, key: &str) -> bool {
99        self.data.contains_key(key)
100    }
101
102    /// Iterate over all keys.
103    pub fn keys(&self) -> impl Iterator<Item = &str> {
104        self.data.keys().map(String::as_str)
105    }
106
107    /// Number of key-value pairs.
108    pub fn len(&self) -> usize {
109        self.data.len()
110    }
111
112    /// True if no key-value pairs.
113    pub fn is_empty(&self) -> bool {
114        self.data.is_empty()
115    }
116
117    /// Remove all key-value pairs. Records all existing keys as removed in delta.
118    pub fn clear(&mut self) {
119        for key in self.data.keys() {
120            self.delta.changes.insert(key.clone(), None);
121        }
122        self.data.clear();
123    }
124
125    /// Read-only reference to pending delta.
126    pub const fn delta(&self) -> &StateDelta {
127        &self.delta
128    }
129
130    /// Take the pending delta and reset tracking. Returns the delta.
131    pub fn flush_delta(&mut self) -> StateDelta {
132        std::mem::take(&mut self.delta)
133    }
134
135    /// Snapshot the materialized data as a JSON Value (for persistence).
136    pub fn snapshot(&self) -> Value {
137        serde_json::to_value(&self.data).expect("HashMap<String, Value> is always serializable")
138    }
139
140    /// Restore from a JSON Value snapshot. Returns a new `SessionState` with
141    /// empty delta.
142    pub fn restore_from_snapshot(snapshot: Value) -> Result<Self, serde_json::Error> {
143        let data: HashMap<String, Value> = serde_json::from_value(snapshot)?;
144        Ok(Self {
145            data,
146            delta: StateDelta::default(),
147        })
148    }
149}
150
151// ─── Compile-time Send + Sync assertions ────────────────────────────────────
152
153const _: () = {
154    const fn assert_send_sync<T: Send + Sync>() {}
155    assert_send_sync::<SessionState>();
156    assert_send_sync::<StateDelta>();
157    assert_send_sync::<std::sync::Arc<std::sync::RwLock<SessionState>>>();
158};
159
160// ─── Tests ──────────────────────────────────────────────────────────────────
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use serde_json::json;
166
167    // ── StateDelta ──
168
169    #[test]
170    fn delta_default_is_empty() {
171        let d = StateDelta::default();
172        assert!(d.is_empty());
173        assert_eq!(d.len(), 0);
174    }
175
176    #[test]
177    fn delta_serde_roundtrip() {
178        let mut d = StateDelta::default();
179        d.changes.insert("a".into(), Some(json!(1)));
180        d.changes.insert("b".into(), None);
181        let json = serde_json::to_string(&d).unwrap();
182        let d2: StateDelta = serde_json::from_str(&json).unwrap();
183        assert_eq!(d2.len(), 2);
184        assert_eq!(d2.changes["a"], Some(json!(1)));
185        assert_eq!(d2.changes["b"], None);
186    }
187
188    // ── SessionState get/set/remove ──
189
190    #[test]
191    fn set_and_get_typed() {
192        let mut s = SessionState::new();
193        s.set("count", 42_i64).unwrap();
194        assert_eq!(s.get::<i64>("count"), Some(42));
195    }
196
197    #[test]
198    fn get_raw_returns_value_ref() {
199        let mut s = SessionState::new();
200        s.set("key", "hello").unwrap();
201        assert_eq!(s.get_raw("key"), Some(&json!("hello")));
202    }
203
204    #[test]
205    fn get_missing_returns_none() {
206        let s = SessionState::new();
207        assert_eq!(s.get::<String>("nope"), None);
208    }
209
210    #[test]
211    fn get_wrong_type_returns_none() {
212        let mut s = SessionState::new();
213        s.set("key", "hello").unwrap();
214        // Try to get as i64 — should fail gracefully
215        assert_eq!(s.get::<i64>("key"), None);
216        // Original value still intact
217        assert_eq!(s.get::<String>("key"), Some("hello".to_string()));
218    }
219
220    #[test]
221    fn remove_existing_key() {
222        let mut s = SessionState::new();
223        s.set("x", 1).unwrap();
224        s.remove("x");
225        assert!(!s.contains("x"));
226        assert!(s.is_empty());
227    }
228
229    #[test]
230    fn remove_absent_key_is_noop() {
231        let mut s = SessionState::new();
232        s.remove("nope");
233        assert!(s.delta().is_empty());
234    }
235
236    #[test]
237    fn contains_keys_len_is_empty() {
238        let mut s = SessionState::new();
239        assert!(s.is_empty());
240        s.set("a", 1).unwrap();
241        s.set("b", 2).unwrap();
242        assert!(s.contains("a"));
243        assert!(!s.contains("c"));
244        assert_eq!(s.len(), 2);
245        assert!(!s.is_empty());
246        let keys: Vec<&str> = s.keys().collect();
247        assert!(keys.contains(&"a"));
248        assert!(keys.contains(&"b"));
249    }
250
251    #[test]
252    fn clear_records_all_removals() {
253        let mut s = SessionState::new();
254        s.set("a", 1).unwrap();
255        s.set("b", 2).unwrap();
256        s.flush_delta(); // reset
257        s.clear();
258        assert!(s.is_empty());
259        assert_eq!(s.delta().len(), 2);
260        assert_eq!(s.delta().changes["a"], None);
261        assert_eq!(s.delta().changes["b"], None);
262    }
263
264    // ── Delta collapse ──
265
266    #[test]
267    fn delta_set_set_last_wins() {
268        let mut s = SessionState::new();
269        s.set("k", 1).unwrap();
270        s.set("k", 2).unwrap();
271        assert_eq!(s.delta().changes["k"], Some(json!(2)));
272        assert_eq!(s.delta().len(), 1);
273    }
274
275    #[test]
276    fn delta_set_remove_is_none() {
277        let mut s = SessionState::new();
278        s.set("k", 1).unwrap();
279        s.remove("k");
280        assert_eq!(s.delta().changes["k"], None);
281    }
282
283    #[test]
284    fn delta_remove_set_is_some() {
285        let mut s = SessionState::with_data(std::iter::once(("k".to_string(), json!(1))).collect());
286        s.remove("k");
287        s.set("k", 99).unwrap();
288        assert_eq!(s.delta().changes["k"], Some(json!(99)));
289    }
290
291    // ── flush_delta ──
292
293    #[test]
294    fn flush_delta_returns_and_resets() {
295        let mut s = SessionState::new();
296        s.set("a", 1).unwrap();
297        let d = s.flush_delta();
298        assert_eq!(d.len(), 1);
299        assert!(s.delta().is_empty());
300    }
301
302    #[test]
303    fn flush_empty_delta_returns_empty() {
304        let mut s = SessionState::new();
305        let d = s.flush_delta();
306        assert!(d.is_empty());
307    }
308
309    // ── with_data (baseline semantics) ──
310
311    #[test]
312    fn with_data_pre_seeds_without_delta() {
313        let data: HashMap<String, Value> = std::iter::once(("x".into(), json!(42))).collect();
314        let s = SessionState::with_data(data);
315        assert_eq!(s.get::<i64>("x"), Some(42));
316        assert!(s.delta().is_empty());
317    }
318
319    // ── snapshot / restore ──
320
321    #[test]
322    fn snapshot_restore_roundtrip() {
323        let mut s = SessionState::new();
324        s.set("name", "alice").unwrap();
325        s.set("age", 30).unwrap();
326        let snap = s.snapshot();
327        let s2 = SessionState::restore_from_snapshot(snap).unwrap();
328        assert_eq!(s2.get::<String>("name"), Some("alice".to_string()));
329        assert_eq!(s2.get::<i64>("age"), Some(30));
330        assert!(s2.delta().is_empty());
331    }
332
333    // ── Serialize roundtrip (delta skipped) ──
334
335    #[test]
336    fn serde_roundtrip_skips_delta() {
337        let mut s = SessionState::new();
338        s.set("k", "v").unwrap();
339        // Delta has an entry
340        assert!(!s.delta().is_empty());
341        let json = serde_json::to_string(&s).unwrap();
342        let s2: SessionState = serde_json::from_str(&json).unwrap();
343        assert_eq!(s2.get::<String>("k"), Some("v".to_string()));
344        // Delta is empty after deserialization (skipped)
345        assert!(s2.delta().is_empty());
346    }
347
348    // ── Serialization error handling ──
349
350    #[test]
351    fn set_returns_error_on_serialization_failure() {
352        use serde::ser::{self, Serializer};
353
354        /// A type whose `Serialize` impl always fails.
355        struct Unserializable;
356
357        impl Serialize for Unserializable {
358            fn serialize<S: Serializer>(&self, _s: S) -> Result<S::Ok, S::Error> {
359                Err(ser::Error::custom("intentional serialization failure"))
360            }
361        }
362
363        let mut s = SessionState::new();
364        let result = s.set("bad", Unserializable);
365        assert!(result.is_err());
366        // State must remain unchanged after a failed set.
367        assert!(!s.contains("bad"));
368        assert!(s.delta().is_empty());
369    }
370
371    // ── Nested JSON values ──
372
373    #[test]
374    fn nested_json_roundtrip() {
375        let mut s = SessionState::new();
376        let nested = json!({
377            "user": {"name": "bob", "scores": [1, 2, 3]},
378            "active": true
379        });
380        s.set("profile", nested.clone()).unwrap();
381        let snap = s.snapshot();
382        let s2 = SessionState::restore_from_snapshot(snap).unwrap();
383        assert_eq!(s2.get_raw("profile"), Some(&nested));
384    }
385
386    #[test]
387    fn restore_from_corrupt_snapshot_returns_error() {
388        let err = SessionState::restore_from_snapshot(json!(["not", "an", "object"])).unwrap_err();
389        assert!(err.to_string().contains("map"));
390    }
391}