swiftide_core/chat_completion/
traits.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use dyn_clone::DynClone;
4use futures_util::Stream;
5use std::{borrow::Cow, pin::Pin, sync::Arc};
6
7use crate::{AgentContext, LanguageModelWithBackOff};
8
9use super::{
10    ToolCall, ToolOutput, ToolSpec,
11    chat_completion_request::ChatCompletionRequest,
12    chat_completion_response::ChatCompletionResponse,
13    errors::{LanguageModelError, ToolError},
14};
15
16#[async_trait]
17impl<LLM: ChatCompletion + Clone> ChatCompletion for LanguageModelWithBackOff<LLM> {
18    async fn complete(
19        &self,
20        request: &ChatCompletionRequest,
21    ) -> Result<ChatCompletionResponse, LanguageModelError> {
22        let strategy = self.strategy();
23
24        let op = || {
25            let request = request.clone();
26            async move {
27                self.inner.complete(&request).await.map_err(|e| match e {
28                    LanguageModelError::ContextLengthExceeded(e) => {
29                        backoff::Error::Permanent(LanguageModelError::ContextLengthExceeded(e))
30                    }
31                    LanguageModelError::PermanentError(e) => {
32                        backoff::Error::Permanent(LanguageModelError::PermanentError(e))
33                    }
34                    LanguageModelError::TransientError(e) => {
35                        backoff::Error::transient(LanguageModelError::TransientError(e))
36                    }
37                })
38            }
39        };
40
41        backoff::future::retry(strategy, op).await
42    }
43}
44
45pub type ChatCompletionStream =
46    Pin<Box<dyn Stream<Item = Result<ChatCompletionResponse, LanguageModelError>> + Send>>;
47#[async_trait]
48pub trait ChatCompletion: Send + Sync + DynClone {
49    async fn complete(
50        &self,
51        request: &ChatCompletionRequest,
52    ) -> Result<ChatCompletionResponse, LanguageModelError>;
53
54    /// Stream the completion response. If it's not supported, it will return a single
55    /// response
56    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
57        Box::pin(tokio_stream::iter(vec![self.complete(request).await]))
58    }
59}
60
61#[async_trait]
62impl ChatCompletion for Box<dyn ChatCompletion> {
63    async fn complete(
64        &self,
65        request: &ChatCompletionRequest,
66    ) -> Result<ChatCompletionResponse, LanguageModelError> {
67        (**self).complete(request).await
68    }
69
70    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
71        (**self).complete_stream(request).await
72    }
73}
74
75#[async_trait]
76impl ChatCompletion for &dyn ChatCompletion {
77    async fn complete(
78        &self,
79        request: &ChatCompletionRequest,
80    ) -> Result<ChatCompletionResponse, LanguageModelError> {
81        (**self).complete(request).await
82    }
83
84    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
85        (**self).complete_stream(request).await
86    }
87}
88
89#[async_trait]
90impl<T> ChatCompletion for &T
91where
92    T: ChatCompletion + Clone + 'static,
93{
94    async fn complete(
95        &self,
96        request: &ChatCompletionRequest,
97    ) -> Result<ChatCompletionResponse, LanguageModelError> {
98        (**self).complete(request).await
99    }
100
101    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
102        (**self).complete_stream(request).await
103    }
104}
105
106impl<LLM> From<&LLM> for Box<dyn ChatCompletion>
107where
108    LLM: ChatCompletion + Clone + 'static,
109{
110    fn from(llm: &LLM) -> Self {
111        Box::new(llm.clone()) as Box<dyn ChatCompletion>
112    }
113}
114
115dyn_clone::clone_trait_object!(ChatCompletion);
116
117#[cfg(test)]
118mod tests {
119    use uuid::Uuid;
120
121    use super::*;
122    use crate::BackoffConfiguration;
123    use std::{
124        collections::HashSet,
125        sync::{
126            Arc,
127            atomic::{AtomicUsize, Ordering},
128        },
129    };
130
131    #[derive(Clone)]
132    enum MockErrorType {
133        Transient,
134        Permanent,
135        ContextLengthExceeded,
136    }
137
138    #[derive(Clone)]
139    struct MockChatCompletion {
140        call_count: Arc<AtomicUsize>,
141        should_fail_count: usize,
142        error_type: MockErrorType,
143    }
144
145    #[async_trait]
146    impl ChatCompletion for MockChatCompletion {
147        async fn complete(
148            &self,
149            _request: &ChatCompletionRequest,
150        ) -> Result<ChatCompletionResponse, LanguageModelError> {
151            let count = self.call_count.fetch_add(1, Ordering::SeqCst);
152
153            if count < self.should_fail_count {
154                match self.error_type {
155                    MockErrorType::Transient => Err(LanguageModelError::TransientError(Box::new(
156                        std::io::Error::new(std::io::ErrorKind::ConnectionReset, "Transient error"),
157                    ))),
158                    MockErrorType::Permanent => Err(LanguageModelError::PermanentError(Box::new(
159                        std::io::Error::new(std::io::ErrorKind::InvalidData, "Permanent error"),
160                    ))),
161                    MockErrorType::ContextLengthExceeded => Err(
162                        LanguageModelError::ContextLengthExceeded(Box::new(std::io::Error::new(
163                            std::io::ErrorKind::InvalidInput,
164                            "Context length exceeded",
165                        ))),
166                    ),
167                }
168            } else {
169                Ok(ChatCompletionResponse {
170                    id: Uuid::new_v4(),
171                    message: Some("Success response".to_string()),
172                    tool_calls: None,
173                    delta: None,
174                    usage: None,
175                })
176            }
177        }
178    }
179
180    #[tokio::test]
181    async fn test_language_model_with_backoff_retries_chat_completion_transient_errors() {
182        let call_count = Arc::new(AtomicUsize::new(0));
183        let mock_chat = MockChatCompletion {
184            call_count: call_count.clone(),
185            should_fail_count: 2, // Fail twice, succeed on third attempt
186            error_type: MockErrorType::Transient,
187        };
188
189        let config = BackoffConfiguration {
190            initial_interval_sec: 1,
191            max_elapsed_time_sec: 10,
192            multiplier: 1.5,
193            randomization_factor: 0.5,
194        };
195
196        let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
197
198        let request = ChatCompletionRequest {
199            messages: vec![],
200            tools_spec: HashSet::default(),
201        };
202
203        let result = model_with_backoff.complete(&request).await;
204
205        assert!(result.is_ok());
206        assert_eq!(call_count.load(Ordering::SeqCst), 3);
207        assert_eq!(
208            result.unwrap().message,
209            Some("Success response".to_string())
210        );
211    }
212
213    #[tokio::test]
214    async fn test_language_model_with_backoff_does_not_retry_chat_completion_permanent_errors() {
215        let call_count = Arc::new(AtomicUsize::new(0));
216        let mock_chat = MockChatCompletion {
217            call_count: call_count.clone(),
218            should_fail_count: 2, // Would fail twice if retried
219            error_type: MockErrorType::Permanent,
220        };
221
222        let config = BackoffConfiguration {
223            initial_interval_sec: 1,
224            max_elapsed_time_sec: 10,
225            multiplier: 1.5,
226            randomization_factor: 0.5,
227        };
228
229        let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
230
231        let request = ChatCompletionRequest {
232            messages: vec![],
233            tools_spec: HashSet::default(),
234        };
235
236        let result = model_with_backoff.complete(&request).await;
237
238        assert!(result.is_err());
239        assert_eq!(call_count.load(Ordering::SeqCst), 1); // Should only be called once
240
241        match result {
242            Err(LanguageModelError::PermanentError(_)) => {} // Expected
243            _ => panic!("Expected PermanentError, got {result:?}"),
244        }
245    }
246
247    #[tokio::test]
248    async fn test_language_model_with_backoff_does_not_retry_chat_completion_context_length_errors()
249    {
250        let call_count = Arc::new(AtomicUsize::new(0));
251        let mock_chat = MockChatCompletion {
252            call_count: call_count.clone(),
253            should_fail_count: 2, // Would fail twice if retried
254            error_type: MockErrorType::ContextLengthExceeded,
255        };
256
257        let config = BackoffConfiguration {
258            initial_interval_sec: 1,
259            max_elapsed_time_sec: 10,
260            multiplier: 1.5,
261            randomization_factor: 0.5,
262        };
263
264        let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
265
266        let request = ChatCompletionRequest {
267            messages: vec![],
268            tools_spec: HashSet::default(),
269        };
270
271        let result = model_with_backoff.complete(&request).await;
272
273        assert!(result.is_err());
274        assert_eq!(call_count.load(Ordering::SeqCst), 1); // Should only be called once
275
276        match result {
277            Err(LanguageModelError::ContextLengthExceeded(_)) => {} // Expected
278            _ => panic!("Expected ContextLengthExceeded, got {result:?}"),
279        }
280    }
281}
282
283/// The `Tool` trait is the main interface for chat completion and agent tools.
284///
285/// `swiftide-macros` provides a set of macros to generate implementations of this trait. If you
286/// need more control over the implementation, you can implement the trait manually.
287///
288/// The `ToolSpec` is what will end up with the LLM. A builder is provided. The `name` is expected
289/// to be unique, and is used to identify the tool. It should be the same as the name in the
290/// `ToolSpec`.
291#[async_trait]
292pub trait Tool: Send + Sync + DynClone {
293    // tbd
294    async fn invoke(
295        &self,
296        agent_context: &dyn AgentContext,
297        tool_call: &ToolCall,
298    ) -> Result<ToolOutput, ToolError>;
299
300    fn name(&self) -> Cow<'_, str>;
301
302    fn tool_spec(&self) -> ToolSpec;
303
304    fn boxed<'a>(self) -> Box<dyn Tool + 'a>
305    where
306        Self: Sized + 'a,
307    {
308        Box::new(self) as Box<dyn Tool>
309    }
310}
311
312/// A toolbox is a collection of tools
313///
314/// It can be a list, an mcp client, or anything else we can think of.
315///
316/// This allows agents to not know their tools when they are created, and to get them at runtime.
317///
318/// It also allows for tools to be dynamically loaded and unloaded, etc.
319#[async_trait]
320pub trait ToolBox: Send + Sync + DynClone {
321    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>>;
322
323    fn name(&self) -> Cow<'_, str> {
324        Cow::Borrowed("Unnamed ToolBox")
325    }
326
327    fn boxed<'a>(self) -> Box<dyn ToolBox + 'a>
328    where
329        Self: Sized + 'a,
330    {
331        Box::new(self) as Box<dyn ToolBox>
332    }
333}
334
335#[async_trait]
336impl ToolBox for Vec<Box<dyn Tool>> {
337    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
338        Ok(self.clone())
339    }
340}
341
342#[async_trait]
343impl ToolBox for Box<dyn ToolBox> {
344    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
345        (**self).available_tools().await
346    }
347}
348
349#[async_trait]
350impl ToolBox for Arc<dyn ToolBox> {
351    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
352        (**self).available_tools().await
353    }
354}
355
356#[async_trait]
357impl ToolBox for &dyn ToolBox {
358    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
359        (**self).available_tools().await
360    }
361}
362
363#[async_trait]
364impl ToolBox for &[Box<dyn Tool>] {
365    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
366        Ok(self.to_vec())
367    }
368}
369
370#[async_trait]
371impl ToolBox for [Box<dyn Tool>] {
372    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
373        Ok(self.to_vec())
374    }
375}
376
377dyn_clone::clone_trait_object!(ToolBox);
378
379#[async_trait]
380impl Tool for Box<dyn Tool> {
381    async fn invoke(
382        &self,
383        agent_context: &dyn AgentContext,
384        tool_call: &ToolCall,
385    ) -> Result<ToolOutput, ToolError> {
386        (**self).invoke(agent_context, tool_call).await
387    }
388    fn name(&self) -> Cow<'_, str> {
389        (**self).name()
390    }
391    fn tool_spec(&self) -> ToolSpec {
392        (**self).tool_spec()
393    }
394}
395
396dyn_clone::clone_trait_object!(Tool);
397
398/// Tools are identified and unique by name
399/// These allow comparison and lookups
400impl PartialEq for Box<dyn Tool> {
401    fn eq(&self, other: &Self) -> bool {
402        self.name() == other.name()
403    }
404}
405impl Eq for Box<dyn Tool> {}
406impl std::hash::Hash for Box<dyn Tool> {
407    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
408        self.name().hash(state);
409    }
410}