1use crate::ai::{error::AiError, model::Model, provider::AiProvider, types::*};
2use std::{
3 collections::HashSet,
4 sync::{Arc, Mutex},
5};
6
7fn validate_tool_use_results(messages: &[Message]) -> Result<(), AiError> {
8 for (i, message) in messages.iter().enumerate() {
9 if message.role != MessageRole::Assistant {
10 continue;
11 }
12
13 let tool_uses = message.content.tool_uses();
14 if tool_uses.is_empty() {
15 continue;
16 }
17
18 let tool_use_ids: HashSet<&str> = tool_uses.iter().map(|tu| tu.id.as_str()).collect();
19
20 let Some(next_message) = messages.get(i + 1) else {
21 continue;
22 };
23
24 if next_message.role != MessageRole::User {
25 let ids: Vec<&str> = tool_use_ids.into_iter().collect();
26 return Err(AiError::Terminal(anyhow::anyhow!(
27 "ValidationException: messages.{}: tool_use ids were found without tool_result blocks immediately after: {}. Each tool_use block must have a corresponding tool_result block in the next message",
28 i,
29 ids.join(", ")
30 )));
31 }
32
33 let tool_results = next_message.content.tool_results();
34 let result_ids: HashSet<&str> = tool_results
35 .iter()
36 .map(|tr| tr.tool_use_id.as_str())
37 .collect();
38
39 let missing_ids: Vec<&str> = tool_use_ids
40 .iter()
41 .filter(|id| !result_ids.contains(*id))
42 .copied()
43 .collect();
44
45 if !missing_ids.is_empty() {
46 return Err(AiError::Terminal(anyhow::anyhow!(
47 "ValidationException: messages.{}: tool_use ids were found without tool_result blocks immediately after: {}. Each tool_use block must have a corresponding tool_result block in the next message",
48 i,
49 missing_ids.join(", ")
50 )));
51 }
52 }
53
54 Ok(())
55}
56
57#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default)]
59#[serde(rename_all = "snake_case")]
60pub enum MockBehavior {
61 #[default]
63 Success,
64 RetryableErrorThenSuccess { remaining_errors: usize },
66 AlwaysRetryableError,
68 AlwaysNonRetryableError,
70 ToolUse {
72 tool_name: String,
73 tool_arguments: String,
74 },
75 ToolUseThenSuccess {
77 tool_name: String,
78 tool_arguments: String,
79 },
80 AlwaysInputTooLong,
82 InputTooLongThenSuccess { remaining_errors: usize },
84 TextOnlyThenToolUse {
86 remaining_text_responses: usize,
87 tool_name: String,
88 tool_arguments: String,
89 },
90 ToolUseThenToolUse {
92 first_tool_name: String,
93 first_tool_arguments: String,
94 second_tool_name: String,
95 second_tool_arguments: String,
96 },
97 MultipleToolUses { tool_uses: Vec<(String, String)> },
99 BehaviorQueue { behaviors: Vec<MockBehavior> },
101}
102
103#[derive(Clone)]
105pub struct MockProvider {
106 behavior: Arc<Mutex<MockBehavior>>,
107 call_count: Arc<Mutex<usize>>,
108 captured_requests: Arc<Mutex<Vec<ConversationRequest>>>,
109}
110
111impl MockProvider {
112 pub fn new(behavior: MockBehavior) -> Self {
113 Self {
114 behavior: Arc::new(Mutex::new(behavior)),
115 call_count: Arc::new(Mutex::new(0)),
116 captured_requests: Arc::new(Mutex::new(Vec::new())),
117 }
118 }
119
120 fn pop_behavior_from_queue(behavior: &mut MockBehavior) -> MockBehavior {
121 if let MockBehavior::BehaviorQueue { behaviors } = behavior {
122 if behaviors.is_empty() {
123 return MockBehavior::Success;
124 }
125 return behaviors.remove(0);
126 }
127 behavior.clone()
128 }
129
130 pub fn set_behavior(&self, behavior: MockBehavior) {
131 *self.behavior.lock().unwrap() = behavior;
132 }
133
134 pub fn get_call_count(&self) -> usize {
135 *self.call_count.lock().unwrap()
136 }
137
138 pub fn reset_call_count(&self) {
139 *self.call_count.lock().unwrap() = 0;
140 }
141
142 pub fn get_captured_requests(&self) -> Vec<ConversationRequest> {
143 self.captured_requests.lock().unwrap().clone()
144 }
145
146 pub fn get_last_captured_request(&self) -> Option<ConversationRequest> {
147 self.captured_requests.lock().unwrap().last().cloned()
148 }
149
150 pub fn clear_captured_requests(&self) {
151 self.captured_requests.lock().unwrap().clear();
152 }
153}
154
155#[async_trait::async_trait]
156impl AiProvider for MockProvider {
157 fn name(&self) -> &'static str {
158 "mock"
159 }
160
161 fn supported_models(&self) -> HashSet<Model> {
162 HashSet::from([Model::None])
163 }
164
165 async fn converse(
166 &self,
167 request: ConversationRequest,
168 ) -> Result<ConversationResponse, AiError> {
169 validate_tool_use_results(&request.messages)?;
170
171 {
173 let mut requests = self.captured_requests.lock().unwrap();
174 requests.push(request.clone());
175 }
176
177 {
179 let mut count = self.call_count.lock().unwrap();
180 *count += 1;
181 }
182
183 let effective = {
184 let mut behavior = self.behavior.lock().unwrap();
185 Self::pop_behavior_from_queue(&mut behavior)
186 };
187
188 match effective {
189 MockBehavior::Success => Ok(ConversationResponse {
190 content: Content::text_only("Mock response".to_string()),
191 usage: TokenUsage::new(10, 10),
192 stop_reason: StopReason::EndTurn,
193 }),
194 MockBehavior::RetryableErrorThenSuccess {
195 mut remaining_errors,
196 } => {
197 if remaining_errors > 0 {
198 remaining_errors -= 1;
199 self.set_behavior(MockBehavior::RetryableErrorThenSuccess { remaining_errors });
200 Err(AiError::Retryable(anyhow::anyhow!(
201 "Mock retryable error (remaining: {})",
202 remaining_errors
203 )))
204 } else {
205 Ok(ConversationResponse {
206 content: Content::text_only("Success after retries".to_string()),
207 usage: TokenUsage::new(10, 10),
208 stop_reason: StopReason::EndTurn,
209 })
210 }
211 }
212 MockBehavior::AlwaysRetryableError => Err(AiError::Retryable(anyhow::anyhow!(
213 "Mock retryable error (always fails)"
214 ))),
215 MockBehavior::AlwaysNonRetryableError => Err(AiError::Terminal(anyhow::anyhow!(
216 "Mock non-retryable error"
217 ))),
218 MockBehavior::ToolUse {
219 tool_name,
220 tool_arguments,
221 } => {
222 let tool_use = ToolUseData {
223 id: format!("tool_{tool_name}"),
224 name: tool_name.clone(),
225 arguments: serde_json::from_str(&tool_arguments)
226 .unwrap_or_else(|_| serde_json::json!({})),
227 };
228
229 Ok(ConversationResponse {
230 content: Content::new(vec![
231 ContentBlock::Text(format!(
232 "I'll use the {tool_name} tool to help with this task."
233 )),
234 ContentBlock::ToolUse(tool_use),
235 ]),
236 usage: TokenUsage::new(10, 10),
237 stop_reason: StopReason::ToolUse,
238 })
239 }
240 MockBehavior::ToolUseThenSuccess {
241 tool_name,
242 tool_arguments,
243 } => {
244 let tool_use = ToolUseData {
245 id: format!("tool_{tool_name}"),
246 name: tool_name.clone(),
247 arguments: serde_json::from_str(&tool_arguments)
248 .unwrap_or_else(|_| serde_json::json!({})),
249 };
250
251 let response = ConversationResponse {
252 content: Content::new(vec![
253 ContentBlock::Text(format!(
254 "I'll use the {tool_name} tool to help with this task."
255 )),
256 ContentBlock::ToolUse(tool_use),
257 ]),
258 usage: TokenUsage::new(10, 10),
259 stop_reason: StopReason::ToolUse,
260 };
261
262 self.set_behavior(MockBehavior::Success);
263 Ok(response)
264 }
265 MockBehavior::AlwaysInputTooLong => Err(AiError::InputTooLong(anyhow::anyhow!(
266 "Mock input too long error (always fails)"
267 ))),
268 MockBehavior::InputTooLongThenSuccess {
269 mut remaining_errors,
270 } => {
271 if remaining_errors > 0 {
272 remaining_errors -= 1;
273 self.set_behavior(MockBehavior::InputTooLongThenSuccess { remaining_errors });
274 Err(AiError::InputTooLong(anyhow::anyhow!(
275 "Mock input too long error (remaining: {})",
276 remaining_errors
277 )))
278 } else {
279 Ok(ConversationResponse {
280 content: Content::text_only(
281 "Success after input too long errors".to_string(),
282 ),
283 usage: TokenUsage::new(10, 10),
284 stop_reason: StopReason::EndTurn,
285 })
286 }
287 }
288 MockBehavior::TextOnlyThenToolUse {
289 mut remaining_text_responses,
290 tool_name,
291 tool_arguments,
292 } => {
293 remaining_text_responses = remaining_text_responses.saturating_sub(1);
294
295 if remaining_text_responses == 0 {
296 self.set_behavior(MockBehavior::ToolUseThenSuccess {
297 tool_name,
298 tool_arguments,
299 });
300 } else {
301 self.set_behavior(MockBehavior::TextOnlyThenToolUse {
302 remaining_text_responses,
303 tool_name,
304 tool_arguments,
305 });
306 }
307
308 Ok(ConversationResponse {
309 content: Content::text_only("Mock text response without tools".to_string()),
310 usage: TokenUsage::new(10, 10),
311 stop_reason: StopReason::EndTurn,
312 })
313 }
314 MockBehavior::ToolUseThenToolUse {
315 first_tool_name,
316 first_tool_arguments,
317 second_tool_name,
318 second_tool_arguments,
319 } => {
320 let tool_use = ToolUseData {
321 id: format!("tool_{first_tool_name}"),
322 name: first_tool_name.clone(),
323 arguments: serde_json::from_str(&first_tool_arguments)
324 .unwrap_or_else(|_| serde_json::json!({})),
325 };
326
327 let response = ConversationResponse {
328 content: Content::new(vec![
329 ContentBlock::Text(format!(
330 "I'll use the {first_tool_name} tool to help with this task."
331 )),
332 ContentBlock::ToolUse(tool_use),
333 ]),
334 usage: TokenUsage::new(10, 10),
335 stop_reason: StopReason::ToolUse,
336 };
337
338 self.set_behavior(MockBehavior::ToolUseThenSuccess {
339 tool_name: second_tool_name,
340 tool_arguments: second_tool_arguments,
341 });
342
343 Ok(response)
344 }
345 MockBehavior::MultipleToolUses { tool_uses } => {
346 let mut content_blocks = vec![ContentBlock::Text(
347 "I'll use multiple tools to help with this task.".to_string(),
348 )];
349
350 for (index, (tool_name, tool_arguments)) in tool_uses.iter().enumerate() {
351 let tool_use = ToolUseData {
352 id: format!("tool_{}_{}", tool_name, index),
353 name: tool_name.clone(),
354 arguments: serde_json::from_str(tool_arguments)
355 .unwrap_or_else(|_| serde_json::json!({})),
356 };
357 content_blocks.push(ContentBlock::ToolUse(tool_use));
358 }
359
360 self.set_behavior(MockBehavior::Success);
361
362 Ok(ConversationResponse {
363 content: Content::new(content_blocks),
364 usage: TokenUsage::new(10, 10),
365 stop_reason: StopReason::ToolUse,
366 })
367 }
368 MockBehavior::BehaviorQueue { .. } => {
369 panic!("Bug: nested BehaviorQueue detected. Test setup error - BehaviorQueues cannot contain other BehaviorQueues")
370 }
371 }
372 }
373
374 fn get_cost(&self, _model: &Model) -> Cost {
375 Cost::new(0.001, 0.002, 0.0, 0.0)
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[tokio::test]
385 async fn test_mock_provider_success() {
386 let provider = MockProvider::new(MockBehavior::Success);
387
388 let request = ConversationRequest {
389 messages: vec![Message::user("Test")],
390 model: Model::None.default_settings(),
391 system_prompt: String::new(),
392 stop_sequences: vec![],
393 tools: vec![],
394 };
395
396 let response = provider.converse(request).await.unwrap();
397 assert_eq!(response.content.text(), "Mock response");
398 assert_eq!(provider.get_call_count(), 1);
399 }
400
401 #[tokio::test]
402 async fn test_mock_provider_retry_then_success() {
403 let provider = MockProvider::new(MockBehavior::RetryableErrorThenSuccess {
404 remaining_errors: 2,
405 });
406
407 let request = ConversationRequest {
408 messages: vec![Message::user("Test")],
409 model: Model::None.default_settings(),
410 system_prompt: String::new(),
411 stop_sequences: vec![],
412 tools: vec![],
413 };
414
415 let result1 = provider.converse(request.clone()).await;
417 assert!(matches!(result1, Err(AiError::Retryable(_))));
418
419 let result2 = provider.converse(request.clone()).await;
421 assert!(matches!(result2, Err(AiError::Retryable(_))));
422
423 let result3 = provider.converse(request).await;
425 assert!(result3.is_ok());
426 assert_eq!(result3.unwrap().content.text(), "Success after retries");
427 assert_eq!(provider.get_call_count(), 3);
428 }
429}