ultrafast_models_sdk/providers/
circuit_breaker_provider.rs

1use crate::circuit_breaker::{
2    CircuitBreaker, CircuitBreakerConfig, CircuitBreakerError, CircuitState,
3};
4use crate::error::ProviderError;
5use crate::models::{
6    AudioRequest, AudioResponse, ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
7    ImageRequest, ImageResponse, SpeechRequest, SpeechResponse,
8};
9use crate::providers::{Provider, ProviderHealth, StreamResult};
10use std::sync::Arc;
11
12/// Wrapper that adds circuit breaker functionality to any provider
13pub struct CircuitBreakerProvider {
14    inner: Arc<dyn Provider>,
15    circuit_breaker: CircuitBreaker,
16}
17
18impl CircuitBreakerProvider {
19    pub fn new(provider: Arc<dyn Provider>, config: CircuitBreakerConfig) -> Self {
20        let circuit_breaker =
21            CircuitBreaker::new(format!("{}_circuit_breaker", provider.name()), config);
22
23        Self {
24            inner: provider,
25            circuit_breaker,
26        }
27    }
28
29    pub fn with_default_config(provider: Arc<dyn Provider>) -> Self {
30        Self::new(provider, CircuitBreakerConfig::default())
31    }
32
33    pub async fn get_circuit_state(&self) -> CircuitState {
34        self.circuit_breaker.get_state().await
35    }
36
37    pub async fn force_open(&self) {
38        self.circuit_breaker.force_open().await;
39    }
40
41    pub async fn force_closed(&self) {
42        self.circuit_breaker.force_closed().await;
43    }
44
45    pub async fn get_circuit_breaker_metrics(
46        &self,
47    ) -> Result<
48        crate::circuit_breaker::CircuitBreakerMetrics,
49        crate::circuit_breaker::CircuitBreakerError,
50    > {
51        Ok(self.circuit_breaker.get_metrics().await)
52    }
53
54    async fn handle_circuit_breaker_error(&self, error: CircuitBreakerError) -> ProviderError {
55        match error {
56            CircuitBreakerError::Open => {
57                tracing::warn!("Provider {} circuit breaker is OPEN", self.inner.name());
58                ProviderError::ServiceUnavailable
59            }
60            CircuitBreakerError::Timeout => {
61                tracing::warn!("Provider {} operation timed out", self.inner.name());
62                ProviderError::Timeout
63            }
64        }
65    }
66}
67
68#[async_trait::async_trait]
69impl Provider for CircuitBreakerProvider {
70    fn name(&self) -> &str {
71        self.inner.name()
72    }
73
74    fn supports_streaming(&self) -> bool {
75        self.inner.supports_streaming()
76    }
77
78    fn supports_function_calling(&self) -> bool {
79        self.inner.supports_function_calling()
80    }
81
82    fn supported_models(&self) -> Vec<String> {
83        self.inner.supported_models()
84    }
85
86    async fn chat_completion(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
87        let inner = self.inner.clone();
88        let operation = || async move { inner.chat_completion(request).await };
89
90        match self.circuit_breaker.call(operation).await {
91            Ok(response) => Ok(response),
92            Err(cb_error) => Err(self.handle_circuit_breaker_error(cb_error).await),
93        }
94    }
95
96    async fn stream_chat_completion(
97        &self,
98        request: ChatRequest,
99    ) -> Result<StreamResult, ProviderError> {
100        // For streaming, we check the circuit breaker state but don't wrap the stream
101        // The stream itself will handle individual chunk failures
102        let state = self.circuit_breaker.get_state().await;
103        if state == CircuitState::Open {
104            return Err(ProviderError::ServiceUnavailable);
105        }
106
107        // Attempt to start the stream
108        let inner = self.inner.clone();
109        let operation = || async move { inner.stream_chat_completion(request).await };
110
111        match self.circuit_breaker.call(operation).await {
112            Ok(stream) => Ok(stream),
113            Err(cb_error) => Err(self.handle_circuit_breaker_error(cb_error).await),
114        }
115    }
116
117    async fn embedding(
118        &self,
119        request: EmbeddingRequest,
120    ) -> Result<EmbeddingResponse, ProviderError> {
121        let inner = self.inner.clone();
122        let operation = || async move { inner.embedding(request).await };
123
124        match self.circuit_breaker.call(operation).await {
125            Ok(response) => Ok(response),
126            Err(cb_error) => Err(self.handle_circuit_breaker_error(cb_error).await),
127        }
128    }
129
130    async fn image_generation(
131        &self,
132        request: ImageRequest,
133    ) -> Result<ImageResponse, ProviderError> {
134        let inner = self.inner.clone();
135        let operation = || async move { inner.image_generation(request).await };
136
137        match self.circuit_breaker.call(operation).await {
138            Ok(response) => Ok(response),
139            Err(cb_error) => Err(self.handle_circuit_breaker_error(cb_error).await),
140        }
141    }
142
143    async fn audio_transcription(
144        &self,
145        request: AudioRequest,
146    ) -> Result<AudioResponse, ProviderError> {
147        let inner = self.inner.clone();
148        let operation = || async move { inner.audio_transcription(request).await };
149
150        match self.circuit_breaker.call(operation).await {
151            Ok(response) => Ok(response),
152            Err(cb_error) => Err(self.handle_circuit_breaker_error(cb_error).await),
153        }
154    }
155
156    async fn text_to_speech(
157        &self,
158        request: SpeechRequest,
159    ) -> Result<SpeechResponse, ProviderError> {
160        let inner = self.inner.clone();
161        let operation = || async move { inner.text_to_speech(request).await };
162
163        match self.circuit_breaker.call(operation).await {
164            Ok(response) => Ok(response),
165            Err(cb_error) => Err(self.handle_circuit_breaker_error(cb_error).await),
166        }
167    }
168
169    async fn health_check(&self) -> Result<ProviderHealth, ProviderError> {
170        let inner = self.inner.clone();
171        let operation = || async move { inner.health_check().await };
172
173        match self.circuit_breaker.call(operation).await {
174            Ok(health) => Ok(health),
175            Err(cb_error) => Err(self.handle_circuit_breaker_error(cb_error).await),
176        }
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use crate::models::Message;
184    use crate::providers::{HealthStatus, ProviderHealth};
185    use std::collections::HashMap;
186    use std::time::Duration;
187
188    // Mock provider for testing
189    struct MockProvider {
190        name: String,
191        should_fail: bool,
192        delay: Duration,
193    }
194
195    impl MockProvider {
196        fn new(name: String, should_fail: bool, delay: Duration) -> Self {
197            Self {
198                name,
199                should_fail,
200                delay,
201            }
202        }
203    }
204
205    #[async_trait::async_trait]
206    impl Provider for MockProvider {
207        fn name(&self) -> &str {
208            &self.name
209        }
210
211        fn supports_streaming(&self) -> bool {
212            false
213        }
214
215        fn supports_function_calling(&self) -> bool {
216            false
217        }
218
219        fn supported_models(&self) -> Vec<String> {
220            vec!["test-model".to_string()]
221        }
222
223        async fn chat_completion(
224            &self,
225            _request: ChatRequest,
226        ) -> Result<ChatResponse, ProviderError> {
227            tokio::time::sleep(self.delay).await;
228
229            if self.should_fail {
230                Err(ProviderError::ServiceUnavailable)
231            } else {
232                Ok(ChatResponse {
233                    id: "test-id".to_string(),
234                    object: "chat.completion".to_string(),
235                    created: 1234567890,
236                    model: "test-model".to_string(),
237                    choices: vec![],
238                    usage: None,
239                    system_fingerprint: None,
240                })
241            }
242        }
243
244        async fn stream_chat_completion(
245            &self,
246            _request: ChatRequest,
247        ) -> Result<StreamResult, ProviderError> {
248            Err(ProviderError::Configuration {
249                message: "Streaming not supported by mock provider".to_string(),
250            })
251        }
252
253        async fn health_check(&self) -> Result<ProviderHealth, ProviderError> {
254            if self.should_fail {
255                Err(ProviderError::ServiceUnavailable)
256            } else {
257                Ok(ProviderHealth {
258                    status: HealthStatus::Healthy,
259                    latency_ms: Some(10),
260                    last_check: chrono::Utc::now(),
261                    details: HashMap::new(),
262                    error_rate: 0.0,
263                })
264            }
265        }
266    }
267
268    #[tokio::test]
269    async fn test_circuit_breaker_provider_success() {
270        let mock_provider = Arc::new(MockProvider::new(
271            "test".to_string(),
272            false,
273            Duration::from_millis(10),
274        ));
275
276        let cb_config = CircuitBreakerConfig {
277            failure_threshold: 2,
278            recovery_timeout: Duration::from_millis(100),
279            request_timeout: Duration::from_millis(50),
280            half_open_max_calls: 1,
281        };
282
283        let cb_provider = CircuitBreakerProvider::new(mock_provider, cb_config);
284
285        let request = ChatRequest {
286            model: "test-model".to_string(),
287            messages: vec![Message::user("test")],
288            ..Default::default()
289        };
290
291        let result = cb_provider.chat_completion(request).await;
292        assert!(result.is_ok());
293        assert_eq!(cb_provider.get_circuit_state().await, CircuitState::Closed);
294    }
295
296    #[tokio::test]
297    async fn test_circuit_breaker_provider_failure() {
298        let mock_provider = Arc::new(MockProvider::new(
299            "test".to_string(),
300            true,
301            Duration::from_millis(10),
302        ));
303
304        let cb_config = CircuitBreakerConfig {
305            failure_threshold: 1,
306            recovery_timeout: Duration::from_millis(100),
307            request_timeout: Duration::from_millis(50),
308            half_open_max_calls: 1,
309        };
310
311        let cb_provider = CircuitBreakerProvider::new(mock_provider, cb_config);
312
313        let request = ChatRequest {
314            model: "test-model".to_string(),
315            messages: vec![Message::user("test")],
316            ..Default::default()
317        };
318
319        // First failure should open the circuit
320        let result = cb_provider.chat_completion(request.clone()).await;
321        assert!(result.is_err());
322        assert_eq!(cb_provider.get_circuit_state().await, CircuitState::Open);
323
324        // Second call should be blocked by circuit breaker
325        let result = cb_provider.chat_completion(request).await;
326        assert!(result.is_err());
327        if let Err(ProviderError::ServiceUnavailable) = result {
328            // Expected error
329        } else {
330            panic!("Expected ServiceUnavailable error");
331        }
332    }
333
334    #[tokio::test]
335    async fn test_circuit_breaker_provider_timeout() {
336        let mock_provider = Arc::new(MockProvider::new(
337            "test".to_string(),
338            false,
339            Duration::from_millis(100), // Longer than timeout
340        ));
341
342        let cb_config = CircuitBreakerConfig {
343            failure_threshold: 1,
344            recovery_timeout: Duration::from_millis(100),
345            request_timeout: Duration::from_millis(50),
346            half_open_max_calls: 1,
347        };
348
349        let cb_provider = CircuitBreakerProvider::new(mock_provider, cb_config);
350
351        let request = ChatRequest {
352            model: "test-model".to_string(),
353            messages: vec![Message::user("test")],
354            ..Default::default()
355        };
356
357        let result = cb_provider.chat_completion(request).await;
358        assert!(result.is_err());
359        if let Err(ProviderError::Timeout) = result {
360            // Expected error
361        } else {
362            panic!("Expected Timeout error");
363        }
364
365        assert_eq!(cb_provider.get_circuit_state().await, CircuitState::Open);
366    }
367}