1pub mod learned;
2
3use std::sync::Arc;
4
5use crate::provider::{
6 Brain, BrainError, BrainRequest, ContentBlock, LatencyClass, Msg, PromptCacheConfig,
7};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum TaskTier {
13 Trivial,
14 Small,
15 Medium,
16 Hard,
17 Vision,
18}
19
20impl TaskTier {
21 pub fn from_str(s: &str) -> Self {
22 match s.to_lowercase().as_str() {
23 "trivial" => TaskTier::Trivial,
24 "small" => TaskTier::Small,
25 "medium" => TaskTier::Medium,
26 "hard" => TaskTier::Hard,
27 "vision" => TaskTier::Vision,
28 _ => TaskTier::Medium,
29 }
30 }
31
32 pub fn as_str(&self) -> &str {
33 match self {
34 TaskTier::Trivial => "trivial",
35 TaskTier::Small => "small",
36 TaskTier::Medium => "medium",
37 TaskTier::Hard => "hard",
38 TaskTier::Vision => "vision",
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
44pub struct RoutingNeed {
45 pub tier: TaskTier,
46 pub required_tools: bool,
47 pub required_vision: bool,
48 pub prefer_local: bool,
49}
50
51#[derive(Debug, Clone)]
52pub struct BudgetState {
53 pub daily_limit_usd: f64,
54 pub daily_spent_usd: f64,
55 pub session_limit_usd: f64,
56 pub session_spent_usd: f64,
57}
58
59impl BudgetState {
60 pub fn remaining_daily(&self) -> f64 {
61 (self.daily_limit_usd - self.daily_spent_usd).max(0.0)
62 }
63
64 pub fn remaining_session(&self) -> f64 {
65 (self.session_limit_usd - self.session_spent_usd).max(0.0)
66 }
67
68 pub fn is_exhausted(&self) -> bool {
69 self.remaining_daily() <= 0.0 || self.remaining_session() <= 0.0
70 }
71}
72
73pub trait Router: Send + Sync {
76 fn select(&self, need: &RoutingNeed, budget: &BudgetState) -> Vec<Arc<dyn Brain>>;
79
80 fn on_error(&self, b: &dyn Brain, e: &BrainError) -> Retry;
81
82 fn find_brain_by_id(&self, model_id: &str) -> Option<Arc<dyn Brain>>;
86}
87
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub enum Retry {
90 NextInChain,
91 Abort,
92 WaitAndRetry(u64), }
94
95use std::collections::HashMap;
98
99use crate::config::Config;
100
101pub struct BasicRouter {
102 providers: HashMap<String, Vec<Arc<dyn Brain>>>,
104 policy: HashMap<String, String>,
106 free_first: bool,
107 preferred_provider: Option<String>,
110 preferred_model: Option<String>,
111 routing_mode: String,
113}
114
115impl BasicRouter {
116 pub fn new(config: &Config, providers: HashMap<String, Vec<Arc<dyn Brain>>>) -> Self {
117 let mut policy = HashMap::new();
118 for (k, v) in &config.routing.policy {
119 policy.insert(k.clone(), v.clone());
120 }
121 if !policy.contains_key("trivial") {
123 policy.insert("trivial".into(), "local".into());
124 }
125 if !policy.contains_key("hard") {
126 policy.insert("hard".into(), "anthropic".into());
127 }
128
129 Self {
130 providers,
131 policy,
132 free_first: config.routing.free_first,
133 preferred_provider: config.routing.preferred_provider.clone(),
134 preferred_model: config.routing.preferred_model.clone(),
135 routing_mode: config.routing.routing_mode.clone(),
136 }
137 }
138
139 fn score(brain: &dyn Brain, need: &RoutingNeed, budget: &BudgetState) -> f64 {
141 let caps = brain.caps();
142 let mut score: f64 = 0.0;
143
144 if need.required_tools {
146 if caps.tools {
147 score += 50.0;
148 } else {
149 score -= 250.0;
150 }
151 }
152 if need.required_vision {
153 if caps.vision {
154 score += 50.0;
155 } else {
156 score -= 300.0;
157 }
158 }
159
160 let est_cost = caps.cost_input_per_mtok + caps.cost_output_per_mtok;
162 if est_cost == 0.0 {
163 score += 100.0; } else if budget.remaining_session() < est_cost * 0.1 {
165 score -= 200.0; } else {
167 score -= est_cost * 10.0; }
169
170 match need.tier {
173 TaskTier::Trivial | TaskTier::Small => match caps.latency {
174 LatencyClass::Fast => score += 15.0,
175 LatencyClass::Medium => score += 6.0,
176 LatencyClass::Slow => score += 0.0,
177 },
178 TaskTier::Medium | TaskTier::Hard | TaskTier::Vision => match caps.latency {
179 LatencyClass::Slow => score += 18.0,
180 LatencyClass::Medium => score += 9.0,
181 LatencyClass::Fast => score += 0.0,
182 },
183 }
184
185 let ctx_weight = match need.tier {
188 TaskTier::Hard | TaskTier::Medium => 20_000.0,
189 _ => 10_000.0,
190 };
191 let ctx_cap = match need.tier {
192 TaskTier::Hard | TaskTier::Medium => 20.0,
193 _ => 10.0,
194 };
195 score += (caps.context_window as f64 / ctx_weight).min(ctx_cap);
196
197 score
198 }
199
200 fn resolve_provider(&self, need: &RoutingNeed) -> &str {
201 if self.routing_mode == "manual" {
203 if let Some(ref pref) = self.preferred_provider {
204 return pref.as_str();
205 }
206 }
207 if let Some(ref pref) = self.preferred_provider {
212 return pref.as_str();
213 }
214 self.policy
215 .get(need.tier.as_str())
216 .map(|s| s.as_str())
217 .unwrap_or("anthropic")
218 }
219
220 pub async fn classify_with_model(&self, task: &str, brain: &dyn Brain) -> TaskTier {
223 let prompt = format!(
224 "Classify this task into exactly one tier: trivial, small, medium, hard, vision.\n\nTask: {}\n\nTier:",
225 task
226 );
227
228 let req = BrainRequest {
229 system: Some("You are a task classifier. Output exactly one word: trivial, small, medium, hard, or vision.".into()),
230 messages: vec![Msg {
231 role: "user".into(),
232 content: vec![ContentBlock::Text { text: prompt }],
233 }],
234 tools: vec![],
235 max_tokens: 10,
236 temperature: 0.0,
237 stop: vec![],
238 cache: PromptCacheConfig::disabled(),
239 };
240
241 match brain.complete(req).await {
242 Ok(mut stream) => {
243 use futures::StreamExt;
244 let mut result = String::new();
245 while let Some(ev) = stream.next().await {
246 if let crate::provider::BrainEvent::TextDelta(t) = ev {
247 result.push_str(&t);
248 }
249 }
250 TaskTier::from_str(result.trim())
251 }
252 Err(_) => TaskTier::Medium, }
254 }
255}
256
257impl Router for BasicRouter {
258 fn select(&self, need: &RoutingNeed, budget: &BudgetState) -> Vec<Arc<dyn Brain>> {
259 if self.routing_mode == "manual" {
261 if let Some(ref model) = self.preferred_model {
262 for (_, brains) in &self.providers {
263 for brain in brains {
264 if brain.id() == *model {
265 return vec![brain.clone()];
266 }
267 }
268 }
269 return vec![];
271 }
272 }
273
274 if budget.is_exhausted() && !need.prefer_local {
275 if let Some(local) = self.providers.get("local") {
277 return local.clone();
278 }
279 return vec![];
280 }
281
282 let preferred_provider = self.resolve_provider(need);
283 let preferred_is_local = preferred_provider == "local" || preferred_provider == "ollama";
284 let mut scored: Vec<(f64, String, Arc<dyn Brain>)> = Vec::new();
285
286 for (provider_name, brains) in &self.providers {
288 if need.prefer_local && provider_name != "local" && provider_name != "ollama" {
289 continue;
290 }
291 for brain in brains {
292 let mut s = Self::score(brain.as_ref(), need, budget);
293 if provider_name == preferred_provider {
294 s += 25.0;
295 }
296 if matches!(need.tier, TaskTier::Trivial | TaskTier::Small)
297 && (provider_name == "local" || provider_name == "ollama")
298 {
299 s += 30.0;
300 }
301 scored.push((s, provider_name.clone(), brain.clone()));
302 }
303 }
304
305 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
308
309 const MAX_CHAIN: usize = 6;
315 const PER_PROVIDER_CAP: usize = 3;
316 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
317 let mut per_provider: HashMap<String, usize> = HashMap::new();
318 let mut result: Vec<(String, Arc<dyn Brain>)> = Vec::new();
320
321 for (_, prov, brain) in &scored {
323 if result.len() >= MAX_CHAIN {
324 break;
325 }
326 let id = brain.id().to_string();
327 if seen.contains(&id) {
328 continue;
329 }
330 let count = per_provider.entry(prov.clone()).or_insert(0);
331 if *count >= PER_PROVIDER_CAP {
332 continue;
333 }
334 *count += 1;
335 seen.insert(id);
336 result.push((prov.clone(), brain.clone()));
337 }
338 if result.len() < MAX_CHAIN {
340 for (_, prov, brain) in &scored {
341 if result.len() >= MAX_CHAIN {
342 break;
343 }
344 let id = brain.id().to_string();
345 if seen.insert(id) {
346 result.push((prov.clone(), brain.clone()));
347 }
348 }
349 }
350
351 if matches!(need.tier, TaskTier::Trivial | TaskTier::Small)
354 && (preferred_is_local || self.free_first)
355 && self.routing_mode != "manual"
356 {
357 if let Some(pos) = result.iter().position(|(prov, b)| {
358 (prov == "local" || prov == "ollama") || b.caps().cost_input_per_mtok == 0.0
359 }) {
360 let chosen = result.remove(pos);
361 result.insert(0, chosen);
362 }
363 }
364
365 result.into_iter().map(|(_, brain)| brain).collect()
366 }
367
368 fn on_error(&self, _b: &dyn Brain, e: &BrainError) -> Retry {
369 match e {
370 BrainError::RateLimit { retry_after } => {
371 if let Some(secs) = retry_after {
372 if *secs <= 10 {
373 Retry::WaitAndRetry(*secs)
374 } else {
375 Retry::NextInChain
376 }
377 } else {
378 Retry::NextInChain
379 }
380 }
381 BrainError::ServerError { status, .. } if *status >= 500 => Retry::NextInChain,
382 BrainError::Timeout => Retry::NextInChain,
383 BrainError::Refusal(_) => Retry::Abort,
384 _ => Retry::NextInChain,
385 }
386 }
387
388 fn find_brain_by_id(&self, model_id: &str) -> Option<Arc<dyn Brain>> {
389 for brains in self.providers.values() {
390 for b in brains {
391 if b.id() == model_id {
392 return Some(b.clone());
393 }
394 }
395 }
396 None
397 }
398}