1use anyhow::Result;
6use reqwest::Client;
7use serde::Deserialize;
8use std::time::Duration;
9
10#[derive(Debug, Clone)]
12pub struct DetectedProvider {
13 pub kind: ProviderKind,
15 pub base_url: String,
17 pub port: u16,
19 pub models: Vec<String>,
21 pub suggested_model: Option<String>,
23 pub status: ProviderStatus,
25}
26
27impl DetectedProvider {
28 pub fn summary_line(&self) -> String {
30 match &self.status {
31 ProviderStatus::Online(model) => {
32 format!("{} at {} - {}", self.kind.label(), self.base_url, model)
33 }
34 ProviderStatus::OnlineNoModel => {
35 format!(
36 "{} at {} (no embedding model)",
37 self.kind.label(),
38 self.base_url
39 )
40 }
41 ProviderStatus::Offline => {
42 format!("{} at {} (offline)", self.kind.label(), self.base_url)
43 }
44 }
45 }
46
47 pub fn summary(&self) -> String {
49 self.summary_line()
50 }
51
52 pub fn is_usable(&self) -> bool {
54 matches!(self.status, ProviderStatus::Online(_))
55 }
56
57 pub fn model(&self) -> Option<&str> {
59 if let ProviderStatus::Online(ref model) = self.status {
60 Some(model.as_str())
61 } else {
62 self.suggested_model.as_deref()
63 }
64 }
65}
66
67fn looks_like_embedding_model(model: &str) -> bool {
68 let model = model.to_ascii_lowercase();
69 model.contains("embedding")
70 || model.contains("embed")
71 || model.contains("bge")
72 || model.contains("nomic")
73 || model.contains("mxbai")
74 || model.contains("minilm")
75}
76
77fn pick_embedding_model(models: &[String]) -> Option<String> {
78 models
79 .iter()
80 .find(|m| looks_like_embedding_model(m))
81 .cloned()
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum ProviderKind {
87 Ollama,
89 Mlx,
91 OpenAICompat,
93 Manual,
95}
96
97impl ProviderKind {
98 pub fn label(&self) -> &'static str {
99 match self {
100 ProviderKind::Ollama => "Ollama",
101 ProviderKind::Mlx => "MLX Server",
102 ProviderKind::OpenAICompat => "OpenAI-Compatible",
103 ProviderKind::Manual => "Manual",
104 }
105 }
106}
107
108#[derive(Debug, Clone)]
110pub enum ProviderStatus {
111 Online(String),
113 OnlineNoModel,
115 Offline,
117}
118
119#[derive(Debug, Deserialize)]
121struct OllamaTagsResponse {
122 models: Vec<OllamaModel>,
123}
124
125#[derive(Debug, Deserialize)]
126struct OllamaModel {
127 name: String,
128}
129
130#[derive(Debug, Deserialize)]
132struct ModelsResponse {
133 data: Vec<ModelInfo>,
134}
135
136#[derive(Debug, Deserialize)]
137struct ModelInfo {
138 id: String,
139}
140
141pub async fn detect_providers() -> Vec<DetectedProvider> {
143 let client = Client::builder()
144 .timeout(Duration::from_secs(3))
145 .connect_timeout(Duration::from_secs(2))
146 .build()
147 .unwrap_or_default();
148
149 let mut providers = Vec::new();
150
151 if let Some(provider) = detect_ollama(&client, "http://localhost", 11434).await {
153 providers.push(provider);
154 }
155
156 if let Some(provider) = detect_mlx(&client, "http://localhost", 12345).await {
158 providers.push(provider);
159 }
160
161 if let Some(provider) = detect_mlx(&client, "http://dragon", 12345).await {
163 providers.push(provider);
164 }
165
166 providers
167}
168
169pub async fn check_health(url: &str) -> bool {
172 let client = Client::builder()
173 .timeout(Duration::from_secs(3))
174 .connect_timeout(Duration::from_secs(2))
175 .build()
176 .unwrap_or_default();
177 client.get(url).send().await.is_ok()
178}
179
180async fn detect_ollama(client: &Client, host: &str, port: u16) -> Option<DetectedProvider> {
182 let base_url = format!("{}:{}", host, port);
183 let tags_url = format!("{}/api/tags", base_url);
184
185 let response = match client.get(&tags_url).send().await {
187 Ok(r) if r.status().is_success() => r,
188 _ => {
189 return Some(DetectedProvider {
190 kind: ProviderKind::Ollama,
191 base_url: base_url.clone(),
192 port,
193 models: vec![],
194 suggested_model: None,
195 status: ProviderStatus::Offline,
196 });
197 }
198 };
199
200 let tags: OllamaTagsResponse = match response.json().await {
201 Ok(t) => t,
202 Err(_) => {
203 return Some(DetectedProvider {
204 kind: ProviderKind::Ollama,
205 base_url,
206 port,
207 models: vec![],
208 suggested_model: None,
209 status: ProviderStatus::OnlineNoModel,
210 });
211 }
212 };
213
214 let models: Vec<String> = tags.models.iter().map(|m| m.name.clone()).collect();
215
216 let embedding_model = pick_embedding_model(&models);
217
218 let status = if let Some(ref model) = embedding_model {
219 ProviderStatus::Online(model.clone())
220 } else {
221 ProviderStatus::OnlineNoModel
222 };
223
224 Some(DetectedProvider {
225 kind: ProviderKind::Ollama,
226 base_url,
227 port,
228 models,
229 suggested_model: embedding_model,
230 status,
231 })
232}
233
234async fn detect_mlx(client: &Client, host: &str, port: u16) -> Option<DetectedProvider> {
236 let base_url = format!("{}:{}", host, port);
237 let models_url = format!("{}/v1/models", base_url);
238
239 let response = match client.get(&models_url).send().await {
241 Ok(r) if r.status().is_success() => r,
242 _ => {
243 return None;
245 }
246 };
247
248 let models_resp: ModelsResponse = match response.json().await {
249 Ok(m) => m,
250 Err(_) => {
251 return Some(DetectedProvider {
252 kind: ProviderKind::Mlx,
253 base_url,
254 port,
255 models: vec![],
256 suggested_model: None,
257 status: ProviderStatus::OnlineNoModel,
258 });
259 }
260 };
261
262 let models: Vec<String> = models_resp.data.iter().map(|m| m.id.clone()).collect();
263
264 let embedding_model = pick_embedding_model(&models);
265
266 let status = if let Some(ref model) = embedding_model {
267 ProviderStatus::Online(model.clone())
268 } else {
269 ProviderStatus::OnlineNoModel
270 };
271
272 Some(DetectedProvider {
273 kind: ProviderKind::Mlx,
274 base_url,
275 port,
276 models,
277 suggested_model: embedding_model,
278 status,
279 })
280}
281
282pub async fn check_custom_endpoint(url: &str) -> Result<DetectedProvider> {
284 let client = Client::builder()
285 .timeout(Duration::from_secs(5))
286 .connect_timeout(Duration::from_secs(3))
287 .build()?;
288
289 let base_url = url.trim_end_matches('/');
290
291 let port = reqwest::Url::parse(base_url)
293 .ok()
294 .and_then(|u| u.port())
295 .unwrap_or(80);
296
297 let models_url = format!("{}/v1/models", base_url);
299 if let Ok(resp) = client.get(&models_url).send().await {
300 if resp.status().is_success() {
301 let body = resp.text().await.unwrap_or_default();
302 if let Ok(models_resp) = serde_json::from_str::<ModelsResponse>(&body) {
303 let models: Vec<String> = models_resp.data.iter().map(|m| m.id.clone()).collect();
304
305 let embedding_model = models
306 .iter()
307 .find(|m| m.contains("embedding") || m.contains("Embedding"))
308 .cloned();
309
310 let status = if let Some(ref model) = embedding_model {
311 ProviderStatus::Online(model.clone())
312 } else {
313 ProviderStatus::OnlineNoModel
314 };
315
316 return Ok(DetectedProvider {
317 kind: ProviderKind::OpenAICompat,
318 base_url: base_url.to_string(),
319 port,
320 models,
321 suggested_model: embedding_model,
322 status,
323 });
324 } else {
325 tracing::debug!(
326 "Failed to parse /v1/models response: {}",
327 &body[..body.len().min(200)]
328 );
329 }
330 } else {
331 let status = resp.status();
332 let body = resp.text().await.unwrap_or_default();
333 tracing::debug!(
334 "OpenAI endpoint returned HTTP {}: {}",
335 status,
336 &body[..body.len().min(200)]
337 );
338 }
339 }
340
341 let tags_url = format!("{}/api/tags", base_url);
343 if let Ok(resp) = client.get(&tags_url).send().await {
344 if resp.status().is_success() {
345 let body = resp.text().await.unwrap_or_default();
346 if let Ok(tags) = serde_json::from_str::<OllamaTagsResponse>(&body) {
347 let models: Vec<String> = tags.models.iter().map(|m| m.name.clone()).collect();
348
349 let embedding_model = pick_embedding_model(&models);
350
351 let status = if let Some(ref model) = embedding_model {
352 ProviderStatus::Online(model.clone())
353 } else if !models.is_empty() {
354 ProviderStatus::OnlineNoModel
355 } else {
356 ProviderStatus::Offline
357 };
358
359 return Ok(DetectedProvider {
360 kind: ProviderKind::Ollama,
361 base_url: base_url.to_string(),
362 port,
363 models,
364 suggested_model: embedding_model,
365 status,
366 });
367 } else {
368 tracing::debug!(
369 "Failed to parse /api/tags response: {}",
370 &body[..body.len().min(200)]
371 );
372 }
373 } else {
374 let status = resp.status();
375 let body = resp.text().await.unwrap_or_default();
376 tracing::debug!(
377 "Ollama endpoint returned HTTP {}: {}",
378 status,
379 &body[..body.len().min(200)]
380 );
381 }
382 }
383
384 Ok(DetectedProvider {
385 kind: ProviderKind::OpenAICompat,
386 base_url: base_url.to_string(),
387 port,
388 models: vec![],
389 suggested_model: None,
390 status: ProviderStatus::Offline,
391 })
392}
393
394pub fn dimension_explanation(dim: usize) -> String {
397 format!("{dim} dims — ensure all providers match this dimension")
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn test_provider_kind_display() {
406 assert_eq!(ProviderKind::Ollama.label(), "Ollama");
407 assert_eq!(ProviderKind::Mlx.label(), "MLX Server");
408 }
409
410 #[test]
411 fn pick_embedding_model_finds_embedding_keyword() {
412 let models = vec!["llama3:8b".to_string(), "qwen3-embedding:8b".to_string()];
413 assert_eq!(
414 pick_embedding_model(&models).as_deref(),
415 Some("qwen3-embedding:8b")
416 );
417 }
418
419 #[test]
420 fn dimension_explanation_is_dynamic() {
421 let explanation = dimension_explanation(1536);
422 assert!(explanation.contains("1536"));
423 }
424}