swiftide_core/chat_completion/
traits.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use dyn_clone::DynClone;
4use std::{borrow::Cow, sync::Arc};
5
6use crate::{AgentContext, CommandOutput, LanguageModelWithBackOff};
7
8use super::{
9    chat_completion_request::ChatCompletionRequest,
10    chat_completion_response::ChatCompletionResponse,
11    errors::{LanguageModelError, ToolError},
12    ToolOutput, ToolSpec,
13};
14
15#[async_trait]
16impl<LLM: ChatCompletion + Clone> ChatCompletion for LanguageModelWithBackOff<LLM> {
17    async fn complete(
18        &self,
19        request: &ChatCompletionRequest,
20    ) -> Result<ChatCompletionResponse, LanguageModelError> {
21        let strategy = self.strategy();
22
23        let op = || {
24            let request = request.clone();
25            async move {
26                self.inner.complete(&request).await.map_err(|e| match e {
27                    LanguageModelError::ContextLengthExceeded(e) => {
28                        backoff::Error::Permanent(LanguageModelError::ContextLengthExceeded(e))
29                    }
30                    LanguageModelError::PermanentError(e) => {
31                        backoff::Error::Permanent(LanguageModelError::PermanentError(e))
32                    }
33                    LanguageModelError::TransientError(e) => {
34                        backoff::Error::transient(LanguageModelError::TransientError(e))
35                    }
36                })
37            }
38        };
39
40        backoff::future::retry(strategy, op).await
41    }
42}
43
44#[async_trait]
45pub trait ChatCompletion: Send + Sync + DynClone {
46    async fn complete(
47        &self,
48        request: &ChatCompletionRequest,
49    ) -> Result<ChatCompletionResponse, LanguageModelError>;
50}
51
52#[async_trait]
53impl ChatCompletion for Box<dyn ChatCompletion> {
54    async fn complete(
55        &self,
56        request: &ChatCompletionRequest,
57    ) -> Result<ChatCompletionResponse, LanguageModelError> {
58        (**self).complete(request).await
59    }
60}
61
62#[async_trait]
63impl ChatCompletion for &dyn ChatCompletion {
64    async fn complete(
65        &self,
66        request: &ChatCompletionRequest,
67    ) -> Result<ChatCompletionResponse, LanguageModelError> {
68        (**self).complete(request).await
69    }
70}
71
72#[async_trait]
73impl<T> ChatCompletion for &T
74where
75    T: ChatCompletion + Clone + 'static,
76{
77    async fn complete(
78        &self,
79        request: &ChatCompletionRequest,
80    ) -> Result<ChatCompletionResponse, LanguageModelError> {
81        (**self).complete(request).await
82    }
83}
84
85impl<LLM> From<&LLM> for Box<dyn ChatCompletion>
86where
87    LLM: ChatCompletion + Clone + 'static,
88{
89    fn from(llm: &LLM) -> Self {
90        Box::new(llm.clone()) as Box<dyn ChatCompletion>
91    }
92}
93
94dyn_clone::clone_trait_object!(ChatCompletion);
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use crate::BackoffConfiguration;
100    use std::{
101        collections::HashSet,
102        sync::{
103            atomic::{AtomicUsize, Ordering},
104            Arc,
105        },
106    };
107
108    #[derive(Clone)]
109    enum MockErrorType {
110        Transient,
111        Permanent,
112        ContextLengthExceeded,
113    }
114
115    #[derive(Clone)]
116    struct MockChatCompletion {
117        call_count: Arc<AtomicUsize>,
118        should_fail_count: usize,
119        error_type: MockErrorType,
120    }
121
122    #[async_trait]
123    impl ChatCompletion for MockChatCompletion {
124        async fn complete(
125            &self,
126            _request: &ChatCompletionRequest,
127        ) -> Result<ChatCompletionResponse, LanguageModelError> {
128            let count = self.call_count.fetch_add(1, Ordering::SeqCst);
129
130            if count < self.should_fail_count {
131                match self.error_type {
132                    MockErrorType::Transient => Err(LanguageModelError::TransientError(Box::new(
133                        std::io::Error::new(std::io::ErrorKind::ConnectionReset, "Transient error"),
134                    ))),
135                    MockErrorType::Permanent => Err(LanguageModelError::PermanentError(Box::new(
136                        std::io::Error::new(std::io::ErrorKind::InvalidData, "Permanent error"),
137                    ))),
138                    MockErrorType::ContextLengthExceeded => Err(
139                        LanguageModelError::ContextLengthExceeded(Box::new(std::io::Error::new(
140                            std::io::ErrorKind::InvalidInput,
141                            "Context length exceeded",
142                        ))),
143                    ),
144                }
145            } else {
146                Ok(ChatCompletionResponse {
147                    message: Some("Success response".to_string()),
148                    tool_calls: None,
149                })
150            }
151        }
152    }
153
154    #[tokio::test]
155    async fn test_language_model_with_backoff_retries_chat_completion_transient_errors() {
156        let call_count = Arc::new(AtomicUsize::new(0));
157        let mock_chat = MockChatCompletion {
158            call_count: call_count.clone(),
159            should_fail_count: 2, // Fail twice, succeed on third attempt
160            error_type: MockErrorType::Transient,
161        };
162
163        let config = BackoffConfiguration {
164            initial_interval_sec: 1,
165            max_elapsed_time_sec: 10,
166            multiplier: 1.5,
167            randomization_factor: 0.5,
168        };
169
170        let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
171
172        let request = ChatCompletionRequest {
173            messages: vec![],
174            tools_spec: HashSet::default(),
175        };
176
177        let result = model_with_backoff.complete(&request).await;
178
179        assert!(result.is_ok());
180        assert_eq!(call_count.load(Ordering::SeqCst), 3);
181        assert_eq!(
182            result.unwrap().message,
183            Some("Success response".to_string())
184        );
185    }
186
187    #[tokio::test]
188    async fn test_language_model_with_backoff_does_not_retry_chat_completion_permanent_errors() {
189        let call_count = Arc::new(AtomicUsize::new(0));
190        let mock_chat = MockChatCompletion {
191            call_count: call_count.clone(),
192            should_fail_count: 2, // Would fail twice if retried
193            error_type: MockErrorType::Permanent,
194        };
195
196        let config = BackoffConfiguration {
197            initial_interval_sec: 1,
198            max_elapsed_time_sec: 10,
199            multiplier: 1.5,
200            randomization_factor: 0.5,
201        };
202
203        let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
204
205        let request = ChatCompletionRequest {
206            messages: vec![],
207            tools_spec: HashSet::default(),
208        };
209
210        let result = model_with_backoff.complete(&request).await;
211
212        assert!(result.is_err());
213        assert_eq!(call_count.load(Ordering::SeqCst), 1); // Should only be called once
214
215        match result {
216            Err(LanguageModelError::PermanentError(_)) => {} // Expected
217            _ => panic!("Expected PermanentError, got {result:?}"),
218        }
219    }
220
221    #[tokio::test]
222    async fn test_language_model_with_backoff_does_not_retry_chat_completion_context_length_errors()
223    {
224        let call_count = Arc::new(AtomicUsize::new(0));
225        let mock_chat = MockChatCompletion {
226            call_count: call_count.clone(),
227            should_fail_count: 2, // Would fail twice if retried
228            error_type: MockErrorType::ContextLengthExceeded,
229        };
230
231        let config = BackoffConfiguration {
232            initial_interval_sec: 1,
233            max_elapsed_time_sec: 10,
234            multiplier: 1.5,
235            randomization_factor: 0.5,
236        };
237
238        let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
239
240        let request = ChatCompletionRequest {
241            messages: vec![],
242            tools_spec: HashSet::default(),
243        };
244
245        let result = model_with_backoff.complete(&request).await;
246
247        assert!(result.is_err());
248        assert_eq!(call_count.load(Ordering::SeqCst), 1); // Should only be called once
249
250        match result {
251            Err(LanguageModelError::ContextLengthExceeded(_)) => {} // Expected
252            _ => panic!("Expected ContextLengthExceeded, got {result:?}"),
253        }
254    }
255}
256
257impl From<CommandOutput> for ToolOutput {
258    fn from(value: CommandOutput) -> Self {
259        ToolOutput::Text(value.output)
260    }
261}
262
263/// The `Tool` trait is the main interface for chat completion and agent tools.
264///
265/// `swiftide-macros` provides a set of macros to generate implementations of this trait. If you
266/// need more control over the implementation, you can implement the trait manually.
267///
268/// The `ToolSpec` is what will end up with the LLM. A builder is provided. The `name` is expected
269/// to be unique, and is used to identify the tool. It should be the same as the name in the
270/// `ToolSpec`.
271#[async_trait]
272pub trait Tool: Send + Sync + DynClone {
273    // tbd
274    async fn invoke(
275        &self,
276        agent_context: &dyn AgentContext,
277        raw_args: Option<&str>,
278    ) -> Result<ToolOutput, ToolError>;
279
280    fn name(&self) -> Cow<'_, str>;
281
282    fn tool_spec(&self) -> ToolSpec;
283
284    fn boxed<'a>(self) -> Box<dyn Tool + 'a>
285    where
286        Self: Sized + 'a,
287    {
288        Box::new(self) as Box<dyn Tool>
289    }
290}
291
292/// A toolbox is a collection of tools
293///
294/// It can be a list, an mcp client, or anything else we can think of.
295///
296/// This allows agents to not know their tools when they are created, and to get them at runtime.
297///
298/// It also allows for tools to be dynamically loaded and unloaded, etc.
299#[async_trait]
300pub trait ToolBox: Send + Sync + DynClone {
301    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>>;
302
303    fn name(&self) -> Cow<'_, str> {
304        Cow::Borrowed("Unnamed ToolBox")
305    }
306
307    fn boxed<'a>(self) -> Box<dyn ToolBox + 'a>
308    where
309        Self: Sized + 'a,
310    {
311        Box::new(self) as Box<dyn ToolBox>
312    }
313}
314
315#[async_trait]
316impl ToolBox for Vec<Box<dyn Tool>> {
317    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
318        Ok(self.clone())
319    }
320}
321
322#[async_trait]
323impl ToolBox for Box<dyn ToolBox> {
324    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
325        (**self).available_tools().await
326    }
327}
328
329#[async_trait]
330impl ToolBox for Arc<dyn ToolBox> {
331    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
332        (**self).available_tools().await
333    }
334}
335
336#[async_trait]
337impl ToolBox for &dyn ToolBox {
338    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
339        (**self).available_tools().await
340    }
341}
342
343#[async_trait]
344impl ToolBox for &[Box<dyn Tool>] {
345    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
346        Ok(self.to_vec())
347    }
348}
349
350#[async_trait]
351impl ToolBox for [Box<dyn Tool>] {
352    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
353        Ok(self.to_vec())
354    }
355}
356
357dyn_clone::clone_trait_object!(ToolBox);
358
359#[async_trait]
360impl Tool for Box<dyn Tool> {
361    async fn invoke(
362        &self,
363        agent_context: &dyn AgentContext,
364        raw_args: Option<&str>,
365    ) -> Result<ToolOutput, ToolError> {
366        (**self).invoke(agent_context, raw_args).await
367    }
368    fn name(&self) -> Cow<'_, str> {
369        (**self).name()
370    }
371    fn tool_spec(&self) -> ToolSpec {
372        (**self).tool_spec()
373    }
374}
375
376dyn_clone::clone_trait_object!(Tool);
377
378/// Tools are identified and unique by name
379/// These allow comparison and lookups
380impl PartialEq for Box<dyn Tool> {
381    fn eq(&self, other: &Self) -> bool {
382        self.name() == other.name()
383    }
384}
385impl Eq for Box<dyn Tool> {}
386impl std::hash::Hash for Box<dyn Tool> {
387    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
388        self.name().hash(state);
389    }
390}