swiftide_core/
indexing_decorators.rs1use std::fmt::Debug;
2
3use crate::{EmbeddingModel, Embeddings, SimplePrompt, prompt::Prompt};
4
5use crate::chat_completion::errors::LanguageModelError;
6use anyhow::Result;
7use async_trait::async_trait;
8use std::time::Duration;
9
10#[derive(Debug, Clone, Copy)]
14pub struct BackoffConfiguration {
15 pub initial_interval_sec: u64,
17 pub multiplier: f64,
19 pub randomization_factor: f64,
21 pub max_elapsed_time_sec: u64,
24}
25
26impl Default for BackoffConfiguration {
27 fn default() -> Self {
28 Self {
29 initial_interval_sec: 1,
30 multiplier: 2.0,
31 randomization_factor: 0.5,
32 max_elapsed_time_sec: 60,
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
38pub struct LanguageModelWithBackOff<P: Clone> {
39 pub(crate) inner: P,
40 config: BackoffConfiguration,
41}
42
43impl<P: Clone> LanguageModelWithBackOff<P> {
44 pub fn new(client: P, config: BackoffConfiguration) -> Self {
45 Self {
46 inner: client,
47 config,
48 }
49 }
50
51 pub(crate) fn strategy(&self) -> backoff::ExponentialBackoff {
52 backoff::ExponentialBackoffBuilder::default()
53 .with_initial_interval(Duration::from_secs(self.config.initial_interval_sec))
54 .with_multiplier(self.config.multiplier)
55 .with_max_elapsed_time(Some(Duration::from_secs(self.config.max_elapsed_time_sec)))
56 .with_randomization_factor(self.config.randomization_factor)
57 .build()
58 }
59}
60
61#[async_trait]
62impl<P: SimplePrompt + Clone> SimplePrompt for LanguageModelWithBackOff<P> {
63 async fn prompt(&self, prompt: Prompt) -> Result<String, LanguageModelError> {
64 let strategy = self.strategy();
65
66 let op = || {
67 let prompt = prompt.clone();
68 async {
69 self.inner.prompt(prompt).await.map_err(|e| match e {
70 LanguageModelError::ContextLengthExceeded(e) => {
71 backoff::Error::Permanent(LanguageModelError::ContextLengthExceeded(e))
72 }
73 LanguageModelError::PermanentError(e) => {
74 backoff::Error::Permanent(LanguageModelError::PermanentError(e))
75 }
76 LanguageModelError::TransientError(e) => {
77 backoff::Error::transient(LanguageModelError::TransientError(e))
78 }
79 })
80 }
81 };
82
83 backoff::future::retry(strategy, op).await
84 }
85
86 fn name(&self) -> &'static str {
87 self.inner.name()
88 }
89}
90
91#[async_trait]
92impl<P: EmbeddingModel + Clone> EmbeddingModel for LanguageModelWithBackOff<P> {
93 async fn embed(&self, input: Vec<String>) -> Result<Embeddings, LanguageModelError> {
94 self.inner.embed(input).await
95 }
96
97 fn name(&self) -> &'static str {
98 self.inner.name()
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use std::sync::Arc;
106 use std::sync::atomic::{AtomicUsize, Ordering};
107
108 #[derive(Debug, Clone)]
109 struct MockSimplePrompt {
110 call_count: Arc<AtomicUsize>,
111 should_fail_count: usize,
112 error_type: MockErrorType,
113 }
114
115 #[derive(Debug, Clone, Copy)]
116 enum MockErrorType {
117 Transient,
118 Permanent,
119 ContextLengthExceeded,
120 }
121
122 #[async_trait]
123 impl SimplePrompt for MockSimplePrompt {
124 async fn prompt(&self, _prompt: Prompt) -> Result<String, LanguageModelError> {
125 let count = self.call_count.fetch_add(1, Ordering::SeqCst);
126
127 if count < self.should_fail_count {
128 match self.error_type {
129 MockErrorType::Transient => Err(LanguageModelError::TransientError(Box::new(
130 std::io::Error::new(std::io::ErrorKind::ConnectionReset, "Transient error"),
131 ))),
132 MockErrorType::Permanent => Err(LanguageModelError::PermanentError(Box::new(
133 std::io::Error::new(std::io::ErrorKind::InvalidData, "Permanent error"),
134 ))),
135 MockErrorType::ContextLengthExceeded => Err(
136 LanguageModelError::ContextLengthExceeded(Box::new(std::io::Error::new(
137 std::io::ErrorKind::InvalidInput,
138 "Context length exceeded",
139 ))),
140 ),
141 }
142 } else {
143 Ok("Success response".to_string())
144 }
145 }
146
147 fn name(&self) -> &'static str {
148 "MockSimplePrompt"
149 }
150 }
151
152 #[tokio::test]
153 async fn test_language_model_with_backoff_retries_transient_errors() {
154 let call_count = Arc::new(AtomicUsize::new(0));
155 let mock_prompt = MockSimplePrompt {
156 call_count: call_count.clone(),
157 should_fail_count: 2, error_type: MockErrorType::Transient,
159 };
160
161 let config = BackoffConfiguration {
162 initial_interval_sec: 1,
163 max_elapsed_time_sec: 10,
164 multiplier: 1.5,
165 randomization_factor: 0.5,
166 };
167
168 let model_with_backoff = LanguageModelWithBackOff::new(mock_prompt, config);
169
170 let result = model_with_backoff.prompt(Prompt::from("Test prompt")).await;
171
172 assert!(result.is_ok());
173 assert_eq!(call_count.load(Ordering::SeqCst), 3);
174 assert_eq!(result.unwrap(), "Success response");
175 }
176
177 #[tokio::test]
178 async fn test_language_model_with_backoff_does_not_retry_permanent_errors() {
179 let call_count = Arc::new(AtomicUsize::new(0));
180 let mock_prompt = MockSimplePrompt {
181 call_count: call_count.clone(),
182 should_fail_count: 1,
183 error_type: MockErrorType::Permanent,
184 };
185
186 let config = BackoffConfiguration {
187 initial_interval_sec: 1,
188 max_elapsed_time_sec: 10,
189 multiplier: 1.5,
190 randomization_factor: 0.5,
191 };
192
193 let model_with_backoff = LanguageModelWithBackOff::new(mock_prompt, config);
194
195 let result = model_with_backoff.prompt(Prompt::from("Test prompt")).await;
196
197 assert!(result.is_err());
198 assert_eq!(call_count.load(Ordering::SeqCst), 1);
199
200 match result {
201 Err(LanguageModelError::PermanentError(_)) => {} _ => panic!("Expected PermanentError"),
203 }
204 }
205
206 #[tokio::test]
207 async fn test_language_model_with_backoff_does_not_retry_context_length_errors() {
208 let call_count = Arc::new(AtomicUsize::new(0));
209 let mock_prompt = MockSimplePrompt {
210 call_count: call_count.clone(),
211 should_fail_count: 1,
212 error_type: MockErrorType::ContextLengthExceeded,
213 };
214
215 let config = BackoffConfiguration {
216 initial_interval_sec: 1,
217 max_elapsed_time_sec: 10,
218 multiplier: 1.5,
219 randomization_factor: 0.5,
220 };
221
222 let model_with_backoff = LanguageModelWithBackOff::new(mock_prompt, config);
223
224 let result = model_with_backoff.prompt(Prompt::from("Test prompt")).await;
225
226 assert!(result.is_err());
227 assert_eq!(call_count.load(Ordering::SeqCst), 1);
228
229 match result {
230 Err(LanguageModelError::ContextLengthExceeded(_)) => {} _ => panic!("Expected ContextLengthExceeded"),
232 }
233 }
234}