Skip to main content

vtcode_core/llm/
lightweight_routing.rs

1use anyhow::{Context, Result};
2
3use crate::config::api_keys::{ApiKeySources, get_api_key};
4use crate::config::constants::model_helpers;
5use crate::config::loader::VTCodeConfig;
6use crate::config::models::{ModelId, Provider};
7use crate::config::types::AgentConfig as RuntimeAgentConfig;
8use crate::llm::factory::{ProviderConfig, create_provider_with_config, infer_provider_from_model};
9use crate::llm::provider::LLMProvider;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum LightweightFeature {
13    Memory,
14    PromptSuggestions,
15    PromptRefinement,
16    AutoModeReview,
17    AutoModeProbe,
18    LargeReadSummary,
19    WebSummary,
20    GitHistorySummary,
21    Subagent,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct ModelRoute {
26    pub provider_name: String,
27    pub model: String,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum LightweightRouteSource {
32    FeatureOverride,
33    SharedConfigured,
34    SharedAutomatic,
35    MainModel,
36}
37
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct LightweightRouteResolution {
40    pub primary: ModelRoute,
41    pub fallback: Option<ModelRoute>,
42    pub source: LightweightRouteSource,
43    pub warning: Option<String>,
44}
45
46impl LightweightRouteResolution {
47    pub fn uses_lightweight_model(&self) -> bool {
48        !matches!(self.source, LightweightRouteSource::MainModel)
49    }
50
51    pub fn fallback_to_main_model(&self) -> Option<&ModelRoute> {
52        self.fallback.as_ref()
53    }
54}
55
56pub fn resolve_lightweight_route(
57    runtime_config: &RuntimeAgentConfig,
58    vt_cfg: Option<&VTCodeConfig>,
59    feature: LightweightFeature,
60    explicit_override_model: Option<&str>,
61) -> LightweightRouteResolution {
62    let main_route = main_model_route(runtime_config);
63    let main_provider = main_route.provider_name.as_str();
64
65    let mut warning = None;
66    if let Some(configured_model) = explicit_override_model
67        .map(str::trim)
68        .filter(|value| !value.is_empty())
69    {
70        if let Some(route) = route_for_candidate(main_provider, configured_model) {
71            return LightweightRouteResolution {
72                fallback: (route != main_route).then_some(main_route),
73                primary: route,
74                source: LightweightRouteSource::FeatureOverride,
75                warning: None,
76            };
77        }
78
79        warning = Some(format!(
80            "ignored lightweight override model '{}' because it does not match the active provider '{}'",
81            configured_model, main_provider
82        ));
83    }
84
85    let Some(vt_cfg) = vt_cfg else {
86        return LightweightRouteResolution {
87            primary: main_route,
88            fallback: None,
89            source: LightweightRouteSource::MainModel,
90            warning,
91        };
92    };
93
94    let shared_cfg = &vt_cfg.agent.small_model;
95    if !shared_cfg.enabled || !feature_uses_shared_model(shared_cfg, feature) {
96        return LightweightRouteResolution {
97            primary: main_route,
98            fallback: None,
99            source: LightweightRouteSource::MainModel,
100            warning,
101        };
102    }
103
104    let configured_model = shared_cfg.model.trim();
105    if !configured_model.is_empty() {
106        if let Some(route) = route_for_candidate(main_provider, configured_model) {
107            return LightweightRouteResolution {
108                fallback: (route != main_route).then_some(main_route),
109                primary: route,
110                source: LightweightRouteSource::SharedConfigured,
111                warning,
112            };
113        }
114
115        warning = Some(format!(
116            "ignored lightweight model '{}' because it does not match the active provider '{}'",
117            configured_model, main_provider
118        ));
119    }
120
121    let primary = ModelRoute {
122        provider_name: main_route.provider_name.clone(),
123        model: auto_lightweight_model(main_provider, &main_route.model),
124    };
125    LightweightRouteResolution {
126        fallback: (primary != main_route).then_some(main_route),
127        primary,
128        source: LightweightRouteSource::SharedAutomatic,
129        warning,
130    }
131}
132
133pub fn main_model_route(runtime_config: &RuntimeAgentConfig) -> ModelRoute {
134    let provider_name = if runtime_config.provider.trim().is_empty() {
135        infer_provider_from_model(&runtime_config.model)
136            .map(|provider| provider.to_string().to_lowercase())
137            .unwrap_or_else(|| "gemini".to_string())
138    } else {
139        runtime_config.provider.to_lowercase()
140    };
141
142    ModelRoute {
143        provider_name,
144        model: runtime_config.model.clone(),
145    }
146}
147
148pub fn auto_lightweight_model(provider_name: &str, active_model: &str) -> String {
149    let trimmed_model = active_model.trim();
150    let provider = resolve_provider_for_model(provider_name, trimmed_model);
151
152    if let Ok(model_id) = trimmed_model.parse::<ModelId>() {
153        if model_id.is_efficient_variant() {
154            return model_id.as_str().to_string();
155        }
156
157        if let Some(lightweight_model) = model_id.preferred_lightweight_variant() {
158            return lightweight_model.as_str().to_string();
159        }
160    }
161
162    if let Some(lightweight_model) = preferred_lightweight_model_slug(provider, trimmed_model) {
163        return lightweight_model;
164    }
165
166    provider_default_lightweight_model(provider)
167        .or_else(|| model_helpers::default_for(provider_name))
168        .unwrap_or(trimmed_model)
169        .to_string()
170}
171
172pub fn lightweight_model_choices(provider_name: &str, active_model: &str) -> Vec<String> {
173    let provider = resolve_provider_for_model(provider_name, active_model);
174    let auto_model = auto_lightweight_model(provider_name, active_model);
175    let mut choices = Vec::new();
176
177    if !auto_model.trim().is_empty() {
178        choices.push(auto_model.clone());
179    }
180    if !active_model.trim().is_empty() {
181        choices.push(active_model.trim().to_string());
182    }
183
184    if let Some(models) = model_helpers::supported_for(provider.as_ref()) {
185        for model in models {
186            let include = model
187                .parse::<ModelId>()
188                .map(|model_id| model_id.is_efficient_variant())
189                .unwrap_or(false)
190                || model.eq_ignore_ascii_case(active_model.trim());
191            if include {
192                choices.push((*model).to_string());
193            }
194        }
195    }
196
197    choices.sort();
198    choices.dedup();
199    if let Some(auto_index) = choices
200        .iter()
201        .position(|candidate| candidate.eq_ignore_ascii_case(auto_model.as_str()))
202    {
203        let auto = choices.remove(auto_index);
204        choices.insert(0, auto);
205    }
206    choices
207}
208
209pub fn create_provider_for_model_route(
210    route: &ModelRoute,
211    runtime_config: &RuntimeAgentConfig,
212    vt_cfg: Option<&VTCodeConfig>,
213) -> Result<Box<dyn LLMProvider>> {
214    let api_key = resolve_api_key_for_model_route(route, runtime_config);
215    create_provider_with_config(
216        &route.provider_name,
217        ProviderConfig {
218            api_key,
219            openai_chatgpt_auth: runtime_config.openai_chatgpt_auth.clone(),
220            copilot_auth: vt_cfg.map(|cfg| cfg.auth.copilot.clone()),
221            base_url: None,
222            model: Some(route.model.clone()),
223            prompt_cache: Some(runtime_config.prompt_cache.clone()),
224            timeouts: None,
225            openai: vt_cfg.map(|cfg| cfg.provider.openai.clone()),
226            anthropic: vt_cfg.map(|cfg| cfg.provider.anthropic.clone()),
227            model_behavior: runtime_config.model_behavior.clone(),
228            workspace_root: Some(runtime_config.workspace.clone()),
229        },
230    )
231    .with_context(|| {
232        format!(
233            "Failed to initialize lightweight provider '{}' for model '{}'",
234            route.provider_name, route.model
235        )
236    })
237}
238
239pub fn resolve_api_key_for_model_route(
240    route: &ModelRoute,
241    runtime_config: &RuntimeAgentConfig,
242) -> Option<String> {
243    if route
244        .provider_name
245        .eq_ignore_ascii_case(main_model_route(runtime_config).provider_name.as_str())
246        && !runtime_config.api_key.trim().is_empty()
247    {
248        return Some(runtime_config.api_key.clone());
249    }
250
251    get_api_key(&route.provider_name, &ApiKeySources::default()).ok()
252}
253
254fn feature_uses_shared_model(
255    shared_cfg: &vtcode_config::core::agent::AgentSmallModelConfig,
256    feature: LightweightFeature,
257) -> bool {
258    match feature {
259        LightweightFeature::Memory => shared_cfg.use_for_memory,
260        LightweightFeature::LargeReadSummary => shared_cfg.use_for_large_reads,
261        LightweightFeature::WebSummary => shared_cfg.use_for_web_summary,
262        LightweightFeature::GitHistorySummary => shared_cfg.use_for_git_history,
263        LightweightFeature::PromptSuggestions
264        | LightweightFeature::PromptRefinement
265        | LightweightFeature::AutoModeReview
266        | LightweightFeature::AutoModeProbe
267        | LightweightFeature::Subagent => true,
268    }
269}
270
271fn route_for_candidate(main_provider: &str, candidate_model: &str) -> Option<ModelRoute> {
272    if infer_provider_from_model(candidate_model)
273        .map(|provider| !provider.as_ref().eq_ignore_ascii_case(main_provider))
274        .unwrap_or(false)
275    {
276        return None;
277    }
278
279    Some(ModelRoute {
280        provider_name: main_provider.to_string(),
281        model: candidate_model.to_string(),
282    })
283}
284
285fn provider_from_name(provider_name: &str) -> Provider {
286    known_provider_from_name(provider_name).unwrap_or(Provider::Gemini)
287}
288
289fn resolve_provider_for_model(provider_name: &str, active_model: &str) -> Provider {
290    known_provider_from_name(provider_name)
291        .or_else(|| infer_provider_from_model(active_model))
292        .unwrap_or_else(|| provider_from_name(provider_name))
293}
294
295fn known_provider_from_name(provider_name: &str) -> Option<Provider> {
296    match provider_name.to_ascii_lowercase().as_str() {
297        "openai" => Some(Provider::OpenAI),
298        "anthropic" => Some(Provider::Anthropic),
299        "copilot" => Some(Provider::Copilot),
300        "deepseek" => Some(Provider::DeepSeek),
301        "gemini" | "google" => Some(Provider::Gemini),
302        "openrouter" => Some(Provider::OpenRouter),
303        "ollama" => Some(Provider::Ollama),
304        "lmstudio" => Some(Provider::LmStudio),
305        "llamacpp" | "llama.cpp" | "llama-cpp" => Some(Provider::LlamaCpp),
306        "moonshot" => Some(Provider::Moonshot),
307        "zai" => Some(Provider::ZAI),
308        "minimax" => Some(Provider::Minimax),
309        "huggingface" => Some(Provider::HuggingFace),
310        "stepfun" => Some(Provider::StepFun),
311        "evolink" => Some(Provider::Evolink),
312        _ => None,
313    }
314}
315
316fn preferred_lightweight_model_slug(provider: Provider, active_model: &str) -> Option<String> {
317    let trimmed_model = active_model.trim();
318    let lower = trimmed_model.to_ascii_lowercase();
319
320    match provider {
321        Provider::Anthropic => {
322            if lower.contains("haiku") {
323                return Some(ModelId::ClaudeHaiku45.as_str().to_string());
324            }
325            if lower.contains("sonnet") || lower.contains("opus") {
326                return Some(ModelId::ClaudeHaiku45.as_str().to_string());
327            }
328            None
329        }
330        Provider::OpenAI => {
331            if lower.contains("gpt-5.4-mini") || lower.contains("gpt-5.4-nano") {
332                return Some(trimmed_model.to_string());
333            }
334            if lower.contains("gpt-5.4") {
335                return Some(ModelId::GPT54Mini.as_str().to_string());
336            }
337            if lower.contains("gpt-5-mini") || lower.contains("gpt-5-nano") {
338                return Some(trimmed_model.to_string());
339            }
340            if lower.contains("gpt-5.") || lower == "gpt-5" || lower.contains("gpt-5-codex") {
341                return Some(ModelId::GPT54Mini.as_str().to_string());
342            }
343            None
344        }
345        Provider::Copilot => {
346            if lower.contains("gpt-5.4-mini") {
347                return Some(trimmed_model.to_string());
348            }
349            if lower.contains("gpt-5") || lower.contains("claude") {
350                return Some(ModelId::CopilotGPT54Mini.as_str().to_string());
351            }
352            None
353        }
354        Provider::DeepSeek => {
355            if lower.contains("flash") || lower.contains("chat") {
356                return Some(trimmed_model.to_string());
357            }
358            if lower.contains("pro") || lower.contains("reasoner") {
359                return Some(trimmed_model.to_string());
360            }
361            None
362        }
363        Provider::Gemini => {
364            if lower.contains("flash-lite") || lower.contains("flash preview") {
365                return Some(trimmed_model.to_string());
366            }
367            if lower.contains("3.1") {
368                return Some(ModelId::Gemini31FlashLitePreview.as_str().to_string());
369            }
370            if lower.contains("gemini-3") || lower.contains("gemini 3") {
371                return Some(ModelId::Gemini35Flash.as_str().to_string());
372            }
373            None
374        }
375        Provider::ZAI => {
376            if lower.contains("glm-5.1") {
377                return Some(ModelId::ZaiGlm5.as_str().to_string());
378            }
379            if lower.contains("glm-5") {
380                return Some(ModelId::ZaiGlm5.as_str().to_string());
381            }
382            None
383        }
384        Provider::Minimax => {
385            if lower.contains("m2.5") {
386                return Some(trimmed_model.to_string());
387            }
388            if lower.contains("m2.7") {
389                return Some(ModelId::MinimaxM25.as_str().to_string());
390            }
391            None
392        }
393        Provider::StepFun => Some(trimmed_model.to_string()),
394        Provider::Evolink => Some(trimmed_model.to_string()),
395        _ => None,
396    }
397}
398
399fn provider_default_lightweight_model(provider: Provider) -> Option<&'static str> {
400    match provider {
401        Provider::OpenAI => Some(ModelId::GPT54Mini.as_str()),
402        Provider::Anthropic => Some(ModelId::ClaudeHaiku45.as_str()),
403        Provider::Copilot => Some(ModelId::CopilotGPT54Mini.as_str()),
404        Provider::DeepSeek => Some(ModelId::DeepSeekV4Flash.as_str()),
405        Provider::Gemini => Some(ModelId::Gemini35Flash.as_str()),
406        Provider::ZAI => Some(ModelId::ZaiGlm5.as_str()),
407        Provider::Minimax => Some(ModelId::MinimaxM25.as_str()),
408        Provider::StepFun => Some(ModelId::StepFun37Flash.as_str()),
409        Provider::Evolink => Some(ModelId::EvolinkGpt52.as_str()),
410        _ => None,
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417
418    fn runtime_config() -> RuntimeAgentConfig {
419        RuntimeAgentConfig {
420            model: ModelId::GPT54.as_str().to_string(),
421            api_key: "test-key".to_string(),
422            provider: "openai".to_string(),
423            openai_chatgpt_auth: None,
424            api_key_env: "OPENAI_API_KEY".to_string(),
425            workspace: std::env::temp_dir().join("vtcode-lightweight-routing-tests"),
426            verbose: false,
427            quiet: false,
428            theme: "default".to_string(),
429            reasoning_effort: Default::default(),
430            ui_surface: Default::default(),
431            prompt_cache: Default::default(),
432            model_source: Default::default(),
433            custom_api_keys: Default::default(),
434            checkpointing_enabled: false,
435            checkpointing_storage_dir: None,
436            checkpointing_max_snapshots: 0,
437            checkpointing_max_age_days: None,
438            max_conversation_turns: 0,
439            model_behavior: None,
440        }
441    }
442
443    #[test]
444    fn explicit_override_uses_active_provider() {
445        let runtime = runtime_config();
446        let route = resolve_lightweight_route(
447            &runtime,
448            Some(&VTCodeConfig::default()),
449            LightweightFeature::Memory,
450            Some("gpt-5-mini"),
451        );
452
453        assert_eq!(route.primary.provider_name, "openai");
454        assert_eq!(route.primary.model, "gpt-5-mini");
455        assert_eq!(route.source, LightweightRouteSource::FeatureOverride);
456    }
457
458    #[test]
459    fn cross_provider_shared_model_falls_back_to_auto_same_provider() {
460        let runtime = runtime_config();
461        let mut vt_cfg = VTCodeConfig::default();
462        vt_cfg.agent.small_model.model = "claude-4-5-haiku".to_string();
463
464        let route = resolve_lightweight_route(
465            &runtime,
466            Some(&vt_cfg),
467            LightweightFeature::PromptSuggestions,
468            None,
469        );
470
471        assert_eq!(route.primary.provider_name, "openai");
472        assert_eq!(route.primary.model, ModelId::GPT54Mini.as_str());
473        assert_eq!(route.source, LightweightRouteSource::SharedAutomatic);
474        assert!(route.warning.is_some());
475    }
476
477    #[test]
478    fn auto_lightweight_model_prefers_same_generation_openai_sibling() {
479        assert_eq!(
480            auto_lightweight_model("openai", ModelId::GPT54.as_str()),
481            ModelId::GPT54Mini.as_str()
482        );
483    }
484
485    #[test]
486    fn auto_lightweight_model_uses_closest_anthropic_haiku_pair() {
487        assert_eq!(
488            auto_lightweight_model("anthropic", ModelId::ClaudeSonnet46.as_str()),
489            ModelId::ClaudeHaiku45.as_str()
490        );
491        assert_eq!(
492            auto_lightweight_model("anthropic", "claude-sonnet-4.5"),
493            ModelId::ClaudeHaiku45.as_str()
494        );
495    }
496
497    #[test]
498    fn auto_lightweight_model_uses_lower_generation_glm_pair() {
499        assert_eq!(
500            auto_lightweight_model("zai", ModelId::ZaiGlm51.as_str()),
501            ModelId::ZaiGlm5.as_str()
502        );
503    }
504
505    #[test]
506    fn auto_lightweight_model_prefers_same_generation_gemini_flash_lite() {
507        assert_eq!(
508            auto_lightweight_model("gemini", ModelId::Gemini31ProPreview.as_str()),
509            ModelId::Gemini31FlashLitePreview.as_str()
510        );
511    }
512
513    #[test]
514    fn auto_lightweight_model_infers_family_for_custom_provider() {
515        assert_eq!(
516            auto_lightweight_model("mycorp", ModelId::GPT54.as_str()),
517            ModelId::GPT54Mini.as_str()
518        );
519    }
520
521    #[test]
522    fn disabled_feature_uses_main_model() {
523        let runtime = runtime_config();
524        let mut vt_cfg = VTCodeConfig::default();
525        vt_cfg.agent.small_model.use_for_memory = false;
526
527        let route =
528            resolve_lightweight_route(&runtime, Some(&vt_cfg), LightweightFeature::Memory, None);
529
530        assert_eq!(route.primary.model, ModelId::GPT54.as_str());
531        assert_eq!(route.source, LightweightRouteSource::MainModel);
532        assert!(route.fallback.is_none());
533    }
534}