swiftide_core/chat_completion/
traits.rs1use anyhow::Result;
2use async_trait::async_trait;
3use dyn_clone::DynClone;
4use futures_util::Stream;
5use std::{borrow::Cow, pin::Pin, sync::Arc};
6
7use crate::{AgentContext, CommandOutput, LanguageModelWithBackOff};
8
9use super::{
10 ToolOutput, ToolSpec,
11 chat_completion_request::ChatCompletionRequest,
12 chat_completion_response::ChatCompletionResponse,
13 errors::{LanguageModelError, ToolError},
14};
15
16#[async_trait]
17impl<LLM: ChatCompletion + Clone> ChatCompletion for LanguageModelWithBackOff<LLM> {
18 async fn complete(
19 &self,
20 request: &ChatCompletionRequest,
21 ) -> Result<ChatCompletionResponse, LanguageModelError> {
22 let strategy = self.strategy();
23
24 let op = || {
25 let request = request.clone();
26 async move {
27 self.inner.complete(&request).await.map_err(|e| match e {
28 LanguageModelError::ContextLengthExceeded(e) => {
29 backoff::Error::Permanent(LanguageModelError::ContextLengthExceeded(e))
30 }
31 LanguageModelError::PermanentError(e) => {
32 backoff::Error::Permanent(LanguageModelError::PermanentError(e))
33 }
34 LanguageModelError::TransientError(e) => {
35 backoff::Error::transient(LanguageModelError::TransientError(e))
36 }
37 })
38 }
39 };
40
41 backoff::future::retry(strategy, op).await
42 }
43}
44
45pub type ChatCompletionStream =
46 Pin<Box<dyn Stream<Item = Result<ChatCompletionResponse, LanguageModelError>> + Send>>;
47#[async_trait]
48pub trait ChatCompletion: Send + Sync + DynClone {
49 async fn complete(
50 &self,
51 request: &ChatCompletionRequest,
52 ) -> Result<ChatCompletionResponse, LanguageModelError>;
53
54 async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
57 Box::pin(tokio_stream::iter(vec![self.complete(request).await]))
58 }
59}
60
61#[async_trait]
62impl ChatCompletion for Box<dyn ChatCompletion> {
63 async fn complete(
64 &self,
65 request: &ChatCompletionRequest,
66 ) -> Result<ChatCompletionResponse, LanguageModelError> {
67 (**self).complete(request).await
68 }
69
70 async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
71 (**self).complete_stream(request).await
72 }
73}
74
75#[async_trait]
76impl ChatCompletion for &dyn ChatCompletion {
77 async fn complete(
78 &self,
79 request: &ChatCompletionRequest,
80 ) -> Result<ChatCompletionResponse, LanguageModelError> {
81 (**self).complete(request).await
82 }
83
84 async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
85 (**self).complete_stream(request).await
86 }
87}
88
89#[async_trait]
90impl<T> ChatCompletion for &T
91where
92 T: ChatCompletion + Clone + 'static,
93{
94 async fn complete(
95 &self,
96 request: &ChatCompletionRequest,
97 ) -> Result<ChatCompletionResponse, LanguageModelError> {
98 (**self).complete(request).await
99 }
100
101 async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
102 (**self).complete_stream(request).await
103 }
104}
105
106impl<LLM> From<&LLM> for Box<dyn ChatCompletion>
107where
108 LLM: ChatCompletion + Clone + 'static,
109{
110 fn from(llm: &LLM) -> Self {
111 Box::new(llm.clone()) as Box<dyn ChatCompletion>
112 }
113}
114
115dyn_clone::clone_trait_object!(ChatCompletion);
116
117#[cfg(test)]
118mod tests {
119 use uuid::Uuid;
120
121 use super::*;
122 use crate::BackoffConfiguration;
123 use std::{
124 collections::HashSet,
125 sync::{
126 Arc,
127 atomic::{AtomicUsize, Ordering},
128 },
129 };
130
131 #[derive(Clone)]
132 enum MockErrorType {
133 Transient,
134 Permanent,
135 ContextLengthExceeded,
136 }
137
138 #[derive(Clone)]
139 struct MockChatCompletion {
140 call_count: Arc<AtomicUsize>,
141 should_fail_count: usize,
142 error_type: MockErrorType,
143 }
144
145 #[async_trait]
146 impl ChatCompletion for MockChatCompletion {
147 async fn complete(
148 &self,
149 _request: &ChatCompletionRequest,
150 ) -> Result<ChatCompletionResponse, LanguageModelError> {
151 let count = self.call_count.fetch_add(1, Ordering::SeqCst);
152
153 if count < self.should_fail_count {
154 match self.error_type {
155 MockErrorType::Transient => Err(LanguageModelError::TransientError(Box::new(
156 std::io::Error::new(std::io::ErrorKind::ConnectionReset, "Transient error"),
157 ))),
158 MockErrorType::Permanent => Err(LanguageModelError::PermanentError(Box::new(
159 std::io::Error::new(std::io::ErrorKind::InvalidData, "Permanent error"),
160 ))),
161 MockErrorType::ContextLengthExceeded => Err(
162 LanguageModelError::ContextLengthExceeded(Box::new(std::io::Error::new(
163 std::io::ErrorKind::InvalidInput,
164 "Context length exceeded",
165 ))),
166 ),
167 }
168 } else {
169 Ok(ChatCompletionResponse {
170 id: Uuid::new_v4(),
171 message: Some("Success response".to_string()),
172 tool_calls: None,
173 delta: None,
174 usage: None,
175 })
176 }
177 }
178 }
179
180 #[tokio::test]
181 async fn test_language_model_with_backoff_retries_chat_completion_transient_errors() {
182 let call_count = Arc::new(AtomicUsize::new(0));
183 let mock_chat = MockChatCompletion {
184 call_count: call_count.clone(),
185 should_fail_count: 2, error_type: MockErrorType::Transient,
187 };
188
189 let config = BackoffConfiguration {
190 initial_interval_sec: 1,
191 max_elapsed_time_sec: 10,
192 multiplier: 1.5,
193 randomization_factor: 0.5,
194 };
195
196 let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
197
198 let request = ChatCompletionRequest {
199 messages: vec![],
200 tools_spec: HashSet::default(),
201 };
202
203 let result = model_with_backoff.complete(&request).await;
204
205 assert!(result.is_ok());
206 assert_eq!(call_count.load(Ordering::SeqCst), 3);
207 assert_eq!(
208 result.unwrap().message,
209 Some("Success response".to_string())
210 );
211 }
212
213 #[tokio::test]
214 async fn test_language_model_with_backoff_does_not_retry_chat_completion_permanent_errors() {
215 let call_count = Arc::new(AtomicUsize::new(0));
216 let mock_chat = MockChatCompletion {
217 call_count: call_count.clone(),
218 should_fail_count: 2, error_type: MockErrorType::Permanent,
220 };
221
222 let config = BackoffConfiguration {
223 initial_interval_sec: 1,
224 max_elapsed_time_sec: 10,
225 multiplier: 1.5,
226 randomization_factor: 0.5,
227 };
228
229 let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
230
231 let request = ChatCompletionRequest {
232 messages: vec![],
233 tools_spec: HashSet::default(),
234 };
235
236 let result = model_with_backoff.complete(&request).await;
237
238 assert!(result.is_err());
239 assert_eq!(call_count.load(Ordering::SeqCst), 1); match result {
242 Err(LanguageModelError::PermanentError(_)) => {} _ => panic!("Expected PermanentError, got {result:?}"),
244 }
245 }
246
247 #[tokio::test]
248 async fn test_language_model_with_backoff_does_not_retry_chat_completion_context_length_errors()
249 {
250 let call_count = Arc::new(AtomicUsize::new(0));
251 let mock_chat = MockChatCompletion {
252 call_count: call_count.clone(),
253 should_fail_count: 2, error_type: MockErrorType::ContextLengthExceeded,
255 };
256
257 let config = BackoffConfiguration {
258 initial_interval_sec: 1,
259 max_elapsed_time_sec: 10,
260 multiplier: 1.5,
261 randomization_factor: 0.5,
262 };
263
264 let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
265
266 let request = ChatCompletionRequest {
267 messages: vec![],
268 tools_spec: HashSet::default(),
269 };
270
271 let result = model_with_backoff.complete(&request).await;
272
273 assert!(result.is_err());
274 assert_eq!(call_count.load(Ordering::SeqCst), 1); match result {
277 Err(LanguageModelError::ContextLengthExceeded(_)) => {} _ => panic!("Expected ContextLengthExceeded, got {result:?}"),
279 }
280 }
281}
282
283impl From<CommandOutput> for ToolOutput {
284 fn from(value: CommandOutput) -> Self {
285 ToolOutput::Text(value.output)
286 }
287}
288
289#[async_trait]
298pub trait Tool: Send + Sync + DynClone {
299 async fn invoke(
301 &self,
302 agent_context: &dyn AgentContext,
303 raw_args: Option<&str>,
304 ) -> Result<ToolOutput, ToolError>;
305
306 fn name(&self) -> Cow<'_, str>;
307
308 fn tool_spec(&self) -> ToolSpec;
309
310 fn boxed<'a>(self) -> Box<dyn Tool + 'a>
311 where
312 Self: Sized + 'a,
313 {
314 Box::new(self) as Box<dyn Tool>
315 }
316}
317
318#[async_trait]
326pub trait ToolBox: Send + Sync + DynClone {
327 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>>;
328
329 fn name(&self) -> Cow<'_, str> {
330 Cow::Borrowed("Unnamed ToolBox")
331 }
332
333 fn boxed<'a>(self) -> Box<dyn ToolBox + 'a>
334 where
335 Self: Sized + 'a,
336 {
337 Box::new(self) as Box<dyn ToolBox>
338 }
339}
340
341#[async_trait]
342impl ToolBox for Vec<Box<dyn Tool>> {
343 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
344 Ok(self.clone())
345 }
346}
347
348#[async_trait]
349impl ToolBox for Box<dyn ToolBox> {
350 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
351 (**self).available_tools().await
352 }
353}
354
355#[async_trait]
356impl ToolBox for Arc<dyn ToolBox> {
357 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
358 (**self).available_tools().await
359 }
360}
361
362#[async_trait]
363impl ToolBox for &dyn ToolBox {
364 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
365 (**self).available_tools().await
366 }
367}
368
369#[async_trait]
370impl ToolBox for &[Box<dyn Tool>] {
371 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
372 Ok(self.to_vec())
373 }
374}
375
376#[async_trait]
377impl ToolBox for [Box<dyn Tool>] {
378 async fn available_tools(&self) -> Result<Vec<Box<dyn Tool>>> {
379 Ok(self.to_vec())
380 }
381}
382
383dyn_clone::clone_trait_object!(ToolBox);
384
385#[async_trait]
386impl Tool for Box<dyn Tool> {
387 async fn invoke(
388 &self,
389 agent_context: &dyn AgentContext,
390 raw_args: Option<&str>,
391 ) -> Result<ToolOutput, ToolError> {
392 (**self).invoke(agent_context, raw_args).await
393 }
394 fn name(&self) -> Cow<'_, str> {
395 (**self).name()
396 }
397 fn tool_spec(&self) -> ToolSpec {
398 (**self).tool_spec()
399 }
400}
401
402dyn_clone::clone_trait_object!(Tool);
403
404impl PartialEq for Box<dyn Tool> {
407 fn eq(&self, other: &Self) -> bool {
408 self.name() == other.name()
409 }
410}
411impl Eq for Box<dyn Tool> {}
412impl std::hash::Hash for Box<dyn Tool> {
413 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
414 self.name().hash(state);
415 }
416}