Skip to main content

subx_cli/services/ai/
openai.rs

1use crate::Result;
2use crate::cli::display_ai_usage;
3use crate::error::SubXError;
4use crate::services::ai::AiUsageStats;
5use crate::services::ai::{
6    AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
7};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde_json::Value;
11use serde_json::json;
12use std::time::Duration;
13
14use crate::services::ai::hosted_hint::{append_local_hint, maybe_attach_local_hint};
15use crate::services::ai::prompts::{PromptBuilder, ResponseParser};
16use crate::services::ai::retry::HttpRetryClient;
17
18/// OpenAI client implementation
19pub struct OpenAIClient {
20    client: Client,
21    api_key: String,
22    model: String,
23    temperature: f32,
24    max_tokens: u32,
25    retry_attempts: u32,
26    retry_delay_ms: u64,
27    base_url: String,
28}
29
30impl std::fmt::Debug for OpenAIClient {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("OpenAIClient")
33            .field("client", &self.client)
34            .field("api_key", &"[REDACTED]")
35            .field("model", &self.model)
36            .field("temperature", &self.temperature)
37            .field("max_tokens", &self.max_tokens)
38            .field("retry_attempts", &self.retry_attempts)
39            .field("retry_delay_ms", &self.retry_delay_ms)
40            .field("base_url", &self.base_url)
41            .finish()
42    }
43}
44
45impl PromptBuilder for OpenAIClient {}
46impl ResponseParser for OpenAIClient {}
47impl HttpRetryClient for OpenAIClient {
48    fn retry_attempts(&self) -> u32 {
49        self.retry_attempts
50    }
51    fn retry_delay_ms(&self) -> u64 {
52        self.retry_delay_ms
53    }
54}
55
56// Mock testing: OpenAIClient with AIProvider interface
57#[cfg(test)]
58mod tests {
59    use super::*;
60    use mockall::{mock, predicate::eq};
61    use serde_json::json;
62    use wiremock::matchers::{header, method, path};
63    use wiremock::{Mock, MockServer, ResponseTemplate};
64
65    mock! {
66        AIClient {}
67
68        #[async_trait]
69        impl AIProvider for AIClient {
70            async fn analyze_content(&self, request: AnalysisRequest) -> crate::Result<MatchResult>;
71            async fn verify_match(&self, verification: VerificationRequest) -> crate::Result<ConfidenceScore>;
72        }
73    }
74
75    #[tokio::test]
76    async fn test_openai_client_creation() {
77        let client = OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.5, 1000, 2, 100);
78        assert_eq!(client.api_key, "test-key");
79        assert_eq!(client.model, "gpt-4.1-mini");
80        assert_eq!(client.temperature, 0.5);
81        assert_eq!(client.max_tokens, 1000);
82        assert_eq!(client.retry_attempts, 2);
83        assert_eq!(client.retry_delay_ms, 100);
84    }
85
86    #[tokio::test]
87    async fn test_chat_completion_success() {
88        let server = MockServer::start().await;
89        Mock::given(method("POST"))
90            .and(path("/chat/completions"))
91            .and(header("authorization", "Bearer test-key"))
92            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
93                "choices": [{"message": {"content": "test response content"}}]
94            })))
95            .mount(&server)
96            .await;
97        let mut client =
98            OpenAIClient::new("test-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
99        client.base_url = server.uri();
100        let messages = vec![json!({"role":"user","content":"test"})];
101        let resp = client.chat_completion(messages).await.unwrap();
102        assert_eq!(resp, "test response content");
103    }
104
105    #[tokio::test]
106    async fn test_chat_completion_error() {
107        let server = MockServer::start().await;
108        Mock::given(method("POST"))
109            .and(path("/chat/completions"))
110            .respond_with(ResponseTemplate::new(400).set_body_json(json!({
111                "error": {"message":"Invalid API key"}
112            })))
113            .mount(&server)
114            .await;
115        let mut client =
116            OpenAIClient::new("bad-key".into(), "gpt-4.1-mini".into(), 0.3, 1000, 1, 0);
117        client.base_url = server.uri();
118        let messages = vec![json!({"role":"user","content":"test"})];
119        let result = client.chat_completion(messages).await;
120        assert!(result.is_err());
121    }
122
123    #[tokio::test]
124    async fn test_analyze_content_with_mock() {
125        let mut mock = MockAIClient::new();
126        let req = AnalysisRequest {
127            video_files: vec!["v.mp4".into()],
128            subtitle_files: vec!["s.srt".into()],
129            content_samples: vec![],
130        };
131        let expected = MatchResult {
132            matches: vec![],
133            confidence: 0.5,
134            reasoning: "OK".into(),
135        };
136        mock.expect_analyze_content()
137            .with(eq(req.clone()))
138            .times(1)
139            .returning(move |_| Ok(expected.clone()));
140        let res = mock.analyze_content(req.clone()).await.unwrap();
141        assert_eq!(res.confidence, 0.5);
142    }
143
144    #[test]
145    fn test_prompt_building_and_parsing() {
146        let client = OpenAIClient::new("k".into(), "m".into(), 0.1, 1000, 0, 0);
147        let request = AnalysisRequest {
148            video_files: vec!["F1.mp4".into()],
149            subtitle_files: vec!["S1.srt".into()],
150            content_samples: vec![],
151        };
152        let prompt = client.build_analysis_prompt(&request);
153        assert!(prompt.contains("F1.mp4"));
154        assert!(prompt.contains("S1.srt"));
155        assert!(prompt.contains("JSON"));
156        let json_resp = r#"{ "matches": [], "confidence":0.9, "reasoning":"r" }"#;
157        let mr = client.parse_match_result(json_resp).unwrap();
158        assert_eq!(mr.confidence, 0.9);
159    }
160
161    #[test]
162    fn test_openai_client_from_config() {
163        let config = crate::config::AIConfig {
164            provider: "openai".to_string(),
165            api_key: Some("test-key".to_string()),
166            model: "gpt-test".to_string(),
167            base_url: "https://custom.openai.com/v1".to_string(),
168            max_sample_length: 500,
169            temperature: 0.7,
170            max_tokens: 2000,
171            retry_attempts: 2,
172            retry_delay_ms: 150,
173            request_timeout_seconds: 60,
174            api_version: None,
175        };
176        let client = OpenAIClient::from_config(&config).unwrap();
177        assert_eq!(client.api_key, "test-key");
178        assert_eq!(client.model, "gpt-test");
179        assert_eq!(client.temperature, 0.7);
180        assert_eq!(client.max_tokens, 2000);
181    }
182
183    #[test]
184    fn test_openai_client_from_config_invalid_base_url() {
185        let config = crate::config::AIConfig {
186            provider: "openai".to_string(),
187            api_key: Some("test-key".to_string()),
188            model: "gpt-test".to_string(),
189            base_url: "ftp://invalid.url".to_string(),
190            max_sample_length: 500,
191            temperature: 0.7,
192            max_tokens: 1000,
193            retry_attempts: 2,
194            retry_delay_ms: 150,
195            request_timeout_seconds: 30,
196            api_version: None,
197        };
198        let err = OpenAIClient::from_config(&config).unwrap_err();
199        // Non-http/https protocols should return protocol error message
200        assert!(
201            err.to_string()
202                .contains("Base URL must use http or https protocol")
203        );
204    }
205
206    /// §3.6 — connection refused against `127.0.0.1` on a port with no
207    /// listener results in a hint-bearing error.
208    #[tokio::test]
209    async fn test_hosted_hint_connection_refused_loopback() {
210        let port = pick_unused_port().await;
211        let mut client = OpenAIClient::new("k".into(), "gpt-4.1-mini".into(), 0.0, 16, 0, 0);
212        client.base_url = format!("http://127.0.0.1:{}", port);
213        let err = client
214            .chat_completion(vec![json!({"role":"user","content":"x"})])
215            .await
216            .unwrap_err();
217        let msg = err.to_string();
218        assert!(
219            msg.contains("ollama") && msg.contains("ai.provider"),
220            "expected local-provider hint, got: {msg}"
221        );
222    }
223
224    /// §3.6 — connection refused against an RFC1918 address surfaces the hint.
225    #[tokio::test]
226    async fn test_hosted_hint_connection_refused_rfc1918() {
227        // 192.168.0.1 with a low port is unlikely to listen and produces a
228        // transport failure (connect refused / unreachable / timeout)
229        // within the configured request timeout. The hint must be appended
230        // regardless of which transport sub-kind reqwest reports.
231        let client = OpenAIClient::new_with_base_url_and_timeout(
232            "k".into(),
233            "gpt-4.1-mini".into(),
234            0.0,
235            16,
236            0,
237            0,
238            "http://192.168.0.1:1".to_string(),
239            1,
240        );
241        let err = client
242            .chat_completion(vec![json!({"role":"user","content":"x"})])
243            .await
244            .unwrap_err();
245        let msg = err.to_string();
246        assert!(
247            msg.contains("ollama") && msg.contains("ai.provider"),
248            "expected local-provider hint, got: {msg}"
249        );
250    }
251
252    /// §3.6 — HTTP 200 with a non-OpenAI body (`{"hello":"world"}`) MUST
253    /// surface the local-provider hint via the parse-shape branch.
254    #[tokio::test]
255    async fn test_hosted_hint_http_200_non_openai_body() {
256        let server = MockServer::start().await;
257        Mock::given(method("POST"))
258            .and(path("/chat/completions"))
259            .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "hello": "world" })))
260            .mount(&server)
261            .await;
262        let mut client = OpenAIClient::new("k".into(), "gpt-4.1-mini".into(), 0.0, 16, 0, 0);
263        client.base_url = server.uri();
264        let err = client
265            .chat_completion(vec![json!({"role":"user","content":"x"})])
266            .await
267            .unwrap_err();
268        let msg = err.to_string();
269        assert!(
270            msg.contains("Invalid API response format"),
271            "expected base parse-shape message: {msg}"
272        );
273        assert!(
274            msg.contains("ollama") && msg.contains("ai.provider"),
275            "expected local-provider hint: {msg}"
276        );
277    }
278
279    /// §3.6 negative — a genuine failure against a public host MUST NOT
280    /// surface the local-provider hint. We use TEST-NET-1 (`192.0.2.0/24`,
281    /// RFC 5737), which is reserved for documentation and never routable,
282    /// so the call fails deterministically without external dependencies.
283    /// `192.0.2.x` is not in any private range, so the predicate must
284    /// classify the host as public and suppress the hint.
285    #[tokio::test]
286    async fn test_hosted_hint_not_emitted_for_public_host() {
287        let client = OpenAIClient::new_with_base_url_and_timeout(
288            "k".into(),
289            "gpt-4.1-mini".into(),
290            0.0,
291            16,
292            0,
293            0,
294            "https://192.0.2.1/v1".to_string(),
295            1,
296        );
297        let err = client
298            .chat_completion(vec![json!({"role":"user","content":"x"})])
299            .await
300            .unwrap_err();
301        let msg = err.to_string();
302        assert!(
303            !msg.contains("ollama"),
304            "public-host failure must NOT carry the local-provider hint: {msg}"
305        );
306    }
307
308    /// Helper: bind a TCP listener on `127.0.0.1:0` to obtain a port the
309    /// kernel allocated, then drop the listener so the port is free
310    /// (race-prone in theory but reliable in tests because no sibling test
311    /// rebinds it within microseconds).
312    async fn pick_unused_port() -> u16 {
313        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
314        let port = listener.local_addr().unwrap().port();
315        drop(listener);
316        port
317    }
318}
319
320impl OpenAIClient {
321    /// Create new OpenAIClient (using default base_url)
322    pub fn new(
323        api_key: String,
324        model: String,
325        temperature: f32,
326        max_tokens: u32,
327        retry_attempts: u32,
328        retry_delay_ms: u64,
329    ) -> Self {
330        Self::new_with_base_url(
331            api_key,
332            model,
333            temperature,
334            max_tokens,
335            retry_attempts,
336            retry_delay_ms,
337            "https://api.openai.com/v1".to_string(),
338        )
339    }
340
341    /// Create a new OpenAIClient with custom base_url support
342    pub fn new_with_base_url(
343        api_key: String,
344        model: String,
345        temperature: f32,
346        max_tokens: u32,
347        retry_attempts: u32,
348        retry_delay_ms: u64,
349        base_url: String,
350    ) -> Self {
351        // Use default 30 second timeout for backward compatibility
352        Self::new_with_base_url_and_timeout(
353            api_key,
354            model,
355            temperature,
356            max_tokens,
357            retry_attempts,
358            retry_delay_ms,
359            base_url,
360            30,
361        )
362    }
363
364    /// Create a new OpenAIClient with custom base_url and timeout support
365    #[allow(clippy::too_many_arguments)]
366    pub fn new_with_base_url_and_timeout(
367        api_key: String,
368        model: String,
369        temperature: f32,
370        max_tokens: u32,
371        retry_attempts: u32,
372        retry_delay_ms: u64,
373        base_url: String,
374        request_timeout_seconds: u64,
375    ) -> Self {
376        let client = Client::builder()
377            .timeout(Duration::from_secs(request_timeout_seconds))
378            .build()
379            .expect("Failed to create HTTP client");
380        Self {
381            client,
382            api_key,
383            model,
384            temperature,
385            max_tokens,
386            retry_attempts,
387            retry_delay_ms,
388            base_url: base_url.trim_end_matches('/').to_string(),
389        }
390    }
391
392    /// Create client from unified configuration
393    pub fn from_config(config: &crate::config::AIConfig) -> crate::Result<Self> {
394        let api_key = config
395            .api_key
396            .as_ref()
397            .ok_or_else(|| crate::error::SubXError::config("Missing OpenAI API Key"))?;
398
399        // Validate base URL format
400        Self::validate_base_url(&config.base_url)?;
401        crate::services::ai::security::warn_on_insecure_http_str(&config.base_url, api_key);
402
403        Ok(Self::new_with_base_url_and_timeout(
404            api_key.clone(),
405            config.model.clone(),
406            config.temperature,
407            config.max_tokens,
408            config.retry_attempts,
409            config.retry_delay_ms,
410            config.base_url.clone(),
411            config.request_timeout_seconds,
412        ))
413    }
414
415    /// Validate base URL format
416    fn validate_base_url(url: &str) -> crate::Result<()> {
417        use url::Url;
418        let parsed = Url::parse(url)
419            .map_err(|e| crate::error::SubXError::config(format!("Invalid base URL: {}", e)))?;
420
421        if !matches!(parsed.scheme(), "http" | "https") {
422            return Err(crate::error::SubXError::config(
423                "Base URL must use http or https protocol".to_string(),
424            ));
425        }
426
427        if parsed.host().is_none() {
428            return Err(crate::error::SubXError::config(
429                "Base URL must contain a valid hostname".to_string(),
430            ));
431        }
432
433        Ok(())
434    }
435
436    /// Send a raw chat completion request to the OpenAI Chat Completions API.
437    pub async fn chat_completion(&self, messages: Vec<serde_json::Value>) -> Result<String> {
438        let request_body = json!({
439            "model": self.model,
440            "messages": messages,
441            "temperature": self.temperature,
442            "max_tokens": self.max_tokens,
443        });
444
445        let request = self
446            .client
447            .post(format!("{}/chat/completions", self.base_url))
448            .header("Authorization", format!("Bearer {}", self.api_key))
449            .header("Content-Type", "application/json")
450            .json(&request_body);
451        let mut response = match self.make_request_with_retry(request).await {
452            Ok(r) => r,
453            Err(e) => return Err(maybe_attach_local_hint(e, &self.base_url)),
454        };
455
456        const MAX_AI_RESPONSE_BYTES: u64 = 10 * 1024 * 1024; // 10 MiB
457        if let Some(len) = response.content_length() {
458            if len > MAX_AI_RESPONSE_BYTES {
459                return Err(SubXError::AiService(format!(
460                    "AI response too large: {} bytes (limit: {} bytes)",
461                    len, MAX_AI_RESPONSE_BYTES
462                )));
463            }
464        }
465
466        if !response.status().is_success() {
467            let status = response.status();
468            let error_text = response.text().await?;
469            let safe_body = crate::services::ai::error_sanitizer::sanitize_url_in_error(
470                &crate::services::ai::error_sanitizer::truncate_error_body(
471                    &error_text,
472                    crate::services::ai::error_sanitizer::DEFAULT_ERROR_BODY_MAX_LEN,
473                ),
474            );
475            return Err(SubXError::AiService(format!(
476                "OpenAI API error {}: {}",
477                status, safe_body
478            )));
479        }
480
481        // Bounded chunked read to guard against oversized responses when
482        // content_length() is not reported by the server.
483        let mut body = Vec::new();
484        while let Some(chunk) = response.chunk().await? {
485            body.extend_from_slice(&chunk);
486            if body.len() as u64 > MAX_AI_RESPONSE_BYTES {
487                return Err(SubXError::AiService(format!(
488                    "AI response too large: {} bytes read (limit: {} bytes)",
489                    body.len(),
490                    MAX_AI_RESPONSE_BYTES
491                )));
492            }
493        }
494        let response_json: Value = serde_json::from_slice(&body)
495            .map_err(|e| SubXError::AiService(format!("Failed to parse AI response: {}", e)))?;
496        let content = response_json["choices"][0]["message"]["content"]
497            .as_str()
498            .ok_or_else(|| {
499                // Body parsed as JSON but the canonical OpenAI shape is
500                // missing — almost always a hosted client pointed at a
501                // non-OpenAI endpoint. Append the local-provider hint
502                // unconditionally here (no host predicate): the parse-shape
503                // failure is itself a strong signal.
504                SubXError::AiService(append_local_hint("Invalid API response format"))
505            })?;
506
507        // Parse usage statistics and display
508        if let Some(usage_obj) = response_json.get("usage") {
509            if let (Some(p), Some(c), Some(t)) = (
510                usage_obj.get("prompt_tokens").and_then(Value::as_u64),
511                usage_obj.get("completion_tokens").and_then(Value::as_u64),
512                usage_obj.get("total_tokens").and_then(Value::as_u64),
513            ) {
514                let stats = AiUsageStats {
515                    model: self.model.clone(),
516                    prompt_tokens: p as u32,
517                    completion_tokens: c as u32,
518                    total_tokens: t as u32,
519                };
520                display_ai_usage(&stats);
521            }
522        }
523
524        Ok(content.to_string())
525    }
526}
527
528#[async_trait]
529impl AIProvider for OpenAIClient {
530    async fn analyze_content(&self, request: AnalysisRequest) -> Result<MatchResult> {
531        let prompt = self.build_analysis_prompt(&request);
532        let messages = vec![
533            json!({"role": "system", "content": Self::get_analysis_system_message()}),
534            json!({"role": "user", "content": prompt}),
535        ];
536        let response = self.chat_completion(messages).await?;
537        self.parse_match_result(&response)
538    }
539
540    async fn verify_match(&self, verification: VerificationRequest) -> Result<ConfidenceScore> {
541        let prompt = self.build_verification_prompt(&verification);
542        let messages = vec![
543            json!({"role": "system", "content": Self::get_verification_system_message()}),
544            json!({"role": "user", "content": prompt}),
545        ];
546        let response = self.chat_completion(messages).await?;
547        self.parse_confidence_score(&response)
548    }
549
550    async fn chat_completion(&self, messages: Vec<serde_json::Value>) -> Result<String> {
551        OpenAIClient::chat_completion(self, messages).await
552    }
553}