1use std::time::Duration;
4
5use reqwest::blocking::Client as HttpClient;
6
7use crate::config::ResolvedAuth;
8use crate::inference::redact;
9
10use super::error::ChatError;
11use super::types::{
12 ChatCompletionOutput, ChatCompletionRequest, ChatCompletionResponse, ChatMessage,
13 ReasoningEffort,
14};
15
16pub const DEFAULT_CHAT_TIMEOUT: Duration = Duration::from_secs(30);
18
19#[derive(Debug, Clone)]
21pub struct ChatClient {
22 base_url: String,
23 model: String,
24 max_tokens: Option<u32>,
25 reasoning_effort: Option<ReasoningEffort>,
26 chat_template_kwargs: Option<serde_json::Value>,
27 auth: ResolvedAuth,
28 http: HttpClient,
29}
30
31impl ChatClient {
32 pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Result<Self, ChatError> {
39 Self::with_timeout(base_url, model, DEFAULT_CHAT_TIMEOUT)
40 }
41
42 pub fn with_timeout(
49 base_url: impl Into<String>,
50 model: impl Into<String>,
51 timeout: Duration,
52 ) -> Result<Self, ChatError> {
53 Self::with_timeout_and_max_tokens(base_url, model, timeout, None)
54 }
55
56 pub fn with_max_tokens(
63 base_url: impl Into<String>,
64 model: impl Into<String>,
65 max_tokens: Option<u32>,
66 ) -> Result<Self, ChatError> {
67 Self::with_timeout_and_max_tokens(base_url, model, DEFAULT_CHAT_TIMEOUT, max_tokens)
68 }
69
70 pub fn with_timeout_and_max_tokens(
77 base_url: impl Into<String>,
78 model: impl Into<String>,
79 timeout: Duration,
80 max_tokens: Option<u32>,
81 ) -> Result<Self, ChatError> {
82 Self::with_timeout_max_tokens_and_auth(
83 base_url,
84 model,
85 timeout,
86 max_tokens,
87 ResolvedAuth::default(),
88 )
89 }
90
91 pub fn with_timeout_max_tokens_and_auth(
98 base_url: impl Into<String>,
99 model: impl Into<String>,
100 timeout: Duration,
101 max_tokens: Option<u32>,
102 auth: ResolvedAuth,
103 ) -> Result<Self, ChatError> {
104 Self::with_optional_timeout_max_tokens_and_auth(
105 base_url,
106 model,
107 Some(timeout),
108 max_tokens,
109 auth,
110 )
111 }
112
113 pub fn with_no_timeout_and_max_tokens(
120 base_url: impl Into<String>,
121 model: impl Into<String>,
122 max_tokens: Option<u32>,
123 ) -> Result<Self, ChatError> {
124 Self::with_optional_timeout_max_tokens_and_auth(
125 base_url,
126 model,
127 None,
128 max_tokens,
129 ResolvedAuth::default(),
130 )
131 }
132
133 fn with_optional_timeout_max_tokens_and_auth(
134 base_url: impl Into<String>,
135 model: impl Into<String>,
136 timeout: Option<Duration>,
137 max_tokens: Option<u32>,
138 auth: ResolvedAuth,
139 ) -> Result<Self, ChatError> {
140 let mut builder = HttpClient::builder();
141 if let Some(timeout) = timeout {
142 builder = builder.timeout(timeout);
143 }
144 let http = builder.build().map_err(|err| ChatError::Build {
145 message: redact(&err.to_string()),
146 })?;
147 Ok(Self {
148 base_url: base_url.into(),
149 model: model.into(),
150 max_tokens,
151 reasoning_effort: None,
152 chat_template_kwargs: None,
153 auth,
154 http,
155 })
156 }
157
158 #[must_use]
160 pub const fn with_reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
161 self.reasoning_effort = Some(effort);
162 self
163 }
164
165 #[must_use]
167 pub fn with_chat_template_kwargs(mut self, value: serde_json::Value) -> Self {
168 self.chat_template_kwargs = Some(value);
169 self
170 }
171
172 #[must_use]
174 pub fn model(&self) -> &str {
175 &self.model
176 }
177
178 #[must_use]
180 pub fn base_url(&self) -> &str {
181 &self.base_url
182 }
183
184 pub fn complete(
192 &self,
193 messages: Vec<ChatMessage>,
194 temperature: f32,
195 ) -> Result<String, ChatError> {
196 self.complete_raw(messages, temperature)
197 .map(|output| output.content)
198 }
199
200 pub fn complete_raw(
208 &self,
209 messages: Vec<ChatMessage>,
210 temperature: f32,
211 ) -> Result<ChatCompletionOutput, ChatError> {
212 let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
213 let body = ChatCompletionRequest {
214 model: self.model.clone(),
215 messages,
216 max_tokens: self.max_tokens,
217 reasoning_effort: self.reasoning_effort,
218 temperature,
219 chat_template_kwargs: self.chat_template_kwargs.clone(),
220 };
221
222 let mut request = self.http.post(&url).json(&body);
223 if let Some(key) = &self.auth.api_key {
224 request = request.bearer_auth(key);
225 }
226 for (name, value) in &self.auth.extra_headers {
227 request = request.header(name.as_str(), value.as_str());
228 }
229
230 let response = request.send().map_err(|err| ChatError::Http {
231 status: None,
232 message: redact(&err.to_string()),
233 timed_out: err.is_timeout(),
234 })?;
235
236 let status = response.status();
237 if !status.is_success() {
238 let snippet = response.text().unwrap_or_default();
239 return Err(ChatError::Http {
240 status: Some(status.as_u16()),
241 message: redact(&snippet),
242 timed_out: false,
243 });
244 }
245
246 let text = response.text().map_err(|_| ChatError::MalformedResponse)?;
247 let completion: ChatCompletionResponse =
248 serde_json::from_str(&text).map_err(|_| ChatError::MalformedResponse)?;
249 let message = completion
250 .choices
251 .first()
252 .map(|choice| &choice.message)
253 .ok_or(ChatError::MalformedResponse)?;
254 let content = message
255 .content
256 .clone()
257 .filter(|content| !content.trim().is_empty())
258 .ok_or(ChatError::MalformedResponse)?;
259 Ok(ChatCompletionOutput {
260 content,
261 reasoning_content: message.reasoning_content.clone(),
262 raw_response: text,
263 })
264 }
265}
266
267#[must_use]
271pub fn strip_code_fences(content: &str) -> String {
272 let stripped = content
273 .trim()
274 .trim_start_matches("```json")
275 .trim_start_matches("```")
276 .trim_end_matches("```")
277 .trim();
278 match (stripped.find('{'), stripped.rfind('}')) {
279 (Some(start), Some(end)) if end > start => stripped[start..=end].to_owned(),
280 _ => stripped.to_owned(),
281 }
282}
283
284#[cfg(test)]
285mod tests;