swiftide_core/chat_completion/
traits.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use dyn_clone::DynClone;
4use futures_util::Stream;
5use std::{borrow::Cow, pin::Pin, sync::Arc};
6
7use crate::{AgentContext, CommandOutput, LanguageModelWithBackOff};
8
9use super::{
10    ToolOutput, ToolSpec,
11    chat_completion_request::ChatCompletionRequest,
12    chat_completion_response::ChatCompletionResponse,
13    errors::{LanguageModelError, ToolError},
14};
15
16#[async_trait]
17impl<LLM: ChatCompletion + Clone> ChatCompletion for LanguageModelWithBackOff<LLM> {
18    async fn complete(
19        &self,
20        request: &ChatCompletionRequest,
21    ) -> Result<ChatCompletionResponse, LanguageModelError> {
22        let strategy = self.strategy();
23
24        let op = || {
25            let request = request.clone();
26            async move {
27                self.inner.complete(&request).await.map_err(|e| match e {
28                    LanguageModelError::ContextLengthExceeded(e) => {
29                        backoff::Error::Permanent(LanguageModelError::ContextLengthExceeded(e))
30                    }
31                    LanguageModelError::PermanentError(e) => {
32                        backoff::Error::Permanent(LanguageModelError::PermanentError(e))
33                    }
34                    LanguageModelError::TransientError(e) => {
35                        backoff::Error::transient(LanguageModelError::TransientError(e))
36                    }
37                })
38            }
39        };
40
41        backoff::future::retry(strategy, op).await
42    }
43}
44
45pub type ChatCompletionStream =
46    Pin<Box<dyn Stream<Item = Result<ChatCompletionResponse, LanguageModelError>> + Send>>;
47#[async_trait]
48pub trait ChatCompletion: Send + Sync + DynClone {
49    async fn complete(
50        &self,
51        request: &ChatCompletionRequest,
52    ) -> Result<ChatCompletionResponse, LanguageModelError>;
53
54    /// Stream the completion response. If it's not supported, it will return a single
55    /// response
56    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
57        Box::pin(tokio_stream::iter(vec![self.complete(request).await]))
58    }
59}
60
61#[async_trait]
62impl ChatCompletion for Box<dyn ChatCompletion> {
63    async fn complete(
64        &self,
65        request: &ChatCompletionRequest,
66    ) -> Result<ChatCompletionResponse, LanguageModelError> {
67        (**self).complete(request).await
68    }
69
70    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
71        (**self).complete_stream(request).await
72    }
73}
74
75#[async_trait]
76impl ChatCompletion for &dyn ChatCompletion {
77    async fn complete(
78        &self,
79        request: &ChatCompletionRequest,
80    ) -> Result<ChatCompletionResponse, LanguageModelError> {
81        (**self).complete(request).await
82    }
83
84    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
85        (**self).complete_stream(request).await
86    }
87}
88
89#[async_trait]
90impl<T> ChatCompletion for &T
91where
92    T: ChatCompletion + Clone + 'static,
93{
94    async fn complete(
95        &self,
96        request: &ChatCompletionRequest,
97    ) -> Result<ChatCompletionResponse, LanguageModelError> {
98        (**self).complete(request).await
99    }
100
101    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
102        (**self).complete_stream(request).await
103    }
104}
105
106impl<LLM> From<&LLM> for Box<dyn ChatCompletion>
107where
108    LLM: ChatCompletion + Clone + 'static,
109{
110    fn from(llm: &LLM) -> Self {
111        Box::new(llm.clone()) as Box<dyn ChatCompletion>
112    }
113}
114
115dyn_clone::clone_trait_object!(ChatCompletion);
116
117#[cfg(test)]
118mod tests {
119    use uuid::Uuid;
120
121    use super::*;
122    use crate::BackoffConfiguration;
123    use std::{
124        collections::HashSet,
125        sync::{
126            Arc,
127            atomic::{AtomicUsize, Ordering},
128        },
129    };
130
131    #[derive(Clone)]
132    enum MockErrorType {
133        Transient,
134        Permanent,
135        ContextLengthExceeded,
136    }
137
138    #[derive(Clone)]
139    struct MockChatCompletion {
140        call_count: Arc<AtomicUsize>,
141        should_fail_count: usize,
142        error_type: MockErrorType,
143    }
144
145    #[async_trait]
146    impl ChatCompletion for MockChatCompletion {
147        async fn complete(
148            &self,
149            _request: &ChatCompletionRequest,
150        ) -> Result<ChatCompletionResponse, LanguageModelError> {
151            let count = self.call_count.fetch_add(1, Ordering::SeqCst);
152
153            if count < self.should_fail_count {
154                match self.error_type {
155                    MockErrorType::Transient => Err(LanguageModelError::TransientError(Box::new(
156                        std::io::Error::new(std::io::ErrorKind::ConnectionReset, "Transient error"),
157                    ))),
158                    MockErrorType::Permanent => Err(LanguageModelError::PermanentError(Box::new(
159                        std::io::Error::new(std::io::ErrorKind::InvalidData, "Permanent error"),
160                    ))),
161                    MockErrorType::ContextLengthExceeded => Err(
162                        LanguageModelError::ContextLengthExceeded(Box::new(std::io::Error::new(
163                            std::io::ErrorKind::InvalidInput,
164                            "Context length exceeded",
165                        ))),
166                    ),
167                }
168            } else {
169                Ok(ChatCompletionResponse {
170                    id: Uuid::new_v4(),
171                    message: Some("Success response".to_string()),
172                    tool_calls: None,
173                    delta: None,
174                    usage: None,
175                })
176            }
177        }
178    }
179
180    #[tokio::test]
181    async fn test_language_model_with_backoff_retries_chat_completion_transient_errors() {
182        let call_count = Arc::new(AtomicUsize::new(0));
183        let mock_chat = MockChatCompletion {
184            call_count: call_count.clone(),
185            should_fail_count: 2, // Fail twice, succeed on third attempt
186            error_type: MockErrorType::Transient,
187        };
188
189        let config = BackoffConfiguration {
190            initial_interval_sec: 1,
191            max_elapsed_time_sec: 10,
192            multiplier: 1.5,
193            randomization_factor: 0.5,
194        };
195
196        let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
197
198        let request = ChatCompletionRequest {
199            messages: vec![],
200            tools_spec: HashSet::default(),
201        };
202
203        let result = model_with_backoff.complete(&request).await;
204
205        assert!(result.is_ok());
206        assert_eq!(call_count.load(Ordering::SeqCst), 3);
207        assert_eq!(
208            result.unwrap().message,
209            Some("Success response".to_string())
210        );
211    }
212
213    #[tokio::test]
214    async fn test_language_model_with_backoff_does_not_retry_chat_completion_permanent_errors() {
215        let call_count = Arc::new(AtomicUsize::new(0));
216        let mock_chat = MockChatCompletion {
217            call_count: call_count.clone(),
218            should_fail_count: 2, // Would fail twice if retried
219            error_type: MockErrorType::Permanent,
220        };
221
222        let config = BackoffConfiguration {
223            initial_interval_sec: 1,
224            max_elapsed_time_sec: 10,
225            multiplier: 1.5,
226            randomization_factor: 0.5,
227        };
228
229        let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
230
231        let request = ChatCompletionRequest {
232            messages: vec![],
233            tools_spec: HashSet::default(),
234        };
235
236        let result = model_with_backoff.complete(&request).await;
237
238        assert!(result.is_err());
239        assert_eq!(call_count.load(Ordering::SeqCst), 1); // Should only be called once
240
241        match result {
242            Err(LanguageModelError::PermanentError(_)) => {} // Expected
243            _ => panic!("Expected PermanentError, got {result:?}"),
244        }
245    }
246
247    #[tokio::test]
248    async fn test_language_model_with_backoff_does_not_retry_chat_completion_context_length_errors()
249    {
250        let call_count = Arc::new(AtomicUsize::new(0));
251        let mock_chat = MockChatCompletion {
252            call_count: call_count.clone(),
253            should_fail_count: 2, // Would fail twice if retried
254            error_type: MockErrorType::ContextLengthExceeded,
255        };
256
257        let config = BackoffConfiguration {
258            initial_interval_sec: 1,
259            max_elapsed_time_sec: 10,
260            multiplier: 1.5,
261            randomization_factor: 0.5,
262        };
263
264        let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
265
266        let request = ChatCompletionRequest {
267            messages: vec![],
268            tools_spec: HashSet::default(),
269        };
270
271        let result = model_with_backoff.complete(&request).await;
272
273        assert!(result.is_err());
274        assert_eq!(call_count.load(Ordering::SeqCst), 1); // Should only be called once
275
276        match result {
277            Err(LanguageModelError::ContextLengthExceeded(_)) => {} // Expected
278            _ => panic!("Expected ContextLengthExceeded, got {result:?}"),
279        }
280    }
281}
282
283impl From<CommandOutput> for ToolOutput {
284    fn from(value: CommandOutput) -> Self {
285        ToolOutput::Text(value.output)
286    }
287}
288
289/// The `Tool` trait is the main interface for chat completion and agent tools.
290///
291/// `swiftide-macros` provides a set of macros to generate implementations of this trait. If you
292/// need more control over the implementation, you can implement the trait manually.
293///
294/// The `ToolSpec` is what will end up with the LLM. A builder is provided. The `name` is expected
295/// to be unique, and is used to identify the tool. It should be the same as the name in the
296/// `ToolSpec`.
297#[async_trait]
298pub trait Tool: Send + Sync + DynClone {
299    // tbd
300    async fn invoke(
301        &self,
302        agent_context: &dyn AgentContext,
303        raw_args: Option<&str>,
304    ) -> Result<ToolOutput, ToolError>;
305
306    fn name(&self) -> Cow<'_, str>;
307
308    fn tool_spec(&self) -> ToolSpec;
309
310    fn boxed<'a>(self) -> Box<dyn Tool + 'a>
311    where
312        Self: Sized + 'a,
313    {
314        Box::new(self) as Box<dyn Tool>
315    }
316}
317
318/// A toolbox is a collection of tools
319///
320/// It can be a list, an mcp client, or anything else we can think of.
321///
322/// This allows agents to not know their tools when they are created, and to get them at runtime.
323///
324/// It also allows for tools to be dynamically loaded and unloaded, etc.
325#[async_trait]
326pub trait ToolBox: Send + Sync + DynClone {
327    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>>;
328
329    fn name(&self) -> Cow<'_, str> {
330        Cow::Borrowed("Unnamed ToolBox")
331    }
332
333    fn boxed<'a>(self) -> Box<dyn ToolBox + 'a>
334    where
335        Self: Sized + 'a,
336    {
337        Box::new(self) as Box<dyn ToolBox>
338    }
339}
340
341#[async_trait]
342impl ToolBox for Vec<Box<dyn Tool>> {
343    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
344        Ok(self.clone())
345    }
346}
347
348#[async_trait]
349impl ToolBox for Box<dyn ToolBox> {
350    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
351        (**self).available_tools().await
352    }
353}
354
355#[async_trait]
356impl ToolBox for Arc<dyn ToolBox> {
357    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
358        (**self).available_tools().await
359    }
360}
361
362#[async_trait]
363impl ToolBox for &dyn ToolBox {
364    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
365        (**self).available_tools().await
366    }
367}
368
369#[async_trait]
370impl ToolBox for &[Box<dyn Tool>] {
371    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
372        Ok(self.to_vec())
373    }
374}
375
376#[async_trait]
377impl ToolBox for [Box<dyn Tool>] {
378    async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
379        Ok(self.to_vec())
380    }
381}
382
383dyn_clone::clone_trait_object!(ToolBox);
384
385#[async_trait]
386impl Tool for Box<dyn Tool> {
387    async fn invoke(
388        &self,
389        agent_context: &dyn AgentContext,
390        raw_args: Option<&str>,
391    ) -> Result<ToolOutput, ToolError> {
392        (**self).invoke(agent_context, raw_args).await
393    }
394    fn name(&self) -> Cow<'_, str> {
395        (**self).name()
396    }
397    fn tool_spec(&self) -> ToolSpec {
398        (**self).tool_spec()
399    }
400}
401
402dyn_clone::clone_trait_object!(Tool);
403
404/// Tools are identified and unique by name
405/// These allow comparison and lookups
406impl PartialEq for Box<dyn Tool> {
407    fn eq(&self, other: &Self) -> bool {
408        self.name() == other.name()
409    }
410}
411impl Eq for Box<dyn Tool> {}
412impl std::hash::Hash for Box<dyn Tool> {
413    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
414        self.name().hash(state);
415    }
416}