Skip to main content

sgr_agent_core/
context.rs

1//! Agent execution context — state and domain-specific data.
2
3use serde_json::Value;
4use std::collections::HashMap;
5use std::path::PathBuf;
6
7/// Well-known key in `AgentContext.custom` for max_tokens override.
8pub const MAX_TOKENS_OVERRIDE_KEY: &str = "_max_tokens_override";
9
10/// Agent execution state.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum AgentState {
13    Running,
14    Completed,
15    Failed,
16    Cancelled,
17    WaitingInput,
18}
19
20/// Shared context passed to tools during execution.
21#[derive(Debug, Clone)]
22pub struct AgentContext {
23    pub iteration: usize,
24    pub state: AgentState,
25    pub cwd: PathBuf,
26    pub custom: HashMap<String, Value>,
27    pub tool_configs: HashMap<String, Value>,
28    pub writable_roots: Vec<PathBuf>,
29    pub observations: Vec<String>,
30    pub observation_limit: usize,
31    pub tool_cache: HashMap<String, String>,
32}
33
34impl AgentContext {
35    pub fn new() -> Self {
36        Self {
37            iteration: 0,
38            state: AgentState::Running,
39            cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
40            custom: HashMap::new(),
41            tool_configs: HashMap::new(),
42            writable_roots: Vec::new(),
43            observations: Vec::new(),
44            observation_limit: 30,
45            tool_cache: HashMap::new(),
46        }
47    }
48
49    pub fn observe(&mut self, entry: impl Into<String>) {
50        self.observations.push(entry.into());
51        while self.observations.len() > self.observation_limit {
52            self.observations.remove(0);
53        }
54    }
55
56    pub fn observation_summary(&self) -> Option<String> {
57        if self.observations.is_empty() {
58            None
59        } else {
60            Some(format!(
61                "OBSERVATION LOG:\n{}",
62                self.observations.join("\n")
63            ))
64        }
65    }
66
67    pub fn cache_tool_result(&mut self, key: impl Into<String>, result: impl Into<String>) {
68        self.tool_cache.insert(key.into(), result.into());
69    }
70
71    pub fn cached_tool_result(&self, key: &str) -> Option<&str> {
72        self.tool_cache.get(key).map(|s| s.as_str())
73    }
74
75    pub fn invalidate_cache(&mut self, key: &str) {
76        self.tool_cache.remove(key);
77    }
78
79    pub fn with_cwd(mut self, cwd: impl Into<PathBuf>) -> Self {
80        self.cwd = cwd.into();
81        self
82    }
83
84    pub fn with_writable_roots(mut self, roots: Vec<PathBuf>) -> Self {
85        self.writable_roots = roots;
86        self
87    }
88
89    pub fn is_writable(&self, path: &std::path::Path) -> bool {
90        if self.writable_roots.is_empty() {
91            return true;
92        }
93        let abs_path = if path.is_absolute() {
94            path.to_path_buf()
95        } else {
96            self.cwd.join(path)
97        };
98        let resolved = std::fs::canonicalize(&abs_path).unwrap_or_else(|_| {
99            if let Some(parent) = abs_path.parent()
100                && let Ok(canon_parent) = std::fs::canonicalize(parent)
101                && let Some(name) = abs_path.file_name()
102            {
103                return canon_parent.join(name);
104            }
105            abs_path.clone()
106        });
107        self.writable_roots.iter().any(|root| {
108            let canon_root = std::fs::canonicalize(root).unwrap_or_else(|_| root.clone());
109            resolved.starts_with(&canon_root)
110        })
111    }
112
113    pub fn set(&mut self, key: impl Into<String>, value: Value) {
114        self.custom.insert(key.into(), value);
115    }
116
117    pub fn get(&self, key: &str) -> Option<&Value> {
118        self.custom.get(key)
119    }
120
121    pub fn max_tokens_override(&self) -> Option<u32> {
122        self.custom
123            .get(MAX_TOKENS_OVERRIDE_KEY)
124            .and_then(|v| v.as_u64())
125            .map(|v| v as u32)
126    }
127
128    pub fn set_tool_config(&mut self, tool_name: impl Into<String>, config: Value) {
129        self.tool_configs.insert(tool_name.into(), config);
130    }
131
132    pub fn tool_config(&self, tool_name: &str) -> Option<&Value> {
133        self.tool_configs.get(tool_name)
134    }
135
136    pub fn merged_tool_config(&self, tool_name: &str, base: &Value) -> Value {
137        match (base, self.tool_configs.get(tool_name)) {
138            (Value::Object(base_obj), Some(Value::Object(override_obj))) => {
139                let mut merged = base_obj.clone();
140                for (k, v) in override_obj {
141                    merged.insert(k.clone(), v.clone());
142                }
143                Value::Object(merged)
144            }
145            (_, Some(override_val)) => override_val.clone(),
146            _ => base.clone(),
147        }
148    }
149}
150
151impl Default for AgentContext {
152    fn default() -> Self {
153        Self::new()
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn context_default_state() {
163        let ctx = AgentContext::new();
164        assert_eq!(ctx.state, AgentState::Running);
165        assert_eq!(ctx.iteration, 0);
166    }
167
168    #[test]
169    fn context_custom_data() {
170        let mut ctx = AgentContext::new();
171        ctx.set("project", serde_json::json!("my-project"));
172        assert_eq!(ctx.get("project").unwrap(), "my-project");
173        assert!(ctx.get("missing").is_none());
174    }
175
176    #[test]
177    fn context_with_cwd() {
178        let ctx = AgentContext::new().with_cwd("/tmp/test");
179        assert_eq!(ctx.cwd, PathBuf::from("/tmp/test"));
180    }
181
182    #[test]
183    fn tool_config_merge() {
184        let mut ctx = AgentContext::new();
185        ctx.set_tool_config("bash", serde_json::json!({"timeout": 60, "shell": "zsh"}));
186        let base = serde_json::json!({"timeout": 30, "cwd": "/tmp"});
187        let merged = ctx.merged_tool_config("bash", &base);
188        assert_eq!(merged["timeout"], 60);
189        assert_eq!(merged["cwd"], "/tmp");
190        assert_eq!(merged["shell"], "zsh");
191    }
192}