vtcode_core/core/
router.rs

1use serde::{Deserialize, Serialize};
2
3use crate::config::loader::VTCodeConfig;
4use crate::config::types::AgentConfig as CoreAgentConfig;
5use crate::llm::{factory::create_provider_for_model, provider as uni};
6
7#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
8pub enum TaskClass {
9    Simple,
10    Standard,
11    Complex,
12    CodegenHeavy,
13    RetrievalHeavy,
14}
15
16impl std::fmt::Display for TaskClass {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        match self {
19            TaskClass::Simple => write!(f, "simple"),
20            TaskClass::Standard => write!(f, "standard"),
21            TaskClass::Complex => write!(f, "complex"),
22            TaskClass::CodegenHeavy => write!(f, "codegen_heavy"),
23            TaskClass::RetrievalHeavy => write!(f, "retrieval_heavy"),
24        }
25    }
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct RouteDecision {
30    pub class: TaskClass,
31    pub selected_model: String,
32}
33
34pub struct Router;
35
36impl Router {
37    pub fn classify_heuristic(input: &str) -> TaskClass {
38        let text = input.to_lowercase();
39        let has_code_fence = text.contains("```") || text.contains("diff --git");
40        let has_patch_keywords = [
41            "apply_patch",
42            "unified diff",
43            "patch",
44            "edit_file",
45            "create_file",
46        ]
47        .iter()
48        .any(|k| text.contains(k));
49        let retrieval = [
50            "search",
51            "web",
52            "google",
53            "docs",
54            "cite",
55            "source",
56            "up-to-date",
57        ]
58        .iter()
59        .any(|k| text.contains(k));
60        let complex_markers = [
61            "plan",
62            "multi-step",
63            "decompose",
64            "orchestrate",
65            "architecture",
66            "benchmark",
67            "implement end-to-end",
68            "design api",
69            "refactor module",
70            "evaluate",
71            "tests suite",
72        ];
73        let complex = complex_markers.iter().any(|k| text.contains(k));
74        let long = text.len() > 1200;
75
76        if has_code_fence || has_patch_keywords {
77            return TaskClass::CodegenHeavy;
78        }
79        if retrieval {
80            return TaskClass::RetrievalHeavy;
81        }
82        if complex || long {
83            return TaskClass::Complex;
84        }
85        if text.len() < 120 {
86            return TaskClass::Simple;
87        }
88        TaskClass::Standard
89    }
90
91    pub fn route(vt_cfg: &VTCodeConfig, core: &CoreAgentConfig, input: &str) -> RouteDecision {
92        let router_cfg = &vt_cfg.router;
93        let class = if router_cfg.heuristic_classification {
94            Self::classify_heuristic(input)
95        } else {
96            // fallback: treat as standard
97            TaskClass::Standard
98        };
99
100        let model = match class {
101            TaskClass::Simple => non_empty_or(&router_cfg.models.simple, &core.model),
102            TaskClass::Standard => non_empty_or(&router_cfg.models.standard, &core.model),
103            TaskClass::Complex => non_empty_or(&router_cfg.models.complex, &core.model),
104            TaskClass::CodegenHeavy => non_empty_or(&router_cfg.models.codegen_heavy, &core.model),
105            TaskClass::RetrievalHeavy => {
106                non_empty_or(&router_cfg.models.retrieval_heavy, &core.model)
107            }
108        };
109
110        RouteDecision {
111            class,
112            selected_model: model.to_string(),
113        }
114    }
115
116    /// Optional LLM-based classification when `router.llm_router_model` is set.
117    /// Falls back to heuristics on any error.
118    pub async fn route_async(
119        vt_cfg: &VTCodeConfig,
120        core: &CoreAgentConfig,
121        api_key: &str,
122        input: &str,
123    ) -> RouteDecision {
124        let router_cfg = &vt_cfg.router;
125        let mut class = if router_cfg.heuristic_classification {
126            Self::classify_heuristic(input)
127        } else {
128            TaskClass::Standard
129        };
130
131        if !router_cfg.llm_router_model.trim().is_empty() {
132            if let Ok(provider) =
133                create_provider_for_model(&router_cfg.llm_router_model, api_key.to_string())
134            {
135                let sys = "You are a routing classifier. Output only one label: simple | standard | complex | codegen_heavy | retrieval_heavy. Choose the best class for the user's last message. No prose.".to_string();
136                let req = uni::LLMRequest {
137                    messages: vec![uni::Message::user(input.to_string())],
138                    system_prompt: Some(sys),
139                    tools: None,
140                    model: router_cfg.llm_router_model.clone(),
141                    max_tokens: Some(8),
142                    temperature: Some(0.0),
143                    stream: false,
144                    tool_choice: Some(uni::ToolChoice::none()),
145                    parallel_tool_calls: None,
146                    parallel_tool_config: None,
147                    reasoning_effort: Some(vt_cfg.agent.reasoning_effort.clone()),
148                };
149                if let Ok(resp) = provider.generate(req).await {
150                    if let Some(text) = resp.content {
151                        let t = text.trim().to_lowercase();
152                        class = match t {
153                            x if x.contains("codegen") => TaskClass::CodegenHeavy,
154                            x if x.contains("retrieval") => TaskClass::RetrievalHeavy,
155                            x if x.contains("complex") => TaskClass::Complex,
156                            x if x.contains("simple") => TaskClass::Simple,
157                            _ => TaskClass::Standard,
158                        };
159                    }
160                }
161            }
162        }
163
164        let model = match class {
165            TaskClass::Simple => non_empty_or(&router_cfg.models.simple, &core.model),
166            TaskClass::Standard => non_empty_or(&router_cfg.models.standard, &core.model),
167            TaskClass::Complex => non_empty_or(&router_cfg.models.complex, &core.model),
168            TaskClass::CodegenHeavy => non_empty_or(&router_cfg.models.codegen_heavy, &core.model),
169            TaskClass::RetrievalHeavy => {
170                non_empty_or(&router_cfg.models.retrieval_heavy, &core.model)
171            }
172        };
173
174        RouteDecision {
175            class,
176            selected_model: model.to_string(),
177        }
178    }
179}
180
181fn non_empty_or<'a>(value: &'a str, fallback: &'a str) -> &'a str {
182    if value.trim().is_empty() {
183        fallback
184    } else {
185        value
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn test_task_class_display() {
195        assert_eq!(format!("{}", TaskClass::Simple), "simple");
196        assert_eq!(format!("{}", TaskClass::Standard), "standard");
197        assert_eq!(format!("{}", TaskClass::Complex), "complex");
198        assert_eq!(format!("{}", TaskClass::CodegenHeavy), "codegen_heavy");
199        assert_eq!(format!("{}", TaskClass::RetrievalHeavy), "retrieval_heavy");
200    }
201}