Skip to main content

tiny_loop/llm/
openai.rs

1use crate::types::{FinishReason, LLMResponse, Message, ToolDefinition};
2use async_trait::async_trait;
3use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
4use serde::{Deserialize, Serialize};
5use serde_json::{Map, Value};
6
7/// Callback for streaming OpenAI responses
8pub type OpenAIStreamCallback = Box<dyn FnMut(String) + Send + Sync>;
9
10/// Request payload for OpenAI chat completions API
11#[derive(Serialize)]
12struct ChatRequest {
13    /// Model ID
14    model: String,
15    /// Conversation messages
16    messages: Vec<Message>,
17    /// Available tools for the model
18    tools: Vec<ToolDefinition>,
19    /// Enable streaming
20    #[serde(skip_serializing_if = "Option::is_none")]
21    stream: Option<bool>,
22}
23
24/// Response from OpenAI chat completions API
25#[derive(Deserialize)]
26struct ChatResponse {
27    /// List of completion choices
28    choices: Vec<Choice>,
29}
30
31/// Streaming response chunk
32#[derive(Deserialize)]
33struct StreamChunk {
34    choices: Vec<StreamChoice>,
35}
36
37/// Streaming choice
38#[derive(Deserialize)]
39struct StreamChoice {
40    delta: Delta,
41    #[serde(default)]
42    finish_reason: Option<FinishReason>,
43}
44
45/// Delta content in streaming
46#[derive(Deserialize)]
47struct Delta {
48    #[serde(default)]
49    content: Option<String>,
50    #[serde(default)]
51    tool_calls: Option<Vec<crate::types::ToolCall>>,
52}
53
54/// Single completion choice from the API response
55#[derive(Deserialize)]
56struct Choice {
57    /// Assistant's response message
58    message: Message,
59    /// Reason the completion finished
60    finish_reason: FinishReason,
61}
62
63/// OpenAI-compatible LLM provider
64///
65/// # Examples
66///
67/// ```
68/// use tiny_loop::llm::OpenAIProvider;
69///
70/// let provider = OpenAIProvider::new()
71///     .api_key("sk-...")
72///     .model("gpt-4o");
73/// ```
74pub struct OpenAIProvider {
75    /// HTTP client for API requests
76    client: reqwest::Client,
77    /// API base URL
78    base_url: String,
79    /// API authentication key
80    api_key: String,
81    /// Model identifier
82    model: String,
83    /// Additional HTTP headers
84    custom_headers: HeaderMap,
85    /// Maximum number of retries on failure
86    max_retries: u32,
87    /// Delay between retries in milliseconds
88    retry_delay_ms: u64,
89    /// Custom body fields to merge into the request
90    custom_body: Map<String, Value>,
91    /// Stream callback for LLM responses
92    stream_callback: Option<OpenAIStreamCallback>,
93}
94
95impl Default for OpenAIProvider {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101impl OpenAIProvider {
102    /// Create a new OpenAI provider with default settings
103    ///
104    /// # Examples
105    ///
106    /// ```
107    /// use tiny_loop::llm::OpenAIProvider;
108    ///
109    /// let provider = OpenAIProvider::new();
110    /// ```
111    pub fn new() -> Self {
112        Self {
113            client: reqwest::Client::new(),
114            base_url: "https://api.openai.com/v1".into(),
115            api_key: "".into(),
116            model: "gpt-4o".into(),
117            custom_headers: HeaderMap::new(),
118            max_retries: 3,
119            retry_delay_ms: 1000,
120            custom_body: Map::new(),
121            stream_callback: None,
122        }
123    }
124
125    /// Set the base URL for the API endpoint (default: `https://api.openai.com/v1`)
126    ///
127    /// # Examples
128    ///
129    /// ```
130    /// use tiny_loop::llm::OpenAIProvider;
131    ///
132    /// let provider = OpenAIProvider::new()
133    ///     .base_url("https://api.custom.com/v1");
134    /// ```
135    pub fn base_url(mut self, value: impl Into<String>) -> Self {
136        self.base_url = value.into();
137        self
138    }
139
140    /// Set the API key for authentication (default: empty string)
141    ///
142    /// # Examples
143    ///
144    /// ```
145    /// use tiny_loop::llm::OpenAIProvider;
146    ///
147    /// let provider = OpenAIProvider::new()
148    ///     .api_key("sk-...");
149    /// ```
150    pub fn api_key(mut self, value: impl Into<String>) -> Self {
151        self.api_key = value.into();
152        self
153    }
154
155    /// Set the model name to use (default: `gpt-4o`)
156    ///
157    /// # Examples
158    ///
159    /// ```
160    /// use tiny_loop::llm::OpenAIProvider;
161    ///
162    /// let provider = OpenAIProvider::new()
163    ///     .model("gpt-4o-mini");
164    /// ```
165    pub fn model(mut self, value: impl Into<String>) -> Self {
166        self.model = value.into();
167        self
168    }
169
170    /// Add a custom HTTP header to requests
171    ///
172    /// # Examples
173    ///
174    /// ```
175    /// use tiny_loop::llm::OpenAIProvider;
176    ///
177    /// let provider = OpenAIProvider::new()
178    ///     .header("x-custom-header", "value")
179    ///     .unwrap();
180    /// ```
181    ///
182    /// # Errors
183    ///
184    /// Returns an error if the header name or value contains invalid characters.
185    pub fn header(
186        mut self,
187        key: impl Into<String>,
188        value: impl Into<String>,
189    ) -> anyhow::Result<Self> {
190        self.custom_headers.insert(
191            HeaderName::try_from(key.into())?,
192            HeaderValue::try_from(value.into())?,
193        );
194        Ok(self)
195    }
196
197    /// Set maximum number of retries on failure (default: 3)
198    ///
199    /// # Examples
200    ///
201    /// ```
202    /// use tiny_loop::llm::OpenAIProvider;
203    ///
204    /// let provider = OpenAIProvider::new()
205    ///     .max_retries(5);
206    /// ```
207    pub fn max_retries(mut self, retries: u32) -> Self {
208        self.max_retries = retries;
209        self
210    }
211
212    /// Set delay between retries in milliseconds (default: 1000)
213    ///
214    /// # Examples
215    ///
216    /// ```
217    /// use tiny_loop::llm::OpenAIProvider;
218    ///
219    /// let provider = OpenAIProvider::new()
220    ///     .retry_delay(2000);
221    /// ```
222    pub fn retry_delay(mut self, delay_ms: u64) -> Self {
223        self.retry_delay_ms = delay_ms;
224        self
225    }
226
227    /// Set custom body fields to merge into the request
228    ///
229    /// # Examples
230    ///
231    /// ```
232    /// use tiny_loop::llm::OpenAIProvider;
233    /// use serde_json::json;
234    ///
235    /// let provider = OpenAIProvider::new()
236    ///     .body(json!({
237    ///         "top_p": 0.9,
238    ///         "frequency_penalty": 0.5
239    ///     }))
240    ///     .unwrap();
241    /// ```
242    ///
243    /// # Errors
244    ///
245    /// Returns an error if the value is not a JSON object
246    pub fn body(mut self, body: Value) -> anyhow::Result<Self> {
247        self.custom_body = body
248            .as_object()
249            .ok_or_else(|| anyhow::anyhow!("body must be a JSON object"))?
250            .clone();
251        Ok(self)
252    }
253
254    /// Set stream callback for LLM responses
255    ///
256    /// # Examples
257    ///
258    /// ```
259    /// use tiny_loop::llm::OpenAIProvider;
260    ///
261    /// let provider = OpenAIProvider::new()
262    ///     .stream_callback(|chunk| print!("{}", chunk));
263    /// ```
264    pub fn stream_callback<F>(mut self, callback: F) -> Self
265    where
266        F: FnMut(String) + Send + Sync + 'static,
267    {
268        self.stream_callback = Some(Box::new(callback));
269        self
270    }
271}
272
273#[async_trait]
274impl super::LLMProvider for OpenAIProvider {
275    async fn call(
276        &mut self,
277        messages: &[Message],
278        tools: &[ToolDefinition],
279    ) -> anyhow::Result<LLMResponse> {
280        let mut attempt = 0;
281        loop {
282            attempt += 1;
283            tracing::debug!(
284                model = %self.model,
285                messages = messages.len(),
286                tools = tools.len(),
287                streaming = self.stream_callback.is_some(),
288                attempt = attempt,
289                max_retries = self.max_retries,
290                "Calling LLM API"
291            );
292
293            match self.call_once(messages, tools).await {
294                Ok(response) => return Ok(response),
295                Err(e) if attempt > self.max_retries => {
296                    tracing::debug!("Max retries exceeded");
297                    return Err(e);
298                }
299                Err(e) => {
300                    tracing::debug!("API call failed, retrying: {}", e);
301                    tokio::time::sleep(tokio::time::Duration::from_millis(self.retry_delay_ms))
302                        .await;
303                }
304            }
305        }
306    }
307}
308
309impl OpenAIProvider {
310    async fn call_once(
311        &mut self,
312        messages: &[Message],
313        tools: &[ToolDefinition],
314    ) -> anyhow::Result<LLMResponse> {
315        let request = ChatRequest {
316            model: self.model.clone(),
317            messages: messages.to_vec(),
318            tools: tools.to_vec(),
319            stream: if self.stream_callback.is_some() {
320                Some(true)
321            } else {
322                None
323            },
324        };
325
326        let mut body = serde_json::to_value(&request)?.as_object().unwrap().clone();
327        body.extend(self.custom_body.clone());
328
329        let response = self
330            .client
331            .post(format!("{}/chat/completions", self.base_url))
332            .header("Authorization", format!("Bearer {}", self.api_key))
333            .header("Content-Type", "application/json")
334            .headers(self.custom_headers.clone())
335            .json(&body)
336            .send()
337            .await?;
338
339        let status = response.status();
340        tracing::trace!("LLM API response status: {}", status);
341
342        if !status.is_success() {
343            let body = response.text().await?;
344            tracing::debug!("LLM API error: status={}, body={}", status, body);
345            anyhow::bail!("API error ({}): {}", status, body);
346        }
347
348        if self.stream_callback.is_some() {
349            self.handle_stream(response).await
350        } else {
351            let body = response.text().await?;
352            let chat_response: ChatResponse = serde_json::from_str(&body)
353                .map_err(|e| anyhow::anyhow!("Failed to parse response: {}. Body: {}", e, body))?;
354            tracing::debug!("LLM API call completed successfully");
355            let choice = &chat_response.choices[0];
356            let Message::Assistant(msg) = &choice.message else {
357                anyhow::bail!("Expected Assistant message, got: {:?}", choice.message);
358            };
359            Ok(LLMResponse {
360                message: msg.clone(),
361                finish_reason: choice.finish_reason.clone(),
362            })
363        }
364    }
365
366    async fn handle_stream(&mut self, response: reqwest::Response) -> anyhow::Result<LLMResponse> {
367        use futures::TryStreamExt;
368
369        let mut stream = response.bytes_stream();
370        let mut buffer = String::new();
371        let mut content = String::new();
372        let mut tool_calls = Vec::new();
373        let mut finish_reason = FinishReason::Stop;
374
375        while let Some(chunk) = stream.try_next().await? {
376            buffer.push_str(&String::from_utf8_lossy(&chunk));
377
378            while let Some(line_end) = buffer.find('\n') {
379                let line = buffer[..line_end].trim().to_string();
380                buffer.drain(..=line_end);
381
382                if let Some(data) = line.strip_prefix("data: ") {
383                    if data == "[DONE]" {
384                        break;
385                    }
386
387                    if let Ok(chunk) = serde_json::from_str::<StreamChunk>(data) {
388                        if let Some(choice) = chunk.choices.first() {
389                            if let Some(delta_content) = &choice.delta.content {
390                                content.push_str(delta_content);
391                                if let Some(callback) = &mut self.stream_callback {
392                                    callback(delta_content.clone());
393                                }
394                            }
395
396                            if let Some(delta_tool_calls) = &choice.delta.tool_calls {
397                                tool_calls.extend(delta_tool_calls.clone());
398                            }
399
400                            if let Some(reason) = &choice.finish_reason {
401                                finish_reason = reason.clone();
402                            }
403                        }
404                    }
405                }
406            }
407        }
408
409        tracing::debug!("Streaming completed, total length: {}", content.len());
410        Ok(LLMResponse {
411            message: crate::types::AssistantMessage {
412                content,
413                tool_calls: if tool_calls.is_empty() {
414                    None
415                } else {
416                    Some(tool_calls)
417                },
418            },
419            finish_reason,
420        })
421    }
422}