1use serde_json::Value;
4use std::collections::HashMap;
5use std::path::PathBuf;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum AgentState {
10 Running,
11 Completed,
12 Failed,
13 Cancelled,
14 WaitingInput,
15}
16
17#[derive(Debug, Clone)]
19pub struct AgentContext {
20 pub iteration: usize,
22 pub state: AgentState,
24 pub cwd: PathBuf,
26 pub custom: HashMap<String, Value>,
28 pub tool_configs: HashMap<String, Value>,
31 pub writable_roots: Vec<PathBuf>,
34}
35
36impl AgentContext {
37 pub fn new() -> Self {
38 Self {
39 iteration: 0,
40 state: AgentState::Running,
41 cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
42 custom: HashMap::new(),
43 tool_configs: HashMap::new(),
44 writable_roots: Vec::new(),
45 }
46 }
47
48 pub fn with_cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
49 self.cwd = cwd.into();
50 self
51 }
52
53 pub fn with_writable_roots(mut self, roots: Vec<PathBuf>) -> Self {
54 self.writable_roots = roots;
55 self
56 }
57
58 pub fn is_writable(&self, path: &std::path::Path) -> bool {
62 if self.writable_roots.is_empty() {
63 return true;
64 }
65 let abs_path = if path.is_absolute() {
66 path.to_path_buf()
67 } else {
68 self.cwd.join(path)
69 };
70 let resolved = std::fs::canonicalize(&abs_path).unwrap_or_else(|_| {
73 if let Some(parent) = abs_path.parent()
75 && let Ok(canon_parent) = std::fs::canonicalize(parent)
76 && let Some(name) = abs_path.file_name()
77 {
78 return canon_parent.join(name);
79 }
80 abs_path.clone()
81 });
82 self.writable_roots.iter().any(|root| {
83 let canon_root = std::fs::canonicalize(root).unwrap_or_else(|_| root.clone());
85 resolved.starts_with(&canon_root)
86 })
87 }
88
89 pub fn set(&mut self, key: impl Into<String>, value: Value) {
91 self.custom.insert(key.into(), value);
92 }
93
94 pub fn get(&self, key: &str) -> Option<&Value> {
96 self.custom.get(key)
97 }
98
99 pub fn set_tool_config(&mut self, tool_name: impl Into<String>, config: Value) {
101 self.tool_configs.insert(tool_name.into(), config);
102 }
103
104 pub fn tool_config(&self, tool_name: &str) -> Option<&Value> {
106 self.tool_configs.get(tool_name)
107 }
108
109 pub fn merged_tool_config(&self, tool_name: &str, base: &Value) -> Value {
112 match (base, self.tool_configs.get(tool_name)) {
113 (Value::Object(base_obj), Some(Value::Object(override_obj))) => {
114 let mut merged = base_obj.clone();
115 for (k, v) in override_obj {
116 merged.insert(k.clone(), v.clone());
117 }
118 Value::Object(merged)
119 }
120 (_, Some(override_val)) => override_val.clone(),
121 _ => base.clone(),
122 }
123 }
124}
125
126impl Default for AgentContext {
127 fn default() -> Self {
128 Self::new()
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 #[test]
137 fn context_default_state() {
138 let ctx = AgentContext::new();
139 assert_eq!(ctx.state, AgentState::Running);
140 assert_eq!(ctx.iteration, 0);
141 }
142
143 #[test]
144 fn context_custom_data() {
145 let mut ctx = AgentContext::new();
146 ctx.set("project", serde_json::json!("my-project"));
147 assert_eq!(ctx.get("project").unwrap(), "my-project");
148 assert!(ctx.get("missing").is_none());
149 }
150
151 #[test]
152 fn context_with_cwd() {
153 let ctx = AgentContext::new().with_cwd("/tmp/test");
154 assert_eq!(ctx.cwd, PathBuf::from("/tmp/test"));
155 }
156
157 #[test]
158 fn tool_config_set_get() {
159 let mut ctx = AgentContext::new();
160 ctx.set_tool_config("bash", serde_json::json!({"timeout": 30}));
161 assert_eq!(ctx.tool_config("bash").unwrap()["timeout"], 30);
162 assert!(ctx.tool_config("read_file").is_none());
163 }
164
165 #[test]
166 fn tool_config_merge() {
167 let mut ctx = AgentContext::new();
168 ctx.set_tool_config("bash", serde_json::json!({"timeout": 60, "shell": "zsh"}));
169
170 let base = serde_json::json!({"timeout": 30, "cwd": "/tmp"});
171 let merged = ctx.merged_tool_config("bash", &base);
172 assert_eq!(merged["timeout"], 60);
174 assert_eq!(merged["cwd"], "/tmp");
175 assert_eq!(merged["shell"], "zsh");
176 }
177
178 #[test]
179 fn tool_config_merge_no_override() {
180 let ctx = AgentContext::new();
181 let base = serde_json::json!({"timeout": 30});
182 let merged = ctx.merged_tool_config("bash", &base);
183 assert_eq!(merged, base);
184 }
185
186 #[test]
187 fn writable_roots_empty_allows_all() {
188 let ctx = AgentContext::new();
189 assert!(ctx.is_writable(std::path::Path::new("/any/path")));
190 }
191
192 #[test]
193 fn writable_roots_restricts() {
194 let ctx =
195 AgentContext::new().with_writable_roots(vec![PathBuf::from("/home/user/project")]);
196 assert!(ctx.is_writable(std::path::Path::new("/home/user/project/src/main.rs")));
197 assert!(!ctx.is_writable(std::path::Path::new("/etc/passwd")));
198 }
199
200 #[test]
201 fn writable_roots_relative_path() {
202 let ctx = AgentContext::new()
203 .with_cwd("/home/user/project")
204 .with_writable_roots(vec![PathBuf::from("/home/user/project")]);
205 assert!(ctx.is_writable(std::path::Path::new("src/main.rs")));
206 assert!(!ctx.is_writable(std::path::Path::new("/etc/passwd")));
207 }
208}