1pub mod learned;
2
3use std::sync::Arc;
4
5use crate::provider::{
6 Brain, BrainError, BrainRequest, ContentBlock, LatencyClass, Msg, PromptCacheConfig,
7};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum TaskTier {
13 Trivial,
14 Small,
15 Medium,
16 Hard,
17 Vision,
18}
19
20impl TaskTier {
21 pub fn from_str(s: &str) -> Self {
22 match s.to_lowercase().as_str() {
23 "trivial" => TaskTier::Trivial,
24 "small" => TaskTier::Small,
25 "medium" => TaskTier::Medium,
26 "hard" => TaskTier::Hard,
27 "vision" => TaskTier::Vision,
28 _ => TaskTier::Medium,
29 }
30 }
31
32 pub fn as_str(&self) -> &str {
33 match self {
34 TaskTier::Trivial => "trivial",
35 TaskTier::Small => "small",
36 TaskTier::Medium => "medium",
37 TaskTier::Hard => "hard",
38 TaskTier::Vision => "vision",
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
44pub struct RoutingNeed {
45 pub tier: TaskTier,
46 pub required_tools: bool,
47 pub required_vision: bool,
48 pub prefer_local: bool,
49}
50
51#[derive(Debug, Clone)]
52pub struct BudgetState {
53 pub daily_limit_usd: f64,
54 pub daily_spent_usd: f64,
55 pub session_limit_usd: f64,
56 pub session_spent_usd: f64,
57}
58
59impl BudgetState {
60 pub fn remaining_daily(&self) -> f64 {
61 (self.daily_limit_usd - self.daily_spent_usd).max(0.0)
62 }
63
64 pub fn remaining_session(&self) -> f64 {
65 (self.session_limit_usd - self.session_spent_usd).max(0.0)
66 }
67
68 pub fn is_exhausted(&self) -> bool {
69 self.remaining_daily() <= 0.0 || self.remaining_session() <= 0.0
70 }
71}
72
73pub trait Router: Send + Sync {
76 fn select(&self, need: &RoutingNeed, budget: &BudgetState) -> Vec<Arc<dyn Brain>>;
79
80 fn on_error(&self, b: &dyn Brain, e: &BrainError) -> Retry;
81
82 fn find_brain_by_id(&self, model_id: &str) -> Option<Arc<dyn Brain>>;
86}
87
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub enum Retry {
90 NextInChain,
91 Abort,
92 WaitAndRetry(u64), }
94
95use std::collections::HashMap;
98
99use crate::config::Config;
100
101pub struct BasicRouter {
102 providers: HashMap<String, Vec<Arc<dyn Brain>>>,
104 policy: HashMap<String, String>,
106 free_first: bool,
107 preferred_provider: Option<String>,
110}
111
112impl BasicRouter {
113 pub fn new(config: &Config, providers: HashMap<String, Vec<Arc<dyn Brain>>>) -> Self {
114 let mut policy = HashMap::new();
115 for (k, v) in &config.routing.policy {
116 policy.insert(k.clone(), v.clone());
117 }
118 if !policy.contains_key("trivial") {
120 policy.insert("trivial".into(), "local".into());
121 }
122 if !policy.contains_key("hard") {
123 policy.insert("hard".into(), "anthropic".into());
124 }
125
126 Self {
127 providers,
128 policy,
129 free_first: config.routing.free_first,
130 preferred_provider: config.routing.preferred_provider.clone(),
131 }
132 }
133
134 fn score(brain: &dyn Brain, need: &RoutingNeed, budget: &BudgetState) -> f64 {
136 let caps = brain.caps();
137 let mut score: f64 = 0.0;
138
139 if need.required_tools {
141 if caps.tools {
142 score += 50.0;
143 } else {
144 score -= 250.0;
145 }
146 }
147 if need.required_vision {
148 if caps.vision {
149 score += 50.0;
150 } else {
151 score -= 300.0;
152 }
153 }
154
155 let est_cost = caps.cost_input_per_mtok + caps.cost_output_per_mtok;
157 if est_cost == 0.0 {
158 score += 100.0; } else if budget.remaining_session() < est_cost * 0.1 {
160 score -= 200.0; } else {
162 score -= est_cost * 10.0; }
164
165 match need.tier {
168 TaskTier::Trivial | TaskTier::Small => match caps.latency {
169 LatencyClass::Fast => score += 15.0,
170 LatencyClass::Medium => score += 6.0,
171 LatencyClass::Slow => score += 0.0,
172 },
173 TaskTier::Medium | TaskTier::Hard | TaskTier::Vision => match caps.latency {
174 LatencyClass::Slow => score += 18.0,
175 LatencyClass::Medium => score += 9.0,
176 LatencyClass::Fast => score += 0.0,
177 },
178 }
179
180 let ctx_weight = match need.tier {
183 TaskTier::Hard | TaskTier::Medium => 20_000.0,
184 _ => 10_000.0,
185 };
186 let ctx_cap = match need.tier {
187 TaskTier::Hard | TaskTier::Medium => 20.0,
188 _ => 10.0,
189 };
190 score += (caps.context_window as f64 / ctx_weight).min(ctx_cap);
191
192 score
193 }
194
195 fn resolve_provider(&self, need: &RoutingNeed) -> &str {
196 if let Some(ref pref) = self.preferred_provider {
201 return pref.as_str();
202 }
203 self.policy
204 .get(need.tier.as_str())
205 .map(|s| s.as_str())
206 .unwrap_or("anthropic")
207 }
208
209 pub async fn classify_with_model(&self, task: &str, brain: &dyn Brain) -> TaskTier {
212 let prompt = format!(
213 "Classify this task into exactly one tier: trivial, small, medium, hard, vision.\n\nTask: {}\n\nTier:",
214 task
215 );
216
217 let req = BrainRequest {
218 system: Some("You are a task classifier. Output exactly one word: trivial, small, medium, hard, or vision.".into()),
219 messages: vec![Msg {
220 role: "user".into(),
221 content: vec![ContentBlock::Text { text: prompt }],
222 }],
223 tools: vec![],
224 max_tokens: 10,
225 temperature: 0.0,
226 stop: vec![],
227 cache: PromptCacheConfig::disabled(),
228 };
229
230 match brain.complete(req).await {
231 Ok(mut stream) => {
232 use futures::StreamExt;
233 let mut result = String::new();
234 while let Some(ev) = stream.next().await {
235 if let crate::provider::BrainEvent::TextDelta(t) = ev {
236 result.push_str(&t);
237 }
238 }
239 TaskTier::from_str(result.trim())
240 }
241 Err(_) => TaskTier::Medium, }
243 }
244}
245
246impl Router for BasicRouter {
247 fn select(&self, need: &RoutingNeed, budget: &BudgetState) -> Vec<Arc<dyn Brain>> {
248 if budget.is_exhausted() && !need.prefer_local {
249 if let Some(local) = self.providers.get("local") {
251 return local.clone();
252 }
253 return vec![];
254 }
255
256 let preferred_provider = self.resolve_provider(need);
257 let preferred_is_local = preferred_provider == "local" || preferred_provider == "ollama";
258 let mut scored: Vec<(f64, String, Arc<dyn Brain>)> = Vec::new();
259
260 for (provider_name, brains) in &self.providers {
262 if need.prefer_local && provider_name != "local" && provider_name != "ollama" {
263 continue;
264 }
265 for brain in brains {
266 let mut s = Self::score(brain.as_ref(), need, budget);
267 if provider_name == preferred_provider {
268 s += 25.0;
269 }
270 if matches!(need.tier, TaskTier::Trivial | TaskTier::Small)
271 && (provider_name == "local" || provider_name == "ollama")
272 {
273 s += 30.0;
274 }
275 scored.push((s, provider_name.clone(), brain.clone()));
276 }
277 }
278
279 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
282
283 const MAX_CHAIN: usize = 6;
289 const PER_PROVIDER_CAP: usize = 3;
290 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
291 let mut per_provider: HashMap<String, usize> = HashMap::new();
292 let mut result: Vec<(String, Arc<dyn Brain>)> = Vec::new();
294
295 for (_, prov, brain) in &scored {
297 if result.len() >= MAX_CHAIN {
298 break;
299 }
300 let id = brain.id().to_string();
301 if seen.contains(&id) {
302 continue;
303 }
304 let count = per_provider.entry(prov.clone()).or_insert(0);
305 if *count >= PER_PROVIDER_CAP {
306 continue;
307 }
308 *count += 1;
309 seen.insert(id);
310 result.push((prov.clone(), brain.clone()));
311 }
312 if result.len() < MAX_CHAIN {
314 for (_, prov, brain) in &scored {
315 if result.len() >= MAX_CHAIN {
316 break;
317 }
318 let id = brain.id().to_string();
319 if seen.insert(id) {
320 result.push((prov.clone(), brain.clone()));
321 }
322 }
323 }
324
325 if matches!(need.tier, TaskTier::Trivial | TaskTier::Small)
328 && (preferred_is_local || self.free_first)
329 {
330 if let Some(pos) = result.iter().position(|(prov, b)| {
331 (prov == "local" || prov == "ollama") || b.caps().cost_input_per_mtok == 0.0
332 }) {
333 let chosen = result.remove(pos);
334 result.insert(0, chosen);
335 }
336 }
337
338 result.into_iter().map(|(_, brain)| brain).collect()
339 }
340
341 fn on_error(&self, _b: &dyn Brain, e: &BrainError) -> Retry {
342 match e {
343 BrainError::RateLimit { retry_after } => {
344 if let Some(secs) = retry_after {
345 if *secs <= 10 {
346 Retry::WaitAndRetry(*secs)
347 } else {
348 Retry::NextInChain
349 }
350 } else {
351 Retry::NextInChain
352 }
353 }
354 BrainError::ServerError { status, .. } if *status >= 500 => Retry::NextInChain,
355 BrainError::Timeout => Retry::NextInChain,
356 BrainError::Refusal(_) => Retry::Abort,
357 _ => Retry::NextInChain,
358 }
359 }
360
361 fn find_brain_by_id(&self, model_id: &str) -> Option<Arc<dyn Brain>> {
362 for brains in self.providers.values() {
363 for b in brains {
364 if b.id() == model_id {
365 return Some(b.clone());
366 }
367 }
368 }
369 None
370 }
371}