1use crate::types::*;
14use anyhow::{anyhow, Context, Result};
15use reqwest::blocking::{Client, Response};
16use std::time::{Duration, Instant};
17use tracing::{debug, warn};
18
19const API_PREFIX: &str = "/graphics/api";
21
22const TRACE_TARGET: &str = "studio_worker::http";
27
28#[derive(Debug, thiserror::Error)]
33#[error("{op} failed: {status} — {body}")]
34pub struct HttpStatusError {
35 pub op: String,
36 pub status: u16,
37 pub body: String,
38}
39
40impl HttpStatusError {
41 pub fn is_transient(&self) -> bool {
44 self.status >= 500
45 }
46}
47
48pub fn is_transient_upload_error(e: &anyhow::Error) -> bool {
52 if let Some(status) = e.downcast_ref::<HttpStatusError>() {
53 return status.is_transient();
54 }
55 e.downcast_ref::<reqwest::Error>().is_some()
56}
57
58pub struct ApiClient {
59 pub base_url: String,
60 pub client: Client,
61}
62
63fn shared_client() -> Result<Client> {
68 static CLIENT: std::sync::OnceLock<Client> = std::sync::OnceLock::new();
69 if let Some(client) = CLIENT.get() {
70 return Ok(client.clone());
71 }
72 let built = Client::builder()
73 .timeout(Duration::from_secs(60))
74 .build()
75 .context("building reqwest client")?;
76 Ok(CLIENT.get_or_init(|| built).clone())
79}
80
81impl ApiClient {
82 pub fn new(base_url: String) -> Result<Self> {
83 Ok(Self {
84 base_url: normalize_base_url(&base_url)?,
85 client: shared_client()?,
86 })
87 }
88
89 fn url(&self, path: &str) -> String {
90 format!("{}{}{}", self.base_url, API_PREFIX, path)
91 }
92
93 fn check(&self, op: &str, url: &str, started: Instant, response: Response) -> Result<Response> {
98 let status = response.status();
99 let elapsed_ms = started.elapsed().as_millis() as u64;
100 if status.is_success() || status.as_u16() == 204 {
101 debug!(
102 target: TRACE_TARGET,
103 op,
104 endpoint = %url,
105 status = status.as_u16(),
106 elapsed_ms,
107 "ok"
108 );
109 return Ok(response);
110 }
111 let body = response.text().unwrap_or_default();
114 warn!(
115 target: TRACE_TARGET,
116 op,
117 endpoint = %url,
118 status = status.as_u16(),
119 elapsed_ms,
120 body = %body,
121 "{op} failed"
122 );
123 Err(HttpStatusError {
124 op: op.to_string(),
125 status: status.as_u16(),
126 body,
127 }
128 .into())
129 }
130
131 pub fn register_request(
139 &self,
140 payload: &AutoRegisterRequest,
141 ) -> Result<AutoRegisterRequestResponse> {
142 let url = self.url("/workers/register-request");
143 let started = Instant::now();
144 let response = self.client.post(&url).json(payload).send()?;
145 let response = self.check("register-request", &url, started, response)?;
146 Ok(response.json()?)
147 }
148
149 pub fn poll_register_status(
156 &self,
157 request_id: &str,
158 registration_secret: &str,
159 ) -> Result<Option<RegisterStatus>> {
160 let url = self.url(&format!("/workers/register-requests/{request_id}"));
161 let started = Instant::now();
162 let response = self
163 .client
164 .get(&url)
165 .bearer_auth(registration_secret)
166 .send()?;
167 if response.status().as_u16() == 404 {
168 debug!(
169 target: TRACE_TARGET,
170 op = "register-poll",
171 endpoint = %url,
172 status = 404,
173 elapsed_ms = started.elapsed().as_millis() as u64,
174 "register request not found (stale id; orchestrator will recreate)"
175 );
176 return Ok(None);
177 }
178 let response = self.check("register-poll", &url, started, response)?;
179 Ok(Some(response.json()?))
180 }
181
182 pub fn complete(
190 &self,
191 worker_id: &str,
192 token: &str,
193 job_id: &str,
194 ext: &str,
195 prompt: &str,
196 image: Vec<u8>,
197 ) -> Result<()> {
198 let mime = mime_for_ext(ext);
199 let bytes = image.len() as u64;
200 debug!(
204 target: TRACE_TARGET,
205 op = "complete",
206 job_id,
207 ext,
208 mime,
209 bytes,
210 "uploading job result"
211 );
212 let part = reqwest::blocking::multipart::Part::bytes(image)
213 .file_name(format!("{job_id}.{ext}"))
214 .mime_str(mime)?;
215 let form = reqwest::blocking::multipart::Form::new()
216 .text("prompt", prompt.to_string())
217 .text("ext", ext.to_string())
218 .part("image", part);
219 let url = self.url(&format!("/workers/{worker_id}/jobs/{job_id}/complete"));
220 let started = Instant::now();
221 let response = self
222 .client
223 .post(&url)
224 .bearer_auth(token)
225 .multipart(form)
226 .send()?;
227 self.check("complete", &url, started, response)?;
228 Ok(())
229 }
230
231 #[allow(clippy::too_many_arguments)]
238 pub fn complete_with_retry(
239 &self,
240 worker_id: &str,
241 token: &str,
242 job_id: &str,
243 ext: &str,
244 prompt: &str,
245 image: Vec<u8>,
246 retries: u32,
247 pause: Duration,
248 ) -> Result<()> {
249 let mut attempt: u32 = 0;
250 loop {
251 match self.complete(worker_id, token, job_id, ext, prompt, image.clone()) {
252 Ok(()) => return Ok(()),
253 Err(e) if attempt < retries && is_transient_upload_error(&e) => {
254 attempt += 1;
255 warn!(
256 target: TRACE_TARGET,
257 op = "complete",
258 job_id,
259 attempt,
260 max_attempts = retries + 1,
261 error = %e,
262 "transient upload failure; retrying"
263 );
264 std::thread::sleep(pause * attempt);
265 }
266 Err(e) => return Err(e),
267 }
268 }
269 }
270}
271
272fn normalize_base_url(base_url: &str) -> Result<String> {
273 let mut url =
274 url::Url::parse(base_url).map_err(|e| anyhow!("invalid api_base_url {base_url:?}: {e}"))?;
275 url.set_query(None);
276 url.set_fragment(None);
277
278 let trimmed_path = url.path().trim_end_matches('/').to_string();
279 if trimmed_path.ends_with(API_PREFIX) {
280 let without_prefix = trimmed_path[..trimmed_path.len() - API_PREFIX.len()].to_string();
281 url.set_path(if without_prefix.is_empty() {
282 "/"
283 } else {
284 &without_prefix
285 });
286 }
287
288 Ok(url.as_str().trim_end_matches('/').to_string())
289}
290
291pub fn mime_for_ext(ext: &str) -> &'static str {
300 match ext {
301 "png" => "image/png",
302 "webp" => "image/webp",
303 "gif" => "image/gif",
304 "wav" => "audio/wav",
305 "mp3" => "audio/mpeg",
306 "mp4" => "video/mp4",
307 _ => "application/octet-stream",
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314
315 #[test]
316 fn mime_for_ext_maps_known_image_audio_video_types() {
317 assert_eq!(mime_for_ext("png"), "image/png");
318 assert_eq!(mime_for_ext("webp"), "image/webp");
319 assert_eq!(mime_for_ext("gif"), "image/gif");
320 assert_eq!(mime_for_ext("wav"), "audio/wav");
321 assert_eq!(mime_for_ext("mp3"), "audio/mpeg");
322 assert_eq!(mime_for_ext("mp4"), "video/mp4");
323 }
324
325 #[test]
326 fn mime_for_ext_falls_back_to_octet_stream_for_unknown() {
327 assert_eq!(mime_for_ext("bin"), "application/octet-stream");
328 assert_eq!(mime_for_ext(""), "application/octet-stream");
329 }
330
331 #[test]
332 fn normalize_base_url_strips_existing_graphics_api_prefix() {
333 let api = ApiClient::new("https://studio.example/graphics/api/".into()).unwrap();
334 assert_eq!(
335 api.url("/workers/register-request"),
336 "https://studio.example/graphics/api/workers/register-request"
337 );
338 }
339
340 #[test]
341 fn normalize_base_url_preserves_outer_mount_path() {
342 let api = ApiClient::new("https://studio.example/custom/graphics/api".into()).unwrap();
343 assert_eq!(
344 api.url("/workers/register-request"),
345 "https://studio.example/custom/graphics/api/workers/register-request"
346 );
347 }
348
349 #[test]
350 fn is_transient_classifies_5xx_as_retryable_and_4xx_as_terminal() {
351 let err = |status| HttpStatusError {
355 op: "complete".into(),
356 status,
357 body: "x".into(),
358 };
359 assert!(err(500).is_transient());
360 assert!(err(503).is_transient());
361 assert!(!err(499).is_transient());
362 assert!(!err(409).is_transient());
363 assert!(!err(400).is_transient());
364 }
365
366 #[test]
367 fn is_transient_upload_error_branches_on_error_kind() {
368 let server_err: anyhow::Error = HttpStatusError {
370 op: "complete".into(),
371 status: 502,
372 body: "bad gateway".into(),
373 }
374 .into();
375 assert!(is_transient_upload_error(&server_err));
376 let client_err: anyhow::Error = HttpStatusError {
377 op: "complete".into(),
378 status: 409,
379 body: "conflict".into(),
380 }
381 .into();
382 assert!(!is_transient_upload_error(&client_err));
383
384 let transport: anyhow::Error = Client::builder()
388 .timeout(Duration::from_millis(200))
389 .build()
390 .unwrap()
391 .post("http://127.0.0.1:1/unreachable")
392 .body(Vec::<u8>::new())
393 .send()
394 .expect_err("connect to a dead port must fail")
395 .into();
396 assert!(
397 is_transient_upload_error(&transport),
398 "a transport-level failure must be retryable"
399 );
400
401 let unrelated = anyhow!("local disk full while staging the upload");
404 assert!(!is_transient_upload_error(&unrelated));
405 }
406
407 #[test]
408 fn mime_for_ext_covers_every_extension_engines_emit() {
409 for ext in ["png", "webp", "gif", "wav"] {
414 assert_ne!(
415 mime_for_ext(ext),
416 "application/octet-stream",
417 "engine output extension {ext:?} must map to a real MIME type"
418 );
419 }
420 }
421}