Skip to main content

rust_memex/tui/
detection.rs

1//! Auto-detection module for embedding providers.
2//!
3//! Detects Ollama, MLX server, and other embedding providers automatically.
4
5use anyhow::Result;
6use reqwest::Client;
7use serde::Deserialize;
8use std::time::Duration;
9
10/// Detected embedding provider information.
11#[derive(Debug, Clone)]
12pub struct DetectedProvider {
13    /// Provider type
14    pub kind: ProviderKind,
15    /// Base URL where provider is running
16    pub base_url: String,
17    /// Port number
18    pub port: u16,
19    /// List of available models (if detected)
20    pub models: Vec<String>,
21    /// Suggested model to use
22    pub suggested_model: Option<String>,
23    /// Connection status
24    pub status: ProviderStatus,
25}
26
27impl DetectedProvider {
28    /// Get a human-readable description
29    pub fn summary_line(&self) -> String {
30        match &self.status {
31            ProviderStatus::Online(model) => {
32                format!("{} at {} - {}", self.kind.label(), self.base_url, model)
33            }
34            ProviderStatus::OnlineNoModel => {
35                format!(
36                    "{} at {} (no embedding model)",
37                    self.kind.label(),
38                    self.base_url
39                )
40            }
41            ProviderStatus::Offline => {
42                format!("{} at {} (offline)", self.kind.label(), self.base_url)
43            }
44        }
45    }
46
47    /// Get provider summary for UI
48    pub fn summary(&self) -> String {
49        self.summary_line()
50    }
51
52    /// Check if provider is usable
53    pub fn is_usable(&self) -> bool {
54        matches!(self.status, ProviderStatus::Online(_))
55    }
56
57    /// Get the model to use
58    pub fn model(&self) -> Option<&str> {
59        if let ProviderStatus::Online(ref model) = self.status {
60            Some(model.as_str())
61        } else {
62            self.suggested_model.as_deref()
63        }
64    }
65}
66
67fn looks_like_embedding_model(model: &str) -> bool {
68    let model = model.to_ascii_lowercase();
69    model.contains("embedding")
70        || model.contains("embed")
71        || model.contains("bge")
72        || model.contains("nomic")
73        || model.contains("mxbai")
74        || model.contains("minilm")
75}
76
77fn pick_embedding_model(models: &[String]) -> Option<String> {
78    models
79        .iter()
80        .find(|m| looks_like_embedding_model(m))
81        .cloned()
82}
83
84/// Type of embedding provider.
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum ProviderKind {
87    /// Ollama with embedding models
88    Ollama,
89    /// MLX embedding server
90    Mlx,
91    /// Generic OpenAI-compatible endpoint
92    OpenAICompat,
93    /// Manual configuration
94    Manual,
95}
96
97impl ProviderKind {
98    pub fn label(&self) -> &'static str {
99        match self {
100            ProviderKind::Ollama => "Ollama",
101            ProviderKind::Mlx => "MLX Server",
102            ProviderKind::OpenAICompat => "OpenAI-Compatible",
103            ProviderKind::Manual => "Manual",
104        }
105    }
106}
107
108/// Provider connection status.
109#[derive(Debug, Clone)]
110pub enum ProviderStatus {
111    /// Provider is online with a usable model
112    Online(String),
113    /// Provider is online but no embedding model found
114    OnlineNoModel,
115    /// Provider is offline
116    Offline,
117}
118
119/// Response from Ollama /api/tags endpoint.
120#[derive(Debug, Deserialize)]
121struct OllamaTagsResponse {
122    models: Vec<OllamaModel>,
123}
124
125#[derive(Debug, Deserialize)]
126struct OllamaModel {
127    name: String,
128}
129
130/// Response from OpenAI-compatible /v1/models endpoint.
131#[derive(Debug, Deserialize)]
132struct ModelsResponse {
133    data: Vec<ModelInfo>,
134}
135
136#[derive(Debug, Deserialize)]
137struct ModelInfo {
138    id: String,
139}
140
141/// Detect embedding providers on standard ports.
142pub async fn detect_providers() -> Vec<DetectedProvider> {
143    let client = Client::builder()
144        .timeout(Duration::from_secs(3))
145        .connect_timeout(Duration::from_secs(2))
146        .build()
147        .unwrap_or_default();
148
149    let mut providers = Vec::new();
150
151    // Check Ollama on localhost:11434
152    if let Some(provider) = detect_ollama(&client, "http://localhost", 11434).await {
153        providers.push(provider);
154    }
155
156    // Check MLX on localhost:12345
157    if let Some(provider) = detect_mlx(&client, "http://localhost", 12345).await {
158        providers.push(provider);
159    }
160
161    // Check dragon:12345 (common remote MLX server)
162    if let Some(provider) = detect_mlx(&client, "http://dragon", 12345).await {
163        providers.push(provider);
164    }
165
166    providers
167}
168
169/// Check if a URL is reachable (simple health check).
170/// Used for quick provider connectivity verification.
171pub async fn check_health(url: &str) -> bool {
172    let client = Client::builder()
173        .timeout(Duration::from_secs(3))
174        .connect_timeout(Duration::from_secs(2))
175        .build()
176        .unwrap_or_default();
177    client.get(url).send().await.is_ok()
178}
179
180/// Detect Ollama on a given host/port.
181async fn detect_ollama(client: &Client, host: &str, port: u16) -> Option<DetectedProvider> {
182    let base_url = format!("{}:{}", host, port);
183    let tags_url = format!("{}/api/tags", base_url);
184
185    // Try to get list of models
186    let response = match client.get(&tags_url).send().await {
187        Ok(r) if r.status().is_success() => r,
188        _ => {
189            return Some(DetectedProvider {
190                kind: ProviderKind::Ollama,
191                base_url: base_url.clone(),
192                port,
193                models: vec![],
194                suggested_model: None,
195                status: ProviderStatus::Offline,
196            });
197        }
198    };
199
200    let tags: OllamaTagsResponse = match response.json().await {
201        Ok(t) => t,
202        Err(_) => {
203            return Some(DetectedProvider {
204                kind: ProviderKind::Ollama,
205                base_url,
206                port,
207                models: vec![],
208                suggested_model: None,
209                status: ProviderStatus::OnlineNoModel,
210            });
211        }
212    };
213
214    let models: Vec<String> = tags.models.iter().map(|m| m.name.clone()).collect();
215
216    let embedding_model = pick_embedding_model(&models);
217
218    let status = if let Some(ref model) = embedding_model {
219        ProviderStatus::Online(model.clone())
220    } else {
221        ProviderStatus::OnlineNoModel
222    };
223
224    Some(DetectedProvider {
225        kind: ProviderKind::Ollama,
226        base_url,
227        port,
228        models,
229        suggested_model: embedding_model,
230        status,
231    })
232}
233
234/// Detect MLX embedding server on a given host/port.
235async fn detect_mlx(client: &Client, host: &str, port: u16) -> Option<DetectedProvider> {
236    let base_url = format!("{}:{}", host, port);
237    let models_url = format!("{}/v1/models", base_url);
238
239    // Try OpenAI-compatible endpoint
240    let response = match client.get(&models_url).send().await {
241        Ok(r) if r.status().is_success() => r,
242        _ => {
243            // Don't report offline MLX servers unless we were explicitly looking for them
244            return None;
245        }
246    };
247
248    let models_resp: ModelsResponse = match response.json().await {
249        Ok(m) => m,
250        Err(_) => {
251            return Some(DetectedProvider {
252                kind: ProviderKind::Mlx,
253                base_url,
254                port,
255                models: vec![],
256                suggested_model: None,
257                status: ProviderStatus::OnlineNoModel,
258            });
259        }
260    };
261
262    let models: Vec<String> = models_resp.data.iter().map(|m| m.id.clone()).collect();
263
264    let embedding_model = pick_embedding_model(&models);
265
266    let status = if let Some(ref model) = embedding_model {
267        ProviderStatus::Online(model.clone())
268    } else {
269        ProviderStatus::OnlineNoModel
270    };
271
272    Some(DetectedProvider {
273        kind: ProviderKind::Mlx,
274        base_url,
275        port,
276        models,
277        suggested_model: embedding_model,
278        status,
279    })
280}
281
282/// Check a custom endpoint for OpenAI-compatibility.
283pub async fn check_custom_endpoint(url: &str) -> Result<DetectedProvider> {
284    let client = Client::builder()
285        .timeout(Duration::from_secs(5))
286        .connect_timeout(Duration::from_secs(3))
287        .build()?;
288
289    let base_url = url.trim_end_matches('/');
290
291    // Extract port from URL
292    let port = reqwest::Url::parse(base_url)
293        .ok()
294        .and_then(|u| u.port())
295        .unwrap_or(80);
296
297    // Try /v1/models first
298    let models_url = format!("{}/v1/models", base_url);
299    if let Ok(resp) = client.get(&models_url).send().await {
300        if resp.status().is_success() {
301            let body = resp.text().await.unwrap_or_default();
302            if let Ok(models_resp) = serde_json::from_str::<ModelsResponse>(&body) {
303                let models: Vec<String> = models_resp.data.iter().map(|m| m.id.clone()).collect();
304
305                let embedding_model = models
306                    .iter()
307                    .find(|m| m.contains("embedding") || m.contains("Embedding"))
308                    .cloned();
309
310                let status = if let Some(ref model) = embedding_model {
311                    ProviderStatus::Online(model.clone())
312                } else {
313                    ProviderStatus::OnlineNoModel
314                };
315
316                return Ok(DetectedProvider {
317                    kind: ProviderKind::OpenAICompat,
318                    base_url: base_url.to_string(),
319                    port,
320                    models,
321                    suggested_model: embedding_model,
322                    status,
323                });
324            } else {
325                tracing::debug!(
326                    "Failed to parse /v1/models response: {}",
327                    &body[..body.len().min(200)]
328                );
329            }
330        } else {
331            let status = resp.status();
332            let body = resp.text().await.unwrap_or_default();
333            tracing::debug!(
334                "OpenAI endpoint returned HTTP {}: {}",
335                status,
336                &body[..body.len().min(200)]
337            );
338        }
339    }
340
341    // Try Ollama endpoint
342    let tags_url = format!("{}/api/tags", base_url);
343    if let Ok(resp) = client.get(&tags_url).send().await {
344        if resp.status().is_success() {
345            let body = resp.text().await.unwrap_or_default();
346            if let Ok(tags) = serde_json::from_str::<OllamaTagsResponse>(&body) {
347                let models: Vec<String> = tags.models.iter().map(|m| m.name.clone()).collect();
348
349                let embedding_model = pick_embedding_model(&models);
350
351                let status = if let Some(ref model) = embedding_model {
352                    ProviderStatus::Online(model.clone())
353                } else if !models.is_empty() {
354                    ProviderStatus::OnlineNoModel
355                } else {
356                    ProviderStatus::Offline
357                };
358
359                return Ok(DetectedProvider {
360                    kind: ProviderKind::Ollama,
361                    base_url: base_url.to_string(),
362                    port,
363                    models,
364                    suggested_model: embedding_model,
365                    status,
366                });
367            } else {
368                tracing::debug!(
369                    "Failed to parse /api/tags response: {}",
370                    &body[..body.len().min(200)]
371                );
372            }
373        } else {
374            let status = resp.status();
375            let body = resp.text().await.unwrap_or_default();
376            tracing::debug!(
377                "Ollama endpoint returned HTTP {}: {}",
378                status,
379                &body[..body.len().min(200)]
380            );
381        }
382    }
383
384    Ok(DetectedProvider {
385        kind: ProviderKind::OpenAICompat,
386        base_url: base_url.to_string(),
387        port,
388        models: vec![],
389        suggested_model: None,
390        status: ProviderStatus::Offline,
391    })
392}
393
394/// Get dimension explanation for UI.
395/// Reports the verified dimension without guessing model variants.
396pub fn dimension_explanation(dim: usize) -> String {
397    format!("{dim} dims — ensure all providers match this dimension")
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn test_provider_kind_display() {
406        assert_eq!(ProviderKind::Ollama.label(), "Ollama");
407        assert_eq!(ProviderKind::Mlx.label(), "MLX Server");
408    }
409
410    #[test]
411    fn pick_embedding_model_finds_embedding_keyword() {
412        let models = vec!["llama3:8b".to_string(), "qwen3-embedding:8b".to_string()];
413        assert_eq!(
414            pick_embedding_model(&models).as_deref(),
415            Some("qwen3-embedding:8b")
416        );
417    }
418
419    #[test]
420    fn dimension_explanation_is_dynamic() {
421        let explanation = dimension_explanation(1536);
422        assert!(explanation.contains("1536"));
423    }
424}