Skip to main content

vtcode_core/llm/
model_resolver.rs

1use std::borrow::Cow;
2use std::str::FromStr;
3
4use crate::config::api_keys::api_key_env_var;
5use crate::config::models::{
6    ModelCatalogEntry, ModelId, ModelPricing, Provider, catalog_provider_keys, model_catalog_entry,
7};
8use crate::llm::provider::Usage;
9use vtcode_config::auth::{AuthCredentialsStoreMode, CustomApiKeyStorage};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ModelAvailability {
13    Available,
14    MissingCredential,
15    ManagedAuthAvailable,
16    Misconfigured,
17    LocalOnly,
18}
19
20impl ModelAvailability {
21    pub fn requires_api_key(&self) -> bool {
22        matches!(self, Self::MissingCredential | Self::Misconfigured)
23    }
24
25    pub fn uses_managed_auth(&self) -> bool {
26        matches!(self, Self::ManagedAuthAvailable)
27    }
28}
29
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct DynamicModelMeta {
32    pub display_name: String,
33    pub description: Option<String>,
34    pub context_window: Option<usize>,
35}
36
37#[derive(Debug, Clone, Copy)]
38pub struct DynamicModelRef<'a> {
39    pub provider: Provider,
40    pub model_id: &'a str,
41}
42
43#[derive(Debug, Clone)]
44pub struct ResolvedModel {
45    pub provider: Provider,
46    pub model_id: String,
47    pub catalog: Option<ModelCatalogEntry>,
48    pub dynamic: Option<DynamicModelMeta>,
49    pub availability: ModelAvailability,
50}
51
52impl ResolvedModel {
53    pub fn known_model(&self) -> bool {
54        self.catalog.is_some()
55    }
56
57    pub fn reasoning_supported(&self) -> bool {
58        self.catalog
59            .map(|entry| entry.reasoning)
60            .unwrap_or_else(|| self.provider.supports_reasoning_effort(&self.model_id))
61    }
62
63    pub fn service_tier_supported(&self) -> bool {
64        self.provider.supports_service_tier(&self.model_id)
65    }
66
67    pub fn supports_tool_calls(&self) -> bool {
68        self.catalog.map(|entry| entry.tool_call).unwrap_or(true)
69    }
70
71    pub fn context_window(&self) -> Option<usize> {
72        self.catalog
73            .map(|entry| entry.context_window)
74            .filter(|value| *value > 0)
75            .or_else(|| {
76                self.dynamic
77                    .as_ref()
78                    .and_then(|dynamic| dynamic.context_window)
79            })
80    }
81
82    pub fn input_modalities(&self) -> &'static [&'static str] {
83        self.catalog
84            .map(|entry| entry.input_modalities)
85            .unwrap_or(&[])
86    }
87
88    pub fn display_name(&self) -> Cow<'_, str> {
89        if let Some(catalog) = self.catalog {
90            return Cow::Borrowed(catalog.display_name);
91        }
92        if let Some(dynamic) = &self.dynamic {
93            return Cow::Borrowed(dynamic.display_name.as_str());
94        }
95        Cow::Borrowed(self.model_id.as_str())
96    }
97
98    pub fn description(&self) -> Option<Cow<'_, str>> {
99        if let Some(catalog) = self.catalog {
100            return (!catalog.description.is_empty()).then_some(Cow::Borrowed(catalog.description));
101        }
102        self.dynamic.as_ref().and_then(|dynamic| {
103            dynamic
104                .description
105                .as_deref()
106                .filter(|value| !value.is_empty())
107                .map(Cow::Borrowed)
108        })
109    }
110
111    pub fn pricing(&self) -> Option<ModelPricing> {
112        self.catalog.map(|entry| entry.pricing).filter(|pricing| {
113            pricing.input.is_some()
114                || pricing.output.is_some()
115                || pricing.cache_read.is_some()
116                || pricing.cache_write.is_some()
117        })
118    }
119
120    pub fn env_key(&self) -> String {
121        api_key_env_var(self.provider.as_ref())
122    }
123}
124
125pub struct ModelResolver;
126
127impl ModelResolver {
128    pub fn resolve(
129        provider_override: Option<&str>,
130        model: &str,
131        dynamic_models: &[DynamicModelRef<'_>],
132        dynamic_meta: Option<DynamicModelMeta>,
133    ) -> Option<ResolvedModel> {
134        let model = model.trim();
135        if model.is_empty() {
136            return None;
137        }
138
139        if let Some(provider) = provider_override.and_then(parse_provider_override) {
140            return Some(Self::resolve_for_provider(
141                provider,
142                model,
143                dynamic_models,
144                dynamic_meta,
145            ));
146        }
147
148        if let Ok(model_id) = ModelId::from_str(model) {
149            return Some(Self::resolve_for_model_id(
150                model,
151                model_id,
152                dynamic_models,
153                dynamic_meta,
154            ));
155        }
156
157        if let Some((provider, entry)) = find_catalog_provider(model) {
158            return Some(ResolvedModel {
159                provider,
160                model_id: model.to_string(),
161                catalog: Some(entry),
162                dynamic: dynamic_meta,
163                availability: Self::availability(provider, model),
164            });
165        }
166
167        if let Some(provider) = find_dynamic_provider(model, dynamic_models) {
168            return Some(Self::resolve_for_provider(
169                provider,
170                model,
171                dynamic_models,
172                dynamic_meta,
173            ));
174        }
175
176        let provider = heuristic_provider_from_model(model)?;
177        Some(Self::resolve_for_provider(
178            provider,
179            model,
180            dynamic_models,
181            dynamic_meta,
182        ))
183    }
184
185    pub fn resolve_provider(
186        provider_override: Option<&str>,
187        model: &str,
188        dynamic_models: &[DynamicModelRef<'_>],
189    ) -> Option<Provider> {
190        Self::resolve(provider_override, model, dynamic_models, None)
191            .map(|resolved| resolved.provider)
192    }
193
194    pub fn availability(provider: Provider, model: &str) -> ModelAvailability {
195        if provider.is_local() && !local_model_requires_remote_auth(provider, model) {
196            return ModelAvailability::LocalOnly;
197        }
198
199        if provider.uses_managed_auth() {
200            return ModelAvailability::ManagedAuthAvailable;
201        }
202
203        if provider == Provider::OpenAI
204            && vtcode_config::auth::load_openai_chatgpt_session()
205                .ok()
206                .flatten()
207                .is_some()
208        {
209            return ModelAvailability::ManagedAuthAvailable;
210        }
211
212        if provider == Provider::OpenRouter
213            && vtcode_config::auth::load_oauth_token()
214                .ok()
215                .flatten()
216                .is_some()
217        {
218            return ModelAvailability::ManagedAuthAvailable;
219        }
220
221        let env_key = api_key_env_var(provider.as_ref());
222        if env_key.trim().is_empty() {
223            return ModelAvailability::ManagedAuthAvailable;
224        }
225
226        if has_env_value(&env_key) || has_stored_key(provider) {
227            return ModelAvailability::Available;
228        }
229
230        if std::env::var(&env_key).is_ok() {
231            return ModelAvailability::Misconfigured;
232        }
233
234        ModelAvailability::MissingCredential
235    }
236
237    pub fn estimate_cost(pricing: ModelPricing, usage: &Usage) -> Option<f64> {
238        let input_cost = pricing.input?;
239        let output_cost = pricing.output?;
240
241        let mut total = (usage.prompt_tokens as f64 * input_cost)
242            + (usage.completion_tokens as f64 * output_cost);
243
244        if let Some(cache_read_cost) = pricing.cache_read {
245            total += usage.cache_read_tokens_or_fallback() as f64 * cache_read_cost;
246        }
247
248        if let Some(cache_write_cost) = pricing.cache_write {
249            total += usage.cache_creation_tokens_or_zero() as f64 * cache_write_cost;
250        }
251
252        Some(total)
253    }
254
255    fn resolve_for_provider(
256        provider: Provider,
257        model: &str,
258        dynamic_models: &[DynamicModelRef<'_>],
259        dynamic_meta: Option<DynamicModelMeta>,
260    ) -> ResolvedModel {
261        let catalog = model_catalog_entry(provider.as_ref(), model);
262        let dynamic = if catalog.is_some() || !has_dynamic_model(provider, model, dynamic_models) {
263            None
264        } else {
265            dynamic_meta.or_else(|| {
266                Some(DynamicModelMeta {
267                    display_name: model.to_string(),
268                    description: None,
269                    context_window: None,
270                })
271            })
272        };
273
274        ResolvedModel {
275            provider,
276            model_id: model.to_string(),
277            catalog,
278            dynamic,
279            availability: Self::availability(provider, model),
280        }
281    }
282
283    fn resolve_for_model_id(
284        requested_model: &str,
285        model_id: ModelId,
286        dynamic_models: &[DynamicModelRef<'_>],
287        dynamic_meta: Option<DynamicModelMeta>,
288    ) -> ResolvedModel {
289        let provider = model_id.provider();
290        let catalog = model_catalog_entry(provider.as_ref(), model_id.as_str());
291        let dynamic =
292            if catalog.is_some() || !has_dynamic_model(provider, requested_model, dynamic_models) {
293                None
294            } else {
295                dynamic_meta.or_else(|| {
296                    Some(DynamicModelMeta {
297                        display_name: requested_model.to_string(),
298                        description: None,
299                        context_window: None,
300                    })
301                })
302            };
303
304        ResolvedModel {
305            provider,
306            model_id: requested_model.to_string(),
307            catalog,
308            dynamic,
309            availability: Self::availability(provider, requested_model),
310        }
311    }
312}
313
314fn parse_provider_override(value: &str) -> Option<Provider> {
315    let trimmed = value.trim();
316    if trimmed.is_empty() {
317        None
318    } else {
319        Provider::from_str(trimmed).ok()
320    }
321}
322
323fn find_catalog_provider(model: &str) -> Option<(Provider, ModelCatalogEntry)> {
324    let mut matches: Vec<(Provider, ModelCatalogEntry)> = catalog_provider_keys()
325        .iter()
326        .filter_map(|provider_key| {
327            let provider = Provider::from_str(provider_key).ok()?;
328            model_catalog_entry(provider_key, model).map(|entry| (provider, entry))
329        })
330        .collect();
331    matches.sort_by_key(|(provider, _)| provider_precedence(*provider));
332    matches.into_iter().next()
333}
334
335fn find_dynamic_provider(model: &str, dynamic_models: &[DynamicModelRef<'_>]) -> Option<Provider> {
336    let mut matches = dynamic_models
337        .iter()
338        .filter(|candidate| candidate.model_id.eq_ignore_ascii_case(model))
339        .map(|candidate| candidate.provider);
340    let first = matches.next()?;
341    if matches.all(|provider| provider == first) {
342        Some(first)
343    } else {
344        None
345    }
346}
347
348fn has_dynamic_model(
349    provider: Provider,
350    model: &str,
351    dynamic_models: &[DynamicModelRef<'_>],
352) -> bool {
353    dynamic_models.iter().any(|candidate| {
354        candidate.provider == provider && candidate.model_id.eq_ignore_ascii_case(model)
355    })
356}
357
358fn provider_precedence(provider: Provider) -> usize {
359    match provider {
360        Provider::OpenAI => 0,
361        Provider::Anthropic => 1,
362        Provider::Gemini => 2,
363        Provider::DeepSeek => 3,
364        Provider::ZAI => 4,
365        Provider::Minimax => 5,
366        Provider::Mistral => 6,
367        Provider::Moonshot => 7,
368        Provider::OpenRouter => 8,
369        Provider::HuggingFace => 9,
370        Provider::Copilot => 10,
371        Provider::Ollama => 11,
372        Provider::LmStudio => 12,
373        Provider::LlamaCpp => 13,
374        Provider::OpenCodeZen => 14,
375        Provider::OpenCodeGo => 15,
376        Provider::MiMo => 16,
377        Provider::Qwen => 17,
378        Provider::StepFun => 18,
379        Provider::Evolink => 19,
380        Provider::Poolside => 20,
381    }
382}
383
384fn local_model_requires_remote_auth(provider: Provider, model: &str) -> bool {
385    provider == Provider::Ollama && (model.contains(":cloud") || model.contains("-cloud"))
386}
387
388fn has_env_value(env_key: &str) -> bool {
389    matches!(std::env::var(env_key), Ok(value) if !value.trim().is_empty())
390}
391
392fn has_stored_key(provider: Provider) -> bool {
393    CustomApiKeyStorage::new(provider.as_ref())
394        .load(AuthCredentialsStoreMode::default())
395        .ok()
396        .flatten()
397        .is_some()
398}
399
400pub(crate) fn heuristic_provider_from_model(model: &str) -> Option<Provider> {
401    let trimmed = model.trim();
402    if trimmed.is_empty() {
403        return None;
404    }
405
406    if trimmed.contains(':') && !trimmed.contains('/') && !trimmed.contains('@') {
407        return Some(Provider::Ollama);
408    }
409
410    let model = trimmed.to_ascii_lowercase();
411    if model.starts_with("gpt-oss-")
412        || model.starts_with("gpt-")
413        || model.starts_with("o1")
414        || model.starts_with("o3")
415        || model.starts_with("o4")
416        || model.starts_with("codex")
417    {
418        Some(Provider::OpenAI)
419    } else if model == "copilot" || model.starts_with("copilot-") {
420        Some(Provider::Copilot)
421    } else if model.starts_with("claude-") {
422        Some(Provider::Anthropic)
423    } else if model.starts_with("deepseek-") {
424        Some(Provider::DeepSeek)
425    } else if model.starts_with("mistral-")
426        || model.starts_with("ministral-")
427        || model.starts_with("codestral-")
428    {
429        Some(Provider::Mistral)
430    } else if model.contains("gemini") || model.starts_with("palm") {
431        Some(Provider::Gemini)
432    } else if model.starts_with("glm-") {
433        Some(Provider::ZAI)
434    } else if model.starts_with("lmstudio-community/") {
435        Some(Provider::LmStudio)
436    } else if model.starts_with("mimo-") {
437        Some(Provider::MiMo)
438    } else if model.starts_with("qwen3.") || model.starts_with("qwen-") {
439        Some(Provider::Qwen)
440    } else if model.starts_with("step-") {
441        Some(Provider::StepFun)
442    } else if model.starts_with("moonshot-") || model.starts_with("kimi-") {
443        Some(Provider::Moonshot)
444    } else if model.starts_with("opencode/") || model.starts_with("opencode-zen/") {
445        Some(Provider::OpenCodeZen)
446    } else if model.starts_with("opencode-go/") {
447        Some(Provider::OpenCodeGo)
448    } else if model.starts_with("poolside/") {
449        Some(Provider::Poolside)
450    } else if model.starts_with("deepseek-ai/")
451        || model.starts_with("openai/gpt-oss-")
452        || model.starts_with("zai-org/")
453        || model.starts_with("moonshotai/")
454        || model.starts_with("minimaxai/")
455        || model.starts_with("nvidia/")
456    {
457        Some(Provider::HuggingFace)
458    } else if model.starts_with("mixtral-")
459        || model.starts_with("qwen-")
460        || model.starts_with("meta-")
461        || model.starts_with("llama-")
462        || model.starts_with("command-")
463        || model.contains('/')
464        || model.contains('@')
465    {
466        Some(Provider::OpenRouter)
467    } else {
468        None
469    }
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475
476    #[test]
477    fn resolver_prefers_catalog_match_over_heuristic() {
478        let resolved = ModelResolver::resolve(None, "gpt-5.4", &[], None).expect("model");
479
480        assert_eq!(resolved.provider, Provider::OpenAI);
481        assert!(resolved.known_model());
482        assert_eq!(resolved.display_name(), "GPT-5.4");
483    }
484
485    #[test]
486    fn resolver_uses_model_id_to_disambiguate_shared_opencode_slugs() {
487        let bare = ModelResolver::resolve(None, "glm-5.1", &[], None).expect("bare model");
488        assert_eq!(bare.provider, Provider::ZAI);
489
490        let zen =
491            ModelResolver::resolve(None, "opencode/glm-5.1", &[], None).expect("opencode zen");
492        assert_eq!(zen.provider, Provider::OpenCodeZen);
493        assert!(zen.known_model());
494        assert_eq!(zen.display_name(), "GLM-5.1 (OpenCode Zen)");
495
496        let go =
497            ModelResolver::resolve(None, "opencode-go/glm-5.1", &[], None).expect("opencode go");
498        assert_eq!(go.provider, Provider::OpenCodeGo);
499        assert!(go.known_model());
500        assert_eq!(go.display_name(), "GLM-5.1 (OpenCode Go)");
501    }
502
503    #[test]
504    fn resolver_uses_provider_override_for_dynamic_model() {
505        let dynamic_models = [DynamicModelRef {
506            provider: Provider::Ollama,
507            model_id: "custom-local-model",
508        }];
509        let resolved = ModelResolver::resolve(
510            Some("ollama"),
511            "custom-local-model",
512            &dynamic_models,
513            Some(DynamicModelMeta {
514                display_name: "Custom Local Model".to_string(),
515                description: Some("dynamic".to_string()),
516                context_window: Some(32_000),
517            }),
518        )
519        .expect("resolved model");
520
521        assert_eq!(resolved.provider, Provider::Ollama);
522        assert!(!resolved.known_model());
523        assert_eq!(resolved.context_window(), Some(32_000));
524    }
525
526    #[test]
527    fn estimate_cost_uses_usage_totals() {
528        let pricing = ModelPricing {
529            input: Some(0.001),
530            output: Some(0.002),
531            cache_read: Some(0.0001),
532            cache_write: Some(0.0002),
533        };
534        let usage = Usage {
535            prompt_tokens: 100,
536            completion_tokens: 50,
537            total_tokens: 150,
538            cached_prompt_tokens: Some(20),
539            cache_creation_tokens: Some(10),
540            cache_read_tokens: None,
541        };
542
543        let total = ModelResolver::estimate_cost(pricing, &usage).expect("cost");
544        assert!(total > 0.0);
545    }
546}