swiftide_core/chat_completion/
traits.rs1use anyhow::Result;
2use async_trait::async_trait;
3use dyn_clone::DynClone;
4use std::{borrow::Cow, sync::Arc};
5
6use crate::{AgentContext, CommandOutput, LanguageModelWithBackOff};
7
8use super::{
9 chat_completion_request::ChatCompletionRequest,
10 chat_completion_response::ChatCompletionResponse,
11 errors::{LanguageModelError, ToolError},
12 ToolOutput, ToolSpec,
13};
14
15#[async_trait]
16impl<LLM: ChatCompletion + Clone> ChatCompletion for LanguageModelWithBackOff<LLM> {
17 async fn complete(
18 &self,
19 request: &ChatCompletionRequest,
20 ) -> Result<ChatCompletionResponse, LanguageModelError> {
21 let strategy = self.strategy();
22
23 let op = || {
24 let request = request.clone();
25 async move {
26 self.inner.complete(&request).await.map_err(|e| match e {
27 LanguageModelError::ContextLengthExceeded(e) => {
28 backoff::Error::Permanent(LanguageModelError::ContextLengthExceeded(e))
29 }
30 LanguageModelError::PermanentError(e) => {
31 backoff::Error::Permanent(LanguageModelError::PermanentError(e))
32 }
33 LanguageModelError::TransientError(e) => {
34 backoff::Error::transient(LanguageModelError::TransientError(e))
35 }
36 })
37 }
38 };
39
40 backoff::future::retry(strategy, op).await
41 }
42}
43
44#[async_trait]
45pub trait ChatCompletion: Send + Sync + DynClone {
46 async fn complete(
47 &self,
48 request: &ChatCompletionRequest,
49 ) -> Result<ChatCompletionResponse, LanguageModelError>;
50}
51
52#[async_trait]
53impl ChatCompletion for Box<dyn ChatCompletion> {
54 async fn complete(
55 &self,
56 request: &ChatCompletionRequest,
57 ) -> Result<ChatCompletionResponse, LanguageModelError> {
58 (**self).complete(request).await
59 }
60}
61
62#[async_trait]
63impl ChatCompletion for &dyn ChatCompletion {
64 async fn complete(
65 &self,
66 request: &ChatCompletionRequest,
67 ) -> Result<ChatCompletionResponse, LanguageModelError> {
68 (**self).complete(request).await
69 }
70}
71
72#[async_trait]
73impl<T> ChatCompletion for &T
74where
75 T: ChatCompletion + Clone + 'static,
76{
77 async fn complete(
78 &self,
79 request: &ChatCompletionRequest,
80 ) -> Result<ChatCompletionResponse, LanguageModelError> {
81 (**self).complete(request).await
82 }
83}
84
85impl<LLM> From<&LLM> for Box<dyn ChatCompletion>
86where
87 LLM: ChatCompletion + Clone + 'static,
88{
89 fn from(llm: &LLM) -> Self {
90 Box::new(llm.clone()) as Box<dyn ChatCompletion>
91 }
92}
93
94dyn_clone::clone_trait_object!(ChatCompletion);
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use crate::BackoffConfiguration;
100 use std::{
101 collections::HashSet,
102 sync::{
103 atomic::{AtomicUsize, Ordering},
104 Arc,
105 },
106 };
107
108 #[derive(Clone)]
109 enum MockErrorType {
110 Transient,
111 Permanent,
112 ContextLengthExceeded,
113 }
114
115 #[derive(Clone)]
116 struct MockChatCompletion {
117 call_count: Arc<AtomicUsize>,
118 should_fail_count: usize,
119 error_type: MockErrorType,
120 }
121
122 #[async_trait]
123 impl ChatCompletion for MockChatCompletion {
124 async fn complete(
125 &self,
126 _request: &ChatCompletionRequest,
127 ) -> Result<ChatCompletionResponse, LanguageModelError> {
128 let count = self.call_count.fetch_add(1, Ordering::SeqCst);
129
130 if count < self.should_fail_count {
131 match self.error_type {
132 MockErrorType::Transient => Err(LanguageModelError::TransientError(Box::new(
133 std::io::Error::new(std::io::ErrorKind::ConnectionReset, "Transient error"),
134 ))),
135 MockErrorType::Permanent => Err(LanguageModelError::PermanentError(Box::new(
136 std::io::Error::new(std::io::ErrorKind::InvalidData, "Permanent error"),
137 ))),
138 MockErrorType::ContextLengthExceeded => Err(
139 LanguageModelError::ContextLengthExceeded(Box::new(std::io::Error::new(
140 std::io::ErrorKind::InvalidInput,
141 "Context length exceeded",
142 ))),
143 ),
144 }
145 } else {
146 Ok(ChatCompletionResponse {
147 message: Some("Success response".to_string()),
148 tool_calls: None,
149 })
150 }
151 }
152 }
153
154 #[tokio::test]
155 async fn test_language_model_with_backoff_retries_chat_completion_transient_errors() {
156 let call_count = Arc::new(AtomicUsize::new(0));
157 let mock_chat = MockChatCompletion {
158 call_count: call_count.clone(),
159 should_fail_count: 2, error_type: MockErrorType::Transient,
161 };
162
163 let config = BackoffConfiguration {
164 initial_interval_sec: 1,
165 max_elapsed_time_sec: 10,
166 multiplier: 1.5,
167 randomization_factor: 0.5,
168 };
169
170 let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
171
172 let request = ChatCompletionRequest {
173 messages: vec![],
174 tools_spec: HashSet::default(),
175 };
176
177 let result = model_with_backoff.complete(&request).await;
178
179 assert!(result.is_ok());
180 assert_eq!(call_count.load(Ordering::SeqCst), 3);
181 assert_eq!(
182 result.unwrap().message,
183 Some("Success response".to_string())
184 );
185 }
186
187 #[tokio::test]
188 async fn test_language_model_with_backoff_does_not_retry_chat_completion_permanent_errors() {
189 let call_count = Arc::new(AtomicUsize::new(0));
190 let mock_chat = MockChatCompletion {
191 call_count: call_count.clone(),
192 should_fail_count: 2, error_type: MockErrorType::Permanent,
194 };
195
196 let config = BackoffConfiguration {
197 initial_interval_sec: 1,
198 max_elapsed_time_sec: 10,
199 multiplier: 1.5,
200 randomization_factor: 0.5,
201 };
202
203 let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
204
205 let request = ChatCompletionRequest {
206 messages: vec![],
207 tools_spec: HashSet::default(),
208 };
209
210 let result = model_with_backoff.complete(&request).await;
211
212 assert!(result.is_err());
213 assert_eq!(call_count.load(Ordering::SeqCst), 1); match result {
216 Err(LanguageModelError::PermanentError(_)) => {} _ => panic!("Expected PermanentError, got {result:?}"),
218 }
219 }
220
221 #[tokio::test]
222 async fn test_language_model_with_backoff_does_not_retry_chat_completion_context_length_errors()
223 {
224 let call_count = Arc::new(AtomicUsize::new(0));
225 let mock_chat = MockChatCompletion {
226 call_count: call_count.clone(),
227 should_fail_count: 2, error_type: MockErrorType::ContextLengthExceeded,
229 };
230
231 let config = BackoffConfiguration {
232 initial_interval_sec: 1,
233 max_elapsed_time_sec: 10,
234 multiplier: 1.5,
235 randomization_factor: 0.5,
236 };
237
238 let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
239
240 let request = ChatCompletionRequest {
241 messages: vec![],
242 tools_spec: HashSet::default(),
243 };
244
245 let result = model_with_backoff.complete(&request).await;
246
247 assert!(result.is_err());
248 assert_eq!(call_count.load(Ordering::SeqCst), 1); match result {
251 Err(LanguageModelError::ContextLengthExceeded(_)) => {} _ => panic!("Expected ContextLengthExceeded, got {result:?}"),
253 }
254 }
255}
256
257impl From<CommandOutput> for ToolOutput {
258 fn from(value: CommandOutput) -> Self {
259 ToolOutput::Text(value.output)
260 }
261}
262
263#[async_trait]
272pub trait Tool: Send + Sync + DynClone {
273 async fn invoke(
275 &self,
276 agent_context: &dyn AgentContext,
277 raw_args: Option<&str>,
278 ) -> Result<ToolOutput, ToolError>;
279
280 fn name(&self) -> Cow<'_, str>;
281
282 fn tool_spec(&self) -> ToolSpec;
283
284 fn boxed<'a>(self) -> Box<dyn Tool + 'a>
285 where
286 Self: Sized + 'a,
287 {
288 Box::new(self) as Box<dyn Tool>
289 }
290}
291
292#[async_trait]
300pub trait ToolBox: Send + Sync + DynClone {
301 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>>;
302
303 fn name(&self) -> Cow<'_, str> {
304 Cow::Borrowed("Unnamed ToolBox")
305 }
306
307 fn boxed<'a>(self) -> Box<dyn ToolBox + 'a>
308 where
309 Self: Sized + 'a,
310 {
311 Box::new(self) as Box<dyn ToolBox>
312 }
313}
314
315#[async_trait]
316impl ToolBox for Vec<Box<dyn Tool>> {
317 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
318 Ok(self.clone())
319 }
320}
321
322#[async_trait]
323impl ToolBox for Box<dyn ToolBox> {
324 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
325 (**self).available_tools().await
326 }
327}
328
329#[async_trait]
330impl ToolBox for Arc<dyn ToolBox> {
331 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
332 (**self).available_tools().await
333 }
334}
335
336#[async_trait]
337impl ToolBox for &dyn ToolBox {
338 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
339 (**self).available_tools().await
340 }
341}
342
343#[async_trait]
344impl ToolBox for &[Box<dyn Tool>] {
345 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
346 Ok(self.to_vec())
347 }
348}
349
350#[async_trait]
351impl ToolBox for [Box<dyn Tool>] {
352 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
353 Ok(self.to_vec())
354 }
355}
356
357dyn_clone::clone_trait_object!(ToolBox);
358
359#[async_trait]
360impl Tool for Box<dyn Tool> {
361 async fn invoke(
362 &self,
363 agent_context: &dyn AgentContext,
364 raw_args: Option<&str>,
365 ) -> Result<ToolOutput, ToolError> {
366 (**self).invoke(agent_context, raw_args).await
367 }
368 fn name(&self) -> Cow<'_, str> {
369 (**self).name()
370 }
371 fn tool_spec(&self) -> ToolSpec {
372 (**self).tool_spec()
373 }
374}
375
376dyn_clone::clone_trait_object!(Tool);
377
378impl PartialEq for Box<dyn Tool> {
381 fn eq(&self, other: &Self) -> bool {
382 self.name() == other.name()
383 }
384}
385impl Eq for Box<dyn Tool> {}
386impl std::hash::Hash for Box<dyn Tool> {
387 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
388 self.name().hash(state);
389 }
390}