1use crate::types::{Message, ToolDefinition};
2use async_trait::async_trait;
3use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
4use serde::{Deserialize, Serialize};
5use serde_json::{Map, Value};
6
7#[derive(Serialize)]
9struct ChatRequest {
10 model: String,
12 messages: Vec<Message>,
14 tools: Vec<ToolDefinition>,
16 #[serde(skip_serializing_if = "Option::is_none")]
18 temperature: Option<f32>,
19 #[serde(skip_serializing_if = "Option::is_none")]
21 max_tokens: Option<u32>,
22 #[serde(skip_serializing_if = "Option::is_none")]
24 stream: Option<bool>,
25}
26
27#[derive(Deserialize)]
29struct ChatResponse {
30 choices: Vec<Choice>,
32}
33
34#[derive(Deserialize)]
36struct StreamChunk {
37 choices: Vec<StreamChoice>,
38}
39
40#[derive(Deserialize)]
42struct StreamChoice {
43 delta: Delta,
44}
45
46#[derive(Deserialize)]
48struct Delta {
49 #[serde(default)]
50 content: Option<String>,
51 #[serde(default)]
52 tool_calls: Option<Vec<crate::types::ToolCall>>,
53}
54
55#[derive(Deserialize)]
57struct Choice {
58 message: Message,
60}
61
62pub struct OpenAIProvider {
75 client: reqwest::Client,
77 base_url: String,
79 api_key: String,
81 model: String,
83 temperature: Option<f32>,
85 max_tokens: Option<u32>,
87 custom_headers: HeaderMap,
89 max_retries: u32,
91 retry_delay_ms: u64,
93 custom_body: Map<String, Value>,
95}
96
97impl Default for OpenAIProvider {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl OpenAIProvider {
104 pub fn new() -> Self {
114 Self {
115 client: reqwest::Client::new(),
116 base_url: "https://api.openai.com/v1".into(),
117 api_key: "".into(),
118 model: "gpt-4o".into(),
119 temperature: None,
120 max_tokens: None,
121 custom_headers: HeaderMap::new(),
122 max_retries: 3,
123 retry_delay_ms: 1000,
124 custom_body: Map::new(),
125 }
126 }
127
128 pub fn base_url(mut self, value: impl Into<String>) -> Self {
139 self.base_url = value.into();
140 self
141 }
142
143 pub fn api_key(mut self, value: impl Into<String>) -> Self {
154 self.api_key = value.into();
155 self
156 }
157
158 pub fn model(mut self, value: impl Into<String>) -> Self {
169 self.model = value.into();
170 self
171 }
172
173 pub fn temperature(mut self, value: impl Into<Option<f32>>) -> Self {
184 self.temperature = value.into();
185 self
186 }
187
188 pub fn max_tokens(mut self, value: impl Into<Option<u32>>) -> Self {
199 self.max_tokens = value.into();
200 self
201 }
202
203 pub fn header(
219 mut self,
220 key: impl Into<String>,
221 value: impl Into<String>,
222 ) -> anyhow::Result<Self> {
223 self.custom_headers.insert(
224 HeaderName::try_from(key.into())?,
225 HeaderValue::try_from(value.into())?,
226 );
227 Ok(self)
228 }
229
230 pub fn max_retries(mut self, retries: u32) -> Self {
241 self.max_retries = retries;
242 self
243 }
244
245 pub fn retry_delay(mut self, delay_ms: u64) -> Self {
256 self.retry_delay_ms = delay_ms;
257 self
258 }
259
260 pub fn body(mut self, body: Value) -> anyhow::Result<Self> {
280 self.custom_body = body
281 .as_object()
282 .ok_or_else(|| anyhow::anyhow!("body must be a JSON object"))?
283 .clone();
284 Ok(self)
285 }
286}
287
288#[async_trait]
289impl super::LLMProvider for OpenAIProvider {
290 async fn call(
291 &self,
292 messages: &[Message],
293 tools: &[ToolDefinition],
294 mut stream_callback: Option<&mut super::StreamCallback>,
295 ) -> anyhow::Result<Message> {
296 let mut attempt = 0;
297 loop {
298 attempt += 1;
299 tracing::debug!(
300 model = %self.model,
301 messages = messages.len(),
302 tools = tools.len(),
303 streaming = stream_callback.is_some(),
304 attempt = attempt,
305 max_retries = self.max_retries,
306 "Calling LLM API"
307 );
308
309 match self
310 .call_once(messages, tools, stream_callback.as_deref_mut())
311 .await
312 {
313 Ok(message) => return Ok(message),
314 Err(e) if attempt > self.max_retries => {
315 tracing::debug!("Max retries exceeded");
316 return Err(e);
317 }
318 Err(e) => {
319 tracing::debug!("API call failed, retrying: {}", e);
320 tokio::time::sleep(tokio::time::Duration::from_millis(self.retry_delay_ms))
321 .await;
322 }
323 }
324 }
325 }
326}
327
328impl OpenAIProvider {
329 async fn call_once(
330 &self,
331 messages: &[Message],
332 tools: &[ToolDefinition],
333 stream_callback: Option<&mut super::StreamCallback>,
334 ) -> anyhow::Result<Message> {
335 let request = ChatRequest {
336 model: self.model.clone(),
337 messages: messages.to_vec(),
338 tools: tools.to_vec(),
339 temperature: self.temperature,
340 max_tokens: self.max_tokens,
341 stream: if stream_callback.is_some() {
342 Some(true)
343 } else {
344 None
345 },
346 };
347
348 let mut body = serde_json::to_value(&request)?.as_object().unwrap().clone();
349 body.extend(self.custom_body.clone());
350
351 let response = self
352 .client
353 .post(format!("{}/chat/completions", self.base_url))
354 .header("Authorization", format!("Bearer {}", self.api_key))
355 .header("Content-Type", "application/json")
356 .headers(self.custom_headers.clone())
357 .json(&body)
358 .send()
359 .await?;
360
361 let status = response.status();
362 tracing::trace!("LLM API response status: {}", status);
363
364 if !status.is_success() {
365 let body = response.text().await?;
366 tracing::debug!("LLM API error: status={}, body={}", status, body);
367 anyhow::bail!("API error ({}): {}", status, body);
368 }
369
370 if let Some(callback) = stream_callback {
371 self.handle_stream(response, callback).await
372 } else {
373 let body = response.text().await?;
374 let chat_response: ChatResponse = serde_json::from_str(&body)
375 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}. Body: {}", e, body))?;
376 tracing::debug!("LLM API call completed successfully");
377 Ok(chat_response.choices[0].message.clone())
378 }
379 }
380
381 async fn handle_stream(
382 &self,
383 response: reqwest::Response,
384 callback: &mut super::StreamCallback,
385 ) -> anyhow::Result<Message> {
386 use futures::TryStreamExt;
387
388 let mut stream = response.bytes_stream();
389 let mut buffer = String::new();
390 let mut content = String::new();
391 let mut tool_calls = Vec::new();
392
393 while let Some(chunk) = stream.try_next().await? {
394 buffer.push_str(&String::from_utf8_lossy(&chunk));
395
396 while let Some(line_end) = buffer.find('\n') {
397 let line = buffer[..line_end].trim().to_string();
398 buffer.drain(..=line_end);
399
400 if let Some(data) = line.strip_prefix("data: ") {
401 if data == "[DONE]" {
402 break;
403 }
404
405 if let Ok(chunk) = serde_json::from_str::<StreamChunk>(data) {
406 if let Some(choice) = chunk.choices.first() {
407 if let Some(delta_content) = &choice.delta.content {
408 content.push_str(delta_content);
409 callback(delta_content.clone());
410 }
411
412 if let Some(delta_tool_calls) = &choice.delta.tool_calls {
413 tool_calls.extend(delta_tool_calls.clone());
414 }
415 }
416 }
417 }
418 }
419 }
420
421 tracing::debug!("Streaming completed, total length: {}", content.len());
422 Ok(Message::Assistant {
423 content,
424 tool_calls: if tool_calls.is_empty() {
425 None
426 } else {
427 Some(tool_calls)
428 },
429 })
430 }
431}