simple_agents_router/
round_robin.rs1use simple_agent_type::prelude::{
6 CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
7};
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10
11pub struct RoundRobinRouter {
13 providers: Vec<Arc<dyn Provider>>,
14 counter: AtomicUsize,
15}
16
17impl RoundRobinRouter {
18 pub fn new(providers: Vec<Arc<dyn Provider>>) -> Result<Self> {
23 if providers.is_empty() {
24 return Err(SimpleAgentsError::Routing(
25 "no providers configured".to_string(),
26 ));
27 }
28
29 Ok(Self {
30 providers,
31 counter: AtomicUsize::new(0),
32 })
33 }
34
35 pub fn provider_count(&self) -> usize {
37 self.providers.len()
38 }
39
40 pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
42 let index = self.select_provider_index()?;
43 let provider = &self.providers[index];
44 let provider_request = provider.transform_request(request)?;
45 let provider_response = provider.execute(provider_request).await?;
46 provider.transform_response(provider_response)
47 }
48
49 pub async fn stream(
51 &self,
52 request: &CompletionRequest,
53 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
54 let index = self.select_provider_index()?;
55 let provider = &self.providers[index];
56 eprintln!(
57 "RoundRobinRouter.stream: provider={}, stream={:?}",
58 provider.name(),
59 request.stream
60 );
61 let provider_request = provider.transform_request(request)?;
62 provider.execute_stream(provider_request).await
63 }
64
65 fn select_provider_index(&self) -> Result<usize> {
66 let len = self.providers.len();
67 if len == 0 {
68 return Err(SimpleAgentsError::Routing(
69 "no providers configured".to_string(),
70 ));
71 }
72
73 let index = self.counter.fetch_add(1, Ordering::Relaxed);
74 Ok(index % len)
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81 use async_trait::async_trait;
82 use simple_agent_type::prelude::*;
83
84 struct MockProvider {
85 name: &'static str,
86 }
87
88 impl MockProvider {
89 fn new(name: &'static str) -> Self {
90 Self { name }
91 }
92 }
93
94 #[async_trait]
95 impl Provider for MockProvider {
96 fn name(&self) -> &str {
97 self.name
98 }
99
100 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
101 Ok(ProviderRequest::new("http://example.com"))
102 }
103
104 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
105 Ok(ProviderResponse::new(200, serde_json::Value::Null))
106 }
107
108 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
109 Ok(CompletionResponse {
110 id: "resp_test".to_string(),
111 model: "test-model".to_string(),
112 choices: vec![CompletionChoice {
113 index: 0,
114 message: Message::assistant("ok"),
115 finish_reason: FinishReason::Stop,
116 logprobs: None,
117 }],
118 usage: Usage::new(1, 1),
119 created: None,
120 provider: Some(self.name().to_string()),
121 healing_metadata: None,
122 })
123 }
124 }
125
126 fn build_request() -> CompletionRequest {
127 CompletionRequest::builder()
128 .model("test-model")
129 .message(Message::user("hello"))
130 .build()
131 .unwrap()
132 }
133
134 #[test]
135 fn empty_router_returns_error() {
136 let result = RoundRobinRouter::new(Vec::new());
137 match result {
138 Ok(_) => panic!("expected error, got Ok"),
139 Err(SimpleAgentsError::Routing(message)) => {
140 assert_eq!(message, "no providers configured");
141 }
142 Err(_) => panic!("unexpected error type"),
143 }
144 }
145
146 #[tokio::test]
147 async fn round_robin_rotates_providers() {
148 let router = RoundRobinRouter::new(vec![
149 Arc::new(MockProvider::new("p1")),
150 Arc::new(MockProvider::new("p2")),
151 ])
152 .unwrap();
153
154 let request = build_request();
155 let first = router.complete(&request).await.unwrap();
156 let second = router.complete(&request).await.unwrap();
157 let third = router.complete(&request).await.unwrap();
158
159 assert_eq!(first.provider.as_deref(), Some("p1"));
160 assert_eq!(second.provider.as_deref(), Some("p2"));
161 assert_eq!(third.provider.as_deref(), Some("p1"));
162 }
163}