swiftide_core/
indexing_decorators.rs

1use std::fmt::Debug;
2
3use crate::chat_completion::{ChatCompletionRequest, ChatCompletionResponse};
4use crate::stream_backoff::{StreamBackoff, TokioSleeper};
5use crate::{ChatCompletion, ChatCompletionStream};
6use crate::{EmbeddingModel, Embeddings, SimplePrompt, prompt::Prompt};
7
8use crate::chat_completion::errors::LanguageModelError;
9use anyhow::Result;
10use async_trait::async_trait;
11use futures_util::{StreamExt as _, TryStreamExt as _};
12use std::time::Duration;
13
14/// Backoff configuration for api calls.
15/// Each time an api call fails backoff will wait an increasing period of time for each subsequent
16/// retry attempt. see <https://docs.rs/backoff/latest/backoff/> for more details.
17#[derive(Debug, Clone, Copy)]
18pub struct BackoffConfiguration {
19    /// Initial interval in seconds between retries
20    pub initial_interval_sec: u64,
21    /// The factor by which the interval is multiplied on each retry attempt
22    pub multiplier: f64,
23    /// Introduces randomness to avoid retry storms
24    pub randomization_factor: f64,
25    /// Total time all attempts are allowed in seconds. Once a retry must wait longer than this,
26    /// the request is considered to have failed.
27    pub max_elapsed_time_sec: u64,
28}
29
30impl Default for BackoffConfiguration {
31    fn default() -> Self {
32        Self {
33            initial_interval_sec: 1,
34            multiplier: 2.0,
35            randomization_factor: 0.5,
36            max_elapsed_time_sec: 60,
37        }
38    }
39}
40
41#[derive(Debug, Clone)]
42pub struct LanguageModelWithBackOff<P: Clone> {
43    pub(crate) inner: P,
44    config: BackoffConfiguration,
45}
46
47impl<P: Clone> LanguageModelWithBackOff<P> {
48    pub fn new(client: P, config: BackoffConfiguration) -> Self {
49        Self {
50            inner: client,
51            config,
52        }
53    }
54
55    pub(crate) fn strategy(&self) -> backoff::ExponentialBackoff {
56        backoff::ExponentialBackoffBuilder::default()
57            .with_initial_interval(Duration::from_secs(self.config.initial_interval_sec))
58            .with_multiplier(self.config.multiplier)
59            .with_max_elapsed_time(Some(Duration::from_secs(self.config.max_elapsed_time_sec)))
60            .with_randomization_factor(self.config.randomization_factor)
61            .build()
62    }
63}
64
65#[async_trait]
66impl<P: SimplePrompt + Clone> SimplePrompt for LanguageModelWithBackOff<P> {
67    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
68        let strategy = self.strategy();
69
70        let op = || {
71            let prompt = prompt.clone();
72            async {
73                self.inner.prompt(prompt).await.map_err(|e| match e {
74                    LanguageModelError::ContextLengthExceeded(e) => {
75                        backoff::Error::Permanent(LanguageModelError::ContextLengthExceeded(e))
76                    }
77                    LanguageModelError::PermanentError(e) => {
78                        backoff::Error::Permanent(LanguageModelError::PermanentError(e))
79                    }
80                    LanguageModelError::TransientError(e) => {
81                        backoff::Error::transient(LanguageModelError::TransientError(e))
82                    }
83                })
84            }
85        };
86
87        backoff::future::retry(strategy, op).await
88    }
89
90    fn name(&self) -> &'static str {
91        self.inner.name()
92    }
93}
94
95#[async_trait]
96impl<P: EmbeddingModel + Clone> EmbeddingModel for LanguageModelWithBackOff<P> {
97    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
98        self.inner.embed(input).await
99    }
100
101    fn name(&self) -> &'static str {
102        self.inner.name()
103    }
104}
105
106#[async_trait]
107impl<LLM: ChatCompletion + Clone> ChatCompletion for LanguageModelWithBackOff<LLM> {
108    async fn complete(
109        &self,
110        request: &ChatCompletionRequest,
111    ) -> Result<ChatCompletionResponse, LanguageModelError> {
112        let strategy = self.strategy();
113
114        let op = || async move {
115            self.inner.complete(request).await.map_err(|e| match e {
116                LanguageModelError::ContextLengthExceeded(e) => {
117                    backoff::Error::Permanent(LanguageModelError::ContextLengthExceeded(e))
118                }
119                LanguageModelError::PermanentError(e) => {
120                    backoff::Error::Permanent(LanguageModelError::PermanentError(e))
121                }
122                LanguageModelError::TransientError(e) => {
123                    backoff::Error::transient(LanguageModelError::TransientError(e))
124                }
125            })
126        };
127
128        backoff::future::retry(strategy, op).await
129    }
130
131    async fn complete_stream(&self, request: &ChatCompletionRequest) -> ChatCompletionStream {
132        let strategy = self.strategy();
133
134        let stream = self.inner.complete_stream(request).await;
135        let stream = stream
136            .map_err(|e| match e {
137                LanguageModelError::ContextLengthExceeded(e) => {
138                    backoff::Error::Permanent(LanguageModelError::ContextLengthExceeded(e))
139                }
140                LanguageModelError::PermanentError(e) => {
141                    backoff::Error::Permanent(LanguageModelError::PermanentError(e))
142                }
143                LanguageModelError::TransientError(e) => {
144                    backoff::Error::transient(LanguageModelError::TransientError(e))
145                }
146            })
147            .boxed();
148        StreamBackoff::new(stream, strategy, TokioSleeper)
149            .map_err(|e| match e {
150                backoff::Error::Permanent(e) => e,
151                backoff::Error::Transient { err, .. } => err,
152            })
153            .boxed()
154    }
155}
156#[cfg(test)]
157mod tests {
158
159    use uuid::Uuid;
160
161    use super::*;
162    use std::collections::HashSet;
163    use std::sync::Arc;
164    use std::sync::atomic::{AtomicUsize, Ordering};
165
166    #[derive(Debug, Clone)]
167    struct MockSimplePrompt {
168        call_count: Arc<AtomicUsize>,
169        should_fail_count: usize,
170        error_type: MockErrorType,
171    }
172
173    #[derive(Debug, Clone, Copy)]
174    enum MockErrorType {
175        Transient,
176        Permanent,
177        ContextLengthExceeded,
178    }
179
180    #[derive(Clone)]
181    struct MockChatCompletion {
182        call_count: Arc<AtomicUsize>,
183        should_fail_count: usize,
184        error_type: MockErrorType,
185    }
186
187    #[async_trait]
188    impl ChatCompletion for MockChatCompletion {
189        async fn complete(
190            &self,
191            _request: &ChatCompletionRequest,
192        ) -> Result<ChatCompletionResponse, LanguageModelError> {
193            let count = self.call_count.fetch_add(1, Ordering::SeqCst);
194
195            if count < self.should_fail_count {
196                match self.error_type {
197                    MockErrorType::Transient => Err(LanguageModelError::TransientError(Box::new(
198                        std::io::Error::new(std::io::ErrorKind::ConnectionReset, "Transient error"),
199                    ))),
200                    MockErrorType::Permanent => Err(LanguageModelError::PermanentError(Box::new(
201                        std::io::Error::new(std::io::ErrorKind::InvalidData, "Permanent error"),
202                    ))),
203                    MockErrorType::ContextLengthExceeded => Err(
204                        LanguageModelError::ContextLengthExceeded(Box::new(std::io::Error::new(
205                            std::io::ErrorKind::InvalidInput,
206                            "Context length exceeded",
207                        ))),
208                    ),
209                }
210            } else {
211                Ok(ChatCompletionResponse {
212                    id: Uuid::new_v4(),
213                    message: Some("Success response".to_string()),
214                    tool_calls: None,
215                    delta: None,
216                    usage: None,
217                })
218            }
219        }
220    }
221    #[async_trait]
222    impl SimplePrompt for MockSimplePrompt {
223        async fn prompt(&self, _prompt: Prompt) -> Result<String, LanguageModelError> {
224            let count = self.call_count.fetch_add(1, Ordering::SeqCst);
225
226            if count < self.should_fail_count {
227                match self.error_type {
228                    MockErrorType::Transient => Err(LanguageModelError::TransientError(Box::new(
229                        std::io::Error::new(std::io::ErrorKind::ConnectionReset, "Transient error"),
230                    ))),
231                    MockErrorType::Permanent => Err(LanguageModelError::PermanentError(Box::new(
232                        std::io::Error::new(std::io::ErrorKind::InvalidData, "Permanent error"),
233                    ))),
234                    MockErrorType::ContextLengthExceeded => Err(
235                        LanguageModelError::ContextLengthExceeded(Box::new(std::io::Error::new(
236                            std::io::ErrorKind::InvalidInput,
237                            "Context length exceeded",
238                        ))),
239                    ),
240                }
241            } else {
242                Ok("Success response".to_string())
243            }
244        }
245
246        fn name(&self) -> &'static str {
247            "MockSimplePrompt"
248        }
249    }
250
251    #[tokio::test]
252    async fn test_language_model_with_backoff_retries_transient_errors() {
253        let call_count = Arc::new(AtomicUsize::new(0));
254        let mock_prompt = MockSimplePrompt {
255            call_count: call_count.clone(),
256            should_fail_count: 2, // Fail twice, succeed on third attempt
257            error_type: MockErrorType::Transient,
258        };
259
260        let config = BackoffConfiguration {
261            initial_interval_sec: 1,
262            max_elapsed_time_sec: 10,
263            multiplier: 1.5,
264            randomization_factor: 0.5,
265        };
266
267        let model_with_backoff = LanguageModelWithBackOff::new(mock_prompt, config);
268
269        let result = model_with_backoff.prompt(Prompt::from("Test prompt")).await;
270
271        assert!(result.is_ok());
272        assert_eq!(call_count.load(Ordering::SeqCst), 3);
273        assert_eq!(result.unwrap(), "Success response");
274    }
275
276    #[tokio::test]
277    async fn test_language_model_with_backoff_does_not_retry_permanent_errors() {
278        let call_count = Arc::new(AtomicUsize::new(0));
279        let mock_prompt = MockSimplePrompt {
280            call_count: call_count.clone(),
281            should_fail_count: 1,
282            error_type: MockErrorType::Permanent,
283        };
284
285        let config = BackoffConfiguration {
286            initial_interval_sec: 1,
287            max_elapsed_time_sec: 10,
288            multiplier: 1.5,
289            randomization_factor: 0.5,
290        };
291
292        let model_with_backoff = LanguageModelWithBackOff::new(mock_prompt, config);
293
294        let result = model_with_backoff.prompt(Prompt::from("Test prompt")).await;
295
296        assert!(result.is_err());
297        assert_eq!(call_count.load(Ordering::SeqCst), 1);
298
299        match result {
300            Err(LanguageModelError::PermanentError(_)) => {} // Expected
301            _ => panic!("Expected PermanentError"),
302        }
303    }
304
305    #[tokio::test]
306    async fn test_language_model_with_backoff_does_not_retry_context_length_errors() {
307        let call_count = Arc::new(AtomicUsize::new(0));
308        let mock_prompt = MockSimplePrompt {
309            call_count: call_count.clone(),
310            should_fail_count: 1,
311            error_type: MockErrorType::ContextLengthExceeded,
312        };
313
314        let config = BackoffConfiguration {
315            initial_interval_sec: 1,
316            max_elapsed_time_sec: 10,
317            multiplier: 1.5,
318            randomization_factor: 0.5,
319        };
320
321        let model_with_backoff = LanguageModelWithBackOff::new(mock_prompt, config);
322
323        let result = model_with_backoff.prompt(Prompt::from("Test prompt")).await;
324
325        assert!(result.is_err());
326        assert_eq!(call_count.load(Ordering::SeqCst), 1);
327
328        match result {
329            Err(LanguageModelError::ContextLengthExceeded(_)) => {} // Expected
330            _ => panic!("Expected ContextLengthExceeded"),
331        }
332    }
333
334    #[tokio::test]
335    async fn test_language_model_with_backoff_retries_chat_completion_transient_errors() {
336        let call_count = Arc::new(AtomicUsize::new(0));
337        let mock_chat = MockChatCompletion {
338            call_count: call_count.clone(),
339            should_fail_count: 2, // Fail twice, succeed on third attempt
340            error_type: MockErrorType::Transient,
341        };
342
343        let config = BackoffConfiguration {
344            initial_interval_sec: 1,
345            max_elapsed_time_sec: 10,
346            multiplier: 1.5,
347            randomization_factor: 0.5,
348        };
349
350        let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
351
352        let request = ChatCompletionRequest {
353            messages: vec![],
354            tools_spec: HashSet::default(),
355        };
356
357        let result = model_with_backoff.complete(&request).await;
358
359        assert!(result.is_ok());
360        assert_eq!(call_count.load(Ordering::SeqCst), 3);
361        assert_eq!(
362            result.unwrap().message,
363            Some("Success response".to_string())
364        );
365    }
366
367    #[tokio::test]
368    async fn test_language_model_with_backoff_does_not_retry_chat_completion_permanent_errors() {
369        let call_count = Arc::new(AtomicUsize::new(0));
370        let mock_chat = MockChatCompletion {
371            call_count: call_count.clone(),
372            should_fail_count: 2, // Would fail twice if retried
373            error_type: MockErrorType::Permanent,
374        };
375
376        let config = BackoffConfiguration {
377            initial_interval_sec: 1,
378            max_elapsed_time_sec: 10,
379            multiplier: 1.5,
380            randomization_factor: 0.5,
381        };
382
383        let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
384
385        let request = ChatCompletionRequest {
386            messages: vec![],
387            tools_spec: HashSet::default(),
388        };
389
390        let result = model_with_backoff.complete(&request).await;
391
392        assert!(result.is_err());
393        assert_eq!(call_count.load(Ordering::SeqCst), 1); // Should only be called once
394
395        match result {
396            Err(LanguageModelError::PermanentError(_)) => {} // Expected
397            _ => panic!("Expected PermanentError, got {result:?}"),
398        }
399    }
400
401    #[tokio::test]
402    async fn test_language_model_with_backoff_does_not_retry_chat_completion_context_length_errors()
403    {
404        let call_count = Arc::new(AtomicUsize::new(0));
405        let mock_chat = MockChatCompletion {
406            call_count: call_count.clone(),
407            should_fail_count: 2, // Would fail twice if retried
408            error_type: MockErrorType::ContextLengthExceeded,
409        };
410
411        let config = BackoffConfiguration {
412            initial_interval_sec: 1,
413            max_elapsed_time_sec: 10,
414            multiplier: 1.5,
415            randomization_factor: 0.5,
416        };
417
418        let model_with_backoff = LanguageModelWithBackOff::new(mock_chat, config);
419
420        let request = ChatCompletionRequest {
421            messages: vec![],
422            tools_spec: HashSet::default(),
423        };
424
425        let result = model_with_backoff.complete(&request).await;
426
427        assert!(result.is_err());
428        assert_eq!(call_count.load(Ordering::SeqCst), 1); // Should only be called once
429
430        match result {
431            Err(LanguageModelError::ContextLengthExceeded(_)) => {} // Expected
432            _ => panic!("Expected ContextLengthExceeded, got {result:?}"),
433        }
434    }
435}