Skip to main content

studio_worker/
http.rs

1//! Thin reqwest wrapper around the studio API.
2//!
3//! Every call goes through [`ApiClient::check`], which:
4//!
5//! - emits a structured `tracing` event on success (`debug`) and
6//!   failure (`warn`) so operators can see what the worker is talking
7//!   to without having to enable wire-level logging in reqwest
8//!   (`complete` also logs the upload byte size before the request so
9//!   the attempted payload size is visible even when it never finishes),
10//!   and
11//! - turns non-2xx responses into an `anyhow` error tagged with the
12//!   operation name so the existing log shipper messages stay legible.
13use crate::types::*;
14use anyhow::{anyhow, Context, Result};
15use reqwest::blocking::{Client, Response};
16use std::time::{Duration, Instant};
17use tracing::{debug, warn};
18
19/// Base path under which the worker endpoints are mounted.
20const API_PREFIX: &str = "/graphics/api";
21
22/// Tracing target used for every event emitted by the HTTP client.
23/// Keeping it stable lets operators filter with
24/// `RUST_LOG=studio_worker::http=debug` without touching the rest of
25/// the agent's logs.
26const TRACE_TARGET: &str = "studio_worker::http";
27
28/// Typed non-2xx response error so callers can branch on the status
29/// class (e.g. retry 5xx, never retry 4xx) without sniffing message
30/// strings.  The rendered message keeps the legacy `<op> failed:
31/// <status> — <body>` shape existing tests and log consumers expect.
32#[derive(Debug, thiserror::Error)]
33#[error("{op} failed: {status} — {body}")]
34pub struct HttpStatusError {
35    pub op: String,
36    pub status: u16,
37    pub body: String,
38}
39
40impl HttpStatusError {
41    /// Server-side failures are worth retrying; 4xx contract errors
42    /// are not.
43    pub fn is_transient(&self) -> bool {
44        self.status >= 500
45    }
46}
47
48/// True when the upload failed for a reason a short retry can fix: a
49/// 5xx from the studio or a transport-level error (connect refused,
50/// timeout, broken pipe).  4xx contract errors return false.
51pub fn is_transient_upload_error(e: &anyhow::Error) -> bool {
52    if let Some(status) = e.downcast_ref::<HttpStatusError>() {
53        return status.is_transient();
54    }
55    e.downcast_ref::<reqwest::Error>().is_some()
56}
57
58pub struct ApiClient {
59    pub base_url: String,
60    pub client: Client,
61}
62
63/// Process-wide blocking client, shared by every `ApiClient`.
64/// `reqwest::blocking::Client` is an `Arc` around a connection pool;
65/// rebuilding it per call (the old behaviour) re-did TLS setup and
66/// threw the pool away between requests.
67fn shared_client() -> Result<Client> {
68    static CLIENT: std::sync::OnceLock<Client> = std::sync::OnceLock::new();
69    if let Some(client) = CLIENT.get() {
70        return Ok(client.clone());
71    }
72    let built = Client::builder()
73        .timeout(Duration::from_secs(60))
74        .build()
75        .context("building reqwest client")?;
76    // A concurrent first call may have won the publish; either client
77    // is fine — take whichever landed.
78    Ok(CLIENT.get_or_init(|| built).clone())
79}
80
81impl ApiClient {
82    pub fn new(base_url: String) -> Result<Self> {
83        Ok(Self {
84            base_url: normalize_base_url(&base_url)?,
85            client: shared_client()?,
86        })
87    }
88
89    fn url(&self, path: &str) -> String {
90        format!("{}{}{}", self.base_url, API_PREFIX, path)
91    }
92
93    /// Inspect a response, log it, and convert non-2xx into an
94    /// `anyhow` error.  `op` is the human-readable operation name used
95    /// in the error message (kept stable for log-shipper consumers and
96    /// existing tests).
97    fn check(&self, op: &str, url: &str, started: Instant, response: Response) -> Result<Response> {
98        let status = response.status();
99        let elapsed_ms = started.elapsed().as_millis() as u64;
100        if status.is_success() || status.as_u16() == 204 {
101            debug!(
102                target: TRACE_TARGET,
103                op,
104                endpoint = %url,
105                status = status.as_u16(),
106                elapsed_ms,
107                "ok"
108            );
109            return Ok(response);
110        }
111        // Body read consumes the response; we only need it on the
112        // failure path.
113        let body = response.text().unwrap_or_default();
114        warn!(
115            target: TRACE_TARGET,
116            op,
117            endpoint = %url,
118            status = status.as_u16(),
119            elapsed_ms,
120            body = %body,
121            "{op} failed"
122        );
123        Err(HttpStatusError {
124            op: op.to_string(),
125            status: status.as_u16(),
126            body,
127        }
128        .into())
129    }
130
131    // -----------------------------------------------------------------------
132    // Auto-register (operator-approved) flow
133    // -----------------------------------------------------------------------
134
135    /// Create a Pending Workers row.  Unauthenticated on purpose —
136    /// the studio rate-limits this endpoint by source IP and the
137    /// operator manually approves before the worker can do anything.
138    pub fn register_request(
139        &self,
140        payload: &AutoRegisterRequest,
141    ) -> Result<AutoRegisterRequestResponse> {
142        let url = self.url("/workers/register-request");
143        let started = Instant::now();
144        let response = self.client.post(&url).json(payload).send()?;
145        let response = self.check("register-request", &url, started, response)?;
146        Ok(response.json()?)
147    }
148
149    /// Poll the studio for the operator's decision on a previously
150    /// submitted register-request.  Returns `Ok(None)` when the
151    /// request id is unknown to the studio (likely cleaned up or
152    /// never existed) so the orchestrator can drop the stale id and
153    /// start a fresh one.  Auth is the raw `registration_secret`
154    /// presented as a Bearer token.
155    pub fn poll_register_status(
156        &self,
157        request_id: &str,
158        registration_secret: &str,
159    ) -> Result<Option<RegisterStatus>> {
160        let url = self.url(&format!("/workers/register-requests/{request_id}"));
161        let started = Instant::now();
162        let response = self
163            .client
164            .get(&url)
165            .bearer_auth(registration_secret)
166            .send()?;
167        if response.status().as_u16() == 404 {
168            debug!(
169                target: TRACE_TARGET,
170                op = "register-poll",
171                endpoint = %url,
172                status = 404,
173                elapsed_ms = started.elapsed().as_millis() as u64,
174                "register request not found (stale id; orchestrator will recreate)"
175            );
176            return Ok(None);
177        }
178        let response = self.check("register-poll", &url, started, response)?;
179        Ok(Some(response.json()?))
180    }
181
182    /// Complete a job with binary output (image / audio / video).
183    ///
184    /// This is the only worker-side HTTP route that survives the WS
185    /// migration: R2 multipart doesn't fit cleanly into WS frames.
186    /// Heartbeats, claim/accept/reject, completeJson, fail, and log
187    /// shipping all flow over the WS session owned by
188    /// `ws::session::spawn_ws_session`.
189    pub fn complete(
190        &self,
191        worker_id: &str,
192        token: &str,
193        job_id: &str,
194        ext: &str,
195        prompt: &str,
196        image: Vec<u8>,
197    ) -> Result<()> {
198        let mime = mime_for_ext(ext);
199        let bytes = image.len() as u64;
200        // Emitted before the (potentially slow or failing) upload so the
201        // attempted payload size is always in the operator's logs, even
202        // when the request itself never completes.
203        debug!(
204            target: TRACE_TARGET,
205            op = "complete",
206            job_id,
207            ext,
208            mime,
209            bytes,
210            "uploading job result"
211        );
212        let part = reqwest::blocking::multipart::Part::bytes(image)
213            .file_name(format!("{job_id}.{ext}"))
214            .mime_str(mime)?;
215        let form = reqwest::blocking::multipart::Form::new()
216            .text("prompt", prompt.to_string())
217            .text("ext", ext.to_string())
218            .part("image", part);
219        let url = self.url(&format!("/workers/{worker_id}/jobs/{job_id}/complete"));
220        let started = Instant::now();
221        let response = self
222            .client
223            .post(&url)
224            .bearer_auth(token)
225            .multipart(form)
226            .send()?;
227        self.check("complete", &url, started, response)?;
228        Ok(())
229    }
230
231    /// Like [`Self::complete`] but retries transient failures (5xx +
232    /// transport errors) up to `retries` additional attempts, pausing
233    /// `pause * attempt` between them.  4xx contract errors surface
234    /// immediately.  Keeps a brief upload blip from costing a full GPU
235    /// regeneration — a reported `Fail` makes the studio requeue and
236    /// re-render the whole job.
237    #[allow(clippy::too_many_arguments)]
238    pub fn complete_with_retry(
239        &self,
240        worker_id: &str,
241        token: &str,
242        job_id: &str,
243        ext: &str,
244        prompt: &str,
245        image: Vec<u8>,
246        retries: u32,
247        pause: Duration,
248    ) -> Result<()> {
249        let mut attempt: u32 = 0;
250        loop {
251            match self.complete(worker_id, token, job_id, ext, prompt, image.clone()) {
252                Ok(()) => return Ok(()),
253                Err(e) if attempt < retries && is_transient_upload_error(&e) => {
254                    attempt += 1;
255                    warn!(
256                        target: TRACE_TARGET,
257                        op = "complete",
258                        job_id,
259                        attempt,
260                        max_attempts = retries + 1,
261                        error = %e,
262                        "transient upload failure; retrying"
263                    );
264                    std::thread::sleep(pause * attempt);
265                }
266                Err(e) => return Err(e),
267            }
268        }
269    }
270}
271
272fn normalize_base_url(base_url: &str) -> Result<String> {
273    let mut url =
274        url::Url::parse(base_url).map_err(|e| anyhow!("invalid api_base_url {base_url:?}: {e}"))?;
275    url.set_query(None);
276    url.set_fragment(None);
277
278    let trimmed_path = url.path().trim_end_matches('/').to_string();
279    if trimmed_path.ends_with(API_PREFIX) {
280        let without_prefix = trimmed_path[..trimmed_path.len() - API_PREFIX.len()].to_string();
281        url.set_path(if without_prefix.is_empty() {
282            "/"
283        } else {
284            &without_prefix
285        });
286    }
287
288    Ok(url.as_str().trim_end_matches('/').to_string())
289}
290
291/// Map a binary output's file extension to the MIME type sent as the
292/// multipart `complete` upload's `Content-Type`.  Single source of
293/// truth: every engine that emits a `TaskResult` binary extension
294/// (synthetic image → `png`/`webp`, sd-cpp → `webp`, tts → `wav`,
295/// synthetic video → `webp`, the `video` feature → `gif`) routes
296/// through here, so a new extension can't silently drift into
297/// `application/octet-stream` and break the studio's stored
298/// content-type.
299pub fn mime_for_ext(ext: &str) -> &'static str {
300    match ext {
301        "png" => "image/png",
302        "webp" => "image/webp",
303        "gif" => "image/gif",
304        "wav" => "audio/wav",
305        "mp3" => "audio/mpeg",
306        "mp4" => "video/mp4",
307        _ => "application/octet-stream",
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    #[test]
316    fn mime_for_ext_maps_known_image_audio_video_types() {
317        assert_eq!(mime_for_ext("png"), "image/png");
318        assert_eq!(mime_for_ext("webp"), "image/webp");
319        assert_eq!(mime_for_ext("gif"), "image/gif");
320        assert_eq!(mime_for_ext("wav"), "audio/wav");
321        assert_eq!(mime_for_ext("mp3"), "audio/mpeg");
322        assert_eq!(mime_for_ext("mp4"), "video/mp4");
323    }
324
325    #[test]
326    fn mime_for_ext_falls_back_to_octet_stream_for_unknown() {
327        assert_eq!(mime_for_ext("bin"), "application/octet-stream");
328        assert_eq!(mime_for_ext(""), "application/octet-stream");
329    }
330
331    #[test]
332    fn normalize_base_url_strips_existing_graphics_api_prefix() {
333        let api = ApiClient::new("https://studio.example/graphics/api/".into()).unwrap();
334        assert_eq!(
335            api.url("/workers/register-request"),
336            "https://studio.example/graphics/api/workers/register-request"
337        );
338    }
339
340    #[test]
341    fn normalize_base_url_preserves_outer_mount_path() {
342        let api = ApiClient::new("https://studio.example/custom/graphics/api".into()).unwrap();
343        assert_eq!(
344            api.url("/workers/register-request"),
345            "https://studio.example/custom/graphics/api/workers/register-request"
346        );
347    }
348
349    #[test]
350    fn is_transient_classifies_5xx_as_retryable_and_4xx_as_terminal() {
351        // The retry gate: server-side failures are worth a short retry,
352        // client-side contract errors are not.  This pins the boundary
353        // the upload-retry loop branches on.
354        let err = |status| HttpStatusError {
355            op: "complete".into(),
356            status,
357            body: "x".into(),
358        };
359        assert!(err(500).is_transient());
360        assert!(err(503).is_transient());
361        assert!(!err(499).is_transient());
362        assert!(!err(409).is_transient());
363        assert!(!err(400).is_transient());
364    }
365
366    #[test]
367    fn is_transient_upload_error_branches_on_error_kind() {
368        // 5xx HTTP status → retry; 4xx → terminal.
369        let server_err: anyhow::Error = HttpStatusError {
370            op: "complete".into(),
371            status: 502,
372            body: "bad gateway".into(),
373        }
374        .into();
375        assert!(is_transient_upload_error(&server_err));
376        let client_err: anyhow::Error = HttpStatusError {
377            op: "complete".into(),
378            status: 409,
379            body: "conflict".into(),
380        }
381        .into();
382        assert!(!is_transient_upload_error(&client_err));
383
384        // A transport-level reqwest error (connection refused) is the
385        // branch the wiremock tests can't reach — no HTTP response ever
386        // comes back — yet it's exactly the upload blip a retry fixes.
387        let transport: anyhow::Error = Client::builder()
388            .timeout(Duration::from_millis(200))
389            .build()
390            .unwrap()
391            .post("http://127.0.0.1:1/unreachable")
392            .body(Vec::<u8>::new())
393            .send()
394            .expect_err("connect to a dead port must fail")
395            .into();
396        assert!(
397            is_transient_upload_error(&transport),
398            "a transport-level failure must be retryable"
399        );
400
401        // An unrelated error (neither HTTP status nor transport) is not
402        // an upload blip, so it must not trigger a retry.
403        let unrelated = anyhow!("local disk full while staging the upload");
404        assert!(!is_transient_upload_error(&unrelated));
405    }
406
407    #[test]
408    fn mime_for_ext_covers_every_extension_engines_emit() {
409        // Lock the contract: each binary extension an engine actually
410        // emits must resolve to a real MIME type, never the
411        // octet-stream fallback.  `gif` is the one the `video`
412        // feature produces and that regressed before this guard.
413        for ext in ["png", "webp", "gif", "wav"] {
414            assert_ne!(
415                mime_for_ext(ext),
416                "application/octet-stream",
417                "engine output extension {ext:?} must map to a real MIME type"
418            );
419        }
420    }
421}