Skip to main content

simple_agent_type/
config.rs

1//! Configuration types for SimpleAgents.
2//!
3//! Provides configuration for retry, healing, and provider capabilities.
4
5use serde::{Deserialize, Serialize};
6use std::time::Duration;
7
8/// Retry configuration for failed requests.
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10pub struct RetryConfig {
11    /// Maximum number of retry attempts
12    #[serde(deserialize_with = "non_zero_u32")]
13    pub max_attempts: u32,
14    /// Initial backoff duration
15    #[serde(with = "duration_millis")]
16    pub initial_backoff: Duration,
17    /// Maximum backoff duration
18    #[serde(with = "duration_millis")]
19    pub max_backoff: Duration,
20    /// Backoff multiplier for exponential backoff
21    pub backoff_multiplier: f32,
22    /// Add random jitter to backoff
23    pub jitter: bool,
24}
25
26impl Default for RetryConfig {
27    fn default() -> Self {
28        Self {
29            max_attempts: 3,
30            initial_backoff: Duration::from_millis(100),
31            max_backoff: Duration::from_secs(10),
32            backoff_multiplier: 2.0,
33            jitter: true,
34        }
35    }
36}
37
38impl RetryConfig {
39    /// Validate retry configuration invariants.
40    pub fn validate(&self) -> Result<(), String> {
41        if self.max_attempts == 0 {
42            return Err("max_attempts must be >= 1".to_string());
43        }
44        Ok(())
45    }
46
47    /// Calculate backoff duration for a given attempt.
48    ///
49    /// # Example
50    /// ```
51    /// use simple_agent_type::config::RetryConfig;
52    /// use std::time::Duration;
53    ///
54    /// let config = RetryConfig::default();
55    /// let backoff = config.calculate_backoff(1);
56    /// assert!(backoff >= Duration::from_millis(100));
57    /// assert!(backoff <= Duration::from_millis(200)); // with jitter
58    /// ```
59    pub fn calculate_backoff(&self, attempt: u32) -> Duration {
60        let base =
61            self.initial_backoff.as_millis() as f32 * self.backoff_multiplier.powi(attempt as i32);
62        let capped = base.min(self.max_backoff.as_millis() as f32);
63
64        let duration = if self.jitter {
65            // Add up to 50% jitter
66            let jitter_factor = 0.5 + (rand() * 0.5);
67            Duration::from_millis((capped * jitter_factor) as u64)
68        } else {
69            Duration::from_millis(capped as u64)
70        };
71
72        duration.min(self.max_backoff)
73    }
74}
75
76fn non_zero_u32<'de, D>(deserializer: D) -> Result<u32, D::Error>
77where
78    D: serde::Deserializer<'de>,
79{
80    let value = u32::deserialize(deserializer)?;
81    if value == 0 {
82        return Err(serde::de::Error::custom("max_attempts must be >= 1"));
83    }
84    Ok(value)
85}
86
87// Cryptographically secure random number generator for jitter (0.0-1.0)
88fn rand() -> f32 {
89    use rand::Rng;
90    rand::thread_rng().gen()
91}
92
93/// Healing configuration for response coercion.
94#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95pub struct HealingConfig {
96    /// Enable healing/coercion
97    pub enabled: bool,
98    /// Strict mode (fail on coercion)
99    pub strict_mode: bool,
100    /// Allow type coercion
101    pub allow_type_coercion: bool,
102    /// Minimum confidence threshold (0.0-1.0)
103    pub min_confidence: f32,
104    /// Allow fuzzy field name matching
105    pub allow_fuzzy_matching: bool,
106    /// Maximum healing attempts
107    pub max_attempts: u32,
108}
109
110impl Default for HealingConfig {
111    fn default() -> Self {
112        Self {
113            enabled: true,
114            strict_mode: false,
115            allow_type_coercion: true,
116            min_confidence: 0.7,
117            allow_fuzzy_matching: true,
118            max_attempts: 3,
119        }
120    }
121}
122
123impl HealingConfig {
124    /// Create a strict healing configuration.
125    pub fn strict() -> Self {
126        Self {
127            enabled: true,
128            strict_mode: true,
129            allow_type_coercion: false,
130            min_confidence: 0.95,
131            allow_fuzzy_matching: false,
132            max_attempts: 1,
133        }
134    }
135
136    /// Create a lenient healing configuration.
137    pub fn lenient() -> Self {
138        Self {
139            enabled: true,
140            strict_mode: false,
141            allow_type_coercion: true,
142            min_confidence: 0.5,
143            allow_fuzzy_matching: true,
144            max_attempts: 5,
145        }
146    }
147}
148
149/// Provider capabilities.
150#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
151pub struct Capabilities {
152    /// Supports streaming responses
153    pub streaming: bool,
154    /// Supports function/tool calling
155    pub function_calling: bool,
156    /// Supports vision/image inputs
157    pub vision: bool,
158    /// Maximum output tokens
159    pub max_tokens: u32,
160}
161
162/// Provider configuration.
163#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
164pub struct ProviderConfig {
165    /// Provider name
166    pub name: String,
167    /// Base URL for API
168    pub base_url: String,
169    /// API key (optional for some providers)
170    #[serde(
171        skip_serializing_if = "Option::is_none",
172        serialize_with = "serialize_optional_secret"
173    )]
174    pub api_key: Option<String>,
175    /// Default model
176    #[serde(skip_serializing_if = "Option::is_none")]
177    pub default_model: Option<String>,
178    /// Retry configuration
179    #[serde(default)]
180    pub retry_config: RetryConfig,
181    /// Request timeout
182    #[serde(with = "duration_millis")]
183    pub timeout: Duration,
184    /// Provider capabilities
185    #[serde(default)]
186    pub capabilities: Capabilities,
187}
188
189fn serialize_optional_secret<S>(value: &Option<String>, serializer: S) -> Result<S::Ok, S::Error>
190where
191    S: serde::Serializer,
192{
193    match value {
194        Some(_) => serializer.serialize_some("<redacted>"),
195        None => serializer.serialize_none(),
196    }
197}
198
199impl ProviderConfig {
200    /// Create a new provider configuration.
201    pub fn new(name: impl Into<String>, base_url: impl Into<String>) -> Self {
202        Self {
203            name: name.into(),
204            base_url: base_url.into(),
205            api_key: None,
206            default_model: None,
207            retry_config: RetryConfig::default(),
208            timeout: Duration::from_secs(30),
209            capabilities: Capabilities::default(),
210        }
211    }
212
213    /// Set the API key.
214    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
215        self.api_key = Some(api_key.into());
216        self
217    }
218
219    /// Set the default model.
220    pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
221        self.default_model = Some(model.into());
222        self
223    }
224
225    /// Set the timeout.
226    pub fn with_timeout(mut self, timeout: Duration) -> Self {
227        self.timeout = timeout;
228        self
229    }
230}
231
232/// Rate limiting scope.
233#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
234pub enum RateLimitScope {
235    /// Per provider instance (isolated rate limits)
236    #[default]
237    PerInstance,
238    /// Shared across all instances with the same API key
239    Shared,
240}
241
242/// Rate limiting configuration.
243#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
244pub struct RateLimitConfig {
245    /// Enable rate limiting
246    pub enabled: bool,
247    /// Maximum requests per second
248    pub requests_per_second: u32,
249    /// Burst size (maximum concurrent requests)
250    pub burst_size: u32,
251    /// Rate limit scope
252    #[serde(default)]
253    pub scope: RateLimitScope,
254}
255
256impl Default for RateLimitConfig {
257    fn default() -> Self {
258        Self {
259            enabled: false,
260            requests_per_second: 10,
261            burst_size: 20,
262            scope: RateLimitScope::PerInstance,
263        }
264    }
265}
266
267impl RateLimitConfig {
268    /// Validate rate-limit configuration invariants.
269    pub fn validate(&self) -> Result<(), String> {
270        if !self.enabled {
271            return Ok(());
272        }
273
274        if self.requests_per_second == 0 {
275            return Err(
276                "requests_per_second must be >= 1 when rate limiting is enabled".to_string(),
277            );
278        }
279
280        if self.burst_size == 0 {
281            return Err("burst_size must be >= 1 when rate limiting is enabled".to_string());
282        }
283
284        Ok(())
285    }
286
287    /// Create a new rate limit configuration with given requests per second.
288    ///
289    /// # Example
290    /// ```
291    /// use simple_agent_type::config::RateLimitConfig;
292    ///
293    /// let config = RateLimitConfig::new(50, 100);
294    /// assert_eq!(config.requests_per_second, 50);
295    /// assert_eq!(config.burst_size, 100);
296    /// assert!(config.enabled);
297    /// ```
298    pub fn new(requests_per_second: u32, burst_size: u32) -> Self {
299        Self {
300            enabled: true,
301            requests_per_second,
302            burst_size,
303            scope: RateLimitScope::PerInstance,
304        }
305    }
306
307    /// Create rate limit config with shared scope.
308    pub fn shared(requests_per_second: u32, burst_size: u32) -> Self {
309        Self {
310            enabled: true,
311            requests_per_second,
312            burst_size,
313            scope: RateLimitScope::Shared,
314        }
315    }
316
317    /// Disable rate limiting.
318    pub fn disabled() -> Self {
319        Self {
320            enabled: false,
321            requests_per_second: 0,
322            burst_size: 0,
323            scope: RateLimitScope::PerInstance,
324        }
325    }
326}
327
328// Serde helper for Duration serialization/deserialization as milliseconds
329mod duration_millis {
330    use serde::{Deserialize, Deserializer, Serializer};
331    use std::time::Duration;
332
333    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
334    where
335        S: Serializer,
336    {
337        serializer.serialize_u64(duration.as_millis() as u64)
338    }
339
340    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
341    where
342        D: Deserializer<'de>,
343    {
344        let millis = u64::deserialize(deserializer)?;
345        Ok(Duration::from_millis(millis))
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_retry_config_default() {
355        let config = RetryConfig::default();
356        assert!(config.validate().is_ok());
357        assert_eq!(config.max_attempts, 3);
358        assert_eq!(config.initial_backoff, Duration::from_millis(100));
359        assert_eq!(config.max_backoff, Duration::from_secs(10));
360        assert_eq!(config.backoff_multiplier, 2.0);
361        assert!(config.jitter);
362    }
363
364    #[test]
365    fn test_retry_config_backoff() {
366        let config = RetryConfig {
367            max_attempts: 5,
368            initial_backoff: Duration::from_millis(100),
369            max_backoff: Duration::from_secs(10),
370            backoff_multiplier: 2.0,
371            jitter: false,
372        };
373
374        let backoff1 = config.calculate_backoff(0);
375        let backoff2 = config.calculate_backoff(1);
376        let backoff3 = config.calculate_backoff(2);
377
378        assert_eq!(backoff1, Duration::from_millis(100));
379        assert_eq!(backoff2, Duration::from_millis(200));
380        assert_eq!(backoff3, Duration::from_millis(400));
381    }
382
383    #[test]
384    fn test_retry_config_validate_rejects_zero_attempts() {
385        let config = RetryConfig {
386            max_attempts: 0,
387            initial_backoff: Duration::from_millis(100),
388            max_backoff: Duration::from_secs(1),
389            backoff_multiplier: 2.0,
390            jitter: false,
391        };
392
393        assert_eq!(
394            config.validate().unwrap_err(),
395            "max_attempts must be >= 1".to_string()
396        );
397    }
398
399    #[test]
400    fn test_retry_config_deserialize_rejects_zero_attempts() {
401        let json = r#"{\"max_attempts\":0,\"initial_backoff\":100,\"max_backoff\":1000,\"backoff_multiplier\":2.0,\"jitter\":false}"#;
402        let parsed: Result<RetryConfig, _> = serde_json::from_str(json);
403        assert!(parsed.is_err());
404    }
405
406    #[test]
407    fn test_retry_config_max_backoff() {
408        let config = RetryConfig {
409            max_attempts: 10,
410            initial_backoff: Duration::from_millis(100),
411            max_backoff: Duration::from_secs(1),
412            backoff_multiplier: 2.0,
413            jitter: false,
414        };
415
416        let backoff = config.calculate_backoff(10);
417        assert!(backoff <= Duration::from_secs(1));
418    }
419
420    #[test]
421    fn test_healing_config_default() {
422        let config = HealingConfig::default();
423        assert!(config.enabled);
424        assert!(!config.strict_mode);
425        assert!(config.allow_type_coercion);
426        assert_eq!(config.min_confidence, 0.7);
427        assert!(config.allow_fuzzy_matching);
428    }
429
430    #[test]
431    fn test_healing_config_strict() {
432        let config = HealingConfig::strict();
433        assert!(config.enabled);
434        assert!(config.strict_mode);
435        assert!(!config.allow_type_coercion);
436        assert_eq!(config.min_confidence, 0.95);
437        assert!(!config.allow_fuzzy_matching);
438    }
439
440    #[test]
441    fn test_healing_config_lenient() {
442        let config = HealingConfig::lenient();
443        assert!(config.enabled);
444        assert!(!config.strict_mode);
445        assert!(config.allow_type_coercion);
446        assert_eq!(config.min_confidence, 0.5);
447        assert!(config.allow_fuzzy_matching);
448    }
449
450    #[test]
451    fn test_capabilities_default() {
452        let caps = Capabilities::default();
453        assert!(!caps.streaming);
454        assert!(!caps.function_calling);
455        assert!(!caps.vision);
456        assert_eq!(caps.max_tokens, 0);
457    }
458
459    #[test]
460    fn test_provider_config_builder() {
461        let config = ProviderConfig::new("openai", "https://api.openai.com/v1")
462            .with_api_key("sk-test")
463            .with_default_model("gpt-4")
464            .with_timeout(Duration::from_secs(60));
465
466        assert_eq!(config.name, "openai");
467        assert_eq!(config.base_url, "https://api.openai.com/v1");
468        assert_eq!(config.api_key, Some("sk-test".to_string()));
469        assert_eq!(config.default_model, Some("gpt-4".to_string()));
470        assert_eq!(config.timeout, Duration::from_secs(60));
471    }
472
473    #[test]
474    fn test_config_serialization() {
475        let config = RetryConfig::default();
476        let json = serde_json::to_string(&config).unwrap();
477        let parsed: RetryConfig = serde_json::from_str(&json).unwrap();
478        assert_eq!(config, parsed);
479    }
480
481    #[test]
482    fn test_provider_config_serialization() {
483        let config = ProviderConfig::new("test", "https://example.com");
484        let json = serde_json::to_string(&config).unwrap();
485        let parsed: ProviderConfig = serde_json::from_str(&json).unwrap();
486        assert_eq!(config.name, parsed.name);
487        assert_eq!(config.base_url, parsed.base_url);
488    }
489
490    #[test]
491    fn test_provider_config_serialization_redacts_api_key() {
492        let config = ProviderConfig::new("test", "https://example.com").with_api_key("secret-key");
493        let json = serde_json::to_string(&config).unwrap();
494        let value: serde_json::Value = serde_json::from_str(&json).unwrap();
495        assert_eq!(
496            value.get("api_key"),
497            Some(&serde_json::Value::String("<redacted>".to_string()))
498        );
499    }
500
501    #[test]
502    fn test_rate_limit_config_validate_enabled_requires_non_zero_values() {
503        let invalid_rps = RateLimitConfig::new(0, 10);
504        assert_eq!(
505            invalid_rps.validate().unwrap_err(),
506            "requests_per_second must be >= 1 when rate limiting is enabled"
507        );
508
509        let invalid_burst = RateLimitConfig::new(10, 0);
510        assert_eq!(
511            invalid_burst.validate().unwrap_err(),
512            "burst_size must be >= 1 when rate limiting is enabled"
513        );
514    }
515
516    #[test]
517    fn test_rate_limit_config_validate_disabled_allows_zero_values() {
518        let disabled = RateLimitConfig::disabled();
519        assert!(disabled.validate().is_ok());
520    }
521
522    #[test]
523    fn test_jitter_randomness() {
524        let config = RetryConfig {
525            max_attempts: 5,
526            initial_backoff: Duration::from_millis(100),
527            max_backoff: Duration::from_secs(10),
528            backoff_multiplier: 2.0,
529            jitter: true,
530        };
531
532        // Generate multiple backoffs and verify they're different (with high probability)
533        let backoffs: Vec<Duration> = (0..10).map(|_| config.calculate_backoff(1)).collect();
534
535        // All values should be within expected range (50-150ms for attempt 1 with jitter)
536        for backoff in &backoffs {
537            let ms = backoff.as_millis();
538            assert!(ms >= 100, "Backoff too small: {}ms", ms); // 50% of 200ms = 100ms
539            assert!(ms <= 300, "Backoff too large: {}ms", ms); // 150% of 200ms = 300ms
540        }
541
542        // At least some values should be different (very high probability with true randomness)
543        let unique_count = backoffs
544            .iter()
545            .collect::<std::collections::HashSet<_>>()
546            .len();
547        assert!(
548            unique_count > 1,
549            "All jitter values are the same - RNG may not be working"
550        );
551    }
552}