1use crate::Result;
12use crate::cli::display_ai_usage;
13use crate::error::SubXError;
14use crate::services::ai::AiUsageStats;
15use crate::services::ai::{
16 AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
17};
18use async_trait::async_trait;
19use reqwest::Client;
20use serde_json::{Value, json};
21use std::time::Duration;
22use tokio::time;
23
24use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
25use crate::services::ai::retry::HttpRetryClient;
26
27pub struct LocalLLMClient {
35 client: Client,
36 api_key: Option<String>,
37 model: String,
38 temperature: f32,
39 max_tokens: u32,
40 retry_attempts: u32,
41 retry_delay_ms: u64,
42 base_url: String,
43 request_timeout_seconds: u64,
44}
45
46impl std::fmt::Debug for LocalLLMClient {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_struct("LocalLLMClient")
49 .field("client", &self.client)
50 .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
51 .field("model", &self.model)
52 .field("temperature", &self.temperature)
53 .field("max_tokens", &self.max_tokens)
54 .field("retry_attempts", &self.retry_attempts)
55 .field("retry_delay_ms", &self.retry_delay_ms)
56 .field("base_url", &self.base_url)
57 .field("request_timeout_seconds", &self.request_timeout_seconds)
58 .finish()
59 }
60}
61
62impl PromptBuilder for LocalLLMClient {}
63impl ResponseParser for LocalLLMClient {}
64impl HttpRetryClient for LocalLLMClient {
65 fn retry_attempts(&self) -> u32 {
66 self.retry_attempts
67 }
68 fn retry_delay_ms(&self) -> u64 {
69 self.retry_delay_ms
70 }
71}
72
73impl LocalLLMClient {
74 #[allow(clippy::too_many_arguments)]
76 pub fn new(
77 api_key: Option<String>,
78 model: String,
79 temperature: f32,
80 max_tokens: u32,
81 retry_attempts: u32,
82 retry_delay_ms: u64,
83 base_url: String,
84 request_timeout_seconds: u64,
85 ) -> Self {
86 let client = Client::builder()
87 .timeout(Duration::from_secs(request_timeout_seconds))
88 .build()
89 .expect("Failed to create HTTP client");
90
91 let api_key = api_key.and_then(|k| {
94 let trimmed = k.trim().to_string();
95 if trimmed.is_empty() {
96 None
97 } else {
98 Some(trimmed)
99 }
100 });
101
102 Self {
103 client,
104 api_key,
105 model,
106 temperature,
107 max_tokens,
108 retry_attempts,
109 retry_delay_ms,
110 base_url: base_url.trim_end_matches('/').to_string(),
114 request_timeout_seconds,
115 }
116 }
117
118 pub fn from_config(config: &crate::config::AIConfig) -> Result<Self> {
124 if config.base_url.trim().is_empty() {
125 return Err(SubXError::config(
126 "ai.base_url is required for the local provider",
127 ));
128 }
129
130 let api_key_for_warning = config.api_key.clone().unwrap_or_default();
134 crate::services::ai::security::warn_on_insecure_http_str(
135 &config.base_url,
136 &api_key_for_warning,
137 );
138
139 Ok(Self::new(
140 config.api_key.clone(),
141 config.model.clone(),
142 config.temperature,
143 config.max_tokens,
144 config.retry_attempts,
145 config.retry_delay_ms,
146 config.base_url.clone(),
147 config.request_timeout_seconds,
148 ))
149 }
150
151 fn chat_completions_url(&self) -> String {
154 format!("{}/chat/completions", self.base_url)
157 }
158
159 pub async fn chat_completion(&self, messages: Vec<Value>) -> Result<String> {
162 let request_body = json!({
163 "model": self.model,
164 "messages": messages,
165 "temperature": self.temperature,
166 "max_tokens": self.max_tokens,
167 });
168
169 let mut builder = self
170 .client
171 .post(self.chat_completions_url())
172 .header("Content-Type", "application/json")
173 .json(&request_body);
174 if let Some(ref key) = self.api_key {
175 builder = builder.header("Authorization", format!("Bearer {}", key));
176 }
177
178 let mut response = self.send_with_retry(builder).await?;
184
185 const MAX_AI_RESPONSE_BYTES: u64 = 10 * 1024 * 1024; if let Some(len) = response.content_length() {
187 if len > MAX_AI_RESPONSE_BYTES {
188 return Err(SubXError::AiService(format!(
189 "AI response too large: {} bytes (limit: {} bytes)",
190 len, MAX_AI_RESPONSE_BYTES
191 )));
192 }
193 }
194
195 if !response.status().is_success() {
196 return Err(self.map_http_error(response).await);
197 }
198
199 let mut body = Vec::new();
201 while let Some(chunk) = response
202 .chunk()
203 .await
204 .map_err(|e| self.map_reqwest_error(e))?
205 {
206 body.extend_from_slice(&chunk);
207 if body.len() as u64 > MAX_AI_RESPONSE_BYTES {
208 return Err(SubXError::AiService(format!(
209 "AI response too large: {} bytes read (limit: {} bytes)",
210 body.len(),
211 MAX_AI_RESPONSE_BYTES
212 )));
213 }
214 }
215
216 let response_json: Value = serde_json::from_slice(&body).map_err(|e| {
217 SubXError::AiService(format!(
218 "local LLM response was not OpenAI-compatible JSON: {}",
219 e
220 ))
221 })?;
222
223 let content = response_json["choices"][0]["message"]["content"]
224 .as_str()
225 .ok_or_else(|| {
226 SubXError::AiService(
227 "local LLM response was not OpenAI-compatible JSON: \
228 missing choices[0].message.content"
229 .to_string(),
230 )
231 })?;
232
233 if let Some(usage_obj) = response_json.get("usage") {
234 if let (Some(p), Some(c), Some(t)) = (
235 usage_obj.get("prompt_tokens").and_then(Value::as_u64),
236 usage_obj.get("completion_tokens").and_then(Value::as_u64),
237 usage_obj.get("total_tokens").and_then(Value::as_u64),
238 ) {
239 let stats = AiUsageStats {
240 model: self.model.clone(),
241 prompt_tokens: p as u32,
242 completion_tokens: c as u32,
243 total_tokens: t as u32,
244 };
245 display_ai_usage(&stats);
246 }
247 }
248
249 Ok(content.to_string())
250 }
251
252 async fn send_with_retry(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
254 let mut attempts: u32 = 0;
255 loop {
256 let cloned = request.try_clone().ok_or_else(|| {
257 SubXError::AiService("Request body cannot be cloned for retry".to_string())
258 })?;
259 match cloned.send().await {
260 Ok(resp) => {
261 if resp.status().is_server_error() && attempts < self.retry_attempts {
262 attempts += 1;
263 log::warn!(
264 "Request attempt {} failed with status {}. Retrying in {}ms...",
265 attempts,
266 resp.status(),
267 self.retry_delay_ms
268 );
269 time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
270 continue;
271 }
272 return Ok(resp);
273 }
274 Err(e) if attempts < self.retry_attempts => {
275 attempts += 1;
276 log::warn!(
277 "Request attempt {} failed: {}. Retrying in {}ms...",
278 attempts,
279 e,
280 self.retry_delay_ms
281 );
282 time::sleep(Duration::from_millis(self.retry_delay_ms)).await;
283 continue;
284 }
285 Err(e) => return Err(self.map_reqwest_error(e)),
286 }
287 }
288 }
289
290 fn map_reqwest_error(&self, err: reqwest::Error) -> SubXError {
292 let url = sanitize_base_url(&self.base_url);
293 if err.is_timeout() {
294 return SubXError::AiService(format!(
295 "local LLM endpoint timed out after {}s: {}",
296 self.request_timeout_seconds, url
297 ));
298 }
299 if err.is_connect() {
300 return SubXError::AiService(format!("local LLM endpoint unreachable: {}", url));
301 }
302 err.into()
305 }
306
307 async fn map_http_error(&self, response: reqwest::Response) -> SubXError {
310 let status = response.status();
311 let body_text = response.text().await.unwrap_or_default();
312 let safe_body = crate::services::ai::error_sanitizer::sanitize_url_in_error(
313 &crate::services::ai::error_sanitizer::truncate_error_body(
314 &body_text,
315 crate::services::ai::error_sanitizer::DEFAULT_ERROR_BODY_MAX_LEN,
316 ),
317 );
318
319 if status.as_u16() == 404 || body_indicates_model_missing(&body_text) {
320 return SubXError::AiService(format!("local LLM model not found: {}", self.model));
321 }
322
323 SubXError::AiService(format!(
324 "local LLM endpoint returned HTTP {}: {}",
325 status, safe_body
326 ))
327 }
328}
329
330fn body_indicates_model_missing(body: &str) -> bool {
333 let lower = body.to_ascii_lowercase();
334 let mentions_model = lower.contains("model");
335 if !mentions_model {
336 return false;
337 }
338 lower.contains("not found")
339 || lower.contains("not loaded")
340 || lower.contains("no such model")
341 || lower.contains("unknown model")
342}
343
344pub(crate) fn sanitize_base_url(input: &str) -> String {
352 match url::Url::parse(input) {
353 Ok(mut url) => {
354 let _ = url.set_username("");
357 let _ = url.set_password(None);
358 url.set_query(None);
359 url.set_fragment(None);
360
361 let scheme = url.scheme();
362 let host_display = match url.host() {
365 Some(url::Host::Ipv6(addr)) => format!("[{}]", addr),
366 Some(_) => url.host_str().unwrap_or_default().to_string(),
367 None => return "<unparseable URL>".to_string(),
368 };
369 let path = url.path();
370 match url.port() {
371 Some(port) => format!("{}://{}:{}{}", scheme, host_display, port, path),
372 None => format!("{}://{}{}", scheme, host_display, path),
373 }
374 }
375 Err(_) => "<unparseable URL>".to_string(),
376 }
377}
378
379#[async_trait]
380impl AIProvider for LocalLLMClient {
381 async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
382 let prompt = self.build_analysis_prompt(&request);
383 let messages = vec![
384 json!({"role": "system", "content": "You are a professional subtitle matching assistant that can analyze the correspondence between video and subtitle files."}),
385 json!({"role": "user", "content": prompt}),
386 ];
387 let response = self.chat_completion(messages).await?;
388 self.parse_match_result(&response)
389 }
390
391 async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
392 let prompt = self.build_verification_prompt(&verification);
393 let messages = vec![
394 json!({"role": "system", "content": "Please evaluate the confidence level of subtitle matching and provide a score between 0-1."}),
395 json!({"role": "user", "content": prompt}),
396 ];
397 let response = self.chat_completion(messages).await?;
398 self.parse_confidence_score(&response)
399 }
400
401 async fn chat_completion(&self, messages: Vec<Value>) -> Result<String> {
402 LocalLLMClient::chat_completion(self, messages).await
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 fn make_client(base_url: &str, api_key: Option<&str>) -> LocalLLMClient {
411 LocalLLMClient::new(
412 api_key.map(|s| s.to_string()),
413 "llama3.1:8b-instruct".to_string(),
414 0.3,
415 1024,
416 1,
417 10,
418 base_url.to_string(),
419 120,
420 )
421 }
422
423 #[test]
424 fn debug_redacts_api_key() {
425 let client = make_client("http://localhost:11434/v1", Some("super-secret-token"));
426 let rendered = format!("{:?}", client);
427 assert!(
428 rendered.contains("[REDACTED]"),
429 "Debug output should redact api_key, got: {rendered}"
430 );
431 assert!(!rendered.contains("super-secret-token"));
432 }
433
434 #[test]
435 fn debug_marks_missing_api_key_as_none() {
436 let client = make_client("http://localhost:11434/v1", None);
437 let rendered = format!("{:?}", client);
438 assert!(rendered.contains("api_key: None"), "got: {rendered}");
439 }
440
441 #[test]
442 fn url_join_with_trailing_slash() {
443 let client = make_client("http://localhost:11434/v1/", None);
444 assert_eq!(
445 client.chat_completions_url(),
446 "http://localhost:11434/v1/chat/completions"
447 );
448 assert!(!client.chat_completions_url().contains("//chat"));
449 }
450
451 #[test]
452 fn url_join_without_trailing_slash() {
453 let client = make_client("http://localhost:11434/v1", None);
454 assert_eq!(
455 client.chat_completions_url(),
456 "http://localhost:11434/v1/chat/completions"
457 );
458 }
459
460 #[test]
461 fn url_join_root_base_url() {
462 let client = make_client("http://localhost:11434", None);
463 assert_eq!(
464 client.chat_completions_url(),
465 "http://localhost:11434/chat/completions"
466 );
467 }
468
469 #[test]
470 fn sanitize_base_url_strips_userinfo_query_and_fragment() {
471 assert_eq!(
472 sanitize_base_url("http://user:secret@127.0.0.1:11434/v1?token=abc#frag"),
473 "http://127.0.0.1:11434/v1"
474 );
475 }
476
477 #[test]
478 fn sanitize_base_url_preserves_plain_localhost() {
479 assert_eq!(
480 sanitize_base_url("http://localhost:11434/v1"),
481 "http://localhost:11434/v1"
482 );
483 }
484
485 #[test]
486 fn sanitize_base_url_preserves_trailing_slash() {
487 assert_eq!(
489 sanitize_base_url("https://host:8080/api/v1/"),
490 "https://host:8080/api/v1/"
491 );
492 }
493
494 #[test]
495 fn sanitize_base_url_handles_unparseable_input() {
496 assert_eq!(sanitize_base_url("not a url"), "<unparseable URL>");
497 assert_eq!(sanitize_base_url(""), "<unparseable URL>");
498 }
499
500 #[test]
501 fn sanitize_base_url_strips_password_only() {
502 assert_eq!(
503 sanitize_base_url("https://:pwd@host:8080/v1"),
504 "https://host:8080/v1"
505 );
506 }
507
508 #[test]
509 fn sanitize_base_url_preserves_ipv6_brackets() {
510 assert_eq!(
511 sanitize_base_url("http://[::1]:11434/v1"),
512 "http://[::1]:11434/v1"
513 );
514 assert_eq!(
515 sanitize_base_url("https://[fd00::1]:8443/v1/"),
516 "https://[fd00::1]:8443/v1/"
517 );
518 assert_eq!(
520 sanitize_base_url("http://user:pwd@[::1]:11434/v1?token=secret"),
521 "http://[::1]:11434/v1"
522 );
523 }
524
525 #[test]
526 fn body_indicates_model_missing_detects_common_patterns() {
527 assert!(body_indicates_model_missing(
528 "{\"error\":\"model 'foo' not found, try pulling it first\"}"
529 ));
530 assert!(body_indicates_model_missing(
531 "{\"error\":\"Model not loaded\"}"
532 ));
533 assert!(body_indicates_model_missing(
534 "{\"detail\":\"no such model: bar\"}"
535 ));
536 assert!(body_indicates_model_missing(
537 "{\"error\":\"unknown model llama99\"}"
538 ));
539 assert!(!body_indicates_model_missing(
540 "{\"error\":\"server overloaded\"}"
541 ));
542 assert!(!body_indicates_model_missing(""));
543 }
544
545 fn make_config(base_url: &str, api_key: Option<&str>) -> crate::config::AIConfig {
546 crate::config::AIConfig {
547 provider: "local".to_string(),
548 api_key: api_key.map(|s| s.to_string()),
549 model: "llama3.1:8b-instruct".to_string(),
550 base_url: base_url.to_string(),
551 max_sample_length: 500,
552 temperature: 0.3,
553 max_tokens: 1024,
554 retry_attempts: 2,
555 retry_delay_ms: 100,
556 request_timeout_seconds: 120,
557 api_version: None,
558 }
559 }
560
561 #[test]
562 fn from_config_rejects_empty_base_url() {
563 let config = make_config("", None);
564 let err = LocalLLMClient::from_config(&config).unwrap_err();
565 assert!(
566 err.to_string().contains("ai.base_url is required"),
567 "unexpected error: {err}"
568 );
569 }
570
571 #[test]
572 fn from_config_rejects_whitespace_base_url() {
573 let config = make_config(" ", None);
574 assert!(LocalLLMClient::from_config(&config).is_err());
575 }
576
577 #[test]
578 fn from_config_accepts_loopback_http() {
579 let config = make_config("http://localhost:11434/v1", None);
580 let client = LocalLLMClient::from_config(&config).expect("should accept loopback HTTP");
581 assert!(client.api_key.is_none());
582 assert_eq!(client.base_url, "http://localhost:11434/v1");
583 }
584
585 #[test]
586 fn from_config_accepts_lan_http() {
587 let config = make_config("http://192.168.1.50:11434/v1", None);
588 let client = LocalLLMClient::from_config(&config).expect("LAN HTTP must be accepted");
589 assert_eq!(client.base_url, "http://192.168.1.50:11434/v1");
590 }
591
592 #[test]
593 fn from_config_accepts_https() {
594 let config = make_config("https://ollama.tailnet.ts.net/v1", Some("vllm-token"));
595 let client = LocalLLMClient::from_config(&config).expect("HTTPS must be accepted");
596 assert_eq!(client.base_url, "https://ollama.tailnet.ts.net/v1");
597 assert_eq!(client.api_key.as_deref(), Some("vllm-token"));
598 }
599
600 #[test]
601 fn from_config_normalizes_empty_api_key_to_none() {
602 let config = make_config("http://localhost:11434/v1", Some(""));
603 let client = LocalLLMClient::from_config(&config).unwrap();
604 assert!(
605 client.api_key.is_none(),
606 "empty api_key should normalize to None"
607 );
608 }
609
610 #[test]
611 fn from_config_trims_trailing_slash_in_base_url() {
612 let config = make_config("http://localhost:11434/v1/", None);
613 let client = LocalLLMClient::from_config(&config).unwrap();
614 assert_eq!(client.base_url, "http://localhost:11434/v1");
615 assert_eq!(
616 client.chat_completions_url(),
617 "http://localhost:11434/v1/chat/completions"
618 );
619 }
620}