vtcode_core/llm/providers/
xai.rs

1use crate::config::constants::{models, urls};
2use crate::config::core::PromptCachingConfig;
3use crate::llm::client::LLMClient;
4use crate::llm::error_display;
5use crate::llm::provider::{LLMError, LLMProvider, LLMRequest, LLMResponse};
6use crate::llm::providers::openai::OpenAIProvider;
7use crate::llm::types as llm_types;
8use async_trait::async_trait;
9
10/// xAI provider that leverages the OpenAI-compatible Grok API surface
11pub struct XAIProvider {
12    inner: OpenAIProvider,
13    model: String,
14    prompt_cache_enabled: bool,
15}
16
17impl XAIProvider {
18    pub fn new(api_key: String) -> Self {
19        Self::with_model_internal(api_key, models::xai::DEFAULT_MODEL.to_string(), None)
20    }
21
22    pub fn with_model(api_key: String, model: String) -> Self {
23        Self::with_model_internal(api_key, model, None)
24    }
25
26    pub fn from_config(
27        api_key: Option<String>,
28        model: Option<String>,
29        base_url: Option<String>,
30        prompt_cache: Option<PromptCachingConfig>,
31    ) -> Self {
32        let resolved_model = model.unwrap_or_else(|| models::xai::DEFAULT_MODEL.to_string());
33        let resolved_base_url = base_url.unwrap_or_else(|| urls::XAI_API_BASE.to_string());
34        let (prompt_cache_enabled, prompt_cache_forward) =
35            Self::extract_prompt_cache_settings(prompt_cache);
36        let inner = OpenAIProvider::from_config(
37            api_key,
38            Some(resolved_model.clone()),
39            Some(resolved_base_url),
40            prompt_cache_forward,
41        );
42
43        Self {
44            inner,
45            model: resolved_model,
46            prompt_cache_enabled,
47        }
48    }
49
50    fn with_model_internal(
51        api_key: String,
52        model: String,
53        prompt_cache: Option<PromptCachingConfig>,
54    ) -> Self {
55        Self::from_config(Some(api_key), Some(model), None, prompt_cache)
56    }
57
58    fn extract_prompt_cache_settings(
59        prompt_cache: Option<PromptCachingConfig>,
60    ) -> (bool, Option<PromptCachingConfig>) {
61        if let Some(cfg) = prompt_cache {
62            let provider_enabled = cfg.providers.xai.enabled;
63            let enabled = cfg.enabled && provider_enabled;
64            if enabled {
65                (true, Some(cfg))
66            } else {
67                (false, None)
68            }
69        } else {
70            (true, None)
71        }
72    }
73}
74
75#[async_trait]
76impl LLMProvider for XAIProvider {
77    fn name(&self) -> &str {
78        "xai"
79    }
80
81    fn supports_reasoning(&self, model: &str) -> bool {
82        let requested = if model.trim().is_empty() {
83            self.model.as_str()
84        } else {
85            model
86        };
87
88        requested == models::xai::GROK_4
89            || requested == models::xai::GROK_4_CODE
90            || requested == models::xai::GROK_4_CODE_LATEST
91    }
92
93    fn supports_reasoning_effort(&self, _model: &str) -> bool {
94        false
95    }
96
97    async fn generate(&self, mut request: LLMRequest) -> Result<LLMResponse, LLMError> {
98        if !self.prompt_cache_enabled {
99            // xAI prompt caching is managed by the platform; no additional parameters required.
100        }
101
102        if request.model.trim().is_empty() {
103            request.model = self.model.clone();
104        }
105        self.inner.generate(request).await
106    }
107
108    fn supported_models(&self) -> Vec<String> {
109        models::xai::SUPPORTED_MODELS
110            .iter()
111            .map(|s| s.to_string())
112            .collect()
113    }
114
115    fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
116        if request.messages.is_empty() {
117            let formatted = error_display::format_llm_error("xAI", "Messages cannot be empty");
118            return Err(LLMError::InvalidRequest(formatted));
119        }
120
121        if !request.model.trim().is_empty() && !self.supported_models().contains(&request.model) {
122            let formatted = error_display::format_llm_error(
123                "xAI",
124                &format!("Unsupported model: {}", request.model),
125            );
126            return Err(LLMError::InvalidRequest(formatted));
127        }
128
129        for message in &request.messages {
130            if let Err(err) = message.validate_for_provider("openai") {
131                let formatted = error_display::format_llm_error("xAI", &err);
132                return Err(LLMError::InvalidRequest(formatted));
133            }
134        }
135
136        Ok(())
137    }
138}
139
140#[async_trait]
141impl LLMClient for XAIProvider {
142    async fn generate(&mut self, prompt: &str) -> Result<llm_types::LLMResponse, LLMError> {
143        <OpenAIProvider as LLMClient>::generate(&mut self.inner, prompt).await
144    }
145
146    fn backend_kind(&self) -> llm_types::BackendKind {
147        llm_types::BackendKind::XAI
148    }
149
150    fn model_id(&self) -> &str {
151        &self.model
152    }
153}