1use crate::types::{FinishReason, LLMResponse, Message, ToolDefinition};
2use async_trait::async_trait;
3use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
4use serde::{Deserialize, Serialize};
5use serde_json::{Map, Value};
6
7pub type OpenAIStreamCallback = Box<dyn FnMut(String) + Send + Sync>;
9
10#[derive(Serialize)]
12struct ChatRequest {
13 model: String,
15 messages: Vec<Message>,
17 tools: Vec<ToolDefinition>,
19 #[serde(skip_serializing_if = "Option::is_none")]
21 stream: Option<bool>,
22}
23
24#[derive(Deserialize)]
26struct ChatResponse {
27 choices: Vec<Choice>,
29}
30
31#[derive(Deserialize)]
33struct StreamChunk {
34 choices: Vec<StreamChoice>,
35}
36
37#[derive(Deserialize)]
39struct StreamChoice {
40 delta: Delta,
41 #[serde(default)]
42 finish_reason: Option<FinishReason>,
43}
44
45#[derive(Deserialize)]
47struct Delta {
48 #[serde(default)]
49 content: Option<String>,
50 #[serde(default)]
51 tool_calls: Option<Vec<crate::types::ToolCall>>,
52}
53
54#[derive(Deserialize)]
56struct Choice {
57 message: Message,
59 finish_reason: FinishReason,
61}
62
63pub struct OpenAIProvider {
75 client: reqwest::Client,
77 base_url: String,
79 api_key: String,
81 model: String,
83 custom_headers: HeaderMap,
85 max_retries: u32,
87 retry_delay_ms: u64,
89 custom_body: Map<String, Value>,
91 stream_callback: Option<OpenAIStreamCallback>,
93}
94
95impl Default for OpenAIProvider {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101impl OpenAIProvider {
102 pub fn new() -> Self {
112 Self {
113 client: reqwest::Client::new(),
114 base_url: "https://api.openai.com/v1".into(),
115 api_key: "".into(),
116 model: "gpt-4o".into(),
117 custom_headers: HeaderMap::new(),
118 max_retries: 3,
119 retry_delay_ms: 1000,
120 custom_body: Map::new(),
121 stream_callback: None,
122 }
123 }
124
125 pub fn base_url(mut self, value: impl Into<String>) -> Self {
136 self.base_url = value.into();
137 self
138 }
139
140 pub fn api_key(mut self, value: impl Into<String>) -> Self {
151 self.api_key = value.into();
152 self
153 }
154
155 pub fn model(mut self, value: impl Into<String>) -> Self {
166 self.model = value.into();
167 self
168 }
169
170 pub fn header(
186 mut self,
187 key: impl Into<String>,
188 value: impl Into<String>,
189 ) -> crate::Result<Self> {
190 self.custom_headers.insert(
191 HeaderName::try_from(key.into())
192 .map_err(|e| crate::Error::InvalidHeader(e.to_string()))?,
193 HeaderValue::try_from(value.into())
194 .map_err(|e| crate::Error::InvalidHeader(e.to_string()))?,
195 );
196 Ok(self)
197 }
198
199 pub fn max_retries(mut self, retries: u32) -> Self {
210 self.max_retries = retries;
211 self
212 }
213
214 pub fn retry_delay(mut self, delay_ms: u64) -> Self {
225 self.retry_delay_ms = delay_ms;
226 self
227 }
228
229 pub fn body(mut self, body: Value) -> crate::Result<Self> {
249 self.custom_body = body.as_object().ok_or(crate::Error::InvalidBody)?.clone();
250 Ok(self)
251 }
252
253 pub fn stream_callback<F>(mut self, callback: F) -> Self
264 where
265 F: FnMut(String) + Send + Sync + 'static,
266 {
267 self.stream_callback = Some(Box::new(callback));
268 self
269 }
270}
271
272#[async_trait]
273impl super::LLMProvider for OpenAIProvider {
274 async fn call(
275 &mut self,
276 messages: &[Message],
277 tools: &[ToolDefinition],
278 ) -> crate::Result<LLMResponse> {
279 let mut attempt = 0;
280 loop {
281 attempt += 1;
282 tracing::debug!(
283 model = %self.model,
284 messages = messages.len(),
285 tools = tools.len(),
286 streaming = self.stream_callback.is_some(),
287 attempt = attempt,
288 max_retries = self.max_retries,
289 "Calling LLM API"
290 );
291
292 match self.call_once(messages, tools).await {
293 Ok(response) => return Ok(response),
294 Err(e) if attempt > self.max_retries => {
295 tracing::debug!("Max retries exceeded");
296 return Err(e);
297 }
298 Err(e) => {
299 tracing::debug!("API call failed, retrying: {}", e);
300 tokio::time::sleep(tokio::time::Duration::from_millis(self.retry_delay_ms))
301 .await;
302 }
303 }
304 }
305 }
306}
307
308impl OpenAIProvider {
309 async fn call_once(
310 &mut self,
311 messages: &[Message],
312 tools: &[ToolDefinition],
313 ) -> crate::Result<LLMResponse> {
314 let request = ChatRequest {
315 model: self.model.clone(),
316 messages: messages.to_vec(),
317 tools: tools.to_vec(),
318 stream: if self.stream_callback.is_some() {
319 Some(true)
320 } else {
321 None
322 },
323 };
324
325 let mut body = serde_json::to_value(&request)?.as_object().unwrap().clone();
326 body.extend(self.custom_body.clone());
327
328 let response = self
329 .client
330 .post(format!("{}/chat/completions", self.base_url))
331 .header("Authorization", format!("Bearer {}", self.api_key))
332 .header("Content-Type", "application/json")
333 .headers(self.custom_headers.clone())
334 .json(&body)
335 .send()
336 .await?;
337
338 let status = response.status();
339 tracing::trace!("LLM API response status: {}", status);
340
341 if !status.is_success() {
342 let body = response.text().await?;
343 tracing::debug!("LLM API error: status={}, body={}", status, body);
344 return Err(crate::Error::ApiError {
345 status: status.as_u16(),
346 body,
347 });
348 }
349
350 if self.stream_callback.is_some() {
351 self.handle_stream(response).await
352 } else {
353 let body = response.text().await?;
354 let chat_response: ChatResponse = serde_json::from_str(&body).map_err(|e| {
355 crate::Error::Custom(format!("Failed to parse response: {}. Body: {}", e, body))
356 })?;
357 tracing::debug!("LLM API call completed successfully");
358 let choice = &chat_response.choices[0];
359 let Message::Assistant(msg) = &choice.message else {
360 return Err(crate::Error::UnexpectedMessage(format!(
361 "{:?}",
362 choice.message
363 )));
364 };
365 Ok(LLMResponse {
366 message: msg.clone(),
367 finish_reason: choice.finish_reason.clone(),
368 })
369 }
370 }
371
372 async fn handle_stream(&mut self, response: reqwest::Response) -> crate::Result<LLMResponse> {
373 use futures::TryStreamExt;
374
375 let mut stream = response.bytes_stream();
376 let mut buffer = String::new();
377 let mut content = String::new();
378 let mut tool_calls = Vec::new();
379 let mut finish_reason = FinishReason::Stop;
380
381 while let Some(chunk) = stream.try_next().await? {
382 buffer.push_str(&String::from_utf8_lossy(&chunk));
383
384 while let Some(line_end) = buffer.find('\n') {
385 let line = buffer[..line_end].trim().to_string();
386 buffer.drain(..=line_end);
387
388 if let Some(data) = line.strip_prefix("data: ") {
389 if data == "[DONE]" {
390 break;
391 }
392
393 if let Ok(chunk) = serde_json::from_str::<StreamChunk>(data) {
394 if let Some(choice) = chunk.choices.first() {
395 if let Some(delta_content) = &choice.delta.content {
396 content.push_str(delta_content);
397 if let Some(callback) = &mut self.stream_callback {
398 callback(delta_content.clone());
399 }
400 }
401
402 if let Some(delta_tool_calls) = &choice.delta.tool_calls {
403 tool_calls.extend(delta_tool_calls.clone());
404 }
405
406 if let Some(reason) = &choice.finish_reason {
407 finish_reason = reason.clone();
408 }
409 }
410 }
411 }
412 }
413 }
414
415 tracing::debug!("Streaming completed, total length: {}", content.len());
416 Ok(LLMResponse {
417 message: crate::types::AssistantMessage {
418 content,
419 tool_calls: if tool_calls.is_empty() {
420 None
421 } else {
422 Some(tool_calls)
423 },
424 },
425 finish_reason,
426 })
427 }
428}