sgr_agent_core/
context.rs1use serde_json::Value;
4use std::collections::HashMap;
5use std::path::PathBuf;
6
7pub const MAX_TOKENS_OVERRIDE_KEY: &str = "_max_tokens_override";
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum AgentState {
13 Running,
14 Completed,
15 Failed,
16 Cancelled,
17 WaitingInput,
18}
19
20#[derive(Debug, Clone)]
22pub struct AgentContext {
23 pub iteration: usize,
24 pub state: AgentState,
25 pub cwd: PathBuf,
26 pub custom: HashMap<String, Value>,
27 pub tool_configs: HashMap<String, Value>,
28 pub writable_roots: Vec<PathBuf>,
29 pub observations: Vec<String>,
30 pub observation_limit: usize,
31 pub tool_cache: HashMap<String, String>,
32}
33
34impl AgentContext {
35 pub fn new() -> Self {
36 Self {
37 iteration: 0,
38 state: AgentState::Running,
39 cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
40 custom: HashMap::new(),
41 tool_configs: HashMap::new(),
42 writable_roots: Vec::new(),
43 observations: Vec::new(),
44 observation_limit: 30,
45 tool_cache: HashMap::new(),
46 }
47 }
48
49 pub fn observe(&mut self, entry: impl Into<String>) {
50 self.observations.push(entry.into());
51 while self.observations.len() > self.observation_limit {
52 self.observations.remove(0);
53 }
54 }
55
56 pub fn observation_summary(&self) -> Option<String> {
57 if self.observations.is_empty() {
58 None
59 } else {
60 Some(format!("OBSERVATION LOG:\n{}", self.observations.join("\n")))
61 }
62 }
63
64 pub fn cache_tool_result(&mut self, key: impl Into<String>, result: impl Into<String>) {
65 self.tool_cache.insert(key.into(), result.into());
66 }
67
68 pub fn cached_tool_result(&self, key: &str) -> Option<&str> {
69 self.tool_cache.get(key).map(|s| s.as_str())
70 }
71
72 pub fn invalidate_cache(&mut self, key: &str) {
73 self.tool_cache.remove(key);
74 }
75
76 pub fn with_cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
77 self.cwd = cwd.into();
78 self
79 }
80
81 pub fn with_writable_roots(mut self, roots: Vec<PathBuf>) -> Self {
82 self.writable_roots = roots;
83 self
84 }
85
86 pub fn is_writable(&self, path: &std::path::Path) -> bool {
87 if self.writable_roots.is_empty() {
88 return true;
89 }
90 let abs_path = if path.is_absolute() {
91 path.to_path_buf()
92 } else {
93 self.cwd.join(path)
94 };
95 let resolved = std::fs::canonicalize(&abs_path).unwrap_or_else(|_| {
96 if let Some(parent) = abs_path.parent()
97 && let Ok(canon_parent) = std::fs::canonicalize(parent)
98 && let Some(name) = abs_path.file_name()
99 {
100 return canon_parent.join(name);
101 }
102 abs_path.clone()
103 });
104 self.writable_roots.iter().any(|root| {
105 let canon_root = std::fs::canonicalize(root).unwrap_or_else(|_| root.clone());
106 resolved.starts_with(&canon_root)
107 })
108 }
109
110 pub fn set(&mut self, key: impl Into<String>, value: Value) {
111 self.custom.insert(key.into(), value);
112 }
113
114 pub fn get(&self, key: &str) -> Option<&Value> {
115 self.custom.get(key)
116 }
117
118 pub fn max_tokens_override(&self) -> Option<u32> {
119 self.custom
120 .get(MAX_TOKENS_OVERRIDE_KEY)
121 .and_then(|v| v.as_u64())
122 .map(|v| v as u32)
123 }
124
125 pub fn set_tool_config(&mut self, tool_name: impl Into<String>, config: Value) {
126 self.tool_configs.insert(tool_name.into(), config);
127 }
128
129 pub fn tool_config(&self, tool_name: &str) -> Option<&Value> {
130 self.tool_configs.get(tool_name)
131 }
132
133 pub fn merged_tool_config(&self, tool_name: &str, base: &Value) -> Value {
134 match (base, self.tool_configs.get(tool_name)) {
135 (Value::Object(base_obj), Some(Value::Object(override_obj))) => {
136 let mut merged = base_obj.clone();
137 for (k, v) in override_obj {
138 merged.insert(k.clone(), v.clone());
139 }
140 Value::Object(merged)
141 }
142 (_, Some(override_val)) => override_val.clone(),
143 _ => base.clone(),
144 }
145 }
146}
147
148impl Default for AgentContext {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[test]
159 fn context_default_state() {
160 let ctx = AgentContext::new();
161 assert_eq!(ctx.state, AgentState::Running);
162 assert_eq!(ctx.iteration, 0);
163 }
164
165 #[test]
166 fn context_custom_data() {
167 let mut ctx = AgentContext::new();
168 ctx.set("project", serde_json::json!("my-project"));
169 assert_eq!(ctx.get("project").unwrap(), "my-project");
170 assert!(ctx.get("missing").is_none());
171 }
172
173 #[test]
174 fn context_with_cwd() {
175 let ctx = AgentContext::new().with_cwd("/tmp/test");
176 assert_eq!(ctx.cwd, PathBuf::from("/tmp/test"));
177 }
178
179 #[test]
180 fn tool_config_merge() {
181 let mut ctx = AgentContext::new();
182 ctx.set_tool_config("bash", serde_json::json!({"timeout": 60, "shell": "zsh"}));
183 let base = serde_json::json!({"timeout": 30, "cwd": "/tmp"});
184 let merged = ctx.merged_tool_config("bash", &base);
185 assert_eq!(merged["timeout"], 60);
186 assert_eq!(merged["cwd"], "/tmp");
187 assert_eq!(merged["shell"], "zsh");
188 }
189}