Skip to main content

simple_agent_type/
router.rs

1//! Routing strategy trait for provider selection.
2//!
3//! Provides abstractions for routing requests across multiple providers.
4
5use crate::config::ProviderConfig;
6use crate::error::Result;
7use crate::request::CompletionRequest;
8use async_trait::async_trait;
9use std::time::Duration;
10
11/// Trait for routing strategies.
12///
13/// Implementations define how requests are routed across multiple providers:
14/// - Round-robin: Distribute requests evenly
15/// - Priority: Try providers in order
16/// - Latency-based: Route to fastest provider
17/// - Load-balancing: Consider provider load
18/// - Cost-optimized: Route to cheapest provider
19///
20/// # Example Implementation
21///
22/// ```rust
23/// use simple_agent_type::router::RoutingStrategy;
24/// use simple_agent_type::config::ProviderConfig;
25/// use simple_agent_type::request::CompletionRequest;
26/// use simple_agent_type::message::Message;
27/// use simple_agent_type::error::{Result, SimpleAgentsError};
28/// use async_trait::async_trait;
29/// use std::sync::atomic::{AtomicUsize, Ordering};
30///
31/// struct RoundRobinStrategy {
32///     counter: AtomicUsize,
33/// }
34///
35/// #[async_trait]
36/// impl RoutingStrategy for RoundRobinStrategy {
37///     async fn select_provider(
38///         &self,
39///         providers: &[ProviderConfig],
40///         _request: &CompletionRequest,
41///     ) -> Result<usize> {
42///         if providers.is_empty() {
43///             return Err(SimpleAgentsError::Routing("no providers".to_string()));
44///         }
45///         let index = self.counter.fetch_add(1, Ordering::Relaxed);
46///         Ok(index % providers.len())
47///     }
48/// }
49///
50/// let strategy = RoundRobinStrategy {
51///     counter: AtomicUsize::new(0),
52/// };
53/// let providers = vec![
54///     ProviderConfig::new("p1", "http://example.com"),
55///     ProviderConfig::new("p2", "http://example.com"),
56/// ];
57/// let request = CompletionRequest::builder()
58///     .model("gpt-4")
59///     .message(Message::user("Hello!"))
60///     .build()
61///     .unwrap();
62///
63/// let rt = tokio::runtime::Runtime::new().unwrap();
64/// rt.block_on(async {
65///     let index = strategy.select_provider(&providers, &request).await.unwrap();
66///     assert!(index < providers.len());
67/// });
68/// ```
69#[async_trait]
70pub trait RoutingStrategy: Send + Sync {
71    /// Select a provider index for the given request.
72    ///
73    /// # Arguments
74    /// - `providers`: Available providers
75    /// - `request`: The completion request
76    ///
77    /// # Returns
78    /// Index of the selected provider in the `providers` slice.
79    ///
80    /// # Errors
81    /// - If no suitable provider is found
82    /// - If all providers are unavailable
83    async fn select_provider(
84        &self,
85        providers: &[ProviderConfig],
86        request: &CompletionRequest,
87    ) -> Result<usize>;
88
89    /// Report successful request completion.
90    ///
91    /// Used by latency-based and adaptive routing strategies to track
92    /// provider performance.
93    ///
94    /// # Arguments
95    /// - `provider_index`: Index of the provider that succeeded
96    /// - `latency`: Request duration
97    async fn report_success(&self, provider_index: usize, latency: Duration) {
98        let _ = (provider_index, latency);
99    }
100
101    /// Report request failure.
102    ///
103    /// Used by reliability-tracking routing strategies.
104    ///
105    /// # Arguments
106    /// - `provider_index`: Index of the provider that failed
107    async fn report_failure(&self, provider_index: usize) {
108        let _ = provider_index;
109    }
110
111    /// Get strategy name (for logging/debugging).
112    fn name(&self) -> &str {
113        "routing-strategy"
114    }
115}
116
117/// Routing mode enum for common strategies.
118#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
119pub enum RoutingMode {
120    /// Try providers in priority order
121    Priority,
122    /// Distribute requests evenly (round-robin)
123    RoundRobin,
124    /// Route to provider with lowest latency
125    LatencyBased,
126    /// Random selection
127    Random,
128}
129
130impl RoutingMode {
131    /// Get a human-readable description.
132    pub fn description(&self) -> &str {
133        match self {
134            Self::Priority => "Try providers in priority order",
135            Self::RoundRobin => "Distribute requests evenly across providers",
136            Self::LatencyBased => "Route to provider with lowest average latency",
137            Self::Random => "Randomly select provider",
138        }
139    }
140}
141
142/// Provider health status.
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
144pub enum ProviderHealth {
145    /// Provider is healthy
146    Healthy,
147    /// Provider is degraded (high error rate)
148    Degraded,
149    /// Provider is unavailable
150    Unavailable,
151}
152
153impl ProviderHealth {
154    /// Check if provider can be used.
155    pub fn is_available(&self) -> bool {
156        matches!(self, Self::Healthy | Self::Degraded)
157    }
158}
159
160/// Provider metrics for routing decisions.
161#[derive(Debug, Clone, Copy, PartialEq)]
162pub struct ProviderMetrics {
163    /// Total requests sent
164    pub total_requests: u64,
165    /// Successful requests
166    pub successful_requests: u64,
167    /// Failed requests
168    pub failed_requests: u64,
169    /// Average latency
170    pub avg_latency: Duration,
171    /// Current health status
172    pub health: ProviderHealth,
173}
174
175impl Default for ProviderMetrics {
176    fn default() -> Self {
177        Self {
178            total_requests: 0,
179            successful_requests: 0,
180            failed_requests: 0,
181            avg_latency: Duration::from_millis(0),
182            health: ProviderHealth::Healthy,
183        }
184    }
185}
186
187impl ProviderMetrics {
188    /// Calculate success rate (0.0-1.0).
189    pub fn success_rate(&self) -> f32 {
190        if self.total_requests == 0 {
191            return 1.0;
192        }
193        self.successful_requests as f32 / self.total_requests as f32
194    }
195
196    /// Calculate failure rate (0.0-1.0).
197    pub fn failure_rate(&self) -> f32 {
198        1.0 - self.success_rate()
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_routing_mode_description() {
208        assert!(!RoutingMode::Priority.description().is_empty());
209        assert!(!RoutingMode::RoundRobin.description().is_empty());
210        assert!(!RoutingMode::LatencyBased.description().is_empty());
211        assert!(!RoutingMode::Random.description().is_empty());
212    }
213
214    #[test]
215    fn test_provider_health_is_available() {
216        assert!(ProviderHealth::Healthy.is_available());
217        assert!(ProviderHealth::Degraded.is_available());
218        assert!(!ProviderHealth::Unavailable.is_available());
219    }
220
221    #[test]
222    fn test_provider_metrics_default() {
223        let metrics = ProviderMetrics::default();
224        assert_eq!(metrics.total_requests, 0);
225        assert_eq!(metrics.successful_requests, 0);
226        assert_eq!(metrics.failed_requests, 0);
227        assert_eq!(metrics.success_rate(), 1.0);
228        assert_eq!(metrics.failure_rate(), 0.0);
229    }
230
231    #[test]
232    fn test_provider_metrics_success_rate() {
233        let metrics = ProviderMetrics {
234            total_requests: 100,
235            successful_requests: 95,
236            failed_requests: 5,
237            avg_latency: Duration::from_millis(200),
238            health: ProviderHealth::Healthy,
239        };
240
241        assert!((metrics.success_rate() - 0.95).abs() < 0.001);
242        assert!((metrics.failure_rate() - 0.05).abs() < 0.001);
243    }
244
245    #[test]
246    fn test_provider_metrics_zero_requests() {
247        let metrics = ProviderMetrics {
248            total_requests: 0,
249            successful_requests: 0,
250            failed_requests: 0,
251            avg_latency: Duration::from_millis(0),
252            health: ProviderHealth::Healthy,
253        };
254
255        // Default to 100% success rate when no data
256        assert_eq!(metrics.success_rate(), 1.0);
257        assert_eq!(metrics.failure_rate(), 0.0);
258    }
259
260    // Test that RoutingStrategy trait is object-safe
261    #[test]
262    fn test_routing_strategy_object_safety() {
263        fn _assert_object_safe(_: &dyn RoutingStrategy) {}
264    }
265}