Skip to main content

spice_framework/
mock.rs

1use async_trait::async_trait;
2use std::time::Duration;
3
4use crate::agent::{AgentConfig, AgentOutput, AgentUnderTest, ToolCall, Turn};
5use crate::error::SpiceError;
6
7/// A scripted response for the mock agent.
8#[derive(Debug, Clone)]
9pub struct MockResponse {
10    pub final_text: String,
11    pub tool_calls: Vec<ToolCall>,
12    pub error: Option<String>,
13}
14
15impl MockResponse {
16    /// Create a simple text response with no tool calls.
17    pub fn text(text: impl Into<String>) -> Self {
18        Self {
19            final_text: text.into(),
20            tool_calls: vec![],
21            error: None,
22        }
23    }
24
25    /// Create a response with tool calls.
26    pub fn with_tools(text: impl Into<String>, tools: Vec<ToolCall>) -> Self {
27        Self {
28            final_text: text.into(),
29            tool_calls: tools,
30            error: None,
31        }
32    }
33
34    /// Create an error response.
35    pub fn error(msg: impl Into<String>) -> Self {
36        Self {
37            final_text: String::new(),
38            tool_calls: vec![],
39            error: Some(msg.into()),
40        }
41    }
42}
43
44/// A single turn in a multi-turn scripted response.
45#[derive(Debug, Clone)]
46pub struct MockTurn {
47    pub tool_calls: Vec<ToolCall>,
48    pub output_text: Option<String>,
49}
50
51/// A multi-turn scripted response with multiple turns of tool calls.
52#[derive(Debug, Clone)]
53pub struct MockMultiTurnResponse {
54    pub turns: Vec<MockTurn>,
55    pub final_text: String,
56}
57
58impl MockMultiTurnResponse {
59    pub fn new(final_text: impl Into<String>) -> Self {
60        Self {
61            turns: vec![],
62            final_text: final_text.into(),
63        }
64    }
65
66    /// Add a turn with tool calls.
67    pub fn turn(mut self, tool_calls: Vec<ToolCall>) -> Self {
68        self.turns.push(MockTurn {
69            tool_calls,
70            output_text: None,
71        });
72        self
73    }
74
75    /// Add a turn with tool calls and output text.
76    pub fn turn_with_text(
77        mut self,
78        tool_calls: Vec<ToolCall>,
79        text: impl Into<String>,
80    ) -> Self {
81        self.turns.push(MockTurn {
82            tool_calls,
83            output_text: Some(text.into()),
84        });
85        self
86    }
87}
88
89/// A mock agent for deterministic testing.
90pub struct MockAgent {
91    name: String,
92    responses: std::collections::HashMap<String, MockResponse>,
93    multi_turn_responses: std::collections::HashMap<String, MockMultiTurnResponse>,
94    default_response: MockResponse,
95    tools: Vec<String>,
96    role_tools: std::collections::HashMap<String, Vec<String>>,
97}
98
99impl MockAgent {
100    pub fn new(name: impl Into<String>) -> Self {
101        Self {
102            name: name.into(),
103            responses: std::collections::HashMap::new(),
104            multi_turn_responses: std::collections::HashMap::new(),
105            default_response: MockResponse::text("I don't know how to help with that."),
106            tools: vec![],
107            role_tools: std::collections::HashMap::new(),
108        }
109    }
110
111    /// Register a scripted response for a specific user message (exact match).
112    pub fn on(mut self, message: impl Into<String>, response: MockResponse) -> Self {
113        self.responses.insert(message.into(), response);
114        self
115    }
116
117    /// Register a multi-turn scripted response for a specific user message.
118    pub fn on_multi_turn(
119        mut self,
120        message: impl Into<String>,
121        response: MockMultiTurnResponse,
122    ) -> Self {
123        self.multi_turn_responses.insert(message.into(), response);
124        self
125    }
126
127    /// Set the default response for unmatched messages.
128    pub fn default_response(mut self, response: MockResponse) -> Self {
129        self.default_response = response;
130        self
131    }
132
133    /// Set the available tools.
134    pub fn with_tools(mut self, tools: Vec<String>) -> Self {
135        self.tools = tools;
136        self
137    }
138
139    /// Set per-role tool lists.
140    pub fn with_role_tools(mut self, role: &str, tools: &[&str]) -> Self {
141        self.role_tools.insert(
142            role.to_string(),
143            tools.iter().map(|s| s.to_string()).collect(),
144        );
145        self
146    }
147}
148
149#[async_trait]
150impl AgentUnderTest for MockAgent {
151    async fn run(
152        &self,
153        user_message: &str,
154        _config: &AgentConfig,
155    ) -> Result<AgentOutput, SpiceError> {
156        // Check multi-turn responses first
157        if let Some(mt) = self.multi_turn_responses.get(user_message) {
158            let mut turns = Vec::new();
159            let mut all_tools_called = Vec::new();
160
161            for (i, mock_turn) in mt.turns.iter().enumerate() {
162                for tc in &mock_turn.tool_calls {
163                    all_tools_called.push(tc.name.clone());
164                }
165                turns.push(Turn {
166                    index: i,
167                    output_text: mock_turn.output_text.clone(),
168                    tool_calls: mock_turn.tool_calls.clone(),
169                    tool_results: vec![],
170                    stop_reason: Some("tool_use".into()),
171                    duration: Duration::from_millis(1),
172                });
173            }
174
175            // Fix last turn's stop_reason
176            if let Some(last) = turns.last_mut() {
177                last.stop_reason = Some("stop".into());
178                last.output_text = Some(mt.final_text.clone());
179            }
180
181            return Ok(AgentOutput {
182                final_text: mt.final_text.clone(),
183                turns,
184                tools_called: all_tools_called,
185                duration: Duration::from_millis(1),
186                error: None,
187            });
188        }
189
190        // Fall back to single-turn responses
191        let response = self
192            .responses
193            .get(user_message)
194            .unwrap_or(&self.default_response);
195
196        if let Some(err) = &response.error {
197            return Err(SpiceError::AgentError(err.clone()));
198        }
199
200        let tools_called: Vec<String> = response
201            .tool_calls
202            .iter()
203            .map(|tc| tc.name.clone())
204            .collect();
205
206        let turn = Turn {
207            index: 0,
208            output_text: Some(response.final_text.clone()),
209            tool_calls: response.tool_calls.clone(),
210            tool_results: vec![],
211            stop_reason: Some("stop".into()),
212            duration: Duration::from_millis(1),
213        };
214
215        Ok(AgentOutput {
216            final_text: response.final_text.clone(),
217            turns: vec![turn],
218            tools_called,
219            duration: Duration::from_millis(1),
220            error: None,
221        })
222    }
223
224    fn available_tools(&self, config: &AgentConfig) -> Vec<String> {
225        // Check for role-specific tools
226        if let Some(role) = config.data.get("role").and_then(|v| v.as_str()) {
227            if let Some(tools) = self.role_tools.get(role) {
228                return tools.clone();
229            }
230        }
231        self.tools.clone()
232    }
233
234    fn name(&self) -> &str {
235        &self.name
236    }
237}