vtcode_core/llm/providers/
openai.rs

1use crate::config::constants::{model_helpers, models};
2use crate::llm::client::LLMClient;
3use crate::llm::error_display;
4use crate::llm::provider::{
5    FinishReason, LLMError, LLMProvider, LLMRequest, LLMResponse, Message, MessageRole, ToolCall,
6};
7use crate::llm::types as llm_types;
8use async_trait::async_trait;
9use reqwest::Client as HttpClient;
10use serde_json::{Value, json};
11
12pub struct OpenAIProvider {
13    api_key: String,
14    http_client: HttpClient,
15    base_url: String,
16}
17
18impl OpenAIProvider {
19    pub fn new(api_key: String) -> Self {
20        Self {
21            api_key,
22            http_client: HttpClient::new(),
23            base_url: "https://api.openai.com/v1".to_string(),
24        }
25    }
26}
27
28#[async_trait]
29impl LLMProvider for OpenAIProvider {
30    fn name(&self) -> &str {
31        "openai"
32    }
33
34    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse, LLMError> {
35        let openai_request = self.convert_to_openai_format(&request)?;
36
37        let url = format!("{}/chat/completions", self.base_url);
38
39        let response = self
40            .http_client
41            .post(&url)
42            .bearer_auth(&self.api_key)
43            .json(&openai_request)
44            .send()
45            .await
46            .map_err(|e| {
47                let formatted_error =
48                    error_display::format_llm_error("OpenAI", &format!("Network error: {}", e));
49                LLMError::Network(formatted_error)
50            })?;
51
52        if !response.status().is_success() {
53            let status = response.status();
54            let error_text = response.text().await.unwrap_or_default();
55            let formatted_error = error_display::format_llm_error(
56                "OpenAI",
57                &format!("HTTP {}: {}", status, error_text),
58            );
59            return Err(LLMError::Provider(formatted_error));
60        }
61
62        let openai_response: Value = response.json().await.map_err(|e| {
63            let formatted_error = error_display::format_llm_error(
64                "OpenAI",
65                &format!("Failed to parse response: {}", e),
66            );
67            LLMError::Provider(formatted_error)
68        })?;
69
70        self.parse_openai_response(openai_response)
71    }
72
73    fn supported_models(&self) -> Vec<String> {
74        models::openai::SUPPORTED_MODELS
75            .iter()
76            .map(|s| s.to_string())
77            .collect()
78    }
79
80    fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
81        if request.messages.is_empty() {
82            return Err(LLMError::InvalidRequest(
83                "Messages cannot be empty".to_string(),
84            ));
85        }
86
87        if request.model.is_empty() {
88            return Err(LLMError::InvalidRequest(
89                "Model cannot be empty".to_string(),
90            ));
91        }
92
93        Ok(())
94    }
95}
96
97impl OpenAIProvider {
98    fn convert_to_openai_format(&self, request: &LLMRequest) -> Result<Value, LLMError> {
99        let mut messages = Vec::new();
100
101        // Add system message if present
102        if let Some(system_prompt) = &request.system_prompt {
103            messages.push(json!({
104                "role": crate::config::constants::message_roles::SYSTEM,
105                "content": system_prompt
106            }));
107        }
108
109        // Convert messages
110        for msg in &request.messages {
111            // Use the proper role mapping for OpenAI
112            let role = msg.role.as_openai_str();
113
114            let mut message = json!({
115                "role": role,
116                "content": msg.content
117            });
118
119            // Add tool call information for assistant messages
120            // Based on OpenAI docs: only assistant messages can have tool_calls
121            if msg.role == MessageRole::Assistant {
122                if let Some(tool_calls) = &msg.tool_calls {
123                    if !tool_calls.is_empty() {
124                        let tool_calls_json: Vec<Value> = tool_calls
125                            .iter()
126                            .map(|tc| {
127                                json!({
128                                    "id": tc.id,
129                                    "type": "function",
130                                    "function": {
131                                        "name": tc.function.name,
132                                        "arguments": tc.function.arguments
133                                    }
134                                })
135                            })
136                            .collect();
137                        message["tool_calls"] = Value::Array(tool_calls_json);
138                    }
139                }
140            }
141
142            // Add tool_call_id for tool messages
143            // Based on OpenAI docs: tool messages must have tool_call_id
144            if msg.role == MessageRole::Tool {
145                if let Some(tool_call_id) = &msg.tool_call_id {
146                    message["tool_call_id"] = Value::String(tool_call_id.clone());
147                } else {
148                    // This should not happen in well-formed requests
149                    eprintln!("Warning: Tool message without tool_call_id in OpenAI request");
150                }
151            }
152
153            messages.push(message);
154        }
155
156        let mut openai_request = json!({
157            "model": request.model,
158            "messages": messages,
159            "stream": request.stream
160        });
161
162        // Add optional parameters
163        if let Some(max_tokens) = request.max_tokens {
164            openai_request["max_tokens"] = json!(max_tokens);
165        }
166
167        if let Some(temperature) = request.temperature {
168            openai_request["temperature"] = json!(temperature);
169        }
170
171        // Add tools if present
172        if let Some(tools) = &request.tools {
173            if !tools.is_empty() {
174                let tools_json: Vec<Value> = tools
175                    .iter()
176                    .map(|tool| {
177                        json!({
178                            "type": "function",
179                            "function": {
180                                "name": tool.function.name,
181                                "description": tool.function.description,
182                                "parameters": tool.function.parameters
183                            }
184                        })
185                    })
186                    .collect();
187                openai_request["tools"] = Value::Array(tools_json);
188            }
189        }
190
191        // Add tool_choice if specified
192        if let Some(tool_choice) = &request.tool_choice {
193            openai_request["tool_choice"] = tool_choice.to_provider_format("openai");
194        }
195
196        // Add parallel_tool_calls if specified
197        if let Some(parallel) = request.parallel_tool_calls {
198            openai_request["parallel_tool_calls"] = Value::Bool(parallel);
199        }
200
201        // Add reasoning_effort for models that support it (GPT-5 etc.)
202        if let Some(reasoning_effort) = &request.reasoning_effort {
203            if request.model.contains(models::openai::GPT_5)
204                || request.model.contains(models::openai::GPT_5_MINI)
205            {
206                openai_request["reasoning_effort"] = json!(reasoning_effort);
207            }
208        }
209
210        Ok(openai_request)
211    }
212
213    fn parse_openai_response(&self, response_json: Value) -> Result<LLMResponse, LLMError> {
214        let choices = response_json
215            .get("choices")
216            .and_then(|c| c.as_array())
217            .ok_or_else(|| {
218                LLMError::Provider("Invalid response format: missing choices".to_string())
219            })?;
220
221        if choices.is_empty() {
222            return Err(LLMError::Provider("No choices in response".to_string()));
223        }
224
225        let choice = &choices[0];
226        let message = choice.get("message").ok_or_else(|| {
227            LLMError::Provider("Invalid response format: missing message".to_string())
228        })?;
229
230        let content = message
231            .get("content")
232            .and_then(|c| c.as_str())
233            .map(|s| s.to_string());
234
235        // Parse tool calls
236        let tool_calls = message
237            .get("tool_calls")
238            .and_then(|tc| tc.as_array())
239            .map(|calls| {
240                calls
241                    .iter()
242                    .filter_map(|call| {
243                        Some(ToolCall {
244                            id: call.get("id")?.as_str()?.to_string(),
245                            call_type: "function".to_string(),
246                            function: crate::llm::provider::FunctionCall {
247                                name: call.get("function")?.get("name")?.as_str()?.to_string(),
248                                arguments: call
249                                    .get("function")
250                                    .and_then(|f| f.get("arguments"))
251                                    .and_then(|args| args.as_str())
252                                    .unwrap_or("{}")
253                                    .to_string(),
254                            },
255                        })
256                    })
257                    .collect()
258            });
259
260        // Parse finish reason
261        let finish_reason = choice
262            .get("finish_reason")
263            .and_then(|fr| fr.as_str())
264            .map(|fr| match fr {
265                "stop" => FinishReason::Stop,
266                "length" => FinishReason::Length,
267                "tool_calls" => FinishReason::ToolCalls,
268                "content_filter" => FinishReason::ContentFilter,
269                _ => FinishReason::Error(fr.to_string()),
270            })
271            .unwrap_or(FinishReason::Stop);
272
273        // Parse usage
274        let usage = response_json
275            .get("usage")
276            .map(|u| crate::llm::provider::Usage {
277                prompt_tokens: u
278                    .get("prompt_tokens")
279                    .and_then(|pt| pt.as_u64())
280                    .unwrap_or(0) as u32,
281                completion_tokens: u
282                    .get("completion_tokens")
283                    .and_then(|ct| ct.as_u64())
284                    .unwrap_or(0) as u32,
285                total_tokens: u
286                    .get("total_tokens")
287                    .and_then(|tt| tt.as_u64())
288                    .unwrap_or(0) as u32,
289            });
290
291        Ok(LLMResponse {
292            content,
293            tool_calls,
294            usage,
295            finish_reason,
296        })
297    }
298}
299
300#[async_trait]
301impl LLMClient for OpenAIProvider {
302    async fn generate(&mut self, _prompt: &str) -> Result<llm_types::LLMResponse, LLMError> {
303        let model = models::openai::DEFAULT_MODEL.to_string();
304
305        // Validate the model
306        if !model_helpers::is_valid("openai", &model) {
307            return Err(LLMError::InvalidRequest(format!(
308                "Invalid OpenAI model '{}'. See docs/models.json",
309                model
310            )));
311        }
312
313        let request = LLMRequest {
314            messages: vec![Message::user("test".to_string())],
315            system_prompt: None,
316            tools: None,
317            model: "test".to_string(),
318            max_tokens: Some(100),
319            temperature: None,
320            stream: false,
321            tool_choice: None,
322            parallel_tool_calls: None,
323            parallel_tool_config: None,
324            reasoning_effort: None,
325        };
326
327        let response = LLMProvider::generate(self, request.clone()).await?;
328
329        Ok(llm_types::LLMResponse {
330            content: response.content.unwrap_or("".to_string()),
331            model,
332            usage: response.usage.map(|u| llm_types::Usage {
333                prompt_tokens: u.prompt_tokens as usize,
334                completion_tokens: u.completion_tokens as usize,
335                total_tokens: u.total_tokens as usize,
336            }),
337        })
338    }
339
340    fn backend_kind(&self) -> llm_types::BackendKind {
341        llm_types::BackendKind::OpenAI
342    }
343
344    fn model_id(&self) -> &str {
345        models::openai::DEFAULT_MODEL
346    }
347}