simple_agent_type/router.rs
1//! Routing strategy trait for provider selection.
2//!
3//! Provides abstractions for routing requests across multiple providers.
4
5use crate::config::ProviderConfig;
6use crate::error::Result;
7use crate::request::CompletionRequest;
8use async_trait::async_trait;
9use std::time::Duration;
10
11/// Trait for routing strategies.
12///
13/// Implementations define how requests are routed across multiple providers:
14/// - Round-robin: Distribute requests evenly
15/// - Priority: Try providers in order
16/// - Latency-based: Route to fastest provider
17/// - Load-balancing: Consider provider load
18/// - Cost-optimized: Route to cheapest provider
19///
20/// # Example Implementation
21///
22/// ```rust
23/// use simple_agent_type::router::RoutingStrategy;
24/// use simple_agent_type::config::ProviderConfig;
25/// use simple_agent_type::request::CompletionRequest;
26/// use simple_agent_type::message::Message;
27/// use simple_agent_type::error::{Result, SimpleAgentsError};
28/// use async_trait::async_trait;
29/// use std::sync::atomic::{AtomicUsize, Ordering};
30///
31/// struct RoundRobinStrategy {
32/// counter: AtomicUsize,
33/// }
34///
35/// #[async_trait]
36/// impl RoutingStrategy for RoundRobinStrategy {
37/// async fn select_provider(
38/// &self,
39/// providers: &[ProviderConfig],
40/// _request: &CompletionRequest,
41/// ) -> Result<usize> {
42/// if providers.is_empty() {
43/// return Err(SimpleAgentsError::Routing("no providers".to_string()));
44/// }
45/// let index = self.counter.fetch_add(1, Ordering::Relaxed);
46/// Ok(index % providers.len())
47/// }
48/// }
49///
50/// let strategy = RoundRobinStrategy {
51/// counter: AtomicUsize::new(0),
52/// };
53/// let providers = vec![
54/// ProviderConfig::new("p1", "http://example.com"),
55/// ProviderConfig::new("p2", "http://example.com"),
56/// ];
57/// let request = CompletionRequest::builder()
58/// .model("gpt-4")
59/// .message(Message::user("Hello!"))
60/// .build()
61/// .unwrap();
62///
63/// let rt = tokio::runtime::Runtime::new().unwrap();
64/// rt.block_on(async {
65/// let index = strategy.select_provider(&providers, &request).await.unwrap();
66/// assert!(index < providers.len());
67/// });
68/// ```
69#[async_trait]
70pub trait RoutingStrategy: Send + Sync {
71 /// Select a provider index for the given request.
72 ///
73 /// # Arguments
74 /// - `providers`: Available providers
75 /// - `request`: The completion request
76 ///
77 /// # Returns
78 /// Index of the selected provider in the `providers` slice.
79 ///
80 /// # Errors
81 /// - If no suitable provider is found
82 /// - If all providers are unavailable
83 async fn select_provider(
84 &self,
85 providers: &[ProviderConfig],
86 request: &CompletionRequest,
87 ) -> Result<usize>;
88
89 /// Report successful request completion.
90 ///
91 /// Used by latency-based and adaptive routing strategies to track
92 /// provider performance.
93 ///
94 /// # Arguments
95 /// - `provider_index`: Index of the provider that succeeded
96 /// - `latency`: Request duration
97 async fn report_success(&self, provider_index: usize, latency: Duration) {
98 let _ = (provider_index, latency);
99 }
100
101 /// Report request failure.
102 ///
103 /// Used by reliability-tracking routing strategies.
104 ///
105 /// # Arguments
106 /// - `provider_index`: Index of the provider that failed
107 async fn report_failure(&self, provider_index: usize) {
108 let _ = provider_index;
109 }
110
111 /// Get strategy name (for logging/debugging).
112 fn name(&self) -> &str {
113 "routing-strategy"
114 }
115}
116
117/// Routing mode enum for common strategies.
118#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
119pub enum RoutingMode {
120 /// Try providers in priority order
121 Priority,
122 /// Distribute requests evenly (round-robin)
123 RoundRobin,
124 /// Route to provider with lowest latency
125 LatencyBased,
126 /// Random selection
127 Random,
128}
129
130impl RoutingMode {
131 /// Get a human-readable description.
132 pub fn description(&self) -> &str {
133 match self {
134 Self::Priority => "Try providers in priority order",
135 Self::RoundRobin => "Distribute requests evenly across providers",
136 Self::LatencyBased => "Route to provider with lowest average latency",
137 Self::Random => "Randomly select provider",
138 }
139 }
140}
141
142/// Provider health status.
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
144pub enum ProviderHealth {
145 /// Provider is healthy
146 Healthy,
147 /// Provider is degraded (high error rate)
148 Degraded,
149 /// Provider is unavailable
150 Unavailable,
151}
152
153impl ProviderHealth {
154 /// Check if provider can be used.
155 pub fn is_available(&self) -> bool {
156 matches!(self, Self::Healthy | Self::Degraded)
157 }
158}
159
160/// Provider metrics for routing decisions.
161#[derive(Debug, Clone, Copy, PartialEq)]
162pub struct ProviderMetrics {
163 /// Total requests sent
164 pub total_requests: u64,
165 /// Successful requests
166 pub successful_requests: u64,
167 /// Failed requests
168 pub failed_requests: u64,
169 /// Average latency
170 pub avg_latency: Duration,
171 /// Current health status
172 pub health: ProviderHealth,
173}
174
175impl Default for ProviderMetrics {
176 fn default() -> Self {
177 Self {
178 total_requests: 0,
179 successful_requests: 0,
180 failed_requests: 0,
181 avg_latency: Duration::from_millis(0),
182 health: ProviderHealth::Healthy,
183 }
184 }
185}
186
187impl ProviderMetrics {
188 /// Calculate success rate (0.0-1.0).
189 pub fn success_rate(&self) -> f32 {
190 if self.total_requests == 0 {
191 return 1.0;
192 }
193 self.successful_requests as f32 / self.total_requests as f32
194 }
195
196 /// Calculate failure rate (0.0-1.0).
197 pub fn failure_rate(&self) -> f32 {
198 1.0 - self.success_rate()
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn test_routing_mode_description() {
208 assert!(!RoutingMode::Priority.description().is_empty());
209 assert!(!RoutingMode::RoundRobin.description().is_empty());
210 assert!(!RoutingMode::LatencyBased.description().is_empty());
211 assert!(!RoutingMode::Random.description().is_empty());
212 }
213
214 #[test]
215 fn test_provider_health_is_available() {
216 assert!(ProviderHealth::Healthy.is_available());
217 assert!(ProviderHealth::Degraded.is_available());
218 assert!(!ProviderHealth::Unavailable.is_available());
219 }
220
221 #[test]
222 fn test_provider_metrics_default() {
223 let metrics = ProviderMetrics::default();
224 assert_eq!(metrics.total_requests, 0);
225 assert_eq!(metrics.successful_requests, 0);
226 assert_eq!(metrics.failed_requests, 0);
227 assert_eq!(metrics.success_rate(), 1.0);
228 assert_eq!(metrics.failure_rate(), 0.0);
229 }
230
231 #[test]
232 fn test_provider_metrics_success_rate() {
233 let metrics = ProviderMetrics {
234 total_requests: 100,
235 successful_requests: 95,
236 failed_requests: 5,
237 avg_latency: Duration::from_millis(200),
238 health: ProviderHealth::Healthy,
239 };
240
241 assert!((metrics.success_rate() - 0.95).abs() < 0.001);
242 assert!((metrics.failure_rate() - 0.05).abs() < 0.001);
243 }
244
245 #[test]
246 fn test_provider_metrics_zero_requests() {
247 let metrics = ProviderMetrics {
248 total_requests: 0,
249 successful_requests: 0,
250 failed_requests: 0,
251 avg_latency: Duration::from_millis(0),
252 health: ProviderHealth::Healthy,
253 };
254
255 // Default to 100% success rate when no data
256 assert_eq!(metrics.success_rate(), 1.0);
257 assert_eq!(metrics.failure_rate(), 0.0);
258 }
259
260 // Test that RoutingStrategy trait is object-safe
261 #[test]
262 fn test_routing_strategy_object_safety() {
263 fn _assert_object_safe(_: &dyn RoutingStrategy) {}
264 }
265}