vtcode_core/llm/providers/
xai.rs1use crate::config::constants::{models, urls};
2use crate::config::core::PromptCachingConfig;
3use crate::llm::client::LLMClient;
4use crate::llm::error_display;
5use crate::llm::provider::{LLMError, LLMProvider, LLMRequest, LLMResponse};
6use crate::llm::providers::openai::OpenAIProvider;
7use crate::llm::types as llm_types;
8use async_trait::async_trait;
9
10pub struct XAIProvider {
12 inner: OpenAIProvider,
13 model: String,
14 prompt_cache_enabled: bool,
15}
16
17impl XAIProvider {
18 pub fn new(api_key: String) -> Self {
19 Self::with_model_internal(api_key, models::xai::DEFAULT_MODEL.to_string(), None)
20 }
21
22 pub fn with_model(api_key: String, model: String) -> Self {
23 Self::with_model_internal(api_key, model, None)
24 }
25
26 pub fn from_config(
27 api_key: Option<String>,
28 model: Option<String>,
29 base_url: Option<String>,
30 prompt_cache: Option<PromptCachingConfig>,
31 ) -> Self {
32 let resolved_model = model.unwrap_or_else(|| models::xai::DEFAULT_MODEL.to_string());
33 let resolved_base_url = base_url.unwrap_or_else(|| urls::XAI_API_BASE.to_string());
34 let (prompt_cache_enabled, prompt_cache_forward) =
35 Self::extract_prompt_cache_settings(prompt_cache);
36 let inner = OpenAIProvider::from_config(
37 api_key,
38 Some(resolved_model.clone()),
39 Some(resolved_base_url),
40 prompt_cache_forward,
41 );
42
43 Self {
44 inner,
45 model: resolved_model,
46 prompt_cache_enabled,
47 }
48 }
49
50 fn with_model_internal(
51 api_key: String,
52 model: String,
53 prompt_cache: Option<PromptCachingConfig>,
54 ) -> Self {
55 Self::from_config(Some(api_key), Some(model), None, prompt_cache)
56 }
57
58 fn extract_prompt_cache_settings(
59 prompt_cache: Option<PromptCachingConfig>,
60 ) -> (bool, Option<PromptCachingConfig>) {
61 if let Some(cfg) = prompt_cache {
62 let provider_enabled = cfg.providers.xai.enabled;
63 let enabled = cfg.enabled && provider_enabled;
64 if enabled {
65 (true, Some(cfg))
66 } else {
67 (false, None)
68 }
69 } else {
70 (true, None)
71 }
72 }
73}
74
75#[async_trait]
76impl LLMProvider for XAIProvider {
77 fn name(&self) -> &str {
78 "xai"
79 }
80
81 fn supports_reasoning(&self, model: &str) -> bool {
82 let requested = if model.trim().is_empty() {
83 self.model.as_str()
84 } else {
85 model
86 };
87
88 requested == models::xai::GROK_4
89 || requested == models::xai::GROK_4_CODE
90 || requested == models::xai::GROK_4_CODE_LATEST
91 }
92
93 fn supports_reasoning_effort(&self, _model: &str) -> bool {
94 false
95 }
96
97 async fn generate(&self, mut request: LLMRequest) -> Result<LLMResponse, LLMError> {
98 if !self.prompt_cache_enabled {
99 }
101
102 if request.model.trim().is_empty() {
103 request.model = self.model.clone();
104 }
105 self.inner.generate(request).await
106 }
107
108 fn supported_models(&self) -> Vec<String> {
109 models::xai::SUPPORTED_MODELS
110 .iter()
111 .map(|s| s.to_string())
112 .collect()
113 }
114
115 fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
116 if request.messages.is_empty() {
117 let formatted = error_display::format_llm_error("xAI", "Messages cannot be empty");
118 return Err(LLMError::InvalidRequest(formatted));
119 }
120
121 if !request.model.trim().is_empty() && !self.supported_models().contains(&request.model) {
122 let formatted = error_display::format_llm_error(
123 "xAI",
124 &format!("Unsupported model: {}", request.model),
125 );
126 return Err(LLMError::InvalidRequest(formatted));
127 }
128
129 for message in &request.messages {
130 if let Err(err) = message.validate_for_provider("openai") {
131 let formatted = error_display::format_llm_error("xAI", &err);
132 return Err(LLMError::InvalidRequest(formatted));
133 }
134 }
135
136 Ok(())
137 }
138}
139
140#[async_trait]
141impl LLMClient for XAIProvider {
142 async fn generate(&mut self, prompt: &str) -> Result<llm_types::LLMResponse, LLMError> {
143 <OpenAIProvider as LLMClient>::generate(&mut self.inner, prompt).await
144 }
145
146 fn backend_kind(&self) -> llm_types::BackendKind {
147 llm_types::BackendKind::XAI
148 }
149
150 fn model_id(&self) -> &str {
151 &self.model
152 }
153}