Skip to main content

tycode_core/ai/
mock.rs

1use crate::ai::{error::AiError, model::Model, provider::AiProvider, types::*};
2use std::{
3    collections::HashSet,
4    sync::{Arc, Mutex},
5};
6
7fn validate_tool_use_results(messages: &[Message]) -> Result<(), AiError> {
8    for (i, message) in messages.iter().enumerate() {
9        if message.role != MessageRole::Assistant {
10            continue;
11        }
12
13        let tool_uses = message.content.tool_uses();
14        if tool_uses.is_empty() {
15            continue;
16        }
17
18        let tool_use_ids: HashSet<&str> = tool_uses.iter().map(|tu| tu.id.as_str()).collect();
19
20        let Some(next_message) = messages.get(i + 1) else {
21            continue;
22        };
23
24        if next_message.role != MessageRole::User {
25            let ids: Vec<&str> = tool_use_ids.into_iter().collect();
26            return Err(AiError::Terminal(anyhow::anyhow!(
27                "ValidationException: messages.{}: tool_use ids were found without tool_result blocks immediately after: {}. Each tool_use block must have a corresponding tool_result block in the next message",
28                i,
29                ids.join(", ")
30            )));
31        }
32
33        let tool_results = next_message.content.tool_results();
34        let result_ids: HashSet<&str> = tool_results
35            .iter()
36            .map(|tr| tr.tool_use_id.as_str())
37            .collect();
38
39        let missing_ids: Vec<&str> = tool_use_ids
40            .iter()
41            .filter(|id| !result_ids.contains(*id))
42            .copied()
43            .collect();
44
45        if !missing_ids.is_empty() {
46            return Err(AiError::Terminal(anyhow::anyhow!(
47                "ValidationException: messages.{}: tool_use ids were found without tool_result blocks immediately after: {}. Each tool_use block must have a corresponding tool_result block in the next message",
48                i,
49                missing_ids.join(", ")
50            )));
51        }
52    }
53
54    Ok(())
55}
56
57/// Mock behavior for the mock provider
58#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default)]
59#[serde(rename_all = "snake_case")]
60pub enum MockBehavior {
61    /// Return successful responses
62    #[default]
63    Success,
64    /// Return a retryable error N times, then succeed
65    RetryableErrorThenSuccess { remaining_errors: usize },
66    /// Always return a retryable error
67    AlwaysRetryableError,
68    /// Always return a non-retryable error
69    AlwaysNonRetryableError,
70    /// Return a tool use response
71    ToolUse {
72        tool_name: String,
73        tool_arguments: String,
74    },
75    /// Return a tool use response once, then success
76    ToolUseThenSuccess {
77        tool_name: String,
78        tool_arguments: String,
79    },
80    /// Always return an InputTooLong error
81    AlwaysInputTooLong,
82    /// Return InputTooLong error N times, then succeed
83    InputTooLongThenSuccess { remaining_errors: usize },
84    /// Return text-only responses N times, then a tool use response
85    TextOnlyThenToolUse {
86        remaining_text_responses: usize,
87        tool_name: String,
88        tool_arguments: String,
89    },
90    /// Return two tool uses in sequence, then success
91    ToolUseThenToolUse {
92        first_tool_name: String,
93        first_tool_arguments: String,
94        second_tool_name: String,
95        second_tool_arguments: String,
96    },
97    /// Return multiple tool uses in a single response, then success
98    MultipleToolUses { tool_uses: Vec<(String, String)> },
99    /// Enables sequential multi-turn conversation testing by orchestrating predetermined agent responses
100    BehaviorQueue { behaviors: Vec<MockBehavior> },
101}
102
103/// Mock AI provider for testing
104#[derive(Clone)]
105pub struct MockProvider {
106    behavior: Arc<Mutex<MockBehavior>>,
107    call_count: Arc<Mutex<usize>>,
108    captured_requests: Arc<Mutex<Vec<ConversationRequest>>>,
109}
110
111impl MockProvider {
112    pub fn new(behavior: MockBehavior) -> Self {
113        Self {
114            behavior: Arc::new(Mutex::new(behavior)),
115            call_count: Arc::new(Mutex::new(0)),
116            captured_requests: Arc::new(Mutex::new(Vec::new())),
117        }
118    }
119
120    fn pop_behavior_from_queue(behavior: &mut MockBehavior) -> MockBehavior {
121        if let MockBehavior::BehaviorQueue { behaviors } = behavior {
122            if behaviors.is_empty() {
123                return MockBehavior::Success;
124            }
125            return behaviors.remove(0);
126        }
127        behavior.clone()
128    }
129
130    pub fn set_behavior(&self, behavior: MockBehavior) {
131        *self.behavior.lock().unwrap() = behavior;
132    }
133
134    pub fn get_call_count(&self) -> usize {
135        *self.call_count.lock().unwrap()
136    }
137
138    pub fn reset_call_count(&self) {
139        *self.call_count.lock().unwrap() = 0;
140    }
141
142    pub fn get_captured_requests(&self) -> Vec<ConversationRequest> {
143        self.captured_requests.lock().unwrap().clone()
144    }
145
146    pub fn get_last_captured_request(&self) -> Option<ConversationRequest> {
147        self.captured_requests.lock().unwrap().last().cloned()
148    }
149
150    pub fn clear_captured_requests(&self) {
151        self.captured_requests.lock().unwrap().clear();
152    }
153}
154
155#[async_trait::async_trait]
156impl AiProvider for MockProvider {
157    fn name(&self) -> &'static str {
158        "mock"
159    }
160
161    fn supported_models(&self) -> HashSet<Model> {
162        HashSet::from([Model::None])
163    }
164
165    async fn converse(
166        &self,
167        request: ConversationRequest,
168    ) -> Result<ConversationResponse, AiError> {
169        validate_tool_use_results(&request.messages)?;
170
171        // Capture the request
172        {
173            let mut requests = self.captured_requests.lock().unwrap();
174            requests.push(request.clone());
175        }
176
177        // Increment call count
178        {
179            let mut count = self.call_count.lock().unwrap();
180            *count += 1;
181        }
182
183        let effective = {
184            let mut behavior = self.behavior.lock().unwrap();
185            Self::pop_behavior_from_queue(&mut behavior)
186        };
187
188        match effective {
189            MockBehavior::Success => Ok(ConversationResponse {
190                content: Content::text_only("Mock response".to_string()),
191                usage: TokenUsage::new(10, 10),
192                stop_reason: StopReason::EndTurn,
193            }),
194            MockBehavior::RetryableErrorThenSuccess {
195                mut remaining_errors,
196            } => {
197                if remaining_errors > 0 {
198                    remaining_errors -= 1;
199                    self.set_behavior(MockBehavior::RetryableErrorThenSuccess { remaining_errors });
200                    Err(AiError::Retryable(anyhow::anyhow!(
201                        "Mock retryable error (remaining: {})",
202                        remaining_errors
203                    )))
204                } else {
205                    Ok(ConversationResponse {
206                        content: Content::text_only("Success after retries".to_string()),
207                        usage: TokenUsage::new(10, 10),
208                        stop_reason: StopReason::EndTurn,
209                    })
210                }
211            }
212            MockBehavior::AlwaysRetryableError => Err(AiError::Retryable(anyhow::anyhow!(
213                "Mock retryable error (always fails)"
214            ))),
215            MockBehavior::AlwaysNonRetryableError => Err(AiError::Terminal(anyhow::anyhow!(
216                "Mock non-retryable error"
217            ))),
218            MockBehavior::ToolUse {
219                tool_name,
220                tool_arguments,
221            } => {
222                let tool_use = ToolUseData {
223                    id: format!("tool_{tool_name}"),
224                    name: tool_name.clone(),
225                    arguments: serde_json::from_str(&tool_arguments)
226                        .unwrap_or_else(|_| serde_json::json!({})),
227                };
228
229                Ok(ConversationResponse {
230                    content: Content::new(vec![
231                        ContentBlock::Text(format!(
232                            "I'll use the {tool_name} tool to help with this task."
233                        )),
234                        ContentBlock::ToolUse(tool_use),
235                    ]),
236                    usage: TokenUsage::new(10, 10),
237                    stop_reason: StopReason::ToolUse,
238                })
239            }
240            MockBehavior::ToolUseThenSuccess {
241                tool_name,
242                tool_arguments,
243            } => {
244                let tool_use = ToolUseData {
245                    id: format!("tool_{tool_name}"),
246                    name: tool_name.clone(),
247                    arguments: serde_json::from_str(&tool_arguments)
248                        .unwrap_or_else(|_| serde_json::json!({})),
249                };
250
251                let response = ConversationResponse {
252                    content: Content::new(vec![
253                        ContentBlock::Text(format!(
254                            "I'll use the {tool_name} tool to help with this task."
255                        )),
256                        ContentBlock::ToolUse(tool_use),
257                    ]),
258                    usage: TokenUsage::new(10, 10),
259                    stop_reason: StopReason::ToolUse,
260                };
261
262                self.set_behavior(MockBehavior::Success);
263                Ok(response)
264            }
265            MockBehavior::AlwaysInputTooLong => Err(AiError::InputTooLong(anyhow::anyhow!(
266                "Mock input too long error (always fails)"
267            ))),
268            MockBehavior::InputTooLongThenSuccess {
269                mut remaining_errors,
270            } => {
271                if remaining_errors > 0 {
272                    remaining_errors -= 1;
273                    self.set_behavior(MockBehavior::InputTooLongThenSuccess { remaining_errors });
274                    Err(AiError::InputTooLong(anyhow::anyhow!(
275                        "Mock input too long error (remaining: {})",
276                        remaining_errors
277                    )))
278                } else {
279                    Ok(ConversationResponse {
280                        content: Content::text_only(
281                            "Success after input too long errors".to_string(),
282                        ),
283                        usage: TokenUsage::new(10, 10),
284                        stop_reason: StopReason::EndTurn,
285                    })
286                }
287            }
288            MockBehavior::TextOnlyThenToolUse {
289                mut remaining_text_responses,
290                tool_name,
291                tool_arguments,
292            } => {
293                remaining_text_responses = remaining_text_responses.saturating_sub(1);
294
295                if remaining_text_responses == 0 {
296                    self.set_behavior(MockBehavior::ToolUseThenSuccess {
297                        tool_name,
298                        tool_arguments,
299                    });
300                } else {
301                    self.set_behavior(MockBehavior::TextOnlyThenToolUse {
302                        remaining_text_responses,
303                        tool_name,
304                        tool_arguments,
305                    });
306                }
307
308                Ok(ConversationResponse {
309                    content: Content::text_only("Mock text response without tools".to_string()),
310                    usage: TokenUsage::new(10, 10),
311                    stop_reason: StopReason::EndTurn,
312                })
313            }
314            MockBehavior::ToolUseThenToolUse {
315                first_tool_name,
316                first_tool_arguments,
317                second_tool_name,
318                second_tool_arguments,
319            } => {
320                let tool_use = ToolUseData {
321                    id: format!("tool_{first_tool_name}"),
322                    name: first_tool_name.clone(),
323                    arguments: serde_json::from_str(&first_tool_arguments)
324                        .unwrap_or_else(|_| serde_json::json!({})),
325                };
326
327                let response = ConversationResponse {
328                    content: Content::new(vec![
329                        ContentBlock::Text(format!(
330                            "I'll use the {first_tool_name} tool to help with this task."
331                        )),
332                        ContentBlock::ToolUse(tool_use),
333                    ]),
334                    usage: TokenUsage::new(10, 10),
335                    stop_reason: StopReason::ToolUse,
336                };
337
338                self.set_behavior(MockBehavior::ToolUseThenSuccess {
339                    tool_name: second_tool_name,
340                    tool_arguments: second_tool_arguments,
341                });
342
343                Ok(response)
344            }
345            MockBehavior::MultipleToolUses { tool_uses } => {
346                let mut content_blocks = vec![ContentBlock::Text(
347                    "I'll use multiple tools to help with this task.".to_string(),
348                )];
349
350                for (index, (tool_name, tool_arguments)) in tool_uses.iter().enumerate() {
351                    let tool_use = ToolUseData {
352                        id: format!("tool_{}_{}", tool_name, index),
353                        name: tool_name.clone(),
354                        arguments: serde_json::from_str(tool_arguments)
355                            .unwrap_or_else(|_| serde_json::json!({})),
356                    };
357                    content_blocks.push(ContentBlock::ToolUse(tool_use));
358                }
359
360                self.set_behavior(MockBehavior::Success);
361
362                Ok(ConversationResponse {
363                    content: Content::new(content_blocks),
364                    usage: TokenUsage::new(10, 10),
365                    stop_reason: StopReason::ToolUse,
366                })
367            }
368            MockBehavior::BehaviorQueue { .. } => {
369                panic!("Bug: nested BehaviorQueue detected. Test setup error - BehaviorQueues cannot contain other BehaviorQueues")
370            }
371        }
372    }
373
374    fn get_cost(&self, _model: &Model) -> Cost {
375        // Mock provider uses test costs
376        Cost::new(0.001, 0.002, 0.0, 0.0)
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[tokio::test]
385    async fn test_mock_provider_success() {
386        let provider = MockProvider::new(MockBehavior::Success);
387
388        let request = ConversationRequest {
389            messages: vec![Message::user("Test")],
390            model: Model::None.default_settings(),
391            system_prompt: String::new(),
392            stop_sequences: vec![],
393            tools: vec![],
394        };
395
396        let response = provider.converse(request).await.unwrap();
397        assert_eq!(response.content.text(), "Mock response");
398        assert_eq!(provider.get_call_count(), 1);
399    }
400
401    #[tokio::test]
402    async fn test_mock_provider_retry_then_success() {
403        let provider = MockProvider::new(MockBehavior::RetryableErrorThenSuccess {
404            remaining_errors: 2,
405        });
406
407        let request = ConversationRequest {
408            messages: vec![Message::user("Test")],
409            model: Model::None.default_settings(),
410            system_prompt: String::new(),
411            stop_sequences: vec![],
412            tools: vec![],
413        };
414
415        // First call should error
416        let result1 = provider.converse(request.clone()).await;
417        assert!(matches!(result1, Err(AiError::Retryable(_))));
418
419        // Second call should error
420        let result2 = provider.converse(request.clone()).await;
421        assert!(matches!(result2, Err(AiError::Retryable(_))));
422
423        // Third call should succeed
424        let result3 = provider.converse(request).await;
425        assert!(result3.is_ok());
426        assert_eq!(result3.unwrap().content.text(), "Success after retries");
427        assert_eq!(provider.get_call_count(), 3);
428    }
429}