1use async_trait::async_trait;
2use reqwest::{self, header};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use tokio_util::sync::CancellationToken;
6use tracing::{debug, error};
7
8use crate::api::Model;
9use crate::api::error::ApiError;
10use crate::api::provider::{CompletionResponse, Provider};
11use crate::app::conversation::{
12 AssistantContent, Message as AppMessage, ThoughtContent, ToolResult, UserContent,
13};
14use steer_tools::ToolSchema;
15
16const API_URL: &str = "https://api.openai.com/v1/chat/completions";
17
18#[derive(Clone)]
19pub struct OpenAIClient {
20 http_client: reqwest::Client,
21}
22
23#[derive(Debug, Serialize, Deserialize)]
25#[serde(tag = "role", rename_all = "lowercase")]
26enum OpenAIMessage {
27 System {
28 content: OpenAIContent,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 name: Option<String>,
31 },
32 User {
33 content: OpenAIContent,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 name: Option<String>,
36 },
37 Assistant {
38 #[serde(skip_serializing_if = "Option::is_none")]
39 content: Option<OpenAIContent>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 tool_calls: Option<Vec<OpenAIToolCall>>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 name: Option<String>,
44 },
45 Tool {
46 content: OpenAIContent,
47 tool_call_id: String,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 name: Option<String>,
50 },
51}
52
53#[derive(Debug, Serialize, Deserialize)]
55#[serde(untagged)]
56enum OpenAIContent {
57 String(String),
58 Array(Vec<OpenAIContentPart>),
59}
60
61#[derive(Debug, Serialize, Deserialize)]
63#[serde(tag = "type")]
64enum OpenAIContentPart {
65 #[serde(rename = "text")]
66 Text { text: String },
67}
68
69#[derive(Debug, Serialize, Deserialize)]
71struct OpenAIFunction {
72 name: String,
73 description: String,
74 parameters: serde_json::Value,
75}
76
77#[derive(Debug, Serialize, Deserialize)]
79struct OpenAITool {
80 #[serde(rename = "type")]
81 tool_type: String, function: OpenAIFunction,
83}
84
85#[derive(Debug, Serialize, Deserialize)]
87struct OpenAIToolCall {
88 id: String,
89 #[serde(rename = "type")]
90 tool_type: String,
91 function: OpenAIFunctionCall,
92}
93
94#[derive(Debug, Serialize, Deserialize)]
95struct OpenAIFunctionCall {
96 name: String,
97 arguments: String, }
99
100#[derive(Debug, Serialize, Deserialize)]
101#[serde(rename_all = "lowercase")]
102enum ReasoningEffort {
103 Low,
104 Medium,
105 High,
106}
107
108#[derive(Debug, Serialize, Deserialize)]
109#[serde(rename_all = "lowercase")]
110enum ServiceTier {
111 Auto,
112 Default,
113 Flex,
114}
115
116#[derive(Debug, Serialize, Deserialize)]
117struct AudioOutput {
118 #[serde(skip_serializing_if = "Option::is_none")]
119 voice: Option<String>,
120 #[serde(skip_serializing_if = "Option::is_none")]
121 format: Option<String>,
122}
123
124#[derive(Debug, Serialize, Deserialize)]
125#[serde(untagged)]
126enum StopSequences {
127 Single(String),
128 Multiple(Vec<String>),
129}
130
131#[derive(Debug, Serialize, Deserialize)]
132struct StreamOptions {
133 #[serde(skip_serializing_if = "Option::is_none")]
134 include_usage: Option<bool>,
135}
136
137#[derive(Debug, Serialize, Deserialize)]
138#[serde(untagged)]
139enum ToolChoice {
140 #[serde(rename = "auto")]
141 Auto,
142 #[serde(rename = "required")]
143 Required,
144 Specific {
145 #[serde(rename = "type")]
146 tool_type: String,
147 function: ToolChoiceFunction,
148 },
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152struct ToolChoiceFunction {
153 name: String,
154}
155
156#[derive(Debug, Serialize, Deserialize)]
157#[serde(untagged)]
158enum ResponseFormat {
159 JsonObject {
160 #[serde(rename = "type")]
161 format_type: String, },
163 JsonSchema {
164 #[serde(rename = "type")]
165 format_type: String, json_schema: serde_json::Value,
167 },
168}
169
170#[derive(Debug, Serialize, Deserialize)]
171#[serde(rename_all = "lowercase")]
172enum PredictionType {
173 Content,
174}
175
176#[derive(Debug, Serialize, Deserialize)]
177#[serde(untagged)]
178enum Prediction {
179 Content {
180 #[serde(rename = "type")]
181 prediction_type: PredictionType,
182 content: String,
183 },
184}
185
186#[derive(Debug, Serialize, Deserialize)]
187struct WebSearchOptions {
188 #[serde(skip_serializing_if = "Option::is_none")]
189 max_results: Option<u32>,
190}
191
192#[derive(Debug, Serialize, Deserialize)]
193struct CompletionRequest {
194 model: String,
195 messages: Vec<OpenAIMessage>,
196 #[serde(skip_serializing_if = "Option::is_none")]
197 audio: Option<AudioOutput>,
198 #[serde(skip_serializing_if = "Option::is_none")]
199 frequency_penalty: Option<f32>,
200 #[serde(skip_serializing_if = "Option::is_none")]
201 logit_bias: Option<HashMap<String, f32>>,
202 #[serde(skip_serializing_if = "Option::is_none")]
203 logprobs: Option<bool>,
204 #[serde(skip_serializing_if = "Option::is_none")]
205 max_completion_tokens: Option<u32>,
206 #[serde(skip_serializing_if = "Option::is_none")]
207 metadata: Option<HashMap<String, String>>,
208 #[serde(skip_serializing_if = "Option::is_none")]
209 modalities: Option<Vec<String>>,
210 #[serde(skip_serializing_if = "Option::is_none")]
211 n: Option<u32>,
212 #[serde(skip_serializing_if = "Option::is_none")]
213 parallel_tool_calls: Option<bool>,
214 #[serde(skip_serializing_if = "Option::is_none")]
215 prediction: Option<Prediction>,
216 #[serde(skip_serializing_if = "Option::is_none")]
217 presence_penalty: Option<f32>,
218 #[serde(skip_serializing_if = "Option::is_none")]
219 reasoning_effort: Option<ReasoningEffort>,
220 #[serde(skip_serializing_if = "Option::is_none")]
221 response_format: Option<ResponseFormat>,
222 #[serde(skip_serializing_if = "Option::is_none")]
223 seed: Option<u64>,
224 #[serde(skip_serializing_if = "Option::is_none")]
225 service_tier: Option<ServiceTier>,
226 #[serde(skip_serializing_if = "Option::is_none")]
227 stop: Option<StopSequences>,
228 #[serde(skip_serializing_if = "Option::is_none")]
229 store: Option<bool>,
230 #[serde(skip_serializing_if = "Option::is_none")]
231 stream: Option<bool>,
232 #[serde(skip_serializing_if = "Option::is_none")]
233 stream_options: Option<StreamOptions>,
234 #[serde(skip_serializing_if = "Option::is_none")]
235 temperature: Option<f32>,
236 #[serde(skip_serializing_if = "Option::is_none")]
237 tool_choice: Option<ToolChoice>,
238 #[serde(skip_serializing_if = "Option::is_none")]
239 tools: Option<Vec<OpenAITool>>,
240 #[serde(skip_serializing_if = "Option::is_none")]
241 top_logprobs: Option<u32>,
242 #[serde(skip_serializing_if = "Option::is_none")]
243 top_p: Option<f32>,
244 #[serde(skip_serializing_if = "Option::is_none")]
245 user: Option<String>,
246 #[serde(skip_serializing_if = "Option::is_none")]
247 web_search_options: Option<WebSearchOptions>,
248}
249
250#[derive(Debug, Serialize, Deserialize)]
251struct OpenAICompletionResponse {
252 id: String,
253 object: String,
254 created: u64,
255 model: String,
256 choices: Vec<Choice>,
257 usage: OpenAIUsage,
258}
259
260#[derive(Debug, Serialize, Deserialize)]
261struct Choice {
262 index: i32,
263 message: AssistantMessage,
264 finish_reason: Option<String>,
265}
266
267#[derive(Debug, Serialize, Deserialize)]
268struct AssistantMessage {
269 content: Option<String>,
270 #[serde(skip_serializing_if = "Option::is_none")]
271 tool_calls: Option<Vec<OpenAIToolCall>>,
272 #[serde(skip_serializing_if = "Option::is_none")]
273 reasoning_content: Option<String>,
274}
275
276#[derive(Debug, Serialize, Deserialize)]
277struct PromptTokensDetails {
278 cached_tokens: usize,
279 audio_tokens: usize,
280}
281
282#[derive(Debug, Serialize, Deserialize)]
283struct CompletionTokensDetails {
284 reasoning_tokens: usize,
285 audio_tokens: usize,
286 accepted_prediction_tokens: usize,
287 rejected_prediction_tokens: usize,
288}
289
290#[derive(Debug, Serialize, Deserialize)]
291struct OpenAIUsage {
292 prompt_tokens: usize,
293 completion_tokens: usize,
294 total_tokens: usize,
295 #[serde(skip_serializing_if = "Option::is_none")]
296 prompt_tokens_details: Option<PromptTokensDetails>,
297 #[serde(skip_serializing_if = "Option::is_none")]
298 completion_tokens_details: Option<CompletionTokensDetails>,
299}
300
301impl OpenAIClient {
302 pub fn new(api_key: String) -> Self {
303 let mut headers = header::HeaderMap::new();
304 headers.insert(
305 header::AUTHORIZATION,
306 header::HeaderValue::from_str(&format!("Bearer {api_key}"))
307 .expect("Invalid API key format"),
308 );
309
310 let client = reqwest::Client::builder()
311 .default_headers(headers)
312 .timeout(std::time::Duration::from_secs(300)) .build()
314 .expect("Failed to build HTTP client");
315
316 Self {
317 http_client: client,
318 }
319 }
320
321 fn convert_messages(
322 &self,
323 messages: Vec<AppMessage>,
324 system: Option<String>,
325 ) -> Vec<OpenAIMessage> {
326 let mut openai_messages = Vec::new();
327
328 if let Some(system_content) = system {
330 openai_messages.push(OpenAIMessage::System {
331 content: OpenAIContent::String(system_content),
332 name: None,
333 });
334 }
335
336 for message in messages {
338 match &message.data {
339 crate::app::conversation::MessageData::User { content, .. } => {
340 let combined_text = content
342 .iter()
343 .filter_map(|user_content| match user_content {
344 UserContent::Text { text } => Some(text.clone()),
345 UserContent::CommandExecution {
346 command,
347 stdout,
348 stderr,
349 exit_code,
350 } => Some(UserContent::format_command_execution_as_xml(
351 command, stdout, stderr, *exit_code,
352 )),
353 UserContent::AppCommand { .. } => {
354 None
356 }
357 })
358 .collect::<Vec<_>>()
359 .join("\n");
360
361 if !combined_text.trim().is_empty() {
363 openai_messages.push(OpenAIMessage::User {
364 content: OpenAIContent::String(combined_text),
365 name: None,
366 });
367 }
368 }
369 crate::app::conversation::MessageData::Assistant { content, .. } => {
370 let mut text_parts = Vec::new();
372 let mut tool_calls = Vec::new();
373
374 for content_block in content {
375 match content_block {
376 AssistantContent::Text { text } => {
377 text_parts.push(text.clone());
378 }
379 AssistantContent::ToolCall { tool_call } => {
380 tool_calls.push(OpenAIToolCall {
381 id: tool_call.id.clone(),
382 tool_type: "function".to_string(),
383 function: OpenAIFunctionCall {
384 name: tool_call.name.clone(),
385 arguments: tool_call.parameters.to_string(),
386 },
387 });
388 }
389 AssistantContent::Thought { .. } => {
390 continue;
392 }
393 }
394 }
395
396 let content = if text_parts.is_empty() {
398 None
399 } else {
400 Some(OpenAIContent::String(text_parts.join("\n")))
401 };
402
403 let tool_calls_opt = if tool_calls.is_empty() {
404 None
405 } else {
406 Some(tool_calls)
407 };
408
409 openai_messages.push(OpenAIMessage::Assistant {
410 content,
411 tool_calls: tool_calls_opt,
412 name: None,
413 });
414 }
415 crate::app::conversation::MessageData::Tool {
416 tool_use_id,
417 result,
418 ..
419 } => {
420 let content_text = match result {
422 ToolResult::Error(e) => format!("Error: {e}"),
423 _ => {
424 let text = result.llm_format();
425 if text.trim().is_empty() {
426 "(No output)".to_string()
427 } else {
428 text
429 }
430 }
431 };
432
433 openai_messages.push(OpenAIMessage::Tool {
434 content: OpenAIContent::String(content_text),
435 tool_call_id: tool_use_id.clone(),
436 name: None,
437 });
438 }
439 }
440 }
441
442 openai_messages
443 }
444
445 fn convert_tools(&self, tools: Vec<ToolSchema>) -> Vec<OpenAITool> {
446 tools
447 .into_iter()
448 .map(|tool| OpenAITool {
449 tool_type: "function".to_string(),
450 function: OpenAIFunction {
451 name: tool.name,
452 description: tool.description,
453 parameters: serde_json::json!({
454 "type": tool.input_schema.schema_type,
455 "properties": tool.input_schema.properties,
456 "required": tool.input_schema.required,
457 }),
458 },
459 })
460 .collect()
461 }
462}
463
464#[async_trait]
465impl Provider for OpenAIClient {
466 fn name(&self) -> &'static str {
467 "openai"
468 }
469
470 async fn complete(
471 &self,
472 model: Model,
473 messages: Vec<AppMessage>,
474 system: Option<String>,
475 tools: Option<Vec<ToolSchema>>,
476 token: CancellationToken,
477 ) -> Result<CompletionResponse, ApiError> {
478 let openai_messages = self.convert_messages(messages, system);
480 let openai_tools = tools.map(|t| self.convert_tools(t));
481
482 let request = if model.supports_thinking() {
483 CompletionRequest {
484 model: model.as_ref().to_string(),
485 messages: openai_messages,
486 audio: None,
487 frequency_penalty: None,
488 logit_bias: None,
489 logprobs: None,
490 max_completion_tokens: Some(32_000), metadata: None,
492 modalities: None,
493 n: None,
494 parallel_tool_calls: None,
495 prediction: None,
496 presence_penalty: None,
497 reasoning_effort: Some(ReasoningEffort::High),
498 response_format: None,
499 seed: None,
500 service_tier: None,
501 stop: None,
502 store: None,
503 stream: None,
504 stream_options: None,
505 temperature: Some(1.0),
506 tool_choice: None,
507 tools: openai_tools,
508 top_logprobs: None,
509 top_p: None,
510 user: None,
511 web_search_options: None,
512 }
513 } else {
514 CompletionRequest {
515 model: model.as_ref().to_string(),
516 messages: openai_messages,
517 audio: None,
518 frequency_penalty: None,
519 logit_bias: None,
520 logprobs: None,
521 max_completion_tokens: None,
522 metadata: None,
523 modalities: None,
524 n: None,
525 parallel_tool_calls: None,
526 prediction: None,
527 presence_penalty: None,
528 reasoning_effort: None,
529 response_format: None,
530 seed: None,
531 service_tier: None,
532 stop: None,
533 store: None,
534 stream: None,
535 stream_options: None,
536 temperature: Some(1.0),
537 tool_choice: None,
538 tools: openai_tools,
539 top_logprobs: None,
540 top_p: None,
541 user: None,
542 web_search_options: None,
543 }
544 };
545
546 let response = self
547 .http_client
548 .post(API_URL)
549 .json(&request)
550 .send()
551 .await
552 .map_err(ApiError::Network)?;
553
554 if !response.status().is_success() {
555 let status = response.status();
556 let error_text = response.text().await.unwrap_or_else(|_| String::new());
557
558 debug!(
559 target: "openai::complete",
560 "OpenAI API error - Status: {}, Body: {}",
561 status,
562 error_text
563 );
564
565 return match status.as_u16() {
566 429 => Err(ApiError::RateLimited {
567 provider: self.name().to_string(),
568 details: error_text,
569 }),
570 400 => Err(ApiError::InvalidRequest {
571 provider: self.name().to_string(),
572 details: error_text,
573 }),
574 401 => Err(ApiError::AuthenticationFailed {
575 provider: self.name().to_string(),
576 details: error_text,
577 }),
578 _ => Err(ApiError::ServerError {
579 provider: self.name().to_string(),
580 status_code: status.as_u16(),
581 details: error_text,
582 }),
583 };
584 }
585
586 let response_text = tokio::select! {
587 _ = token.cancelled() => {
588 debug!(target: "openai::complete", "Cancellation token triggered while reading successful response body.");
589 return Err(ApiError::Cancelled { provider: self.name().to_string() });
590 }
591 text_res = response.text() => {
592 text_res?
593 }
594 };
595
596 let openai_response: OpenAICompletionResponse = serde_json::from_str(&response_text)
597 .map_err(|e| {
598 error!(
599 target: "openai::complete",
600 "Failed to parse response: {}, Body: {}",
601 e,
602 response_text
603 );
604 ApiError::ResponseParsingError {
605 provider: self.name().to_string(),
606 details: format!("Error: {e}, Body: {response_text}"),
607 }
608 })?;
609
610 if let Some(choice) = openai_response.choices.first() {
612 let mut content_blocks = Vec::new();
613
614 if let Some(reasoning) = &choice.message.reasoning_content {
616 content_blocks.push(AssistantContent::Thought {
617 thought: ThoughtContent::Simple {
618 text: reasoning.clone(),
619 },
620 });
621 }
622
623 if let Some(content) = &choice.message.content {
625 if !content.trim().is_empty() {
626 content_blocks.push(AssistantContent::Text {
627 text: content.clone(),
628 });
629 }
630 }
631
632 if let Some(tool_calls) = &choice.message.tool_calls {
634 for tool_call in tool_calls {
635 let parameters = serde_json::from_str(&tool_call.function.arguments)
637 .unwrap_or(serde_json::Value::Null);
638
639 content_blocks.push(AssistantContent::ToolCall {
640 tool_call: steer_tools::ToolCall {
641 id: tool_call.id.clone(),
642 name: tool_call.function.name.clone(),
643 parameters,
644 },
645 });
646 }
647 }
648
649 Ok(crate::api::provider::CompletionResponse {
650 content: content_blocks,
651 })
652 } else {
653 Err(ApiError::NoChoices {
654 provider: self.name().to_string(),
655 })
656 }
657 }
658}