Skip to main content

tiny_loop/llm/
openai.rs

1use crate::types::{FinishReason, LLMResponse, Message, ToolDefinition};
2use async_trait::async_trait;
3use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
4use serde::{Deserialize, Serialize};
5use serde_json::{Map, Value};
6
7/// Callback for streaming OpenAI responses
8pub type OpenAIStreamCallback = Box<dyn FnMut(String) + Send + Sync>;
9
10/// Request payload for OpenAI chat completions API
11#[derive(Serialize)]
12struct ChatRequest {
13    /// Model ID
14    model: String,
15    /// Conversation messages
16    messages: Vec<Message>,
17    /// Available tools for the model
18    tools: Vec<ToolDefinition>,
19    /// Enable streaming
20    #[serde(skip_serializing_if = "Option::is_none")]
21    stream: Option<bool>,
22}
23
24/// Response from OpenAI chat completions API
25#[derive(Deserialize)]
26struct ChatResponse {
27    /// List of completion choices
28    choices: Vec<Choice>,
29}
30
31/// Streaming response chunk
32#[derive(Deserialize)]
33struct StreamChunk {
34    choices: Vec<StreamChoice>,
35}
36
37/// Streaming choice
38#[derive(Deserialize)]
39struct StreamChoice {
40    delta: Delta,
41    #[serde(default)]
42    finish_reason: Option<FinishReason>,
43}
44
45/// Delta content in streaming
46#[derive(Deserialize)]
47struct Delta {
48    #[serde(default)]
49    content: Option<String>,
50    #[serde(default)]
51    tool_calls: Option<Vec<crate::types::ToolCall>>,
52}
53
54/// Single completion choice from the API response
55#[derive(Deserialize)]
56struct Choice {
57    /// Assistant's response message
58    message: Message,
59    /// Reason the completion finished
60    finish_reason: FinishReason,
61}
62
63/// OpenAI-compatible LLM provider
64///
65/// # Examples
66///
67/// ```
68/// use tiny_loop::llm::OpenAIProvider;
69///
70/// let provider = OpenAIProvider::new()
71///     .api_key("sk-...")
72///     .model("gpt-4o");
73/// ```
74pub struct OpenAIProvider {
75    /// HTTP client for API requests
76    client: reqwest::Client,
77    /// API base URL
78    base_url: String,
79    /// API authentication key
80    api_key: String,
81    /// Model identifier
82    model: String,
83    /// Additional HTTP headers
84    custom_headers: HeaderMap,
85    /// Maximum number of retries on failure
86    max_retries: u32,
87    /// Delay between retries in milliseconds
88    retry_delay_ms: u64,
89    /// Custom body fields to merge into the request
90    custom_body: Map<String, Value>,
91    /// Stream callback for LLM responses
92    stream_callback: Option<OpenAIStreamCallback>,
93}
94
95impl Default for OpenAIProvider {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101impl OpenAIProvider {
102    /// Create a new OpenAI provider with default settings
103    ///
104    /// # Examples
105    ///
106    /// ```
107    /// use tiny_loop::llm::OpenAIProvider;
108    ///
109    /// let provider = OpenAIProvider::new();
110    /// ```
111    pub fn new() -> Self {
112        Self {
113            client: reqwest::Client::new(),
114            base_url: "https://api.openai.com/v1".into(),
115            api_key: "".into(),
116            model: "gpt-4o".into(),
117            custom_headers: HeaderMap::new(),
118            max_retries: 3,
119            retry_delay_ms: 1000,
120            custom_body: Map::new(),
121            stream_callback: None,
122        }
123    }
124
125    /// Set the base URL for the API endpoint (default: `https://api.openai.com/v1`)
126    ///
127    /// # Examples
128    ///
129    /// ```
130    /// use tiny_loop::llm::OpenAIProvider;
131    ///
132    /// let provider = OpenAIProvider::new()
133    ///     .base_url("https://api.custom.com/v1");
134    /// ```
135    pub fn base_url(mut self, value: impl Into<String>) -> Self {
136        self.base_url = value.into();
137        self
138    }
139
140    /// Set the API key for authentication (default: empty string)
141    ///
142    /// # Examples
143    ///
144    /// ```
145    /// use tiny_loop::llm::OpenAIProvider;
146    ///
147    /// let provider = OpenAIProvider::new()
148    ///     .api_key("sk-...");
149    /// ```
150    pub fn api_key(mut self, value: impl Into<String>) -> Self {
151        self.api_key = value.into();
152        self
153    }
154
155    /// Set the model name to use (default: `gpt-4o`)
156    ///
157    /// # Examples
158    ///
159    /// ```
160    /// use tiny_loop::llm::OpenAIProvider;
161    ///
162    /// let provider = OpenAIProvider::new()
163    ///     .model("gpt-4o-mini");
164    /// ```
165    pub fn model(mut self, value: impl Into<String>) -> Self {
166        self.model = value.into();
167        self
168    }
169
170    /// Add a custom HTTP header to requests
171    ///
172    /// # Examples
173    ///
174    /// ```
175    /// use tiny_loop::llm::OpenAIProvider;
176    ///
177    /// let provider = OpenAIProvider::new()
178    ///     .header("x-custom-header", "value")
179    ///     .unwrap();
180    /// ```
181    ///
182    /// # Errors
183    ///
184    /// Returns an error if the header name or value contains invalid characters.
185    pub fn header(
186        mut self,
187        key: impl Into<String>,
188        value: impl Into<String>,
189    ) -> crate::Result<Self> {
190        self.custom_headers.insert(
191            HeaderName::try_from(key.into())
192                .map_err(|e| crate::Error::InvalidHeader(e.to_string()))?,
193            HeaderValue::try_from(value.into())
194                .map_err(|e| crate::Error::InvalidHeader(e.to_string()))?,
195        );
196        Ok(self)
197    }
198
199    /// Set maximum number of retries on failure (default: 3)
200    ///
201    /// # Examples
202    ///
203    /// ```
204    /// use tiny_loop::llm::OpenAIProvider;
205    ///
206    /// let provider = OpenAIProvider::new()
207    ///     .max_retries(5);
208    /// ```
209    pub fn max_retries(mut self, retries: u32) -> Self {
210        self.max_retries = retries;
211        self
212    }
213
214    /// Set delay between retries in milliseconds (default: 1000)
215    ///
216    /// # Examples
217    ///
218    /// ```
219    /// use tiny_loop::llm::OpenAIProvider;
220    ///
221    /// let provider = OpenAIProvider::new()
222    ///     .retry_delay(2000);
223    /// ```
224    pub fn retry_delay(mut self, delay_ms: u64) -> Self {
225        self.retry_delay_ms = delay_ms;
226        self
227    }
228
229    /// Set custom body fields to merge into the request
230    ///
231    /// # Examples
232    ///
233    /// ```
234    /// use tiny_loop::llm::OpenAIProvider;
235    /// use serde_json::json;
236    ///
237    /// let provider = OpenAIProvider::new()
238    ///     .body(json!({
239    ///         "top_p": 0.9,
240    ///         "frequency_penalty": 0.5
241    ///     }))
242    ///     .unwrap();
243    /// ```
244    ///
245    /// # Errors
246    ///
247    /// Returns an error if the value is not a JSON object
248    pub fn body(mut self, body: Value) -> crate::Result<Self> {
249        self.custom_body = body.as_object().ok_or(crate::Error::InvalidBody)?.clone();
250        Ok(self)
251    }
252
253    /// Set stream callback for LLM responses
254    ///
255    /// # Examples
256    ///
257    /// ```
258    /// use tiny_loop::llm::OpenAIProvider;
259    ///
260    /// let provider = OpenAIProvider::new()
261    ///     .stream_callback(|chunk| print!("{}", chunk));
262    /// ```
263    pub fn stream_callback<F>(mut self, callback: F) -> Self
264    where
265        F: FnMut(String) + Send + Sync + 'static,
266    {
267        self.stream_callback = Some(Box::new(callback));
268        self
269    }
270}
271
272#[async_trait]
273impl super::LLMProvider for OpenAIProvider {
274    async fn call(
275        &mut self,
276        messages: &[Message],
277        tools: &[ToolDefinition],
278    ) -> crate::Result<LLMResponse> {
279        let mut attempt = 0;
280        loop {
281            attempt += 1;
282            tracing::debug!(
283                model = %self.model,
284                messages = messages.len(),
285                tools = tools.len(),
286                streaming = self.stream_callback.is_some(),
287                attempt = attempt,
288                max_retries = self.max_retries,
289                "Calling LLM API"
290            );
291
292            match self.call_once(messages, tools).await {
293                Ok(response) => return Ok(response),
294                Err(e) if attempt > self.max_retries => {
295                    tracing::debug!("Max retries exceeded");
296                    return Err(e);
297                }
298                Err(e) => {
299                    tracing::debug!("API call failed, retrying: {}", e);
300                    tokio::time::sleep(tokio::time::Duration::from_millis(self.retry_delay_ms))
301                        .await;
302                }
303            }
304        }
305    }
306}
307
308impl OpenAIProvider {
309    async fn call_once(
310        &mut self,
311        messages: &[Message],
312        tools: &[ToolDefinition],
313    ) -> crate::Result<LLMResponse> {
314        let request = ChatRequest {
315            model: self.model.clone(),
316            messages: messages.to_vec(),
317            tools: tools.to_vec(),
318            stream: if self.stream_callback.is_some() {
319                Some(true)
320            } else {
321                None
322            },
323        };
324
325        let mut body = serde_json::to_value(&request)?.as_object().unwrap().clone();
326        body.extend(self.custom_body.clone());
327
328        let response = self
329            .client
330            .post(format!("{}/chat/completions", self.base_url))
331            .header("Authorization", format!("Bearer {}", self.api_key))
332            .header("Content-Type", "application/json")
333            .headers(self.custom_headers.clone())
334            .json(&body)
335            .send()
336            .await?;
337
338        let status = response.status();
339        tracing::trace!("LLM API response status: {}", status);
340
341        if !status.is_success() {
342            let body = response.text().await?;
343            tracing::debug!("LLM API error: status={}, body={}", status, body);
344            return Err(crate::Error::ApiError {
345                status: status.as_u16(),
346                body,
347            });
348        }
349
350        if self.stream_callback.is_some() {
351            self.handle_stream(response).await
352        } else {
353            let body = response.text().await?;
354            let chat_response: ChatResponse = serde_json::from_str(&body).map_err(|e| {
355                crate::Error::Custom(format!("Failed to parse response: {}. Body: {}", e, body))
356            })?;
357            tracing::debug!("LLM API call completed successfully");
358            let choice = &chat_response.choices[0];
359            let Message::Assistant(msg) = &choice.message else {
360                return Err(crate::Error::UnexpectedMessage(format!(
361                    "{:?}",
362                    choice.message
363                )));
364            };
365            Ok(LLMResponse {
366                message: msg.clone(),
367                finish_reason: choice.finish_reason.clone(),
368            })
369        }
370    }
371
372    async fn handle_stream(&mut self, response: reqwest::Response) -> crate::Result<LLMResponse> {
373        use futures::TryStreamExt;
374
375        let mut stream = response.bytes_stream();
376        let mut buffer = String::new();
377        let mut content = String::new();
378        let mut tool_calls = Vec::new();
379        let mut finish_reason = FinishReason::Stop;
380
381        while let Some(chunk) = stream.try_next().await? {
382            buffer.push_str(&String::from_utf8_lossy(&chunk));
383
384            while let Some(line_end) = buffer.find('\n') {
385                let line = buffer[..line_end].trim().to_string();
386                buffer.drain(..=line_end);
387
388                if let Some(data) = line.strip_prefix("data: ") {
389                    if data == "[DONE]" {
390                        break;
391                    }
392
393                    if let Ok(chunk) = serde_json::from_str::<StreamChunk>(data) {
394                        if let Some(choice) = chunk.choices.first() {
395                            if let Some(delta_content) = &choice.delta.content {
396                                content.push_str(delta_content);
397                                if let Some(callback) = &mut self.stream_callback {
398                                    callback(delta_content.clone());
399                                }
400                            }
401
402                            if let Some(delta_tool_calls) = &choice.delta.tool_calls {
403                                tool_calls.extend(delta_tool_calls.clone());
404                            }
405
406                            if let Some(reason) = &choice.finish_reason {
407                                finish_reason = reason.clone();
408                            }
409                        }
410                    }
411                }
412            }
413        }
414
415        tracing::debug!("Streaming completed, total length: {}", content.len());
416        Ok(LLMResponse {
417            message: crate::types::AssistantMessage {
418                content,
419                tool_calls: if tool_calls.is_empty() {
420                    None
421                } else {
422                    Some(tool_calls)
423                },
424            },
425            finish_reason,
426        })
427    }
428}