Skip to main content

tiny_loop/llm/
openai.rs

1use crate::types::{Message, ToolDefinition};
2use async_trait::async_trait;
3use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
4use serde::{Deserialize, Serialize};
5
6/// Request payload for OpenAI chat completions API
7#[derive(Serialize)]
8struct ChatRequest {
9    /// Model ID
10    model: String,
11    /// Conversation messages
12    messages: Vec<Message>,
13    /// Available tools for the model
14    tools: Vec<ToolDefinition>,
15    /// Sampling temperature (0-2)
16    #[serde(skip_serializing_if = "Option::is_none")]
17    temperature: Option<f32>,
18    /// Maximum tokens to generate
19    #[serde(skip_serializing_if = "Option::is_none")]
20    max_tokens: Option<u32>,
21    /// Enable streaming
22    #[serde(skip_serializing_if = "Option::is_none")]
23    stream: Option<bool>,
24}
25
26/// Response from OpenAI chat completions API
27#[derive(Deserialize)]
28struct ChatResponse {
29    /// List of completion choices
30    choices: Vec<Choice>,
31}
32
33/// Streaming response chunk
34#[derive(Deserialize)]
35struct StreamChunk {
36    choices: Vec<StreamChoice>,
37}
38
39/// Streaming choice
40#[derive(Deserialize)]
41struct StreamChoice {
42    delta: Delta,
43}
44
45/// Delta content in streaming
46#[derive(Deserialize)]
47struct Delta {
48    #[serde(default)]
49    content: Option<String>,
50    #[serde(default)]
51    tool_calls: Option<Vec<crate::types::ToolCall>>,
52}
53
54/// Single completion choice from the API response
55#[derive(Deserialize)]
56struct Choice {
57    /// Assistant's response message
58    message: Message,
59}
60
61/// OpenAI-compatible LLM provider
62///
63/// # Examples
64///
65/// ```
66/// use tiny_loop::llm::OpenAIProvider;
67///
68/// let provider = OpenAIProvider::new()
69///     .api_key("sk-...")
70///     .model("gpt-4o")
71///     .temperature(0.7);
72/// ```
73pub struct OpenAIProvider {
74    /// HTTP client for API requests
75    client: reqwest::Client,
76    /// API base URL
77    base_url: String,
78    /// API authentication key
79    api_key: String,
80    /// Model identifier
81    model: String,
82    /// Sampling temperature
83    temperature: Option<f32>,
84    /// Maximum tokens to generate
85    max_tokens: Option<u32>,
86    /// Additional HTTP headers
87    custom_headers: HeaderMap,
88    /// Maximum number of retries on failure
89    max_retries: u32,
90    /// Delay between retries in milliseconds
91    retry_delay_ms: u64,
92}
93
94impl Default for OpenAIProvider {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100impl OpenAIProvider {
101    /// Create a new OpenAI provider with default settings
102    ///
103    /// # Examples
104    ///
105    /// ```
106    /// use tiny_loop::llm::OpenAIProvider;
107    ///
108    /// let provider = OpenAIProvider::new();
109    /// ```
110    pub fn new() -> Self {
111        Self {
112            client: reqwest::Client::new(),
113            base_url: "https://api.openai.com/v1".into(),
114            api_key: "".into(),
115            model: "gpt-4o".into(),
116            temperature: None,
117            max_tokens: None,
118            custom_headers: HeaderMap::new(),
119            max_retries: 3,
120            retry_delay_ms: 1000,
121        }
122    }
123
124    /// Set the base URL for the API endpoint (default: `https://api.openai.com/v1`)
125    ///
126    /// # Examples
127    ///
128    /// ```
129    /// use tiny_loop::llm::OpenAIProvider;
130    ///
131    /// let provider = OpenAIProvider::new()
132    ///     .base_url("https://api.custom.com/v1");
133    /// ```
134    pub fn base_url(mut self, value: impl Into<String>) -> Self {
135        self.base_url = value.into();
136        self
137    }
138
139    /// Set the API key for authentication (default: empty string)
140    ///
141    /// # Examples
142    ///
143    /// ```
144    /// use tiny_loop::llm::OpenAIProvider;
145    ///
146    /// let provider = OpenAIProvider::new()
147    ///     .api_key("sk-...");
148    /// ```
149    pub fn api_key(mut self, value: impl Into<String>) -> Self {
150        self.api_key = value.into();
151        self
152    }
153
154    /// Set the model name to use (default: `gpt-4o`)
155    ///
156    /// # Examples
157    ///
158    /// ```
159    /// use tiny_loop::llm::OpenAIProvider;
160    ///
161    /// let provider = OpenAIProvider::new()
162    ///     .model("gpt-4o-mini");
163    /// ```
164    pub fn model(mut self, value: impl Into<String>) -> Self {
165        self.model = value.into();
166        self
167    }
168
169    /// Set the temperature for response randomness (default: `None`)
170    ///
171    /// # Examples
172    ///
173    /// ```
174    /// use tiny_loop::llm::OpenAIProvider;
175    ///
176    /// let provider = OpenAIProvider::new()
177    ///     .temperature(0.7);
178    /// ```
179    pub fn temperature(mut self, value: impl Into<Option<f32>>) -> Self {
180        self.temperature = value.into();
181        self
182    }
183
184    /// Set the maximum number of tokens to generate (default: `None`)
185    ///
186    /// # Examples
187    ///
188    /// ```
189    /// use tiny_loop::llm::OpenAIProvider;
190    ///
191    /// let provider = OpenAIProvider::new()
192    ///     .max_tokens(1000);
193    /// ```
194    pub fn max_tokens(mut self, value: impl Into<Option<u32>>) -> Self {
195        self.max_tokens = value.into();
196        self
197    }
198
199    /// Add a custom HTTP header to requests
200    ///
201    /// # Examples
202    ///
203    /// ```
204    /// use tiny_loop::llm::OpenAIProvider;
205    ///
206    /// let provider = OpenAIProvider::new()
207    ///     .header("x-custom-header", "value");
208    /// ```
209    ///
210    /// # Panics
211    ///
212    /// Panics if the header name or value contains invalid characters.
213    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
214        self.custom_headers.insert(
215            HeaderName::try_from(key.into()).unwrap(),
216            HeaderValue::try_from(value.into()).unwrap(),
217        );
218        self
219    }
220
221    /// Set maximum number of retries on failure (default: 3)
222    ///
223    /// # Examples
224    ///
225    /// ```
226    /// use tiny_loop::llm::OpenAIProvider;
227    ///
228    /// let provider = OpenAIProvider::new()
229    ///     .max_retries(5);
230    /// ```
231    pub fn max_retries(mut self, retries: u32) -> Self {
232        self.max_retries = retries;
233        self
234    }
235
236    /// Set delay between retries in milliseconds (default: 1000)
237    ///
238    /// # Examples
239    ///
240    /// ```
241    /// use tiny_loop::llm::OpenAIProvider;
242    ///
243    /// let provider = OpenAIProvider::new()
244    ///     .retry_delay(2000);
245    /// ```
246    pub fn retry_delay(mut self, delay_ms: u64) -> Self {
247        self.retry_delay_ms = delay_ms;
248        self
249    }
250}
251
252#[async_trait]
253impl super::LLMProvider for OpenAIProvider {
254    async fn call(
255        &self,
256        messages: &[Message],
257        tools: &[ToolDefinition],
258        mut stream_callback: Option<&mut super::StreamCallback>,
259    ) -> anyhow::Result<Message> {
260        let mut attempt = 0;
261        loop {
262            attempt += 1;
263            tracing::debug!(
264                model = %self.model,
265                messages = messages.len(),
266                tools = tools.len(),
267                streaming = stream_callback.is_some(),
268                attempt = attempt,
269                max_retries = self.max_retries,
270                "Calling LLM API"
271            );
272
273            match self
274                .call_once(messages, tools, stream_callback.as_deref_mut())
275                .await
276            {
277                Ok(message) => return Ok(message),
278                Err(e) if attempt > self.max_retries => {
279                    tracing::debug!("Max retries exceeded");
280                    return Err(e);
281                }
282                Err(e) => {
283                    tracing::debug!("API call failed, retrying: {}", e);
284                    tokio::time::sleep(tokio::time::Duration::from_millis(self.retry_delay_ms))
285                        .await;
286                }
287            }
288        }
289    }
290}
291
292impl OpenAIProvider {
293    async fn call_once(
294        &self,
295        messages: &[Message],
296        tools: &[ToolDefinition],
297        stream_callback: Option<&mut super::StreamCallback>,
298    ) -> anyhow::Result<Message> {
299        let request = ChatRequest {
300            model: self.model.clone(),
301            messages: messages.to_vec(),
302            tools: tools.to_vec(),
303            temperature: self.temperature,
304            max_tokens: self.max_tokens,
305            stream: if stream_callback.is_some() {
306                Some(true)
307            } else {
308                None
309            },
310        };
311
312        let response = self
313            .client
314            .post(format!("{}/chat/completions", self.base_url))
315            .header("Authorization", format!("Bearer {}", self.api_key))
316            .header("Content-Type", "application/json")
317            .headers(self.custom_headers.clone())
318            .json(&request)
319            .send()
320            .await?;
321
322        let status = response.status();
323        tracing::trace!("LLM API response status: {}", status);
324
325        if !status.is_success() {
326            let body = response.text().await?;
327            tracing::debug!("LLM API error: status={}, body={}", status, body);
328            anyhow::bail!("API error ({}): {}", status, body);
329        }
330
331        if let Some(callback) = stream_callback {
332            self.handle_stream(response, callback).await
333        } else {
334            let body = response.text().await?;
335            let chat_response: ChatResponse = serde_json::from_str(&body)
336                .map_err(|e| anyhow::anyhow!("Failed to parse response: {}. Body: {}", e, body))?;
337            tracing::debug!("LLM API call completed successfully");
338            Ok(chat_response.choices[0].message.clone())
339        }
340    }
341
342    async fn handle_stream(
343        &self,
344        response: reqwest::Response,
345        callback: &mut super::StreamCallback,
346    ) -> anyhow::Result<Message> {
347        use futures::TryStreamExt;
348
349        let mut stream = response.bytes_stream();
350        let mut buffer = String::new();
351        let mut content = String::new();
352        let mut tool_calls = Vec::new();
353
354        while let Some(chunk) = stream.try_next().await? {
355            buffer.push_str(&String::from_utf8_lossy(&chunk));
356
357            while let Some(line_end) = buffer.find('\n') {
358                let line = buffer[..line_end].trim().to_string();
359                buffer.drain(..=line_end);
360
361                if let Some(data) = line.strip_prefix("data: ") {
362                    if data == "[DONE]" {
363                        break;
364                    }
365
366                    if let Ok(chunk) = serde_json::from_str::<StreamChunk>(data) {
367                        if let Some(choice) = chunk.choices.first() {
368                            if let Some(delta_content) = &choice.delta.content {
369                                content.push_str(delta_content);
370                                callback(delta_content.clone());
371                            }
372
373                            if let Some(delta_tool_calls) = &choice.delta.tool_calls {
374                                tool_calls.extend(delta_tool_calls.clone());
375                            }
376                        }
377                    }
378                }
379            }
380        }
381
382        tracing::debug!("Streaming completed, total length: {}", content.len());
383        Ok(Message::Assistant {
384            content,
385            tool_calls: if tool_calls.is_empty() {
386                None
387            } else {
388                Some(tool_calls)
389            },
390        })
391    }
392}