1use serde::Deserialize;
4
5use crate::client::XaiClient;
6use crate::{Error, Result};
7
8#[derive(Debug, Clone)]
10pub struct ModelsApi {
11 client: XaiClient,
12}
13
14impl ModelsApi {
15 pub(crate) fn new(client: XaiClient) -> Self {
16 Self { client }
17 }
18
19 pub async fn list(&self) -> Result<ModelListResponse> {
37 let url = format!("{}/models", self.client.base_url());
38
39 let response = self.client.send(self.client.http().get(&url)).await?;
40
41 if !response.status().is_success() {
42 return Err(Error::from_response(response).await);
43 }
44
45 Ok(response.json().await?)
46 }
47
48 pub async fn get(&self, model_id: &str) -> Result<Model> {
50 let id = XaiClient::encode_path(model_id);
51 let url = format!("{}/models/{}", self.client.base_url(), id);
52
53 let response = self.client.send(self.client.http().get(&url)).await?;
54
55 if !response.status().is_success() {
56 return Err(Error::from_response(response).await);
57 }
58
59 Ok(response.json().await?)
60 }
61
62 pub async fn language_models(&self) -> Result<ModelListResponse> {
64 let url = format!("{}/language-models", self.client.base_url());
65
66 let response = self.client.send(self.client.http().get(&url)).await?;
67
68 if !response.status().is_success() {
69 return Err(Error::from_response(response).await);
70 }
71
72 Ok(response.json().await?)
73 }
74
75 pub async fn language_model(&self, model_id: &str) -> Result<Model> {
77 let id = XaiClient::encode_path(model_id);
78 let url = format!("{}/language-models/{}", self.client.base_url(), id);
79
80 let response = self.client.send(self.client.http().get(&url)).await?;
81
82 if !response.status().is_success() {
83 return Err(Error::from_response(response).await);
84 }
85
86 Ok(response.json().await?)
87 }
88
89 pub async fn embedding_models(&self) -> Result<ModelListResponse> {
91 let url = format!("{}/embedding-models", self.client.base_url());
92
93 let response = self.client.send(self.client.http().get(&url)).await?;
94
95 if !response.status().is_success() {
96 return Err(Error::from_response(response).await);
97 }
98
99 Ok(response.json().await?)
100 }
101
102 pub async fn embedding_model(&self, model_id: &str) -> Result<Model> {
104 let id = XaiClient::encode_path(model_id);
105 let url = format!("{}/embedding-models/{}", self.client.base_url(), id);
106
107 let response = self.client.send(self.client.http().get(&url)).await?;
108
109 if !response.status().is_success() {
110 return Err(Error::from_response(response).await);
111 }
112
113 Ok(response.json().await?)
114 }
115
116 pub async fn image_generation_models(&self) -> Result<ModelListResponse> {
118 let url = format!("{}/image-generation-models", self.client.base_url());
119
120 let response = self.client.send(self.client.http().get(&url)).await?;
121
122 if !response.status().is_success() {
123 return Err(Error::from_response(response).await);
124 }
125
126 Ok(response.json().await?)
127 }
128
129 pub async fn image_generation_model(&self, model_id: &str) -> Result<Model> {
131 let id = XaiClient::encode_path(model_id);
132 let url = format!("{}/image-generation-models/{}", self.client.base_url(), id);
133
134 let response = self.client.send(self.client.http().get(&url)).await?;
135
136 if !response.status().is_success() {
137 return Err(Error::from_response(response).await);
138 }
139
140 Ok(response.json().await?)
141 }
142}
143
144#[derive(Debug, Clone, Deserialize)]
146pub struct ModelListResponse {
147 pub object: String,
149 pub data: Vec<Model>,
151}
152
153#[derive(Debug, Clone, Deserialize)]
155pub struct Model {
156 pub id: String,
158 pub object: String,
160 pub owned_by: String,
162 #[serde(default)]
164 pub created: Option<i64>,
165}
166
167#[allow(dead_code)]
169pub mod known_models {
170 pub const GROK_4: &str = "grok-4";
172 pub const GROK_4_FAST: &str = "grok-4-fast";
174 pub const GROK_4_1_FAST: &str = "grok-4-1-fast";
176 pub const GROK_4_FAST_NON_REASONING: &str = "grok-4-fast-non-reasoning";
178 pub const GROK_4_1_FAST_NON_REASONING: &str = "grok-4-1-fast-non-reasoning";
180 pub const GROK_CODE_FAST_1: &str = "grok-code-fast-1";
182 pub const GROK_2_IMAGE: &str = "grok-2-image";
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use serde_json::json;
190 use wiremock::matchers::{method, path};
191 use wiremock::{Mock, MockServer, ResponseTemplate};
192
193 #[tokio::test]
194 async fn list_requests_models_endpoint_and_parses_response() {
195 let server = MockServer::start().await;
196
197 Mock::given(method("GET"))
198 .and(path("/models"))
199 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
200 "object": "list",
201 "data": [{
202 "id": "grok-4",
203 "object": "model",
204 "owned_by": "xai",
205 "created": 1700000000
206 }]
207 })))
208 .mount(&server)
209 .await;
210
211 let client = XaiClient::builder()
212 .api_key("test-key")
213 .base_url(server.uri())
214 .build()
215 .unwrap();
216
217 let response = client.models().list().await.unwrap();
218 assert_eq!(response.object, "list");
219 assert_eq!(response.data.len(), 1);
220 assert_eq!(response.data[0].id, "grok-4");
221 assert_eq!(response.data[0].owned_by, "xai");
222 assert_eq!(response.data[0].created, Some(1700000000));
223 }
224
225 #[tokio::test]
226 async fn get_encodes_model_id_as_path_segment() {
227 let server = MockServer::start().await;
228
229 Mock::given(method("GET"))
230 .and(path("/models/grok%2Fspecial%20v1"))
231 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
232 "id": "grok/special v1",
233 "object": "model",
234 "owned_by": "xai"
235 })))
236 .mount(&server)
237 .await;
238
239 let client = XaiClient::builder()
240 .api_key("test-key")
241 .base_url(server.uri())
242 .build()
243 .unwrap();
244
245 let model = client.models().get("grok/special v1").await.unwrap();
246 assert_eq!(model.id, "grok/special v1");
247 }
248
249 #[tokio::test]
250 async fn list_propagates_api_error_response() {
251 let server = MockServer::start().await;
252
253 Mock::given(method("GET"))
254 .and(path("/models"))
255 .respond_with(ResponseTemplate::new(503).set_body_json(json!({
256 "error": {
257 "message": "service unavailable",
258 "type": "server_error"
259 }
260 })))
261 .mount(&server)
262 .await;
263
264 let client = XaiClient::builder()
265 .api_key("test-key")
266 .base_url(server.uri())
267 .build()
268 .unwrap();
269
270 let err = client.models().list().await.unwrap_err();
271 match err {
272 Error::Api {
273 status,
274 message,
275 error_type,
276 } => {
277 assert_eq!(status, 503);
278 assert_eq!(message, "service unavailable");
279 assert_eq!(error_type.as_deref(), Some("server_error"));
280 }
281 other => panic!("expected api error, got {other:?}"),
282 }
283 }
284}