1use serde_json::Value;
4use std::any::{Any, TypeId};
5use std::collections::{HashMap, VecDeque};
6use std::path::PathBuf;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum AgentState {
11 Running,
12 Completed,
13 Failed,
14 Cancelled,
15 WaitingInput,
16}
17
18pub const MAX_TOKENS_OVERRIDE_KEY: &str = "_max_tokens_override";
20
21#[derive(Clone)]
27pub struct AgentContext {
28 pub iteration: usize,
29 pub state: AgentState,
30 pub cwd: PathBuf,
31 pub custom: HashMap<String, Value>,
33 pub tool_configs: HashMap<String, Value>,
35 pub writable_roots: Vec<PathBuf>,
37 observations: VecDeque<String>,
39 pub observation_limit: usize,
40 pub tool_cache: HashMap<String, String>,
42 typed: HashMap<TypeId, TypedSlot>,
44}
45
46struct TypedSlot {
50 data: Box<dyn Any + Send + Sync>,
51 clone_fn: fn(&Box<dyn Any + Send + Sync>) -> Box<dyn Any + Send + Sync>,
52}
53
54impl Clone for TypedSlot {
55 fn clone(&self) -> Self {
56 Self {
57 data: (self.clone_fn)(&self.data),
58 clone_fn: self.clone_fn,
59 }
60 }
61}
62
63fn make_clone_fn<T: Clone + Send + Sync + 'static>()
64-> fn(&Box<dyn Any + Send + Sync>) -> Box<dyn Any + Send + Sync> {
65 |data| {
66 let val = data.downcast_ref::<T>().expect("TypedSlot type mismatch");
67 Box::new(val.clone())
68 }
69}
70
71impl std::fmt::Debug for AgentContext {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.debug_struct("AgentContext")
74 .field("iteration", &self.iteration)
75 .field("state", &self.state)
76 .field("cwd", &self.cwd)
77 .field("custom_keys", &self.custom.keys().collect::<Vec<_>>())
78 .field("typed_count", &self.typed.len())
79 .field("observations", &self.observations.len())
80 .finish()
81 }
82}
83
84impl AgentContext {
85 pub fn new() -> Self {
86 Self {
87 iteration: 0,
88 state: AgentState::Running,
89 cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
90 custom: HashMap::new(),
91 tool_configs: HashMap::new(),
92 writable_roots: Vec::new(),
93 observations: VecDeque::new(),
94 observation_limit: 30,
95 tool_cache: HashMap::new(),
96 typed: HashMap::new(),
97 }
98 }
99
100 pub fn insert<T: Clone + Send + Sync + 'static>(&mut self, value: T) {
110 self.typed.insert(
111 TypeId::of::<T>(),
112 TypedSlot {
113 data: Box::new(value),
114 clone_fn: make_clone_fn::<T>(),
115 },
116 );
117 }
118
119 pub fn get_typed<T: Clone + Send + Sync + 'static>(&self) -> Option<&T> {
121 self.typed
122 .get(&TypeId::of::<T>())
123 .and_then(|slot| slot.data.downcast_ref())
124 }
125
126 pub fn remove_typed<T: Clone + Send + Sync + 'static>(&mut self) -> Option<T> {
128 self.typed
129 .remove(&TypeId::of::<T>())
130 .and_then(|slot| slot.data.downcast::<T>().ok().map(|b| *b))
131 }
132
133 pub fn observe(&mut self, entry: impl Into<String>) {
137 self.observations.push_back(entry.into());
138 while self.observations.len() > self.observation_limit {
139 self.observations.pop_front();
140 }
141 }
142
143 pub fn observation_summary(&self) -> Option<String> {
145 if self.observations.is_empty() {
146 None
147 } else {
148 let joined: String = self
149 .observations
150 .iter()
151 .cloned()
152 .collect::<Vec<_>>()
153 .join("\n");
154 Some(format!("OBSERVATION LOG:\n{joined}"))
155 }
156 }
157
158 pub fn cache_tool_result(&mut self, key: impl Into<String>, result: impl Into<String>) {
161 self.tool_cache.insert(key.into(), result.into());
162 }
163
164 pub fn cached_tool_result(&self, key: &str) -> Option<&str> {
165 self.tool_cache.get(key).map(|s| s.as_str())
166 }
167
168 pub fn invalidate_cache(&mut self, key: &str) {
169 self.tool_cache.remove(key);
170 }
171
172 pub fn with_cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
175 self.cwd = cwd.into();
176 self
177 }
178
179 pub fn with_writable_roots(mut self, roots: Vec<PathBuf>) -> Self {
180 self.writable_roots = roots;
181 self
182 }
183
184 pub fn is_writable(&self, path: &std::path::Path) -> bool {
187 if self.writable_roots.is_empty() {
188 return true;
189 }
190 let abs_path = if path.is_absolute() {
191 path.to_path_buf()
192 } else {
193 self.cwd.join(path)
194 };
195 let resolved = std::fs::canonicalize(&abs_path).unwrap_or_else(|_| {
196 if let Some(parent) = abs_path.parent()
197 && let Ok(canon_parent) = std::fs::canonicalize(parent)
198 && let Some(name) = abs_path.file_name()
199 {
200 return canon_parent.join(name);
201 }
202 abs_path.clone()
203 });
204 self.writable_roots.iter().any(|root| {
205 let canon_root = std::fs::canonicalize(root).unwrap_or_else(|_| root.clone());
206 resolved.starts_with(&canon_root)
207 })
208 }
209
210 pub fn set(&mut self, key: impl Into<String>, value: Value) {
213 self.custom.insert(key.into(), value);
214 }
215
216 pub fn get(&self, key: &str) -> Option<&Value> {
217 self.custom.get(key)
218 }
219
220 pub fn max_tokens_override(&self) -> Option<u32> {
221 self.custom
222 .get(MAX_TOKENS_OVERRIDE_KEY)
223 .and_then(|v| v.as_u64())
224 .map(|v| v as u32)
225 }
226
227 pub fn set_tool_config(&mut self, tool_name: impl Into<String>, config: Value) {
230 self.tool_configs.insert(tool_name.into(), config);
231 }
232
233 pub fn tool_config(&self, tool_name: &str) -> Option<&Value> {
234 self.tool_configs.get(tool_name)
235 }
236
237 pub fn merged_tool_config(&self, tool_name: &str, base: &Value) -> Value {
238 match (base, self.tool_configs.get(tool_name)) {
239 (Value::Object(base_obj), Some(Value::Object(override_obj))) => {
240 let mut merged = base_obj.clone();
241 for (k, v) in override_obj {
242 merged.insert(k.clone(), v.clone());
243 }
244 Value::Object(merged)
245 }
246 (_, Some(override_val)) => override_val.clone(),
247 _ => base.clone(),
248 }
249 }
250}
251
252impl Default for AgentContext {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn context_default_state() {
264 let ctx = AgentContext::new();
265 assert_eq!(ctx.state, AgentState::Running);
266 assert_eq!(ctx.iteration, 0);
267 }
268
269 #[test]
270 fn context_custom_data() {
271 let mut ctx = AgentContext::new();
272 ctx.set("project", serde_json::json!("my-project"));
273 assert_eq!(ctx.get("project").unwrap(), "my-project");
274 assert!(ctx.get("missing").is_none());
275 }
276
277 #[test]
278 fn context_with_cwd() {
279 let ctx = AgentContext::new().with_cwd("/tmp/test");
280 assert_eq!(ctx.cwd, PathBuf::from("/tmp/test"));
281 }
282
283 #[test]
284 fn tool_config_merge() {
285 let mut ctx = AgentContext::new();
286 ctx.set_tool_config("bash", serde_json::json!({"timeout": 60, "shell": "zsh"}));
287 let base = serde_json::json!({"timeout": 30, "cwd": "/tmp"});
288 let merged = ctx.merged_tool_config("bash", &base);
289 assert_eq!(merged["timeout"], 60);
290 assert_eq!(merged["cwd"], "/tmp");
291 assert_eq!(merged["shell"], "zsh");
292 }
293
294 #[test]
295 fn typed_store() {
296 #[derive(Clone, Debug, PartialEq)]
297 struct MyState {
298 count: usize,
299 }
300
301 let mut ctx = AgentContext::new();
302 assert!(ctx.get_typed::<MyState>().is_none());
303
304 ctx.insert(MyState { count: 42 });
305 assert_eq!(ctx.get_typed::<MyState>().unwrap().count, 42);
306
307 #[derive(Clone)]
309 struct OtherState(String);
310 ctx.insert(OtherState("hello".into()));
311 assert_eq!(ctx.get_typed::<MyState>().unwrap().count, 42);
312 }
313
314 #[test]
315 fn typed_store_clone() {
316 #[derive(Clone, PartialEq, Debug)]
317 struct S(u32);
318
319 let mut ctx = AgentContext::new();
320 ctx.insert(S(7));
321
322 let ctx2 = ctx.clone();
323 assert_eq!(ctx2.get_typed::<S>().unwrap(), &S(7));
324 }
325
326 #[test]
327 fn observations_fifo() {
328 let mut ctx = AgentContext::new();
329 ctx.observation_limit = 3;
330 ctx.observe("a");
331 ctx.observe("b");
332 ctx.observe("c");
333 ctx.observe("d"); let summary = ctx.observation_summary().unwrap();
335 assert!(!summary.contains("\na\n"));
336 assert!(summary.contains("b"));
337 assert!(summary.contains("d"));
338 }
339}