Skip to main content

subx_cli/services/ai/
local.rs

1//! Local / OpenAI-compatible LLM provider client.
2//!
3//! `LocalLLMClient` targets any OpenAI-compatible chat-completions endpoint
4//! exposed by a local, LAN, VPN, or remote runtime (Ollama, LM Studio,
5//! llama.cpp `llama-server`, vLLM, text-generation-webui, etc.). It mirrors
6//! the structure of [`crate::services::ai::openrouter::OpenRouterClient`]
7//! but treats the API key as optional and emits actionable, sanitized
8//! error messages when the local endpoint is unreachable, returns a
9//! non-OpenAI response, or signals that the requested model is not loaded.
10
11use crate::Result;
12use crate::cli::display_ai_usage;
13use crate::error::SubXError;
14use crate::services::ai::AiUsageStats;
15use crate::services::ai::{
16    AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
17};
18use async_trait::async_trait;
19use reqwest::Client;
20use serde_json::{Value, json};
21use std::time::Duration;
22use tokio::time;
23
24use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
25use crate::services::ai::retry::HttpRetryClient;
26
27/// Client for OpenAI-compatible local LLM runtimes.
28///
29/// The struct mirrors the field layout of the hosted-provider clients
30/// (`OpenAIClient`, `OpenRouterClient`) but stores the API key as
31/// `Option<String>` because most local runtimes accept unauthenticated
32/// requests. The `request_timeout_seconds` value is retained so it can be
33/// embedded in timeout error messages.
34pub struct LocalLLMClient {
35    client: Client,
36    api_key: Option<String>,
37    model: String,
38    temperature: f32,
39    max_tokens: u32,
40    retry_attempts: u32,
41    retry_delay_ms: u64,
42    base_url: String,
43    request_timeout_seconds: u64,
44}
45
46impl std::fmt::Debug for LocalLLMClient {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("LocalLLMClient")
49            .field("client", &self.client)
50            .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
51            .field("model", &self.model)
52            .field("temperature", &self.temperature)
53            .field("max_tokens", &self.max_tokens)
54            .field("retry_attempts", &self.retry_attempts)
55            .field("retry_delay_ms", &self.retry_delay_ms)
56            .field("base_url", &self.base_url)
57            .field("request_timeout_seconds", &self.request_timeout_seconds)
58            .finish()
59    }
60}
61
62impl PromptBuilder for LocalLLMClient {}
63impl ResponseParser for LocalLLMClient {}
64impl HttpRetryClient for LocalLLMClient {
65    fn retry_attempts(&self) -> u32 {
66        self.retry_attempts
67    }
68    fn retry_delay_ms(&self) -> u64 {
69        self.retry_delay_ms
70    }
71}
72
73impl LocalLLMClient {
74    /// Construct a new `LocalLLMClient` with explicit parameters.
75    #[allow(clippy::too_many_arguments)]
76    pub fn new(
77        api_key: Option<String>,
78        model: String,
79        temperature: f32,
80        max_tokens: u32,
81        retry_attempts: u32,
82        retry_delay_ms: u64,
83        base_url: String,
84        request_timeout_seconds: u64,
85    ) -> Self {
86        let client = Client::builder()
87            .timeout(Duration::from_secs(request_timeout_seconds))
88            .build()
89            .expect("Failed to create HTTP client");
90
91        // Normalize an empty/whitespace-only key to `None` so request paths
92        // do not have to repeat the check.
93        let api_key = api_key.and_then(|k| {
94            let trimmed = k.trim().to_string();
95            if trimmed.is_empty() {
96                None
97            } else {
98                Some(trimmed)
99            }
100        });
101
102        Self {
103            client,
104            api_key,
105            model,
106            temperature,
107            max_tokens,
108            retry_attempts,
109            retry_delay_ms,
110            // Trim a single trailing slash so URL joining always yields one
111            // separator. Both `http://h/v1/` and `http://h/v1` collapse to
112            // the same stored form.
113            base_url: base_url.trim_end_matches('/').to_string(),
114            request_timeout_seconds,
115        }
116    }
117
118    /// Create a `LocalLLMClient` from the unified `AIConfig`.
119    ///
120    /// Validates that `base_url` is non-empty, emits the shared insecure-HTTP
121    /// warning (which already exempts loopback), and constructs an inner
122    /// `reqwest::Client` honoring `request_timeout_seconds`.
123    pub fn from_config(config: &crate::config::AIConfig) -> Result<Self> {
124        if config.base_url.trim().is_empty() {
125            return Err(SubXError::config(
126                "ai.base_url is required for the local provider",
127            ));
128        }
129
130        // The shared helper warns only when an api_key is configured against
131        // a non-loopback HTTP host. Pass the configured key (or empty string
132        // when absent) so the existing loopback exemption applies.
133        let api_key_for_warning = config.api_key.clone().unwrap_or_default();
134        crate::services::ai::security::warn_on_insecure_http_str(
135            &config.base_url,
136            &api_key_for_warning,
137        );
138
139        Ok(Self::new(
140            config.api_key.clone(),
141            config.model.clone(),
142            config.temperature,
143            config.max_tokens,
144            config.retry_attempts,
145            config.retry_delay_ms,
146            config.base_url.clone(),
147            config.request_timeout_seconds,
148        ))
149    }
150
151    /// Build the `chat/completions` URL by joining the stored `base_url`
152    /// with the path segment using exactly one `/` separator.
153    fn chat_completions_url(&self) -> String {
154        // `base_url` already had any single trailing slash stripped during
155        // construction, so this format never produces `//chat/completions`.
156        format!("{}/chat/completions", self.base_url)
157    }
158
159    /// Issue a chat-completions request, applying retry + actionable error
160    /// mapping. Sends only OpenAI-canonical body fields.
161    pub async fn chat_completion(&self, messages: Vec<Value>) -> Result<String> {
162        let request_body = json!({
163            "model": self.model,
164            "messages": messages,
165            "temperature": self.temperature,
166            "max_tokens": self.max_tokens,
167        });
168
169        let mut builder = self
170            .client
171            .post(self.chat_completions_url())
172            .header("Content-Type", "application/json")
173            .json(&request_body);
174        if let Some(ref key) = self.api_key {
175            builder = builder.header("Authorization", format!("Bearer {}", key));
176        }
177
178        // Note: `local_provider_hint()` is intentionally NOT appended here.
179        // The advisory exists so hosted-provider clients can suggest the
180        // local provider; from this client (which IS the local provider)
181        // it would be tautological.
182
183        let mut response = self.send_with_retry(builder).await?;
184
185        const MAX_AI_RESPONSE_BYTES: u64 = 10 * 1024 * 1024; // 10 MiB
186        if let Some(len) = response.content_length() {
187            if len > MAX_AI_RESPONSE_BYTES {
188                return Err(SubXError::AiService(format!(
189                    "AI response too large: {} bytes (limit: {} bytes)",
190                    len, MAX_AI_RESPONSE_BYTES
191                )));
192            }
193        }
194
195        if !response.status().is_success() {
196            return Err(self.map_http_error(response).await);
197        }
198
199        // Bounded chunked read in case the server omits Content-Length.
200        let mut body = Vec::new();
201        while let Some(chunk) = response
202            .chunk()
203            .await
204            .map_err(|e| self.map_reqwest_error(e))?
205        {
206            body.extend_from_slice(&chunk);
207            if body.len() as u64 > MAX_AI_RESPONSE_BYTES {
208                return Err(SubXError::AiService(format!(
209                    "AI response too large: {} bytes read (limit: {} bytes)",
210                    body.len(),
211                    MAX_AI_RESPONSE_BYTES
212                )));
213            }
214        }
215
216        let response_json: Value = serde_json::from_slice(&body).map_err(|e| {
217            SubXError::AiService(format!(
218                "local LLM response was not OpenAI-compatible JSON: {}",
219                e
220            ))
221        })?;
222
223        let content = response_json["choices"][0]["message"]["content"]
224            .as_str()
225            .ok_or_else(|| {
226                SubXError::AiService(
227                    "local LLM response was not OpenAI-compatible JSON: \
228                     missing choices[0].message.content"
229                        .to_string(),
230                )
231            })?;
232
233        if let Some(usage_obj) = response_json.get("usage") {
234            if let (Some(p), Some(c), Some(t)) = (
235                usage_obj.get("prompt_tokens").and_then(Value::as_u64),
236                usage_obj.get("completion_tokens").and_then(Value::as_u64),
237                usage_obj.get("total_tokens").and_then(Value::as_u64),
238            ) {
239                let stats = AiUsageStats {
240                    model: self.model.clone(),
241                    prompt_tokens: p as u32,
242                    completion_tokens: c as u32,
243                    total_tokens: t as u32,
244                };
245                display_ai_usage(&stats);
246            }
247        }
248
249        Ok(content.to_string())
250    }
251
252    /// Send `request` with retry, surfacing actionable transport errors.
253    async fn send_with_retry(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
254        let mut attempts: u32 = 0;
255        loop {
256            let cloned = request.try_clone().ok_or_else(|| {
257                SubXError::AiService("Request body cannot be cloned for retry".to_string())
258            })?;
259            match cloned.send().await {
260                Ok(resp) => {
261                    if resp.status().is_server_error() && attempts < self.retry_attempts {
262                        attempts += 1;
263                        log::warn!(
264                            "Request attempt {} failed with status {}. Retrying in {}ms...",
265                            attempts,
266                            resp.status(),
267                            self.retry_delay_ms
268                        );
269                        time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
270                        continue;
271                    }
272                    return Ok(resp);
273                }
274                Err(e) if attempts < self.retry_attempts => {
275                    attempts += 1;
276                    log::warn!(
277                        "Request attempt {} failed: {}. Retrying in {}ms...",
278                        attempts,
279                        e,
280                        self.retry_delay_ms
281                    );
282                    time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
283                    continue;
284                }
285                Err(e) => return Err(self.map_reqwest_error(e)),
286            }
287        }
288    }
289
290    /// Map a low-level `reqwest::Error` into an actionable `SubXError`.
291    fn map_reqwest_error(&self, err: reqwest::Error) -> SubXError {
292        let url = sanitize_base_url(&self.base_url);
293        if err.is_timeout() {
294            return SubXError::AiService(format!(
295                "local LLM endpoint timed out after {}s: {}",
296                self.request_timeout_seconds, url
297            ));
298        }
299        if err.is_connect() {
300            return SubXError::AiService(format!("local LLM endpoint unreachable: {}", url));
301        }
302        // Fall back to the generic conversion (which already strips query
303        // strings from any embedded URLs via `error_sanitizer`).
304        err.into()
305    }
306
307    /// Map a non-2xx HTTP response into an actionable `SubXError`. Reads
308    /// the body, sanitizes it, and detects the "model not found" pattern.
309    async fn map_http_error(&self, response: reqwest::Response) -> SubXError {
310        let status = response.status();
311        let body_text = response.text().await.unwrap_or_default();
312        let safe_body = crate::services::ai::error_sanitizer::sanitize_url_in_error(
313            &crate::services::ai::error_sanitizer::truncate_error_body(
314                &body_text,
315                crate::services::ai::error_sanitizer::DEFAULT_ERROR_BODY_MAX_LEN,
316            ),
317        );
318
319        if status.as_u16() == 404 || body_indicates_model_missing(&body_text) {
320            return SubXError::AiService(format!("local LLM model not found: {}", self.model));
321        }
322
323        SubXError::AiService(format!(
324            "local LLM endpoint returned HTTP {}: {}",
325            status, safe_body
326        ))
327    }
328}
329
330/// Detect common "model not loaded / no such model" body patterns emitted
331/// by Ollama, LM Studio, llama.cpp, and vLLM.
332fn body_indicates_model_missing(body: &str) -> bool {
333    let lower = body.to_ascii_lowercase();
334    let mentions_model = lower.contains("model");
335    if !mentions_model {
336        return false;
337    }
338    lower.contains("not found")
339        || lower.contains("not loaded")
340        || lower.contains("no such model")
341        || lower.contains("unknown model")
342}
343
344/// Strip userinfo, query strings, and fragments from a base URL so it can
345/// be safely embedded in user-facing error messages.
346///
347/// Returns only `scheme://host[:port]/path`. Trailing slashes on the path
348/// are preserved as authored. If the input cannot be parsed as a URL the
349/// helper falls back to `"<unparseable URL>"` so credentials hidden in a
350/// malformed string can never leak.
351pub(crate) fn sanitize_base_url(input: &str) -> String {
352    match url::Url::parse(input) {
353        Ok(mut url) => {
354            // Wipe credentials. The setters only fail when the URL cannot
355            // carry userinfo; ignore any error and continue.
356            let _ = url.set_username("");
357            let _ = url.set_password(None);
358            url.set_query(None);
359            url.set_fragment(None);
360
361            let scheme = url.scheme();
362            // Render IPv6 hosts with brackets so the result remains a
363            // syntactically valid URL (e.g. `http://[::1]:11434/v1`).
364            let host_display = match url.host() {
365                Some(url::Host::Ipv6(addr)) => format!("[{}]", addr),
366                Some(_) => url.host_str().unwrap_or_default().to_string(),
367                None => return "<unparseable URL>".to_string(),
368            };
369            let path = url.path();
370            match url.port() {
371                Some(port) => format!("{}://{}:{}{}", scheme, host_display, port, path),
372                None => format!("{}://{}{}", scheme, host_display, path),
373            }
374        }
375        Err(_) => "<unparseable URL>".to_string(),
376    }
377}
378
379#[async_trait]
380impl AIProvider for LocalLLMClient {
381    async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
382        let prompt = self.build_analysis_prompt(&request);
383        let messages = vec![
384            json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
385            json!({"role": "user", "content": prompt}),
386        ];
387        let response = self.chat_completion(messages).await?;
388        self.parse_match_result(&response)
389    }
390
391    async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
392        let prompt = self.build_verification_prompt(&verification);
393        let messages = vec![
394            json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
395            json!({"role": "user", "content": prompt}),
396        ];
397        let response = self.chat_completion(messages).await?;
398        self.parse_confidence_score(&response)
399    }
400
401    async fn chat_completion(&self, messages: Vec<Value>) -> Result<String> {
402        LocalLLMClient::chat_completion(self, messages).await
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    fn make_client(base_url: &str, api_key: Option<&str>) -> LocalLLMClient {
411        LocalLLMClient::new(
412            api_key.map(|s| s.to_string()),
413            "llama3.1:8b-instruct".to_string(),
414            0.3,
415            1024,
416            1,
417            10,
418            base_url.to_string(),
419            120,
420        )
421    }
422
423    #[test]
424    fn debug_redacts_api_key() {
425        let client = make_client("http://localhost:11434/v1", Some("super-secret-token"));
426        let rendered = format!("{:?}", client);
427        assert!(
428            rendered.contains("[REDACTED]"),
429            "Debug output should redact api_key, got: {rendered}"
430        );
431        assert!(!rendered.contains("super-secret-token"));
432    }
433
434    #[test]
435    fn debug_marks_missing_api_key_as_none() {
436        let client = make_client("http://localhost:11434/v1", None);
437        let rendered = format!("{:?}", client);
438        assert!(rendered.contains("api_key: None"), "got: {rendered}");
439    }
440
441    #[test]
442    fn url_join_with_trailing_slash() {
443        let client = make_client("http://localhost:11434/v1/", None);
444        assert_eq!(
445            client.chat_completions_url(),
446            "http://localhost:11434/v1/chat/completions"
447        );
448        assert!(!client.chat_completions_url().contains("//chat"));
449    }
450
451    #[test]
452    fn url_join_without_trailing_slash() {
453        let client = make_client("http://localhost:11434/v1", None);
454        assert_eq!(
455            client.chat_completions_url(),
456            "http://localhost:11434/v1/chat/completions"
457        );
458    }
459
460    #[test]
461    fn url_join_root_base_url() {
462        let client = make_client("http://localhost:11434", None);
463        assert_eq!(
464            client.chat_completions_url(),
465            "http://localhost:11434/chat/completions"
466        );
467    }
468
469    #[test]
470    fn sanitize_base_url_strips_userinfo_query_and_fragment() {
471        assert_eq!(
472            sanitize_base_url("http://user:secret@127.0.0.1:11434/v1?token=abc#frag"),
473            "http://127.0.0.1:11434/v1"
474        );
475    }
476
477    #[test]
478    fn sanitize_base_url_preserves_plain_localhost() {
479        assert_eq!(
480            sanitize_base_url("http://localhost:11434/v1"),
481            "http://localhost:11434/v1"
482        );
483    }
484
485    #[test]
486    fn sanitize_base_url_preserves_trailing_slash() {
487        // Trailing slashes on the path are preserved as authored.
488        assert_eq!(
489            sanitize_base_url("https://host:8080/api/v1/"),
490            "https://host:8080/api/v1/"
491        );
492    }
493
494    #[test]
495    fn sanitize_base_url_handles_unparseable_input() {
496        assert_eq!(sanitize_base_url("not a url"), "<unparseable URL>");
497        assert_eq!(sanitize_base_url(""), "<unparseable URL>");
498    }
499
500    #[test]
501    fn sanitize_base_url_strips_password_only() {
502        assert_eq!(
503            sanitize_base_url("https://:pwd@host:8080/v1"),
504            "https://host:8080/v1"
505        );
506    }
507
508    #[test]
509    fn sanitize_base_url_preserves_ipv6_brackets() {
510        assert_eq!(
511            sanitize_base_url("http://[::1]:11434/v1"),
512            "http://[::1]:11434/v1"
513        );
514        assert_eq!(
515            sanitize_base_url("https://[fd00::1]:8443/v1/"),
516            "https://[fd00::1]:8443/v1/"
517        );
518        // Userinfo on IPv6 host must still be stripped while brackets remain.
519        assert_eq!(
520            sanitize_base_url("http://user:pwd@[::1]:11434/v1?token=secret"),
521            "http://[::1]:11434/v1"
522        );
523    }
524
525    #[test]
526    fn body_indicates_model_missing_detects_common_patterns() {
527        assert!(body_indicates_model_missing(
528            "{\"error\":\"model 'foo' not found, try pulling it first\"}"
529        ));
530        assert!(body_indicates_model_missing(
531            "{\"error\":\"Model not loaded\"}"
532        ));
533        assert!(body_indicates_model_missing(
534            "{\"detail\":\"no such model: bar\"}"
535        ));
536        assert!(body_indicates_model_missing(
537            "{\"error\":\"unknown model llama99\"}"
538        ));
539        assert!(!body_indicates_model_missing(
540            "{\"error\":\"server overloaded\"}"
541        ));
542        assert!(!body_indicates_model_missing(""));
543    }
544
545    fn make_config(base_url: &str, api_key: Option<&str>) -> crate::config::AIConfig {
546        crate::config::AIConfig {
547            provider: "local".to_string(),
548            api_key: api_key.map(|s| s.to_string()),
549            model: "llama3.1:8b-instruct".to_string(),
550            base_url: base_url.to_string(),
551            max_sample_length: 500,
552            temperature: 0.3,
553            max_tokens: 1024,
554            retry_attempts: 2,
555            retry_delay_ms: 100,
556            request_timeout_seconds: 120,
557            api_version: None,
558        }
559    }
560
561    #[test]
562    fn from_config_rejects_empty_base_url() {
563        let config = make_config("", None);
564        let err = LocalLLMClient::from_config(&config).unwrap_err();
565        assert!(
566            err.to_string().contains("ai.base_url is required"),
567            "unexpected error: {err}"
568        );
569    }
570
571    #[test]
572    fn from_config_rejects_whitespace_base_url() {
573        let config = make_config("   ", None);
574        assert!(LocalLLMClient::from_config(&config).is_err());
575    }
576
577    #[test]
578    fn from_config_accepts_loopback_http() {
579        let config = make_config("http://localhost:11434/v1", None);
580        let client = LocalLLMClient::from_config(&config).expect("should accept loopback HTTP");
581        assert!(client.api_key.is_none());
582        assert_eq!(client.base_url, "http://localhost:11434/v1");
583    }
584
585    #[test]
586    fn from_config_accepts_lan_http() {
587        let config = make_config("http://192.168.1.50:11434/v1", None);
588        let client = LocalLLMClient::from_config(&config).expect("LAN HTTP must be accepted");
589        assert_eq!(client.base_url, "http://192.168.1.50:11434/v1");
590    }
591
592    #[test]
593    fn from_config_accepts_https() {
594        let config = make_config("https://ollama.tailnet.ts.net/v1", Some("vllm-token"));
595        let client = LocalLLMClient::from_config(&config).expect("HTTPS must be accepted");
596        assert_eq!(client.base_url, "https://ollama.tailnet.ts.net/v1");
597        assert_eq!(client.api_key.as_deref(), Some("vllm-token"));
598    }
599
600    #[test]
601    fn from_config_normalizes_empty_api_key_to_none() {
602        let config = make_config("http://localhost:11434/v1", Some(""));
603        let client = LocalLLMClient::from_config(&config).unwrap();
604        assert!(
605            client.api_key.is_none(),
606            "empty api_key should normalize to None"
607        );
608    }
609
610    #[test]
611    fn from_config_trims_trailing_slash_in_base_url() {
612        let config = make_config("http://localhost:11434/v1/", None);
613        let client = LocalLLMClient::from_config(&config).unwrap();
614        assert_eq!(client.base_url, "http://localhost:11434/v1");
615        assert_eq!(
616            client.chat_completions_url(),
617            "http://localhost:11434/v1/chat/completions"
618        );
619    }
620}