Skip to main content

swarm_engine_llm/
registry.rs

1//! Model Registry - Ollamaモデルの動的検出と管理
2//!
3//! Ollamaと連携してインストール済みモデルを自動検出
4
5use ollama_rs::Ollama;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9/// モデル情報
10#[derive(Debug, Clone)]
11pub struct ModelInfo {
12    /// モデル名 (e.g., "qwen2.5-coder:1.5b")
13    pub name: String,
14    /// サイズ (bytes)
15    pub size_bytes: u64,
16}
17
18/// Model Registry
19pub struct ModelRegistry {
20    ollama: Ollama,
21    /// キャッシュされたモデル一覧
22    models: Arc<RwLock<Vec<ModelInfo>>>,
23    /// エンドポイント
24    endpoint: String,
25}
26
27impl ModelRegistry {
28    /// 新規作成
29    pub fn new(host: &str, port: u16) -> Self {
30        let endpoint = format!("{}:{}", host, port);
31        Self {
32            ollama: Ollama::new(host.to_string(), port),
33            models: Arc::new(RwLock::new(Vec::new())),
34            endpoint,
35        }
36    }
37
38    /// デフォルト (localhost:11434)
39    pub fn default_local() -> Self {
40        Self::new("http://localhost", 11434)
41    }
42
43    /// モデル一覧を検出してキャッシュ
44    pub async fn discover(&self) -> Result<Vec<ModelInfo>, RegistryError> {
45        let local_models = self
46            .ollama
47            .list_local_models()
48            .await
49            .map_err(|e| RegistryError::ConnectionFailed(e.to_string()))?;
50
51        let models: Vec<ModelInfo> = local_models
52            .into_iter()
53            .map(|m| ModelInfo {
54                name: m.name,
55                size_bytes: m.size,
56            })
57            .collect();
58
59        // キャッシュ更新
60        {
61            let mut cache = self.models.write().await;
62            *cache = models.clone();
63        }
64
65        tracing::info!(
66            endpoint = %self.endpoint,
67            count = models.len(),
68            "Discovered {} models",
69            models.len()
70        );
71
72        Ok(models)
73    }
74
75    /// キャッシュからモデル一覧を取得
76    pub async fn list(&self) -> Vec<ModelInfo> {
77        self.models.read().await.clone()
78    }
79
80    /// モデル名で検索
81    pub async fn get(&self, name: &str) -> Option<ModelInfo> {
82        let models = self.models.read().await;
83        models.iter().find(|m| m.name == name).cloned()
84    }
85
86    /// 名前のプレフィックスでフィルタ (e.g., "hf.co/LiquidAI", "qwen")
87    pub async fn by_prefix(&self, prefix: &str) -> Vec<ModelInfo> {
88        let models = self.models.read().await;
89        models
90            .iter()
91            .filter(|m| m.name.starts_with(prefix))
92            .cloned()
93            .collect()
94    }
95
96    /// 名前に含まれる文字列でフィルタ
97    pub async fn search(&self, query: &str) -> Vec<ModelInfo> {
98        let query_lower = query.to_lowercase();
99        let models = self.models.read().await;
100        models
101            .iter()
102            .filter(|m| m.name.to_lowercase().contains(&query_lower))
103            .cloned()
104            .collect()
105    }
106
107    /// モデルが存在するか確認
108    pub async fn exists(&self, name: &str) -> bool {
109        self.get(name).await.is_some()
110    }
111
112    /// 最初に見つかったモデルを返す(フォールバック用)
113    pub async fn first(&self) -> Option<ModelInfo> {
114        let models = self.models.read().await;
115        models.first().cloned()
116    }
117
118    /// モデル名を解決(存在確認 + フォールバック)
119    pub async fn resolve(&self, preferred: &str) -> Result<ModelInfo, RegistryError> {
120        // 優先モデルが存在すればそれを返す
121        if let Some(model) = self.get(preferred).await {
122            return Ok(model);
123        }
124
125        // なければ最初のモデルを返す
126        self.first()
127            .await
128            .ok_or_else(|| RegistryError::NoModelsAvailable {
129                requested: preferred.to_string(),
130            })
131    }
132
133    /// エンドポイントを取得
134    pub fn endpoint(&self) -> &str {
135        &self.endpoint
136    }
137}
138
139impl Default for ModelRegistry {
140    fn default() -> Self {
141        Self::default_local()
142    }
143}
144
145/// Registry エラー
146#[derive(Debug, thiserror::Error)]
147pub enum RegistryError {
148    #[error("Failed to connect to Ollama: {0}")]
149    ConnectionFailed(String),
150
151    #[error("Model '{requested}' not found and no fallback available")]
152    NoModelsAvailable { requested: String },
153}
154
155impl From<RegistryError> for swarm_engine_core::error::SwarmError {
156    fn from(err: RegistryError) -> Self {
157        match err {
158            RegistryError::ConnectionFailed(msg) => {
159                swarm_engine_core::error::SwarmError::NetworkTransient { message: msg }
160            }
161            RegistryError::NoModelsAvailable { requested } => {
162                swarm_engine_core::error::SwarmError::Config {
163                    message: format!("Model '{}' not found", requested),
164                }
165            }
166        }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[tokio::test]
175    async fn test_registry_creation() {
176        let registry = ModelRegistry::default_local();
177        assert_eq!(registry.endpoint(), "http://localhost:11434");
178    }
179}