1use serde::Serialize;
4
5use crate::client::XaiClient;
6use crate::models::image::{ImageGenerationRequest, ImageGenerationResponse, ImageResponseFormat};
7use crate::{Error, Result};
8
9#[derive(Debug, Clone)]
11pub struct ImagesApi {
12 client: XaiClient,
13}
14
15impl ImagesApi {
16 pub(crate) fn new(client: XaiClient) -> Self {
17 Self { client }
18 }
19
20 pub fn generate(
42 &self,
43 model: impl Into<String>,
44 prompt: impl Into<String>,
45 ) -> ImageGenerationBuilder {
46 ImageGenerationBuilder::new(self.client.clone(), model.into(), prompt.into())
47 }
48
49 pub fn edit(
51 &self,
52 model: impl Into<String>,
53 image: impl Into<String>,
54 prompt: impl Into<String>,
55 ) -> ImageEditBuilder {
56 ImageEditBuilder::new(
57 self.client.clone(),
58 model.into(),
59 image.into(),
60 prompt.into(),
61 )
62 }
63}
64
65#[derive(Debug)]
67pub struct ImageGenerationBuilder {
68 client: XaiClient,
69 request: ImageGenerationRequest,
70}
71
72impl ImageGenerationBuilder {
73 fn new(client: XaiClient, model: String, prompt: String) -> Self {
74 Self {
75 client,
76 request: ImageGenerationRequest {
77 model,
78 prompt,
79 n: None,
80 response_format: None,
81 },
82 }
83 }
84
85 pub fn n(mut self, n: u8) -> Self {
87 self.request.n = Some(n.clamp(1, 10));
88 self
89 }
90
91 pub fn response_format(mut self, format: ImageResponseFormat) -> Self {
93 self.request.response_format = Some(format);
94 self
95 }
96
97 pub fn url_format(self) -> Self {
99 self.response_format(ImageResponseFormat::Url)
100 }
101
102 pub fn base64_format(self) -> Self {
104 self.response_format(ImageResponseFormat::B64Json)
105 }
106
107 pub async fn send(self) -> Result<ImageGenerationResponse> {
109 let url = format!("{}/images/generations", self.client.base_url());
110
111 let response = self
112 .client
113 .send(self.client.http().post(&url).json(&self.request))
114 .await?;
115
116 if !response.status().is_success() {
117 return Err(Error::from_response(response).await);
118 }
119
120 Ok(response.json().await?)
121 }
122}
123
124#[derive(Debug, Clone, Serialize)]
126struct ImageEditRequest {
127 model: String,
128 image: String,
129 prompt: String,
130 n: Option<u8>,
131 response_format: Option<ImageResponseFormat>,
132}
133
134impl ImageEditRequest {
135 fn new(model: String, image: String, prompt: String) -> Self {
136 Self {
137 model,
138 image,
139 prompt,
140 n: None,
141 response_format: None,
142 }
143 }
144}
145
146#[derive(Debug)]
148pub struct ImageEditBuilder {
149 client: XaiClient,
150 request: ImageEditRequest,
151}
152
153impl ImageEditBuilder {
154 fn new(client: XaiClient, model: String, image: String, prompt: String) -> Self {
155 Self {
156 client,
157 request: ImageEditRequest::new(model, image, prompt),
158 }
159 }
160
161 pub fn n(mut self, n: u8) -> Self {
163 self.request.n = Some(n.clamp(1, 10));
164 self
165 }
166
167 pub fn response_format(mut self, format: ImageResponseFormat) -> Self {
169 self.request.response_format = Some(format);
170 self
171 }
172
173 pub fn url_format(self) -> Self {
175 self.response_format(ImageResponseFormat::Url)
176 }
177
178 pub fn base64_format(self) -> Self {
180 self.response_format(ImageResponseFormat::B64Json)
181 }
182
183 pub async fn send(self) -> Result<ImageGenerationResponse> {
185 let url = format!("{}/images/edits", self.client.base_url());
186 let response = self
187 .client
188 .send(self.client.http().post(&url).json(&self.request))
189 .await?;
190
191 if !response.status().is_success() {
192 return Err(Error::from_response(response).await);
193 }
194
195 Ok(response.json().await?)
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use serde_json::json;
203 use wiremock::matchers::{method, path};
204 use wiremock::{Mock, MockServer, ResponseTemplate};
205
206 #[tokio::test]
207 async fn generate_forwards_n_and_base64_format() {
208 let server = MockServer::start().await;
209
210 Mock::given(method("POST"))
211 .and(path("/images/generations"))
212 .respond_with(move |req: &wiremock::Request| {
213 let body = serde_json::from_slice::<serde_json::Value>(&req.body).unwrap();
214 assert_eq!(body["model"], "grok-2-image");
215 assert_eq!(body["prompt"], "draw");
216 assert_eq!(body["n"], 2);
217 assert_eq!(body["response_format"], "b64_json");
218 ResponseTemplate::new(200).set_body_json(json!({
219 "created": 1700000000,
220 "data": [{"b64_json": "aGVsbG8="}]
221 }))
222 })
223 .mount(&server)
224 .await;
225
226 let client = XaiClient::builder()
227 .api_key("test-key")
228 .base_url(server.uri())
229 .build()
230 .unwrap();
231
232 let response = client
233 .images()
234 .generate("grok-2-image", "draw")
235 .n(2)
236 .base64_format()
237 .send()
238 .await
239 .unwrap();
240
241 assert_eq!(response.first_base64(), Some("aGVsbG8="));
242 }
243
244 #[tokio::test]
245 async fn generate_clamps_n_to_max_ten() {
246 let server = MockServer::start().await;
247
248 Mock::given(method("POST"))
249 .and(path("/images/generations"))
250 .respond_with(move |req: &wiremock::Request| {
251 let body = serde_json::from_slice::<serde_json::Value>(&req.body).unwrap();
252 assert_eq!(body["n"], 10);
253 ResponseTemplate::new(200).set_body_json(json!({
254 "created": 1700000000,
255 "data": [{"url": "https://example.com/image.png"}]
256 }))
257 })
258 .mount(&server)
259 .await;
260
261 let client = XaiClient::builder()
262 .api_key("test-key")
263 .base_url(server.uri())
264 .build()
265 .unwrap();
266
267 let response = client
268 .images()
269 .generate("grok-2-image", "draw")
270 .n(99)
271 .send()
272 .await
273 .unwrap();
274
275 assert_eq!(response.first_url(), Some("https://example.com/image.png"));
276 }
277
278 #[tokio::test]
279 async fn edit_forwards_payload_and_path() {
280 let server = MockServer::start().await;
281
282 Mock::given(method("POST"))
283 .and(path("/images/edits"))
284 .respond_with(move |req: &wiremock::Request| {
285 let body = serde_json::from_slice::<serde_json::Value>(&req.body).unwrap();
286 assert_eq!(body["model"], "grok-2-image");
287 assert_eq!(body["image"], "image-1");
288 assert_eq!(body["prompt"], "sunset");
289 assert_eq!(body["n"], 3);
290 ResponseTemplate::new(200).set_body_json(json!({
291 "created": 1700000000,
292 "data": [{"url": "https://example.com/edited.png"}]
293 }))
294 })
295 .mount(&server)
296 .await;
297
298 let client = XaiClient::builder()
299 .api_key("test-key")
300 .base_url(server.uri())
301 .build()
302 .unwrap();
303
304 let response = client
305 .images()
306 .edit("grok-2-image", "image-1", "sunset")
307 .n(3)
308 .send()
309 .await
310 .unwrap();
311
312 assert_eq!(response.first_url(), Some("https://example.com/edited.png"));
313 }
314}