1use serde::{Deserialize, Serialize};
2
3use crate::config::loader::VTCodeConfig;
4use crate::config::types::AgentConfig as CoreAgentConfig;
5use crate::llm::{factory::create_provider_for_model, provider as uni};
6
7#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
8pub enum TaskClass {
9 Simple,
10 Standard,
11 Complex,
12 CodegenHeavy,
13 RetrievalHeavy,
14}
15
16impl std::fmt::Display for TaskClass {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 match self {
19 TaskClass::Simple => write!(f, "simple"),
20 TaskClass::Standard => write!(f, "standard"),
21 TaskClass::Complex => write!(f, "complex"),
22 TaskClass::CodegenHeavy => write!(f, "codegen_heavy"),
23 TaskClass::RetrievalHeavy => write!(f, "retrieval_heavy"),
24 }
25 }
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct RouteDecision {
30 pub class: TaskClass,
31 pub selected_model: String,
32}
33
34pub struct Router;
35
36impl Router {
37 pub fn classify_heuristic(input: &str) -> TaskClass {
38 let text = input.to_lowercase();
39 let has_code_fence = text.contains("```") || text.contains("diff --git");
40 let has_patch_keywords = [
41 "apply_patch",
42 "unified diff",
43 "patch",
44 "edit_file",
45 "create_file",
46 ]
47 .iter()
48 .any(|k| text.contains(k));
49 let retrieval = [
50 "search",
51 "web",
52 "google",
53 "docs",
54 "cite",
55 "source",
56 "up-to-date",
57 ]
58 .iter()
59 .any(|k| text.contains(k));
60 let complex_markers = [
61 "plan",
62 "multi-step",
63 "decompose",
64 "orchestrate",
65 "architecture",
66 "benchmark",
67 "implement end-to-end",
68 "design api",
69 "refactor module",
70 "evaluate",
71 "tests suite",
72 ];
73 let complex = complex_markers.iter().any(|k| text.contains(k));
74 let long = text.len() > 1200;
75
76 if has_code_fence || has_patch_keywords {
77 return TaskClass::CodegenHeavy;
78 }
79 if retrieval {
80 return TaskClass::RetrievalHeavy;
81 }
82 if complex || long {
83 return TaskClass::Complex;
84 }
85 if text.len() < 120 {
86 return TaskClass::Simple;
87 }
88 TaskClass::Standard
89 }
90
91 pub fn route(vt_cfg: &VTCodeConfig, core: &CoreAgentConfig, input: &str) -> RouteDecision {
92 let router_cfg = &vt_cfg.router;
93 let class = if router_cfg.heuristic_classification {
94 Self::classify_heuristic(input)
95 } else {
96 TaskClass::Standard
98 };
99
100 let model = match class {
101 TaskClass::Simple => non_empty_or(&router_cfg.models.simple, &core.model),
102 TaskClass::Standard => non_empty_or(&router_cfg.models.standard, &core.model),
103 TaskClass::Complex => non_empty_or(&router_cfg.models.complex, &core.model),
104 TaskClass::CodegenHeavy => non_empty_or(&router_cfg.models.codegen_heavy, &core.model),
105 TaskClass::RetrievalHeavy => {
106 non_empty_or(&router_cfg.models.retrieval_heavy, &core.model)
107 }
108 };
109
110 RouteDecision {
111 class,
112 selected_model: model.to_string(),
113 }
114 }
115
116 pub async fn route_async(
119 vt_cfg: &VTCodeConfig,
120 core: &CoreAgentConfig,
121 api_key: &str,
122 input: &str,
123 ) -> RouteDecision {
124 let router_cfg = &vt_cfg.router;
125 let mut class = if router_cfg.heuristic_classification {
126 Self::classify_heuristic(input)
127 } else {
128 TaskClass::Standard
129 };
130
131 if !router_cfg.llm_router_model.trim().is_empty() {
132 if let Ok(provider) =
133 create_provider_for_model(&router_cfg.llm_router_model, api_key.to_string())
134 {
135 let sys = "You are a routing classifier. Output only one label: simple | standard | complex | codegen_heavy | retrieval_heavy. Choose the best class for the user's last message. No prose.".to_string();
136 let req = uni::LLMRequest {
137 messages: vec![uni::Message::user(input.to_string())],
138 system_prompt: Some(sys),
139 tools: None,
140 model: router_cfg.llm_router_model.clone(),
141 max_tokens: Some(8),
142 temperature: Some(0.0),
143 stream: false,
144 tool_choice: Some(uni::ToolChoice::none()),
145 parallel_tool_calls: None,
146 parallel_tool_config: None,
147 reasoning_effort: Some(vt_cfg.agent.reasoning_effort.clone()),
148 };
149 if let Ok(resp) = provider.generate(req).await {
150 if let Some(text) = resp.content {
151 let t = text.trim().to_lowercase();
152 class = match t {
153 x if x.contains("codegen") => TaskClass::CodegenHeavy,
154 x if x.contains("retrieval") => TaskClass::RetrievalHeavy,
155 x if x.contains("complex") => TaskClass::Complex,
156 x if x.contains("simple") => TaskClass::Simple,
157 _ => TaskClass::Standard,
158 };
159 }
160 }
161 }
162 }
163
164 let model = match class {
165 TaskClass::Simple => non_empty_or(&router_cfg.models.simple, &core.model),
166 TaskClass::Standard => non_empty_or(&router_cfg.models.standard, &core.model),
167 TaskClass::Complex => non_empty_or(&router_cfg.models.complex, &core.model),
168 TaskClass::CodegenHeavy => non_empty_or(&router_cfg.models.codegen_heavy, &core.model),
169 TaskClass::RetrievalHeavy => {
170 non_empty_or(&router_cfg.models.retrieval_heavy, &core.model)
171 }
172 };
173
174 RouteDecision {
175 class,
176 selected_model: model.to_string(),
177 }
178 }
179}
180
181fn non_empty_or<'a>(value: &'a str, fallback: &'a str) -> &'a str {
182 if value.trim().is_empty() {
183 fallback
184 } else {
185 value
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn test_task_class_display() {
195 assert_eq!(format!("{}", TaskClass::Simple), "simple");
196 assert_eq!(format!("{}", TaskClass::Standard), "standard");
197 assert_eq!(format!("{}", TaskClass::Complex), "complex");
198 assert_eq!(format!("{}", TaskClass::CodegenHeavy), "codegen_heavy");
199 assert_eq!(format!("{}", TaskClass::RetrievalHeavy), "retrieval_heavy");
200 }
201}