Skip to main content

sgr_agent_core/
context.rs

1//! Agent execution context — shared state passed to tools during execution.
2
3use serde_json::Value;
4use std::any::{Any, TypeId};
5use std::collections::{HashMap, VecDeque};
6use std::path::PathBuf;
7
8/// Agent execution state.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum AgentState {
11    Running,
12    Completed,
13    Failed,
14    Cancelled,
15    WaitingInput,
16}
17
18/// Well-known key for max_tokens override (legacy string-key compat).
19pub const MAX_TOKENS_OVERRIDE_KEY: &str = "_max_tokens_override";
20
21/// Shared context passed to tools during execution.
22///
23/// Two ways to store custom state:
24/// - **Typed** (preferred): `ctx.insert::<MyState>(state)` / `ctx.get_typed::<MyState>()`
25/// - **String-keyed** (legacy): `ctx.set("key", json_value)` / `ctx.get("key")`
26#[derive(Clone)]
27pub struct AgentContext {
28    pub iteration: usize,
29    pub state: AgentState,
30    pub cwd: PathBuf,
31    /// String-keyed extensible state (legacy — prefer typed store).
32    pub custom: HashMap<String, Value>,
33    /// Per-tool configuration overrides.
34    pub tool_configs: HashMap<String, Value>,
35    /// Sandbox: writable directory roots (empty = no restriction).
36    pub writable_roots: Vec<PathBuf>,
37    /// Compressed observation log (FIFO, capped at `observation_limit`).
38    observations: VecDeque<String>,
39    pub observation_limit: usize,
40    /// Tool result cache — keyed by "tool_name:arg_hash".
41    pub tool_cache: HashMap<String, String>,
42    /// Type-safe extensible store. Projects store typed data without string-key collisions.
43    typed: HashMap<TypeId, TypedSlot>,
44}
45
46// -- Typed store support --
47
48/// Wrapper that stores a concrete Clone + Send + Sync + 'static value as trait object.
49struct TypedSlot {
50    data: Box<dyn Any + Send + Sync>,
51    clone_fn: fn(&Box<dyn Any + Send + Sync>) -> Box<dyn Any + Send + Sync>,
52}
53
54impl Clone for TypedSlot {
55    fn clone(&self) -> Self {
56        Self {
57            data: (self.clone_fn)(&self.data),
58            clone_fn: self.clone_fn,
59        }
60    }
61}
62
63fn make_clone_fn<T: Clone + Send + Sync + 'static>()
64-> fn(&Box<dyn Any + Send + Sync>) -> Box<dyn Any + Send + Sync> {
65    |data| {
66        let val = data.downcast_ref::<T>().expect("TypedSlot type mismatch");
67        Box::new(val.clone())
68    }
69}
70
71impl std::fmt::Debug for AgentContext {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.debug_struct("AgentContext")
74            .field("iteration", &self.iteration)
75            .field("state", &self.state)
76            .field("cwd", &self.cwd)
77            .field("custom_keys", &self.custom.keys().collect::<Vec<_>>())
78            .field("typed_count", &self.typed.len())
79            .field("observations", &self.observations.len())
80            .finish()
81    }
82}
83
84impl AgentContext {
85    pub fn new() -> Self {
86        Self {
87            iteration: 0,
88            state: AgentState::Running,
89            cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
90            custom: HashMap::new(),
91            tool_configs: HashMap::new(),
92            writable_roots: Vec::new(),
93            observations: VecDeque::new(),
94            observation_limit: 30,
95            tool_cache: HashMap::new(),
96            typed: HashMap::new(),
97        }
98    }
99
100    // -- Typed store (preferred API) --
101
102    /// Insert typed state. Each type T gets exactly one slot — no string-key collisions.
103    ///
104    /// ```rust,ignore
105    /// #[derive(Clone)]
106    /// struct MyToolState { count: usize }
107    /// ctx.insert(MyToolState { count: 0 });
108    /// ```
109    pub fn insert<T: Clone + Send + Sync + 'static>(&mut self, value: T) {
110        self.typed.insert(
111            TypeId::of::<T>(),
112            TypedSlot {
113                data: Box::new(value),
114                clone_fn: make_clone_fn::<T>(),
115            },
116        );
117    }
118
119    /// Get typed state by type.
120    pub fn get_typed<T: Clone + Send + Sync + 'static>(&self) -> Option<&T> {
121        self.typed
122            .get(&TypeId::of::<T>())
123            .and_then(|slot| slot.data.downcast_ref())
124    }
125
126    /// Remove typed state, returning it if present.
127    pub fn remove_typed<T: Clone + Send + Sync + 'static>(&mut self) -> Option<T> {
128        self.typed
129            .remove(&TypeId::of::<T>())
130            .and_then(|slot| slot.data.downcast::<T>().ok().map(|b| *b))
131    }
132
133    // -- Observations (VecDeque, O(1) eviction) --
134
135    /// Record a compressed observation.
136    pub fn observe(&mut self, entry: impl Into<String>) {
137        self.observations.push_back(entry.into());
138        while self.observations.len() > self.observation_limit {
139            self.observations.pop_front();
140        }
141    }
142
143    /// Get observation log as a single string for LLM context injection.
144    pub fn observation_summary(&self) -> Option<String> {
145        if self.observations.is_empty() {
146            None
147        } else {
148            let joined: String = self
149                .observations
150                .iter()
151                .cloned()
152                .collect::<Vec<_>>()
153                .join("\n");
154            Some(format!("OBSERVATION LOG:\n{joined}"))
155        }
156    }
157
158    // -- Tool cache --
159
160    pub fn cache_tool_result(&mut self, key: impl Into<String>, result: impl Into<String>) {
161        self.tool_cache.insert(key.into(), result.into());
162    }
163
164    pub fn cached_tool_result(&self, key: &str) -> Option<&str> {
165        self.tool_cache.get(key).map(|s| s.as_str())
166    }
167
168    pub fn invalidate_cache(&mut self, key: &str) {
169        self.tool_cache.remove(key);
170    }
171
172    // -- Builders --
173
174    pub fn with_cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
175        self.cwd = cwd.into();
176        self
177    }
178
179    pub fn with_writable_roots(mut self, roots: Vec<PathBuf>) -> Self {
180        self.writable_roots = roots;
181        self
182    }
183
184    // -- Sandbox --
185
186    pub fn is_writable(&self, path: &std::path::Path) -> bool {
187        if self.writable_roots.is_empty() {
188            return true;
189        }
190        let abs_path = if path.is_absolute() {
191            path.to_path_buf()
192        } else {
193            self.cwd.join(path)
194        };
195        let resolved = std::fs::canonicalize(&abs_path).unwrap_or_else(|_| {
196            if let Some(parent) = abs_path.parent()
197                && let Ok(canon_parent) = std::fs::canonicalize(parent)
198                && let Some(name) = abs_path.file_name()
199            {
200                return canon_parent.join(name);
201            }
202            abs_path.clone()
203        });
204        self.writable_roots.iter().any(|root| {
205            let canon_root = std::fs::canonicalize(root).unwrap_or_else(|_| root.clone());
206            resolved.starts_with(&canon_root)
207        })
208    }
209
210    // -- Legacy string-keyed custom data --
211
212    pub fn set(&mut self, key: impl Into<String>, value: Value) {
213        self.custom.insert(key.into(), value);
214    }
215
216    pub fn get(&self, key: &str) -> Option<&Value> {
217        self.custom.get(key)
218    }
219
220    pub fn max_tokens_override(&self) -> Option<u32> {
221        self.custom
222            .get(MAX_TOKENS_OVERRIDE_KEY)
223            .and_then(|v| v.as_u64())
224            .map(|v| v as u32)
225    }
226
227    // -- Per-tool config --
228
229    pub fn set_tool_config(&mut self, tool_name: impl Into<String>, config: Value) {
230        self.tool_configs.insert(tool_name.into(), config);
231    }
232
233    pub fn tool_config(&self, tool_name: &str) -> Option<&Value> {
234        self.tool_configs.get(tool_name)
235    }
236
237    pub fn merged_tool_config(&self, tool_name: &str, base: &Value) -> Value {
238        match (base, self.tool_configs.get(tool_name)) {
239            (Value::Object(base_obj), Some(Value::Object(override_obj))) => {
240                let mut merged = base_obj.clone();
241                for (k, v) in override_obj {
242                    merged.insert(k.clone(), v.clone());
243                }
244                Value::Object(merged)
245            }
246            (_, Some(override_val)) => override_val.clone(),
247            _ => base.clone(),
248        }
249    }
250}
251
252impl Default for AgentContext {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn context_default_state() {
264        let ctx = AgentContext::new();
265        assert_eq!(ctx.state, AgentState::Running);
266        assert_eq!(ctx.iteration, 0);
267    }
268
269    #[test]
270    fn context_custom_data() {
271        let mut ctx = AgentContext::new();
272        ctx.set("project", serde_json::json!("my-project"));
273        assert_eq!(ctx.get("project").unwrap(), "my-project");
274        assert!(ctx.get("missing").is_none());
275    }
276
277    #[test]
278    fn context_with_cwd() {
279        let ctx = AgentContext::new().with_cwd("/tmp/test");
280        assert_eq!(ctx.cwd, PathBuf::from("/tmp/test"));
281    }
282
283    #[test]
284    fn tool_config_merge() {
285        let mut ctx = AgentContext::new();
286        ctx.set_tool_config("bash", serde_json::json!({"timeout": 60, "shell": "zsh"}));
287        let base = serde_json::json!({"timeout": 30, "cwd": "/tmp"});
288        let merged = ctx.merged_tool_config("bash", &base);
289        assert_eq!(merged["timeout"], 60);
290        assert_eq!(merged["cwd"], "/tmp");
291        assert_eq!(merged["shell"], "zsh");
292    }
293
294    #[test]
295    fn typed_store() {
296        #[derive(Clone, Debug, PartialEq)]
297        struct MyState {
298            count: usize,
299        }
300
301        let mut ctx = AgentContext::new();
302        assert!(ctx.get_typed::<MyState>().is_none());
303
304        ctx.insert(MyState { count: 42 });
305        assert_eq!(ctx.get_typed::<MyState>().unwrap().count, 42);
306
307        // Different types don't collide
308        #[derive(Clone)]
309        struct OtherState(String);
310        ctx.insert(OtherState("hello".into()));
311        assert_eq!(ctx.get_typed::<MyState>().unwrap().count, 42);
312    }
313
314    #[test]
315    fn typed_store_clone() {
316        #[derive(Clone, PartialEq, Debug)]
317        struct S(u32);
318
319        let mut ctx = AgentContext::new();
320        ctx.insert(S(7));
321
322        let ctx2 = ctx.clone();
323        assert_eq!(ctx2.get_typed::<S>().unwrap(), &S(7));
324    }
325
326    #[test]
327    fn observations_fifo() {
328        let mut ctx = AgentContext::new();
329        ctx.observation_limit = 3;
330        ctx.observe("a");
331        ctx.observe("b");
332        ctx.observe("c");
333        ctx.observe("d"); // evicts "a"
334        let summary = ctx.observation_summary().unwrap();
335        assert!(!summary.contains("\na\n"));
336        assert!(summary.contains("b"));
337        assert!(summary.contains("d"));
338    }
339}