Skip to main content

simple_agents_router/
round_robin.rs

1//! Round-robin routing implementation.
2//!
3//! Distributes requests evenly across configured providers.
4
5use simple_agent_type::prelude::{
6    CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
7};
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10
11/// Router that selects providers using round-robin order.
12pub struct RoundRobinRouter {
13    providers: Vec<Arc<dyn Provider>>,
14    counter: AtomicUsize,
15}
16
17impl RoundRobinRouter {
18    /// Create a new round-robin router.
19    ///
20    /// # Errors
21    /// Returns a routing error if no providers are supplied.
22    pub fn new(providers: Vec<Arc<dyn Provider>>) -> Result<Self> {
23        if providers.is_empty() {
24            return Err(SimpleAgentsError::Routing(
25                "no providers configured".to_string(),
26            ));
27        }
28
29        Ok(Self {
30            providers,
31            counter: AtomicUsize::new(0),
32        })
33    }
34
35    /// Return the number of configured providers.
36    pub fn provider_count(&self) -> usize {
37        self.providers.len()
38    }
39
40    /// Execute a completion request using round-robin provider selection.
41    pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
42        let index = self.select_provider_index()?;
43        let provider = &self.providers[index];
44        let provider_request = provider.transform_request(request)?;
45        let provider_response = provider.execute(provider_request).await?;
46        provider.transform_response(provider_response)
47    }
48
49    /// Execute a streaming request using round-robin provider selection.
50    pub async fn stream(
51        &self,
52        request: &CompletionRequest,
53    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
54        let index = self.select_provider_index()?;
55        let provider = &self.providers[index];
56        eprintln!(
57            "RoundRobinRouter.stream: provider={}, stream={:?}",
58            provider.name(),
59            request.stream
60        );
61        let provider_request = provider.transform_request(request)?;
62        provider.execute_stream(provider_request).await
63    }
64
65    fn select_provider_index(&self) -> Result<usize> {
66        let len = self.providers.len();
67        if len == 0 {
68            return Err(SimpleAgentsError::Routing(
69                "no providers configured".to_string(),
70            ));
71        }
72
73        let index = self.counter.fetch_add(1, Ordering::Relaxed);
74        Ok(index % len)
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81    use async_trait::async_trait;
82    use simple_agent_type::prelude::*;
83
84    struct MockProvider {
85        name: &'static str,
86    }
87
88    impl MockProvider {
89        fn new(name: &'static str) -> Self {
90            Self { name }
91        }
92    }
93
94    #[async_trait]
95    impl Provider for MockProvider {
96        fn name(&self) -> &str {
97            self.name
98        }
99
100        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
101            Ok(ProviderRequest::new("http://example.com"))
102        }
103
104        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
105            Ok(ProviderResponse::new(200, serde_json::Value::Null))
106        }
107
108        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
109            Ok(CompletionResponse {
110                id: "resp_test".to_string(),
111                model: "test-model".to_string(),
112                choices: vec![CompletionChoice {
113                    index: 0,
114                    message: Message::assistant("ok"),
115                    finish_reason: FinishReason::Stop,
116                    logprobs: None,
117                }],
118                usage: Usage::new(1, 1),
119                created: None,
120                provider: Some(self.name().to_string()),
121                healing_metadata: None,
122            })
123        }
124    }
125
126    fn build_request() -> CompletionRequest {
127        CompletionRequest::builder()
128            .model("test-model")
129            .message(Message::user("hello"))
130            .build()
131            .unwrap()
132    }
133
134    #[test]
135    fn empty_router_returns_error() {
136        let result = RoundRobinRouter::new(Vec::new());
137        match result {
138            Ok(_) => panic!("expected error, got Ok"),
139            Err(SimpleAgentsError::Routing(message)) => {
140                assert_eq!(message, "no providers configured");
141            }
142            Err(_) => panic!("unexpected error type"),
143        }
144    }
145
146    #[tokio::test]
147    async fn round_robin_rotates_providers() {
148        let router = RoundRobinRouter::new(vec![
149            Arc::new(MockProvider::new("p1")),
150            Arc::new(MockProvider::new("p2")),
151        ])
152        .unwrap();
153
154        let request = build_request();
155        let first = router.complete(&request).await.unwrap();
156        let second = router.complete(&request).await.unwrap();
157        let third = router.complete(&request).await.unwrap();
158
159        assert_eq!(first.provider.as_deref(), Some("p1"));
160        assert_eq!(second.provider.as_deref(), Some("p2"));
161        assert_eq!(third.provider.as_deref(), Some("p1"));
162    }
163}