Skip to main content

rustic_ai/providers/
grok.rs

1use std::sync::Arc;
2
3use reqwest::Url;
4
5use crate::model::{Model, ModelSettings};
6use crate::providers::openai::{OpenAIChatCapabilities, OpenAIChatModel};
7use crate::providers::{Provider, ProviderError};
8
9#[derive(Clone, Debug)]
10pub struct GrokProvider {
11    api_key: String,
12    base_url: Url,
13}
14
15impl GrokProvider {
16    pub fn new(
17        api_key: impl Into<String>,
18        base_url: impl AsRef<str>,
19    ) -> Result<Self, ProviderError> {
20        let url = Url::parse(base_url.as_ref())
21            .map_err(|_| ProviderError::InvalidModel(base_url.as_ref().to_string()))?;
22        Ok(Self {
23            api_key: api_key.into(),
24            base_url: url,
25        })
26    }
27
28    pub fn from_env() -> Result<Self, ProviderError> {
29        let api_key = std::env::var("XAI_API_KEY")
30            .or_else(|_| std::env::var("GROK_API_KEY"))
31            .map_err(|_| ProviderError::MissingApiKey("grok".to_string()))?;
32        Self::new(api_key, "https://api.x.ai/v1")
33    }
34}
35
36impl Provider for GrokProvider {
37    fn name(&self) -> &str {
38        "grok"
39    }
40
41    fn model(&self, model: &str, settings: Option<ModelSettings>) -> Arc<dyn Model> {
42        let capabilities = OpenAIChatCapabilities {
43            supports_response_format: false,
44            supports_parallel_tool_calls: true,
45            reject_binary_images: true,
46        };
47        Arc::new(OpenAIChatModel::new_with_capabilities(
48            model,
49            self.api_key.clone(),
50            self.base_url.clone(),
51            settings,
52            capabilities,
53        ))
54    }
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60
61    #[test]
62    fn grok_provider_rejects_invalid_url() {
63        let err = GrokProvider::new("key", "not a url").expect_err("invalid url");
64        assert!(matches!(err, ProviderError::InvalidModel(_)));
65    }
66
67    #[test]
68    fn grok_provider_builds_model() {
69        let provider = GrokProvider::new("key", "https://api.x.ai/v1").expect("valid provider");
70        assert_eq!(provider.name(), "grok");
71        let model = provider.model("grok-test", None);
72        assert_eq!(model.name(), "grok-test");
73    }
74}