Skip to main content

tiny_loop/llm/
openai.rs

1use crate::types::{Message, ToolDefinition};
2use async_trait::async_trait;
3use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
4use serde::{Deserialize, Serialize};
5use serde_json::{Map, Value};
6
7/// Request payload for OpenAI chat completions API
8#[derive(Serialize)]
9struct ChatRequest {
10    /// Model ID
11    model: String,
12    /// Conversation messages
13    messages: Vec<Message>,
14    /// Available tools for the model
15    tools: Vec<ToolDefinition>,
16    /// Sampling temperature (0-2)
17    #[serde(skip_serializing_if = "Option::is_none")]
18    temperature: Option<f32>,
19    /// Maximum tokens to generate
20    #[serde(skip_serializing_if = "Option::is_none")]
21    max_tokens: Option<u32>,
22    /// Enable streaming
23    #[serde(skip_serializing_if = "Option::is_none")]
24    stream: Option<bool>,
25}
26
27/// Response from OpenAI chat completions API
28#[derive(Deserialize)]
29struct ChatResponse {
30    /// List of completion choices
31    choices: Vec<Choice>,
32}
33
34/// Streaming response chunk
35#[derive(Deserialize)]
36struct StreamChunk {
37    choices: Vec<StreamChoice>,
38}
39
40/// Streaming choice
41#[derive(Deserialize)]
42struct StreamChoice {
43    delta: Delta,
44}
45
46/// Delta content in streaming
47#[derive(Deserialize)]
48struct Delta {
49    #[serde(default)]
50    content: Option<String>,
51    #[serde(default)]
52    tool_calls: Option<Vec<crate::types::ToolCall>>,
53}
54
55/// Single completion choice from the API response
56#[derive(Deserialize)]
57struct Choice {
58    /// Assistant's response message
59    message: Message,
60}
61
62/// OpenAI-compatible LLM provider
63///
64/// # Examples
65///
66/// ```
67/// use tiny_loop::llm::OpenAIProvider;
68///
69/// let provider = OpenAIProvider::new()
70///     .api_key("sk-...")
71///     .model("gpt-4o")
72///     .temperature(0.7);
73/// ```
74pub struct OpenAIProvider {
75    /// HTTP client for API requests
76    client: reqwest::Client,
77    /// API base URL
78    base_url: String,
79    /// API authentication key
80    api_key: String,
81    /// Model identifier
82    model: String,
83    /// Sampling temperature
84    temperature: Option<f32>,
85    /// Maximum tokens to generate
86    max_tokens: Option<u32>,
87    /// Additional HTTP headers
88    custom_headers: HeaderMap,
89    /// Maximum number of retries on failure
90    max_retries: u32,
91    /// Delay between retries in milliseconds
92    retry_delay_ms: u64,
93    /// Custom body fields to merge into the request
94    custom_body: Map<String, Value>,
95}
96
97impl Default for OpenAIProvider {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103impl OpenAIProvider {
104    /// Create a new OpenAI provider with default settings
105    ///
106    /// # Examples
107    ///
108    /// ```
109    /// use tiny_loop::llm::OpenAIProvider;
110    ///
111    /// let provider = OpenAIProvider::new();
112    /// ```
113    pub fn new() -> Self {
114        Self {
115            client: reqwest::Client::new(),
116            base_url: "https://api.openai.com/v1".into(),
117            api_key: "".into(),
118            model: "gpt-4o".into(),
119            temperature: None,
120            max_tokens: None,
121            custom_headers: HeaderMap::new(),
122            max_retries: 3,
123            retry_delay_ms: 1000,
124            custom_body: Map::new(),
125        }
126    }
127
128    /// Set the base URL for the API endpoint (default: `https://api.openai.com/v1`)
129    ///
130    /// # Examples
131    ///
132    /// ```
133    /// use tiny_loop::llm::OpenAIProvider;
134    ///
135    /// let provider = OpenAIProvider::new()
136    ///     .base_url("https://api.custom.com/v1");
137    /// ```
138    pub fn base_url(mut self, value: impl Into<String>) -> Self {
139        self.base_url = value.into();
140        self
141    }
142
143    /// Set the API key for authentication (default: empty string)
144    ///
145    /// # Examples
146    ///
147    /// ```
148    /// use tiny_loop::llm::OpenAIProvider;
149    ///
150    /// let provider = OpenAIProvider::new()
151    ///     .api_key("sk-...");
152    /// ```
153    pub fn api_key(mut self, value: impl Into<String>) -> Self {
154        self.api_key = value.into();
155        self
156    }
157
158    /// Set the model name to use (default: `gpt-4o`)
159    ///
160    /// # Examples
161    ///
162    /// ```
163    /// use tiny_loop::llm::OpenAIProvider;
164    ///
165    /// let provider = OpenAIProvider::new()
166    ///     .model("gpt-4o-mini");
167    /// ```
168    pub fn model(mut self, value: impl Into<String>) -> Self {
169        self.model = value.into();
170        self
171    }
172
173    /// Set the temperature for response randomness (default: `None`)
174    ///
175    /// # Examples
176    ///
177    /// ```
178    /// use tiny_loop::llm::OpenAIProvider;
179    ///
180    /// let provider = OpenAIProvider::new()
181    ///     .temperature(0.7);
182    /// ```
183    pub fn temperature(mut self, value: impl Into<Option<f32>>) -> Self {
184        self.temperature = value.into();
185        self
186    }
187
188    /// Set the maximum number of tokens to generate (default: `None`)
189    ///
190    /// # Examples
191    ///
192    /// ```
193    /// use tiny_loop::llm::OpenAIProvider;
194    ///
195    /// let provider = OpenAIProvider::new()
196    ///     .max_tokens(1000);
197    /// ```
198    pub fn max_tokens(mut self, value: impl Into<Option<u32>>) -> Self {
199        self.max_tokens = value.into();
200        self
201    }
202
203    /// Add a custom HTTP header to requests
204    ///
205    /// # Examples
206    ///
207    /// ```
208    /// use tiny_loop::llm::OpenAIProvider;
209    ///
210    /// let provider = OpenAIProvider::new()
211    ///     .header("x-custom-header", "value")
212    ///     .unwrap();
213    /// ```
214    ///
215    /// # Errors
216    ///
217    /// Returns an error if the header name or value contains invalid characters.
218    pub fn header(
219        mut self,
220        key: impl Into<String>,
221        value: impl Into<String>,
222    ) -> anyhow::Result<Self> {
223        self.custom_headers.insert(
224            HeaderName::try_from(key.into())?,
225            HeaderValue::try_from(value.into())?,
226        );
227        Ok(self)
228    }
229
230    /// Set maximum number of retries on failure (default: 3)
231    ///
232    /// # Examples
233    ///
234    /// ```
235    /// use tiny_loop::llm::OpenAIProvider;
236    ///
237    /// let provider = OpenAIProvider::new()
238    ///     .max_retries(5);
239    /// ```
240    pub fn max_retries(mut self, retries: u32) -> Self {
241        self.max_retries = retries;
242        self
243    }
244
245    /// Set delay between retries in milliseconds (default: 1000)
246    ///
247    /// # Examples
248    ///
249    /// ```
250    /// use tiny_loop::llm::OpenAIProvider;
251    ///
252    /// let provider = OpenAIProvider::new()
253    ///     .retry_delay(2000);
254    /// ```
255    pub fn retry_delay(mut self, delay_ms: u64) -> Self {
256        self.retry_delay_ms = delay_ms;
257        self
258    }
259
260    /// Set custom body fields to merge into the request
261    ///
262    /// # Examples
263    ///
264    /// ```
265    /// use tiny_loop::llm::OpenAIProvider;
266    /// use serde_json::json;
267    ///
268    /// let provider = OpenAIProvider::new()
269    ///     .body(json!({
270    ///         "top_p": 0.9,
271    ///         "frequency_penalty": 0.5
272    ///     }))
273    ///     .unwrap();
274    /// ```
275    ///
276    /// # Errors
277    ///
278    /// Returns an error if the value is not a JSON object
279    pub fn body(mut self, body: Value) -> anyhow::Result<Self> {
280        self.custom_body = body
281            .as_object()
282            .ok_or_else(|| anyhow::anyhow!("body must be a JSON object"))?
283            .clone();
284        Ok(self)
285    }
286}
287
288#[async_trait]
289impl super::LLMProvider for OpenAIProvider {
290    async fn call(
291        &self,
292        messages: &[Message],
293        tools: &[ToolDefinition],
294        mut stream_callback: Option<&mut super::StreamCallback>,
295    ) -> anyhow::Result<Message> {
296        let mut attempt = 0;
297        loop {
298            attempt += 1;
299            tracing::debug!(
300                model = %self.model,
301                messages = messages.len(),
302                tools = tools.len(),
303                streaming = stream_callback.is_some(),
304                attempt = attempt,
305                max_retries = self.max_retries,
306                "Calling LLM API"
307            );
308
309            match self
310                .call_once(messages, tools, stream_callback.as_deref_mut())
311                .await
312            {
313                Ok(message) => return Ok(message),
314                Err(e) if attempt > self.max_retries => {
315                    tracing::debug!("Max retries exceeded");
316                    return Err(e);
317                }
318                Err(e) => {
319                    tracing::debug!("API call failed, retrying: {}", e);
320                    tokio::time::sleep(tokio::time::Duration::from_millis(self.retry_delay_ms))
321                        .await;
322                }
323            }
324        }
325    }
326}
327
328impl OpenAIProvider {
329    async fn call_once(
330        &self,
331        messages: &[Message],
332        tools: &[ToolDefinition],
333        stream_callback: Option<&mut super::StreamCallback>,
334    ) -> anyhow::Result<Message> {
335        let request = ChatRequest {
336            model: self.model.clone(),
337            messages: messages.to_vec(),
338            tools: tools.to_vec(),
339            temperature: self.temperature,
340            max_tokens: self.max_tokens,
341            stream: if stream_callback.is_some() {
342                Some(true)
343            } else {
344                None
345            },
346        };
347
348        let mut body = serde_json::to_value(&request)?.as_object().unwrap().clone();
349        body.extend(self.custom_body.clone());
350
351        let response = self
352            .client
353            .post(format!("{}/chat/completions", self.base_url))
354            .header("Authorization", format!("Bearer {}", self.api_key))
355            .header("Content-Type", "application/json")
356            .headers(self.custom_headers.clone())
357            .json(&body)
358            .send()
359            .await?;
360
361        let status = response.status();
362        tracing::trace!("LLM API response status: {}", status);
363
364        if !status.is_success() {
365            let body = response.text().await?;
366            tracing::debug!("LLM API error: status={}, body={}", status, body);
367            anyhow::bail!("API error ({}): {}", status, body);
368        }
369
370        if let Some(callback) = stream_callback {
371            self.handle_stream(response, callback).await
372        } else {
373            let body = response.text().await?;
374            let chat_response: ChatResponse = serde_json::from_str(&body)
375                .map_err(|e| anyhow::anyhow!("Failed to parse response: {}. Body: {}", e, body))?;
376            tracing::debug!("LLM API call completed successfully");
377            Ok(chat_response.choices[0].message.clone())
378        }
379    }
380
381    async fn handle_stream(
382        &self,
383        response: reqwest::Response,
384        callback: &mut super::StreamCallback,
385    ) -> anyhow::Result<Message> {
386        use futures::TryStreamExt;
387
388        let mut stream = response.bytes_stream();
389        let mut buffer = String::new();
390        let mut content = String::new();
391        let mut tool_calls = Vec::new();
392
393        while let Some(chunk) = stream.try_next().await? {
394            buffer.push_str(&String::from_utf8_lossy(&chunk));
395
396            while let Some(line_end) = buffer.find('\n') {
397                let line = buffer[..line_end].trim().to_string();
398                buffer.drain(..=line_end);
399
400                if let Some(data) = line.strip_prefix("data: ") {
401                    if data == "[DONE]" {
402                        break;
403                    }
404
405                    if let Ok(chunk) = serde_json::from_str::<StreamChunk>(data) {
406                        if let Some(choice) = chunk.choices.first() {
407                            if let Some(delta_content) = &choice.delta.content {
408                                content.push_str(delta_content);
409                                callback(delta_content.clone());
410                            }
411
412                            if let Some(delta_tool_calls) = &choice.delta.tool_calls {
413                                tool_calls.extend(delta_tool_calls.clone());
414                            }
415                        }
416                    }
417                }
418            }
419        }
420
421        tracing::debug!("Streaming completed, total length: {}", content.len());
422        Ok(Message::Assistant {
423            content,
424            tool_calls: if tool_calls.is_empty() {
425                None
426            } else {
427                Some(tool_calls)
428            },
429        })
430    }
431}