vtcode_core/llm/providers/lmstudio/
client.rs1use std::io;
9use std::path::Path;
10use std::time::Duration;
11
12use serde_json::Value as JsonValue;
13
14pub const LMSTUDIO_CONNECTION_ERROR: &str = "LM Studio is not responding. Install from https://lmstudio.ai/download and run 'lms server start'.";
15
16#[derive(Clone, Debug)]
20pub struct LMStudioClient {
21 client: reqwest::Client,
22 base_url: String,
23 use_native_api: bool,
25}
26
27impl LMStudioClient {
28 pub async fn try_from_base_url(base_url: &str) -> io::Result<Self> {
30 Self::try_from_base_url_with_api_version(base_url, false).await
31 }
32
33 pub async fn try_from_base_url_with_api_version(
38 base_url: &str,
39 use_native_api: bool,
40 ) -> io::Result<Self> {
41 let client = reqwest::Client::builder()
42 .connect_timeout(Duration::from_secs(5))
43 .build()
44 .unwrap_or_else(|_| reqwest::Client::new());
45
46 let instance = Self {
47 client,
48 base_url: base_url.to_string(),
49 use_native_api,
50 };
51
52 instance.check_server().await?;
53 Ok(instance)
54 }
55
56 fn models_endpoint(&self) -> String {
58 let base = self.base_url.trim_end_matches('/');
59 if self.use_native_api {
60 format!("{base}/api/v0/models")
61 } else {
62 format!("{base}/v1/models")
63 }
64 }
65
66 async fn check_server(&self) -> io::Result<()> {
68 let url = self.models_endpoint();
69 let response = self.client.get(&url).send().await;
70
71 if let Ok(resp) = response {
72 if resp.status().is_success() {
73 Ok(())
74 } else {
75 Err(io::Error::other(format!(
76 "Server returned error: {} {LMSTUDIO_CONNECTION_ERROR}",
77 resp.status()
78 )))
79 }
80 } else {
81 Err(io::Error::other(LMSTUDIO_CONNECTION_ERROR))
82 }
83 }
84
85 pub async fn fetch_models(&self) -> io::Result<Vec<String>> {
87 let url = self.models_endpoint();
88 let response = self
89 .client
90 .get(&url)
91 .send()
92 .await
93 .map_err(|e| io::Error::other(format!("Request failed: {e}")))?;
94
95 if response.status().is_success() {
96 let json: JsonValue = response.json().await.map_err(|e| {
97 io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}"))
98 })?;
99
100 let models = json["data"]
101 .as_array()
102 .ok_or_else(|| {
103 io::Error::new(io::ErrorKind::InvalidData, "No 'data' array in response")
104 })?
105 .iter()
106 .filter_map(|model| model["id"].as_str())
107 .map(ToString::to_string)
108 .collect();
109
110 Ok(models)
111 } else {
112 Err(io::Error::other(format!(
113 "Failed to fetch models: {}",
114 response.status()
115 )))
116 }
117 }
118
119 pub async fn load_model(&self, model: &str) -> io::Result<()> {
124 if self.use_native_api {
125 let url = format!("{}/api/v0/models/load", self.base_url.trim_end_matches('/'));
126 let request_body = serde_json::json!({
127 "model": model
128 });
129
130 let response = self
131 .client
132 .post(&url)
133 .header("Content-Type", "application/json")
134 .json(&request_body)
135 .send()
136 .await
137 .map_err(|e| io::Error::other(format!("Request failed: {e}")))?;
138
139 if response.status().is_success() {
140 tracing::info!("Successfully loaded model '{model}' via native API");
141 Ok(())
142 } else {
143 Err(io::Error::other(format!(
144 "Failed to load model: {}",
145 response.status()
146 )))
147 }
148 } else {
149 let url = format!(
151 "{}/v1/chat/completions",
152 self.base_url.trim_end_matches('/')
153 );
154 let request_body = serde_json::json!({
155 "model": model,
156 "messages": [{"role": "user", "content": "hi"}],
157 "max_tokens": 1
158 });
159
160 let response = self
161 .client
162 .post(&url)
163 .header("Content-Type", "application/json")
164 .json(&request_body)
165 .send()
166 .await
167 .map_err(|e| io::Error::other(format!("Request failed: {e}")))?;
168
169 if response.status().is_success() {
170 tracing::info!("Successfully loaded model '{model}'");
171 Ok(())
172 } else {
173 Err(io::Error::other(format!(
174 "Failed to load model: {}",
175 response.status()
176 )))
177 }
178 }
179 }
180
181 pub async fn unload_model(&self, model: &str) -> io::Result<()> {
185 if !self.use_native_api {
186 return Err(io::Error::other(
187 "Model unload requires native API (use_native_api = true)",
188 ));
189 }
190
191 let url = format!(
192 "{}/api/v0/models/unload",
193 self.base_url.trim_end_matches('/')
194 );
195 let request_body = serde_json::json!({
196 "model": model
197 });
198
199 let response = self
200 .client
201 .post(&url)
202 .header("Content-Type", "application/json")
203 .json(&request_body)
204 .send()
205 .await
206 .map_err(|e| io::Error::other(format!("Request failed: {e}")))?;
207
208 if response.status().is_success() {
209 tracing::info!("Successfully unloaded model '{model}'");
210 Ok(())
211 } else {
212 Err(io::Error::other(format!(
213 "Failed to unload model: {}",
214 response.status()
215 )))
216 }
217 }
218
219 fn find_lms() -> io::Result<String> {
221 Self::find_lms_with_home_dir(None)
222 }
223
224 fn find_lms_with_home_dir(home_dir: Option<&str>) -> io::Result<String> {
226 if which::which("lms").is_ok() {
228 return Ok("lms".to_string());
229 }
230
231 let home = match home_dir {
233 Some(dir) => dir.to_string(),
234 None => {
235 #[cfg(unix)]
236 {
237 std::env::var("HOME").unwrap_or_default()
238 }
239 #[cfg(windows)]
240 {
241 std::env::var("USERPROFILE").unwrap_or_default()
242 }
243 }
244 };
245
246 #[cfg(unix)]
247 let fallback_path = format!("{home}/.lmstudio/bin/lms");
248 #[cfg(windows)]
249 let fallback_path = format!("{home}/.lmstudio/bin/lms.exe");
250
251 if Path::new(&fallback_path).exists() {
252 Ok(fallback_path)
253 } else {
254 Err(io::Error::new(
255 io::ErrorKind::NotFound,
256 "LM Studio not found. Please install LM Studio from https://lmstudio.ai/",
257 ))
258 }
259 }
260
261 pub async fn download_model(&self, model: &str) -> io::Result<()> {
263 let lms = Self::find_lms()?;
264 tracing::info!(model, "downloading model");
265
266 let status = std::process::Command::new(&lms)
267 .args(["get", "--yes", model])
268 .stdout(std::process::Stdio::inherit())
269 .stderr(std::process::Stdio::null())
270 .status()
271 .map_err(|e| {
272 io::Error::other(format!("Failed to execute '{lms} get --yes {model}': {e}"))
273 })?;
274
275 if !status.success() {
276 return Err(io::Error::other(format!(
277 "Model download failed with exit code: {}",
278 status.code().unwrap_or(-1)
279 )));
280 }
281
282 tracing::info!("Successfully downloaded model '{model}'");
283 Ok(())
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 fn panic_message(payload: Box<dyn std::any::Any + Send>) -> String {
292 if let Some(message) = payload.downcast_ref::<String>() {
293 return message.clone();
294 }
295 if let Some(message) = payload.downcast_ref::<&str>() {
296 return (*message).to_string();
297 }
298 "unknown panic".to_string()
299 }
300
301 async fn start_mock_server_or_skip() -> Option<wiremock::MockServer> {
302 match tokio::spawn(async { wiremock::MockServer::start().await }).await {
303 Ok(server) => Some(server),
304 Err(err) if err.is_panic() => {
305 let message = panic_message(err.into_panic());
306 if message.contains("Operation not permitted")
307 || message.contains("PermissionDenied")
308 {
309 return None;
310 }
311 panic!("mock server should start: {message}");
312 }
313 Err(err) => panic!("mock server task should complete: {err}"),
314 }
315 }
316
317 #[test]
318 fn test_find_lms() {
319 let result = LMStudioClient::find_lms();
320 match result {
321 Ok(_) => {
322 }
324 Err(e) => {
325 assert!(e.to_string().contains("LM Studio not found"));
327 }
328 }
329 }
330
331 #[test]
332 fn test_find_lms_with_mock_home() {
333 #[cfg(unix)]
335 {
336 let result = LMStudioClient::find_lms_with_home_dir(Some("/test/home"));
337 if let Err(e) = result {
338 assert!(e.to_string().contains("LM Studio not found"));
339 }
340 }
341 #[cfg(windows)]
342 {
343 let result = LMStudioClient::find_lms_with_home_dir(Some("C:\\test\\home"));
344 if let Err(e) = result {
345 assert!(e.to_string().contains("LM Studio not found"));
346 }
347 }
348 }
349
350 #[tokio::test]
351 async fn test_fetch_models_happy_path() {
352 if std::env::var("CODEX_SANDBOX_NETWORK_DISABLED").is_ok() {
353 return;
354 }
355
356 let Some(server) = start_mock_server_or_skip().await else {
357 return;
358 };
359 wiremock::Mock::given(wiremock::matchers::method("GET"))
360 .and(wiremock::matchers::path("/v1/models"))
361 .respond_with(
362 wiremock::ResponseTemplate::new(200).set_body_raw(
363 serde_json::json!({
364 "data": [
365 {"id": "openai/gpt-oss-20b"},
366 ]
367 })
368 .to_string(),
369 "application/json",
370 ),
371 )
372 .mount(&server)
373 .await;
374
375 let client = LMStudioClient::try_from_base_url(&server.uri()).await;
376 assert!(client.is_ok());
377
378 let client = client.unwrap();
379 let models = client.fetch_models().await.expect("fetch models");
380 assert!(models.contains(&"openai/gpt-oss-20b".to_string()));
381 }
382
383 #[tokio::test]
384 async fn test_fetch_models_native_api() {
385 if std::env::var("CODEX_SANDBOX_NETWORK_DISABLED").is_ok() {
386 return;
387 }
388
389 let Some(server) = start_mock_server_or_skip().await else {
390 return;
391 };
392 wiremock::Mock::given(wiremock::matchers::method("GET"))
393 .and(wiremock::matchers::path("/api/v0/models"))
394 .respond_with(
395 wiremock::ResponseTemplate::new(200).set_body_raw(
396 serde_json::json!({
397 "data": [
398 {"id": "lmstudio-community/meta-llama-3.1-8b-instruct"},
399 ]
400 })
401 .to_string(),
402 "application/json",
403 ),
404 )
405 .mount(&server)
406 .await;
407
408 let client = LMStudioClient::try_from_base_url_with_api_version(&server.uri(), true).await;
409 assert!(client.is_ok());
410
411 let client = client.unwrap();
412 let models = client.fetch_models().await.expect("fetch models");
413 assert!(models.contains(&"lmstudio-community/meta-llama-3.1-8b-instruct".to_string()));
414 }
415
416 #[tokio::test]
417 async fn test_fetch_models_no_data_array() {
418 if std::env::var("CODEX_SANDBOX_NETWORK_DISABLED").is_ok() {
419 return;
420 }
421
422 let Some(server) = start_mock_server_or_skip().await else {
423 return;
424 };
425 wiremock::Mock::given(wiremock::matchers::method("GET"))
426 .and(wiremock::matchers::path("/v1/models"))
427 .respond_with(
428 wiremock::ResponseTemplate::new(200)
429 .set_body_raw(serde_json::json!({}).to_string(), "application/json"),
430 )
431 .mount(&server)
432 .await;
433
434 let client = LMStudioClient::try_from_base_url(&server.uri()).await;
435 let client = client.unwrap();
436 let result = client.fetch_models().await;
437
438 assert!(result.is_err());
439 assert!(
440 result
441 .unwrap_err()
442 .to_string()
443 .contains("No 'data' array in response")
444 );
445 }
446
447 #[tokio::test]
448 async fn test_check_server_happy_path() {
449 if std::env::var("CODEX_SANDBOX_NETWORK_DISABLED").is_ok() {
450 return;
451 }
452
453 let Some(server) = start_mock_server_or_skip().await else {
454 return;
455 };
456 wiremock::Mock::given(wiremock::matchers::method("GET"))
457 .and(wiremock::matchers::path("/v1/models"))
458 .respond_with(wiremock::ResponseTemplate::new(200))
459 .mount(&server)
460 .await;
461
462 let result = LMStudioClient::try_from_base_url(&server.uri()).await;
463 result.unwrap();
464 }
465
466 #[tokio::test]
467 async fn test_check_server_error() {
468 if std::env::var("CODEX_SANDBOX_NETWORK_DISABLED").is_ok() {
469 return;
470 }
471
472 let Some(server) = start_mock_server_or_skip().await else {
473 return;
474 };
475 wiremock::Mock::given(wiremock::matchers::method("GET"))
476 .and(wiremock::matchers::path("/v1/models"))
477 .respond_with(wiremock::ResponseTemplate::new(404))
478 .mount(&server)
479 .await;
480
481 let result = LMStudioClient::try_from_base_url(&server.uri()).await;
482 assert!(result.is_err());
483 assert!(
484 result
485 .unwrap_err()
486 .to_string()
487 .contains("Server returned error: 404")
488 );
489 }
490}