Skip to main content

simple_agent_type/
config.rs

1//! Configuration types for SimpleAgents.
2//!
3//! Provides configuration for retry, healing, and provider capabilities.
4
5use serde::{Deserialize, Serialize};
6use std::time::Duration;
7
8/// Retry configuration for failed requests.
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10pub struct RetryConfig {
11    /// Maximum number of retry attempts
12    pub max_attempts: u32,
13    /// Initial backoff duration
14    #[serde(with = "duration_millis")]
15    pub initial_backoff: Duration,
16    /// Maximum backoff duration
17    #[serde(with = "duration_millis")]
18    pub max_backoff: Duration,
19    /// Backoff multiplier for exponential backoff
20    pub backoff_multiplier: f32,
21    /// Add random jitter to backoff
22    pub jitter: bool,
23}
24
25impl Default for RetryConfig {
26    fn default() -> Self {
27        Self {
28            max_attempts: 3,
29            initial_backoff: Duration::from_millis(100),
30            max_backoff: Duration::from_secs(10),
31            backoff_multiplier: 2.0,
32            jitter: true,
33        }
34    }
35}
36
37impl RetryConfig {
38    /// Calculate backoff duration for a given attempt.
39    ///
40    /// # Example
41    /// ```
42    /// use simple_agent_type::config::RetryConfig;
43    /// use std::time::Duration;
44    ///
45    /// let config = RetryConfig::default();
46    /// let backoff = config.calculate_backoff(1);
47    /// assert!(backoff >= Duration::from_millis(100));
48    /// assert!(backoff <= Duration::from_millis(200)); // with jitter
49    /// ```
50    pub fn calculate_backoff(&self, attempt: u32) -> Duration {
51        let base =
52            self.initial_backoff.as_millis() as f32 * self.backoff_multiplier.powi(attempt as i32);
53        let capped = base.min(self.max_backoff.as_millis() as f32);
54
55        let duration = if self.jitter {
56            // Add up to 50% jitter
57            let jitter_factor = 0.5 + (rand() * 0.5);
58            Duration::from_millis((capped * jitter_factor) as u64)
59        } else {
60            Duration::from_millis(capped as u64)
61        };
62
63        duration.min(self.max_backoff)
64    }
65}
66
67// Cryptographically secure random number generator for jitter (0.0-1.0)
68fn rand() -> f32 {
69    use rand::Rng;
70    rand::thread_rng().gen()
71}
72
73/// Healing configuration for response coercion.
74#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
75pub struct HealingConfig {
76    /// Enable healing/coercion
77    pub enabled: bool,
78    /// Strict mode (fail on coercion)
79    pub strict_mode: bool,
80    /// Allow type coercion
81    pub allow_type_coercion: bool,
82    /// Minimum confidence threshold (0.0-1.0)
83    pub min_confidence: f32,
84    /// Allow fuzzy field name matching
85    pub allow_fuzzy_matching: bool,
86    /// Maximum healing attempts
87    pub max_attempts: u32,
88}
89
90impl Default for HealingConfig {
91    fn default() -> Self {
92        Self {
93            enabled: true,
94            strict_mode: false,
95            allow_type_coercion: true,
96            min_confidence: 0.7,
97            allow_fuzzy_matching: true,
98            max_attempts: 3,
99        }
100    }
101}
102
103impl HealingConfig {
104    /// Create a strict healing configuration.
105    pub fn strict() -> Self {
106        Self {
107            enabled: true,
108            strict_mode: true,
109            allow_type_coercion: false,
110            min_confidence: 0.95,
111            allow_fuzzy_matching: false,
112            max_attempts: 1,
113        }
114    }
115
116    /// Create a lenient healing configuration.
117    pub fn lenient() -> Self {
118        Self {
119            enabled: true,
120            strict_mode: false,
121            allow_type_coercion: true,
122            min_confidence: 0.5,
123            allow_fuzzy_matching: true,
124            max_attempts: 5,
125        }
126    }
127}
128
129/// Provider capabilities.
130#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
131pub struct Capabilities {
132    /// Supports streaming responses
133    pub streaming: bool,
134    /// Supports function/tool calling
135    pub function_calling: bool,
136    /// Supports vision/image inputs
137    pub vision: bool,
138    /// Maximum output tokens
139    pub max_tokens: u32,
140}
141
142/// Provider configuration.
143#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
144pub struct ProviderConfig {
145    /// Provider name
146    pub name: String,
147    /// Base URL for API
148    pub base_url: String,
149    /// API key (optional for some providers)
150    #[serde(skip_serializing_if = "Option::is_none")]
151    pub api_key: Option<String>,
152    /// Default model
153    #[serde(skip_serializing_if = "Option::is_none")]
154    pub default_model: Option<String>,
155    /// Retry configuration
156    #[serde(default)]
157    pub retry_config: RetryConfig,
158    /// Request timeout
159    #[serde(with = "duration_millis")]
160    pub timeout: Duration,
161    /// Provider capabilities
162    #[serde(default)]
163    pub capabilities: Capabilities,
164}
165
166impl ProviderConfig {
167    /// Create a new provider configuration.
168    pub fn new(name: impl Into<String>, base_url: impl Into<String>) -> Self {
169        Self {
170            name: name.into(),
171            base_url: base_url.into(),
172            api_key: None,
173            default_model: None,
174            retry_config: RetryConfig::default(),
175            timeout: Duration::from_secs(30),
176            capabilities: Capabilities::default(),
177        }
178    }
179
180    /// Set the API key.
181    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
182        self.api_key = Some(api_key.into());
183        self
184    }
185
186    /// Set the default model.
187    pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
188        self.default_model = Some(model.into());
189        self
190    }
191
192    /// Set the timeout.
193    pub fn with_timeout(mut self, timeout: Duration) -> Self {
194        self.timeout = timeout;
195        self
196    }
197}
198
199/// Rate limiting scope.
200#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
201pub enum RateLimitScope {
202    /// Per provider instance (isolated rate limits)
203    #[default]
204    PerInstance,
205    /// Shared across all instances with the same API key
206    Shared,
207}
208
209/// Rate limiting configuration.
210#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
211pub struct RateLimitConfig {
212    /// Enable rate limiting
213    pub enabled: bool,
214    /// Maximum requests per second
215    pub requests_per_second: u32,
216    /// Burst size (maximum concurrent requests)
217    pub burst_size: u32,
218    /// Rate limit scope
219    #[serde(default)]
220    pub scope: RateLimitScope,
221}
222
223impl Default for RateLimitConfig {
224    fn default() -> Self {
225        Self {
226            enabled: false,
227            requests_per_second: 10,
228            burst_size: 20,
229            scope: RateLimitScope::PerInstance,
230        }
231    }
232}
233
234impl RateLimitConfig {
235    /// Create a new rate limit configuration with given requests per second.
236    ///
237    /// # Example
238    /// ```
239    /// use simple_agent_type::config::RateLimitConfig;
240    ///
241    /// let config = RateLimitConfig::new(50, 100);
242    /// assert_eq!(config.requests_per_second, 50);
243    /// assert_eq!(config.burst_size, 100);
244    /// assert!(config.enabled);
245    /// ```
246    pub fn new(requests_per_second: u32, burst_size: u32) -> Self {
247        Self {
248            enabled: true,
249            requests_per_second,
250            burst_size,
251            scope: RateLimitScope::PerInstance,
252        }
253    }
254
255    /// Create rate limit config with shared scope.
256    pub fn shared(requests_per_second: u32, burst_size: u32) -> Self {
257        Self {
258            enabled: true,
259            requests_per_second,
260            burst_size,
261            scope: RateLimitScope::Shared,
262        }
263    }
264
265    /// Disable rate limiting.
266    pub fn disabled() -> Self {
267        Self {
268            enabled: false,
269            requests_per_second: 0,
270            burst_size: 0,
271            scope: RateLimitScope::PerInstance,
272        }
273    }
274}
275
276// Serde helper for Duration serialization/deserialization as milliseconds
277mod duration_millis {
278    use serde::{Deserialize, Deserializer, Serializer};
279    use std::time::Duration;
280
281    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
282    where
283        S: Serializer,
284    {
285        serializer.serialize_u64(duration.as_millis() as u64)
286    }
287
288    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
289    where
290        D: Deserializer<'de>,
291    {
292        let millis = u64::deserialize(deserializer)?;
293        Ok(Duration::from_millis(millis))
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    #[test]
302    fn test_retry_config_default() {
303        let config = RetryConfig::default();
304        assert_eq!(config.max_attempts, 3);
305        assert_eq!(config.initial_backoff, Duration::from_millis(100));
306        assert_eq!(config.max_backoff, Duration::from_secs(10));
307        assert_eq!(config.backoff_multiplier, 2.0);
308        assert!(config.jitter);
309    }
310
311    #[test]
312    fn test_retry_config_backoff() {
313        let config = RetryConfig {
314            max_attempts: 5,
315            initial_backoff: Duration::from_millis(100),
316            max_backoff: Duration::from_secs(10),
317            backoff_multiplier: 2.0,
318            jitter: false,
319        };
320
321        let backoff1 = config.calculate_backoff(0);
322        let backoff2 = config.calculate_backoff(1);
323        let backoff3 = config.calculate_backoff(2);
324
325        assert_eq!(backoff1, Duration::from_millis(100));
326        assert_eq!(backoff2, Duration::from_millis(200));
327        assert_eq!(backoff3, Duration::from_millis(400));
328    }
329
330    #[test]
331    fn test_retry_config_max_backoff() {
332        let config = RetryConfig {
333            max_attempts: 10,
334            initial_backoff: Duration::from_millis(100),
335            max_backoff: Duration::from_secs(1),
336            backoff_multiplier: 2.0,
337            jitter: false,
338        };
339
340        let backoff = config.calculate_backoff(10);
341        assert!(backoff <= Duration::from_secs(1));
342    }
343
344    #[test]
345    fn test_healing_config_default() {
346        let config = HealingConfig::default();
347        assert!(config.enabled);
348        assert!(!config.strict_mode);
349        assert!(config.allow_type_coercion);
350        assert_eq!(config.min_confidence, 0.7);
351        assert!(config.allow_fuzzy_matching);
352    }
353
354    #[test]
355    fn test_healing_config_strict() {
356        let config = HealingConfig::strict();
357        assert!(config.enabled);
358        assert!(config.strict_mode);
359        assert!(!config.allow_type_coercion);
360        assert_eq!(config.min_confidence, 0.95);
361        assert!(!config.allow_fuzzy_matching);
362    }
363
364    #[test]
365    fn test_healing_config_lenient() {
366        let config = HealingConfig::lenient();
367        assert!(config.enabled);
368        assert!(!config.strict_mode);
369        assert!(config.allow_type_coercion);
370        assert_eq!(config.min_confidence, 0.5);
371        assert!(config.allow_fuzzy_matching);
372    }
373
374    #[test]
375    fn test_capabilities_default() {
376        let caps = Capabilities::default();
377        assert!(!caps.streaming);
378        assert!(!caps.function_calling);
379        assert!(!caps.vision);
380        assert_eq!(caps.max_tokens, 0);
381    }
382
383    #[test]
384    fn test_provider_config_builder() {
385        let config = ProviderConfig::new("openai", "https://api.openai.com/v1")
386            .with_api_key("sk-test")
387            .with_default_model("gpt-4")
388            .with_timeout(Duration::from_secs(60));
389
390        assert_eq!(config.name, "openai");
391        assert_eq!(config.base_url, "https://api.openai.com/v1");
392        assert_eq!(config.api_key, Some("sk-test".to_string()));
393        assert_eq!(config.default_model, Some("gpt-4".to_string()));
394        assert_eq!(config.timeout, Duration::from_secs(60));
395    }
396
397    #[test]
398    fn test_config_serialization() {
399        let config = RetryConfig::default();
400        let json = serde_json::to_string(&config).unwrap();
401        let parsed: RetryConfig = serde_json::from_str(&json).unwrap();
402        assert_eq!(config, parsed);
403    }
404
405    #[test]
406    fn test_provider_config_serialization() {
407        let config = ProviderConfig::new("test", "https://example.com");
408        let json = serde_json::to_string(&config).unwrap();
409        let parsed: ProviderConfig = serde_json::from_str(&json).unwrap();
410        assert_eq!(config.name, parsed.name);
411        assert_eq!(config.base_url, parsed.base_url);
412    }
413
414    #[test]
415    fn test_jitter_randomness() {
416        let config = RetryConfig {
417            max_attempts: 5,
418            initial_backoff: Duration::from_millis(100),
419            max_backoff: Duration::from_secs(10),
420            backoff_multiplier: 2.0,
421            jitter: true,
422        };
423
424        // Generate multiple backoffs and verify they're different (with high probability)
425        let backoffs: Vec<Duration> = (0..10).map(|_| config.calculate_backoff(1)).collect();
426
427        // All values should be within expected range (50-150ms for attempt 1 with jitter)
428        for backoff in &backoffs {
429            let ms = backoff.as_millis();
430            assert!(ms >= 100, "Backoff too small: {}ms", ms); // 50% of 200ms = 100ms
431            assert!(ms <= 300, "Backoff too large: {}ms", ms); // 150% of 200ms = 300ms
432        }
433
434        // At least some values should be different (very high probability with true randomness)
435        let unique_count = backoffs
436            .iter()
437            .collect::<std::collections::HashSet<_>>()
438            .len();
439        assert!(
440            unique_count > 1,
441            "All jitter values are the same - RNG may not be working"
442        );
443    }
444}