vtcode_core/llm/providers/
xai.rs

1use crate::config::constants::{models, urls};
2use crate::llm::client::LLMClient;
3use crate::llm::error_display;
4use crate::llm::provider::{LLMError, LLMProvider, LLMRequest, LLMResponse};
5use crate::llm::providers::openai::OpenAIProvider;
6use crate::llm::types as llm_types;
7use async_trait::async_trait;
8
9/// xAI provider that leverages the OpenAI-compatible Grok API surface
10pub struct XAIProvider {
11    inner: OpenAIProvider,
12    model: String,
13}
14
15impl XAIProvider {
16    pub fn new(api_key: String) -> Self {
17        Self::with_model(api_key, models::xai::DEFAULT_MODEL.to_string())
18    }
19
20    pub fn with_model(api_key: String, model: String) -> Self {
21        Self::from_config(Some(api_key), Some(model), None)
22    }
23
24    pub fn from_config(
25        api_key: Option<String>,
26        model: Option<String>,
27        base_url: Option<String>,
28    ) -> Self {
29        let resolved_model = model.unwrap_or_else(|| models::xai::DEFAULT_MODEL.to_string());
30        let resolved_base_url = base_url.unwrap_or_else(|| urls::XAI_API_BASE.to_string());
31        let inner = OpenAIProvider::from_config(
32            api_key,
33            Some(resolved_model.clone()),
34            Some(resolved_base_url),
35        );
36
37        Self {
38            inner,
39            model: resolved_model,
40        }
41    }
42}
43
44#[async_trait]
45impl LLMProvider for XAIProvider {
46    fn name(&self) -> &str {
47        "xai"
48    }
49
50    fn supports_reasoning(&self, model: &str) -> bool {
51        let requested = if model.trim().is_empty() {
52            self.model.as_str()
53        } else {
54            model
55        };
56        requested.contains("reasoning")
57    }
58
59    fn supports_reasoning_effort(&self, _model: &str) -> bool {
60        false
61    }
62
63    async fn generate(&self, mut request: LLMRequest) -> Result<LLMResponse, LLMError> {
64        if request.model.trim().is_empty() {
65            request.model = self.model.clone();
66        }
67        self.inner.generate(request).await
68    }
69
70    fn supported_models(&self) -> Vec<String> {
71        models::xai::SUPPORTED_MODELS
72            .iter()
73            .map(|s| s.to_string())
74            .collect()
75    }
76
77    fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
78        if request.messages.is_empty() {
79            let formatted = error_display::format_llm_error("xAI", "Messages cannot be empty");
80            return Err(LLMError::InvalidRequest(formatted));
81        }
82
83        if !request.model.trim().is_empty() && !self.supported_models().contains(&request.model) {
84            let formatted = error_display::format_llm_error(
85                "xAI",
86                &format!("Unsupported model: {}", request.model),
87            );
88            return Err(LLMError::InvalidRequest(formatted));
89        }
90
91        for message in &request.messages {
92            if let Err(err) = message.validate_for_provider("openai") {
93                let formatted = error_display::format_llm_error("xAI", &err);
94                return Err(LLMError::InvalidRequest(formatted));
95            }
96        }
97
98        Ok(())
99    }
100}
101
102#[async_trait]
103impl LLMClient for XAIProvider {
104    async fn generate(&mut self, prompt: &str) -> Result<llm_types::LLMResponse, LLMError> {
105        <OpenAIProvider as LLMClient>::generate(&mut self.inner, prompt).await
106    }
107
108    fn backend_kind(&self) -> llm_types::BackendKind {
109        llm_types::BackendKind::XAI
110    }
111
112    fn model_id(&self) -> &str {
113        &self.model
114    }
115}