swiftide_core/chat_completion/
traits.rs1use anyhow::Result;
2use async_trait::async_trait;
3use dyn_clone::DynClone;
4use futures_util::Stream;
5use std::{borrow::Cow, pin::Pin, sync::Arc};
6
7use crate::{AgentContext, LanguageModelWithBackOff};
8
9use super::{
10 ToolCall, ToolOutput, ToolSpec,
11 chat_completion_request::ChatCompletionRequest,
12 chat_completion_response::ChatCompletionResponse,
13 errors::{LanguageModelError, ToolError},
14};
15
16#[async_trait]
17impl<LLM: ChatCompletion + Clone> ChatCompletion for LanguageModelWithBackOff<LLM> {
18 async fn complete(
19 &self,
20 request: &ChatCompletionRequest,
21 ) -> Result<ChatCompletionResponse, LanguageModelError> {
22 let strategy = self.strategy();
23
24 let op = || {
25 let request = request.clone();
26 async move {
27 self.inner.complete(&request).await.map_err(|e| match e {
28 LanguageModelError::ContextLengthExceeded(e) => {
29 backoff::Error::Permanent(LanguageModelError::ContextLengthExceeded(e))
30 }
31 LanguageModelError::PermanentError(e) => {
32 backoff::Error::Permanent(LanguageModelError::PermanentError(e))
33 }
34 LanguageModelError::TransientError(e) => {
35 backoff::Error::transient(LanguageModelError::TransientError(e))
36 }
37 })
38 }
39 };
40
41 backoff::future::retry(strategy, op).await
42 }
43}
44
45pub type ChatCompletionStream =
46 Pin<Box<dyn Stream<Item = Result<ChatCompletionResponse, LanguageModelError>> + Send>>;
47#[async_trait]
48pub trait ChatCompletion: Send + Sync + DynClone {
49 async fn complete(
50 &self,
51 request: &ChatCompletionRequest,
52 ) -> Result<ChatCompletionResponse, LanguageModelError>;
53
54 async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
57 Box::pin(tokio_stream::iter(vec![self.complete(request).await]))
58 }
59}
60
61#[async_trait]
62impl ChatCompletion for Box<dyn ChatCompletion> {
63 async fn complete(
64 &self,
65 request: &ChatCompletionRequest,
66 ) -> Result<ChatCompletionResponse, LanguageModelError> {
67 (**self).complete(request).await
68 }
69
70 async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
71 (**self).complete_stream(request).await
72 }
73}
74
75#[async_trait]
76impl ChatCompletion for &dyn ChatCompletion {
77 async fn complete(
78 &self,
79 request: &ChatCompletionRequest,
80 ) -> Result<ChatCompletionResponse, LanguageModelError> {
81 (**self).complete(request).await
82 }
83
84 async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
85 (**self).complete_stream(request).await
86 }
87}
88
89#[async_trait]
90impl<T> ChatCompletion for &T
91where
92 T: ChatCompletion + Clone + 'static,
93{
94 async fn complete(
95 &self,
96 request: &ChatCompletionRequest,
97 ) -> Result<ChatCompletionResponse, LanguageModelError> {
98 (**self).complete(request).await
99 }
100
101 async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
102 (**self).complete_stream(request).await
103 }
104}
105
106impl<LLM> From<&LLM> for Box<dyn ChatCompletion>
107where
108 LLM: ChatCompletion + Clone + 'static,
109{
110 fn from(llm: &LLM) -> Self {
111 Box::new(llm.clone()) as Box<dyn ChatCompletion>
112 }
113}
114
115dyn_clone::clone_trait_object!(ChatCompletion);
116
117#[cfg(test)]
118mod tests {
119 use uuid::Uuid;
120
121 use super::*;
122 use crate::BackoffConfiguration;
123 use std::{
124 collections::HashSet,
125 sync::{
126 Arc,
127 atomic::{AtomicUsize, Ordering},
128 },
129 };
130
131 #[derive(Clone)]
132 enum MockErrorType {
133 Transient,
134 Permanent,
135 ContextLengthExceeded,
136 }
137
138 #[derive(Clone)]
139 struct MockChatCompletion {
140 call_count: Arc<AtomicUsize>,
141 should_fail_count: usize,
142 error_type: MockErrorType,
143 }
144
145 #[async_trait]
146 impl ChatCompletion for MockChatCompletion {
147 async fn complete(
148 &self,
149 _request: &ChatCompletionRequest,
150 ) -> Result<ChatCompletionResponse, LanguageModelError> {
151 let count = self.call_count.fetch_add(1, Ordering::SeqCst);
152
153 if count < self.should_fail_count {
154 match self.error_type {
155 MockErrorType::Transient => Err(LanguageModelError::TransientError(Box::new(
156 std::io::Error::new(std::io::ErrorKind::ConnectionReset, "Transient error"),
157 ))),
158 MockErrorType::Permanent => Err(LanguageModelError::PermanentError(Box::new(
159 std::io::Error::new(std::io::ErrorKind::InvalidData, "Permanent error"),
160 ))),
161 MockErrorType::ContextLengthExceeded => Err(
162 LanguageModelError::ContextLengthExceeded(Box::new(std::io::Error::new(
163 std::io::ErrorKind::InvalidInput,
164 "Context length exceeded",
165 ))),
166 ),
167 }
168 } else {
169 Ok(ChatCompletionResponse {
170 id: Uuid::new_v4(),
171 message: Some("Success response".to_string()),
172 tool_calls: None,
173 delta: None,
174 usage: None,
175 })
176 }
177 }
178 }
179
180 #[tokio::test]
181 async fn test_language_model_with_backoff_retries_chat_completion_transient_errors() {
182 let call_count = Arc::new(AtomicUsize::new(0));
183 let mock_chat = MockChatCompletion {
184 call_count: call_count.clone(),
185 should_fail_count: 2, error_type: MockErrorType::Transient,
187 };
188
189 let config = BackoffConfiguration {
190 initial_interval_sec: 1,
191 max_elapsed_time_sec: 10,
192 multiplier: 1.5,
193 randomization_factor: 0.5,
194 };
195
196 let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
197
198 let request = ChatCompletionRequest {
199 messages: vec![],
200 tools_spec: HashSet::default(),
201 };
202
203 let result = model_with_backoff.complete(&request).await;
204
205 assert!(result.is_ok());
206 assert_eq!(call_count.load(Ordering::SeqCst), 3);
207 assert_eq!(
208 result.unwrap().message,
209 Some("Success response".to_string())
210 );
211 }
212
213 #[tokio::test]
214 async fn test_language_model_with_backoff_does_not_retry_chat_completion_permanent_errors() {
215 let call_count = Arc::new(AtomicUsize::new(0));
216 let mock_chat = MockChatCompletion {
217 call_count: call_count.clone(),
218 should_fail_count: 2, error_type: MockErrorType::Permanent,
220 };
221
222 let config = BackoffConfiguration {
223 initial_interval_sec: 1,
224 max_elapsed_time_sec: 10,
225 multiplier: 1.5,
226 randomization_factor: 0.5,
227 };
228
229 let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
230
231 let request = ChatCompletionRequest {
232 messages: vec![],
233 tools_spec: HashSet::default(),
234 };
235
236 let result = model_with_backoff.complete(&request).await;
237
238 assert!(result.is_err());
239 assert_eq!(call_count.load(Ordering::SeqCst), 1); match result {
242 Err(LanguageModelError::PermanentError(_)) => {} _ => panic!("Expected PermanentError, got {result:?}"),
244 }
245 }
246
247 #[tokio::test]
248 async fn test_language_model_with_backoff_does_not_retry_chat_completion_context_length_errors()
249 {
250 let call_count = Arc::new(AtomicUsize::new(0));
251 let mock_chat = MockChatCompletion {
252 call_count: call_count.clone(),
253 should_fail_count: 2, error_type: MockErrorType::ContextLengthExceeded,
255 };
256
257 let config = BackoffConfiguration {
258 initial_interval_sec: 1,
259 max_elapsed_time_sec: 10,
260 multiplier: 1.5,
261 randomization_factor: 0.5,
262 };
263
264 let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
265
266 let request = ChatCompletionRequest {
267 messages: vec![],
268 tools_spec: HashSet::default(),
269 };
270
271 let result = model_with_backoff.complete(&request).await;
272
273 assert!(result.is_err());
274 assert_eq!(call_count.load(Ordering::SeqCst), 1); match result {
277 Err(LanguageModelError::ContextLengthExceeded(_)) => {} _ => panic!("Expected ContextLengthExceeded, got {result:?}"),
279 }
280 }
281}
282
283#[async_trait]
292pub trait Tool: Send + Sync + DynClone {
293 async fn invoke(
295 &self,
296 agent_context: &dyn AgentContext,
297 tool_call: &ToolCall,
298 ) -> Result<ToolOutput, ToolError>;
299
300 fn name(&self) -> Cow<'_, str>;
301
302 fn tool_spec(&self) -> ToolSpec;
303
304 fn boxed<'a>(self) -> Box<dyn Tool + 'a>
305 where
306 Self: Sized + 'a,
307 {
308 Box::new(self) as Box<dyn Tool>
309 }
310}
311
312#[async_trait]
320pub trait ToolBox: Send + Sync + DynClone {
321 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>>;
322
323 fn name(&self) -> Cow<'_, str> {
324 Cow::Borrowed("Unnamed ToolBox")
325 }
326
327 fn boxed<'a>(self) -> Box<dyn ToolBox + 'a>
328 where
329 Self: Sized + 'a,
330 {
331 Box::new(self) as Box<dyn ToolBox>
332 }
333}
334
335#[async_trait]
336impl ToolBox for Vec<Box<dyn Tool>> {
337 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
338 Ok(self.clone())
339 }
340}
341
342#[async_trait]
343impl ToolBox for Box<dyn ToolBox> {
344 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
345 (**self).available_tools().await
346 }
347}
348
349#[async_trait]
350impl ToolBox for Arc<dyn ToolBox> {
351 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
352 (**self).available_tools().await
353 }
354}
355
356#[async_trait]
357impl ToolBox for &dyn ToolBox {
358 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
359 (**self).available_tools().await
360 }
361}
362
363#[async_trait]
364impl ToolBox for &[Box<dyn Tool>] {
365 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
366 Ok(self.to_vec())
367 }
368}
369
370#[async_trait]
371impl ToolBox for [Box<dyn Tool>] {
372 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
373 Ok(self.to_vec())
374 }
375}
376
377dyn_clone::clone_trait_object!(ToolBox);
378
379#[async_trait]
380impl Tool for Box<dyn Tool> {
381 async fn invoke(
382 &self,
383 agent_context: &dyn AgentContext,
384 tool_call: &ToolCall,
385 ) -> Result<ToolOutput, ToolError> {
386 (**self).invoke(agent_context, tool_call).await
387 }
388 fn name(&self) -> Cow<'_, str> {
389 (**self).name()
390 }
391 fn tool_spec(&self) -> ToolSpec {
392 (**self).tool_spec()
393 }
394}
395
396dyn_clone::clone_trait_object!(Tool);
397
398impl PartialEq for Box<dyn Tool> {
401 fn eq(&self, other: &Self) -> bool {
402 self.name() == other.name()
403 }
404}
405impl Eq for Box<dyn Tool> {}
406impl std::hash::Hash for Box<dyn Tool> {
407 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
408 self.name().hash(state);
409 }
410}