Skip to main content

sgr_agent_core/
context.rs

1//! Agent execution context — state and domain-specific data.
2
3use serde_json::Value;
4use std::collections::HashMap;
5use std::path::PathBuf;
6
7/// Well-known key in `AgentContext.custom` for max_tokens override.
8pub const MAX_TOKENS_OVERRIDE_KEY: &str = "_max_tokens_override";
9
10/// Agent execution state.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum AgentState {
13    Running,
14    Completed,
15    Failed,
16    Cancelled,
17    WaitingInput,
18}
19
20/// Shared context passed to tools during execution.
21#[derive(Debug, Clone)]
22pub struct AgentContext {
23    pub iteration: usize,
24    pub state: AgentState,
25    pub cwd: PathBuf,
26    pub custom: HashMap<String, Value>,
27    pub tool_configs: HashMap<String, Value>,
28    pub writable_roots: Vec<PathBuf>,
29    pub observations: Vec<String>,
30    pub observation_limit: usize,
31    pub tool_cache: HashMap<String, String>,
32}
33
34impl AgentContext {
35    pub fn new() -> Self {
36        Self {
37            iteration: 0,
38            state: AgentState::Running,
39            cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
40            custom: HashMap::new(),
41            tool_configs: HashMap::new(),
42            writable_roots: Vec::new(),
43            observations: Vec::new(),
44            observation_limit: 30,
45            tool_cache: HashMap::new(),
46        }
47    }
48
49    pub fn observe(&mut self, entry: impl Into<String>) {
50        self.observations.push(entry.into());
51        while self.observations.len() > self.observation_limit {
52            self.observations.remove(0);
53        }
54    }
55
56    pub fn observation_summary(&self) -> Option<String> {
57        if self.observations.is_empty() {
58            None
59        } else {
60            Some(format!("OBSERVATION LOG:\n{}", self.observations.join("\n")))
61        }
62    }
63
64    pub fn cache_tool_result(&mut self, key: impl Into<String>, result: impl Into<String>) {
65        self.tool_cache.insert(key.into(), result.into());
66    }
67
68    pub fn cached_tool_result(&self, key: &str) -> Option<&str> {
69        self.tool_cache.get(key).map(|s| s.as_str())
70    }
71
72    pub fn invalidate_cache(&mut self, key: &str) {
73        self.tool_cache.remove(key);
74    }
75
76    pub fn with_cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
77        self.cwd = cwd.into();
78        self
79    }
80
81    pub fn with_writable_roots(mut self, roots: Vec<PathBuf>) -> Self {
82        self.writable_roots = roots;
83        self
84    }
85
86    pub fn is_writable(&self, path: &std::path::Path) -> bool {
87        if self.writable_roots.is_empty() {
88            return true;
89        }
90        let abs_path = if path.is_absolute() {
91            path.to_path_buf()
92        } else {
93            self.cwd.join(path)
94        };
95        let resolved = std::fs::canonicalize(&abs_path).unwrap_or_else(|_| {
96            if let Some(parent) = abs_path.parent()
97                && let Ok(canon_parent) = std::fs::canonicalize(parent)
98                && let Some(name) = abs_path.file_name()
99            {
100                return canon_parent.join(name);
101            }
102            abs_path.clone()
103        });
104        self.writable_roots.iter().any(|root| {
105            let canon_root = std::fs::canonicalize(root).unwrap_or_else(|_| root.clone());
106            resolved.starts_with(&canon_root)
107        })
108    }
109
110    pub fn set(&mut self, key: impl Into<String>, value: Value) {
111        self.custom.insert(key.into(), value);
112    }
113
114    pub fn get(&self, key: &str) -> Option<&Value> {
115        self.custom.get(key)
116    }
117
118    pub fn max_tokens_override(&self) -> Option<u32> {
119        self.custom
120            .get(MAX_TOKENS_OVERRIDE_KEY)
121            .and_then(|v| v.as_u64())
122            .map(|v| v as u32)
123    }
124
125    pub fn set_tool_config(&mut self, tool_name: impl Into<String>, config: Value) {
126        self.tool_configs.insert(tool_name.into(), config);
127    }
128
129    pub fn tool_config(&self, tool_name: &str) -> Option<&Value> {
130        self.tool_configs.get(tool_name)
131    }
132
133    pub fn merged_tool_config(&self, tool_name: &str, base: &Value) -> Value {
134        match (base, self.tool_configs.get(tool_name)) {
135            (Value::Object(base_obj), Some(Value::Object(override_obj))) => {
136                let mut merged = base_obj.clone();
137                for (k, v) in override_obj {
138                    merged.insert(k.clone(), v.clone());
139                }
140                Value::Object(merged)
141            }
142            (_, Some(override_val)) => override_val.clone(),
143            _ => base.clone(),
144        }
145    }
146}
147
148impl Default for AgentContext {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn context_default_state() {
160        let ctx = AgentContext::new();
161        assert_eq!(ctx.state, AgentState::Running);
162        assert_eq!(ctx.iteration, 0);
163    }
164
165    #[test]
166    fn context_custom_data() {
167        let mut ctx = AgentContext::new();
168        ctx.set("project", serde_json::json!("my-project"));
169        assert_eq!(ctx.get("project").unwrap(), "my-project");
170        assert!(ctx.get("missing").is_none());
171    }
172
173    #[test]
174    fn context_with_cwd() {
175        let ctx = AgentContext::new().with_cwd("/tmp/test");
176        assert_eq!(ctx.cwd, PathBuf::from("/tmp/test"));
177    }
178
179    #[test]
180    fn tool_config_merge() {
181        let mut ctx = AgentContext::new();
182        ctx.set_tool_config("bash", serde_json::json!({"timeout": 60, "shell": "zsh"}));
183        let base = serde_json::json!({"timeout": 30, "cwd": "/tmp"});
184        let merged = ctx.merged_tool_config("bash", &base);
185        assert_eq!(merged["timeout"], 60);
186        assert_eq!(merged["cwd"], "/tmp");
187        assert_eq!(merged["shell"], "zsh");
188    }
189}