Skip to main content

vtcode_core/llm/providers/lmstudio/
client.rs

1/// High-level LM Studio client for server interaction and model management.
2///
3/// Supports both LM Studio native REST API (`/api/v0/*`) and OpenAI-compatible
4/// endpoints (`/v1/*`). The native API provides enhanced features like model
5/// management (load/unload), richer model metadata, and TTL-based auto-evict.
6///
7/// See: https://lmstudio.ai/docs/developer
8use std::io;
9use std::path::Path;
10use std::time::Duration;
11
12use serde_json::Value as JsonValue;
13
14pub const LMSTUDIO_CONNECTION_ERROR: &str = "LM Studio is not responding. Install from https://lmstudio.ai/download and run 'lms server start'.";
15
16/// Client for interacting with a local LM Studio instance.
17///
18/// Supports both native REST API (`/api/v0/*`) and OpenAI-compatible endpoints (`/v1/*`).
19#[derive(Clone, Debug)]
20pub struct LMStudioClient {
21    client: reqwest::Client,
22    base_url: String,
23    /// Use native REST API endpoints (default: false, uses OpenAI-compatible endpoints)
24    use_native_api: bool,
25}
26
27impl LMStudioClient {
28    /// Create a client from a base URL and verify the server is reachable.
29    pub async fn try_from_base_url(base_url: &str) -> io::Result<Self> {
30        Self::try_from_base_url_with_api_version(base_url, false).await
31    }
32
33    /// Create a client with explicit API version selection.
34    ///
35    /// - `use_native_api = false`: Use OpenAI-compatible endpoints at `/v1/*` (default)
36    /// - `use_native_api = true`: Use native REST API at `/api/v0/*`
37    pub async fn try_from_base_url_with_api_version(
38        base_url: &str,
39        use_native_api: bool,
40    ) -> io::Result<Self> {
41        let client = reqwest::Client::builder()
42            .connect_timeout(Duration::from_secs(5))
43            .build()
44            .unwrap_or_else(|_| reqwest::Client::new());
45
46        let instance = Self {
47            client,
48            base_url: base_url.to_string(),
49            use_native_api,
50        };
51
52        instance.check_server().await?;
53        Ok(instance)
54    }
55
56    /// Get the models endpoint URL based on API version.
57    fn models_endpoint(&self) -> String {
58        let base = self.base_url.trim_end_matches('/');
59        if self.use_native_api {
60            format!("{base}/api/v0/models")
61        } else {
62            format!("{base}/v1/models")
63        }
64    }
65
66    /// Verify that the server is reachable.
67    async fn check_server(&self) -> io::Result<()> {
68        let url = self.models_endpoint();
69        let response = self.client.get(&url).send().await;
70
71        if let Ok(resp) = response {
72            if resp.status().is_success() {
73                Ok(())
74            } else {
75                Err(io::Error::other(format!(
76                    "Server returned error: {} {LMSTUDIO_CONNECTION_ERROR}",
77                    resp.status()
78                )))
79            }
80        } else {
81            Err(io::Error::other(LMSTUDIO_CONNECTION_ERROR))
82        }
83    }
84
85    /// Fetch the list of model IDs available on the server.
86    pub async fn fetch_models(&self) -> io::Result<Vec<String>> {
87        let url = self.models_endpoint();
88        let response = self
89            .client
90            .get(&url)
91            .send()
92            .await
93            .map_err(|e| io::Error::other(format!("Request failed: {e}")))?;
94
95        if response.status().is_success() {
96            let json: JsonValue = response.json().await.map_err(|e| {
97                io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}"))
98            })?;
99
100            let models = json["data"]
101                .as_array()
102                .ok_or_else(|| {
103                    io::Error::new(io::ErrorKind::InvalidData, "No 'data' array in response")
104                })?
105                .iter()
106                .filter_map(|model| model["id"].as_str())
107                .map(ToString::to_string)
108                .collect();
109
110            Ok(models)
111        } else {
112            Err(io::Error::other(format!(
113                "Failed to fetch models: {}",
114                response.status()
115            )))
116        }
117    }
118
119    /// Load a model into memory (pre-loads for faster inference).
120    ///
121    /// Uses native REST API `/api/v0/models/load` when `use_native_api` is true,
122    /// otherwise sends a minimal request via `/v1/chat/completions`.
123    pub async fn load_model(&self, model: &str) -> io::Result<()> {
124        if self.use_native_api {
125            let url = format!("{}/api/v0/models/load", self.base_url.trim_end_matches('/'));
126            let request_body = serde_json::json!({
127                "model": model
128            });
129
130            let response = self
131                .client
132                .post(&url)
133                .header("Content-Type", "application/json")
134                .json(&request_body)
135                .send()
136                .await
137                .map_err(|e| io::Error::other(format!("Request failed: {e}")))?;
138
139            if response.status().is_success() {
140                tracing::info!("Successfully loaded model '{model}' via native API");
141                Ok(())
142            } else {
143                Err(io::Error::other(format!(
144                    "Failed to load model: {}",
145                    response.status()
146                )))
147            }
148        } else {
149            // Use OpenAI-compatible endpoint with minimal chat completion
150            let url = format!(
151                "{}/v1/chat/completions",
152                self.base_url.trim_end_matches('/')
153            );
154            let request_body = serde_json::json!({
155                "model": model,
156                "messages": [{"role": "user", "content": "hi"}],
157                "max_tokens": 1
158            });
159
160            let response = self
161                .client
162                .post(&url)
163                .header("Content-Type", "application/json")
164                .json(&request_body)
165                .send()
166                .await
167                .map_err(|e| io::Error::other(format!("Request failed: {e}")))?;
168
169            if response.status().is_success() {
170                tracing::info!("Successfully loaded model '{model}'");
171                Ok(())
172            } else {
173                Err(io::Error::other(format!(
174                    "Failed to load model: {}",
175                    response.status()
176                )))
177            }
178        }
179    }
180
181    /// Unload a model from memory (native REST API only).
182    ///
183    /// This endpoint requires `use_native_api = true`.
184    pub async fn unload_model(&self, model: &str) -> io::Result<()> {
185        if !self.use_native_api {
186            return Err(io::Error::other(
187                "Model unload requires native API (use_native_api = true)",
188            ));
189        }
190
191        let url = format!(
192            "{}/api/v0/models/unload",
193            self.base_url.trim_end_matches('/')
194        );
195        let request_body = serde_json::json!({
196            "model": model
197        });
198
199        let response = self
200            .client
201            .post(&url)
202            .header("Content-Type", "application/json")
203            .json(&request_body)
204            .send()
205            .await
206            .map_err(|e| io::Error::other(format!("Request failed: {e}")))?;
207
208        if response.status().is_success() {
209            tracing::info!("Successfully unloaded model '{model}'");
210            Ok(())
211        } else {
212            Err(io::Error::other(format!(
213                "Failed to unload model: {}",
214                response.status()
215            )))
216        }
217    }
218
219    /// Find the `lms` CLI tool, checking PATH and fallback locations.
220    fn find_lms() -> io::Result<String> {
221        Self::find_lms_with_home_dir(None)
222    }
223
224    /// Find `lms` CLI with an optional home directory override (for testing).
225    fn find_lms_with_home_dir(home_dir: Option<&str>) -> io::Result<String> {
226        // First try 'lms' in PATH
227        if which::which("lms").is_ok() {
228            return Ok("lms".to_string());
229        }
230
231        // Platform-specific fallback paths
232        let home = match home_dir {
233            Some(dir) => dir.to_string(),
234            None => {
235                #[cfg(unix)]
236                {
237                    std::env::var("HOME").unwrap_or_default()
238                }
239                #[cfg(windows)]
240                {
241                    std::env::var("USERPROFILE").unwrap_or_default()
242                }
243            }
244        };
245
246        #[cfg(unix)]
247        let fallback_path = format!("{home}/.lmstudio/bin/lms");
248        #[cfg(windows)]
249        let fallback_path = format!("{home}/.lmstudio/bin/lms.exe");
250
251        if Path::new(&fallback_path).exists() {
252            Ok(fallback_path)
253        } else {
254            Err(io::Error::new(
255                io::ErrorKind::NotFound,
256                "LM Studio not found. Please install LM Studio from https://lmstudio.ai/",
257            ))
258        }
259    }
260
261    /// Download a model using the `lms` CLI tool.
262    pub async fn download_model(&self, model: &str) -> io::Result<()> {
263        let lms = Self::find_lms()?;
264        tracing::info!(model, "downloading model");
265
266        let status = std::process::Command::new(&lms)
267            .args(["get", "--yes", model])
268            .stdout(std::process::Stdio::inherit())
269            .stderr(std::process::Stdio::null())
270            .status()
271            .map_err(|e| {
272                io::Error::other(format!("Failed to execute '{lms} get --yes {model}': {e}"))
273            })?;
274
275        if !status.success() {
276            return Err(io::Error::other(format!(
277                "Model download failed with exit code: {}",
278                status.code().unwrap_or(-1)
279            )));
280        }
281
282        tracing::info!("Successfully downloaded model '{model}'");
283        Ok(())
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    fn panic_message(payload: Box<dyn std::any::Any + Send>) -> String {
292        if let Some(message) = payload.downcast_ref::<String>() {
293            return message.clone();
294        }
295        if let Some(message) = payload.downcast_ref::<&str>() {
296            return (*message).to_string();
297        }
298        "unknown panic".to_string()
299    }
300
301    async fn start_mock_server_or_skip() -> Option<wiremock::MockServer> {
302        match tokio::spawn(async { wiremock::MockServer::start().await }).await {
303            Ok(server) => Some(server),
304            Err(err) if err.is_panic() => {
305                let message = panic_message(err.into_panic());
306                if message.contains("Operation not permitted")
307                    || message.contains("PermissionDenied")
308                {
309                    return None;
310                }
311                panic!("mock server should start: {message}");
312            }
313            Err(err) => panic!("mock server task should complete: {err}"),
314        }
315    }
316
317    #[test]
318    fn test_find_lms() {
319        let result = LMStudioClient::find_lms();
320        match result {
321            Ok(_) => {
322                // lms was found in PATH - that's fine
323            }
324            Err(e) => {
325                // Expected error when LM Studio not installed
326                assert!(e.to_string().contains("LM Studio not found"));
327            }
328        }
329    }
330
331    #[test]
332    fn test_find_lms_with_mock_home() {
333        // Test fallback path construction without touching env vars
334        #[cfg(unix)]
335        {
336            let result = LMStudioClient::find_lms_with_home_dir(Some("/test/home"));
337            if let Err(e) = result {
338                assert!(e.to_string().contains("LM Studio not found"));
339            }
340        }
341        #[cfg(windows)]
342        {
343            let result = LMStudioClient::find_lms_with_home_dir(Some("C:\\test\\home"));
344            if let Err(e) = result {
345                assert!(e.to_string().contains("LM Studio not found"));
346            }
347        }
348    }
349
350    #[tokio::test]
351    async fn test_fetch_models_happy_path() {
352        if std::env::var("CODEX_SANDBOX_NETWORK_DISABLED").is_ok() {
353            return;
354        }
355
356        let Some(server) = start_mock_server_or_skip().await else {
357            return;
358        };
359        wiremock::Mock::given(wiremock::matchers::method("GET"))
360            .and(wiremock::matchers::path("/v1/models"))
361            .respond_with(
362                wiremock::ResponseTemplate::new(200).set_body_raw(
363                    serde_json::json!({
364                        "data": [
365                            {"id": "openai/gpt-oss-20b"},
366                        ]
367                    })
368                    .to_string(),
369                    "application/json",
370                ),
371            )
372            .mount(&server)
373            .await;
374
375        let client = LMStudioClient::try_from_base_url(&server.uri()).await;
376        assert!(client.is_ok());
377
378        let client = client.unwrap();
379        let models = client.fetch_models().await.expect("fetch models");
380        assert!(models.contains(&"openai/gpt-oss-20b".to_string()));
381    }
382
383    #[tokio::test]
384    async fn test_fetch_models_native_api() {
385        if std::env::var("CODEX_SANDBOX_NETWORK_DISABLED").is_ok() {
386            return;
387        }
388
389        let Some(server) = start_mock_server_or_skip().await else {
390            return;
391        };
392        wiremock::Mock::given(wiremock::matchers::method("GET"))
393            .and(wiremock::matchers::path("/api/v0/models"))
394            .respond_with(
395                wiremock::ResponseTemplate::new(200).set_body_raw(
396                    serde_json::json!({
397                        "data": [
398                            {"id": "lmstudio-community/meta-llama-3.1-8b-instruct"},
399                        ]
400                    })
401                    .to_string(),
402                    "application/json",
403                ),
404            )
405            .mount(&server)
406            .await;
407
408        let client = LMStudioClient::try_from_base_url_with_api_version(&server.uri(), true).await;
409        assert!(client.is_ok());
410
411        let client = client.unwrap();
412        let models = client.fetch_models().await.expect("fetch models");
413        assert!(models.contains(&"lmstudio-community/meta-llama-3.1-8b-instruct".to_string()));
414    }
415
416    #[tokio::test]
417    async fn test_fetch_models_no_data_array() {
418        if std::env::var("CODEX_SANDBOX_NETWORK_DISABLED").is_ok() {
419            return;
420        }
421
422        let Some(server) = start_mock_server_or_skip().await else {
423            return;
424        };
425        wiremock::Mock::given(wiremock::matchers::method("GET"))
426            .and(wiremock::matchers::path("/v1/models"))
427            .respond_with(
428                wiremock::ResponseTemplate::new(200)
429                    .set_body_raw(serde_json::json!({}).to_string(), "application/json"),
430            )
431            .mount(&server)
432            .await;
433
434        let client = LMStudioClient::try_from_base_url(&server.uri()).await;
435        let client = client.unwrap();
436        let result = client.fetch_models().await;
437
438        assert!(result.is_err());
439        assert!(
440            result
441                .unwrap_err()
442                .to_string()
443                .contains("No 'data' array in response")
444        );
445    }
446
447    #[tokio::test]
448    async fn test_check_server_happy_path() {
449        if std::env::var("CODEX_SANDBOX_NETWORK_DISABLED").is_ok() {
450            return;
451        }
452
453        let Some(server) = start_mock_server_or_skip().await else {
454            return;
455        };
456        wiremock::Mock::given(wiremock::matchers::method("GET"))
457            .and(wiremock::matchers::path("/v1/models"))
458            .respond_with(wiremock::ResponseTemplate::new(200))
459            .mount(&server)
460            .await;
461
462        let result = LMStudioClient::try_from_base_url(&server.uri()).await;
463        result.unwrap();
464    }
465
466    #[tokio::test]
467    async fn test_check_server_error() {
468        if std::env::var("CODEX_SANDBOX_NETWORK_DISABLED").is_ok() {
469            return;
470        }
471
472        let Some(server) = start_mock_server_or_skip().await else {
473            return;
474        };
475        wiremock::Mock::given(wiremock::matchers::method("GET"))
476            .and(wiremock::matchers::path("/v1/models"))
477            .respond_with(wiremock::ResponseTemplate::new(404))
478            .mount(&server)
479            .await;
480
481        let result = LMStudioClient::try_from_base_url(&server.uri()).await;
482        assert!(result.is_err());
483        assert!(
484            result
485                .unwrap_err()
486                .to_string()
487                .contains("Server returned error: 404")
488        );
489    }
490}