Skip to main content

simple_agents_workflow/state/
mod.rs

1use std::collections::{BTreeMap, HashMap};
2
3use serde_json::{Map, Value};
4use thiserror::Error;
5
6/// Opaque token used to gate scoped state reads/writes.
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub struct CapabilityToken {
9    capability: ScopeCapability,
10}
11
12impl CapabilityToken {
13    pub fn llm_read() -> Self {
14        Self {
15            capability: ScopeCapability::LlmRead,
16        }
17    }
18
19    pub fn tool_read() -> Self {
20        Self {
21            capability: ScopeCapability::ToolRead,
22        }
23    }
24
25    pub fn condition_read() -> Self {
26        Self {
27            capability: ScopeCapability::ConditionRead,
28        }
29    }
30
31    pub fn subgraph_read() -> Self {
32        Self {
33            capability: ScopeCapability::SubgraphRead,
34        }
35    }
36
37    pub fn batch_read() -> Self {
38        Self {
39            capability: ScopeCapability::BatchRead,
40        }
41    }
42
43    pub fn filter_read() -> Self {
44        Self {
45            capability: ScopeCapability::FilterRead,
46        }
47    }
48
49    pub fn llm_write() -> Self {
50        Self {
51            capability: ScopeCapability::LlmWrite,
52        }
53    }
54
55    pub fn tool_write() -> Self {
56        Self {
57            capability: ScopeCapability::ToolWrite,
58        }
59    }
60
61    pub fn condition_write() -> Self {
62        Self {
63            capability: ScopeCapability::ConditionWrite,
64        }
65    }
66
67    pub fn subgraph_write() -> Self {
68        Self {
69            capability: ScopeCapability::SubgraphWrite,
70        }
71    }
72
73    pub fn batch_write() -> Self {
74        Self {
75            capability: ScopeCapability::BatchWrite,
76        }
77    }
78
79    pub fn filter_write() -> Self {
80        Self {
81            capability: ScopeCapability::FilterWrite,
82        }
83    }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87enum ScopeCapability {
88    LlmRead,
89    ToolRead,
90    ConditionRead,
91    SubgraphRead,
92    BatchRead,
93    FilterRead,
94    LlmWrite,
95    ToolWrite,
96    ConditionWrite,
97    SubgraphWrite,
98    BatchWrite,
99    FilterWrite,
100}
101
102impl ScopeCapability {
103    fn as_str(self) -> &'static str {
104        match self {
105            Self::LlmRead => "llm_read",
106            Self::ToolRead => "tool_read",
107            Self::ConditionRead => "condition_read",
108            Self::SubgraphRead => "subgraph_read",
109            Self::BatchRead => "batch_read",
110            Self::FilterRead => "filter_read",
111            Self::LlmWrite => "llm_write",
112            Self::ToolWrite => "tool_write",
113            Self::ConditionWrite => "condition_write",
114            Self::SubgraphWrite => "subgraph_write",
115            Self::BatchWrite => "batch_write",
116            Self::FilterWrite => "filter_write",
117        }
118    }
119
120    fn allows_read(self) -> bool {
121        matches!(
122            self,
123            Self::LlmRead
124                | Self::ToolRead
125                | Self::ConditionRead
126                | Self::SubgraphRead
127                | Self::BatchRead
128                | Self::FilterRead
129        )
130    }
131}
132
133/// Typed read/write boundary failures for scoped runtime state.
134#[derive(Debug, Clone, PartialEq, Eq, Error)]
135pub enum ScopeAccessError {
136    /// Node attempted to read scope with an invalid capability.
137    #[error("scope read denied for capability '{capability}'")]
138    ReadDenied {
139        /// Capability used for the read.
140        capability: &'static str,
141    },
142    /// Node attempted to write scope with an invalid capability.
143    #[error("scope write denied for capability '{capability}'")]
144    WriteDenied {
145        /// Capability used for the write.
146        capability: &'static str,
147    },
148    /// Attempted to exit root scope.
149    #[error("cannot exit root scope")]
150    CannotExitRoot,
151}
152
153#[derive(Debug)]
154struct ScopeFrame {
155    parent: Option<usize>,
156    node_outputs: BTreeMap<String, Value>,
157    loop_iterations: HashMap<String, u32>,
158    last_llm_output: Option<String>,
159    last_tool_output: Option<Value>,
160}
161
162/// Hierarchical scoped runtime state with capability-checked access.
163#[derive(Debug)]
164pub struct ScopedState {
165    workflow_input: Value,
166    frames: Vec<ScopeFrame>,
167    active_frame: usize,
168}
169
170impl ScopedState {
171    pub fn new(workflow_input: Value) -> Self {
172        Self {
173            workflow_input,
174            frames: vec![ScopeFrame {
175                parent: None,
176                node_outputs: BTreeMap::new(),
177                loop_iterations: HashMap::new(),
178                last_llm_output: None,
179                last_tool_output: None,
180            }],
181            active_frame: 0,
182        }
183    }
184
185    pub fn enter_child_scope(&mut self) {
186        let parent = self.active_frame;
187        self.frames.push(ScopeFrame {
188            parent: Some(parent),
189            node_outputs: BTreeMap::new(),
190            loop_iterations: HashMap::new(),
191            last_llm_output: None,
192            last_tool_output: None,
193        });
194        self.active_frame = self.frames.len().saturating_sub(1);
195    }
196
197    pub fn exit_to_parent_scope(&mut self) -> Result<(), ScopeAccessError> {
198        let parent = self.frames[self.active_frame]
199            .parent
200            .ok_or(ScopeAccessError::CannotExitRoot)?;
201        self.active_frame = parent;
202        Ok(())
203    }
204
205    pub fn scoped_input(&self, token: &CapabilityToken) -> Result<Value, ScopeAccessError> {
206        if !token.capability.allows_read() {
207            return Err(ScopeAccessError::ReadDenied {
208                capability: token.capability.as_str(),
209            });
210        }
211
212        let mut object = Map::new();
213        object.insert("input".to_string(), self.workflow_input.clone());
214        object.insert(
215            "last_llm_output".to_string(),
216            self.visible_last_llm_output()
217                .map_or(Value::Null, Value::String),
218        );
219        object.insert(
220            "last_tool_output".to_string(),
221            self.visible_last_tool_output().unwrap_or(Value::Null),
222        );
223        object.insert(
224            "node_outputs".to_string(),
225            Value::Object(
226                self.visible_node_outputs()
227                    .into_iter()
228                    .collect::<Map<String, Value>>(),
229            ),
230        );
231        Ok(Value::Object(object))
232    }
233
234    pub fn record_llm_output(
235        &mut self,
236        node_id: &str,
237        output: String,
238        token: &CapabilityToken,
239    ) -> Result<(), ScopeAccessError> {
240        self.ensure_write(token, ScopeCapability::LlmWrite)?;
241        let frame = &mut self.frames[self.active_frame];
242        frame.last_llm_output = Some(output.clone());
243        frame
244            .node_outputs
245            .insert(node_id.to_string(), Value::String(output));
246        Ok(())
247    }
248
249    pub fn record_tool_output(
250        &mut self,
251        node_id: &str,
252        output: Value,
253        token: &CapabilityToken,
254    ) -> Result<(), ScopeAccessError> {
255        self.ensure_write(token, ScopeCapability::ToolWrite)?;
256        let frame = &mut self.frames[self.active_frame];
257        frame.last_tool_output = Some(output.clone());
258        frame.node_outputs.insert(node_id.to_string(), output);
259        Ok(())
260    }
261
262    pub fn record_condition_output(
263        &mut self,
264        node_id: &str,
265        evaluated: bool,
266        token: &CapabilityToken,
267    ) -> Result<(), ScopeAccessError> {
268        self.ensure_write(token, ScopeCapability::ConditionWrite)?;
269        self.frames[self.active_frame]
270            .node_outputs
271            .insert(node_id.to_string(), Value::Bool(evaluated));
272        Ok(())
273    }
274
275    pub fn record_subgraph_output(
276        &mut self,
277        node_id: &str,
278        output: Value,
279        token: &CapabilityToken,
280    ) -> Result<(), ScopeAccessError> {
281        self.ensure_write(token, ScopeCapability::SubgraphWrite)?;
282        self.frames[self.active_frame]
283            .node_outputs
284            .insert(node_id.to_string(), output);
285        Ok(())
286    }
287
288    pub fn record_batch_output(
289        &mut self,
290        node_id: &str,
291        output: Value,
292        token: &CapabilityToken,
293    ) -> Result<(), ScopeAccessError> {
294        self.ensure_write(token, ScopeCapability::BatchWrite)?;
295        self.frames[self.active_frame]
296            .node_outputs
297            .insert(node_id.to_string(), output);
298        Ok(())
299    }
300
301    pub fn record_filter_output(
302        &mut self,
303        node_id: &str,
304        output: Value,
305        token: &CapabilityToken,
306    ) -> Result<(), ScopeAccessError> {
307        self.ensure_write(token, ScopeCapability::FilterWrite)?;
308        self.frames[self.active_frame]
309            .node_outputs
310            .insert(node_id.to_string(), output);
311        Ok(())
312    }
313
314    pub fn loop_iteration(&self, node_id: &str) -> u32 {
315        self.frames[self.active_frame]
316            .loop_iterations
317            .get(node_id)
318            .copied()
319            .unwrap_or(0)
320    }
321
322    pub fn set_loop_iteration(&mut self, node_id: &str, iteration: u32) {
323        self.frames[self.active_frame]
324            .loop_iterations
325            .insert(node_id.to_string(), iteration);
326    }
327
328    pub fn clear_loop_iteration(&mut self, node_id: &str) {
329        self.frames[self.active_frame]
330            .loop_iterations
331            .remove(node_id);
332    }
333
334    pub fn current_scope_node_outputs(&self) -> Value {
335        Value::Object(
336            self.frames[self.active_frame]
337                .node_outputs
338                .iter()
339                .map(|(k, v)| (k.clone(), v.clone()))
340                .collect(),
341        )
342    }
343
344    pub fn visible_node_outputs(&self) -> BTreeMap<String, Value> {
345        let mut chain = Vec::new();
346        let mut cursor = Some(self.active_frame);
347        while let Some(index) = cursor {
348            chain.push(index);
349            cursor = self.frames[index].parent;
350        }
351        chain.reverse();
352
353        let mut merged = BTreeMap::new();
354        for index in chain {
355            for (key, value) in &self.frames[index].node_outputs {
356                merged.insert(key.clone(), value.clone());
357            }
358        }
359        merged
360    }
361
362    fn visible_last_llm_output(&self) -> Option<String> {
363        let mut cursor = Some(self.active_frame);
364        while let Some(index) = cursor {
365            if let Some(value) = &self.frames[index].last_llm_output {
366                return Some(value.clone());
367            }
368            cursor = self.frames[index].parent;
369        }
370        None
371    }
372
373    fn visible_last_tool_output(&self) -> Option<Value> {
374        let mut cursor = Some(self.active_frame);
375        while let Some(index) = cursor {
376            if let Some(value) = &self.frames[index].last_tool_output {
377                return Some(value.clone());
378            }
379            cursor = self.frames[index].parent;
380        }
381        None
382    }
383
384    fn ensure_write(
385        &self,
386        token: &CapabilityToken,
387        required: ScopeCapability,
388    ) -> Result<(), ScopeAccessError> {
389        if token.capability != required {
390            return Err(ScopeAccessError::WriteDenied {
391                capability: token.capability.as_str(),
392            });
393        }
394        Ok(())
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use serde_json::json;
401
402    use super::{CapabilityToken, ScopeAccessError, ScopedState};
403
404    #[test]
405    fn enforces_capabilities() {
406        let mut state = ScopedState::new(json!({"input": true}));
407
408        let read_err = state
409            .scoped_input(&CapabilityToken::llm_write())
410            .expect_err("write token cannot read");
411        assert!(matches!(read_err, ScopeAccessError::ReadDenied { .. }));
412
413        let write_err = state
414            .record_tool_output("tool", json!({"ok": true}), &CapabilityToken::llm_write())
415            .expect_err("wrong write token should be rejected");
416        assert!(matches!(write_err, ScopeAccessError::WriteDenied { .. }));
417    }
418
419    #[test]
420    fn supports_parent_child_visibility() {
421        let mut state = ScopedState::new(json!({"req": 1}));
422        state
423            .record_tool_output("root_tool", json!({"x": 1}), &CapabilityToken::tool_write())
424            .expect("root write should succeed");
425
426        state.enter_child_scope();
427        state
428            .record_llm_output("child_llm", "ok".to_string(), &CapabilityToken::llm_write())
429            .expect("child write should succeed");
430
431        let scoped = state
432            .scoped_input(&CapabilityToken::condition_read())
433            .expect("read should include parent and child outputs");
434        assert!(scoped["node_outputs"]["root_tool"].is_object());
435        assert_eq!(scoped["node_outputs"]["child_llm"], json!("ok"));
436
437        state
438            .exit_to_parent_scope()
439            .expect("child scope should exit to parent");
440        let scoped_parent = state
441            .scoped_input(&CapabilityToken::condition_read())
442            .expect("parent read should work");
443        assert!(scoped_parent["node_outputs"].get("child_llm").is_none());
444    }
445}