1use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15use std::time::Duration;
16
17use super::{LlmProvider, LlmRequest, LlmResponse, LlmUsage};
18
19pub const OLLAMA_PORT: u16 = 11434;
21pub const LMSTUDIO_PORT: u16 = 1234;
22
23#[derive(Debug, Clone, PartialEq)]
25pub enum LocalLlmType {
26 Ollama,
27 LmStudio,
28}
29
30impl std::fmt::Display for LocalLlmType {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 LocalLlmType::Ollama => write!(f, "Ollama"),
34 LocalLlmType::LmStudio => write!(f, "LM Studio"),
35 }
36 }
37}
38
39#[derive(Debug, Clone)]
41pub struct LocalLlmInfo {
42 pub server_type: LocalLlmType,
43 pub base_url: String,
44 pub models: Vec<String>,
45}
46
47pub struct OllamaProvider {
49 client: Client,
50 base_url: String,
51 server_type: LocalLlmType,
52 default_model: String,
53}
54
55impl OllamaProvider {
56 pub fn new(base_url: &str, server_type: LocalLlmType) -> Self {
58 Self {
59 client: Client::builder()
60 .timeout(Duration::from_secs(300)) .build()
62 .expect("Failed to create HTTP client"),
63 base_url: base_url.trim_end_matches('/').to_string(),
64 server_type,
65 default_model: "llama3.2".to_string(),
66 }
67 }
68
69 pub fn ollama() -> Self {
71 Self::new(
72 &format!("http://localhost:{}", OLLAMA_PORT),
73 LocalLlmType::Ollama,
74 )
75 }
76
77 pub fn lmstudio() -> Self {
79 Self::new(
80 &format!("http://localhost:{}", LMSTUDIO_PORT),
81 LocalLlmType::LmStudio,
82 )
83 }
84
85 pub fn with_model(mut self, model: &str) -> Self {
87 self.default_model = model.to_string();
88 self
89 }
90
91 pub async fn list_models(&self) -> Result<Vec<String>> {
93 match self.server_type {
94 LocalLlmType::Ollama => self.list_ollama_models().await,
95 LocalLlmType::LmStudio => self.list_lmstudio_models().await,
96 }
97 }
98
99 async fn list_ollama_models(&self) -> Result<Vec<String>> {
100 let url = format!("{}/api/tags", self.base_url);
101 let response = self
102 .client
103 .get(&url)
104 .send()
105 .await
106 .context("Failed to connect to Ollama")?;
107
108 let tags: OllamaTagsResponse = response
109 .json()
110 .await
111 .context("Failed to parse Ollama models response")?;
112
113 Ok(tags.models.into_iter().map(|m| m.name).collect())
114 }
115
116 async fn list_lmstudio_models(&self) -> Result<Vec<String>> {
117 let url = format!("{}/v1/models", self.base_url);
118 let response = self
119 .client
120 .get(&url)
121 .send()
122 .await
123 .context("Failed to connect to LM Studio")?;
124
125 let models: OpenAiModelsResponse = response
126 .json()
127 .await
128 .context("Failed to parse LM Studio models response")?;
129
130 Ok(models.data.into_iter().map(|m| m.id).collect())
131 }
132}
133
134impl Default for OllamaProvider {
135 fn default() -> Self {
136 Self::ollama()
137 }
138}
139
140#[async_trait]
141impl LlmProvider for OllamaProvider {
142 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
143 let url = format!("{}/v1/chat/completions", self.base_url);
144
145 let model = if request.model.is_empty() || request.model == "default" {
146 self.default_model.clone()
147 } else {
148 request.model.clone()
149 };
150
151 let openai_request = OpenAiChatRequest {
152 model: model.clone(),
153 messages: request
154 .messages
155 .iter()
156 .map(|m| OpenAiMessage {
157 role: match m.role {
158 super::LlmRole::System => "system".to_string(),
159 super::LlmRole::User => "user".to_string(),
160 super::LlmRole::Assistant => "assistant".to_string(),
161 },
162 content: m.content.clone(),
163 })
164 .collect(),
165 temperature: request.temperature,
166 max_tokens: request.max_tokens,
167 stream: false, };
169
170 let response = self
171 .client
172 .post(&url)
173 .json(&openai_request)
174 .send()
175 .await
176 .context(format!("Failed to send request to {}", self.server_type))?;
177
178 if !response.status().is_success() {
179 let status = response.status();
180 let error_text = response.text().await.unwrap_or_default();
181 return Err(anyhow::anyhow!(
182 "{} returned error {}: {}",
183 self.server_type,
184 status,
185 error_text
186 ));
187 }
188
189 let openai_response: OpenAiChatResponse = response
190 .json()
191 .await
192 .context("Failed to parse response from local LLM")?;
193
194 let content = openai_response
195 .choices
196 .first()
197 .map(|c| c.message.content.clone())
198 .unwrap_or_default();
199
200 Ok(LlmResponse {
201 content,
202 model: openai_response.model,
203 usage: openai_response.usage.map(|u| LlmUsage {
204 prompt_tokens: u.prompt_tokens,
205 completion_tokens: u.completion_tokens,
206 total_tokens: u.total_tokens,
207 }),
208 })
209 }
210
211 fn name(&self) -> &'static str {
212 match self.server_type {
213 LocalLlmType::Ollama => "ollama",
214 LocalLlmType::LmStudio => "lmstudio",
215 }
216 }
217}
218
219pub async fn check_server(host: &str, port: u16, timeout_ms: u64) -> bool {
225 let client = match Client::builder()
226 .timeout(Duration::from_millis(timeout_ms))
227 .build()
228 {
229 Ok(c) => c,
230 Err(_) => return false,
231 };
232
233 let health_url = format!("http://{}:{}/", host, port);
235 if client.get(&health_url).send().await.is_ok() {
236 return true;
237 }
238
239 let models_url = format!("http://{}:{}/v1/models", host, port);
241 client.get(&models_url).send().await.is_ok()
242}
243
244pub async fn detect_local_llms() -> Vec<LocalLlmInfo> {
246 let mut detected = Vec::new();
247
248 if check_server("localhost", OLLAMA_PORT, 500).await {
250 let provider = OllamaProvider::ollama();
251 let models = provider.list_models().await.unwrap_or_default();
252 detected.push(LocalLlmInfo {
253 server_type: LocalLlmType::Ollama,
254 base_url: format!("http://localhost:{}", OLLAMA_PORT),
255 models,
256 });
257 }
258
259 if check_server("localhost", LMSTUDIO_PORT, 500).await {
261 let provider = OllamaProvider::lmstudio();
262 let models = provider.list_models().await.unwrap_or_default();
263 detected.push(LocalLlmInfo {
264 server_type: LocalLlmType::LmStudio,
265 base_url: format!("http://localhost:{}", LMSTUDIO_PORT),
266 models,
267 });
268 }
269
270 detected
271}
272
273pub async fn any_local_llm_available() -> bool {
275 tokio::select! {
276 ollama = check_server("localhost", OLLAMA_PORT, 200) => {
277 if ollama { return true; }
278 }
279 lmstudio = check_server("localhost", LMSTUDIO_PORT, 200) => {
280 if lmstudio { return true; }
281 }
282 }
283
284 check_server("localhost", OLLAMA_PORT, 200).await
286 || check_server("localhost", LMSTUDIO_PORT, 200).await
287}
288
289#[derive(Debug, Deserialize)]
294struct OllamaTagsResponse {
295 models: Vec<OllamaModel>,
296}
297
298#[derive(Debug, Deserialize)]
299struct OllamaModel {
300 name: String,
301 #[allow(dead_code)]
302 modified_at: Option<String>,
303 #[allow(dead_code)]
304 size: Option<u64>,
305}
306
307#[derive(Debug, Deserialize)]
308struct OpenAiModelsResponse {
309 data: Vec<OpenAiModelInfo>,
310}
311
312#[derive(Debug, Deserialize)]
313struct OpenAiModelInfo {
314 id: String,
315}
316
317#[derive(Debug, Serialize)]
318struct OpenAiChatRequest {
319 model: String,
320 messages: Vec<OpenAiMessage>,
321 #[serde(skip_serializing_if = "Option::is_none")]
322 temperature: Option<f32>,
323 #[serde(skip_serializing_if = "Option::is_none")]
324 max_tokens: Option<usize>,
325 stream: bool,
326}
327
328#[derive(Debug, Serialize, Deserialize)]
329struct OpenAiMessage {
330 role: String,
331 content: String,
332}
333
334#[derive(Debug, Deserialize)]
335struct OpenAiChatResponse {
336 model: String,
337 choices: Vec<OpenAiChoice>,
338 usage: Option<OpenAiUsageInfo>,
339}
340
341#[derive(Debug, Deserialize)]
342struct OpenAiChoice {
343 message: OpenAiMessage,
344}
345
346#[derive(Debug, Deserialize)]
347struct OpenAiUsageInfo {
348 prompt_tokens: usize,
349 completion_tokens: usize,
350 total_tokens: usize,
351}
352
353#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[tokio::test]
362 async fn test_detect_local_llms() {
363 let detected = detect_local_llms().await;
365 println!("Detected {} local LLM server(s)", detected.len());
366 for info in &detected {
367 println!(
368 " - {} at {} with {} models",
369 info.server_type,
370 info.base_url,
371 info.models.len()
372 );
373 for model in &info.models {
374 println!(" • {}", model);
375 }
376 }
377 }
378
379 #[tokio::test]
380 async fn test_check_server_timeout() {
381 let start = std::time::Instant::now();
383 let result = check_server("localhost", 59999, 100).await;
384 let elapsed = start.elapsed();
385
386 assert!(!result);
387 assert!(
388 elapsed.as_millis() < 500,
389 "Timeout took too long: {:?}",
390 elapsed
391 );
392 }
393}