1use serde::{Deserialize, Serialize};
6use std::time::Duration;
7
8#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10pub struct RetryConfig {
11 #[serde(deserialize_with = "non_zero_u32")]
13 pub max_attempts: u32,
14 #[serde(with = "duration_millis")]
16 pub initial_backoff: Duration,
17 #[serde(with = "duration_millis")]
19 pub max_backoff: Duration,
20 pub backoff_multiplier: f32,
22 pub jitter: bool,
24}
25
26impl Default for RetryConfig {
27 fn default() -> Self {
28 Self {
29 max_attempts: 3,
30 initial_backoff: Duration::from_millis(100),
31 max_backoff: Duration::from_secs(10),
32 backoff_multiplier: 2.0,
33 jitter: true,
34 }
35 }
36}
37
38impl RetryConfig {
39 pub fn validate(&self) -> Result<(), String> {
41 if self.max_attempts == 0 {
42 return Err("max_attempts must be >= 1".to_string());
43 }
44 Ok(())
45 }
46
47 pub fn calculate_backoff(&self, attempt: u32) -> Duration {
60 let base =
61 self.initial_backoff.as_millis() as f32 * self.backoff_multiplier.powi(attempt as i32);
62 let capped = base.min(self.max_backoff.as_millis() as f32);
63
64 let duration = if self.jitter {
65 let jitter_factor = 0.5 + (rand() * 0.5);
67 Duration::from_millis((capped * jitter_factor) as u64)
68 } else {
69 Duration::from_millis(capped as u64)
70 };
71
72 duration.min(self.max_backoff)
73 }
74}
75
76fn non_zero_u32<'de, D>(deserializer: D) -> Result<u32, D::Error>
77where
78 D: serde::Deserializer<'de>,
79{
80 let value = u32::deserialize(deserializer)?;
81 if value == 0 {
82 return Err(serde::de::Error::custom("max_attempts must be >= 1"));
83 }
84 Ok(value)
85}
86
87fn rand() -> f32 {
89 use rand::Rng;
90 rand::thread_rng().gen()
91}
92
93#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95pub struct HealingConfig {
96 pub enabled: bool,
98 pub strict_mode: bool,
100 pub allow_type_coercion: bool,
102 pub min_confidence: f32,
104 pub allow_fuzzy_matching: bool,
106 pub max_attempts: u32,
108}
109
110impl Default for HealingConfig {
111 fn default() -> Self {
112 Self {
113 enabled: true,
114 strict_mode: false,
115 allow_type_coercion: true,
116 min_confidence: 0.7,
117 allow_fuzzy_matching: true,
118 max_attempts: 3,
119 }
120 }
121}
122
123impl HealingConfig {
124 pub fn strict() -> Self {
126 Self {
127 enabled: true,
128 strict_mode: true,
129 allow_type_coercion: false,
130 min_confidence: 0.95,
131 allow_fuzzy_matching: false,
132 max_attempts: 1,
133 }
134 }
135
136 pub fn lenient() -> Self {
138 Self {
139 enabled: true,
140 strict_mode: false,
141 allow_type_coercion: true,
142 min_confidence: 0.5,
143 allow_fuzzy_matching: true,
144 max_attempts: 5,
145 }
146 }
147}
148
149#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
151pub struct Capabilities {
152 pub streaming: bool,
154 pub function_calling: bool,
156 pub vision: bool,
158 pub max_tokens: u32,
160}
161
162#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
164pub struct ProviderConfig {
165 pub name: String,
167 pub base_url: String,
169 #[serde(
171 skip_serializing_if = "Option::is_none",
172 serialize_with = "serialize_optional_secret"
173 )]
174 pub api_key: Option<String>,
175 #[serde(skip_serializing_if = "Option::is_none")]
177 pub default_model: Option<String>,
178 #[serde(default)]
180 pub retry_config: RetryConfig,
181 #[serde(with = "duration_millis")]
183 pub timeout: Duration,
184 #[serde(default)]
186 pub capabilities: Capabilities,
187}
188
189fn serialize_optional_secret<S>(value: &Option<String>, serializer: S) -> Result<S::Ok, S::Error>
190where
191 S: serde::Serializer,
192{
193 match value {
194 Some(_) => serializer.serialize_some("<redacted>"),
195 None => serializer.serialize_none(),
196 }
197}
198
199impl ProviderConfig {
200 pub fn new(name: impl Into<String>, base_url: impl Into<String>) -> Self {
202 Self {
203 name: name.into(),
204 base_url: base_url.into(),
205 api_key: None,
206 default_model: None,
207 retry_config: RetryConfig::default(),
208 timeout: Duration::from_secs(30),
209 capabilities: Capabilities::default(),
210 }
211 }
212
213 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
215 self.api_key = Some(api_key.into());
216 self
217 }
218
219 pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
221 self.default_model = Some(model.into());
222 self
223 }
224
225 pub fn with_timeout(mut self, timeout: Duration) -> Self {
227 self.timeout = timeout;
228 self
229 }
230}
231
232#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
234pub enum RateLimitScope {
235 #[default]
237 PerInstance,
238 Shared,
240}
241
242#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
244pub struct RateLimitConfig {
245 pub enabled: bool,
247 pub requests_per_second: u32,
249 pub burst_size: u32,
251 #[serde(default)]
253 pub scope: RateLimitScope,
254}
255
256impl Default for RateLimitConfig {
257 fn default() -> Self {
258 Self {
259 enabled: false,
260 requests_per_second: 10,
261 burst_size: 20,
262 scope: RateLimitScope::PerInstance,
263 }
264 }
265}
266
267impl RateLimitConfig {
268 pub fn validate(&self) -> Result<(), String> {
270 if !self.enabled {
271 return Ok(());
272 }
273
274 if self.requests_per_second == 0 {
275 return Err(
276 "requests_per_second must be >= 1 when rate limiting is enabled".to_string(),
277 );
278 }
279
280 if self.burst_size == 0 {
281 return Err("burst_size must be >= 1 when rate limiting is enabled".to_string());
282 }
283
284 Ok(())
285 }
286
287 pub fn new(requests_per_second: u32, burst_size: u32) -> Self {
299 Self {
300 enabled: true,
301 requests_per_second,
302 burst_size,
303 scope: RateLimitScope::PerInstance,
304 }
305 }
306
307 pub fn shared(requests_per_second: u32, burst_size: u32) -> Self {
309 Self {
310 enabled: true,
311 requests_per_second,
312 burst_size,
313 scope: RateLimitScope::Shared,
314 }
315 }
316
317 pub fn disabled() -> Self {
319 Self {
320 enabled: false,
321 requests_per_second: 0,
322 burst_size: 0,
323 scope: RateLimitScope::PerInstance,
324 }
325 }
326}
327
328mod duration_millis {
330 use serde::{Deserialize, Deserializer, Serializer};
331 use std::time::Duration;
332
333 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
334 where
335 S: Serializer,
336 {
337 serializer.serialize_u64(duration.as_millis() as u64)
338 }
339
340 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
341 where
342 D: Deserializer<'de>,
343 {
344 let millis = u64::deserialize(deserializer)?;
345 Ok(Duration::from_millis(millis))
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_retry_config_default() {
355 let config = RetryConfig::default();
356 assert!(config.validate().is_ok());
357 assert_eq!(config.max_attempts, 3);
358 assert_eq!(config.initial_backoff, Duration::from_millis(100));
359 assert_eq!(config.max_backoff, Duration::from_secs(10));
360 assert_eq!(config.backoff_multiplier, 2.0);
361 assert!(config.jitter);
362 }
363
364 #[test]
365 fn test_retry_config_backoff() {
366 let config = RetryConfig {
367 max_attempts: 5,
368 initial_backoff: Duration::from_millis(100),
369 max_backoff: Duration::from_secs(10),
370 backoff_multiplier: 2.0,
371 jitter: false,
372 };
373
374 let backoff1 = config.calculate_backoff(0);
375 let backoff2 = config.calculate_backoff(1);
376 let backoff3 = config.calculate_backoff(2);
377
378 assert_eq!(backoff1, Duration::from_millis(100));
379 assert_eq!(backoff2, Duration::from_millis(200));
380 assert_eq!(backoff3, Duration::from_millis(400));
381 }
382
383 #[test]
384 fn test_retry_config_validate_rejects_zero_attempts() {
385 let config = RetryConfig {
386 max_attempts: 0,
387 initial_backoff: Duration::from_millis(100),
388 max_backoff: Duration::from_secs(1),
389 backoff_multiplier: 2.0,
390 jitter: false,
391 };
392
393 assert_eq!(
394 config.validate().unwrap_err(),
395 "max_attempts must be >= 1".to_string()
396 );
397 }
398
399 #[test]
400 fn test_retry_config_deserialize_rejects_zero_attempts() {
401 let json = r#"{\"max_attempts\":0,\"initial_backoff\":100,\"max_backoff\":1000,\"backoff_multiplier\":2.0,\"jitter\":false}"#;
402 let parsed: Result<RetryConfig, _> = serde_json::from_str(json);
403 assert!(parsed.is_err());
404 }
405
406 #[test]
407 fn test_retry_config_max_backoff() {
408 let config = RetryConfig {
409 max_attempts: 10,
410 initial_backoff: Duration::from_millis(100),
411 max_backoff: Duration::from_secs(1),
412 backoff_multiplier: 2.0,
413 jitter: false,
414 };
415
416 let backoff = config.calculate_backoff(10);
417 assert!(backoff <= Duration::from_secs(1));
418 }
419
420 #[test]
421 fn test_healing_config_default() {
422 let config = HealingConfig::default();
423 assert!(config.enabled);
424 assert!(!config.strict_mode);
425 assert!(config.allow_type_coercion);
426 assert_eq!(config.min_confidence, 0.7);
427 assert!(config.allow_fuzzy_matching);
428 }
429
430 #[test]
431 fn test_healing_config_strict() {
432 let config = HealingConfig::strict();
433 assert!(config.enabled);
434 assert!(config.strict_mode);
435 assert!(!config.allow_type_coercion);
436 assert_eq!(config.min_confidence, 0.95);
437 assert!(!config.allow_fuzzy_matching);
438 }
439
440 #[test]
441 fn test_healing_config_lenient() {
442 let config = HealingConfig::lenient();
443 assert!(config.enabled);
444 assert!(!config.strict_mode);
445 assert!(config.allow_type_coercion);
446 assert_eq!(config.min_confidence, 0.5);
447 assert!(config.allow_fuzzy_matching);
448 }
449
450 #[test]
451 fn test_capabilities_default() {
452 let caps = Capabilities::default();
453 assert!(!caps.streaming);
454 assert!(!caps.function_calling);
455 assert!(!caps.vision);
456 assert_eq!(caps.max_tokens, 0);
457 }
458
459 #[test]
460 fn test_provider_config_builder() {
461 let config = ProviderConfig::new("openai", "https://api.openai.com/v1")
462 .with_api_key("sk-test")
463 .with_default_model("gpt-4")
464 .with_timeout(Duration::from_secs(60));
465
466 assert_eq!(config.name, "openai");
467 assert_eq!(config.base_url, "https://api.openai.com/v1");
468 assert_eq!(config.api_key, Some("sk-test".to_string()));
469 assert_eq!(config.default_model, Some("gpt-4".to_string()));
470 assert_eq!(config.timeout, Duration::from_secs(60));
471 }
472
473 #[test]
474 fn test_config_serialization() {
475 let config = RetryConfig::default();
476 let json = serde_json::to_string(&config).unwrap();
477 let parsed: RetryConfig = serde_json::from_str(&json).unwrap();
478 assert_eq!(config, parsed);
479 }
480
481 #[test]
482 fn test_provider_config_serialization() {
483 let config = ProviderConfig::new("test", "https://example.com");
484 let json = serde_json::to_string(&config).unwrap();
485 let parsed: ProviderConfig = serde_json::from_str(&json).unwrap();
486 assert_eq!(config.name, parsed.name);
487 assert_eq!(config.base_url, parsed.base_url);
488 }
489
490 #[test]
491 fn test_provider_config_serialization_redacts_api_key() {
492 let config = ProviderConfig::new("test", "https://example.com").with_api_key("secret-key");
493 let json = serde_json::to_string(&config).unwrap();
494 let value: serde_json::Value = serde_json::from_str(&json).unwrap();
495 assert_eq!(
496 value.get("api_key"),
497 Some(&serde_json::Value::String("<redacted>".to_string()))
498 );
499 }
500
501 #[test]
502 fn test_rate_limit_config_validate_enabled_requires_non_zero_values() {
503 let invalid_rps = RateLimitConfig::new(0, 10);
504 assert_eq!(
505 invalid_rps.validate().unwrap_err(),
506 "requests_per_second must be >= 1 when rate limiting is enabled"
507 );
508
509 let invalid_burst = RateLimitConfig::new(10, 0);
510 assert_eq!(
511 invalid_burst.validate().unwrap_err(),
512 "burst_size must be >= 1 when rate limiting is enabled"
513 );
514 }
515
516 #[test]
517 fn test_rate_limit_config_validate_disabled_allows_zero_values() {
518 let disabled = RateLimitConfig::disabled();
519 assert!(disabled.validate().is_ok());
520 }
521
522 #[test]
523 fn test_jitter_randomness() {
524 let config = RetryConfig {
525 max_attempts: 5,
526 initial_backoff: Duration::from_millis(100),
527 max_backoff: Duration::from_secs(10),
528 backoff_multiplier: 2.0,
529 jitter: true,
530 };
531
532 let backoffs: Vec<Duration> = (0..10).map(|_| config.calculate_backoff(1)).collect();
534
535 for backoff in &backoffs {
537 let ms = backoff.as_millis();
538 assert!(ms >= 100, "Backoff too small: {}ms", ms); assert!(ms <= 300, "Backoff too large: {}ms", ms); }
541
542 let unique_count = backoffs
544 .iter()
545 .collect::<std::collections::HashSet<_>>()
546 .len();
547 assert!(
548 unique_count > 1,
549 "All jitter values are the same - RNG may not be working"
550 );
551 }
552}