tl_cli/translation/
client.rs

1use anyhow::{Context, Result};
2use bytes::Bytes;
3use futures_util::Stream;
4use reqwest::Client;
5use serde::Serialize;
6use sha2::{Digest, Sha256};
7use std::borrow::Cow;
8use std::pin::Pin;
9
10use super::prompt::{SYSTEM_PROMPT_TEMPLATE, build_system_prompt_with_style};
11use super::sse_parser::sse_to_text_stream;
12
13/// A request to translate text.
14///
15/// Contains all parameters needed to perform a translation and compute
16/// a unique cache key.
17#[derive(Debug, Clone)]
18pub struct TranslationRequest {
19    /// The text to translate.
20    pub source_text: String,
21    /// The target language (ISO 639-1 code, e.g., "ja", "en").
22    pub target_language: String,
23    /// The model to use for translation.
24    pub model: String,
25    /// The API endpoint URL.
26    pub endpoint: String,
27    /// The translation style prompt (if specified).
28    pub style: Option<String>,
29}
30
31impl TranslationRequest {
32    /// Computes a unique cache key for this request.
33    ///
34    /// The key is a SHA-256 hash of the source text, target language,
35    /// model, endpoint, style, and prompt template hash.
36    pub fn cache_key(&self) -> String {
37        let prompt_hash = Self::prompt_hash();
38
39        let cache_input = serde_json::json!({
40            "source_text": self.source_text,
41            "target_language": self.target_language,
42            "model": self.model,
43            "endpoint": self.endpoint,
44            "prompt_hash": prompt_hash,
45            "style": self.style
46        });
47
48        let mut hasher = Sha256::new();
49        hasher.update(cache_input.to_string().as_bytes());
50        hex::encode(hasher.finalize())
51    }
52
53    /// Computes a hash of the system prompt template.
54    ///
55    /// Used to invalidate cache when the prompt changes.
56    pub fn prompt_hash() -> String {
57        let mut hasher = Sha256::new();
58        hasher.update(SYSTEM_PROMPT_TEMPLATE.as_bytes());
59        hex::encode(hasher.finalize())
60    }
61}
62
63/// Request body for the chat completions API.
64#[derive(Debug, Serialize)]
65struct ChatCompletionRequest<'a> {
66    model: &'a str,
67    messages: Vec<Message<'a>>,
68    stream: bool,
69}
70
71impl<'a> ChatCompletionRequest<'a> {
72    /// Builds a chat completion request for translation.
73    fn for_translation(model: &'a str, system_prompt: &'a str, source_text: &'a str) -> Self {
74        Self {
75            model,
76            messages: vec![
77                Message {
78                    role: "system",
79                    content: Cow::Borrowed(system_prompt),
80                },
81                Message {
82                    role: "user",
83                    content: Cow::Borrowed(source_text),
84                },
85            ],
86            stream: true,
87        }
88    }
89}
90
91#[derive(Debug, Serialize)]
92struct Message<'a> {
93    role: &'static str,
94    content: Cow<'a, str>,
95}
96
97/// Client for translating text using OpenAI-compatible APIs.
98///
99/// Supports streaming responses for real-time output.
100///
101/// # Example
102///
103/// ```no_run
104/// use tl_cli::translation::{TranslationClient, TranslationRequest};
105/// use futures_util::StreamExt;
106///
107/// # async fn example() -> anyhow::Result<()> {
108/// let client = TranslationClient::new(
109///     "http://localhost:11434".to_string(),
110///     None,
111/// );
112///
113/// let request = TranslationRequest {
114///     source_text: "Hello, world!".to_string(),
115///     target_language: "ja".to_string(),
116///     model: "gemma3:12b".to_string(),
117///     endpoint: "http://localhost:11434".to_string(),
118///     style: None,
119/// };
120///
121/// let mut stream = client.translate_stream(&request).await?;
122/// while let Some(chunk) = stream.next().await {
123///     print!("{}", chunk?);
124/// }
125/// # Ok(())
126/// # }
127/// ```
128pub struct TranslationClient {
129    client: Client,
130    endpoint: String,
131    api_key: Option<String>,
132}
133
134impl TranslationClient {
135    /// Creates a new translation client.
136    pub fn new(endpoint: String, api_key: Option<String>) -> Self {
137        Self {
138            client: Client::new(),
139            endpoint,
140            api_key,
141        }
142    }
143
144    /// Translates text and returns a stream of response chunks.
145    ///
146    /// The stream yields chunks of the translated text as they arrive,
147    /// enabling real-time display of the translation.
148    pub async fn translate_stream(
149        &self,
150        request: &TranslationRequest,
151    ) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>> {
152        let byte_stream = self
153            .send_chat_completion(
154                &request.model,
155                &request.target_language,
156                &request.source_text,
157                request.style.as_deref(),
158            )
159            .await?;
160
161        Ok(Box::pin(sse_to_text_stream(byte_stream)))
162    }
163
164    /// Sends a chat completion request and returns the raw byte stream.
165    async fn send_chat_completion(
166        &self,
167        model: &str,
168        target_language: &str,
169        source_text: &str,
170        style: Option<&str>,
171    ) -> Result<impl Stream<Item = reqwest::Result<Bytes>> + Send + 'static> {
172        let url = self.build_url();
173        let system_prompt = build_system_prompt_with_style(target_language, style);
174        let chat_request =
175            ChatCompletionRequest::for_translation(model, &system_prompt, source_text);
176
177        let response = self.send_request(&url, &chat_request).await?;
178
179        Ok(response.bytes_stream())
180    }
181
182    /// Sends an HTTP POST request with optional authorization.
183    async fn send_request<T: Serialize + Sync>(
184        &self,
185        url: &str,
186        body: &T,
187    ) -> Result<reqwest::Response> {
188        let mut request = self.client.post(url).json(body);
189
190        if let Some(api_key) = &self.api_key {
191            request = request.header("Authorization", format!("Bearer {api_key}"));
192        }
193
194        let response = request
195            .send()
196            .await
197            .with_context(|| format!("Failed to connect to API endpoint: {url}"))?;
198
199        if !response.status().is_success() {
200            let status = response.status();
201            let body = response.text().await.unwrap_or_default();
202            anyhow::bail!("API request failed with status {status}: {body}");
203        }
204
205        Ok(response)
206    }
207
208    /// Builds the chat completions API URL.
209    fn build_url(&self) -> String {
210        format!(
211            "{}/v1/chat/completions",
212            self.endpoint.trim_end_matches('/')
213        )
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    fn create_test_request() -> TranslationRequest {
222        TranslationRequest {
223            source_text: "Hello, world!".to_string(),
224            target_language: "ja".to_string(),
225            model: "gemma3:12b".to_string(),
226            endpoint: "http://localhost:11434".to_string(),
227            style: None,
228        }
229    }
230
231    #[test]
232    fn test_cache_key_is_consistent() {
233        let request = create_test_request();
234        let key1 = request.cache_key();
235        let key2 = request.cache_key();
236        assert_eq!(key1, key2);
237    }
238
239    #[test]
240    fn test_cache_key_is_hex_string() {
241        let request = create_test_request();
242        let key = request.cache_key();
243        // SHA-256 produces 64 hex characters
244        assert_eq!(key.len(), 64);
245        assert!(key.chars().all(|c| c.is_ascii_hexdigit()));
246    }
247
248    #[test]
249    fn test_cache_key_differs_for_different_source_text() {
250        let request1 = create_test_request();
251        let mut request2 = create_test_request();
252        request2.source_text = "Different text".to_string();
253        assert_ne!(request1.cache_key(), request2.cache_key());
254    }
255
256    #[test]
257    fn test_cache_key_differs_for_different_target_language() {
258        let request1 = create_test_request();
259        let mut request2 = create_test_request();
260        request2.target_language = "en".to_string();
261        assert_ne!(request1.cache_key(), request2.cache_key());
262    }
263
264    #[test]
265    fn test_cache_key_differs_for_different_model() {
266        let request1 = create_test_request();
267        let mut request2 = create_test_request();
268        request2.model = "gpt-4o".to_string();
269        assert_ne!(request1.cache_key(), request2.cache_key());
270    }
271
272    #[test]
273    fn test_cache_key_differs_for_different_endpoint() {
274        let request1 = create_test_request();
275        let mut request2 = create_test_request();
276        request2.endpoint = "https://api.openai.com".to_string();
277        assert_ne!(request1.cache_key(), request2.cache_key());
278    }
279
280    #[test]
281    fn test_prompt_hash_is_consistent() {
282        let hash1 = TranslationRequest::prompt_hash();
283        let hash2 = TranslationRequest::prompt_hash();
284        assert_eq!(hash1, hash2);
285    }
286
287    #[test]
288    fn test_prompt_hash_is_hex_string() {
289        let hash = TranslationRequest::prompt_hash();
290        // SHA-256 produces 64 hex characters
291        assert_eq!(hash.len(), 64);
292        assert!(hash.chars().all(|c| c.is_ascii_hexdigit()));
293    }
294
295    #[test]
296    fn test_translation_client_new() {
297        let client = TranslationClient::new(
298            "http://localhost:11434".to_string(),
299            Some("test-api-key".to_string()),
300        );
301        assert_eq!(client.endpoint, "http://localhost:11434");
302        assert_eq!(client.api_key, Some("test-api-key".to_string()));
303    }
304
305    #[test]
306    fn test_translation_client_new_without_api_key() {
307        let client = TranslationClient::new("http://localhost:11434".to_string(), None);
308        assert_eq!(client.endpoint, "http://localhost:11434");
309        assert!(client.api_key.is_none());
310    }
311
312    #[test]
313    fn test_build_url_without_trailing_slash() {
314        let client = TranslationClient::new("http://localhost:11434".to_string(), None);
315        assert_eq!(
316            client.build_url(),
317            "http://localhost:11434/v1/chat/completions"
318        );
319    }
320
321    #[test]
322    fn test_build_url_with_trailing_slash() {
323        let client = TranslationClient::new("http://localhost:11434/".to_string(), None);
324        assert_eq!(
325            client.build_url(),
326            "http://localhost:11434/v1/chat/completions"
327        );
328    }
329}