Skip to main content

talon_core/llm/
client.rs

1//! Blocking OpenAI-compatible chat-completions client.
2
3use std::time::Duration;
4
5use reqwest::blocking::Client as HttpClient;
6
7use crate::config::ResolvedAuth;
8use crate::inference::redact;
9
10use super::error::ChatError;
11use super::types::{
12    ChatCompletionOutput, ChatCompletionRequest, ChatCompletionResponse, ChatMessage,
13    ReasoningEffort,
14};
15
16/// Default HTTP timeout for LLM chat calls.
17pub const DEFAULT_CHAT_TIMEOUT: Duration = Duration::from_secs(30);
18
19/// Blocking client for OpenAI-compatible `/chat/completions`.
20#[derive(Debug, Clone)]
21pub struct ChatClient {
22    base_url: String,
23    model: String,
24    max_tokens: Option<u32>,
25    reasoning_effort: Option<ReasoningEffort>,
26    chat_template_kwargs: Option<serde_json::Value>,
27    auth: ResolvedAuth,
28    http: HttpClient,
29}
30
31impl ChatClient {
32    /// Builds a client targeting `base_url` with the default timeout.
33    ///
34    /// # Errors
35    ///
36    /// Returns [`ChatError::Build`] if the underlying `reqwest::Client` fails
37    /// to build.
38    pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Result<Self, ChatError> {
39        Self::with_timeout(base_url, model, DEFAULT_CHAT_TIMEOUT)
40    }
41
42    /// Builds a client with a custom timeout.
43    ///
44    /// # Errors
45    ///
46    /// Returns [`ChatError::Build`] if the underlying `reqwest::Client` fails
47    /// to build.
48    pub fn with_timeout(
49        base_url: impl Into<String>,
50        model: impl Into<String>,
51        timeout: Duration,
52    ) -> Result<Self, ChatError> {
53        Self::with_timeout_and_max_tokens(base_url, model, timeout, None)
54    }
55
56    /// Builds a client with the default timeout and optional completion token cap.
57    ///
58    /// # Errors
59    ///
60    /// Returns [`ChatError::Build`] if the underlying `reqwest::Client` fails
61    /// to build.
62    pub fn with_max_tokens(
63        base_url: impl Into<String>,
64        model: impl Into<String>,
65        max_tokens: Option<u32>,
66    ) -> Result<Self, ChatError> {
67        Self::with_timeout_and_max_tokens(base_url, model, DEFAULT_CHAT_TIMEOUT, max_tokens)
68    }
69
70    /// Builds a client with a custom timeout and optional completion token cap.
71    ///
72    /// # Errors
73    ///
74    /// Returns [`ChatError::Build`] if the underlying `reqwest::Client` fails
75    /// to build.
76    pub fn with_timeout_and_max_tokens(
77        base_url: impl Into<String>,
78        model: impl Into<String>,
79        timeout: Duration,
80        max_tokens: Option<u32>,
81    ) -> Result<Self, ChatError> {
82        Self::with_timeout_max_tokens_and_auth(
83            base_url,
84            model,
85            timeout,
86            max_tokens,
87            ResolvedAuth::default(),
88        )
89    }
90
91    /// Builds a client with timeout, token cap, and resolved auth material.
92    ///
93    /// # Errors
94    ///
95    /// Returns [`ChatError::Build`] if the underlying `reqwest::Client` fails
96    /// to build.
97    pub fn with_timeout_max_tokens_and_auth(
98        base_url: impl Into<String>,
99        model: impl Into<String>,
100        timeout: Duration,
101        max_tokens: Option<u32>,
102        auth: ResolvedAuth,
103    ) -> Result<Self, ChatError> {
104        Self::with_optional_timeout_max_tokens_and_auth(
105            base_url,
106            model,
107            Some(timeout),
108            max_tokens,
109            auth,
110        )
111    }
112
113    /// Builds a client with no HTTP request timeout and optional completion token cap.
114    ///
115    /// # Errors
116    ///
117    /// Returns [`ChatError::Build`] if the underlying `reqwest::Client` fails
118    /// to build.
119    pub fn with_no_timeout_and_max_tokens(
120        base_url: impl Into<String>,
121        model: impl Into<String>,
122        max_tokens: Option<u32>,
123    ) -> Result<Self, ChatError> {
124        Self::with_optional_timeout_max_tokens_and_auth(
125            base_url,
126            model,
127            None,
128            max_tokens,
129            ResolvedAuth::default(),
130        )
131    }
132
133    fn with_optional_timeout_max_tokens_and_auth(
134        base_url: impl Into<String>,
135        model: impl Into<String>,
136        timeout: Option<Duration>,
137        max_tokens: Option<u32>,
138        auth: ResolvedAuth,
139    ) -> Result<Self, ChatError> {
140        let mut builder = HttpClient::builder();
141        if let Some(timeout) = timeout {
142            builder = builder.timeout(timeout);
143        }
144        let http = builder.build().map_err(|err| ChatError::Build {
145            message: redact(&err.to_string()),
146        })?;
147        Ok(Self {
148            base_url: base_url.into(),
149            model: model.into(),
150            max_tokens,
151            reasoning_effort: None,
152            chat_template_kwargs: None,
153            auth,
154            http,
155        })
156    }
157
158    /// Sets OpenAI-compatible reasoning effort for the request body.
159    #[must_use]
160    pub const fn with_reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
161        self.reasoning_effort = Some(effort);
162        self
163    }
164
165    /// Sets provider-specific chat-template options.
166    #[must_use]
167    pub fn with_chat_template_kwargs(mut self, value: serde_json::Value) -> Self {
168        self.chat_template_kwargs = Some(value);
169        self
170    }
171
172    /// Returns the configured model identifier.
173    #[must_use]
174    pub fn model(&self) -> &str {
175        &self.model
176    }
177
178    /// Returns the configured chat-completions base URL.
179    #[must_use]
180    pub fn base_url(&self) -> &str {
181        &self.base_url
182    }
183
184    /// Sends a chat completion request and returns the first message content.
185    ///
186    /// # Errors
187    ///
188    /// Returns [`ChatError::Http`] for transport failures or non-2xx statuses.
189    /// Returns [`ChatError::MalformedResponse`] when the response body cannot be
190    /// decoded or the first choice has no content.
191    pub fn complete(
192        &self,
193        messages: Vec<ChatMessage>,
194        temperature: f32,
195    ) -> Result<String, ChatError> {
196        self.complete_raw(messages, temperature)
197            .map(|output| output.content)
198    }
199
200    /// Sends a chat completion request and returns content plus raw response.
201    ///
202    /// # Errors
203    ///
204    /// Returns [`ChatError::Http`] for transport failures or non-2xx statuses.
205    /// Returns [`ChatError::MalformedResponse`] when the response body cannot be
206    /// decoded or the first choice has no visible content.
207    pub fn complete_raw(
208        &self,
209        messages: Vec<ChatMessage>,
210        temperature: f32,
211    ) -> Result<ChatCompletionOutput, ChatError> {
212        let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
213        let body = ChatCompletionRequest {
214            model: self.model.clone(),
215            messages,
216            max_tokens: self.max_tokens,
217            reasoning_effort: self.reasoning_effort,
218            temperature,
219            chat_template_kwargs: self.chat_template_kwargs.clone(),
220        };
221
222        let mut request = self.http.post(&url).json(&body);
223        if let Some(key) = &self.auth.api_key {
224            request = request.bearer_auth(key);
225        }
226        for (name, value) in &self.auth.extra_headers {
227            request = request.header(name.as_str(), value.as_str());
228        }
229
230        let response = request.send().map_err(|err| ChatError::Http {
231            status: None,
232            message: redact(&err.to_string()),
233            timed_out: err.is_timeout(),
234        })?;
235
236        let status = response.status();
237        if !status.is_success() {
238            let snippet = response.text().unwrap_or_default();
239            return Err(ChatError::Http {
240                status: Some(status.as_u16()),
241                message: redact(&snippet),
242                timed_out: false,
243            });
244        }
245
246        let text = response.text().map_err(|_| ChatError::MalformedResponse)?;
247        let completion: ChatCompletionResponse =
248            serde_json::from_str(&text).map_err(|_| ChatError::MalformedResponse)?;
249        let message = completion
250            .choices
251            .first()
252            .map(|choice| &choice.message)
253            .ok_or(ChatError::MalformedResponse)?;
254        let content = message
255            .content
256            .clone()
257            .filter(|content| !content.trim().is_empty())
258            .ok_or(ChatError::MalformedResponse)?;
259        Ok(ChatCompletionOutput {
260            content,
261            reasoning_content: message.reasoning_content.clone(),
262            raw_response: text,
263        })
264    }
265}
266
267/// Strips Markdown code fences and extracts the JSON object substring.
268///
269/// Ports `stripCodeFences` from `clients/sidecar-llm/local-llm.ts`.
270#[must_use]
271pub fn strip_code_fences(content: &str) -> String {
272    let stripped = content
273        .trim()
274        .trim_start_matches("```json")
275        .trim_start_matches("```")
276        .trim_end_matches("```")
277        .trim();
278    match (stripped.find('{'), stripped.rfind('}')) {
279        (Some(start), Some(end)) if end > start => stripped[start..=end].to_owned(),
280        _ => stripped.to_owned(),
281    }
282}
283
284#[cfg(test)]
285mod tests;