Skip to main content

xai_rust/api/
models.rs

1//! Models API for listing available models.
2
3use serde::Deserialize;
4
5use crate::client::XaiClient;
6use crate::{Error, Result};
7
8/// Models API for listing available models.
9#[derive(Debug, Clone)]
10pub struct ModelsApi {
11    client: XaiClient,
12}
13
14impl ModelsApi {
15    pub(crate) fn new(client: XaiClient) -> Self {
16        Self { client }
17    }
18
19    /// List all available models.
20    ///
21    /// # Example
22    ///
23    /// ```rust,no_run
24    /// use xai_rust::XaiClient;
25    ///
26    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
27    /// let client = XaiClient::from_env()?;
28    ///
29    /// let models = client.models().list().await?;
30    /// for model in models.data {
31    ///     println!("{}: {}", model.id, model.owned_by);
32    /// }
33    /// # Ok(())
34    /// # }
35    /// ```
36    pub async fn list(&self) -> Result<ModelListResponse> {
37        let url = format!("{}/models", self.client.base_url());
38
39        let response = self.client.send(self.client.http().get(&url)).await?;
40
41        if !response.status().is_success() {
42            return Err(Error::from_response(response).await);
43        }
44
45        Ok(response.json().await?)
46    }
47
48    /// Get a specific model by ID.
49    pub async fn get(&self, model_id: &str) -> Result<Model> {
50        let id = XaiClient::encode_path(model_id);
51        let url = format!("{}/models/{}", self.client.base_url(), id);
52
53        let response = self.client.send(self.client.http().get(&url)).await?;
54
55        if !response.status().is_success() {
56            return Err(Error::from_response(response).await);
57        }
58
59        Ok(response.json().await?)
60    }
61
62    /// List all available language models.
63    pub async fn language_models(&self) -> Result<ModelListResponse> {
64        let url = format!("{}/language-models", self.client.base_url());
65
66        let response = self.client.send(self.client.http().get(&url)).await?;
67
68        if !response.status().is_success() {
69            return Err(Error::from_response(response).await);
70        }
71
72        Ok(response.json().await?)
73    }
74
75    /// Get a specific language model by ID.
76    pub async fn language_model(&self, model_id: &str) -> Result<Model> {
77        let id = XaiClient::encode_path(model_id);
78        let url = format!("{}/language-models/{}", self.client.base_url(), id);
79
80        let response = self.client.send(self.client.http().get(&url)).await?;
81
82        if !response.status().is_success() {
83            return Err(Error::from_response(response).await);
84        }
85
86        Ok(response.json().await?)
87    }
88
89    /// List all available embedding models.
90    pub async fn embedding_models(&self) -> Result<ModelListResponse> {
91        let url = format!("{}/embedding-models", self.client.base_url());
92
93        let response = self.client.send(self.client.http().get(&url)).await?;
94
95        if !response.status().is_success() {
96            return Err(Error::from_response(response).await);
97        }
98
99        Ok(response.json().await?)
100    }
101
102    /// Get a specific embedding model by ID.
103    pub async fn embedding_model(&self, model_id: &str) -> Result<Model> {
104        let id = XaiClient::encode_path(model_id);
105        let url = format!("{}/embedding-models/{}", self.client.base_url(), id);
106
107        let response = self.client.send(self.client.http().get(&url)).await?;
108
109        if !response.status().is_success() {
110            return Err(Error::from_response(response).await);
111        }
112
113        Ok(response.json().await?)
114    }
115
116    /// List all available image-generation models.
117    pub async fn image_generation_models(&self) -> Result<ModelListResponse> {
118        let url = format!("{}/image-generation-models", self.client.base_url());
119
120        let response = self.client.send(self.client.http().get(&url)).await?;
121
122        if !response.status().is_success() {
123            return Err(Error::from_response(response).await);
124        }
125
126        Ok(response.json().await?)
127    }
128
129    /// Get a specific image-generation model by ID.
130    pub async fn image_generation_model(&self, model_id: &str) -> Result<Model> {
131        let id = XaiClient::encode_path(model_id);
132        let url = format!("{}/image-generation-models/{}", self.client.base_url(), id);
133
134        let response = self.client.send(self.client.http().get(&url)).await?;
135
136        if !response.status().is_success() {
137            return Err(Error::from_response(response).await);
138        }
139
140        Ok(response.json().await?)
141    }
142}
143
144/// Response from listing models.
145#[derive(Debug, Clone, Deserialize)]
146pub struct ModelListResponse {
147    /// Object type.
148    pub object: String,
149    /// List of models.
150    pub data: Vec<Model>,
151}
152
153/// A model available via the API.
154#[derive(Debug, Clone, Deserialize)]
155pub struct Model {
156    /// Model identifier.
157    pub id: String,
158    /// Object type.
159    pub object: String,
160    /// Owner of the model.
161    pub owned_by: String,
162    /// Creation timestamp.
163    #[serde(default)]
164    pub created: Option<i64>,
165}
166
167/// Known model identifiers.
168#[allow(dead_code)]
169pub mod known_models {
170    /// Grok 4 - flagship model.
171    pub const GROK_4: &str = "grok-4";
172    /// Grok 4 Fast - faster variant.
173    pub const GROK_4_FAST: &str = "grok-4-fast";
174    /// Grok 4-1 Fast - optimized for agentic tasks.
175    pub const GROK_4_1_FAST: &str = "grok-4-1-fast";
176    /// Grok 4 Fast (non-reasoning variant).
177    pub const GROK_4_FAST_NON_REASONING: &str = "grok-4-fast-non-reasoning";
178    /// Grok 4-1 Fast (non-reasoning variant).
179    pub const GROK_4_1_FAST_NON_REASONING: &str = "grok-4-1-fast-non-reasoning";
180    /// Grok Code Fast 1 - optimized for coding.
181    pub const GROK_CODE_FAST_1: &str = "grok-code-fast-1";
182    /// Grok 2 Image - for image generation.
183    pub const GROK_2_IMAGE: &str = "grok-2-image";
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use serde_json::json;
190    use wiremock::matchers::{method, path};
191    use wiremock::{Mock, MockServer, ResponseTemplate};
192
193    #[tokio::test]
194    async fn list_requests_models_endpoint_and_parses_response() {
195        let server = MockServer::start().await;
196
197        Mock::given(method("GET"))
198            .and(path("/models"))
199            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
200                "object": "list",
201                "data": [{
202                    "id": "grok-4",
203                    "object": "model",
204                    "owned_by": "xai",
205                    "created": 1700000000
206                }]
207            })))
208            .mount(&server)
209            .await;
210
211        let client = XaiClient::builder()
212            .api_key("test-key")
213            .base_url(server.uri())
214            .build()
215            .unwrap();
216
217        let response = client.models().list().await.unwrap();
218        assert_eq!(response.object, "list");
219        assert_eq!(response.data.len(), 1);
220        assert_eq!(response.data[0].id, "grok-4");
221        assert_eq!(response.data[0].owned_by, "xai");
222        assert_eq!(response.data[0].created, Some(1700000000));
223    }
224
225    #[tokio::test]
226    async fn get_encodes_model_id_as_path_segment() {
227        let server = MockServer::start().await;
228
229        Mock::given(method("GET"))
230            .and(path("/models/grok%2Fspecial%20v1"))
231            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
232                "id": "grok/special v1",
233                "object": "model",
234                "owned_by": "xai"
235            })))
236            .mount(&server)
237            .await;
238
239        let client = XaiClient::builder()
240            .api_key("test-key")
241            .base_url(server.uri())
242            .build()
243            .unwrap();
244
245        let model = client.models().get("grok/special v1").await.unwrap();
246        assert_eq!(model.id, "grok/special v1");
247    }
248
249    #[tokio::test]
250    async fn list_propagates_api_error_response() {
251        let server = MockServer::start().await;
252
253        Mock::given(method("GET"))
254            .and(path("/models"))
255            .respond_with(ResponseTemplate::new(503).set_body_json(json!({
256                "error": {
257                    "message": "service unavailable",
258                    "type": "server_error"
259                }
260            })))
261            .mount(&server)
262            .await;
263
264        let client = XaiClient::builder()
265            .api_key("test-key")
266            .base_url(server.uri())
267            .build()
268            .unwrap();
269
270        let err = client.models().list().await.unwrap_err();
271        match err {
272            Error::Api {
273                status,
274                message,
275                error_type,
276            } => {
277                assert_eq!(status, 503);
278                assert_eq!(message, "service unavailable");
279                assert_eq!(error_type.as_deref(), Some("server_error"));
280            }
281            other => panic!("expected api error, got {other:?}"),
282        }
283    }
284}