1use serde::{Deserialize, Serialize};
2
3use crate::config::loader::VTCodeConfig;
4use crate::config::types::AgentConfig as CoreAgentConfig;
5use crate::llm::{
6 factory::{create_provider_with_config, get_factory},
7 provider as uni,
8};
9use crate::models::ModelId;
10
11#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
12pub enum TaskClass {
13 Simple,
14 Standard,
15 Complex,
16 CodegenHeavy,
17 RetrievalHeavy,
18}
19
20impl std::fmt::Display for TaskClass {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 match self {
23 TaskClass::Simple => write!(f, "simple"),
24 TaskClass::Standard => write!(f, "standard"),
25 TaskClass::Complex => write!(f, "complex"),
26 TaskClass::CodegenHeavy => write!(f, "codegen_heavy"),
27 TaskClass::RetrievalHeavy => write!(f, "retrieval_heavy"),
28 }
29 }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct RouteDecision {
34 pub class: TaskClass,
35 pub selected_model: String,
36}
37
38pub struct Router;
39
40impl Router {
41 pub fn classify_heuristic(input: &str) -> TaskClass {
42 let text = input.to_lowercase();
43 let has_code_fence = text.contains("```") || text.contains("diff --git");
44 let has_patch_keywords = [
45 "apply_patch",
46 "unified diff",
47 "patch",
48 "edit_file",
49 "create_file",
50 ]
51 .iter()
52 .any(|k| text.contains(k));
53 let retrieval = [
54 "search",
55 "web",
56 "google",
57 "docs",
58 "cite",
59 "source",
60 "up-to-date",
61 ]
62 .iter()
63 .any(|k| text.contains(k));
64 let complex_markers = [
65 "plan",
66 "multi-step",
67 "decompose",
68 "orchestrate",
69 "architecture",
70 "benchmark",
71 "implement end-to-end",
72 "design api",
73 "refactor module",
74 "evaluate",
75 "tests suite",
76 ];
77 let complex = complex_markers.iter().any(|k| text.contains(k));
78 let long = text.len() > 1200;
79
80 if has_code_fence || has_patch_keywords {
81 return TaskClass::CodegenHeavy;
82 }
83 if retrieval {
84 return TaskClass::RetrievalHeavy;
85 }
86 if complex || long {
87 return TaskClass::Complex;
88 }
89 if text.len() < 120 {
90 return TaskClass::Simple;
91 }
92 TaskClass::Standard
93 }
94
95 pub fn route(vt_cfg: &VTCodeConfig, core: &CoreAgentConfig, input: &str) -> RouteDecision {
96 let router_cfg = &vt_cfg.router;
97 let class = if router_cfg.heuristic_classification {
98 Self::classify_heuristic(input)
99 } else {
100 TaskClass::Standard
102 };
103
104 let model = match class {
105 TaskClass::Simple => non_empty_or(&router_cfg.models.simple, &core.model),
106 TaskClass::Standard => non_empty_or(&router_cfg.models.standard, &core.model),
107 TaskClass::Complex => non_empty_or(&router_cfg.models.complex, &core.model),
108 TaskClass::CodegenHeavy => non_empty_or(&router_cfg.models.codegen_heavy, &core.model),
109 TaskClass::RetrievalHeavy => {
110 non_empty_or(&router_cfg.models.retrieval_heavy, &core.model)
111 }
112 };
113
114 RouteDecision {
115 class,
116 selected_model: model.to_string(),
117 }
118 }
119
120 pub async fn route_async(
123 vt_cfg: &VTCodeConfig,
124 core: &CoreAgentConfig,
125 api_key: &str,
126 input: &str,
127 ) -> RouteDecision {
128 let router_cfg = &vt_cfg.router;
129 let mut class = if router_cfg.heuristic_classification {
130 Self::classify_heuristic(input)
131 } else {
132 TaskClass::Standard
133 };
134
135 if !router_cfg.llm_router_model.trim().is_empty() {
136 let provider_name = if core.provider.trim().is_empty() {
137 core.model
138 .parse::<ModelId>()
139 .ok()
140 .map(|model| model.provider().to_string())
141 .or_else(|| {
142 let factory = get_factory().lock().unwrap();
143 factory.provider_from_model(core.model.as_str())
144 })
145 .unwrap_or_else(|| "gemini".to_string())
146 } else {
147 core.provider.to_lowercase()
148 };
149 if let Ok(provider) = create_provider_with_config(
150 &provider_name,
151 Some(api_key.to_string()),
152 None,
153 Some(router_cfg.llm_router_model.clone()),
154 ) {
155 let sys = "You are a routing classifier. Output only one label: simple | standard | complex | codegen_heavy | retrieval_heavy. Choose the best class for the user's last message. No prose.".to_string();
156 let req = uni::LLMRequest {
157 messages: vec![uni::Message::user(input.to_string())],
158 system_prompt: Some(sys),
159 tools: None,
160 model: router_cfg.llm_router_model.clone(),
161 max_tokens: Some(8),
162 temperature: Some(0.0),
163 stream: false,
164 tool_choice: Some(uni::ToolChoice::none()),
165 parallel_tool_calls: None,
166 parallel_tool_config: None,
167 reasoning_effort: Some(vt_cfg.agent.reasoning_effort.clone()),
168 };
169 if let Ok(resp) = provider.generate(req).await {
170 if let Some(text) = resp.content {
171 let t = text.trim().to_lowercase();
172 class = match t {
173 x if x.contains("codegen") => TaskClass::CodegenHeavy,
174 x if x.contains("retrieval") => TaskClass::RetrievalHeavy,
175 x if x.contains("complex") => TaskClass::Complex,
176 x if x.contains("simple") => TaskClass::Simple,
177 _ => TaskClass::Standard,
178 };
179 }
180 }
181 }
182 }
183
184 let model = match class {
185 TaskClass::Simple => non_empty_or(&router_cfg.models.simple, &core.model),
186 TaskClass::Standard => non_empty_or(&router_cfg.models.standard, &core.model),
187 TaskClass::Complex => non_empty_or(&router_cfg.models.complex, &core.model),
188 TaskClass::CodegenHeavy => non_empty_or(&router_cfg.models.codegen_heavy, &core.model),
189 TaskClass::RetrievalHeavy => {
190 non_empty_or(&router_cfg.models.retrieval_heavy, &core.model)
191 }
192 };
193
194 RouteDecision {
195 class,
196 selected_model: model.to_string(),
197 }
198 }
199}
200
201fn non_empty_or<'a>(value: &'a str, fallback: &'a str) -> &'a str {
202 if value.trim().is_empty() {
203 fallback
204 } else {
205 value
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn test_task_class_display() {
215 assert_eq!(format!("{}", TaskClass::Simple), "simple");
216 assert_eq!(format!("{}", TaskClass::Standard), "standard");
217 assert_eq!(format!("{}", TaskClass::Complex), "complex");
218 assert_eq!(format!("{}", TaskClass::CodegenHeavy), "codegen_heavy");
219 assert_eq!(format!("{}", TaskClass::RetrievalHeavy), "retrieval_heavy");
220 }
221}