vtcode_core/core/
router.rs

1use serde::{Deserialize, Serialize};
2
3use crate::config::loader::VTCodeConfig;
4use crate::config::types::AgentConfig as CoreAgentConfig;
5use crate::llm::{
6    factory::{create_provider_with_config, get_factory},
7    provider as uni,
8};
9use crate::models::ModelId;
10
11#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
12pub enum TaskClass {
13    Simple,
14    Standard,
15    Complex,
16    CodegenHeavy,
17    RetrievalHeavy,
18}
19
20impl std::fmt::Display for TaskClass {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        match self {
23            TaskClass::Simple => write!(f, "simple"),
24            TaskClass::Standard => write!(f, "standard"),
25            TaskClass::Complex => write!(f, "complex"),
26            TaskClass::CodegenHeavy => write!(f, "codegen_heavy"),
27            TaskClass::RetrievalHeavy => write!(f, "retrieval_heavy"),
28        }
29    }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct RouteDecision {
34    pub class: TaskClass,
35    pub selected_model: String,
36}
37
38pub struct Router;
39
40impl Router {
41    pub fn classify_heuristic(input: &str) -> TaskClass {
42        let text = input.to_lowercase();
43        let has_code_fence = text.contains("```") || text.contains("diff --git");
44        let has_patch_keywords = [
45            "apply_patch",
46            "unified diff",
47            "patch",
48            "edit_file",
49            "create_file",
50        ]
51        .iter()
52        .any(|k| text.contains(k));
53        let retrieval = [
54            "search",
55            "web",
56            "google",
57            "docs",
58            "cite",
59            "source",
60            "up-to-date",
61        ]
62        .iter()
63        .any(|k| text.contains(k));
64        let complex_markers = [
65            "plan",
66            "multi-step",
67            "decompose",
68            "orchestrate",
69            "architecture",
70            "benchmark",
71            "implement end-to-end",
72            "design api",
73            "refactor module",
74            "evaluate",
75            "tests suite",
76        ];
77        let complex = complex_markers.iter().any(|k| text.contains(k));
78        let long = text.len() > 1200;
79
80        if has_code_fence || has_patch_keywords {
81            return TaskClass::CodegenHeavy;
82        }
83        if retrieval {
84            return TaskClass::RetrievalHeavy;
85        }
86        if complex || long {
87            return TaskClass::Complex;
88        }
89        if text.len() < 120 {
90            return TaskClass::Simple;
91        }
92        TaskClass::Standard
93    }
94
95    pub fn route(vt_cfg: &VTCodeConfig, core: &CoreAgentConfig, input: &str) -> RouteDecision {
96        let router_cfg = &vt_cfg.router;
97        let class = if router_cfg.heuristic_classification {
98            Self::classify_heuristic(input)
99        } else {
100            // fallback: treat as standard
101            TaskClass::Standard
102        };
103
104        let model = match class {
105            TaskClass::Simple => non_empty_or(&router_cfg.models.simple, &core.model),
106            TaskClass::Standard => non_empty_or(&router_cfg.models.standard, &core.model),
107            TaskClass::Complex => non_empty_or(&router_cfg.models.complex, &core.model),
108            TaskClass::CodegenHeavy => non_empty_or(&router_cfg.models.codegen_heavy, &core.model),
109            TaskClass::RetrievalHeavy => {
110                non_empty_or(&router_cfg.models.retrieval_heavy, &core.model)
111            }
112        };
113
114        RouteDecision {
115            class,
116            selected_model: model.to_string(),
117        }
118    }
119
120    /// Optional LLM-based classification when `router.llm_router_model` is set.
121    /// Falls back to heuristics on any error.
122    pub async fn route_async(
123        vt_cfg: &VTCodeConfig,
124        core: &CoreAgentConfig,
125        api_key: &str,
126        input: &str,
127    ) -> RouteDecision {
128        let router_cfg = &vt_cfg.router;
129        let mut class = if router_cfg.heuristic_classification {
130            Self::classify_heuristic(input)
131        } else {
132            TaskClass::Standard
133        };
134
135        if !router_cfg.llm_router_model.trim().is_empty() {
136            let provider_name = if core.provider.trim().is_empty() {
137                core.model
138                    .parse::<ModelId>()
139                    .ok()
140                    .map(|model| model.provider().to_string())
141                    .or_else(|| {
142                        let factory = get_factory().lock().unwrap();
143                        factory.provider_from_model(core.model.as_str())
144                    })
145                    .unwrap_or_else(|| "gemini".to_string())
146            } else {
147                core.provider.to_lowercase()
148            };
149            if let Ok(provider) = create_provider_with_config(
150                &provider_name,
151                Some(api_key.to_string()),
152                None,
153                Some(router_cfg.llm_router_model.clone()),
154                Some(core.prompt_cache.clone()),
155            ) {
156                let sys = "You are a routing classifier. Output only one label: simple | standard | complex | codegen_heavy | retrieval_heavy. Choose the best class for the user's last message. No prose.".to_string();
157                let supports_effort =
158                    provider.supports_reasoning_effort(&router_cfg.llm_router_model);
159                let reasoning_effort = if supports_effort {
160                    Some(vt_cfg.agent.reasoning_effort.as_str().to_string())
161                } else {
162                    None
163                };
164                let req = uni::LLMRequest {
165                    messages: vec![uni::Message::user(input.to_string())],
166                    system_prompt: Some(sys),
167                    tools: None,
168                    model: router_cfg.llm_router_model.clone(),
169                    max_tokens: Some(8),
170                    temperature: Some(0.0),
171                    stream: false,
172                    tool_choice: Some(uni::ToolChoice::none()),
173                    parallel_tool_calls: None,
174                    parallel_tool_config: None,
175                    reasoning_effort,
176                };
177                if let Ok(resp) = provider.generate(req).await {
178                    if let Some(text) = resp.content {
179                        let t = text.trim().to_lowercase();
180                        class = match t {
181                            x if x.contains("codegen") => TaskClass::CodegenHeavy,
182                            x if x.contains("retrieval") => TaskClass::RetrievalHeavy,
183                            x if x.contains("complex") => TaskClass::Complex,
184                            x if x.contains("simple") => TaskClass::Simple,
185                            _ => TaskClass::Standard,
186                        };
187                    }
188                }
189            }
190        }
191
192        let model = match class {
193            TaskClass::Simple => non_empty_or(&router_cfg.models.simple, &core.model),
194            TaskClass::Standard => non_empty_or(&router_cfg.models.standard, &core.model),
195            TaskClass::Complex => non_empty_or(&router_cfg.models.complex, &core.model),
196            TaskClass::CodegenHeavy => non_empty_or(&router_cfg.models.codegen_heavy, &core.model),
197            TaskClass::RetrievalHeavy => {
198                non_empty_or(&router_cfg.models.retrieval_heavy, &core.model)
199            }
200        };
201
202        RouteDecision {
203            class,
204            selected_model: model.to_string(),
205        }
206    }
207}
208
209fn non_empty_or<'a>(value: &'a str, fallback: &'a str) -> &'a str {
210    if value.trim().is_empty() {
211        fallback
212    } else {
213        value
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn test_task_class_display() {
223        assert_eq!(format!("{}", TaskClass::Simple), "simple");
224        assert_eq!(format!("{}", TaskClass::Standard), "standard");
225        assert_eq!(format!("{}", TaskClass::Complex), "complex");
226        assert_eq!(format!("{}", TaskClass::CodegenHeavy), "codegen_heavy");
227        assert_eq!(format!("{}", TaskClass::RetrievalHeavy), "retrieval_heavy");
228    }
229}