simple_agents_router/
round_robin.rs1use simple_agent_type::prelude::{
6 CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
7};
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10
11pub struct RoundRobinRouter {
13 providers: Vec<Arc<dyn Provider>>,
14 counter: AtomicUsize,
15}
16
17impl RoundRobinRouter {
18 pub fn new(providers: Vec<Arc<dyn Provider>>) -> Result<Self> {
23 if providers.is_empty() {
24 return Err(SimpleAgentsError::Routing(
25 "no providers configured".to_string(),
26 ));
27 }
28
29 Ok(Self {
30 providers,
31 counter: AtomicUsize::new(0),
32 })
33 }
34
35 pub fn provider_count(&self) -> usize {
37 self.providers.len()
38 }
39
40 pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
42 let index = self.select_provider_index()?;
43 let provider = &self.providers[index];
44 let provider_request = provider.transform_request(request)?;
45 let provider_response = provider.execute(provider_request).await?;
46 provider.transform_response(provider_response)
47 }
48
49 pub async fn stream(
51 &self,
52 request: &CompletionRequest,
53 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
54 let index = self.select_provider_index()?;
55 let provider = &self.providers[index];
56 let provider_request = provider.transform_request(request)?;
57 provider.execute_stream(provider_request).await
58 }
59
60 fn select_provider_index(&self) -> Result<usize> {
61 let len = self.providers.len();
62 if len == 0 {
63 return Err(SimpleAgentsError::Routing(
64 "no providers configured".to_string(),
65 ));
66 }
67
68 let index = self.counter.fetch_add(1, Ordering::Relaxed);
69 Ok(index % len)
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use async_trait::async_trait;
77 use simple_agent_type::prelude::*;
78
79 struct MockProvider {
80 name: &'static str,
81 }
82
83 impl MockProvider {
84 fn new(name: &'static str) -> Self {
85 Self { name }
86 }
87 }
88
89 #[async_trait]
90 impl Provider for MockProvider {
91 fn name(&self) -> &str {
92 self.name
93 }
94
95 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
96 Ok(ProviderRequest::new("http://example.com"))
97 }
98
99 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
100 Ok(ProviderResponse::new(200, serde_json::Value::Null))
101 }
102
103 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
104 Ok(CompletionResponse {
105 id: "resp_test".to_string(),
106 model: "test-model".to_string(),
107 choices: vec![CompletionChoice {
108 index: 0,
109 message: Message::assistant("ok"),
110 finish_reason: FinishReason::Stop,
111 logprobs: None,
112 }],
113 usage: Usage::new(1, 1),
114 created: None,
115 provider: Some(self.name().to_string()),
116 healing_metadata: None,
117 })
118 }
119 }
120
121 fn build_request() -> CompletionRequest {
122 CompletionRequest::builder()
123 .model("test-model")
124 .message(Message::user("hello"))
125 .build()
126 .unwrap()
127 }
128
129 #[test]
130 fn empty_router_returns_error() {
131 let result = RoundRobinRouter::new(Vec::new());
132 match result {
133 Ok(_) => panic!("expected error, got Ok"),
134 Err(SimpleAgentsError::Routing(message)) => {
135 assert_eq!(message, "no providers configured");
136 }
137 Err(_) => panic!("unexpected error type"),
138 }
139 }
140
141 #[tokio::test]
142 async fn round_robin_rotates_providers() {
143 let router = RoundRobinRouter::new(vec![
144 Arc::new(MockProvider::new("p1")),
145 Arc::new(MockProvider::new("p2")),
146 ])
147 .unwrap();
148
149 let request = build_request();
150 let first = router.complete(&request).await.unwrap();
151 let second = router.complete(&request).await.unwrap();
152 let third = router.complete(&request).await.unwrap();
153
154 assert_eq!(first.provider.as_deref(), Some("p1"));
155 assert_eq!(second.provider.as_deref(), Some("p2"));
156 assert_eq!(third.provider.as_deref(), Some("p1"));
157 }
158}