ultrafast_models_sdk/providers/
circuit_breaker_provider.rs1use crate::circuit_breaker::{
2 CircuitBreaker, CircuitBreakerConfig, CircuitBreakerError, CircuitState,
3};
4use crate::error::ProviderError;
5use crate::models::{
6 AudioRequest, AudioResponse, ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
7 ImageRequest, ImageResponse, SpeechRequest, SpeechResponse,
8};
9use crate::providers::{Provider, ProviderHealth, StreamResult};
10use std::sync::Arc;
11
12pub struct CircuitBreakerProvider {
14 inner: Arc<dyn Provider>,
15 circuit_breaker: CircuitBreaker,
16}
17
18impl CircuitBreakerProvider {
19 pub fn new(provider: Arc<dyn Provider>, config: CircuitBreakerConfig) -> Self {
20 let circuit_breaker =
21 CircuitBreaker::new(format!("{}_circuit_breaker", provider.name()), config);
22
23 Self {
24 inner: provider,
25 circuit_breaker,
26 }
27 }
28
29 pub fn with_default_config(provider: Arc<dyn Provider>) -> Self {
30 Self::new(provider, CircuitBreakerConfig::default())
31 }
32
33 pub async fn get_circuit_state(&self) -> CircuitState {
34 self.circuit_breaker.get_state().await
35 }
36
37 pub async fn force_open(&self) {
38 self.circuit_breaker.force_open().await;
39 }
40
41 pub async fn force_closed(&self) {
42 self.circuit_breaker.force_closed().await;
43 }
44
45 pub async fn get_circuit_breaker_metrics(
46 &self,
47 ) -> Result<
48 crate::circuit_breaker::CircuitBreakerMetrics,
49 crate::circuit_breaker::CircuitBreakerError,
50 > {
51 Ok(self.circuit_breaker.get_metrics().await)
52 }
53
54 async fn handle_circuit_breaker_error(&self, error: CircuitBreakerError) -> ProviderError {
55 match error {
56 CircuitBreakerError::Open => {
57 tracing::warn!("Provider {} circuit breaker is OPEN", self.inner.name());
58 ProviderError::ServiceUnavailable
59 }
60 CircuitBreakerError::Timeout => {
61 tracing::warn!("Provider {} operation timed out", self.inner.name());
62 ProviderError::Timeout
63 }
64 }
65 }
66}
67
68#[async_trait::async_trait]
69impl Provider for CircuitBreakerProvider {
70 fn name(&self) -> &str {
71 self.inner.name()
72 }
73
74 fn supports_streaming(&self) -> bool {
75 self.inner.supports_streaming()
76 }
77
78 fn supports_function_calling(&self) -> bool {
79 self.inner.supports_function_calling()
80 }
81
82 fn supported_models(&self) -> Vec<String> {
83 self.inner.supported_models()
84 }
85
86 async fn chat_completion(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
87 let inner = self.inner.clone();
88 let operation = || async move { inner.chat_completion(request).await };
89
90 match self.circuit_breaker.call(operation).await {
91 Ok(response) => Ok(response),
92 Err(cb_error) => Err(self.handle_circuit_breaker_error(cb_error).await),
93 }
94 }
95
96 async fn stream_chat_completion(
97 &self,
98 request: ChatRequest,
99 ) -> Result<StreamResult, ProviderError> {
100 let state = self.circuit_breaker.get_state().await;
103 if state == CircuitState::Open {
104 return Err(ProviderError::ServiceUnavailable);
105 }
106
107 let inner = self.inner.clone();
109 let operation = || async move { inner.stream_chat_completion(request).await };
110
111 match self.circuit_breaker.call(operation).await {
112 Ok(stream) => Ok(stream),
113 Err(cb_error) => Err(self.handle_circuit_breaker_error(cb_error).await),
114 }
115 }
116
117 async fn embedding(
118 &self,
119 request: EmbeddingRequest,
120 ) -> Result<EmbeddingResponse, ProviderError> {
121 let inner = self.inner.clone();
122 let operation = || async move { inner.embedding(request).await };
123
124 match self.circuit_breaker.call(operation).await {
125 Ok(response) => Ok(response),
126 Err(cb_error) => Err(self.handle_circuit_breaker_error(cb_error).await),
127 }
128 }
129
130 async fn image_generation(
131 &self,
132 request: ImageRequest,
133 ) -> Result<ImageResponse, ProviderError> {
134 let inner = self.inner.clone();
135 let operation = || async move { inner.image_generation(request).await };
136
137 match self.circuit_breaker.call(operation).await {
138 Ok(response) => Ok(response),
139 Err(cb_error) => Err(self.handle_circuit_breaker_error(cb_error).await),
140 }
141 }
142
143 async fn audio_transcription(
144 &self,
145 request: AudioRequest,
146 ) -> Result<AudioResponse, ProviderError> {
147 let inner = self.inner.clone();
148 let operation = || async move { inner.audio_transcription(request).await };
149
150 match self.circuit_breaker.call(operation).await {
151 Ok(response) => Ok(response),
152 Err(cb_error) => Err(self.handle_circuit_breaker_error(cb_error).await),
153 }
154 }
155
156 async fn text_to_speech(
157 &self,
158 request: SpeechRequest,
159 ) -> Result<SpeechResponse, ProviderError> {
160 let inner = self.inner.clone();
161 let operation = || async move { inner.text_to_speech(request).await };
162
163 match self.circuit_breaker.call(operation).await {
164 Ok(response) => Ok(response),
165 Err(cb_error) => Err(self.handle_circuit_breaker_error(cb_error).await),
166 }
167 }
168
169 async fn health_check(&self) -> Result<ProviderHealth, ProviderError> {
170 let inner = self.inner.clone();
171 let operation = || async move { inner.health_check().await };
172
173 match self.circuit_breaker.call(operation).await {
174 Ok(health) => Ok(health),
175 Err(cb_error) => Err(self.handle_circuit_breaker_error(cb_error).await),
176 }
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use crate::models::Message;
184 use crate::providers::{HealthStatus, ProviderHealth};
185 use std::collections::HashMap;
186 use std::time::Duration;
187
188 struct MockProvider {
190 name: String,
191 should_fail: bool,
192 delay: Duration,
193 }
194
195 impl MockProvider {
196 fn new(name: String, should_fail: bool, delay: Duration) -> Self {
197 Self {
198 name,
199 should_fail,
200 delay,
201 }
202 }
203 }
204
205 #[async_trait::async_trait]
206 impl Provider for MockProvider {
207 fn name(&self) -> &str {
208 &self.name
209 }
210
211 fn supports_streaming(&self) -> bool {
212 false
213 }
214
215 fn supports_function_calling(&self) -> bool {
216 false
217 }
218
219 fn supported_models(&self) -> Vec<String> {
220 vec!["test-model".to_string()]
221 }
222
223 async fn chat_completion(
224 &self,
225 _request: ChatRequest,
226 ) -> Result<ChatResponse, ProviderError> {
227 tokio::time::sleep(self.delay).await;
228
229 if self.should_fail {
230 Err(ProviderError::ServiceUnavailable)
231 } else {
232 Ok(ChatResponse {
233 id: "test-id".to_string(),
234 object: "chat.completion".to_string(),
235 created: 1234567890,
236 model: "test-model".to_string(),
237 choices: vec![],
238 usage: None,
239 system_fingerprint: None,
240 })
241 }
242 }
243
244 async fn stream_chat_completion(
245 &self,
246 _request: ChatRequest,
247 ) -> Result<StreamResult, ProviderError> {
248 Err(ProviderError::Configuration {
249 message: "Streaming not supported by mock provider".to_string(),
250 })
251 }
252
253 async fn health_check(&self) -> Result<ProviderHealth, ProviderError> {
254 if self.should_fail {
255 Err(ProviderError::ServiceUnavailable)
256 } else {
257 Ok(ProviderHealth {
258 status: HealthStatus::Healthy,
259 latency_ms: Some(10),
260 last_check: chrono::Utc::now(),
261 details: HashMap::new(),
262 error_rate: 0.0,
263 })
264 }
265 }
266 }
267
268 #[tokio::test]
269 async fn test_circuit_breaker_provider_success() {
270 let mock_provider = Arc::new(MockProvider::new(
271 "test".to_string(),
272 false,
273 Duration::from_millis(10),
274 ));
275
276 let cb_config = CircuitBreakerConfig {
277 failure_threshold: 2,
278 recovery_timeout: Duration::from_millis(100),
279 request_timeout: Duration::from_millis(50),
280 half_open_max_calls: 1,
281 };
282
283 let cb_provider = CircuitBreakerProvider::new(mock_provider, cb_config);
284
285 let request = ChatRequest {
286 model: "test-model".to_string(),
287 messages: vec![Message::user("test")],
288 ..Default::default()
289 };
290
291 let result = cb_provider.chat_completion(request).await;
292 assert!(result.is_ok());
293 assert_eq!(cb_provider.get_circuit_state().await, CircuitState::Closed);
294 }
295
296 #[tokio::test]
297 async fn test_circuit_breaker_provider_failure() {
298 let mock_provider = Arc::new(MockProvider::new(
299 "test".to_string(),
300 true,
301 Duration::from_millis(10),
302 ));
303
304 let cb_config = CircuitBreakerConfig {
305 failure_threshold: 1,
306 recovery_timeout: Duration::from_millis(100),
307 request_timeout: Duration::from_millis(50),
308 half_open_max_calls: 1,
309 };
310
311 let cb_provider = CircuitBreakerProvider::new(mock_provider, cb_config);
312
313 let request = ChatRequest {
314 model: "test-model".to_string(),
315 messages: vec![Message::user("test")],
316 ..Default::default()
317 };
318
319 let result = cb_provider.chat_completion(request.clone()).await;
321 assert!(result.is_err());
322 assert_eq!(cb_provider.get_circuit_state().await, CircuitState::Open);
323
324 let result = cb_provider.chat_completion(request).await;
326 assert!(result.is_err());
327 if let Err(ProviderError::ServiceUnavailable) = result {
328 } else {
330 panic!("Expected ServiceUnavailable error");
331 }
332 }
333
334 #[tokio::test]
335 async fn test_circuit_breaker_provider_timeout() {
336 let mock_provider = Arc::new(MockProvider::new(
337 "test".to_string(),
338 false,
339 Duration::from_millis(100), ));
341
342 let cb_config = CircuitBreakerConfig {
343 failure_threshold: 1,
344 recovery_timeout: Duration::from_millis(100),
345 request_timeout: Duration::from_millis(50),
346 half_open_max_calls: 1,
347 };
348
349 let cb_provider = CircuitBreakerProvider::new(mock_provider, cb_config);
350
351 let request = ChatRequest {
352 model: "test-model".to_string(),
353 messages: vec![Message::user("test")],
354 ..Default::default()
355 };
356
357 let result = cb_provider.chat_completion(request).await;
358 assert!(result.is_err());
359 if let Err(ProviderError::Timeout) = result {
360 } else {
362 panic!("Expected Timeout error");
363 }
364
365 assert_eq!(cb_provider.get_circuit_state().await, CircuitState::Open);
366 }
367}