Skip to main content

simple_agents_router/
round_robin.rs

1//! Round-robin routing implementation.
2//!
3//! Distributes requests evenly across configured providers.
4
5use simple_agent_type::prelude::{
6    CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
7};
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10
11/// Router that selects providers using round-robin order.
12pub struct RoundRobinRouter {
13    providers: Vec<Arc<dyn Provider>>,
14    counter: AtomicUsize,
15}
16
17impl RoundRobinRouter {
18    /// Create a new round-robin router.
19    ///
20    /// # Errors
21    /// Returns a routing error if no providers are supplied.
22    pub fn new(providers: Vec<Arc<dyn Provider>>) -> Result<Self> {
23        if providers.is_empty() {
24            return Err(SimpleAgentsError::Routing(
25                "no providers configured".to_string(),
26            ));
27        }
28
29        Ok(Self {
30            providers,
31            counter: AtomicUsize::new(0),
32        })
33    }
34
35    /// Return the number of configured providers.
36    pub fn provider_count(&self) -> usize {
37        self.providers.len()
38    }
39
40    /// Execute a completion request using round-robin provider selection.
41    pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
42        let index = self.select_provider_index()?;
43        let provider = &self.providers[index];
44        let provider_request = provider.transform_request(request)?;
45        let provider_response = provider.execute(provider_request).await?;
46        provider.transform_response(provider_response)
47    }
48
49    /// Execute a streaming request using round-robin provider selection.
50    pub async fn stream(
51        &self,
52        request: &CompletionRequest,
53    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
54        let index = self.select_provider_index()?;
55        let provider = &self.providers[index];
56        let provider_request = provider.transform_request(request)?;
57        provider.execute_stream(provider_request).await
58    }
59
60    fn select_provider_index(&self) -> Result<usize> {
61        let len = self.providers.len();
62        if len == 0 {
63            return Err(SimpleAgentsError::Routing(
64                "no providers configured".to_string(),
65            ));
66        }
67
68        let index = self.counter.fetch_add(1, Ordering::Relaxed);
69        Ok(index % len)
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76    use async_trait::async_trait;
77    use simple_agent_type::prelude::*;
78
79    struct MockProvider {
80        name: &'static str,
81    }
82
83    impl MockProvider {
84        fn new(name: &'static str) -> Self {
85            Self { name }
86        }
87    }
88
89    #[async_trait]
90    impl Provider for MockProvider {
91        fn name(&self) -> &str {
92            self.name
93        }
94
95        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
96            Ok(ProviderRequest::new("http://example.com"))
97        }
98
99        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
100            Ok(ProviderResponse::new(200, serde_json::Value::Null))
101        }
102
103        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
104            Ok(CompletionResponse {
105                id: "resp_test".to_string(),
106                model: "test-model".to_string(),
107                choices: vec![CompletionChoice {
108                    index: 0,
109                    message: Message::assistant("ok"),
110                    finish_reason: FinishReason::Stop,
111                    logprobs: None,
112                }],
113                usage: Usage::new(1, 1),
114                created: None,
115                provider: Some(self.name().to_string()),
116                healing_metadata: None,
117            })
118        }
119    }
120
121    fn build_request() -> CompletionRequest {
122        CompletionRequest::builder()
123            .model("test-model")
124            .message(Message::user("hello"))
125            .build()
126            .unwrap()
127    }
128
129    #[test]
130    fn empty_router_returns_error() {
131        let result = RoundRobinRouter::new(Vec::new());
132        match result {
133            Ok(_) => panic!("expected error, got Ok"),
134            Err(SimpleAgentsError::Routing(message)) => {
135                assert_eq!(message, "no providers configured");
136            }
137            Err(_) => panic!("unexpected error type"),
138        }
139    }
140
141    #[tokio::test]
142    async fn round_robin_rotates_providers() {
143        let router = RoundRobinRouter::new(vec![
144            Arc::new(MockProvider::new("p1")),
145            Arc::new(MockProvider::new("p2")),
146        ])
147        .unwrap();
148
149        let request = build_request();
150        let first = router.complete(&request).await.unwrap();
151        let second = router.complete(&request).await.unwrap();
152        let third = router.complete(&request).await.unwrap();
153
154        assert_eq!(first.provider.as_deref(), Some("p1"));
155        assert_eq!(second.provider.as_deref(), Some("p2"));
156        assert_eq!(third.provider.as_deref(), Some("p1"));
157    }
158}