Skip to main content

scrapfly_sdk/
batch.rs

1//! Streaming multipart/mixed parser for POST /scrape/batch.
2//!
3//! The API emits one part per scrape as each completes. This module
4//! reads the response body as a stream of `Bytes` chunks and yields
5//! `(headers, body)` per part as they arrive.
6//!
7//! Zero new dependencies — only `reqwest`, `bytes`, and `futures-util`
8//! that are already in `Cargo.toml`.
9
10use std::collections::HashMap;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use bytes::{Bytes, BytesMut};
15use futures_util::stream::Stream;
16
17use crate::error::ScrapflyError;
18use crate::result::scrape::ScrapeResult;
19
20const CRLF: &[u8] = b"\r\n";
21const DOUBLE_CRLF: &[u8] = b"\r\n\r\n";
22
23/// One multipart part: header map (lowercased keys) plus body bytes.
24#[derive(Debug)]
25pub struct BatchPart {
26    /// Per-part headers, lowercased. `content-type` is always set;
27    /// `x-scrapfly-correlation-id` and `x-scrapfly-scrape-status`
28    /// are set by the server.
29    pub headers: HashMap<String, String>,
30
31    /// Part body bytes (not decoded).
32    pub body: Bytes,
33}
34
35/// A proxified batch part surfaced as a native Response-like value.
36/// The part body is the raw upstream response (HTML, JSON, binary,
37/// etc.) — not a JSON envelope. `reqwest::Response` is tied to a
38/// live connection so we cannot re-synthesize one from bytes; this
39/// struct carries the same fields a caller needs.
40#[derive(Debug)]
41pub struct BatchProxifiedResponse {
42    /// Upstream HTTP status code restored from X-Scrapfly-Scrape-Status.
43    pub status: u16,
44
45    /// Response headers: upstream headers (originally prefixed with
46    /// X-Scrapfly-Upstream- on the wire, stripped here) PLUS
47    /// Scrapfly metadata (X-Scrapfly-Log, X-Scrapfly-Content-Format,
48    /// X-Scrapfly-Log-Uuid). `Content-Type` is the upstream's
49    /// content-type.
50    pub headers: HashMap<String, String>,
51
52    /// Raw upstream body bytes.
53    pub body: Bytes,
54}
55
56impl BatchProxifiedResponse {
57    /// Decode the body as UTF-8 text (mirrors reqwest::Response::text()).
58    pub fn text(&self) -> String {
59        String::from_utf8_lossy(&self.body).into_owned()
60    }
61
62    /// Convenience accessor for the response content-type.
63    pub fn content_type(&self) -> Option<&str> {
64        self.headers.get("content-type").map(String::as_str)
65    }
66
67    /// Scrapfly log UUID (X-Scrapfly-Log if present, else X-Scrapfly-Log-Uuid).
68    pub fn scrapfly_log(&self) -> Option<&str> {
69        self.headers
70            .get("x-scrapfly-log")
71            .or_else(|| self.headers.get("x-scrapfly-log-uuid"))
72            .map(String::as_str)
73    }
74}
75
76/// Wire format for the per-part body. JSON is the default;
77/// Msgpack matches the Scrapfly API's msgpack negotiation.
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
79pub enum BatchFormat {
80    /// application/json (default).
81    #[default]
82    Json,
83    /// application/msgpack — smaller wire payload.
84    Msgpack,
85}
86
87impl BatchFormat {
88    pub(crate) fn accept_header(self) -> &'static str {
89        match self {
90            BatchFormat::Json => "application/json",
91            BatchFormat::Msgpack => "application/msgpack",
92        }
93    }
94}
95
96/// Options for `Client::scrape_batch_with_options`.
97#[derive(Debug, Clone, Default)]
98pub struct BatchOptions {
99    /// Wire format for per-part bodies. Defaults to JSON.
100    pub format: BatchFormat,
101}
102
103/// Per-part outcome yielded by `Client::scrape_batch`.
104///
105/// NOTE on `clippy::large_enum_variant`: the `Scrape(ScrapeResult)` variant is
106/// ~528 bytes vs. ~120 bytes for the others. Boxing to shrink it would be a
107/// breaking change to v0.2.x consumers pattern-matching on `BatchOutcome::Scrape`.
108/// We accept the unbalanced size — each enum instance is ephemeral (yielded
109/// once per batch part via a stream) and count tops out at the batch size.
110#[allow(clippy::large_enum_variant)]
111#[derive(Debug)]
112pub enum BatchOutcome {
113    /// Standard per-part scrape result (JSON envelope decoded).
114    Scrape(ScrapeResult),
115
116    /// Proxified part: the upstream's raw response, with status +
117    /// headers + body restored from the multipart part metadata.
118    /// Surfaces when the originating `ScrapeConfig.proxified_response
119    /// == true`. Matches the single-scrape `scrape_proxified()`
120    /// return shape as closely as we can without a live connection.
121    Proxified(BatchProxifiedResponse),
122
123    /// Per-part error (decode failure, per-scrape upstream error, etc.).
124    Err(ScrapflyError),
125}
126
127fn find_subslice(buf: &[u8], needle: &[u8]) -> Option<usize> {
128    if needle.is_empty() {
129        return Some(0);
130    }
131
132    if buf.len() < needle.len() {
133        return None;
134    }
135
136    (0..=buf.len() - needle.len()).find(|&i| &buf[i..i + needle.len()] == needle)
137}
138
139fn parse_content_type(value: &str) -> (String, HashMap<String, String>) {
140    if let Some(idx) = value.find(';') {
141        let mime = value[..idx].trim().to_lowercase();
142        let mut params = HashMap::new();
143
144        for piece in value[idx + 1..].split(';') {
145            if let Some(eq) = piece.find('=') {
146                let k = piece[..eq].trim().to_lowercase();
147                let mut v = piece[eq + 1..].trim().to_string();
148
149                if v.starts_with('"') && v.ends_with('"') && v.len() >= 2 {
150                    v = v[1..v.len() - 1].to_string();
151                }
152
153                params.insert(k, v);
154            }
155        }
156
157        (mime, params)
158    } else {
159        (value.trim().to_lowercase(), HashMap::new())
160    }
161}
162
163/// Stream adapter: wraps a reqwest bytes stream and yields one
164/// `BatchPart` per multipart section as the body arrives.
165pub struct BatchPartStream<S> {
166    inner: S,
167    boundary_line: Vec<u8>,
168    boundary_sep: Vec<u8>,
169    buf: BytesMut,
170    state: State,
171    done: bool,
172}
173
174enum State {
175    /// Haven't found the first --boundary yet; discard anything before it.
176    FindFirstBoundary,
177    /// Just consumed a --boundary; next is either CRLF or "--" (terminator).
178    BoundarySuffix,
179    /// Reading part headers up to CRLF CRLF.
180    Headers,
181    /// Reading part body either by Content-Length or up to next boundary.
182    Body {
183        headers: HashMap<String, String>,
184        content_length: Option<usize>,
185    },
186    /// Body already yielded; scan for the trailing "\r\n--<boundary>"
187    /// and discard it before transitioning to BoundarySuffix. This
188    /// lets Content-Length framing yield a part the instant its body
189    /// bytes arrive, without waiting for the next part's boundary.
190    ConsumeSeparator,
191    /// Stream is done; no more parts.
192    Done,
193}
194
195impl<S> BatchPartStream<S>
196where
197    S: Stream<Item = Result<Bytes, reqwest::Error>> + Unpin,
198{
199    /// Construct a part stream wrapping `stream` with the given
200    /// `boundary` (without the leading `--`).
201    pub fn new(stream: S, boundary: &str) -> Self {
202        let boundary_line = format!("--{}", boundary).into_bytes();
203        let boundary_sep = format!("\r\n--{}", boundary).into_bytes();
204
205        Self {
206            inner: stream,
207            boundary_line,
208            boundary_sep,
209            buf: BytesMut::new(),
210            state: State::FindFirstBoundary,
211            done: false,
212        }
213    }
214}
215
216impl<S> Stream for BatchPartStream<S>
217where
218    S: Stream<Item = Result<Bytes, reqwest::Error>> + Unpin,
219{
220    type Item = Result<BatchPart, ScrapflyError>;
221
222    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
223        loop {
224            // Run the state machine on the current buffer.
225            let this = &mut *self;
226
227            match &mut this.state {
228                State::Done => return Poll::Ready(None),
229
230                State::FindFirstBoundary => {
231                    if let Some(idx) = find_subslice(&this.buf, &this.boundary_line) {
232                        let _ = this.buf.split_to(idx + this.boundary_line.len());
233                        this.state = State::BoundarySuffix;
234                        continue;
235                    }
236                }
237
238                State::BoundarySuffix => {
239                    if this.buf.len() < 2 {
240                        // need more
241                    } else {
242                        let head = &this.buf[..2];
243
244                        if head == b"--" {
245                            this.state = State::Done;
246
247                            return Poll::Ready(None);
248                        }
249
250                        if head == CRLF {
251                            let _ = this.buf.split_to(2);
252                            this.state = State::Headers;
253                            continue;
254                        }
255
256                        // Tolerate LF-only.
257                        if this.buf[0] == b'\n' {
258                            let _ = this.buf.split_to(1);
259                            this.state = State::Headers;
260                            continue;
261                        }
262
263                        this.state = State::Done;
264
265                        return Poll::Ready(None);
266                    }
267                }
268
269                State::Headers => {
270                    if let Some(idx) = find_subslice(&this.buf, DOUBLE_CRLF) {
271                        let header_block = this.buf.split_to(idx).freeze();
272                        let _ = this.buf.split_to(DOUBLE_CRLF.len());
273
274                        let mut headers: HashMap<String, String> = HashMap::new();
275
276                        let bytes_ref: &[u8] = header_block.as_ref();
277
278                        for line in bytes_ref.split(|b: &u8| *b == b'\n') {
279                            let line: &[u8] = if let Some(l) = line.strip_suffix(&[b'\r'][..]) {
280                                l
281                            } else {
282                                line
283                            };
284
285                            if line.is_empty() {
286                                continue;
287                            }
288
289                            let s = match std::str::from_utf8(line) {
290                                Ok(s) => s,
291                                Err(_) => continue,
292                            };
293
294                            if let Some(colon) = s.find(':') {
295                                let k = s[..colon].trim().to_lowercase();
296                                let v = s[colon + 1..].trim().to_string();
297                                headers.insert(k, v);
298                            }
299                        }
300
301                        let content_length = headers
302                            .get("content-length")
303                            .and_then(|v| v.parse::<usize>().ok());
304
305                        this.state = State::Body {
306                            headers,
307                            content_length,
308                        };
309                        continue;
310                    }
311                }
312
313                State::Body {
314                    headers,
315                    content_length,
316                } => {
317                    // With Content-Length, yield the part the instant
318                    // its body bytes arrive. Consuming the trailing
319                    // "\r\n--<boundary>" separator is deferred to the
320                    // next poll via State::ConsumeSeparator — that
321                    // way the caller observes streaming order even
322                    // when the next part is slow to land on the wire.
323                    //
324                    // Without Content-Length we have no choice but to
325                    // scan for the separator (it's how the body ends).
326                    let (body_end, consume_sep_after_yield) = match *content_length {
327                        Some(cl) if this.buf.len() >= cl => (Some(cl), true),
328                        Some(_) => (None, false),
329                        None => (find_subslice(&this.buf, &this.boundary_sep), false),
330                    };
331
332                    if let Some(end) = body_end {
333                        let body = this.buf.split_to(end).freeze();
334
335                        let part = BatchPart {
336                            headers: std::mem::take(headers),
337                            body,
338                        };
339
340                        if consume_sep_after_yield {
341                            this.state = State::ConsumeSeparator;
342                        } else {
343                            // Separator was part of the body_end scan;
344                            // drop its bytes and go back to the suffix.
345                            let _ = this.buf.split_to(this.boundary_sep.len());
346                            this.state = State::BoundarySuffix;
347                        }
348
349                        return Poll::Ready(Some(Ok(part)));
350                    }
351                }
352
353                State::ConsumeSeparator => {
354                    if let Some(idx) = find_subslice(&this.buf, &this.boundary_sep) {
355                        let _ = this.buf.split_to(idx + this.boundary_sep.len());
356                        this.state = State::BoundarySuffix;
357                        continue;
358                    }
359                    // Need more bytes; fall through to the pump block.
360                }
361            }
362
363            // Need more bytes from the underlying stream.
364            if this.done {
365                return Poll::Ready(None);
366            }
367
368            match Pin::new(&mut this.inner).poll_next(cx) {
369                Poll::Pending => return Poll::Pending,
370                Poll::Ready(None) => {
371                    this.done = true;
372                    // Let the state machine see EOF on next iteration —
373                    // usually it'll go to Done.
374                    continue;
375                }
376                Poll::Ready(Some(Err(e))) => {
377                    return Poll::Ready(Some(Err(ScrapflyError::Config(format!(
378                        "batch stream error: {}",
379                        e
380                    )))));
381                }
382                Poll::Ready(Some(Ok(bytes))) => {
383                    this.buf.extend_from_slice(&bytes);
384                    continue;
385                }
386            }
387        }
388    }
389}
390
391/// Convenience: take a reqwest `Response` whose Content-Type is
392/// multipart/mixed and return a typed `Stream<Item=BatchPart>`.
393pub fn parts_from_response(
394    resp: reqwest::Response,
395) -> Result<BatchPartStream<impl Stream<Item = Result<Bytes, reqwest::Error>>>, ScrapflyError> {
396    let ct = resp
397        .headers()
398        .get("content-type")
399        .and_then(|v| v.to_str().ok())
400        .unwrap_or("")
401        .to_string();
402
403    let (mime, params) = parse_content_type(&ct);
404
405    if mime != "multipart/mixed" {
406        return Err(ScrapflyError::Config(format!(
407            "scrape_batch: expected Content-Type multipart/mixed, got {:?}",
408            ct
409        )));
410    }
411
412    let boundary = params.get("boundary").cloned().ok_or_else(|| {
413        ScrapflyError::Config(format!(
414            "scrape_batch: Content-Type multipart/mixed missing boundary: {:?}",
415            ct
416        ))
417    })?;
418
419    Ok(BatchPartStream::new(resp.bytes_stream(), &boundary))
420}
421
422/// Header prefix used by the server to forward upstream response
423/// headers on proxified batch parts.
424const UPSTREAM_PREFIX: &str = "x-scrapfly-upstream-";
425
426/// Synthesize a `BatchProxifiedResponse` from a proxified batch part.
427/// Restores the upstream HTTP status from `X-Scrapfly-Scrape-Status`,
428/// merges upstream headers (after stripping the `X-Scrapfly-Upstream-`
429/// prefix) with Scrapfly metadata headers, and exposes the raw body.
430pub fn build_proxified_response(part: BatchPart) -> BatchProxifiedResponse {
431    let status: u16 = part
432        .headers
433        .get("x-scrapfly-scrape-status")
434        .and_then(|s| s.parse().ok())
435        .unwrap_or(200);
436
437    let mut out_headers: HashMap<String, String> = HashMap::new();
438
439    for (key, value) in &part.headers {
440        if key == "content-type" {
441            out_headers.insert("content-type".into(), value.clone());
442        } else if let Some(stripped) = key.strip_prefix(UPSTREAM_PREFIX) {
443            out_headers.insert(stripped.to_string(), value.clone());
444        } else if key.starts_with("x-scrapfly-") {
445            out_headers.insert(key.clone(), value.clone());
446        }
447    }
448
449    // Normalize X-Scrapfly-Log-Uuid → X-Scrapfly-Log for parity with
450    // the single-scrape proxified response.
451    if !out_headers.contains_key("x-scrapfly-log") {
452        if let Some(log_uuid) = out_headers.get("x-scrapfly-log-uuid").cloned() {
453            out_headers.insert("x-scrapfly-log".into(), log_uuid);
454        }
455    }
456
457    BatchProxifiedResponse {
458        status,
459        headers: out_headers,
460        body: part.body,
461    }
462}
463
464/// Decode a part body according to its Content-Type. Supports
465/// `application/json` (default) and `application/msgpack`.
466pub fn decode_part_body<T: serde::de::DeserializeOwned>(
467    part: &BatchPart,
468) -> Result<T, ScrapflyError> {
469    let ct = part
470        .headers
471        .get("content-type")
472        .cloned()
473        .unwrap_or_else(|| "application/json".to_string());
474
475    if ct.starts_with("application/json") {
476        return serde_json::from_slice::<T>(&part.body)
477            .map_err(|e| ScrapflyError::Config(format!("scrape_batch: decode JSON part: {}", e)));
478    }
479
480    if ct.starts_with("application/msgpack") || ct.starts_with("application/x-msgpack") {
481        return rmp_serde::from_slice::<T>(&part.body).map_err(|e| {
482            ScrapflyError::Config(format!("scrape_batch: decode msgpack part: {}", e))
483        });
484    }
485
486    Err(ScrapflyError::Config(format!(
487        "scrape_batch: unsupported part Content-Type: {:?}",
488        ct
489    )))
490}