rustic_ai/providers/
grok.rs1use std::sync::Arc;
2
3use reqwest::Url;
4
5use crate::model::{Model, ModelSettings};
6use crate::providers::openai::{OpenAIChatCapabilities, OpenAIChatModel};
7use crate::providers::{Provider, ProviderError};
8
9#[derive(Clone, Debug)]
10pub struct GrokProvider {
11 api_key: String,
12 base_url: Url,
13}
14
15impl GrokProvider {
16 pub fn new(
17 api_key: impl Into<String>,
18 base_url: impl AsRef<str>,
19 ) -> Result<Self, ProviderError> {
20 let url = Url::parse(base_url.as_ref())
21 .map_err(|_| ProviderError::InvalidModel(base_url.as_ref().to_string()))?;
22 Ok(Self {
23 api_key: api_key.into(),
24 base_url: url,
25 })
26 }
27
28 pub fn from_env() -> Result<Self, ProviderError> {
29 let api_key = std::env::var("XAI_API_KEY")
30 .or_else(|_| std::env::var("GROK_API_KEY"))
31 .map_err(|_| ProviderError::MissingApiKey("grok".to_string()))?;
32 Self::new(api_key, "https://api.x.ai/v1")
33 }
34}
35
36impl Provider for GrokProvider {
37 fn name(&self) -> &str {
38 "grok"
39 }
40
41 fn model(&self, model: &str, settings: Option<ModelSettings>) -> Arc<dyn Model> {
42 let capabilities = OpenAIChatCapabilities {
43 supports_response_format: false,
44 supports_parallel_tool_calls: true,
45 reject_binary_images: true,
46 };
47 Arc::new(OpenAIChatModel::new_with_capabilities(
48 model,
49 self.api_key.clone(),
50 self.base_url.clone(),
51 settings,
52 capabilities,
53 ))
54 }
55}
56
57#[cfg(test)]
58mod tests {
59 use super::*;
60
61 #[test]
62 fn grok_provider_rejects_invalid_url() {
63 let err = GrokProvider::new("key", "not a url").expect_err("invalid url");
64 assert!(matches!(err, ProviderError::InvalidModel(_)));
65 }
66
67 #[test]
68 fn grok_provider_builds_model() {
69 let provider = GrokProvider::new("key", "https://api.x.ai/v1").expect("valid provider");
70 assert_eq!(provider.name(), "grok");
71 let model = provider.model("grok-test", None);
72 assert_eq!(model.name(), "grok-test");
73 }
74}