Skip to main content

st/proxy/
ollama.rs

1//! 🦙 Ollama & LM Studio Provider - Local LLM Auto-Detection
2//!
3//! Automatically detects and connects to local LLM servers:
4//! - Ollama at localhost:11434
5//! - LM Studio at localhost:1234
6//!
7//! Both use OpenAI-compatible APIs, so we handle them uniformly.
8//!
9//! "Why pay for clouds when you've got a llama at home?" - The Cheet 🦙
10
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15use std::time::Duration;
16
17use super::{LlmProvider, LlmRequest, LlmResponse, LlmUsage};
18
19/// Default ports for local LLM servers
20pub const OLLAMA_PORT: u16 = 11434;
21pub const LMSTUDIO_PORT: u16 = 1234;
22
23/// Detected local LLM server type
24#[derive(Debug, Clone, PartialEq)]
25pub enum LocalLlmType {
26    Ollama,
27    LmStudio,
28}
29
30impl std::fmt::Display for LocalLlmType {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            LocalLlmType::Ollama => write!(f, "Ollama"),
34            LocalLlmType::LmStudio => write!(f, "LM Studio"),
35        }
36    }
37}
38
39/// Information about a detected local LLM server
40#[derive(Debug, Clone)]
41pub struct LocalLlmInfo {
42    pub server_type: LocalLlmType,
43    pub base_url: String,
44    pub models: Vec<String>,
45}
46
47/// 🦙 Provider for local LLM servers (Ollama, LM Studio)
48pub struct OllamaProvider {
49    client: Client,
50    base_url: String,
51    server_type: LocalLlmType,
52    default_model: String,
53}
54
55impl OllamaProvider {
56    /// Create a new Ollama provider with explicit URL
57    pub fn new(base_url: &str, server_type: LocalLlmType) -> Self {
58        Self {
59            client: Client::builder()
60                .timeout(Duration::from_secs(300)) // Local models can be slow
61                .build()
62                .expect("Failed to create HTTP client"),
63            base_url: base_url.trim_end_matches('/').to_string(),
64            server_type,
65            default_model: "llama3.2".to_string(),
66        }
67    }
68
69    /// Create provider for Ollama at default port
70    pub fn ollama() -> Self {
71        Self::new(
72            &format!("http://localhost:{}", OLLAMA_PORT),
73            LocalLlmType::Ollama,
74        )
75    }
76
77    /// Create provider for LM Studio at default port
78    pub fn lmstudio() -> Self {
79        Self::new(
80            &format!("http://localhost:{}", LMSTUDIO_PORT),
81            LocalLlmType::LmStudio,
82        )
83    }
84
85    /// Set the default model to use
86    pub fn with_model(mut self, model: &str) -> Self {
87        self.default_model = model.to_string();
88        self
89    }
90
91    /// List available models from the server
92    pub async fn list_models(&self) -> Result<Vec<String>> {
93        match self.server_type {
94            LocalLlmType::Ollama => self.list_ollama_models().await,
95            LocalLlmType::LmStudio => self.list_lmstudio_models().await,
96        }
97    }
98
99    async fn list_ollama_models(&self) -> Result<Vec<String>> {
100        let url = format!("{}/api/tags", self.base_url);
101        let response = self
102            .client
103            .get(&url)
104            .send()
105            .await
106            .context("Failed to connect to Ollama")?;
107
108        let tags: OllamaTagsResponse = response
109            .json()
110            .await
111            .context("Failed to parse Ollama models response")?;
112
113        Ok(tags.models.into_iter().map(|m| m.name).collect())
114    }
115
116    async fn list_lmstudio_models(&self) -> Result<Vec<String>> {
117        let url = format!("{}/v1/models", self.base_url);
118        let response = self
119            .client
120            .get(&url)
121            .send()
122            .await
123            .context("Failed to connect to LM Studio")?;
124
125        let models: OpenAiModelsResponse = response
126            .json()
127            .await
128            .context("Failed to parse LM Studio models response")?;
129
130        Ok(models.data.into_iter().map(|m| m.id).collect())
131    }
132}
133
134impl Default for OllamaProvider {
135    fn default() -> Self {
136        Self::ollama()
137    }
138}
139
140#[async_trait]
141impl LlmProvider for OllamaProvider {
142    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
143        let url = format!("{}/v1/chat/completions", self.base_url);
144
145        let model = if request.model.is_empty() || request.model == "default" {
146            self.default_model.clone()
147        } else {
148            request.model.clone()
149        };
150
151        let openai_request = OpenAiChatRequest {
152            model: model.clone(),
153            messages: request
154                .messages
155                .iter()
156                .map(|m| OpenAiMessage {
157                    role: match m.role {
158                        super::LlmRole::System => "system".to_string(),
159                        super::LlmRole::User => "user".to_string(),
160                        super::LlmRole::Assistant => "assistant".to_string(),
161                    },
162                    content: m.content.clone(),
163                })
164                .collect(),
165            temperature: request.temperature,
166            max_tokens: request.max_tokens,
167            stream: false, // We don't handle streaming in this basic impl
168        };
169
170        let response = self
171            .client
172            .post(&url)
173            .json(&openai_request)
174            .send()
175            .await
176            .context(format!("Failed to send request to {}", self.server_type))?;
177
178        if !response.status().is_success() {
179            let status = response.status();
180            let error_text = response.text().await.unwrap_or_default();
181            return Err(anyhow::anyhow!(
182                "{} returned error {}: {}",
183                self.server_type,
184                status,
185                error_text
186            ));
187        }
188
189        let openai_response: OpenAiChatResponse = response
190            .json()
191            .await
192            .context("Failed to parse response from local LLM")?;
193
194        let content = openai_response
195            .choices
196            .first()
197            .map(|c| c.message.content.clone())
198            .unwrap_or_default();
199
200        Ok(LlmResponse {
201            content,
202            model: openai_response.model,
203            usage: openai_response.usage.map(|u| LlmUsage {
204                prompt_tokens: u.prompt_tokens,
205                completion_tokens: u.completion_tokens,
206                total_tokens: u.total_tokens,
207            }),
208        })
209    }
210
211    fn name(&self) -> &'static str {
212        match self.server_type {
213            LocalLlmType::Ollama => "ollama",
214            LocalLlmType::LmStudio => "lmstudio",
215        }
216    }
217}
218
219// ============================================================================
220// Auto-Detection
221// ============================================================================
222
223/// Check if a local LLM server is running at the given port
224pub async fn check_server(host: &str, port: u16, timeout_ms: u64) -> bool {
225    let client = match Client::builder()
226        .timeout(Duration::from_millis(timeout_ms))
227        .build()
228    {
229        Ok(c) => c,
230        Err(_) => return false,
231    };
232
233    // Try the health/version endpoint first (fast)
234    let health_url = format!("http://{}:{}/", host, port);
235    if client.get(&health_url).send().await.is_ok() {
236        return true;
237    }
238
239    // Fallback: try the models endpoint
240    let models_url = format!("http://{}:{}/v1/models", host, port);
241    client.get(&models_url).send().await.is_ok()
242}
243
244/// Detect all available local LLM servers
245pub async fn detect_local_llms() -> Vec<LocalLlmInfo> {
246    let mut detected = Vec::new();
247
248    // Check Ollama
249    if check_server("localhost", OLLAMA_PORT, 500).await {
250        let provider = OllamaProvider::ollama();
251        let models = provider.list_models().await.unwrap_or_default();
252        detected.push(LocalLlmInfo {
253            server_type: LocalLlmType::Ollama,
254            base_url: format!("http://localhost:{}", OLLAMA_PORT),
255            models,
256        });
257    }
258
259    // Check LM Studio
260    if check_server("localhost", LMSTUDIO_PORT, 500).await {
261        let provider = OllamaProvider::lmstudio();
262        let models = provider.list_models().await.unwrap_or_default();
263        detected.push(LocalLlmInfo {
264            server_type: LocalLlmType::LmStudio,
265            base_url: format!("http://localhost:{}", LMSTUDIO_PORT),
266            models,
267        });
268    }
269
270    detected
271}
272
273/// Quick check if any local LLM is available (non-blocking, fast timeout)
274pub async fn any_local_llm_available() -> bool {
275    tokio::select! {
276        ollama = check_server("localhost", OLLAMA_PORT, 200) => {
277            if ollama { return true; }
278        }
279        lmstudio = check_server("localhost", LMSTUDIO_PORT, 200) => {
280            if lmstudio { return true; }
281        }
282    }
283
284    // Check remaining
285    check_server("localhost", OLLAMA_PORT, 200).await
286        || check_server("localhost", LMSTUDIO_PORT, 200).await
287}
288
289// ============================================================================
290// API Types
291// ============================================================================
292
293#[derive(Debug, Deserialize)]
294struct OllamaTagsResponse {
295    models: Vec<OllamaModel>,
296}
297
298#[derive(Debug, Deserialize)]
299struct OllamaModel {
300    name: String,
301    #[allow(dead_code)]
302    modified_at: Option<String>,
303    #[allow(dead_code)]
304    size: Option<u64>,
305}
306
307#[derive(Debug, Deserialize)]
308struct OpenAiModelsResponse {
309    data: Vec<OpenAiModelInfo>,
310}
311
312#[derive(Debug, Deserialize)]
313struct OpenAiModelInfo {
314    id: String,
315}
316
317#[derive(Debug, Serialize)]
318struct OpenAiChatRequest {
319    model: String,
320    messages: Vec<OpenAiMessage>,
321    #[serde(skip_serializing_if = "Option::is_none")]
322    temperature: Option<f32>,
323    #[serde(skip_serializing_if = "Option::is_none")]
324    max_tokens: Option<usize>,
325    stream: bool,
326}
327
328#[derive(Debug, Serialize, Deserialize)]
329struct OpenAiMessage {
330    role: String,
331    content: String,
332}
333
334#[derive(Debug, Deserialize)]
335struct OpenAiChatResponse {
336    model: String,
337    choices: Vec<OpenAiChoice>,
338    usage: Option<OpenAiUsageInfo>,
339}
340
341#[derive(Debug, Deserialize)]
342struct OpenAiChoice {
343    message: OpenAiMessage,
344}
345
346#[derive(Debug, Deserialize)]
347struct OpenAiUsageInfo {
348    prompt_tokens: usize,
349    completion_tokens: usize,
350    total_tokens: usize,
351}
352
353// ============================================================================
354// Tests
355// ============================================================================
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[tokio::test]
362    async fn test_detect_local_llms() {
363        // This test will pass whether or not local LLMs are running
364        let detected = detect_local_llms().await;
365        println!("Detected {} local LLM server(s)", detected.len());
366        for info in &detected {
367            println!(
368                "  - {} at {} with {} models",
369                info.server_type,
370                info.base_url,
371                info.models.len()
372            );
373            for model in &info.models {
374                println!("      • {}", model);
375            }
376        }
377    }
378
379    #[tokio::test]
380    async fn test_check_server_timeout() {
381        // Should timeout quickly on non-existent server
382        let start = std::time::Instant::now();
383        let result = check_server("localhost", 59999, 100).await;
384        let elapsed = start.elapsed();
385
386        assert!(!result);
387        assert!(
388            elapsed.as_millis() < 500,
389            "Timeout took too long: {:?}",
390            elapsed
391        );
392    }
393}