1use simple_agent_type::prelude::{
6 CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
7};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11
12#[derive(Debug, Clone, PartialEq)]
14pub struct ProviderCost {
15 pub name: String,
17 pub cost_per_1k_tokens: f64,
19}
20
21impl ProviderCost {
22 pub fn new(name: impl Into<String>, cost_per_1k_tokens: f64) -> Result<Self> {
24 if !cost_per_1k_tokens.is_finite() || cost_per_1k_tokens < 0.0 {
25 return Err(SimpleAgentsError::Routing(
26 "provider cost must be a non-negative finite value".to_string(),
27 ));
28 }
29
30 Ok(Self {
31 name: name.into(),
32 cost_per_1k_tokens,
33 })
34 }
35}
36
37#[derive(Debug, Clone, Default)]
39pub struct CostRouterConfig {
40 pub costs: Vec<ProviderCost>,
42}
43
44impl CostRouterConfig {
45 pub fn new(costs: Vec<ProviderCost>) -> Self {
47 Self { costs }
48 }
49}
50
51pub struct CostRouter {
53 providers: Vec<Arc<dyn Provider>>,
54 costs: Vec<f64>,
55 counter: AtomicUsize,
56}
57
58impl CostRouter {
59 pub fn new(providers: Vec<Arc<dyn Provider>>, config: CostRouterConfig) -> Result<Self> {
64 if providers.is_empty() {
65 return Err(SimpleAgentsError::Routing(
66 "no providers configured".to_string(),
67 ));
68 }
69
70 let mut cost_map = HashMap::new();
71 for cost in config.costs {
72 if !cost.cost_per_1k_tokens.is_finite() || cost.cost_per_1k_tokens < 0.0 {
73 return Err(SimpleAgentsError::Routing(
74 "provider cost must be a non-negative finite value".to_string(),
75 ));
76 }
77 cost_map.insert(cost.name, cost.cost_per_1k_tokens);
78 }
79
80 let mut costs = Vec::with_capacity(providers.len());
81 for provider in &providers {
82 let name = provider.name();
83 match cost_map.get(name) {
84 Some(cost) => costs.push(*cost),
85 None => {
86 return Err(SimpleAgentsError::Routing(format!(
87 "missing cost for provider: {}",
88 name
89 )));
90 }
91 }
92 }
93
94 Ok(Self {
95 providers,
96 costs,
97 counter: AtomicUsize::new(0),
98 })
99 }
100
101 pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
103 let index = self.select_provider_index()?;
104 let provider = &self.providers[index];
105 let provider_request = provider.transform_request(request)?;
106 let provider_response = provider.execute(provider_request).await?;
107 provider.transform_response(provider_response)
108 }
109
110 pub async fn stream(
112 &self,
113 request: &CompletionRequest,
114 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
115 let index = self.select_provider_index()?;
116 let provider = &self.providers[index];
117 let provider_request = provider.transform_request(request)?;
118 provider.execute_stream(provider_request).await
119 }
120
121 fn select_provider_index(&self) -> Result<usize> {
122 if self.providers.is_empty() {
123 return Err(SimpleAgentsError::Routing(
124 "no providers configured".to_string(),
125 ));
126 }
127
128 let mut min_cost = f64::INFINITY;
129 for cost in &self.costs {
130 if *cost < min_cost {
131 min_cost = *cost;
132 }
133 }
134
135 if !min_cost.is_finite() {
136 return Err(SimpleAgentsError::Routing(
137 "invalid provider costs".to_string(),
138 ));
139 }
140
141 let min_indices: Vec<usize> = self
142 .costs
143 .iter()
144 .enumerate()
145 .filter(|(_, cost)| **cost == min_cost)
146 .map(|(index, _)| index)
147 .collect();
148
149 if min_indices.is_empty() {
150 return Err(SimpleAgentsError::Routing(
151 "no providers configured".to_string(),
152 ));
153 }
154
155 let offset = self.counter.fetch_add(1, Ordering::Relaxed);
156 Ok(min_indices[offset % min_indices.len()])
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use async_trait::async_trait;
164 use simple_agent_type::prelude::*;
165
166 struct MockProvider {
167 name: &'static str,
168 }
169
170 impl MockProvider {
171 fn new(name: &'static str) -> Self {
172 Self { name }
173 }
174 }
175
176 #[async_trait]
177 impl Provider for MockProvider {
178 fn name(&self) -> &str {
179 self.name
180 }
181
182 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
183 Ok(ProviderRequest::new("http://example.com"))
184 }
185
186 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
187 Ok(ProviderResponse::new(200, serde_json::Value::Null))
188 }
189
190 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
191 Ok(CompletionResponse {
192 id: "resp_test".to_string(),
193 model: "test-model".to_string(),
194 choices: vec![CompletionChoice {
195 index: 0,
196 message: Message::assistant("ok"),
197 finish_reason: FinishReason::Stop,
198 logprobs: None,
199 }],
200 usage: Usage::new(1, 1),
201 created: None,
202 provider: Some(self.name().to_string()),
203 healing_metadata: None,
204 })
205 }
206 }
207
208 fn build_costs(entries: Vec<ProviderCost>) -> CostRouterConfig {
209 CostRouterConfig::new(entries)
210 }
211
212 #[test]
213 fn empty_router_returns_error() {
214 let config = build_costs(vec![ProviderCost::new("p1", 0.5).unwrap()]);
215 let result = CostRouter::new(Vec::new(), config);
216 match result {
217 Ok(_) => panic!("expected error, got Ok"),
218 Err(SimpleAgentsError::Routing(message)) => {
219 assert_eq!(message, "no providers configured");
220 }
221 Err(_) => panic!("unexpected error type"),
222 }
223 }
224
225 #[test]
226 fn missing_cost_returns_error() {
227 let config = build_costs(vec![ProviderCost::new("p1", 0.5).unwrap()]);
228 let result = CostRouter::new(
229 vec![
230 Arc::new(MockProvider::new("p1")),
231 Arc::new(MockProvider::new("p2")),
232 ],
233 config,
234 );
235
236 match result {
237 Ok(_) => panic!("expected error, got Ok"),
238 Err(SimpleAgentsError::Routing(message)) => {
239 assert_eq!(message, "missing cost for provider: p2");
240 }
241 Err(_) => panic!("unexpected error type"),
242 }
243 }
244
245 #[test]
246 fn selects_lowest_cost() {
247 let config = build_costs(vec![
248 ProviderCost::new("p1", 0.8).unwrap(),
249 ProviderCost::new("p2", 0.2).unwrap(),
250 ]);
251 let router = CostRouter::new(
252 vec![
253 Arc::new(MockProvider::new("p1")),
254 Arc::new(MockProvider::new("p2")),
255 ],
256 config,
257 )
258 .unwrap();
259
260 let index = router.select_provider_index().unwrap();
261 assert_eq!(index, 1);
262 }
263
264 #[test]
265 fn tie_breaks_with_rotation() {
266 let config = build_costs(vec![
267 ProviderCost::new("p1", 0.5).unwrap(),
268 ProviderCost::new("p2", 0.5).unwrap(),
269 ProviderCost::new("p3", 0.8).unwrap(),
270 ]);
271 let router = CostRouter::new(
272 vec![
273 Arc::new(MockProvider::new("p1")),
274 Arc::new(MockProvider::new("p2")),
275 Arc::new(MockProvider::new("p3")),
276 ],
277 config,
278 )
279 .unwrap();
280
281 let first = router.select_provider_index().unwrap();
282 let second = router.select_provider_index().unwrap();
283
284 assert_eq!(first, 0);
285 assert_eq!(second, 1);
286 }
287}