1use serde::{Deserialize, Serialize};
2
3use crate::config::loader::VTCodeConfig;
4use crate::config::types::AgentConfig as CoreAgentConfig;
5use crate::llm::{
6 factory::{create_provider_with_config, get_factory},
7 provider as uni,
8};
9use crate::models::ModelId;
10
11#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
12pub enum TaskClass {
13 Simple,
14 Standard,
15 Complex,
16 CodegenHeavy,
17 RetrievalHeavy,
18}
19
20impl std::fmt::Display for TaskClass {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 match self {
23 TaskClass::Simple => write!(f, "simple"),
24 TaskClass::Standard => write!(f, "standard"),
25 TaskClass::Complex => write!(f, "complex"),
26 TaskClass::CodegenHeavy => write!(f, "codegen_heavy"),
27 TaskClass::RetrievalHeavy => write!(f, "retrieval_heavy"),
28 }
29 }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct RouteDecision {
34 pub class: TaskClass,
35 pub selected_model: String,
36}
37
38pub struct Router;
39
40impl Router {
41 pub fn classify_heuristic(input: &str) -> TaskClass {
42 let text = input.to_lowercase();
43 let has_code_fence = text.contains("```") || text.contains("diff --git");
44 let has_patch_keywords = [
45 "apply_patch",
46 "unified diff",
47 "patch",
48 "edit_file",
49 "create_file",
50 ]
51 .iter()
52 .any(|k| text.contains(k));
53 let retrieval = [
54 "search",
55 "web",
56 "google",
57 "docs",
58 "cite",
59 "source",
60 "up-to-date",
61 ]
62 .iter()
63 .any(|k| text.contains(k));
64 let complex_markers = [
65 "plan",
66 "multi-step",
67 "decompose",
68 "orchestrate",
69 "architecture",
70 "benchmark",
71 "implement end-to-end",
72 "design api",
73 "refactor module",
74 "evaluate",
75 "tests suite",
76 ];
77 let complex = complex_markers.iter().any(|k| text.contains(k));
78 let long = text.len() > 1200;
79
80 if has_code_fence || has_patch_keywords {
81 return TaskClass::CodegenHeavy;
82 }
83 if retrieval {
84 return TaskClass::RetrievalHeavy;
85 }
86 if complex || long {
87 return TaskClass::Complex;
88 }
89 if text.len() < 120 {
90 return TaskClass::Simple;
91 }
92 TaskClass::Standard
93 }
94
95 pub fn route(vt_cfg: &VTCodeConfig, core: &CoreAgentConfig, input: &str) -> RouteDecision {
96 let router_cfg = &vt_cfg.router;
97 let class = if router_cfg.heuristic_classification {
98 Self::classify_heuristic(input)
99 } else {
100 TaskClass::Standard
102 };
103
104 let model = match class {
105 TaskClass::Simple => non_empty_or(&router_cfg.models.simple, &core.model),
106 TaskClass::Standard => non_empty_or(&router_cfg.models.standard, &core.model),
107 TaskClass::Complex => non_empty_or(&router_cfg.models.complex, &core.model),
108 TaskClass::CodegenHeavy => non_empty_or(&router_cfg.models.codegen_heavy, &core.model),
109 TaskClass::RetrievalHeavy => {
110 non_empty_or(&router_cfg.models.retrieval_heavy, &core.model)
111 }
112 };
113
114 RouteDecision {
115 class,
116 selected_model: model.to_string(),
117 }
118 }
119
120 pub async fn route_async(
123 vt_cfg: &VTCodeConfig,
124 core: &CoreAgentConfig,
125 api_key: &str,
126 input: &str,
127 ) -> RouteDecision {
128 let router_cfg = &vt_cfg.router;
129 let mut class = if router_cfg.heuristic_classification {
130 Self::classify_heuristic(input)
131 } else {
132 TaskClass::Standard
133 };
134
135 if !router_cfg.llm_router_model.trim().is_empty() {
136 let provider_name = if core.provider.trim().is_empty() {
137 core.model
138 .parse::<ModelId>()
139 .ok()
140 .map(|model| model.provider().to_string())
141 .or_else(|| {
142 let factory = get_factory().lock().unwrap();
143 factory.provider_from_model(core.model.as_str())
144 })
145 .unwrap_or_else(|| "gemini".to_string())
146 } else {
147 core.provider.to_lowercase()
148 };
149 if let Ok(provider) = create_provider_with_config(
150 &provider_name,
151 Some(api_key.to_string()),
152 None,
153 Some(router_cfg.llm_router_model.clone()),
154 Some(core.prompt_cache.clone()),
155 ) {
156 let sys = "You are a routing classifier. Output only one label: simple | standard | complex | codegen_heavy | retrieval_heavy. Choose the best class for the user's last message. No prose.".to_string();
157 let supports_effort =
158 provider.supports_reasoning_effort(&router_cfg.llm_router_model);
159 let reasoning_effort = if supports_effort {
160 Some(vt_cfg.agent.reasoning_effort.as_str().to_string())
161 } else {
162 None
163 };
164 let req = uni::LLMRequest {
165 messages: vec![uni::Message::user(input.to_string())],
166 system_prompt: Some(sys),
167 tools: None,
168 model: router_cfg.llm_router_model.clone(),
169 max_tokens: Some(8),
170 temperature: Some(0.0),
171 stream: false,
172 tool_choice: Some(uni::ToolChoice::none()),
173 parallel_tool_calls: None,
174 parallel_tool_config: None,
175 reasoning_effort,
176 };
177 if let Ok(resp) = provider.generate(req).await {
178 if let Some(text) = resp.content {
179 let t = text.trim().to_lowercase();
180 class = match t {
181 x if x.contains("codegen") => TaskClass::CodegenHeavy,
182 x if x.contains("retrieval") => TaskClass::RetrievalHeavy,
183 x if x.contains("complex") => TaskClass::Complex,
184 x if x.contains("simple") => TaskClass::Simple,
185 _ => TaskClass::Standard,
186 };
187 }
188 }
189 }
190 }
191
192 let model = match class {
193 TaskClass::Simple => non_empty_or(&router_cfg.models.simple, &core.model),
194 TaskClass::Standard => non_empty_or(&router_cfg.models.standard, &core.model),
195 TaskClass::Complex => non_empty_or(&router_cfg.models.complex, &core.model),
196 TaskClass::CodegenHeavy => non_empty_or(&router_cfg.models.codegen_heavy, &core.model),
197 TaskClass::RetrievalHeavy => {
198 non_empty_or(&router_cfg.models.retrieval_heavy, &core.model)
199 }
200 };
201
202 RouteDecision {
203 class,
204 selected_model: model.to_string(),
205 }
206 }
207}
208
209fn non_empty_or<'a>(value: &'a str, fallback: &'a str) -> &'a str {
210 if value.trim().is_empty() {
211 fallback
212 } else {
213 value
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn test_task_class_display() {
223 assert_eq!(format!("{}", TaskClass::Simple), "simple");
224 assert_eq!(format!("{}", TaskClass::Standard), "standard");
225 assert_eq!(format!("{}", TaskClass::Complex), "complex");
226 assert_eq!(format!("{}", TaskClass::CodegenHeavy), "codegen_heavy");
227 assert_eq!(format!("{}", TaskClass::RetrievalHeavy), "retrieval_heavy");
228 }
229}