Skip to main content

simple_agent_type/
config.rs

1//! Configuration types for SimpleAgents.
2//!
3//! Provides configuration for retry, healing, and provider capabilities.
4
5use serde::{Deserialize, Serialize};
6use std::time::Duration;
7
8/// Retry configuration for failed requests.
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10pub struct RetryConfig {
11    /// Maximum number of retry attempts
12    #[serde(deserialize_with = "non_zero_u32")]
13    pub max_attempts: u32,
14    /// Initial backoff duration
15    #[serde(with = "duration_millis")]
16    pub initial_backoff: Duration,
17    /// Maximum backoff duration
18    #[serde(with = "duration_millis")]
19    pub max_backoff: Duration,
20    /// Backoff multiplier for exponential backoff
21    pub backoff_multiplier: f32,
22    /// Add random jitter to backoff
23    pub jitter: bool,
24}
25
26impl Default for RetryConfig {
27    fn default() -> Self {
28        Self {
29            max_attempts: 3,
30            initial_backoff: Duration::from_millis(100),
31            max_backoff: Duration::from_secs(10),
32            backoff_multiplier: 2.0,
33            jitter: true,
34        }
35    }
36}
37
38impl RetryConfig {
39    /// Validate retry configuration invariants.
40    pub fn validate(&self) -> Result<(), String> {
41        if self.max_attempts == 0 {
42            return Err("max_attempts must be >= 1".to_string());
43        }
44        Ok(())
45    }
46
47    /// Calculate backoff duration for a given attempt.
48    ///
49    /// # Example
50    /// ```
51    /// use simple_agent_type::config::RetryConfig;
52    /// use std::time::Duration;
53    ///
54    /// let config = RetryConfig::default();
55    /// let backoff = config.calculate_backoff(1);
56    /// assert!(backoff >= Duration::from_millis(100));
57    /// assert!(backoff <= Duration::from_millis(200)); // with jitter
58    /// ```
59    pub fn calculate_backoff(&self, attempt: u32) -> Duration {
60        let base =
61            self.initial_backoff.as_millis() as f32 * self.backoff_multiplier.powi(attempt as i32);
62        let capped = base.min(self.max_backoff.as_millis() as f32);
63
64        let duration = if self.jitter {
65            // Add up to 50% jitter
66            let jitter_factor = 0.5 + (rand() * 0.5);
67            Duration::from_millis((capped * jitter_factor) as u64)
68        } else {
69            Duration::from_millis(capped as u64)
70        };
71
72        duration.min(self.max_backoff)
73    }
74}
75
76fn non_zero_u32<'de, D>(deserializer: D) -> Result<u32, D::Error>
77where
78    D: serde::Deserializer<'de>,
79{
80    let value = u32::deserialize(deserializer)?;
81    if value == 0 {
82        return Err(serde::de::Error::custom("max_attempts must be >= 1"));
83    }
84    Ok(value)
85}
86
87// Cryptographically secure random number generator for jitter (0.0-1.0)
88fn rand() -> f32 {
89    use rand::Rng;
90    rand::thread_rng().gen()
91}
92
93/// Healing configuration for response coercion.
94#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95pub struct HealingConfig {
96    /// Enable healing/coercion
97    pub enabled: bool,
98    /// Strict mode (fail on coercion)
99    pub strict_mode: bool,
100    /// Allow type coercion
101    pub allow_type_coercion: bool,
102    /// Minimum confidence threshold (0.0-1.0)
103    pub min_confidence: f32,
104    /// Allow fuzzy field name matching
105    pub allow_fuzzy_matching: bool,
106    /// Maximum healing attempts
107    pub max_attempts: u32,
108}
109
110impl Default for HealingConfig {
111    fn default() -> Self {
112        Self {
113            enabled: true,
114            strict_mode: false,
115            allow_type_coercion: true,
116            min_confidence: 0.7,
117            allow_fuzzy_matching: true,
118            max_attempts: 3,
119        }
120    }
121}
122
123impl HealingConfig {
124    /// Create a strict healing configuration.
125    pub fn strict() -> Self {
126        Self {
127            enabled: true,
128            strict_mode: true,
129            allow_type_coercion: false,
130            min_confidence: 0.95,
131            allow_fuzzy_matching: false,
132            max_attempts: 1,
133        }
134    }
135
136    /// Create a lenient healing configuration.
137    pub fn lenient() -> Self {
138        Self {
139            enabled: true,
140            strict_mode: false,
141            allow_type_coercion: true,
142            min_confidence: 0.5,
143            allow_fuzzy_matching: true,
144            max_attempts: 5,
145        }
146    }
147}
148
149/// Provider capabilities.
150#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
151pub struct Capabilities {
152    /// Supports streaming responses
153    pub streaming: bool,
154    /// Supports function/tool calling
155    pub function_calling: bool,
156    /// Supports vision/image inputs
157    pub vision: bool,
158    /// Maximum output tokens
159    pub max_tokens: u32,
160}
161
162/// Provider configuration.
163#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
164pub struct ProviderConfig {
165    /// Provider name
166    pub name: String,
167    /// Base URL for API
168    pub base_url: String,
169    /// API key (optional for some providers)
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub api_key: Option<String>,
172    /// Default model
173    #[serde(skip_serializing_if = "Option::is_none")]
174    pub default_model: Option<String>,
175    /// Retry configuration
176    #[serde(default)]
177    pub retry_config: RetryConfig,
178    /// Request timeout
179    #[serde(with = "duration_millis")]
180    pub timeout: Duration,
181    /// Provider capabilities
182    #[serde(default)]
183    pub capabilities: Capabilities,
184}
185
186impl ProviderConfig {
187    /// Create a new provider configuration.
188    pub fn new(name: impl Into<String>, base_url: impl Into<String>) -> Self {
189        Self {
190            name: name.into(),
191            base_url: base_url.into(),
192            api_key: None,
193            default_model: None,
194            retry_config: RetryConfig::default(),
195            timeout: Duration::from_secs(30),
196            capabilities: Capabilities::default(),
197        }
198    }
199
200    /// Set the API key.
201    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
202        self.api_key = Some(api_key.into());
203        self
204    }
205
206    /// Set the default model.
207    pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
208        self.default_model = Some(model.into());
209        self
210    }
211
212    /// Set the timeout.
213    pub fn with_timeout(mut self, timeout: Duration) -> Self {
214        self.timeout = timeout;
215        self
216    }
217}
218
219/// Rate limiting scope.
220#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
221pub enum RateLimitScope {
222    /// Per provider instance (isolated rate limits)
223    #[default]
224    PerInstance,
225    /// Shared across all instances with the same API key
226    Shared,
227}
228
229/// Rate limiting configuration.
230#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
231pub struct RateLimitConfig {
232    /// Enable rate limiting
233    pub enabled: bool,
234    /// Maximum requests per second
235    pub requests_per_second: u32,
236    /// Burst size (maximum concurrent requests)
237    pub burst_size: u32,
238    /// Rate limit scope
239    #[serde(default)]
240    pub scope: RateLimitScope,
241}
242
243impl Default for RateLimitConfig {
244    fn default() -> Self {
245        Self {
246            enabled: false,
247            requests_per_second: 10,
248            burst_size: 20,
249            scope: RateLimitScope::PerInstance,
250        }
251    }
252}
253
254impl RateLimitConfig {
255    /// Validate rate-limit configuration invariants.
256    pub fn validate(&self) -> Result<(), String> {
257        if !self.enabled {
258            return Ok(());
259        }
260
261        if self.requests_per_second == 0 {
262            return Err(
263                "requests_per_second must be >= 1 when rate limiting is enabled".to_string(),
264            );
265        }
266
267        if self.burst_size == 0 {
268            return Err("burst_size must be >= 1 when rate limiting is enabled".to_string());
269        }
270
271        Ok(())
272    }
273
274    /// Create a new rate limit configuration with given requests per second.
275    ///
276    /// # Example
277    /// ```
278    /// use simple_agent_type::config::RateLimitConfig;
279    ///
280    /// let config = RateLimitConfig::new(50, 100);
281    /// assert_eq!(config.requests_per_second, 50);
282    /// assert_eq!(config.burst_size, 100);
283    /// assert!(config.enabled);
284    /// ```
285    pub fn new(requests_per_second: u32, burst_size: u32) -> Self {
286        Self {
287            enabled: true,
288            requests_per_second,
289            burst_size,
290            scope: RateLimitScope::PerInstance,
291        }
292    }
293
294    /// Create rate limit config with shared scope.
295    pub fn shared(requests_per_second: u32, burst_size: u32) -> Self {
296        Self {
297            enabled: true,
298            requests_per_second,
299            burst_size,
300            scope: RateLimitScope::Shared,
301        }
302    }
303
304    /// Disable rate limiting.
305    pub fn disabled() -> Self {
306        Self {
307            enabled: false,
308            requests_per_second: 0,
309            burst_size: 0,
310            scope: RateLimitScope::PerInstance,
311        }
312    }
313}
314
315// Serde helper for Duration serialization/deserialization as milliseconds
316mod duration_millis {
317    use serde::{Deserialize, Deserializer, Serializer};
318    use std::time::Duration;
319
320    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
321    where
322        S: Serializer,
323    {
324        serializer.serialize_u64(duration.as_millis() as u64)
325    }
326
327    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
328    where
329        D: Deserializer<'de>,
330    {
331        let millis = u64::deserialize(deserializer)?;
332        Ok(Duration::from_millis(millis))
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn test_retry_config_default() {
342        let config = RetryConfig::default();
343        assert!(config.validate().is_ok());
344        assert_eq!(config.max_attempts, 3);
345        assert_eq!(config.initial_backoff, Duration::from_millis(100));
346        assert_eq!(config.max_backoff, Duration::from_secs(10));
347        assert_eq!(config.backoff_multiplier, 2.0);
348        assert!(config.jitter);
349    }
350
351    #[test]
352    fn test_retry_config_backoff() {
353        let config = RetryConfig {
354            max_attempts: 5,
355            initial_backoff: Duration::from_millis(100),
356            max_backoff: Duration::from_secs(10),
357            backoff_multiplier: 2.0,
358            jitter: false,
359        };
360
361        let backoff1 = config.calculate_backoff(0);
362        let backoff2 = config.calculate_backoff(1);
363        let backoff3 = config.calculate_backoff(2);
364
365        assert_eq!(backoff1, Duration::from_millis(100));
366        assert_eq!(backoff2, Duration::from_millis(200));
367        assert_eq!(backoff3, Duration::from_millis(400));
368    }
369
370    #[test]
371    fn test_retry_config_validate_rejects_zero_attempts() {
372        let config = RetryConfig {
373            max_attempts: 0,
374            initial_backoff: Duration::from_millis(100),
375            max_backoff: Duration::from_secs(1),
376            backoff_multiplier: 2.0,
377            jitter: false,
378        };
379
380        assert_eq!(
381            config.validate().unwrap_err(),
382            "max_attempts must be >= 1".to_string()
383        );
384    }
385
386    #[test]
387    fn test_retry_config_deserialize_rejects_zero_attempts() {
388        let json = r#"{\"max_attempts\":0,\"initial_backoff\":100,\"max_backoff\":1000,\"backoff_multiplier\":2.0,\"jitter\":false}"#;
389        let parsed: Result<RetryConfig, _> = serde_json::from_str(json);
390        assert!(parsed.is_err());
391    }
392
393    #[test]
394    fn test_retry_config_max_backoff() {
395        let config = RetryConfig {
396            max_attempts: 10,
397            initial_backoff: Duration::from_millis(100),
398            max_backoff: Duration::from_secs(1),
399            backoff_multiplier: 2.0,
400            jitter: false,
401        };
402
403        let backoff = config.calculate_backoff(10);
404        assert!(backoff <= Duration::from_secs(1));
405    }
406
407    #[test]
408    fn test_healing_config_default() {
409        let config = HealingConfig::default();
410        assert!(config.enabled);
411        assert!(!config.strict_mode);
412        assert!(config.allow_type_coercion);
413        assert_eq!(config.min_confidence, 0.7);
414        assert!(config.allow_fuzzy_matching);
415    }
416
417    #[test]
418    fn test_healing_config_strict() {
419        let config = HealingConfig::strict();
420        assert!(config.enabled);
421        assert!(config.strict_mode);
422        assert!(!config.allow_type_coercion);
423        assert_eq!(config.min_confidence, 0.95);
424        assert!(!config.allow_fuzzy_matching);
425    }
426
427    #[test]
428    fn test_healing_config_lenient() {
429        let config = HealingConfig::lenient();
430        assert!(config.enabled);
431        assert!(!config.strict_mode);
432        assert!(config.allow_type_coercion);
433        assert_eq!(config.min_confidence, 0.5);
434        assert!(config.allow_fuzzy_matching);
435    }
436
437    #[test]
438    fn test_capabilities_default() {
439        let caps = Capabilities::default();
440        assert!(!caps.streaming);
441        assert!(!caps.function_calling);
442        assert!(!caps.vision);
443        assert_eq!(caps.max_tokens, 0);
444    }
445
446    #[test]
447    fn test_provider_config_builder() {
448        let config = ProviderConfig::new("openai", "https://api.openai.com/v1")
449            .with_api_key("sk-test")
450            .with_default_model("gpt-4")
451            .with_timeout(Duration::from_secs(60));
452
453        assert_eq!(config.name, "openai");
454        assert_eq!(config.base_url, "https://api.openai.com/v1");
455        assert_eq!(config.api_key, Some("sk-test".to_string()));
456        assert_eq!(config.default_model, Some("gpt-4".to_string()));
457        assert_eq!(config.timeout, Duration::from_secs(60));
458    }
459
460    #[test]
461    fn test_config_serialization() {
462        let config = RetryConfig::default();
463        let json = serde_json::to_string(&config).unwrap();
464        let parsed: RetryConfig = serde_json::from_str(&json).unwrap();
465        assert_eq!(config, parsed);
466    }
467
468    #[test]
469    fn test_provider_config_serialization() {
470        let config = ProviderConfig::new("test", "https://example.com");
471        let json = serde_json::to_string(&config).unwrap();
472        let parsed: ProviderConfig = serde_json::from_str(&json).unwrap();
473        assert_eq!(config.name, parsed.name);
474        assert_eq!(config.base_url, parsed.base_url);
475    }
476
477    #[test]
478    fn test_rate_limit_config_validate_enabled_requires_non_zero_values() {
479        let invalid_rps = RateLimitConfig::new(0, 10);
480        assert_eq!(
481            invalid_rps.validate().unwrap_err(),
482            "requests_per_second must be >= 1 when rate limiting is enabled"
483        );
484
485        let invalid_burst = RateLimitConfig::new(10, 0);
486        assert_eq!(
487            invalid_burst.validate().unwrap_err(),
488            "burst_size must be >= 1 when rate limiting is enabled"
489        );
490    }
491
492    #[test]
493    fn test_rate_limit_config_validate_disabled_allows_zero_values() {
494        let disabled = RateLimitConfig::disabled();
495        assert!(disabled.validate().is_ok());
496    }
497
498    #[test]
499    fn test_jitter_randomness() {
500        let config = RetryConfig {
501            max_attempts: 5,
502            initial_backoff: Duration::from_millis(100),
503            max_backoff: Duration::from_secs(10),
504            backoff_multiplier: 2.0,
505            jitter: true,
506        };
507
508        // Generate multiple backoffs and verify they're different (with high probability)
509        let backoffs: Vec<Duration> = (0..10).map(|_| config.calculate_backoff(1)).collect();
510
511        // All values should be within expected range (50-150ms for attempt 1 with jitter)
512        for backoff in &backoffs {
513            let ms = backoff.as_millis();
514            assert!(ms >= 100, "Backoff too small: {}ms", ms); // 50% of 200ms = 100ms
515            assert!(ms <= 300, "Backoff too large: {}ms", ms); // 150% of 200ms = 300ms
516        }
517
518        // At least some values should be different (very high probability with true randomness)
519        let unique_count = backoffs
520            .iter()
521            .collect::<std::collections::HashSet<_>>()
522            .len();
523        assert!(
524            unique_count > 1,
525            "All jitter values are the same - RNG may not be working"
526        );
527    }
528}