1use anyhow::{Context, Result};
2
3use crate::config::api_keys::{ApiKeySources, get_api_key};
4use crate::config::constants::model_helpers;
5use crate::config::loader::VTCodeConfig;
6use crate::config::models::{ModelId, Provider};
7use crate::config::types::AgentConfig as RuntimeAgentConfig;
8use crate::llm::factory::{ProviderConfig, create_provider_with_config, infer_provider_from_model};
9use crate::llm::provider::LLMProvider;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum LightweightFeature {
13 Memory,
14 PromptSuggestions,
15 PromptRefinement,
16 AutoModeReview,
17 AutoModeProbe,
18 LargeReadSummary,
19 WebSummary,
20 GitHistorySummary,
21 Subagent,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct ModelRoute {
26 pub provider_name: String,
27 pub model: String,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum LightweightRouteSource {
32 FeatureOverride,
33 SharedConfigured,
34 SharedAutomatic,
35 MainModel,
36}
37
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct LightweightRouteResolution {
40 pub primary: ModelRoute,
41 pub fallback: Option<ModelRoute>,
42 pub source: LightweightRouteSource,
43 pub warning: Option<String>,
44}
45
46impl LightweightRouteResolution {
47 pub fn uses_lightweight_model(&self) -> bool {
48 !matches!(self.source, LightweightRouteSource::MainModel)
49 }
50
51 pub fn fallback_to_main_model(&self) -> Option<&ModelRoute> {
52 self.fallback.as_ref()
53 }
54}
55
56pub fn resolve_lightweight_route(
57 runtime_config: &RuntimeAgentConfig,
58 vt_cfg: Option<&VTCodeConfig>,
59 feature: LightweightFeature,
60 explicit_override_model: Option<&str>,
61) -> LightweightRouteResolution {
62 let main_route = main_model_route(runtime_config);
63 let main_provider = main_route.provider_name.as_str();
64
65 let mut warning = None;
66 if let Some(configured_model) = explicit_override_model
67 .map(str::trim)
68 .filter(|value| !value.is_empty())
69 {
70 if let Some(route) = route_for_candidate(main_provider, configured_model) {
71 return LightweightRouteResolution {
72 fallback: (route != main_route).then_some(main_route),
73 primary: route,
74 source: LightweightRouteSource::FeatureOverride,
75 warning: None,
76 };
77 }
78
79 warning = Some(format!(
80 "ignored lightweight override model '{}' because it does not match the active provider '{}'",
81 configured_model, main_provider
82 ));
83 }
84
85 let Some(vt_cfg) = vt_cfg else {
86 return LightweightRouteResolution {
87 primary: main_route,
88 fallback: None,
89 source: LightweightRouteSource::MainModel,
90 warning,
91 };
92 };
93
94 let shared_cfg = &vt_cfg.agent.small_model;
95 if !shared_cfg.enabled || !feature_uses_shared_model(shared_cfg, feature) {
96 return LightweightRouteResolution {
97 primary: main_route,
98 fallback: None,
99 source: LightweightRouteSource::MainModel,
100 warning,
101 };
102 }
103
104 let configured_model = shared_cfg.model.trim();
105 if !configured_model.is_empty() {
106 if let Some(route) = route_for_candidate(main_provider, configured_model) {
107 return LightweightRouteResolution {
108 fallback: (route != main_route).then_some(main_route),
109 primary: route,
110 source: LightweightRouteSource::SharedConfigured,
111 warning,
112 };
113 }
114
115 warning = Some(format!(
116 "ignored lightweight model '{}' because it does not match the active provider '{}'",
117 configured_model, main_provider
118 ));
119 }
120
121 let primary = ModelRoute {
122 provider_name: main_route.provider_name.clone(),
123 model: auto_lightweight_model(main_provider, &main_route.model),
124 };
125 LightweightRouteResolution {
126 fallback: (primary != main_route).then_some(main_route),
127 primary,
128 source: LightweightRouteSource::SharedAutomatic,
129 warning,
130 }
131}
132
133pub fn main_model_route(runtime_config: &RuntimeAgentConfig) -> ModelRoute {
134 let provider_name = if runtime_config.provider.trim().is_empty() {
135 infer_provider_from_model(&runtime_config.model)
136 .map(|provider| provider.to_string().to_lowercase())
137 .unwrap_or_else(|| "gemini".to_string())
138 } else {
139 runtime_config.provider.to_lowercase()
140 };
141
142 ModelRoute {
143 provider_name,
144 model: runtime_config.model.clone(),
145 }
146}
147
148pub fn auto_lightweight_model(provider_name: &str, active_model: &str) -> String {
149 let trimmed_model = active_model.trim();
150 let provider = resolve_provider_for_model(provider_name, trimmed_model);
151
152 if let Ok(model_id) = trimmed_model.parse::<ModelId>() {
153 if model_id.is_efficient_variant() {
154 return model_id.as_str().to_string();
155 }
156
157 if let Some(lightweight_model) = model_id.preferred_lightweight_variant() {
158 return lightweight_model.as_str().to_string();
159 }
160 }
161
162 if let Some(lightweight_model) = preferred_lightweight_model_slug(provider, trimmed_model) {
163 return lightweight_model;
164 }
165
166 provider_default_lightweight_model(provider)
167 .or_else(|| model_helpers::default_for(provider_name))
168 .unwrap_or(trimmed_model)
169 .to_string()
170}
171
172pub fn lightweight_model_choices(provider_name: &str, active_model: &str) -> Vec<String> {
173 let provider = resolve_provider_for_model(provider_name, active_model);
174 let auto_model = auto_lightweight_model(provider_name, active_model);
175 let mut choices = Vec::new();
176
177 if !auto_model.trim().is_empty() {
178 choices.push(auto_model.clone());
179 }
180 if !active_model.trim().is_empty() {
181 choices.push(active_model.trim().to_string());
182 }
183
184 if let Some(models) = model_helpers::supported_for(provider.as_ref()) {
185 for model in models {
186 let include = model
187 .parse::<ModelId>()
188 .map(|model_id| model_id.is_efficient_variant())
189 .unwrap_or(false)
190 || model.eq_ignore_ascii_case(active_model.trim());
191 if include {
192 choices.push((*model).to_string());
193 }
194 }
195 }
196
197 choices.sort();
198 choices.dedup();
199 if let Some(auto_index) = choices
200 .iter()
201 .position(|candidate| candidate.eq_ignore_ascii_case(auto_model.as_str()))
202 {
203 let auto = choices.remove(auto_index);
204 choices.insert(0, auto);
205 }
206 choices
207}
208
209pub fn create_provider_for_model_route(
210 route: &ModelRoute,
211 runtime_config: &RuntimeAgentConfig,
212 vt_cfg: Option<&VTCodeConfig>,
213) -> Result<Box<dyn LLMProvider>> {
214 let api_key = resolve_api_key_for_model_route(route, runtime_config);
215 create_provider_with_config(
216 &route.provider_name,
217 ProviderConfig {
218 api_key,
219 openai_chatgpt_auth: runtime_config.openai_chatgpt_auth.clone(),
220 copilot_auth: vt_cfg.map(|cfg| cfg.auth.copilot.clone()),
221 base_url: None,
222 model: Some(route.model.clone()),
223 prompt_cache: Some(runtime_config.prompt_cache.clone()),
224 timeouts: None,
225 openai: vt_cfg.map(|cfg| cfg.provider.openai.clone()),
226 anthropic: vt_cfg.map(|cfg| cfg.provider.anthropic.clone()),
227 model_behavior: runtime_config.model_behavior.clone(),
228 workspace_root: Some(runtime_config.workspace.clone()),
229 },
230 )
231 .with_context(|| {
232 format!(
233 "Failed to initialize lightweight provider '{}' for model '{}'",
234 route.provider_name, route.model
235 )
236 })
237}
238
239pub fn resolve_api_key_for_model_route(
240 route: &ModelRoute,
241 runtime_config: &RuntimeAgentConfig,
242) -> Option<String> {
243 if route
244 .provider_name
245 .eq_ignore_ascii_case(main_model_route(runtime_config).provider_name.as_str())
246 && !runtime_config.api_key.trim().is_empty()
247 {
248 return Some(runtime_config.api_key.clone());
249 }
250
251 get_api_key(&route.provider_name, &ApiKeySources::default()).ok()
252}
253
254fn feature_uses_shared_model(
255 shared_cfg: &vtcode_config::core::agent::AgentSmallModelConfig,
256 feature: LightweightFeature,
257) -> bool {
258 match feature {
259 LightweightFeature::Memory => shared_cfg.use_for_memory,
260 LightweightFeature::LargeReadSummary => shared_cfg.use_for_large_reads,
261 LightweightFeature::WebSummary => shared_cfg.use_for_web_summary,
262 LightweightFeature::GitHistorySummary => shared_cfg.use_for_git_history,
263 LightweightFeature::PromptSuggestions
264 | LightweightFeature::PromptRefinement
265 | LightweightFeature::AutoModeReview
266 | LightweightFeature::AutoModeProbe
267 | LightweightFeature::Subagent => true,
268 }
269}
270
271fn route_for_candidate(main_provider: &str, candidate_model: &str) -> Option<ModelRoute> {
272 if infer_provider_from_model(candidate_model)
273 .map(|provider| !provider.as_ref().eq_ignore_ascii_case(main_provider))
274 .unwrap_or(false)
275 {
276 return None;
277 }
278
279 Some(ModelRoute {
280 provider_name: main_provider.to_string(),
281 model: candidate_model.to_string(),
282 })
283}
284
285fn provider_from_name(provider_name: &str) -> Provider {
286 known_provider_from_name(provider_name).unwrap_or(Provider::Gemini)
287}
288
289fn resolve_provider_for_model(provider_name: &str, active_model: &str) -> Provider {
290 known_provider_from_name(provider_name)
291 .or_else(|| infer_provider_from_model(active_model))
292 .unwrap_or_else(|| provider_from_name(provider_name))
293}
294
295fn known_provider_from_name(provider_name: &str) -> Option<Provider> {
296 match provider_name.to_ascii_lowercase().as_str() {
297 "openai" => Some(Provider::OpenAI),
298 "anthropic" => Some(Provider::Anthropic),
299 "copilot" => Some(Provider::Copilot),
300 "deepseek" => Some(Provider::DeepSeek),
301 "gemini" | "google" => Some(Provider::Gemini),
302 "openrouter" => Some(Provider::OpenRouter),
303 "ollama" => Some(Provider::Ollama),
304 "lmstudio" => Some(Provider::LmStudio),
305 "llamacpp" | "llama.cpp" | "llama-cpp" => Some(Provider::LlamaCpp),
306 "moonshot" => Some(Provider::Moonshot),
307 "zai" => Some(Provider::ZAI),
308 "minimax" => Some(Provider::Minimax),
309 "huggingface" => Some(Provider::HuggingFace),
310 "stepfun" => Some(Provider::StepFun),
311 "evolink" => Some(Provider::Evolink),
312 _ => None,
313 }
314}
315
316fn preferred_lightweight_model_slug(provider: Provider, active_model: &str) -> Option<String> {
317 let trimmed_model = active_model.trim();
318 let lower = trimmed_model.to_ascii_lowercase();
319
320 match provider {
321 Provider::Anthropic => {
322 if lower.contains("haiku") {
323 return Some(ModelId::ClaudeHaiku45.as_str().to_string());
324 }
325 if lower.contains("sonnet") || lower.contains("opus") {
326 return Some(ModelId::ClaudeHaiku45.as_str().to_string());
327 }
328 None
329 }
330 Provider::OpenAI => {
331 if lower.contains("gpt-5.4-mini") || lower.contains("gpt-5.4-nano") {
332 return Some(trimmed_model.to_string());
333 }
334 if lower.contains("gpt-5.4") {
335 return Some(ModelId::GPT54Mini.as_str().to_string());
336 }
337 if lower.contains("gpt-5-mini") || lower.contains("gpt-5-nano") {
338 return Some(trimmed_model.to_string());
339 }
340 if lower.contains("gpt-5.") || lower == "gpt-5" || lower.contains("gpt-5-codex") {
341 return Some(ModelId::GPT54Mini.as_str().to_string());
342 }
343 None
344 }
345 Provider::Copilot => {
346 if lower.contains("gpt-5.4-mini") {
347 return Some(trimmed_model.to_string());
348 }
349 if lower.contains("gpt-5") || lower.contains("claude") {
350 return Some(ModelId::CopilotGPT54Mini.as_str().to_string());
351 }
352 None
353 }
354 Provider::DeepSeek => {
355 if lower.contains("flash") || lower.contains("chat") {
356 return Some(trimmed_model.to_string());
357 }
358 if lower.contains("pro") || lower.contains("reasoner") {
359 return Some(trimmed_model.to_string());
360 }
361 None
362 }
363 Provider::Gemini => {
364 if lower.contains("flash-lite") || lower.contains("flash preview") {
365 return Some(trimmed_model.to_string());
366 }
367 if lower.contains("3.1") {
368 return Some(ModelId::Gemini31FlashLitePreview.as_str().to_string());
369 }
370 if lower.contains("gemini-3") || lower.contains("gemini 3") {
371 return Some(ModelId::Gemini35Flash.as_str().to_string());
372 }
373 None
374 }
375 Provider::ZAI => {
376 if lower.contains("glm-5.1") {
377 return Some(ModelId::ZaiGlm5.as_str().to_string());
378 }
379 if lower.contains("glm-5") {
380 return Some(ModelId::ZaiGlm5.as_str().to_string());
381 }
382 None
383 }
384 Provider::Minimax => {
385 if lower.contains("m2.5") {
386 return Some(trimmed_model.to_string());
387 }
388 if lower.contains("m2.7") {
389 return Some(ModelId::MinimaxM25.as_str().to_string());
390 }
391 None
392 }
393 Provider::StepFun => Some(trimmed_model.to_string()),
394 Provider::Evolink => Some(trimmed_model.to_string()),
395 _ => None,
396 }
397}
398
399fn provider_default_lightweight_model(provider: Provider) -> Option<&'static str> {
400 match provider {
401 Provider::OpenAI => Some(ModelId::GPT54Mini.as_str()),
402 Provider::Anthropic => Some(ModelId::ClaudeHaiku45.as_str()),
403 Provider::Copilot => Some(ModelId::CopilotGPT54Mini.as_str()),
404 Provider::DeepSeek => Some(ModelId::DeepSeekV4Flash.as_str()),
405 Provider::Gemini => Some(ModelId::Gemini35Flash.as_str()),
406 Provider::ZAI => Some(ModelId::ZaiGlm5.as_str()),
407 Provider::Minimax => Some(ModelId::MinimaxM25.as_str()),
408 Provider::StepFun => Some(ModelId::StepFun37Flash.as_str()),
409 Provider::Evolink => Some(ModelId::EvolinkGpt52.as_str()),
410 _ => None,
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417
418 fn runtime_config() -> RuntimeAgentConfig {
419 RuntimeAgentConfig {
420 model: ModelId::GPT54.as_str().to_string(),
421 api_key: "test-key".to_string(),
422 provider: "openai".to_string(),
423 openai_chatgpt_auth: None,
424 api_key_env: "OPENAI_API_KEY".to_string(),
425 workspace: std::env::temp_dir().join("vtcode-lightweight-routing-tests"),
426 verbose: false,
427 quiet: false,
428 theme: "default".to_string(),
429 reasoning_effort: Default::default(),
430 ui_surface: Default::default(),
431 prompt_cache: Default::default(),
432 model_source: Default::default(),
433 custom_api_keys: Default::default(),
434 checkpointing_enabled: false,
435 checkpointing_storage_dir: None,
436 checkpointing_max_snapshots: 0,
437 checkpointing_max_age_days: None,
438 max_conversation_turns: 0,
439 model_behavior: None,
440 }
441 }
442
443 #[test]
444 fn explicit_override_uses_active_provider() {
445 let runtime = runtime_config();
446 let route = resolve_lightweight_route(
447 &runtime,
448 Some(&VTCodeConfig::default()),
449 LightweightFeature::Memory,
450 Some("gpt-5-mini"),
451 );
452
453 assert_eq!(route.primary.provider_name, "openai");
454 assert_eq!(route.primary.model, "gpt-5-mini");
455 assert_eq!(route.source, LightweightRouteSource::FeatureOverride);
456 }
457
458 #[test]
459 fn cross_provider_shared_model_falls_back_to_auto_same_provider() {
460 let runtime = runtime_config();
461 let mut vt_cfg = VTCodeConfig::default();
462 vt_cfg.agent.small_model.model = "claude-4-5-haiku".to_string();
463
464 let route = resolve_lightweight_route(
465 &runtime,
466 Some(&vt_cfg),
467 LightweightFeature::PromptSuggestions,
468 None,
469 );
470
471 assert_eq!(route.primary.provider_name, "openai");
472 assert_eq!(route.primary.model, ModelId::GPT54Mini.as_str());
473 assert_eq!(route.source, LightweightRouteSource::SharedAutomatic);
474 assert!(route.warning.is_some());
475 }
476
477 #[test]
478 fn auto_lightweight_model_prefers_same_generation_openai_sibling() {
479 assert_eq!(
480 auto_lightweight_model("openai", ModelId::GPT54.as_str()),
481 ModelId::GPT54Mini.as_str()
482 );
483 }
484
485 #[test]
486 fn auto_lightweight_model_uses_closest_anthropic_haiku_pair() {
487 assert_eq!(
488 auto_lightweight_model("anthropic", ModelId::ClaudeSonnet46.as_str()),
489 ModelId::ClaudeHaiku45.as_str()
490 );
491 assert_eq!(
492 auto_lightweight_model("anthropic", "claude-sonnet-4.5"),
493 ModelId::ClaudeHaiku45.as_str()
494 );
495 }
496
497 #[test]
498 fn auto_lightweight_model_uses_lower_generation_glm_pair() {
499 assert_eq!(
500 auto_lightweight_model("zai", ModelId::ZaiGlm51.as_str()),
501 ModelId::ZaiGlm5.as_str()
502 );
503 }
504
505 #[test]
506 fn auto_lightweight_model_prefers_same_generation_gemini_flash_lite() {
507 assert_eq!(
508 auto_lightweight_model("gemini", ModelId::Gemini31ProPreview.as_str()),
509 ModelId::Gemini31FlashLitePreview.as_str()
510 );
511 }
512
513 #[test]
514 fn auto_lightweight_model_infers_family_for_custom_provider() {
515 assert_eq!(
516 auto_lightweight_model("mycorp", ModelId::GPT54.as_str()),
517 ModelId::GPT54Mini.as_str()
518 );
519 }
520
521 #[test]
522 fn disabled_feature_uses_main_model() {
523 let runtime = runtime_config();
524 let mut vt_cfg = VTCodeConfig::default();
525 vt_cfg.agent.small_model.use_for_memory = false;
526
527 let route =
528 resolve_lightweight_route(&runtime, Some(&vt_cfg), LightweightFeature::Memory, None);
529
530 assert_eq!(route.primary.model, ModelId::GPT54.as_str());
531 assert_eq!(route.source, LightweightRouteSource::MainModel);
532 assert!(route.fallback.is_none());
533 }
534}