1use serde::{Deserialize, Serialize};
6use std::time::Duration;
7
8#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10pub struct RetryConfig {
11 pub max_attempts: u32,
13 #[serde(with = "duration_millis")]
15 pub initial_backoff: Duration,
16 #[serde(with = "duration_millis")]
18 pub max_backoff: Duration,
19 pub backoff_multiplier: f32,
21 pub jitter: bool,
23}
24
25impl Default for RetryConfig {
26 fn default() -> Self {
27 Self {
28 max_attempts: 3,
29 initial_backoff: Duration::from_millis(100),
30 max_backoff: Duration::from_secs(10),
31 backoff_multiplier: 2.0,
32 jitter: true,
33 }
34 }
35}
36
37impl RetryConfig {
38 pub fn calculate_backoff(&self, attempt: u32) -> Duration {
51 let base =
52 self.initial_backoff.as_millis() as f32 * self.backoff_multiplier.powi(attempt as i32);
53 let capped = base.min(self.max_backoff.as_millis() as f32);
54
55 let duration = if self.jitter {
56 let jitter_factor = 0.5 + (rand() * 0.5);
58 Duration::from_millis((capped * jitter_factor) as u64)
59 } else {
60 Duration::from_millis(capped as u64)
61 };
62
63 duration.min(self.max_backoff)
64 }
65}
66
67fn rand() -> f32 {
69 use rand::Rng;
70 rand::thread_rng().gen()
71}
72
73#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
75pub struct HealingConfig {
76 pub enabled: bool,
78 pub strict_mode: bool,
80 pub allow_type_coercion: bool,
82 pub min_confidence: f32,
84 pub allow_fuzzy_matching: bool,
86 pub max_attempts: u32,
88}
89
90impl Default for HealingConfig {
91 fn default() -> Self {
92 Self {
93 enabled: true,
94 strict_mode: false,
95 allow_type_coercion: true,
96 min_confidence: 0.7,
97 allow_fuzzy_matching: true,
98 max_attempts: 3,
99 }
100 }
101}
102
103impl HealingConfig {
104 pub fn strict() -> Self {
106 Self {
107 enabled: true,
108 strict_mode: true,
109 allow_type_coercion: false,
110 min_confidence: 0.95,
111 allow_fuzzy_matching: false,
112 max_attempts: 1,
113 }
114 }
115
116 pub fn lenient() -> Self {
118 Self {
119 enabled: true,
120 strict_mode: false,
121 allow_type_coercion: true,
122 min_confidence: 0.5,
123 allow_fuzzy_matching: true,
124 max_attempts: 5,
125 }
126 }
127}
128
129#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
131pub struct Capabilities {
132 pub streaming: bool,
134 pub function_calling: bool,
136 pub vision: bool,
138 pub max_tokens: u32,
140}
141
142#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
144pub struct ProviderConfig {
145 pub name: String,
147 pub base_url: String,
149 #[serde(skip_serializing_if = "Option::is_none")]
151 pub api_key: Option<String>,
152 #[serde(skip_serializing_if = "Option::is_none")]
154 pub default_model: Option<String>,
155 #[serde(default)]
157 pub retry_config: RetryConfig,
158 #[serde(with = "duration_millis")]
160 pub timeout: Duration,
161 #[serde(default)]
163 pub capabilities: Capabilities,
164}
165
166impl ProviderConfig {
167 pub fn new(name: impl Into<String>, base_url: impl Into<String>) -> Self {
169 Self {
170 name: name.into(),
171 base_url: base_url.into(),
172 api_key: None,
173 default_model: None,
174 retry_config: RetryConfig::default(),
175 timeout: Duration::from_secs(30),
176 capabilities: Capabilities::default(),
177 }
178 }
179
180 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
182 self.api_key = Some(api_key.into());
183 self
184 }
185
186 pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
188 self.default_model = Some(model.into());
189 self
190 }
191
192 pub fn with_timeout(mut self, timeout: Duration) -> Self {
194 self.timeout = timeout;
195 self
196 }
197}
198
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
201pub enum RateLimitScope {
202 #[default]
204 PerInstance,
205 Shared,
207}
208
209#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
211pub struct RateLimitConfig {
212 pub enabled: bool,
214 pub requests_per_second: u32,
216 pub burst_size: u32,
218 #[serde(default)]
220 pub scope: RateLimitScope,
221}
222
223impl Default for RateLimitConfig {
224 fn default() -> Self {
225 Self {
226 enabled: false,
227 requests_per_second: 10,
228 burst_size: 20,
229 scope: RateLimitScope::PerInstance,
230 }
231 }
232}
233
234impl RateLimitConfig {
235 pub fn new(requests_per_second: u32, burst_size: u32) -> Self {
247 Self {
248 enabled: true,
249 requests_per_second,
250 burst_size,
251 scope: RateLimitScope::PerInstance,
252 }
253 }
254
255 pub fn shared(requests_per_second: u32, burst_size: u32) -> Self {
257 Self {
258 enabled: true,
259 requests_per_second,
260 burst_size,
261 scope: RateLimitScope::Shared,
262 }
263 }
264
265 pub fn disabled() -> Self {
267 Self {
268 enabled: false,
269 requests_per_second: 0,
270 burst_size: 0,
271 scope: RateLimitScope::PerInstance,
272 }
273 }
274}
275
276mod duration_millis {
278 use serde::{Deserialize, Deserializer, Serializer};
279 use std::time::Duration;
280
281 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
282 where
283 S: Serializer,
284 {
285 serializer.serialize_u64(duration.as_millis() as u64)
286 }
287
288 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
289 where
290 D: Deserializer<'de>,
291 {
292 let millis = u64::deserialize(deserializer)?;
293 Ok(Duration::from_millis(millis))
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[test]
302 fn test_retry_config_default() {
303 let config = RetryConfig::default();
304 assert_eq!(config.max_attempts, 3);
305 assert_eq!(config.initial_backoff, Duration::from_millis(100));
306 assert_eq!(config.max_backoff, Duration::from_secs(10));
307 assert_eq!(config.backoff_multiplier, 2.0);
308 assert!(config.jitter);
309 }
310
311 #[test]
312 fn test_retry_config_backoff() {
313 let config = RetryConfig {
314 max_attempts: 5,
315 initial_backoff: Duration::from_millis(100),
316 max_backoff: Duration::from_secs(10),
317 backoff_multiplier: 2.0,
318 jitter: false,
319 };
320
321 let backoff1 = config.calculate_backoff(0);
322 let backoff2 = config.calculate_backoff(1);
323 let backoff3 = config.calculate_backoff(2);
324
325 assert_eq!(backoff1, Duration::from_millis(100));
326 assert_eq!(backoff2, Duration::from_millis(200));
327 assert_eq!(backoff3, Duration::from_millis(400));
328 }
329
330 #[test]
331 fn test_retry_config_max_backoff() {
332 let config = RetryConfig {
333 max_attempts: 10,
334 initial_backoff: Duration::from_millis(100),
335 max_backoff: Duration::from_secs(1),
336 backoff_multiplier: 2.0,
337 jitter: false,
338 };
339
340 let backoff = config.calculate_backoff(10);
341 assert!(backoff <= Duration::from_secs(1));
342 }
343
344 #[test]
345 fn test_healing_config_default() {
346 let config = HealingConfig::default();
347 assert!(config.enabled);
348 assert!(!config.strict_mode);
349 assert!(config.allow_type_coercion);
350 assert_eq!(config.min_confidence, 0.7);
351 assert!(config.allow_fuzzy_matching);
352 }
353
354 #[test]
355 fn test_healing_config_strict() {
356 let config = HealingConfig::strict();
357 assert!(config.enabled);
358 assert!(config.strict_mode);
359 assert!(!config.allow_type_coercion);
360 assert_eq!(config.min_confidence, 0.95);
361 assert!(!config.allow_fuzzy_matching);
362 }
363
364 #[test]
365 fn test_healing_config_lenient() {
366 let config = HealingConfig::lenient();
367 assert!(config.enabled);
368 assert!(!config.strict_mode);
369 assert!(config.allow_type_coercion);
370 assert_eq!(config.min_confidence, 0.5);
371 assert!(config.allow_fuzzy_matching);
372 }
373
374 #[test]
375 fn test_capabilities_default() {
376 let caps = Capabilities::default();
377 assert!(!caps.streaming);
378 assert!(!caps.function_calling);
379 assert!(!caps.vision);
380 assert_eq!(caps.max_tokens, 0);
381 }
382
383 #[test]
384 fn test_provider_config_builder() {
385 let config = ProviderConfig::new("openai", "https://api.openai.com/v1")
386 .with_api_key("sk-test")
387 .with_default_model("gpt-4")
388 .with_timeout(Duration::from_secs(60));
389
390 assert_eq!(config.name, "openai");
391 assert_eq!(config.base_url, "https://api.openai.com/v1");
392 assert_eq!(config.api_key, Some("sk-test".to_string()));
393 assert_eq!(config.default_model, Some("gpt-4".to_string()));
394 assert_eq!(config.timeout, Duration::from_secs(60));
395 }
396
397 #[test]
398 fn test_config_serialization() {
399 let config = RetryConfig::default();
400 let json = serde_json::to_string(&config).unwrap();
401 let parsed: RetryConfig = serde_json::from_str(&json).unwrap();
402 assert_eq!(config, parsed);
403 }
404
405 #[test]
406 fn test_provider_config_serialization() {
407 let config = ProviderConfig::new("test", "https://example.com");
408 let json = serde_json::to_string(&config).unwrap();
409 let parsed: ProviderConfig = serde_json::from_str(&json).unwrap();
410 assert_eq!(config.name, parsed.name);
411 assert_eq!(config.base_url, parsed.base_url);
412 }
413
414 #[test]
415 fn test_jitter_randomness() {
416 let config = RetryConfig {
417 max_attempts: 5,
418 initial_backoff: Duration::from_millis(100),
419 max_backoff: Duration::from_secs(10),
420 backoff_multiplier: 2.0,
421 jitter: true,
422 };
423
424 let backoffs: Vec<Duration> = (0..10).map(|_| config.calculate_backoff(1)).collect();
426
427 for backoff in &backoffs {
429 let ms = backoff.as_millis();
430 assert!(ms >= 100, "Backoff too small: {}ms", ms); assert!(ms <= 300, "Backoff too large: {}ms", ms); }
433
434 let unique_count = backoffs
436 .iter()
437 .collect::<std::collections::HashSet<_>>()
438 .len();
439 assert!(
440 unique_count > 1,
441 "All jitter values are the same - RNG may not be working"
442 );
443 }
444}