swiftide_core/
indexing_decorators.rs

1use std::fmt::Debug;
2
3use crate::{EmbeddingModel, Embeddings, SimplePrompt, prompt::Prompt};
4
5use crate::chat_completion::errors::LanguageModelError;
6use anyhow::Result;
7use async_trait::async_trait;
8use std::time::Duration;
9
10/// Backoff configuration for api calls.
11/// Each time an api call fails backoff will wait an increasing period of time for each subsequent
12/// retry attempt. see <https://docs.rs/backoff/latest/backoff/> for more details.
13#[derive(Debug, Clone, Copy)]
14pub struct BackoffConfiguration {
15    /// Initial interval in seconds between retries
16    pub initial_interval_sec: u64,
17    /// The factor by which the interval is multiplied on each retry attempt
18    pub multiplier: f64,
19    /// Introduces randomness to avoid retry storms
20    pub randomization_factor: f64,
21    /// Total time all attempts are allowed in seconds. Once a retry must wait longer than this,
22    /// the request is considered to have failed.
23    pub max_elapsed_time_sec: u64,
24}
25
26impl Default for BackoffConfiguration {
27    fn default() -> Self {
28        Self {
29            initial_interval_sec: 1,
30            multiplier: 2.0,
31            randomization_factor: 0.5,
32            max_elapsed_time_sec: 60,
33        }
34    }
35}
36
37#[derive(Debug, Clone)]
38pub struct LanguageModelWithBackOff<P: Clone> {
39    pub(crate) inner: P,
40    config: BackoffConfiguration,
41}
42
43impl<P: Clone> LanguageModelWithBackOff<P> {
44    pub fn new(client: P, config: BackoffConfiguration) -> Self {
45        Self {
46            inner: client,
47            config,
48        }
49    }
50
51    pub(crate) fn strategy(&self) -> backoff::ExponentialBackoff {
52        backoff::ExponentialBackoffBuilder::default()
53            .with_initial_interval(Duration::from_secs(self.config.initial_interval_sec))
54            .with_multiplier(self.config.multiplier)
55            .with_max_elapsed_time(Some(Duration::from_secs(self.config.max_elapsed_time_sec)))
56            .with_randomization_factor(self.config.randomization_factor)
57            .build()
58    }
59}
60
61#[async_trait]
62impl<P: SimplePrompt + Clone> SimplePrompt for LanguageModelWithBackOff<P> {
63    async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
64        let strategy = self.strategy();
65
66        let op = || {
67            let prompt = prompt.clone();
68            async {
69                self.inner.prompt(prompt).await.map_err(|e| match e {
70                    LanguageModelError::ContextLengthExceeded(e) => {
71                        backoff::Error::Permanent(LanguageModelError::ContextLengthExceeded(e))
72                    }
73                    LanguageModelError::PermanentError(e) => {
74                        backoff::Error::Permanent(LanguageModelError::PermanentError(e))
75                    }
76                    LanguageModelError::TransientError(e) => {
77                        backoff::Error::transient(LanguageModelError::TransientError(e))
78                    }
79                })
80            }
81        };
82
83        backoff::future::retry(strategy, op).await
84    }
85
86    fn name(&self) -> &'static str {
87        self.inner.name()
88    }
89}
90
91#[async_trait]
92impl<P: EmbeddingModel + Clone> EmbeddingModel for LanguageModelWithBackOff<P> {
93    async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
94        self.inner.embed(input).await
95    }
96
97    fn name(&self) -> &'static str {
98        self.inner.name()
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use std::sync::Arc;
106    use std::sync::atomic::{AtomicUsize, Ordering};
107
108    #[derive(Debug, Clone)]
109    struct MockSimplePrompt {
110        call_count: Arc<AtomicUsize>,
111        should_fail_count: usize,
112        error_type: MockErrorType,
113    }
114
115    #[derive(Debug, Clone, Copy)]
116    enum MockErrorType {
117        Transient,
118        Permanent,
119        ContextLengthExceeded,
120    }
121
122    #[async_trait]
123    impl SimplePrompt for MockSimplePrompt {
124        async fn prompt(&self, _prompt: Prompt) -> Result<String, LanguageModelError> {
125            let count = self.call_count.fetch_add(1, Ordering::SeqCst);
126
127            if count < self.should_fail_count {
128                match self.error_type {
129                    MockErrorType::Transient => Err(LanguageModelError::TransientError(Box::new(
130                        std::io::Error::new(std::io::ErrorKind::ConnectionReset, "Transient error"),
131                    ))),
132                    MockErrorType::Permanent => Err(LanguageModelError::PermanentError(Box::new(
133                        std::io::Error::new(std::io::ErrorKind::InvalidData, "Permanent error"),
134                    ))),
135                    MockErrorType::ContextLengthExceeded => Err(
136                        LanguageModelError::ContextLengthExceeded(Box::new(std::io::Error::new(
137                            std::io::ErrorKind::InvalidInput,
138                            "Context length exceeded",
139                        ))),
140                    ),
141                }
142            } else {
143                Ok("Success response".to_string())
144            }
145        }
146
147        fn name(&self) -> &'static str {
148            "MockSimplePrompt"
149        }
150    }
151
152    #[tokio::test]
153    async fn test_language_model_with_backoff_retries_transient_errors() {
154        let call_count = Arc::new(AtomicUsize::new(0));
155        let mock_prompt = MockSimplePrompt {
156            call_count: call_count.clone(),
157            should_fail_count: 2, // Fail twice, succeed on third attempt
158            error_type: MockErrorType::Transient,
159        };
160
161        let config = BackoffConfiguration {
162            initial_interval_sec: 1,
163            max_elapsed_time_sec: 10,
164            multiplier: 1.5,
165            randomization_factor: 0.5,
166        };
167
168        let model_with_backoff = LanguageModelWithBackOff::new(mock_prompt, config);
169
170        let result = model_with_backoff.prompt(Prompt::from("Test prompt")).await;
171
172        assert!(result.is_ok());
173        assert_eq!(call_count.load(Ordering::SeqCst), 3);
174        assert_eq!(result.unwrap(), "Success response");
175    }
176
177    #[tokio::test]
178    async fn test_language_model_with_backoff_does_not_retry_permanent_errors() {
179        let call_count = Arc::new(AtomicUsize::new(0));
180        let mock_prompt = MockSimplePrompt {
181            call_count: call_count.clone(),
182            should_fail_count: 1,
183            error_type: MockErrorType::Permanent,
184        };
185
186        let config = BackoffConfiguration {
187            initial_interval_sec: 1,
188            max_elapsed_time_sec: 10,
189            multiplier: 1.5,
190            randomization_factor: 0.5,
191        };
192
193        let model_with_backoff = LanguageModelWithBackOff::new(mock_prompt, config);
194
195        let result = model_with_backoff.prompt(Prompt::from("Test prompt")).await;
196
197        assert!(result.is_err());
198        assert_eq!(call_count.load(Ordering::SeqCst), 1);
199
200        match result {
201            Err(LanguageModelError::PermanentError(_)) => {} // Expected
202            _ => panic!("Expected PermanentError"),
203        }
204    }
205
206    #[tokio::test]
207    async fn test_language_model_with_backoff_does_not_retry_context_length_errors() {
208        let call_count = Arc::new(AtomicUsize::new(0));
209        let mock_prompt = MockSimplePrompt {
210            call_count: call_count.clone(),
211            should_fail_count: 1,
212            error_type: MockErrorType::ContextLengthExceeded,
213        };
214
215        let config = BackoffConfiguration {
216            initial_interval_sec: 1,
217            max_elapsed_time_sec: 10,
218            multiplier: 1.5,
219            randomization_factor: 0.5,
220        };
221
222        let model_with_backoff = LanguageModelWithBackOff::new(mock_prompt, config);
223
224        let result = model_with_backoff.prompt(Prompt::from("Test prompt")).await;
225
226        assert!(result.is_err());
227        assert_eq!(call_count.load(Ordering::SeqCst), 1);
228
229        match result {
230            Err(LanguageModelError::ContextLengthExceeded(_)) => {} // Expected
231            _ => panic!("Expected ContextLengthExceeded"),
232        }
233    }
234}