Skip to main content

sgr_agent/
agent.rs

1//! Agent trait — decides what to do next given conversation history and tools.
2
3use crate::types::{SgrError, ToolCall};
4
5/// Agent's decision: what to do next.
6#[derive(Debug, Clone)]
7pub struct Decision {
8    /// Agent's assessment of the current situation.
9    pub situation: String,
10    /// Task breakdown (reasoning steps).
11    pub task: Vec<String>,
12    /// Tool calls to execute.
13    pub tool_calls: Vec<ToolCall>,
14    /// If true, the agent considers the task done.
15    pub completed: bool,
16}
17
18/// Errors from agent operations.
19#[derive(Debug, thiserror::Error)]
20pub enum AgentError {
21    #[error("LLM error: {0}")]
22    Llm(#[from] SgrError),
23    #[error("tool error: {0}")]
24    Tool(String),
25    #[error("loop detected after {0} iterations")]
26    LoopDetected(usize),
27    #[error("max steps reached: {0}")]
28    MaxSteps(usize),
29    #[error("cancelled")]
30    Cancelled,
31}
32
33/// An agent that decides what tools to call given conversation history.
34///
35/// Lifecycle hooks (all have default no-op implementations):
36/// - `prepare_context` — called before each step to modify context
37/// - `prepare_tools` — called before each step to filter/modify tool set
38/// - `after_action` — called after tool execution with results
39#[async_trait::async_trait]
40pub trait Agent: Send + Sync {
41    /// Given messages and available tools, decide what to do next.
42    async fn decide(
43        &self,
44        messages: &[crate::types::Message],
45        tools: &crate::registry::ToolRegistry,
46    ) -> Result<Decision, AgentError>;
47
48    /// Stateful decide — returns response_id for multi-turn chaining.
49    ///
50    /// The agent loop tracks `response_id` across steps and passes it here.
51    /// Agents that support stateful sessions (e.g., HybridAgent with
52    /// `tools_call_stateful`) can use it for delta-only requests.
53    ///
54    /// Default: delegates to stateless `decide()`, returns `None`.
55    async fn decide_stateful(
56        &self,
57        messages: &[crate::types::Message],
58        tools: &crate::registry::ToolRegistry,
59        _previous_response_id: Option<&str>,
60    ) -> Result<(Decision, Option<String>), AgentError> {
61        let d = self.decide(messages, tools).await?;
62        Ok((d, None))
63    }
64
65    /// Hook: modify context before each step. Default: no-op.
66    fn prepare_context(
67        &self,
68        _ctx: &mut crate::context::AgentContext,
69        _messages: &[crate::types::Message],
70    ) {
71    }
72
73    /// Hook: filter or reorder tools before each step.
74    /// Returns tool names to include. Default: all tools.
75    fn prepare_tools(
76        &self,
77        _ctx: &crate::context::AgentContext,
78        tools: &crate::registry::ToolRegistry,
79    ) -> Vec<String> {
80        tools.list().iter().map(|t| t.name().to_string()).collect()
81    }
82
83    /// Hook: called after tool execution with the tool name and output.
84    /// Can modify context or messages. Default: no-op.
85    fn after_action(
86        &self,
87        _ctx: &mut crate::context::AgentContext,
88        _tool_name: &str,
89        _output: &str,
90    ) {
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn decision_completed() {
100        let d = Decision {
101            situation: "done".into(),
102            task: vec![],
103            tool_calls: vec![],
104            completed: true,
105        };
106        assert!(d.completed);
107    }
108
109    #[test]
110    fn agent_error_display() {
111        let err = AgentError::LoopDetected(5);
112        assert_eq!(err.to_string(), "loop detected after 5 iterations");
113        let err = AgentError::MaxSteps(50);
114        assert_eq!(err.to_string(), "max steps reached: 50");
115    }
116}