1use async_trait::async_trait;
2use std::time::Duration;
3
4use crate::agent::{AgentConfig, AgentOutput, AgentUnderTest, ToolCall, Turn};
5use crate::error::SpiceError;
6
7#[derive(Debug, Clone)]
9pub struct MockResponse {
10 pub final_text: String,
11 pub tool_calls: Vec<ToolCall>,
12 pub error: Option<String>,
13}
14
15impl MockResponse {
16 pub fn text(text: impl Into<String>) -> Self {
18 Self {
19 final_text: text.into(),
20 tool_calls: vec![],
21 error: None,
22 }
23 }
24
25 pub fn with_tools(text: impl Into<String>, tools: Vec<ToolCall>) -> Self {
27 Self {
28 final_text: text.into(),
29 tool_calls: tools,
30 error: None,
31 }
32 }
33
34 pub fn error(msg: impl Into<String>) -> Self {
36 Self {
37 final_text: String::new(),
38 tool_calls: vec![],
39 error: Some(msg.into()),
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct MockTurn {
47 pub tool_calls: Vec<ToolCall>,
48 pub output_text: Option<String>,
49}
50
51#[derive(Debug, Clone)]
53pub struct MockMultiTurnResponse {
54 pub turns: Vec<MockTurn>,
55 pub final_text: String,
56}
57
58impl MockMultiTurnResponse {
59 pub fn new(final_text: impl Into<String>) -> Self {
60 Self {
61 turns: vec![],
62 final_text: final_text.into(),
63 }
64 }
65
66 pub fn turn(mut self, tool_calls: Vec<ToolCall>) -> Self {
68 self.turns.push(MockTurn {
69 tool_calls,
70 output_text: None,
71 });
72 self
73 }
74
75 pub fn turn_with_text(
77 mut self,
78 tool_calls: Vec<ToolCall>,
79 text: impl Into<String>,
80 ) -> Self {
81 self.turns.push(MockTurn {
82 tool_calls,
83 output_text: Some(text.into()),
84 });
85 self
86 }
87}
88
89pub struct MockAgent {
91 name: String,
92 responses: std::collections::HashMap<String, MockResponse>,
93 multi_turn_responses: std::collections::HashMap<String, MockMultiTurnResponse>,
94 default_response: MockResponse,
95 tools: Vec<String>,
96 role_tools: std::collections::HashMap<String, Vec<String>>,
97}
98
99impl MockAgent {
100 pub fn new(name: impl Into<String>) -> Self {
101 Self {
102 name: name.into(),
103 responses: std::collections::HashMap::new(),
104 multi_turn_responses: std::collections::HashMap::new(),
105 default_response: MockResponse::text("I don't know how to help with that."),
106 tools: vec![],
107 role_tools: std::collections::HashMap::new(),
108 }
109 }
110
111 pub fn on(mut self, message: impl Into<String>, response: MockResponse) -> Self {
113 self.responses.insert(message.into(), response);
114 self
115 }
116
117 pub fn on_multi_turn(
119 mut self,
120 message: impl Into<String>,
121 response: MockMultiTurnResponse,
122 ) -> Self {
123 self.multi_turn_responses.insert(message.into(), response);
124 self
125 }
126
127 pub fn default_response(mut self, response: MockResponse) -> Self {
129 self.default_response = response;
130 self
131 }
132
133 pub fn with_tools(mut self, tools: Vec<String>) -> Self {
135 self.tools = tools;
136 self
137 }
138
139 pub fn with_role_tools(mut self, role: &str, tools: &[&str]) -> Self {
141 self.role_tools.insert(
142 role.to_string(),
143 tools.iter().map(|s| s.to_string()).collect(),
144 );
145 self
146 }
147}
148
149#[async_trait]
150impl AgentUnderTest for MockAgent {
151 async fn run(
152 &self,
153 user_message: &str,
154 _config: &AgentConfig,
155 ) -> Result<AgentOutput, SpiceError> {
156 if let Some(mt) = self.multi_turn_responses.get(user_message) {
158 let mut turns = Vec::new();
159 let mut all_tools_called = Vec::new();
160
161 for (i, mock_turn) in mt.turns.iter().enumerate() {
162 for tc in &mock_turn.tool_calls {
163 all_tools_called.push(tc.name.clone());
164 }
165 turns.push(Turn {
166 index: i,
167 output_text: mock_turn.output_text.clone(),
168 tool_calls: mock_turn.tool_calls.clone(),
169 tool_results: vec![],
170 stop_reason: Some("tool_use".into()),
171 duration: Duration::from_millis(1),
172 });
173 }
174
175 if let Some(last) = turns.last_mut() {
177 last.stop_reason = Some("stop".into());
178 last.output_text = Some(mt.final_text.clone());
179 }
180
181 return Ok(AgentOutput {
182 final_text: mt.final_text.clone(),
183 turns,
184 tools_called: all_tools_called,
185 duration: Duration::from_millis(1),
186 error: None,
187 });
188 }
189
190 let response = self
192 .responses
193 .get(user_message)
194 .unwrap_or(&self.default_response);
195
196 if let Some(err) = &response.error {
197 return Err(SpiceError::AgentError(err.clone()));
198 }
199
200 let tools_called: Vec<String> = response
201 .tool_calls
202 .iter()
203 .map(|tc| tc.name.clone())
204 .collect();
205
206 let turn = Turn {
207 index: 0,
208 output_text: Some(response.final_text.clone()),
209 tool_calls: response.tool_calls.clone(),
210 tool_results: vec![],
211 stop_reason: Some("stop".into()),
212 duration: Duration::from_millis(1),
213 };
214
215 Ok(AgentOutput {
216 final_text: response.final_text.clone(),
217 turns: vec![turn],
218 tools_called,
219 duration: Duration::from_millis(1),
220 error: None,
221 })
222 }
223
224 fn available_tools(&self, config: &AgentConfig) -> Vec<String> {
225 if let Some(role) = config.data.get("role").and_then(|v| v.as_str()) {
227 if let Some(tools) = self.role_tools.get(role) {
228 return tools.clone();
229 }
230 }
231 self.tools.clone()
232 }
233
234 fn name(&self) -> &str {
235 &self.name
236 }
237}