Skip to main content

vtcode_core/models_manager/
manager.rs

1//! Models Manager - Coordinates model discovery, caching, and selection.
2//!
3//! This module provides the main `ModelsManager` struct that coordinates:
4//! - Local model presets (built-in configurations)
5//! - Remote model discovery (fetching from provider APIs)
6//! - Disk caching with TTL
7//! - Model family resolution
8
9use chrono::Utc;
10use hashbrown::HashSet;
11use std::path::PathBuf;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::RwLock;
15use tracing::{debug, error, info};
16
17use super::cache::{self, ModelsCache};
18use super::model_family::{ModelFamily, find_family_for_model};
19use super::model_presets::{ModelInfo, ModelPreset, builtin_model_presets, presets_for_provider};
20use crate::config::models::Provider;
21use crate::llm::providers::llamacpp::fetch_llamacpp_models;
22
23/// Cache file name
24const MODEL_CACHE_FILE: &str = "models_cache.json";
25
26/// Default cache TTL (2 minutes)
27const DEFAULT_MODEL_CACHE_TTL: Duration = Duration::from_secs(120);
28
29/// Default model for Gemini provider
30const GEMINI_DEFAULT_MODEL: &str = "gemini-3-flash-preview";
31
32/// Default model for OpenAI provider
33const OPENAI_DEFAULT_MODEL: &str = "gpt-5.4";
34
35/// Default model for Anthropic provider
36const ANTHROPIC_DEFAULT_MODEL: &str = "claude-opus-4-8";
37
38/// Coordinates remote model discovery plus cached metadata on disk.
39#[derive(Debug)]
40pub struct ModelsManager {
41    /// Local built-in model presets
42    local_models: Vec<ModelPreset>,
43    /// Remote models fetched from provider APIs
44    remote_models: RwLock<Vec<ModelInfo>>,
45    /// ETag for conditional requests
46    etag: RwLock<Option<String>>,
47    /// VT Code home directory for cache storage
48    vtcode_home: PathBuf,
49    /// Cache TTL
50    cache_ttl: Duration,
51    /// Current active provider
52    current_provider: RwLock<Provider>,
53    /// Whether remote model fetching is enabled
54    remote_models_enabled: bool,
55}
56
57impl Default for ModelsManager {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl ModelsManager {
64    /// Construct a new ModelsManager with default settings
65    pub fn new() -> Self {
66        let vtcode_home = Self::default_vtcode_home();
67        Self {
68            local_models: builtin_model_presets(),
69            remote_models: RwLock::new(Vec::new()),
70            etag: RwLock::new(None),
71            vtcode_home,
72            cache_ttl: DEFAULT_MODEL_CACHE_TTL,
73            current_provider: RwLock::new(Provider::default()),
74            remote_models_enabled: true,
75        }
76    }
77
78    /// Construct with a specific home directory
79    pub fn with_home(vtcode_home: PathBuf) -> Self {
80        Self {
81            local_models: builtin_model_presets(),
82            remote_models: RwLock::new(Vec::new()),
83            etag: RwLock::new(None),
84            vtcode_home,
85            cache_ttl: DEFAULT_MODEL_CACHE_TTL,
86            current_provider: RwLock::new(Provider::default()),
87            remote_models_enabled: true,
88        }
89    }
90
91    /// Construct with a specific provider
92    pub fn with_provider(provider: Provider) -> Self {
93        let vtcode_home = Self::default_vtcode_home();
94        Self {
95            local_models: presets_for_provider(provider),
96            remote_models: RwLock::new(Vec::new()),
97            etag: RwLock::new(None),
98            vtcode_home,
99            cache_ttl: DEFAULT_MODEL_CACHE_TTL,
100            current_provider: RwLock::new(provider),
101            remote_models_enabled: true,
102        }
103    }
104
105    /// Construct with specific home directory and provider
106    pub fn with_home_and_provider(vtcode_home: PathBuf, provider: Provider) -> Self {
107        Self {
108            local_models: presets_for_provider(provider),
109            remote_models: RwLock::new(Vec::new()),
110            etag: RwLock::new(None),
111            vtcode_home,
112            cache_ttl: DEFAULT_MODEL_CACHE_TTL,
113            current_provider: RwLock::new(provider),
114            remote_models_enabled: true,
115        }
116    }
117
118    /// Enable or disable remote model fetching
119    pub fn set_remote_models_enabled(&mut self, enabled: bool) {
120        self.remote_models_enabled = enabled;
121    }
122
123    /// Set the cache TTL
124    pub fn set_cache_ttl(&mut self, ttl: Duration) {
125        self.cache_ttl = ttl;
126    }
127
128    /// Get the default VT Code home directory
129    fn default_vtcode_home() -> PathBuf {
130        dirs::home_dir()
131            .map(|h| h.join(".vtcode"))
132            .unwrap_or_else(|| PathBuf::from(".vtcode"))
133    }
134
135    /// Refresh available models, using cache if fresh
136    pub async fn refresh_available_models(&self) -> anyhow::Result<()> {
137        if !self.remote_models_enabled {
138            debug!("Remote model fetching is disabled");
139            return Ok(());
140        }
141
142        // Try to load from cache first
143        if self.try_load_cache().await {
144            debug!("Using cached models");
145            return Ok(());
146        }
147
148        let provider = *self.current_provider.read().await;
149
150        match provider {
151            Provider::Ollama => {
152                debug!("Fetching remote models for Ollama...");
153                match self.fetch_ollama_models().await {
154                    Ok(models) => {
155                        info!("Fetched {} models from Ollama", models.len());
156                        self.apply_remote_models(models.clone()).await;
157                        self.persist_cache(&models, None).await;
158                        Ok(())
159                    }
160                    Err(e) => {
161                        error!("Failed to fetch Ollama models: {e}");
162                        // Fall back to local presets if fetch fails
163                        Ok(())
164                    }
165                }
166            }
167            Provider::LlamaCpp => {
168                debug!("Fetching remote models for llama.cpp...");
169                match self.fetch_llamacpp_models().await {
170                    Ok(models) => {
171                        info!("Fetched {} models from llama.cpp", models.len());
172                        self.apply_remote_models(models.clone()).await;
173                        self.persist_cache(&models, None).await;
174                        Ok(())
175                    }
176                    Err(e) => {
177                        error!("Failed to fetch llama.cpp models: {e}");
178                        Ok(())
179                    }
180                }
181            }
182            _ => {
183                // For other providers, we don't have remote discovery yet
184                info!(
185                    "Remote model discovery for {:?} not implemented, using local presets",
186                    provider
187                );
188                Ok(())
189            }
190        }
191    }
192
193    /// Fetch models from Ollama API
194    async fn fetch_ollama_models(&self) -> anyhow::Result<Vec<ModelInfo>> {
195        let client = reqwest::Client::new();
196        let resp = client.get("http://localhost:11434/api/tags").send().await?;
197
198        if !resp.status().is_success() {
199            return Err(anyhow::anyhow!("Ollama API returned {}", resp.status()));
200        }
201
202        let json: serde_json::Value = resp.json().await?;
203        let mut models = Vec::new();
204
205        if let Some(ollama_models) = json.get("models").and_then(|m| m.as_array()) {
206            for m in ollama_models {
207                if let Some(name) = m.get("name").and_then(|s| s.as_str()) {
208                    models.push(ModelInfo {
209                        slug: name.to_string(),
210                        display_name: format!("{} (Ollama)", name),
211                        description: format!("Ollama model: {}", name),
212                        provider: Provider::Ollama,
213                        default_reasoning_level: crate::config::types::ReasoningEffortLevel::Medium,
214                        supported_reasoning_levels: vec![],
215                        context_window: Some(32_000), // Default for most Ollama models
216                        supports_tool_use: true,
217                        supports_streaming: true,
218                        supports_reasoning: false,
219                        priority: 100,
220                        visibility: "list".to_string(),
221                        supported_in_api: true,
222                        upgrade: None,
223                    });
224                }
225            }
226        }
227
228        Ok(models)
229    }
230
231    async fn fetch_llamacpp_models(&self) -> anyhow::Result<Vec<ModelInfo>> {
232        let mut models = Vec::new();
233        for model in fetch_llamacpp_models(None).await? {
234            models.push(ModelInfo {
235                slug: model.clone(),
236                display_name: format!("{model} (llama.cpp)"),
237                description: format!("llama.cpp model: {model}"),
238                provider: Provider::LlamaCpp,
239                default_reasoning_level: crate::config::types::ReasoningEffortLevel::Medium,
240                supported_reasoning_levels: vec![],
241                context_window: Some(131_072),
242                supports_tool_use: true,
243                supports_streaming: true,
244                supports_reasoning: true,
245                priority: 100,
246                visibility: "list".to_string(),
247                supported_in_api: true,
248                upgrade: None,
249            });
250        }
251
252        Ok(models)
253    }
254
255    /// List available models for the current provider
256    pub async fn list_models(&self) -> Vec<ModelPreset> {
257        if let Err(err) = self.refresh_available_models().await {
258            error!("Failed to refresh available models: {err}");
259        }
260        let remote_models = self.remote_models.read().await;
261        self.build_available_models(remote_models.clone())
262    }
263
264    /// List available models for a specific provider
265    pub async fn list_models_for_provider(&self, provider: Provider) -> Vec<ModelPreset> {
266        let all_models = self.list_models().await;
267        all_models
268            .into_iter()
269            .filter(|m| m.provider == provider)
270            .collect()
271    }
272
273    /// Try to list models without async refresh (uses cache only)
274    pub fn try_list_models(&self) -> Result<Vec<ModelPreset>, tokio::sync::TryLockError> {
275        let remote_models = self.remote_models.try_read()?;
276        Ok(self.build_available_models(remote_models.clone()))
277    }
278
279    /// Get the model family for a given model slug
280    pub async fn construct_model_family(&self, model: &str) -> ModelFamily {
281        find_family_for_model(model)
282    }
283
284    /// Get the model to use, resolving defaults if not specified
285    pub async fn get_model(&self, model: Option<&str>) -> String {
286        if let Some(m) = model {
287            return m.to_string();
288        }
289
290        // Refresh models to ensure we have the latest
291        if let Err(err) = self.refresh_available_models().await {
292            error!("Failed to refresh available models: {err}");
293        }
294
295        // Return default for current provider
296        let provider = *self.current_provider.read().await;
297        self.get_default_model_for_provider(provider)
298    }
299
300    /// Get the default model for a specific provider
301    pub fn get_default_model_for_provider(&self, provider: Provider) -> String {
302        // First check if there's a default in local presets
303        if let Some(preset) = self
304            .local_models
305            .iter()
306            .find(|p| p.provider == provider && p.is_default)
307        {
308            return preset.model.clone();
309        }
310
311        // Fall back to hardcoded defaults
312        match provider {
313            Provider::Gemini => GEMINI_DEFAULT_MODEL.to_string(),
314            Provider::OpenAI => OPENAI_DEFAULT_MODEL.to_string(),
315            Provider::Anthropic => ANTHROPIC_DEFAULT_MODEL.to_string(),
316            Provider::Copilot => {
317                crate::config::constants::models::copilot::DEFAULT_MODEL.to_string()
318            }
319            Provider::DeepSeek => "deepseek-reasoner".to_string(),
320            Provider::ZAI => "glm-5.1".to_string(),
321            Provider::Minimax => {
322                crate::config::constants::models::minimax::DEFAULT_MODEL.to_string()
323            }
324            Provider::Mistral => {
325                crate::config::constants::models::mistral::MISTRAL_LARGE_3.to_string()
326            }
327            Provider::OpenRouter => "xiaomi/mimo-v2.5-pro".to_string(),
328            Provider::Ollama => "gpt-oss:20b".to_string(),
329            Provider::LmStudio => {
330                crate::config::constants::models::lmstudio::DEFAULT_MODEL.to_string()
331            }
332            Provider::LlamaCpp => {
333                crate::config::constants::models::llamacpp::DEFAULT_MODEL.to_string()
334            }
335            Provider::Moonshot => {
336                crate::config::constants::models::moonshot::DEFAULT_MODEL.to_string()
337            }
338            Provider::HuggingFace => "deepseek-ai/DeepSeek-V3-0324".to_string(),
339            Provider::OpenCodeZen => {
340                crate::config::constants::models::opencode_zen::DEFAULT_MODEL.to_string()
341            }
342            Provider::OpenCodeGo => {
343                crate::config::constants::models::opencode_go::DEFAULT_MODEL.to_string()
344            }
345            Provider::MiMo => crate::config::constants::models::mimo::DEFAULT_MODEL.to_string(),
346            Provider::Qwen => crate::config::constants::models::qwen::DEFAULT_MODEL.to_string(),
347            Provider::StepFun => {
348                crate::config::constants::models::stepfun::DEFAULT_MODEL.to_string()
349            }
350            Provider::Evolink => {
351                crate::config::constants::models::evolink::DEFAULT_MODEL.to_string()
352            }
353            Provider::Poolside => {
354                crate::config::constants::models::poolside::DEFAULT_MODEL.to_string()
355            }
356        }
357    }
358
359    /// Get model offline (without network) for testing
360    #[cfg(test)]
361    pub fn get_model_offline(model: Option<&str>) -> String {
362        model.unwrap_or(GEMINI_DEFAULT_MODEL).to_string()
363    }
364
365    /// Construct model family offline for testing
366    #[cfg(test)]
367    pub fn construct_model_family_offline(model: &str) -> ModelFamily {
368        find_family_for_model(model)
369    }
370
371    /// Apply remote models (replace cached state)
372    async fn apply_remote_models(&self, models: Vec<ModelInfo>) {
373        *self.remote_models.write().await = models;
374    }
375
376    /// Try to load from cache
377    async fn try_load_cache(&self) -> bool {
378        let cache_path = self.cache_path();
379        let cache = match cache::load_cache(&cache_path).await {
380            Ok(Some(cache)) => cache,
381            Ok(None) => {
382                debug!("No cache file found at {:?}", cache_path);
383                return false;
384            }
385            Err(err) => {
386                error!("Failed to load models cache: {err}");
387                return false;
388            }
389        };
390
391        if !cache.is_fresh(self.cache_ttl) {
392            debug!("Cache is stale (age: {:?})", cache.age());
393            return false;
394        }
395
396        let models: Vec<ModelInfo> = cache.models.into_iter().collect();
397
398        *self.etag.write().await = cache.etag;
399        self.apply_remote_models(models).await;
400        true
401    }
402
403    /// Persist cache to disk
404    async fn persist_cache(&self, models: &[ModelInfo], etag: Option<String>) {
405        let provider = *self.current_provider.read().await;
406        let cache = ModelsCache {
407            fetched_at: Utc::now(),
408            etag,
409            provider: provider.to_string(),
410            models: models.to_vec(),
411        };
412        let cache_path = self.cache_path();
413        if let Err(err) = cache::save_cache(&cache_path, &cache).await {
414            error!("Failed to write models cache: {err}");
415        }
416    }
417
418    /// Build available models by merging remote and local presets
419    fn build_available_models(&self, mut remote_models: Vec<ModelInfo>) -> Vec<ModelPreset> {
420        // Sort by priority
421        remote_models.sort_by(|a, b| a.priority.cmp(&b.priority));
422
423        // Convert remote models to presets
424        let remote_presets: Vec<ModelPreset> = remote_models.into_iter().map(Into::into).collect();
425        let existing_presets = self.local_models.clone();
426        let mut merged_presets = Self::merge_presets(remote_presets, existing_presets);
427        merged_presets = self.filter_visible_models(merged_presets);
428
429        // Ensure one default per provider
430        self.ensure_defaults(&mut merged_presets);
431
432        merged_presets
433    }
434
435    /// Filter to only visible models
436    fn filter_visible_models(&self, models: Vec<ModelPreset>) -> Vec<ModelPreset> {
437        models
438            .into_iter()
439            .filter(|model| model.show_in_picker && model.supported_in_api)
440            .collect()
441    }
442
443    /// Merge remote and local presets, preferring remote when duplicates exist
444    fn merge_presets(
445        remote_presets: Vec<ModelPreset>,
446        existing_presets: Vec<ModelPreset>,
447    ) -> Vec<ModelPreset> {
448        if remote_presets.is_empty() {
449            return existing_presets;
450        }
451
452        let remote_slugs: HashSet<String> = remote_presets
453            .iter()
454            .map(|preset| preset.model.clone())
455            .collect();
456
457        let mut merged_presets = remote_presets;
458        for mut preset in existing_presets {
459            if remote_slugs.contains(&preset.model) {
460                continue;
461            }
462            preset.is_default = false;
463            merged_presets.push(preset);
464        }
465
466        merged_presets
467    }
468
469    /// Ensure there's at least one default model
470    fn ensure_defaults(&self, presets: &mut [ModelPreset]) {
471        let has_default = presets.iter().any(|p| p.is_default);
472        if !has_default && let Some(first) = presets.first_mut() {
473            first.is_default = true;
474        }
475    }
476
477    /// Get the cache file path
478    fn cache_path(&self) -> PathBuf {
479        self.vtcode_home.join(MODEL_CACHE_FILE)
480    }
481
482    /// Set the current provider
483    pub async fn set_provider(&self, provider: Provider) {
484        *self.current_provider.write().await = provider;
485    }
486
487    /// Get the current provider
488    pub async fn get_provider(&self) -> Provider {
489        *self.current_provider.read().await
490    }
491
492    /// Find a model preset by ID
493    pub async fn find_model(&self, model_id: &str) -> Option<ModelPreset> {
494        let models = self.list_models().await;
495        models
496            .into_iter()
497            .find(|m| m.model == model_id || m.id == model_id)
498    }
499
500    /// Check if a model exists
501    pub async fn model_exists(&self, model_id: &str) -> bool {
502        self.find_model(model_id).await.is_some()
503    }
504
505    /// Check if a model exists (sync, uses local presets only)
506    ///
507    /// This is a fast, non-blocking check that only looks at local presets.
508    /// Use `model_exists` for the async version that includes remote models.
509    pub fn model_exists_sync(&self, model_id: &str) -> bool {
510        self.local_models
511            .iter()
512            .any(|m| m.model == model_id || m.id == model_id)
513    }
514
515    /// Get all supported providers
516    pub fn supported_providers() -> Vec<Provider> {
517        Provider::all_providers()
518    }
519
520    /// Get version string for API requests
521    pub fn client_version() -> String {
522        format!(
523            "{}.{}.{}",
524            env!("CARGO_PKG_VERSION_MAJOR"),
525            env!("CARGO_PKG_VERSION_MINOR"),
526            env!("CARGO_PKG_VERSION_PATCH")
527        )
528    }
529}
530
531/// Thread-safe reference-counted ModelsManager
532pub type SharedModelsManager = Arc<ModelsManager>;
533
534/// Create a new shared ModelsManager
535pub fn new_shared_models_manager() -> SharedModelsManager {
536    Arc::new(ModelsManager::new())
537}
538
539/// Create a new shared ModelsManager with specific provider
540pub fn new_shared_models_manager_with_provider(provider: Provider) -> SharedModelsManager {
541    Arc::new(ModelsManager::with_provider(provider))
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547    use tempfile::tempdir;
548
549    #[tokio::test]
550    async fn test_new_manager() {
551        let manager = ModelsManager::new();
552        assert!(!manager.local_models.is_empty());
553    }
554
555    #[tokio::test]
556    async fn test_list_models() {
557        let manager = ModelsManager::new();
558        let models = manager.list_models().await;
559        assert!(!models.is_empty());
560    }
561
562    #[tokio::test]
563    async fn test_list_models_for_provider() {
564        let manager = ModelsManager::new();
565        let gemini_models = manager.list_models_for_provider(Provider::Gemini).await;
566        assert!(!gemini_models.is_empty());
567        assert!(gemini_models.iter().all(|m| m.provider == Provider::Gemini));
568    }
569
570    #[tokio::test]
571    async fn test_get_model_with_default() {
572        let manager = ModelsManager::with_provider(Provider::Gemini);
573        let model = manager.get_model(None).await;
574        assert!(!model.is_empty());
575    }
576
577    #[tokio::test]
578    async fn test_get_model_with_explicit() {
579        let manager = ModelsManager::new();
580        let model = manager.get_model(Some("custom-model")).await;
581        assert_eq!(model, "custom-model");
582    }
583
584    #[tokio::test]
585    async fn test_construct_model_family() {
586        let manager = ModelsManager::new();
587        let family = manager
588            .construct_model_family("gemini-3-flash-preview")
589            .await;
590        assert_eq!(family.family, "gemini-3");
591        assert_eq!(family.provider, Provider::Gemini);
592    }
593
594    #[tokio::test]
595    async fn test_find_model() {
596        let manager = ModelsManager::new();
597        let model = manager.find_model("gemini-3-flash-preview").await;
598        assert!(model.is_some());
599    }
600
601    #[tokio::test]
602    async fn test_model_exists() {
603        let manager = ModelsManager::new();
604        assert!(manager.model_exists("gemini-3-flash-preview").await);
605        assert!(!manager.model_exists("nonexistent-model").await);
606    }
607
608    #[tokio::test]
609    async fn test_set_provider() {
610        let manager = ModelsManager::new();
611        manager.set_provider(Provider::Anthropic).await;
612        assert_eq!(manager.get_provider().await, Provider::Anthropic);
613    }
614
615    #[tokio::test]
616    async fn test_cache_operations() {
617        let dir = tempdir().expect("create temp dir");
618        let manager = ModelsManager::with_home(dir.path().to_path_buf());
619
620        // Initially no cache
621        let cached = manager.try_load_cache().await;
622        assert!(!cached);
623
624        // Persist some models
625        let models = vec![ModelInfo {
626            slug: "test-model".to_string(),
627            display_name: "Test Model".to_string(),
628            description: "A test".to_string(),
629            provider: Provider::Gemini,
630            default_reasoning_level: crate::config::types::ReasoningEffortLevel::Medium,
631            supported_reasoning_levels: vec![],
632            context_window: Some(128_000),
633            supports_tool_use: true,
634            supports_streaming: true,
635            supports_reasoning: false,
636            priority: 0,
637            visibility: "list".to_string(),
638            supported_in_api: true,
639            upgrade: None,
640        }];
641        manager.persist_cache(&models, None).await;
642
643        // Now cache should load
644        let cached = manager.try_load_cache().await;
645        assert!(cached);
646    }
647
648    #[test]
649    fn test_client_version() {
650        let version = ModelsManager::client_version();
651        assert!(!version.is_empty());
652        assert!(version.contains('.'));
653    }
654
655    #[test]
656    fn test_supported_providers() {
657        let providers = ModelsManager::supported_providers();
658        assert!(!providers.is_empty());
659        assert!(providers.contains(&Provider::Gemini));
660        assert!(providers.contains(&Provider::OpenAI));
661    }
662
663    #[test]
664    fn moonshot_default_model_uses_curated_default() {
665        let manager = ModelsManager::new();
666        assert_eq!(
667            manager.get_default_model_for_provider(Provider::Moonshot),
668            crate::config::constants::models::moonshot::DEFAULT_MODEL
669        );
670    }
671}