sgr_agent_core/
context.rs1use serde_json::Value;
4use std::collections::HashMap;
5use std::path::PathBuf;
6
7pub const MAX_TOKENS_OVERRIDE_KEY: &str = "_max_tokens_override";
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum AgentState {
13 Running,
14 Completed,
15 Failed,
16 Cancelled,
17 WaitingInput,
18}
19
20#[derive(Debug, Clone)]
22pub struct AgentContext {
23 pub iteration: usize,
24 pub state: AgentState,
25 pub cwd: PathBuf,
26 pub custom: HashMap<String, Value>,
27 pub tool_configs: HashMap<String, Value>,
28 pub writable_roots: Vec<PathBuf>,
29 pub observations: Vec<String>,
30 pub observation_limit: usize,
31 pub tool_cache: HashMap<String, String>,
32}
33
34impl AgentContext {
35 pub fn new() -> Self {
36 Self {
37 iteration: 0,
38 state: AgentState::Running,
39 cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
40 custom: HashMap::new(),
41 tool_configs: HashMap::new(),
42 writable_roots: Vec::new(),
43 observations: Vec::new(),
44 observation_limit: 30,
45 tool_cache: HashMap::new(),
46 }
47 }
48
49 pub fn observe(&mut self, entry: impl Into<String>) {
50 self.observations.push(entry.into());
51 while self.observations.len() > self.observation_limit {
52 self.observations.remove(0);
53 }
54 }
55
56 pub fn observation_summary(&self) -> Option<String> {
57 if self.observations.is_empty() {
58 None
59 } else {
60 Some(format!(
61 "OBSERVATION LOG:\n{}",
62 self.observations.join("\n")
63 ))
64 }
65 }
66
67 pub fn cache_tool_result(&mut self, key: impl Into<String>, result: impl Into<String>) {
68 self.tool_cache.insert(key.into(), result.into());
69 }
70
71 pub fn cached_tool_result(&self, key: &str) -> Option<&str> {
72 self.tool_cache.get(key).map(|s| s.as_str())
73 }
74
75 pub fn invalidate_cache(&mut self, key: &str) {
76 self.tool_cache.remove(key);
77 }
78
79 pub fn with_cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
80 self.cwd = cwd.into();
81 self
82 }
83
84 pub fn with_writable_roots(mut self, roots: Vec<PathBuf>) -> Self {
85 self.writable_roots = roots;
86 self
87 }
88
89 pub fn is_writable(&self, path: &std::path::Path) -> bool {
90 if self.writable_roots.is_empty() {
91 return true;
92 }
93 let abs_path = if path.is_absolute() {
94 path.to_path_buf()
95 } else {
96 self.cwd.join(path)
97 };
98 let resolved = std::fs::canonicalize(&abs_path).unwrap_or_else(|_| {
99 if let Some(parent) = abs_path.parent()
100 && let Ok(canon_parent) = std::fs::canonicalize(parent)
101 && let Some(name) = abs_path.file_name()
102 {
103 return canon_parent.join(name);
104 }
105 abs_path.clone()
106 });
107 self.writable_roots.iter().any(|root| {
108 let canon_root = std::fs::canonicalize(root).unwrap_or_else(|_| root.clone());
109 resolved.starts_with(&canon_root)
110 })
111 }
112
113 pub fn set(&mut self, key: impl Into<String>, value: Value) {
114 self.custom.insert(key.into(), value);
115 }
116
117 pub fn get(&self, key: &str) -> Option<&Value> {
118 self.custom.get(key)
119 }
120
121 pub fn max_tokens_override(&self) -> Option<u32> {
122 self.custom
123 .get(MAX_TOKENS_OVERRIDE_KEY)
124 .and_then(|v| v.as_u64())
125 .map(|v| v as u32)
126 }
127
128 pub fn set_tool_config(&mut self, tool_name: impl Into<String>, config: Value) {
129 self.tool_configs.insert(tool_name.into(), config);
130 }
131
132 pub fn tool_config(&self, tool_name: &str) -> Option<&Value> {
133 self.tool_configs.get(tool_name)
134 }
135
136 pub fn merged_tool_config(&self, tool_name: &str, base: &Value) -> Value {
137 match (base, self.tool_configs.get(tool_name)) {
138 (Value::Object(base_obj), Some(Value::Object(override_obj))) => {
139 let mut merged = base_obj.clone();
140 for (k, v) in override_obj {
141 merged.insert(k.clone(), v.clone());
142 }
143 Value::Object(merged)
144 }
145 (_, Some(override_val)) => override_val.clone(),
146 _ => base.clone(),
147 }
148 }
149}
150
151impl Default for AgentContext {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn context_default_state() {
163 let ctx = AgentContext::new();
164 assert_eq!(ctx.state, AgentState::Running);
165 assert_eq!(ctx.iteration, 0);
166 }
167
168 #[test]
169 fn context_custom_data() {
170 let mut ctx = AgentContext::new();
171 ctx.set("project", serde_json::json!("my-project"));
172 assert_eq!(ctx.get("project").unwrap(), "my-project");
173 assert!(ctx.get("missing").is_none());
174 }
175
176 #[test]
177 fn context_with_cwd() {
178 let ctx = AgentContext::new().with_cwd("/tmp/test");
179 assert_eq!(ctx.cwd, PathBuf::from("/tmp/test"));
180 }
181
182 #[test]
183 fn tool_config_merge() {
184 let mut ctx = AgentContext::new();
185 ctx.set_tool_config("bash", serde_json::json!({"timeout": 60, "shell": "zsh"}));
186 let base = serde_json::json!({"timeout": 30, "cwd": "/tmp"});
187 let merged = ctx.merged_tool_config("bash", &base);
188 assert_eq!(merged["timeout"], 60);
189 assert_eq!(merged["cwd"], "/tmp");
190 assert_eq!(merged["shell"], "zsh");
191 }
192}