steer_core/api/openai/
client.rs

1use async_trait::async_trait;
2use reqwest::{self, header};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use tokio_util::sync::CancellationToken;
6use tracing::{debug, error};
7
8use crate::api::Model;
9use crate::api::error::ApiError;
10use crate::api::provider::{CompletionResponse, Provider};
11use crate::app::conversation::{
12    AssistantContent, Message as AppMessage, ThoughtContent, ToolResult, UserContent,
13};
14use steer_tools::ToolSchema;
15
16const API_URL: &str = "https://api.openai.com/v1/chat/completions";
17
18#[derive(Clone)]
19pub struct OpenAIClient {
20    http_client: reqwest::Client,
21}
22
23// OpenAI-specific message format
24#[derive(Debug, Serialize, Deserialize)]
25#[serde(tag = "role", rename_all = "lowercase")]
26enum OpenAIMessage {
27    System {
28        content: OpenAIContent,
29        #[serde(skip_serializing_if = "Option::is_none")]
30        name: Option<String>,
31    },
32    User {
33        content: OpenAIContent,
34        #[serde(skip_serializing_if = "Option::is_none")]
35        name: Option<String>,
36    },
37    Assistant {
38        #[serde(skip_serializing_if = "Option::is_none")]
39        content: Option<OpenAIContent>,
40        #[serde(skip_serializing_if = "Option::is_none")]
41        tool_calls: Option<Vec<OpenAIToolCall>>,
42        #[serde(skip_serializing_if = "Option::is_none")]
43        name: Option<String>,
44    },
45    Tool {
46        content: OpenAIContent,
47        tool_call_id: String,
48        #[serde(skip_serializing_if = "Option::is_none")]
49        name: Option<String>,
50    },
51}
52
53// OpenAI content can be a string or an array of content parts
54#[derive(Debug, Serialize, Deserialize)]
55#[serde(untagged)]
56enum OpenAIContent {
57    String(String),
58    Array(Vec<OpenAIContentPart>),
59}
60
61// OpenAI content parts for multi-modal messages
62#[derive(Debug, Serialize, Deserialize)]
63#[serde(tag = "type")]
64enum OpenAIContentPart {
65    #[serde(rename = "text")]
66    Text { text: String },
67}
68
69// OpenAI function calling format
70#[derive(Debug, Serialize, Deserialize)]
71struct OpenAIFunction {
72    name: String,
73    description: String,
74    parameters: serde_json::Value,
75}
76
77// OpenAI tool format
78#[derive(Debug, Serialize, Deserialize)]
79struct OpenAITool {
80    #[serde(rename = "type")]
81    tool_type: String, // "function"
82    function: OpenAIFunction,
83}
84
85// OpenAI tool call
86#[derive(Debug, Serialize, Deserialize)]
87struct OpenAIToolCall {
88    id: String,
89    #[serde(rename = "type")]
90    tool_type: String,
91    function: OpenAIFunctionCall,
92}
93
94#[derive(Debug, Serialize, Deserialize)]
95struct OpenAIFunctionCall {
96    name: String,
97    arguments: String, // JSON string
98}
99
100#[derive(Debug, Serialize, Deserialize)]
101#[serde(rename_all = "lowercase")]
102enum ReasoningEffort {
103    Low,
104    Medium,
105    High,
106}
107
108#[derive(Debug, Serialize, Deserialize)]
109#[serde(rename_all = "lowercase")]
110enum ServiceTier {
111    Auto,
112    Default,
113    Flex,
114}
115
116#[derive(Debug, Serialize, Deserialize)]
117struct AudioOutput {
118    #[serde(skip_serializing_if = "Option::is_none")]
119    voice: Option<String>,
120    #[serde(skip_serializing_if = "Option::is_none")]
121    format: Option<String>,
122}
123
124#[derive(Debug, Serialize, Deserialize)]
125#[serde(untagged)]
126enum StopSequences {
127    Single(String),
128    Multiple(Vec<String>),
129}
130
131#[derive(Debug, Serialize, Deserialize)]
132struct StreamOptions {
133    #[serde(skip_serializing_if = "Option::is_none")]
134    include_usage: Option<bool>,
135}
136
137#[derive(Debug, Serialize, Deserialize)]
138#[serde(untagged)]
139enum ToolChoice {
140    #[serde(rename = "auto")]
141    Auto,
142    #[serde(rename = "required")]
143    Required,
144    Specific {
145        #[serde(rename = "type")]
146        tool_type: String,
147        function: ToolChoiceFunction,
148    },
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152struct ToolChoiceFunction {
153    name: String,
154}
155
156#[derive(Debug, Serialize, Deserialize)]
157#[serde(untagged)]
158enum ResponseFormat {
159    JsonObject {
160        #[serde(rename = "type")]
161        format_type: String, // "json_object"
162    },
163    JsonSchema {
164        #[serde(rename = "type")]
165        format_type: String, // "json_schema"
166        json_schema: serde_json::Value,
167    },
168}
169
170#[derive(Debug, Serialize, Deserialize)]
171#[serde(rename_all = "lowercase")]
172enum PredictionType {
173    Content,
174}
175
176#[derive(Debug, Serialize, Deserialize)]
177#[serde(untagged)]
178enum Prediction {
179    Content {
180        #[serde(rename = "type")]
181        prediction_type: PredictionType,
182        content: String,
183    },
184}
185
186#[derive(Debug, Serialize, Deserialize)]
187struct WebSearchOptions {
188    #[serde(skip_serializing_if = "Option::is_none")]
189    max_results: Option<u32>,
190}
191
192#[derive(Debug, Serialize, Deserialize)]
193struct CompletionRequest {
194    model: String,
195    messages: Vec<OpenAIMessage>,
196    #[serde(skip_serializing_if = "Option::is_none")]
197    audio: Option<AudioOutput>,
198    #[serde(skip_serializing_if = "Option::is_none")]
199    frequency_penalty: Option<f32>,
200    #[serde(skip_serializing_if = "Option::is_none")]
201    logit_bias: Option<HashMap<String, f32>>,
202    #[serde(skip_serializing_if = "Option::is_none")]
203    logprobs: Option<bool>,
204    #[serde(skip_serializing_if = "Option::is_none")]
205    max_completion_tokens: Option<u32>,
206    #[serde(skip_serializing_if = "Option::is_none")]
207    metadata: Option<HashMap<String, String>>,
208    #[serde(skip_serializing_if = "Option::is_none")]
209    modalities: Option<Vec<String>>,
210    #[serde(skip_serializing_if = "Option::is_none")]
211    n: Option<u32>,
212    #[serde(skip_serializing_if = "Option::is_none")]
213    parallel_tool_calls: Option<bool>,
214    #[serde(skip_serializing_if = "Option::is_none")]
215    prediction: Option<Prediction>,
216    #[serde(skip_serializing_if = "Option::is_none")]
217    presence_penalty: Option<f32>,
218    #[serde(skip_serializing_if = "Option::is_none")]
219    reasoning_effort: Option<ReasoningEffort>,
220    #[serde(skip_serializing_if = "Option::is_none")]
221    response_format: Option<ResponseFormat>,
222    #[serde(skip_serializing_if = "Option::is_none")]
223    seed: Option<u64>,
224    #[serde(skip_serializing_if = "Option::is_none")]
225    service_tier: Option<ServiceTier>,
226    #[serde(skip_serializing_if = "Option::is_none")]
227    stop: Option<StopSequences>,
228    #[serde(skip_serializing_if = "Option::is_none")]
229    store: Option<bool>,
230    #[serde(skip_serializing_if = "Option::is_none")]
231    stream: Option<bool>,
232    #[serde(skip_serializing_if = "Option::is_none")]
233    stream_options: Option<StreamOptions>,
234    #[serde(skip_serializing_if = "Option::is_none")]
235    temperature: Option<f32>,
236    #[serde(skip_serializing_if = "Option::is_none")]
237    tool_choice: Option<ToolChoice>,
238    #[serde(skip_serializing_if = "Option::is_none")]
239    tools: Option<Vec<OpenAITool>>,
240    #[serde(skip_serializing_if = "Option::is_none")]
241    top_logprobs: Option<u32>,
242    #[serde(skip_serializing_if = "Option::is_none")]
243    top_p: Option<f32>,
244    #[serde(skip_serializing_if = "Option::is_none")]
245    user: Option<String>,
246    #[serde(skip_serializing_if = "Option::is_none")]
247    web_search_options: Option<WebSearchOptions>,
248}
249
250#[derive(Debug, Serialize, Deserialize)]
251struct OpenAICompletionResponse {
252    id: String,
253    object: String,
254    created: u64,
255    model: String,
256    choices: Vec<Choice>,
257    usage: OpenAIUsage,
258}
259
260#[derive(Debug, Serialize, Deserialize)]
261struct Choice {
262    index: i32,
263    message: AssistantMessage,
264    finish_reason: Option<String>,
265}
266
267#[derive(Debug, Serialize, Deserialize)]
268struct AssistantMessage {
269    content: Option<String>,
270    #[serde(skip_serializing_if = "Option::is_none")]
271    tool_calls: Option<Vec<OpenAIToolCall>>,
272    #[serde(skip_serializing_if = "Option::is_none")]
273    reasoning_content: Option<String>,
274}
275
276#[derive(Debug, Serialize, Deserialize)]
277struct PromptTokensDetails {
278    cached_tokens: usize,
279    audio_tokens: usize,
280}
281
282#[derive(Debug, Serialize, Deserialize)]
283struct CompletionTokensDetails {
284    reasoning_tokens: usize,
285    audio_tokens: usize,
286    accepted_prediction_tokens: usize,
287    rejected_prediction_tokens: usize,
288}
289
290#[derive(Debug, Serialize, Deserialize)]
291struct OpenAIUsage {
292    prompt_tokens: usize,
293    completion_tokens: usize,
294    total_tokens: usize,
295    #[serde(skip_serializing_if = "Option::is_none")]
296    prompt_tokens_details: Option<PromptTokensDetails>,
297    #[serde(skip_serializing_if = "Option::is_none")]
298    completion_tokens_details: Option<CompletionTokensDetails>,
299}
300
301impl OpenAIClient {
302    pub fn new(api_key: String) -> Self {
303        let mut headers = header::HeaderMap::new();
304        headers.insert(
305            header::AUTHORIZATION,
306            header::HeaderValue::from_str(&format!("Bearer {api_key}"))
307                .expect("Invalid API key format"),
308        );
309
310        let client = reqwest::Client::builder()
311            .default_headers(headers)
312            .timeout(std::time::Duration::from_secs(300)) // 5 minute timeout for o3
313            .build()
314            .expect("Failed to build HTTP client");
315
316        Self {
317            http_client: client,
318        }
319    }
320
321    fn convert_messages(
322        &self,
323        messages: Vec<AppMessage>,
324        system: Option<String>,
325    ) -> Vec<OpenAIMessage> {
326        let mut openai_messages = Vec::new();
327
328        // Add system message if provided
329        if let Some(system_content) = system {
330            openai_messages.push(OpenAIMessage::System {
331                content: OpenAIContent::String(system_content),
332                name: None,
333            });
334        }
335
336        // Convert our messages to OpenAI format
337        for message in messages {
338            match &message.data {
339                crate::app::conversation::MessageData::User { content, .. } => {
340                    // Convert UserContent to text
341                    let combined_text = content
342                        .iter()
343                        .filter_map(|user_content| match user_content {
344                            UserContent::Text { text } => Some(text.clone()),
345                            UserContent::CommandExecution {
346                                command,
347                                stdout,
348                                stderr,
349                                exit_code,
350                            } => Some(UserContent::format_command_execution_as_xml(
351                                command, stdout, stderr, *exit_code,
352                            )),
353                            UserContent::AppCommand { .. } => {
354                                // Don't send app commands to the model - they're for local execution only
355                                None
356                            }
357                        })
358                        .collect::<Vec<_>>()
359                        .join("\n");
360
361                    // Only add the message if it has content after filtering
362                    if !combined_text.trim().is_empty() {
363                        openai_messages.push(OpenAIMessage::User {
364                            content: OpenAIContent::String(combined_text),
365                            name: None,
366                        });
367                    }
368                }
369                crate::app::conversation::MessageData::Assistant { content, .. } => {
370                    // Convert AssistantContent to OpenAI format
371                    let mut text_parts = Vec::new();
372                    let mut tool_calls = Vec::new();
373
374                    for content_block in content {
375                        match content_block {
376                            AssistantContent::Text { text } => {
377                                text_parts.push(text.clone());
378                            }
379                            AssistantContent::ToolCall { tool_call } => {
380                                tool_calls.push(OpenAIToolCall {
381                                    id: tool_call.id.clone(),
382                                    tool_type: "function".to_string(),
383                                    function: OpenAIFunctionCall {
384                                        name: tool_call.name.clone(),
385                                        arguments: tool_call.parameters.to_string(),
386                                    },
387                                });
388                            }
389                            AssistantContent::Thought { .. } => {
390                                // Skip
391                                continue;
392                            }
393                        }
394                    }
395
396                    // Build the assistant message
397                    let content = if text_parts.is_empty() {
398                        None
399                    } else {
400                        Some(OpenAIContent::String(text_parts.join("\n")))
401                    };
402
403                    let tool_calls_opt = if tool_calls.is_empty() {
404                        None
405                    } else {
406                        Some(tool_calls)
407                    };
408
409                    openai_messages.push(OpenAIMessage::Assistant {
410                        content,
411                        tool_calls: tool_calls_opt,
412                        name: None,
413                    });
414                }
415                crate::app::conversation::MessageData::Tool {
416                    tool_use_id,
417                    result,
418                    ..
419                } => {
420                    // Convert ToolResult to OpenAI format
421                    let content_text = match result {
422                        ToolResult::Error(e) => format!("Error: {e}"),
423                        _ => {
424                            let text = result.llm_format();
425                            if text.trim().is_empty() {
426                                "(No output)".to_string()
427                            } else {
428                                text
429                            }
430                        }
431                    };
432
433                    openai_messages.push(OpenAIMessage::Tool {
434                        content: OpenAIContent::String(content_text),
435                        tool_call_id: tool_use_id.clone(),
436                        name: None,
437                    });
438                }
439            }
440        }
441
442        openai_messages
443    }
444
445    fn convert_tools(&self, tools: Vec<ToolSchema>) -> Vec<OpenAITool> {
446        tools
447            .into_iter()
448            .map(|tool| OpenAITool {
449                tool_type: "function".to_string(),
450                function: OpenAIFunction {
451                    name: tool.name,
452                    description: tool.description,
453                    parameters: serde_json::json!({
454                        "type": tool.input_schema.schema_type,
455                        "properties": tool.input_schema.properties,
456                        "required": tool.input_schema.required,
457                    }),
458                },
459            })
460            .collect()
461    }
462}
463
464#[async_trait]
465impl Provider for OpenAIClient {
466    fn name(&self) -> &'static str {
467        "openai"
468    }
469
470    async fn complete(
471        &self,
472        model: Model,
473        messages: Vec<AppMessage>,
474        system: Option<String>,
475        tools: Option<Vec<ToolSchema>>,
476        token: CancellationToken,
477    ) -> Result<CompletionResponse, ApiError> {
478        // <-- Use ApiError
479        let openai_messages = self.convert_messages(messages, system);
480        let openai_tools = tools.map(|t| self.convert_tools(t));
481
482        let request = if model.supports_thinking() {
483            CompletionRequest {
484                model: model.as_ref().to_string(),
485                messages: openai_messages,
486                audio: None,
487                frequency_penalty: None,
488                logit_bias: None,
489                logprobs: None,
490                max_completion_tokens: Some(32_000), // May need to tweak based on context window
491                metadata: None,
492                modalities: None,
493                n: None,
494                parallel_tool_calls: None,
495                prediction: None,
496                presence_penalty: None,
497                reasoning_effort: Some(ReasoningEffort::High),
498                response_format: None,
499                seed: None,
500                service_tier: None,
501                stop: None,
502                store: None,
503                stream: None,
504                stream_options: None,
505                temperature: Some(1.0),
506                tool_choice: None,
507                tools: openai_tools,
508                top_logprobs: None,
509                top_p: None,
510                user: None,
511                web_search_options: None,
512            }
513        } else {
514            CompletionRequest {
515                model: model.as_ref().to_string(),
516                messages: openai_messages,
517                audio: None,
518                frequency_penalty: None,
519                logit_bias: None,
520                logprobs: None,
521                max_completion_tokens: None,
522                metadata: None,
523                modalities: None,
524                n: None,
525                parallel_tool_calls: None,
526                prediction: None,
527                presence_penalty: None,
528                reasoning_effort: None,
529                response_format: None,
530                seed: None,
531                service_tier: None,
532                stop: None,
533                store: None,
534                stream: None,
535                stream_options: None,
536                temperature: Some(1.0),
537                tool_choice: None,
538                tools: openai_tools,
539                top_logprobs: None,
540                top_p: None,
541                user: None,
542                web_search_options: None,
543            }
544        };
545
546        let response = self
547            .http_client
548            .post(API_URL)
549            .json(&request)
550            .send()
551            .await
552            .map_err(ApiError::Network)?;
553
554        if !response.status().is_success() {
555            let status = response.status();
556            let error_text = response.text().await.unwrap_or_else(|_| String::new());
557
558            debug!(
559                target: "openai::complete",
560                "OpenAI API error - Status: {}, Body: {}",
561                status,
562                error_text
563            );
564
565            return match status.as_u16() {
566                429 => Err(ApiError::RateLimited {
567                    provider: self.name().to_string(),
568                    details: error_text,
569                }),
570                400 => Err(ApiError::InvalidRequest {
571                    provider: self.name().to_string(),
572                    details: error_text,
573                }),
574                401 => Err(ApiError::AuthenticationFailed {
575                    provider: self.name().to_string(),
576                    details: error_text,
577                }),
578                _ => Err(ApiError::ServerError {
579                    provider: self.name().to_string(),
580                    status_code: status.as_u16(),
581                    details: error_text,
582                }),
583            };
584        }
585
586        let response_text = tokio::select! {
587            _ = token.cancelled() => {
588                debug!(target: "openai::complete", "Cancellation token triggered while reading successful response body.");
589                return Err(ApiError::Cancelled { provider: self.name().to_string() });
590            }
591            text_res = response.text() => {
592                text_res?
593            }
594        };
595
596        let openai_response: OpenAICompletionResponse = serde_json::from_str(&response_text)
597            .map_err(|e| {
598                error!(
599                    target: "openai::complete",
600                    "Failed to parse response: {}, Body: {}",
601                    e,
602                    response_text
603                );
604                ApiError::ResponseParsingError {
605                    provider: self.name().to_string(),
606                    details: format!("Error: {e}, Body: {response_text}"),
607                }
608            })?;
609
610        // Convert OpenAI response to our CompletionResponse
611        if let Some(choice) = openai_response.choices.first() {
612            let mut content_blocks = Vec::new();
613
614            // Add reasoning content if present (convert to thought)
615            if let Some(reasoning) = &choice.message.reasoning_content {
616                content_blocks.push(AssistantContent::Thought {
617                    thought: ThoughtContent::Simple {
618                        text: reasoning.clone(),
619                    },
620                });
621            }
622
623            // Add regular content
624            if let Some(content) = &choice.message.content {
625                if !content.trim().is_empty() {
626                    content_blocks.push(AssistantContent::Text {
627                        text: content.clone(),
628                    });
629                }
630            }
631
632            // Add tool calls
633            if let Some(tool_calls) = &choice.message.tool_calls {
634                for tool_call in tool_calls {
635                    // Parse the arguments JSON string
636                    let parameters = serde_json::from_str(&tool_call.function.arguments)
637                        .unwrap_or(serde_json::Value::Null);
638
639                    content_blocks.push(AssistantContent::ToolCall {
640                        tool_call: steer_tools::ToolCall {
641                            id: tool_call.id.clone(),
642                            name: tool_call.function.name.clone(),
643                            parameters,
644                        },
645                    });
646                }
647            }
648
649            Ok(crate::api::provider::CompletionResponse {
650                content: content_blocks,
651            })
652        } else {
653            Err(ApiError::NoChoices {
654                provider: self.name().to_string(),
655            })
656        }
657    }
658}