1use crate::types::{FinishReason, LLMResponse, Message, ToolDefinition};
2use async_trait::async_trait;
3use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
4use serde::{Deserialize, Serialize};
5use serde_json::{Map, Value};
6
7pub type OpenAIStreamCallback = Box<dyn FnMut(String) + Send + Sync>;
9
10#[derive(Serialize)]
12struct ChatRequest {
13 model: String,
15 messages: Vec<Message>,
17 tools: Vec<ToolDefinition>,
19 #[serde(skip_serializing_if = "Option::is_none")]
21 stream: Option<bool>,
22}
23
24#[derive(Deserialize)]
26struct ChatResponse {
27 choices: Vec<Choice>,
29}
30
31#[derive(Deserialize)]
33struct StreamChunk {
34 choices: Vec<StreamChoice>,
35}
36
37#[derive(Deserialize)]
39struct StreamChoice {
40 delta: Delta,
41 #[serde(default)]
42 finish_reason: Option<FinishReason>,
43}
44
45#[derive(Deserialize)]
47struct Delta {
48 #[serde(default)]
49 content: Option<String>,
50 #[serde(default)]
51 tool_calls: Option<Vec<crate::types::ToolCall>>,
52}
53
54#[derive(Deserialize)]
56struct Choice {
57 message: Message,
59 finish_reason: FinishReason,
61}
62
63pub struct OpenAIProvider {
75 client: reqwest::Client,
77 base_url: String,
79 api_key: String,
81 model: String,
83 custom_headers: HeaderMap,
85 max_retries: u32,
87 retry_delay_ms: u64,
89 custom_body: Map<String, Value>,
91 stream_callback: Option<OpenAIStreamCallback>,
93}
94
95impl Default for OpenAIProvider {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101impl OpenAIProvider {
102 pub fn new() -> Self {
112 Self {
113 client: reqwest::Client::new(),
114 base_url: "https://api.openai.com/v1".into(),
115 api_key: "".into(),
116 model: "gpt-4o".into(),
117 custom_headers: HeaderMap::new(),
118 max_retries: 3,
119 retry_delay_ms: 1000,
120 custom_body: Map::new(),
121 stream_callback: None,
122 }
123 }
124
125 pub fn base_url(mut self, value: impl Into<String>) -> Self {
136 self.base_url = value.into();
137 self
138 }
139
140 pub fn api_key(mut self, value: impl Into<String>) -> Self {
151 self.api_key = value.into();
152 self
153 }
154
155 pub fn model(mut self, value: impl Into<String>) -> Self {
166 self.model = value.into();
167 self
168 }
169
170 pub fn header(
186 mut self,
187 key: impl Into<String>,
188 value: impl Into<String>,
189 ) -> anyhow::Result<Self> {
190 self.custom_headers.insert(
191 HeaderName::try_from(key.into())?,
192 HeaderValue::try_from(value.into())?,
193 );
194 Ok(self)
195 }
196
197 pub fn max_retries(mut self, retries: u32) -> Self {
208 self.max_retries = retries;
209 self
210 }
211
212 pub fn retry_delay(mut self, delay_ms: u64) -> Self {
223 self.retry_delay_ms = delay_ms;
224 self
225 }
226
227 pub fn body(mut self, body: Value) -> anyhow::Result<Self> {
247 self.custom_body = body
248 .as_object()
249 .ok_or_else(|| anyhow::anyhow!("body must be a JSON object"))?
250 .clone();
251 Ok(self)
252 }
253
254 pub fn stream_callback<F>(mut self, callback: F) -> Self
265 where
266 F: FnMut(String) + Send + Sync + 'static,
267 {
268 self.stream_callback = Some(Box::new(callback));
269 self
270 }
271}
272
273#[async_trait]
274impl super::LLMProvider for OpenAIProvider {
275 async fn call(
276 &mut self,
277 messages: &[Message],
278 tools: &[ToolDefinition],
279 ) -> anyhow::Result<LLMResponse> {
280 let mut attempt = 0;
281 loop {
282 attempt += 1;
283 tracing::debug!(
284 model = %self.model,
285 messages = messages.len(),
286 tools = tools.len(),
287 streaming = self.stream_callback.is_some(),
288 attempt = attempt,
289 max_retries = self.max_retries,
290 "Calling LLM API"
291 );
292
293 match self.call_once(messages, tools).await {
294 Ok(response) => return Ok(response),
295 Err(e) if attempt > self.max_retries => {
296 tracing::debug!("Max retries exceeded");
297 return Err(e);
298 }
299 Err(e) => {
300 tracing::debug!("API call failed, retrying: {}", e);
301 tokio::time::sleep(tokio::time::Duration::from_millis(self.retry_delay_ms))
302 .await;
303 }
304 }
305 }
306 }
307}
308
309impl OpenAIProvider {
310 async fn call_once(
311 &mut self,
312 messages: &[Message],
313 tools: &[ToolDefinition],
314 ) -> anyhow::Result<LLMResponse> {
315 let request = ChatRequest {
316 model: self.model.clone(),
317 messages: messages.to_vec(),
318 tools: tools.to_vec(),
319 stream: if self.stream_callback.is_some() {
320 Some(true)
321 } else {
322 None
323 },
324 };
325
326 let mut body = serde_json::to_value(&request)?.as_object().unwrap().clone();
327 body.extend(self.custom_body.clone());
328
329 let response = self
330 .client
331 .post(format!("{}/chat/completions", self.base_url))
332 .header("Authorization", format!("Bearer {}", self.api_key))
333 .header("Content-Type", "application/json")
334 .headers(self.custom_headers.clone())
335 .json(&body)
336 .send()
337 .await?;
338
339 let status = response.status();
340 tracing::trace!("LLM API response status: {}", status);
341
342 if !status.is_success() {
343 let body = response.text().await?;
344 tracing::debug!("LLM API error: status={}, body={}", status, body);
345 anyhow::bail!("API error ({}): {}", status, body);
346 }
347
348 if self.stream_callback.is_some() {
349 self.handle_stream(response).await
350 } else {
351 let body = response.text().await?;
352 let chat_response: ChatResponse = serde_json::from_str(&body)
353 .map_err(|e| anyhow::anyhow!("Failed to parse response: {}. Body: {}", e, body))?;
354 tracing::debug!("LLM API call completed successfully");
355 let choice = &chat_response.choices[0];
356 let Message::Assistant(msg) = &choice.message else {
357 anyhow::bail!("Expected Assistant message, got: {:?}", choice.message);
358 };
359 Ok(LLMResponse {
360 message: msg.clone(),
361 finish_reason: choice.finish_reason.clone(),
362 })
363 }
364 }
365
366 async fn handle_stream(&mut self, response: reqwest::Response) -> anyhow::Result<LLMResponse> {
367 use futures::TryStreamExt;
368
369 let mut stream = response.bytes_stream();
370 let mut buffer = String::new();
371 let mut content = String::new();
372 let mut tool_calls = Vec::new();
373 let mut finish_reason = FinishReason::Stop;
374
375 while let Some(chunk) = stream.try_next().await? {
376 buffer.push_str(&String::from_utf8_lossy(&chunk));
377
378 while let Some(line_end) = buffer.find('\n') {
379 let line = buffer[..line_end].trim().to_string();
380 buffer.drain(..=line_end);
381
382 if let Some(data) = line.strip_prefix("data: ") {
383 if data == "[DONE]" {
384 break;
385 }
386
387 if let Ok(chunk) = serde_json::from_str::<StreamChunk>(data) {
388 if let Some(choice) = chunk.choices.first() {
389 if let Some(delta_content) = &choice.delta.content {
390 content.push_str(delta_content);
391 if let Some(callback) = &mut self.stream_callback {
392 callback(delta_content.clone());
393 }
394 }
395
396 if let Some(delta_tool_calls) = &choice.delta.tool_calls {
397 tool_calls.extend(delta_tool_calls.clone());
398 }
399
400 if let Some(reason) = &choice.finish_reason {
401 finish_reason = reason.clone();
402 }
403 }
404 }
405 }
406 }
407 }
408
409 tracing::debug!("Streaming completed, total length: {}", content.len());
410 Ok(LLMResponse {
411 message: crate::types::AssistantMessage {
412 content,
413 tool_calls: if tool_calls.is_empty() {
414 None
415 } else {
416 Some(tool_calls)
417 },
418 },
419 finish_reason,
420 })
421 }
422}