1use std::fmt::Debug;
2
3use crate::chat_completion::{ChatCompletionRequest, ChatCompletionResponse};
4use crate::stream_backoff::{StreamBackoff, TokioSleeper};
5use crate::{ChatCompletion, ChatCompletionStream};
6use crate::{EmbeddingModel, Embeddings, SimplePrompt, prompt::Prompt};
7
8use crate::chat_completion::errors::LanguageModelError;
9use anyhow::Result;
10use async_trait::async_trait;
11use futures_util::{StreamExt as _, TryStreamExt as _};
12use std::time::Duration;
13
14#[derive(Debug, Clone, Copy)]
18pub struct BackoffConfiguration {
19 pub initial_interval_sec: u64,
21 pub multiplier: f64,
23 pub randomization_factor: f64,
25 pub max_elapsed_time_sec: u64,
28}
29
30impl Default for BackoffConfiguration {
31 fn default() -> Self {
32 Self {
33 initial_interval_sec: 1,
34 multiplier: 2.0,
35 randomization_factor: 0.5,
36 max_elapsed_time_sec: 60,
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
42pub struct LanguageModelWithBackOff<P: Clone> {
43 pub(crate) inner: P,
44 config: BackoffConfiguration,
45}
46
47impl<P: Clone> LanguageModelWithBackOff<P> {
48 pub fn new(client: P, config: BackoffConfiguration) -> Self {
49 Self {
50 inner: client,
51 config,
52 }
53 }
54
55 pub(crate) fn strategy(&self) -> backoff::ExponentialBackoff {
56 backoff::ExponentialBackoffBuilder::default()
57 .with_initial_interval(Duration::from_secs(self.config.initial_interval_sec))
58 .with_multiplier(self.config.multiplier)
59 .with_max_elapsed_time(Some(Duration::from_secs(self.config.max_elapsed_time_sec)))
60 .with_randomization_factor(self.config.randomization_factor)
61 .build()
62 }
63}
64
65#[async_trait]
66impl<P: SimplePrompt + Clone> SimplePrompt for LanguageModelWithBackOff<P> {
67 async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
68 let strategy = self.strategy();
69
70 let op = || {
71 let prompt = prompt.clone();
72 async {
73 self.inner.prompt(prompt).await.map_err(|e| match e {
74 LanguageModelError::ContextLengthExceeded(e) => {
75 backoff::Error::Permanent(LanguageModelError::ContextLengthExceeded(e))
76 }
77 LanguageModelError::PermanentError(e) => {
78 backoff::Error::Permanent(LanguageModelError::PermanentError(e))
79 }
80 LanguageModelError::TransientError(e) => {
81 backoff::Error::transient(LanguageModelError::TransientError(e))
82 }
83 })
84 }
85 };
86
87 backoff::future::retry(strategy, op).await
88 }
89
90 fn name(&self) -> &'static str {
91 self.inner.name()
92 }
93}
94
95#[async_trait]
96impl<P: EmbeddingModel + Clone> EmbeddingModel for LanguageModelWithBackOff<P> {
97 async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
98 self.inner.embed(input).await
99 }
100
101 fn name(&self) -> &'static str {
102 self.inner.name()
103 }
104}
105
106#[async_trait]
107impl<LLM: ChatCompletion + Clone> ChatCompletion for LanguageModelWithBackOff<LLM> {
108 async fn complete(
109 &self,
110 request: &ChatCompletionRequest,
111 ) -> Result<ChatCompletionResponse, LanguageModelError> {
112 let strategy = self.strategy();
113
114 let op = || async move {
115 self.inner.complete(request).await.map_err(|e| match e {
116 LanguageModelError::ContextLengthExceeded(e) => {
117 backoff::Error::Permanent(LanguageModelError::ContextLengthExceeded(e))
118 }
119 LanguageModelError::PermanentError(e) => {
120 backoff::Error::Permanent(LanguageModelError::PermanentError(e))
121 }
122 LanguageModelError::TransientError(e) => {
123 backoff::Error::transient(LanguageModelError::TransientError(e))
124 }
125 })
126 };
127
128 backoff::future::retry(strategy, op).await
129 }
130
131 async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
132 let strategy = self.strategy();
133
134 let stream = self.inner.complete_stream(request).await;
135 let stream = stream
136 .map_err(|e| match e {
137 LanguageModelError::ContextLengthExceeded(e) => {
138 backoff::Error::Permanent(LanguageModelError::ContextLengthExceeded(e))
139 }
140 LanguageModelError::PermanentError(e) => {
141 backoff::Error::Permanent(LanguageModelError::PermanentError(e))
142 }
143 LanguageModelError::TransientError(e) => {
144 backoff::Error::transient(LanguageModelError::TransientError(e))
145 }
146 })
147 .boxed();
148 StreamBackoff::new(stream, strategy, TokioSleeper)
149 .map_err(|e| match e {
150 backoff::Error::Permanent(e) => e,
151 backoff::Error::Transient { err, .. } => err,
152 })
153 .boxed()
154 }
155}
156#[cfg(test)]
157mod tests {
158
159 use uuid::Uuid;
160
161 use super::*;
162 use std::collections::HashSet;
163 use std::sync::Arc;
164 use std::sync::atomic::{AtomicUsize, Ordering};
165
166 #[derive(Debug, Clone)]
167 struct MockSimplePrompt {
168 call_count: Arc<AtomicUsize>,
169 should_fail_count: usize,
170 error_type: MockErrorType,
171 }
172
173 #[derive(Debug, Clone, Copy)]
174 enum MockErrorType {
175 Transient,
176 Permanent,
177 ContextLengthExceeded,
178 }
179
180 #[derive(Clone)]
181 struct MockChatCompletion {
182 call_count: Arc<AtomicUsize>,
183 should_fail_count: usize,
184 error_type: MockErrorType,
185 }
186
187 #[async_trait]
188 impl ChatCompletion for MockChatCompletion {
189 async fn complete(
190 &self,
191 _request: &ChatCompletionRequest,
192 ) -> Result<ChatCompletionResponse, LanguageModelError> {
193 let count = self.call_count.fetch_add(1, Ordering::SeqCst);
194
195 if count < self.should_fail_count {
196 match self.error_type {
197 MockErrorType::Transient => Err(LanguageModelError::TransientError(Box::new(
198 std::io::Error::new(std::io::ErrorKind::ConnectionReset, "Transient error"),
199 ))),
200 MockErrorType::Permanent => Err(LanguageModelError::PermanentError(Box::new(
201 std::io::Error::new(std::io::ErrorKind::InvalidData, "Permanent error"),
202 ))),
203 MockErrorType::ContextLengthExceeded => Err(
204 LanguageModelError::ContextLengthExceeded(Box::new(std::io::Error::new(
205 std::io::ErrorKind::InvalidInput,
206 "Context length exceeded",
207 ))),
208 ),
209 }
210 } else {
211 Ok(ChatCompletionResponse {
212 id: Uuid::new_v4(),
213 message: Some("Success response".to_string()),
214 tool_calls: None,
215 delta: None,
216 usage: None,
217 })
218 }
219 }
220 }
221 #[async_trait]
222 impl SimplePrompt for MockSimplePrompt {
223 async fn prompt(&self, _prompt: Prompt) -> Result<String, LanguageModelError> {
224 let count = self.call_count.fetch_add(1, Ordering::SeqCst);
225
226 if count < self.should_fail_count {
227 match self.error_type {
228 MockErrorType::Transient => Err(LanguageModelError::TransientError(Box::new(
229 std::io::Error::new(std::io::ErrorKind::ConnectionReset, "Transient error"),
230 ))),
231 MockErrorType::Permanent => Err(LanguageModelError::PermanentError(Box::new(
232 std::io::Error::new(std::io::ErrorKind::InvalidData, "Permanent error"),
233 ))),
234 MockErrorType::ContextLengthExceeded => Err(
235 LanguageModelError::ContextLengthExceeded(Box::new(std::io::Error::new(
236 std::io::ErrorKind::InvalidInput,
237 "Context length exceeded",
238 ))),
239 ),
240 }
241 } else {
242 Ok("Success response".to_string())
243 }
244 }
245
246 fn name(&self) -> &'static str {
247 "MockSimplePrompt"
248 }
249 }
250
251 #[tokio::test]
252 async fn test_language_model_with_backoff_retries_transient_errors() {
253 let call_count = Arc::new(AtomicUsize::new(0));
254 let mock_prompt = MockSimplePrompt {
255 call_count: call_count.clone(),
256 should_fail_count: 2, error_type: MockErrorType::Transient,
258 };
259
260 let config = BackoffConfiguration {
261 initial_interval_sec: 1,
262 max_elapsed_time_sec: 10,
263 multiplier: 1.5,
264 randomization_factor: 0.5,
265 };
266
267 let model_with_backoff = LanguageModelWithBackOff::new(mock_prompt, config);
268
269 let result = model_with_backoff.prompt(Prompt::from("Test prompt")).await;
270
271 assert!(result.is_ok());
272 assert_eq!(call_count.load(Ordering::SeqCst), 3);
273 assert_eq!(result.unwrap(), "Success response");
274 }
275
276 #[tokio::test]
277 async fn test_language_model_with_backoff_does_not_retry_permanent_errors() {
278 let call_count = Arc::new(AtomicUsize::new(0));
279 let mock_prompt = MockSimplePrompt {
280 call_count: call_count.clone(),
281 should_fail_count: 1,
282 error_type: MockErrorType::Permanent,
283 };
284
285 let config = BackoffConfiguration {
286 initial_interval_sec: 1,
287 max_elapsed_time_sec: 10,
288 multiplier: 1.5,
289 randomization_factor: 0.5,
290 };
291
292 let model_with_backoff = LanguageModelWithBackOff::new(mock_prompt, config);
293
294 let result = model_with_backoff.prompt(Prompt::from("Test prompt")).await;
295
296 assert!(result.is_err());
297 assert_eq!(call_count.load(Ordering::SeqCst), 1);
298
299 match result {
300 Err(LanguageModelError::PermanentError(_)) => {} _ => panic!("Expected PermanentError"),
302 }
303 }
304
305 #[tokio::test]
306 async fn test_language_model_with_backoff_does_not_retry_context_length_errors() {
307 let call_count = Arc::new(AtomicUsize::new(0));
308 let mock_prompt = MockSimplePrompt {
309 call_count: call_count.clone(),
310 should_fail_count: 1,
311 error_type: MockErrorType::ContextLengthExceeded,
312 };
313
314 let config = BackoffConfiguration {
315 initial_interval_sec: 1,
316 max_elapsed_time_sec: 10,
317 multiplier: 1.5,
318 randomization_factor: 0.5,
319 };
320
321 let model_with_backoff = LanguageModelWithBackOff::new(mock_prompt, config);
322
323 let result = model_with_backoff.prompt(Prompt::from("Test prompt")).await;
324
325 assert!(result.is_err());
326 assert_eq!(call_count.load(Ordering::SeqCst), 1);
327
328 match result {
329 Err(LanguageModelError::ContextLengthExceeded(_)) => {} _ => panic!("Expected ContextLengthExceeded"),
331 }
332 }
333
334 #[tokio::test]
335 async fn test_language_model_with_backoff_retries_chat_completion_transient_errors() {
336 let call_count = Arc::new(AtomicUsize::new(0));
337 let mock_chat = MockChatCompletion {
338 call_count: call_count.clone(),
339 should_fail_count: 2, error_type: MockErrorType::Transient,
341 };
342
343 let config = BackoffConfiguration {
344 initial_interval_sec: 1,
345 max_elapsed_time_sec: 10,
346 multiplier: 1.5,
347 randomization_factor: 0.5,
348 };
349
350 let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
351
352 let request = ChatCompletionRequest {
353 messages: vec![],
354 tools_spec: HashSet::default(),
355 };
356
357 let result = model_with_backoff.complete(&request).await;
358
359 assert!(result.is_ok());
360 assert_eq!(call_count.load(Ordering::SeqCst), 3);
361 assert_eq!(
362 result.unwrap().message,
363 Some("Success response".to_string())
364 );
365 }
366
367 #[tokio::test]
368 async fn test_language_model_with_backoff_does_not_retry_chat_completion_permanent_errors() {
369 let call_count = Arc::new(AtomicUsize::new(0));
370 let mock_chat = MockChatCompletion {
371 call_count: call_count.clone(),
372 should_fail_count: 2, error_type: MockErrorType::Permanent,
374 };
375
376 let config = BackoffConfiguration {
377 initial_interval_sec: 1,
378 max_elapsed_time_sec: 10,
379 multiplier: 1.5,
380 randomization_factor: 0.5,
381 };
382
383 let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
384
385 let request = ChatCompletionRequest {
386 messages: vec![],
387 tools_spec: HashSet::default(),
388 };
389
390 let result = model_with_backoff.complete(&request).await;
391
392 assert!(result.is_err());
393 assert_eq!(call_count.load(Ordering::SeqCst), 1); match result {
396 Err(LanguageModelError::PermanentError(_)) => {} _ => panic!("Expected PermanentError, got {result:?}"),
398 }
399 }
400
401 #[tokio::test]
402 async fn test_language_model_with_backoff_does_not_retry_chat_completion_context_length_errors()
403 {
404 let call_count = Arc::new(AtomicUsize::new(0));
405 let mock_chat = MockChatCompletion {
406 call_count: call_count.clone(),
407 should_fail_count: 2, error_type: MockErrorType::ContextLengthExceeded,
409 };
410
411 let config = BackoffConfiguration {
412 initial_interval_sec: 1,
413 max_elapsed_time_sec: 10,
414 multiplier: 1.5,
415 randomization_factor: 0.5,
416 };
417
418 let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
419
420 let request = ChatCompletionRequest {
421 messages: vec![],
422 tools_spec: HashSet::default(),
423 };
424
425 let result = model_with_backoff.complete(&request).await;
426
427 assert!(result.is_err());
428 assert_eq!(call_count.load(Ordering::SeqCst), 1); match result {
431 Err(LanguageModelError::ContextLengthExceeded(_)) => {} _ => panic!("Expected ContextLengthExceeded, got {result:?}"),
433 }
434 }
435}