Skip to main content

simple_agents_router/
cost.rs

1//! Cost-based routing implementation.
2//!
3//! Routes requests to lowest-cost provider.
4
5use simple_agent_type::prelude::{
6    CompletionChunk, CompletionRequest, CompletionResponse, Provider, Result, SimpleAgentsError,
7};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::Arc;
11
12/// Cost metadata for a provider.
13#[derive(Debug, Clone, PartialEq)]
14pub struct ProviderCost {
15    /// Provider name that matches `Provider::name()`.
16    pub name: String,
17    /// Cost per 1k tokens.
18    pub cost_per_1k_tokens: f64,
19}
20
21impl ProviderCost {
22    /// Create a new provider cost entry.
23    pub fn new(name: impl Into<String>, cost_per_1k_tokens: f64) -> Result<Self> {
24        if !cost_per_1k_tokens.is_finite() || cost_per_1k_tokens < 0.0 {
25            return Err(SimpleAgentsError::Routing(
26                "provider cost must be a non-negative finite value".to_string(),
27            ));
28        }
29
30        Ok(Self {
31            name: name.into(),
32            cost_per_1k_tokens,
33        })
34    }
35}
36
37/// Configuration for cost-based routing.
38#[derive(Debug, Clone, Default)]
39pub struct CostRouterConfig {
40    /// Provider costs.
41    pub costs: Vec<ProviderCost>,
42}
43
44impl CostRouterConfig {
45    /// Create a config from a list of provider costs.
46    pub fn new(costs: Vec<ProviderCost>) -> Self {
47        Self { costs }
48    }
49}
50
51/// Router that selects providers based on lowest cost.
52pub struct CostRouter {
53    providers: Vec<Arc<dyn Provider>>,
54    costs: Vec<f64>,
55    counter: AtomicUsize,
56}
57
58impl CostRouter {
59    /// Create a cost router using the provided config.
60    ///
61    /// # Errors
62    /// Returns a routing error if providers or costs are missing.
63    pub fn new(providers: Vec<Arc<dyn Provider>>, config: CostRouterConfig) -> Result<Self> {
64        if providers.is_empty() {
65            return Err(SimpleAgentsError::Routing(
66                "no providers configured".to_string(),
67            ));
68        }
69
70        let mut cost_map = HashMap::new();
71        for cost in config.costs {
72            if !cost.cost_per_1k_tokens.is_finite() || cost.cost_per_1k_tokens < 0.0 {
73                return Err(SimpleAgentsError::Routing(
74                    "provider cost must be a non-negative finite value".to_string(),
75                ));
76            }
77            cost_map.insert(cost.name, cost.cost_per_1k_tokens);
78        }
79
80        let mut costs = Vec::with_capacity(providers.len());
81        for provider in &providers {
82            let name = provider.name();
83            match cost_map.get(name) {
84                Some(cost) => costs.push(*cost),
85                None => {
86                    return Err(SimpleAgentsError::Routing(format!(
87                        "missing cost for provider: {}",
88                        name
89                    )));
90                }
91            }
92        }
93
94        Ok(Self {
95            providers,
96            costs,
97            counter: AtomicUsize::new(0),
98        })
99    }
100
101    /// Execute a completion request using cost-based selection.
102    pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
103        let index = self.select_provider_index()?;
104        let provider = &self.providers[index];
105        let provider_request = provider.transform_request(request)?;
106        let provider_response = provider.execute(provider_request).await?;
107        provider.transform_response(provider_response)
108    }
109
110    /// Execute a streaming request using cost-based selection.
111    pub async fn stream(
112        &self,
113        request: &CompletionRequest,
114    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
115        let index = self.select_provider_index()?;
116        let provider = &self.providers[index];
117        let provider_request = provider.transform_request(request)?;
118        provider.execute_stream(provider_request).await
119    }
120
121    fn select_provider_index(&self) -> Result<usize> {
122        if self.providers.is_empty() {
123            return Err(SimpleAgentsError::Routing(
124                "no providers configured".to_string(),
125            ));
126        }
127
128        let mut min_cost = f64::INFINITY;
129        for cost in &self.costs {
130            if *cost < min_cost {
131                min_cost = *cost;
132            }
133        }
134
135        if !min_cost.is_finite() {
136            return Err(SimpleAgentsError::Routing(
137                "invalid provider costs".to_string(),
138            ));
139        }
140
141        let min_indices: Vec<usize> = self
142            .costs
143            .iter()
144            .enumerate()
145            .filter(|(_, cost)| **cost == min_cost)
146            .map(|(index, _)| index)
147            .collect();
148
149        if min_indices.is_empty() {
150            return Err(SimpleAgentsError::Routing(
151                "no providers configured".to_string(),
152            ));
153        }
154
155        let offset = self.counter.fetch_add(1, Ordering::Relaxed);
156        Ok(min_indices[offset % min_indices.len()])
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use async_trait::async_trait;
164    use simple_agent_type::prelude::*;
165
166    struct MockProvider {
167        name: &'static str,
168    }
169
170    impl MockProvider {
171        fn new(name: &'static str) -> Self {
172            Self { name }
173        }
174    }
175
176    #[async_trait]
177    impl Provider for MockProvider {
178        fn name(&self) -> &str {
179            self.name
180        }
181
182        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
183            Ok(ProviderRequest::new("http://example.com"))
184        }
185
186        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
187            Ok(ProviderResponse::new(200, serde_json::Value::Null))
188        }
189
190        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
191            Ok(CompletionResponse {
192                id: "resp_test".to_string(),
193                model: "test-model".to_string(),
194                choices: vec![CompletionChoice {
195                    index: 0,
196                    message: Message::assistant("ok"),
197                    finish_reason: FinishReason::Stop,
198                    logprobs: None,
199                }],
200                usage: Usage::new(1, 1),
201                created: None,
202                provider: Some(self.name().to_string()),
203                healing_metadata: None,
204            })
205        }
206    }
207
208    fn build_costs(entries: Vec<ProviderCost>) -> CostRouterConfig {
209        CostRouterConfig::new(entries)
210    }
211
212    #[test]
213    fn empty_router_returns_error() {
214        let config = build_costs(vec![ProviderCost::new("p1", 0.5).unwrap()]);
215        let result = CostRouter::new(Vec::new(), config);
216        match result {
217            Ok(_) => panic!("expected error, got Ok"),
218            Err(SimpleAgentsError::Routing(message)) => {
219                assert_eq!(message, "no providers configured");
220            }
221            Err(_) => panic!("unexpected error type"),
222        }
223    }
224
225    #[test]
226    fn missing_cost_returns_error() {
227        let config = build_costs(vec![ProviderCost::new("p1", 0.5).unwrap()]);
228        let result = CostRouter::new(
229            vec![
230                Arc::new(MockProvider::new("p1")),
231                Arc::new(MockProvider::new("p2")),
232            ],
233            config,
234        );
235
236        match result {
237            Ok(_) => panic!("expected error, got Ok"),
238            Err(SimpleAgentsError::Routing(message)) => {
239                assert_eq!(message, "missing cost for provider: p2");
240            }
241            Err(_) => panic!("unexpected error type"),
242        }
243    }
244
245    #[test]
246    fn selects_lowest_cost() {
247        let config = build_costs(vec![
248            ProviderCost::new("p1", 0.8).unwrap(),
249            ProviderCost::new("p2", 0.2).unwrap(),
250        ]);
251        let router = CostRouter::new(
252            vec![
253                Arc::new(MockProvider::new("p1")),
254                Arc::new(MockProvider::new("p2")),
255            ],
256            config,
257        )
258        .unwrap();
259
260        let index = router.select_provider_index().unwrap();
261        assert_eq!(index, 1);
262    }
263
264    #[test]
265    fn tie_breaks_with_rotation() {
266        let config = build_costs(vec![
267            ProviderCost::new("p1", 0.5).unwrap(),
268            ProviderCost::new("p2", 0.5).unwrap(),
269            ProviderCost::new("p3", 0.8).unwrap(),
270        ]);
271        let router = CostRouter::new(
272            vec![
273                Arc::new(MockProvider::new("p1")),
274                Arc::new(MockProvider::new("p2")),
275                Arc::new(MockProvider::new("p3")),
276            ],
277            config,
278        )
279        .unwrap();
280
281        let first = router.select_provider_index().unwrap();
282        let second = router.select_provider_index().unwrap();
283
284        assert_eq!(first, 0);
285        assert_eq!(second, 1);
286    }
287}