1use serde::{Deserialize, Serialize};
6use std::time::Duration;
7
8#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10pub struct RetryConfig {
11 #[serde(deserialize_with = "non_zero_u32")]
13 pub max_attempts: u32,
14 #[serde(with = "duration_millis")]
16 pub initial_backoff: Duration,
17 #[serde(with = "duration_millis")]
19 pub max_backoff: Duration,
20 pub backoff_multiplier: f32,
22 pub jitter: bool,
24}
25
26impl Default for RetryConfig {
27 fn default() -> Self {
28 Self {
29 max_attempts: 3,
30 initial_backoff: Duration::from_millis(100),
31 max_backoff: Duration::from_secs(10),
32 backoff_multiplier: 2.0,
33 jitter: true,
34 }
35 }
36}
37
38impl RetryConfig {
39 pub fn validate(&self) -> Result<(), String> {
41 if self.max_attempts == 0 {
42 return Err("max_attempts must be >= 1".to_string());
43 }
44 Ok(())
45 }
46
47 pub fn calculate_backoff(&self, attempt: u32) -> Duration {
60 let base =
61 self.initial_backoff.as_millis() as f32 * self.backoff_multiplier.powi(attempt as i32);
62 let capped = base.min(self.max_backoff.as_millis() as f32);
63
64 let duration = if self.jitter {
65 let jitter_factor = 0.5 + (rand() * 0.5);
67 Duration::from_millis((capped * jitter_factor) as u64)
68 } else {
69 Duration::from_millis(capped as u64)
70 };
71
72 duration.min(self.max_backoff)
73 }
74}
75
76fn non_zero_u32<'de, D>(deserializer: D) -> Result<u32, D::Error>
77where
78 D: serde::Deserializer<'de>,
79{
80 let value = u32::deserialize(deserializer)?;
81 if value == 0 {
82 return Err(serde::de::Error::custom("max_attempts must be >= 1"));
83 }
84 Ok(value)
85}
86
87fn rand() -> f32 {
89 use rand::Rng;
90 rand::thread_rng().gen()
91}
92
93#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95pub struct HealingConfig {
96 pub enabled: bool,
98 pub strict_mode: bool,
100 pub allow_type_coercion: bool,
102 pub min_confidence: f32,
104 pub allow_fuzzy_matching: bool,
106 pub max_attempts: u32,
108}
109
110impl Default for HealingConfig {
111 fn default() -> Self {
112 Self {
113 enabled: true,
114 strict_mode: false,
115 allow_type_coercion: true,
116 min_confidence: 0.7,
117 allow_fuzzy_matching: true,
118 max_attempts: 3,
119 }
120 }
121}
122
123impl HealingConfig {
124 pub fn strict() -> Self {
126 Self {
127 enabled: true,
128 strict_mode: true,
129 allow_type_coercion: false,
130 min_confidence: 0.95,
131 allow_fuzzy_matching: false,
132 max_attempts: 1,
133 }
134 }
135
136 pub fn lenient() -> Self {
138 Self {
139 enabled: true,
140 strict_mode: false,
141 allow_type_coercion: true,
142 min_confidence: 0.5,
143 allow_fuzzy_matching: true,
144 max_attempts: 5,
145 }
146 }
147}
148
149#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
151pub struct Capabilities {
152 pub streaming: bool,
154 pub function_calling: bool,
156 pub vision: bool,
158 pub max_tokens: u32,
160}
161
162#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
164pub struct ProviderConfig {
165 pub name: String,
167 pub base_url: String,
169 #[serde(skip_serializing_if = "Option::is_none")]
171 pub api_key: Option<String>,
172 #[serde(skip_serializing_if = "Option::is_none")]
174 pub default_model: Option<String>,
175 #[serde(default)]
177 pub retry_config: RetryConfig,
178 #[serde(with = "duration_millis")]
180 pub timeout: Duration,
181 #[serde(default)]
183 pub capabilities: Capabilities,
184}
185
186impl ProviderConfig {
187 pub fn new(name: impl Into<String>, base_url: impl Into<String>) -> Self {
189 Self {
190 name: name.into(),
191 base_url: base_url.into(),
192 api_key: None,
193 default_model: None,
194 retry_config: RetryConfig::default(),
195 timeout: Duration::from_secs(30),
196 capabilities: Capabilities::default(),
197 }
198 }
199
200 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
202 self.api_key = Some(api_key.into());
203 self
204 }
205
206 pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
208 self.default_model = Some(model.into());
209 self
210 }
211
212 pub fn with_timeout(mut self, timeout: Duration) -> Self {
214 self.timeout = timeout;
215 self
216 }
217}
218
219#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
221pub enum RateLimitScope {
222 #[default]
224 PerInstance,
225 Shared,
227}
228
229#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
231pub struct RateLimitConfig {
232 pub enabled: bool,
234 pub requests_per_second: u32,
236 pub burst_size: u32,
238 #[serde(default)]
240 pub scope: RateLimitScope,
241}
242
243impl Default for RateLimitConfig {
244 fn default() -> Self {
245 Self {
246 enabled: false,
247 requests_per_second: 10,
248 burst_size: 20,
249 scope: RateLimitScope::PerInstance,
250 }
251 }
252}
253
254impl RateLimitConfig {
255 pub fn validate(&self) -> Result<(), String> {
257 if !self.enabled {
258 return Ok(());
259 }
260
261 if self.requests_per_second == 0 {
262 return Err(
263 "requests_per_second must be >= 1 when rate limiting is enabled".to_string(),
264 );
265 }
266
267 if self.burst_size == 0 {
268 return Err("burst_size must be >= 1 when rate limiting is enabled".to_string());
269 }
270
271 Ok(())
272 }
273
274 pub fn new(requests_per_second: u32, burst_size: u32) -> Self {
286 Self {
287 enabled: true,
288 requests_per_second,
289 burst_size,
290 scope: RateLimitScope::PerInstance,
291 }
292 }
293
294 pub fn shared(requests_per_second: u32, burst_size: u32) -> Self {
296 Self {
297 enabled: true,
298 requests_per_second,
299 burst_size,
300 scope: RateLimitScope::Shared,
301 }
302 }
303
304 pub fn disabled() -> Self {
306 Self {
307 enabled: false,
308 requests_per_second: 0,
309 burst_size: 0,
310 scope: RateLimitScope::PerInstance,
311 }
312 }
313}
314
315mod duration_millis {
317 use serde::{Deserialize, Deserializer, Serializer};
318 use std::time::Duration;
319
320 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
321 where
322 S: Serializer,
323 {
324 serializer.serialize_u64(duration.as_millis() as u64)
325 }
326
327 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
328 where
329 D: Deserializer<'de>,
330 {
331 let millis = u64::deserialize(deserializer)?;
332 Ok(Duration::from_millis(millis))
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn test_retry_config_default() {
342 let config = RetryConfig::default();
343 assert!(config.validate().is_ok());
344 assert_eq!(config.max_attempts, 3);
345 assert_eq!(config.initial_backoff, Duration::from_millis(100));
346 assert_eq!(config.max_backoff, Duration::from_secs(10));
347 assert_eq!(config.backoff_multiplier, 2.0);
348 assert!(config.jitter);
349 }
350
351 #[test]
352 fn test_retry_config_backoff() {
353 let config = RetryConfig {
354 max_attempts: 5,
355 initial_backoff: Duration::from_millis(100),
356 max_backoff: Duration::from_secs(10),
357 backoff_multiplier: 2.0,
358 jitter: false,
359 };
360
361 let backoff1 = config.calculate_backoff(0);
362 let backoff2 = config.calculate_backoff(1);
363 let backoff3 = config.calculate_backoff(2);
364
365 assert_eq!(backoff1, Duration::from_millis(100));
366 assert_eq!(backoff2, Duration::from_millis(200));
367 assert_eq!(backoff3, Duration::from_millis(400));
368 }
369
370 #[test]
371 fn test_retry_config_validate_rejects_zero_attempts() {
372 let config = RetryConfig {
373 max_attempts: 0,
374 initial_backoff: Duration::from_millis(100),
375 max_backoff: Duration::from_secs(1),
376 backoff_multiplier: 2.0,
377 jitter: false,
378 };
379
380 assert_eq!(
381 config.validate().unwrap_err(),
382 "max_attempts must be >= 1".to_string()
383 );
384 }
385
386 #[test]
387 fn test_retry_config_deserialize_rejects_zero_attempts() {
388 let json = r#"{\"max_attempts\":0,\"initial_backoff\":100,\"max_backoff\":1000,\"backoff_multiplier\":2.0,\"jitter\":false}"#;
389 let parsed: Result<RetryConfig, _> = serde_json::from_str(json);
390 assert!(parsed.is_err());
391 }
392
393 #[test]
394 fn test_retry_config_max_backoff() {
395 let config = RetryConfig {
396 max_attempts: 10,
397 initial_backoff: Duration::from_millis(100),
398 max_backoff: Duration::from_secs(1),
399 backoff_multiplier: 2.0,
400 jitter: false,
401 };
402
403 let backoff = config.calculate_backoff(10);
404 assert!(backoff <= Duration::from_secs(1));
405 }
406
407 #[test]
408 fn test_healing_config_default() {
409 let config = HealingConfig::default();
410 assert!(config.enabled);
411 assert!(!config.strict_mode);
412 assert!(config.allow_type_coercion);
413 assert_eq!(config.min_confidence, 0.7);
414 assert!(config.allow_fuzzy_matching);
415 }
416
417 #[test]
418 fn test_healing_config_strict() {
419 let config = HealingConfig::strict();
420 assert!(config.enabled);
421 assert!(config.strict_mode);
422 assert!(!config.allow_type_coercion);
423 assert_eq!(config.min_confidence, 0.95);
424 assert!(!config.allow_fuzzy_matching);
425 }
426
427 #[test]
428 fn test_healing_config_lenient() {
429 let config = HealingConfig::lenient();
430 assert!(config.enabled);
431 assert!(!config.strict_mode);
432 assert!(config.allow_type_coercion);
433 assert_eq!(config.min_confidence, 0.5);
434 assert!(config.allow_fuzzy_matching);
435 }
436
437 #[test]
438 fn test_capabilities_default() {
439 let caps = Capabilities::default();
440 assert!(!caps.streaming);
441 assert!(!caps.function_calling);
442 assert!(!caps.vision);
443 assert_eq!(caps.max_tokens, 0);
444 }
445
446 #[test]
447 fn test_provider_config_builder() {
448 let config = ProviderConfig::new("openai", "https://api.openai.com/v1")
449 .with_api_key("sk-test")
450 .with_default_model("gpt-4")
451 .with_timeout(Duration::from_secs(60));
452
453 assert_eq!(config.name, "openai");
454 assert_eq!(config.base_url, "https://api.openai.com/v1");
455 assert_eq!(config.api_key, Some("sk-test".to_string()));
456 assert_eq!(config.default_model, Some("gpt-4".to_string()));
457 assert_eq!(config.timeout, Duration::from_secs(60));
458 }
459
460 #[test]
461 fn test_config_serialization() {
462 let config = RetryConfig::default();
463 let json = serde_json::to_string(&config).unwrap();
464 let parsed: RetryConfig = serde_json::from_str(&json).unwrap();
465 assert_eq!(config, parsed);
466 }
467
468 #[test]
469 fn test_provider_config_serialization() {
470 let config = ProviderConfig::new("test", "https://example.com");
471 let json = serde_json::to_string(&config).unwrap();
472 let parsed: ProviderConfig = serde_json::from_str(&json).unwrap();
473 assert_eq!(config.name, parsed.name);
474 assert_eq!(config.base_url, parsed.base_url);
475 }
476
477 #[test]
478 fn test_rate_limit_config_validate_enabled_requires_non_zero_values() {
479 let invalid_rps = RateLimitConfig::new(0, 10);
480 assert_eq!(
481 invalid_rps.validate().unwrap_err(),
482 "requests_per_second must be >= 1 when rate limiting is enabled"
483 );
484
485 let invalid_burst = RateLimitConfig::new(10, 0);
486 assert_eq!(
487 invalid_burst.validate().unwrap_err(),
488 "burst_size must be >= 1 when rate limiting is enabled"
489 );
490 }
491
492 #[test]
493 fn test_rate_limit_config_validate_disabled_allows_zero_values() {
494 let disabled = RateLimitConfig::disabled();
495 assert!(disabled.validate().is_ok());
496 }
497
498 #[test]
499 fn test_jitter_randomness() {
500 let config = RetryConfig {
501 max_attempts: 5,
502 initial_backoff: Duration::from_millis(100),
503 max_backoff: Duration::from_secs(10),
504 backoff_multiplier: 2.0,
505 jitter: true,
506 };
507
508 let backoffs: Vec<Duration> = (0..10).map(|_| config.calculate_backoff(1)).collect();
510
511 for backoff in &backoffs {
513 let ms = backoff.as_millis();
514 assert!(ms >= 100, "Backoff too small: {}ms", ms); assert!(ms <= 300, "Backoff too large: {}ms", ms); }
517
518 let unique_count = backoffs
520 .iter()
521 .collect::<std::collections::HashSet<_>>()
522 .len();
523 assert!(
524 unique_count > 1,
525 "All jitter values are the same - RNG may not be working"
526 );
527 }
528}