vtcode_core/llm/providers/
openai.rs

1use crate::config::constants::{models, urls};
2use crate::llm::client::LLMClient;
3use crate::llm::error_display;
4use crate::llm::provider::{
5    FinishReason, LLMError, LLMProvider, LLMRequest, LLMResponse, Message, MessageRole, ToolCall,
6    ToolChoice, ToolDefinition,
7};
8use crate::llm::types as llm_types;
9use async_trait::async_trait;
10use reqwest::Client as HttpClient;
11use serde_json::{Value, json};
12
13use super::extract_reasoning_trace;
14
15pub struct OpenAIProvider {
16    api_key: String,
17    http_client: HttpClient,
18    base_url: String,
19    model: String,
20}
21
22impl OpenAIProvider {
23    pub fn new(api_key: String) -> Self {
24        Self::with_model(api_key, models::openai::DEFAULT_MODEL.to_string())
25    }
26
27    pub fn with_model(api_key: String, model: String) -> Self {
28        Self {
29            api_key,
30            http_client: HttpClient::new(),
31            base_url: urls::OPENAI_API_BASE.to_string(),
32            model,
33        }
34    }
35
36    pub fn from_config(
37        api_key: Option<String>,
38        model: Option<String>,
39        base_url: Option<String>,
40    ) -> Self {
41        let api_key_value = api_key.unwrap_or_default();
42        let mut provider = if let Some(model_value) = model {
43            Self::with_model(api_key_value, model_value)
44        } else {
45            Self::new(api_key_value)
46        };
47        if let Some(base) = base_url {
48            provider.base_url = base;
49        }
50        provider
51    }
52
53    fn default_request(&self, prompt: &str) -> LLMRequest {
54        LLMRequest {
55            messages: vec![Message::user(prompt.to_string())],
56            system_prompt: None,
57            tools: None,
58            model: self.model.clone(),
59            max_tokens: None,
60            temperature: None,
61            stream: false,
62            tool_choice: None,
63            parallel_tool_calls: None,
64            parallel_tool_config: None,
65            reasoning_effort: None,
66        }
67    }
68
69    fn parse_client_prompt(&self, prompt: &str) -> LLMRequest {
70        let trimmed = prompt.trim_start();
71        if trimmed.starts_with('{') {
72            if let Ok(value) = serde_json::from_str::<Value>(trimmed) {
73                if let Some(request) = self.parse_chat_request(&value) {
74                    return request;
75                }
76            }
77        }
78
79        self.default_request(prompt)
80    }
81
82    fn parse_chat_request(&self, value: &Value) -> Option<LLMRequest> {
83        let messages_value = value.get("messages")?.as_array()?;
84        let mut system_prompt = None;
85        let mut messages = Vec::new();
86
87        for entry in messages_value {
88            let role = entry
89                .get("role")
90                .and_then(|r| r.as_str())
91                .unwrap_or(crate::config::constants::message_roles::USER);
92            let content = entry.get("content");
93            let text_content = content.map(Self::extract_content_text).unwrap_or_default();
94
95            match role {
96                "system" => {
97                    if system_prompt.is_none() && !text_content.is_empty() {
98                        system_prompt = Some(text_content);
99                    }
100                }
101                "assistant" => {
102                    let tool_calls = entry
103                        .get("tool_calls")
104                        .and_then(|tc| tc.as_array())
105                        .map(|calls| {
106                            calls
107                                .iter()
108                                .filter_map(|call| {
109                                    let id = call.get("id").and_then(|v| v.as_str())?;
110                                    let function = call.get("function")?;
111                                    let name = function.get("name").and_then(|v| v.as_str())?;
112                                    let arguments = function.get("arguments");
113                                    let serialized = arguments.map_or("{}".to_string(), |value| {
114                                        if value.is_string() {
115                                            value.as_str().unwrap_or("").to_string()
116                                        } else {
117                                            value.to_string()
118                                        }
119                                    });
120                                    Some(ToolCall::function(
121                                        id.to_string(),
122                                        name.to_string(),
123                                        serialized,
124                                    ))
125                                })
126                                .collect::<Vec<_>>()
127                        })
128                        .filter(|calls| !calls.is_empty());
129
130                    let message = if let Some(calls) = tool_calls {
131                        Message {
132                            role: MessageRole::Assistant,
133                            content: text_content,
134                            tool_calls: Some(calls),
135                            tool_call_id: None,
136                        }
137                    } else {
138                        Message::assistant(text_content)
139                    };
140                    messages.push(message);
141                }
142                "tool" => {
143                    let tool_call_id = entry
144                        .get("tool_call_id")
145                        .and_then(|id| id.as_str())
146                        .map(|s| s.to_string());
147                    let content_value = entry
148                        .get("content")
149                        .map(|value| {
150                            if text_content.is_empty() {
151                                value.to_string()
152                            } else {
153                                text_content.clone()
154                            }
155                        })
156                        .unwrap_or_else(|| text_content.clone());
157                    messages.push(Message {
158                        role: MessageRole::Tool,
159                        content: content_value,
160                        tool_calls: None,
161                        tool_call_id,
162                    });
163                }
164                _ => {
165                    messages.push(Message::user(text_content));
166                }
167            }
168        }
169
170        if messages.is_empty() {
171            return None;
172        }
173
174        let tools = value.get("tools").and_then(|tools_value| {
175            let tools_array = tools_value.as_array()?;
176            let converted: Vec<_> = tools_array
177                .iter()
178                .filter_map(|tool| {
179                    let function = tool.get("function")?;
180                    let name = function.get("name").and_then(|n| n.as_str())?;
181                    let description = function
182                        .get("description")
183                        .and_then(|d| d.as_str())
184                        .unwrap_or("")
185                        .to_string();
186                    let parameters = function
187                        .get("parameters")
188                        .cloned()
189                        .unwrap_or_else(|| json!({}));
190                    Some(ToolDefinition::function(
191                        name.to_string(),
192                        description,
193                        parameters,
194                    ))
195                })
196                .collect();
197
198            if converted.is_empty() {
199                None
200            } else {
201                Some(converted)
202            }
203        });
204
205        let max_tokens = value
206            .get("max_tokens")
207            .and_then(|v| v.as_u64())
208            .map(|v| v as u32);
209        let temperature = value
210            .get("temperature")
211            .and_then(|v| v.as_f64())
212            .map(|v| v as f32);
213        let stream = value
214            .get("stream")
215            .and_then(|v| v.as_bool())
216            .unwrap_or(false);
217        let tool_choice = value.get("tool_choice").and_then(Self::parse_tool_choice);
218        let parallel_tool_calls = value.get("parallel_tool_calls").and_then(|v| v.as_bool());
219        let reasoning_effort = value
220            .get("reasoning_effort")
221            .and_then(|v| v.as_str())
222            .map(|s| s.to_string())
223            .or_else(|| {
224                value
225                    .get("reasoning")
226                    .and_then(|r| r.get("effort"))
227                    .and_then(|effort| effort.as_str())
228                    .map(|s| s.to_string())
229            });
230
231        let model = value
232            .get("model")
233            .and_then(|m| m.as_str())
234            .unwrap_or(&self.model)
235            .to_string();
236
237        Some(LLMRequest {
238            messages,
239            system_prompt,
240            tools,
241            model,
242            max_tokens,
243            temperature,
244            stream,
245            tool_choice,
246            parallel_tool_calls,
247            parallel_tool_config: None,
248            reasoning_effort,
249        })
250    }
251
252    fn extract_content_text(content: &Value) -> String {
253        match content {
254            Value::String(text) => text.to_string(),
255            Value::Array(parts) => parts
256                .iter()
257                .filter_map(|part| {
258                    if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
259                        Some(text.to_string())
260                    } else if let Some(Value::String(text)) = part.get("content") {
261                        Some(text.clone())
262                    } else {
263                        None
264                    }
265                })
266                .collect::<Vec<_>>()
267                .join(""),
268            _ => String::new(),
269        }
270    }
271
272    fn parse_tool_choice(choice: &Value) -> Option<ToolChoice> {
273        match choice {
274            Value::String(value) => match value.as_str() {
275                "auto" => Some(ToolChoice::auto()),
276                "none" => Some(ToolChoice::none()),
277                "required" => Some(ToolChoice::any()),
278                _ => None,
279            },
280            Value::Object(map) => {
281                let choice_type = map.get("type").and_then(|t| t.as_str())?;
282                match choice_type {
283                    "function" => map
284                        .get("function")
285                        .and_then(|f| f.get("name"))
286                        .and_then(|n| n.as_str())
287                        .map(|name| ToolChoice::function(name.to_string())),
288                    "auto" => Some(ToolChoice::auto()),
289                    "none" => Some(ToolChoice::none()),
290                    "any" | "required" => Some(ToolChoice::any()),
291                    _ => None,
292                }
293            }
294            _ => None,
295        }
296    }
297
298    fn convert_to_openai_format(&self, request: &LLMRequest) -> Result<Value, LLMError> {
299        let mut messages = Vec::new();
300
301        if let Some(system_prompt) = &request.system_prompt {
302            messages.push(json!({
303                "role": crate::config::constants::message_roles::SYSTEM,
304                "content": system_prompt
305            }));
306        }
307
308        for msg in &request.messages {
309            let role = msg.role.as_openai_str();
310            let mut message = json!({
311                "role": role,
312                "content": msg.content
313            });
314
315            if msg.role == MessageRole::Assistant {
316                if let Some(tool_calls) = &msg.tool_calls {
317                    if !tool_calls.is_empty() {
318                        let tool_calls_json: Vec<Value> = tool_calls
319                            .iter()
320                            .map(|tc| {
321                                json!({
322                                    "id": tc.id,
323                                    "type": "function",
324                                    "function": {
325                                        "name": tc.function.name,
326                                        "arguments": tc.function.arguments
327                                    }
328                                })
329                            })
330                            .collect();
331                        message["tool_calls"] = Value::Array(tool_calls_json);
332                    }
333                }
334            }
335
336            if msg.role == MessageRole::Tool {
337                if let Some(tool_call_id) = &msg.tool_call_id {
338                    message["tool_call_id"] = Value::String(tool_call_id.clone());
339                }
340            }
341
342            messages.push(message);
343        }
344
345        if messages.is_empty() {
346            let formatted_error = error_display::format_llm_error("OpenAI", "No messages provided");
347            return Err(LLMError::InvalidRequest(formatted_error));
348        }
349
350        let mut openai_request = json!({
351            "model": request.model,
352            "messages": messages,
353            "stream": request.stream
354        });
355
356        if let Some(max_tokens) = request.max_tokens {
357            openai_request["max_tokens"] = json!(max_tokens);
358        }
359
360        if let Some(temperature) = request.temperature {
361            openai_request["temperature"] = json!(temperature);
362        }
363
364        if let Some(tools) = &request.tools {
365            if !tools.is_empty() {
366                let tools_json: Vec<Value> = tools
367                    .iter()
368                    .map(|tool| {
369                        json!({
370                            "type": "function",
371                            "function": {
372                                "name": tool.function.name,
373                                "description": tool.function.description,
374                                "parameters": tool.function.parameters
375                            }
376                        })
377                    })
378                    .collect();
379                openai_request["tools"] = Value::Array(tools_json);
380            }
381        }
382
383        if let Some(tool_choice) = &request.tool_choice {
384            openai_request["tool_choice"] = tool_choice.to_provider_format("openai");
385        }
386
387        if let Some(parallel) = request.parallel_tool_calls {
388            openai_request["parallel_tool_calls"] = Value::Bool(parallel);
389        }
390
391        if let Some(effort) = request.reasoning_effort.as_deref() {
392            if self.supports_reasoning_effort(&request.model) {
393                openai_request["reasoning"] = json!({ "effort": effort });
394            }
395        }
396
397        Ok(openai_request)
398    }
399
400    fn parse_openai_response(&self, response_json: Value) -> Result<LLMResponse, LLMError> {
401        let choices = response_json
402            .get("choices")
403            .and_then(|c| c.as_array())
404            .ok_or_else(|| {
405                let formatted_error = error_display::format_llm_error(
406                    "OpenAI",
407                    "Invalid response format: missing choices",
408                );
409                LLMError::Provider(formatted_error)
410            })?;
411
412        if choices.is_empty() {
413            let formatted_error =
414                error_display::format_llm_error("OpenAI", "No choices in response");
415            return Err(LLMError::Provider(formatted_error));
416        }
417
418        let choice = &choices[0];
419        let message = choice.get("message").ok_or_else(|| {
420            let formatted_error = error_display::format_llm_error(
421                "OpenAI",
422                "Invalid response format: missing message",
423            );
424            LLMError::Provider(formatted_error)
425        })?;
426
427        let content = match message.get("content") {
428            Some(Value::String(text)) => Some(text.to_string()),
429            Some(Value::Array(parts)) => {
430                let text = parts
431                    .iter()
432                    .filter_map(|part| part.get("text").and_then(|t| t.as_str()))
433                    .collect::<Vec<_>>()
434                    .join("");
435                if text.is_empty() { None } else { Some(text) }
436            }
437            _ => None,
438        };
439
440        let tool_calls = message
441            .get("tool_calls")
442            .and_then(|tc| tc.as_array())
443            .map(|calls| {
444                calls
445                    .iter()
446                    .filter_map(|call| {
447                        let id = call.get("id").and_then(|v| v.as_str())?;
448                        let function = call.get("function")?;
449                        let name = function.get("name").and_then(|v| v.as_str())?;
450                        let arguments = function.get("arguments");
451                        let serialized = arguments.map_or("{}".to_string(), |value| {
452                            if value.is_string() {
453                                value.as_str().unwrap_or("").to_string()
454                            } else {
455                                value.to_string()
456                            }
457                        });
458                        Some(ToolCall::function(
459                            id.to_string(),
460                            name.to_string(),
461                            serialized,
462                        ))
463                    })
464                    .collect::<Vec<_>>()
465            })
466            .filter(|calls| !calls.is_empty());
467
468        let reasoning = message
469            .get("reasoning")
470            .and_then(extract_reasoning_trace)
471            .or_else(|| choice.get("reasoning").and_then(extract_reasoning_trace));
472
473        let finish_reason = choice
474            .get("finish_reason")
475            .and_then(|fr| fr.as_str())
476            .map(|fr| match fr {
477                "stop" => FinishReason::Stop,
478                "length" => FinishReason::Length,
479                "tool_calls" => FinishReason::ToolCalls,
480                "content_filter" => FinishReason::ContentFilter,
481                other => FinishReason::Error(other.to_string()),
482            })
483            .unwrap_or(FinishReason::Stop);
484
485        let usage = response_json
486            .get("usage")
487            .map(|usage_value| crate::llm::provider::Usage {
488                prompt_tokens: usage_value
489                    .get("prompt_tokens")
490                    .and_then(|pt| pt.as_u64())
491                    .unwrap_or(0) as u32,
492                completion_tokens: usage_value
493                    .get("completion_tokens")
494                    .and_then(|ct| ct.as_u64())
495                    .unwrap_or(0) as u32,
496                total_tokens: usage_value
497                    .get("total_tokens")
498                    .and_then(|tt| tt.as_u64())
499                    .unwrap_or(0) as u32,
500            });
501
502        Ok(LLMResponse {
503            content,
504            tool_calls,
505            usage,
506            finish_reason,
507            reasoning,
508        })
509    }
510}
511
512#[async_trait]
513impl LLMProvider for OpenAIProvider {
514    fn name(&self) -> &str {
515        "openai"
516    }
517
518    fn supports_reasoning(&self, _model: &str) -> bool {
519        false
520    }
521
522    fn supports_reasoning_effort(&self, model: &str) -> bool {
523        let requested = if model.trim().is_empty() {
524            self.model.as_str()
525        } else {
526            model
527        };
528        models::openai::REASONING_MODELS
529            .iter()
530            .any(|candidate| *candidate == requested)
531    }
532
533    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse, LLMError> {
534        let openai_request = self.convert_to_openai_format(&request)?;
535        let url = format!("{}/chat/completions", self.base_url);
536
537        let response = self
538            .http_client
539            .post(&url)
540            .bearer_auth(&self.api_key)
541            .json(&openai_request)
542            .send()
543            .await
544            .map_err(|e| {
545                let formatted_error =
546                    error_display::format_llm_error("OpenAI", &format!("Network error: {}", e));
547                LLMError::Network(formatted_error)
548            })?;
549
550        if !response.status().is_success() {
551            let status = response.status();
552            let error_text = response.text().await.unwrap_or_default();
553
554            // Handle specific HTTP status codes
555            if status.as_u16() == 429
556                || error_text.contains("insufficient_quota")
557                || error_text.contains("quota")
558                || error_text.contains("rate limit")
559            {
560                return Err(LLMError::RateLimit);
561            }
562
563            let formatted_error = error_display::format_llm_error(
564                "OpenAI",
565                &format!("HTTP {}: {}", status, error_text),
566            );
567            return Err(LLMError::Provider(formatted_error));
568        }
569
570        let openai_response: Value = response.json().await.map_err(|e| {
571            let formatted_error = error_display::format_llm_error(
572                "OpenAI",
573                &format!("Failed to parse response: {}", e),
574            );
575            LLMError::Provider(formatted_error)
576        })?;
577
578        self.parse_openai_response(openai_response)
579    }
580
581    fn supported_models(&self) -> Vec<String> {
582        models::openai::SUPPORTED_MODELS
583            .iter()
584            .map(|s| s.to_string())
585            .collect()
586    }
587
588    fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
589        if request.messages.is_empty() {
590            let formatted_error =
591                error_display::format_llm_error("OpenAI", "Messages cannot be empty");
592            return Err(LLMError::InvalidRequest(formatted_error));
593        }
594
595        if !self.supported_models().contains(&request.model) {
596            let formatted_error = error_display::format_llm_error(
597                "OpenAI",
598                &format!("Unsupported model: {}", request.model),
599            );
600            return Err(LLMError::InvalidRequest(formatted_error));
601        }
602
603        for message in &request.messages {
604            if let Err(err) = message.validate_for_provider("openai") {
605                let formatted = error_display::format_llm_error("OpenAI", &err);
606                return Err(LLMError::InvalidRequest(formatted));
607            }
608        }
609
610        Ok(())
611    }
612}
613
614#[async_trait]
615impl LLMClient for OpenAIProvider {
616    async fn generate(&mut self, prompt: &str) -> Result<llm_types::LLMResponse, LLMError> {
617        let request = self.parse_client_prompt(prompt);
618        let request_model = request.model.clone();
619        let response = LLMProvider::generate(self, request).await?;
620
621        Ok(llm_types::LLMResponse {
622            content: response.content.unwrap_or_default(),
623            model: request_model,
624            usage: response.usage.map(|u| llm_types::Usage {
625                prompt_tokens: u.prompt_tokens as usize,
626                completion_tokens: u.completion_tokens as usize,
627                total_tokens: u.total_tokens as usize,
628            }),
629            reasoning: response.reasoning,
630        })
631    }
632
633    fn backend_kind(&self) -> llm_types::BackendKind {
634        llm_types::BackendKind::OpenAI
635    }
636
637    fn model_id(&self) -> &str {
638        &self.model
639    }
640}