1use std::collections::HashMap;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use bytes::{Bytes, BytesMut};
15use futures_util::stream::Stream;
16
17use crate::error::ScrapflyError;
18use crate::result::scrape::ScrapeResult;
19
20const CRLF: &[u8] = b"\r\n";
21const DOUBLE_CRLF: &[u8] = b"\r\n\r\n";
22
23#[derive(Debug)]
25pub struct BatchPart {
26 pub headers: HashMap<String, String>,
30
31 pub body: Bytes,
33}
34
35#[derive(Debug)]
41pub struct BatchProxifiedResponse {
42 pub status: u16,
44
45 pub headers: HashMap<String, String>,
51
52 pub body: Bytes,
54}
55
56impl BatchProxifiedResponse {
57 pub fn text(&self) -> String {
59 String::from_utf8_lossy(&self.body).into_owned()
60 }
61
62 pub fn content_type(&self) -> Option<&str> {
64 self.headers.get("content-type").map(String::as_str)
65 }
66
67 pub fn scrapfly_log(&self) -> Option<&str> {
69 self.headers
70 .get("x-scrapfly-log")
71 .or_else(|| self.headers.get("x-scrapfly-log-uuid"))
72 .map(String::as_str)
73 }
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
79pub enum BatchFormat {
80 #[default]
82 Json,
83 Msgpack,
85}
86
87impl BatchFormat {
88 pub(crate) fn accept_header(self) -> &'static str {
89 match self {
90 BatchFormat::Json => "application/json",
91 BatchFormat::Msgpack => "application/msgpack",
92 }
93 }
94}
95
96#[derive(Debug, Clone, Default)]
98pub struct BatchOptions {
99 pub format: BatchFormat,
101}
102
103#[allow(clippy::large_enum_variant)]
111#[derive(Debug)]
112pub enum BatchOutcome {
113 Scrape(ScrapeResult),
115
116 Proxified(BatchProxifiedResponse),
122
123 Err(ScrapflyError),
125}
126
127fn find_subslice(buf: &[u8], needle: &[u8]) -> Option<usize> {
128 if needle.is_empty() {
129 return Some(0);
130 }
131
132 if buf.len() < needle.len() {
133 return None;
134 }
135
136 (0..=buf.len() - needle.len()).find(|&i| &buf[i..i + needle.len()] == needle)
137}
138
139fn parse_content_type(value: &str) -> (String, HashMap<String, String>) {
140 if let Some(idx) = value.find(';') {
141 let mime = value[..idx].trim().to_lowercase();
142 let mut params = HashMap::new();
143
144 for piece in value[idx + 1..].split(';') {
145 if let Some(eq) = piece.find('=') {
146 let k = piece[..eq].trim().to_lowercase();
147 let mut v = piece[eq + 1..].trim().to_string();
148
149 if v.starts_with('"') && v.ends_with('"') && v.len() >= 2 {
150 v = v[1..v.len() - 1].to_string();
151 }
152
153 params.insert(k, v);
154 }
155 }
156
157 (mime, params)
158 } else {
159 (value.trim().to_lowercase(), HashMap::new())
160 }
161}
162
163pub struct BatchPartStream<S> {
166 inner: S,
167 boundary_line: Vec<u8>,
168 boundary_sep: Vec<u8>,
169 buf: BytesMut,
170 state: State,
171 done: bool,
172}
173
174enum State {
175 FindFirstBoundary,
177 BoundarySuffix,
179 Headers,
181 Body {
183 headers: HashMap<String, String>,
184 content_length: Option<usize>,
185 },
186 ConsumeSeparator,
191 Done,
193}
194
195impl<S> BatchPartStream<S>
196where
197 S: Stream<Item = Result<Bytes, reqwest::Error>> + Unpin,
198{
199 pub fn new(stream: S, boundary: &str) -> Self {
202 let boundary_line = format!("--{}", boundary).into_bytes();
203 let boundary_sep = format!("\r\n--{}", boundary).into_bytes();
204
205 Self {
206 inner: stream,
207 boundary_line,
208 boundary_sep,
209 buf: BytesMut::new(),
210 state: State::FindFirstBoundary,
211 done: false,
212 }
213 }
214}
215
216impl<S> Stream for BatchPartStream<S>
217where
218 S: Stream<Item = Result<Bytes, reqwest::Error>> + Unpin,
219{
220 type Item = Result<BatchPart, ScrapflyError>;
221
222 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
223 loop {
224 let this = &mut *self;
226
227 match &mut this.state {
228 State::Done => return Poll::Ready(None),
229
230 State::FindFirstBoundary => {
231 if let Some(idx) = find_subslice(&this.buf, &this.boundary_line) {
232 let _ = this.buf.split_to(idx + this.boundary_line.len());
233 this.state = State::BoundarySuffix;
234 continue;
235 }
236 }
237
238 State::BoundarySuffix => {
239 if this.buf.len() < 2 {
240 } else {
242 let head = &this.buf[..2];
243
244 if head == b"--" {
245 this.state = State::Done;
246
247 return Poll::Ready(None);
248 }
249
250 if head == CRLF {
251 let _ = this.buf.split_to(2);
252 this.state = State::Headers;
253 continue;
254 }
255
256 if this.buf[0] == b'\n' {
258 let _ = this.buf.split_to(1);
259 this.state = State::Headers;
260 continue;
261 }
262
263 this.state = State::Done;
264
265 return Poll::Ready(None);
266 }
267 }
268
269 State::Headers => {
270 if let Some(idx) = find_subslice(&this.buf, DOUBLE_CRLF) {
271 let header_block = this.buf.split_to(idx).freeze();
272 let _ = this.buf.split_to(DOUBLE_CRLF.len());
273
274 let mut headers: HashMap<String, String> = HashMap::new();
275
276 let bytes_ref: &[u8] = header_block.as_ref();
277
278 for line in bytes_ref.split(|b: &u8| *b == b'\n') {
279 let line: &[u8] = if let Some(l) = line.strip_suffix(&[b'\r'][..]) {
280 l
281 } else {
282 line
283 };
284
285 if line.is_empty() {
286 continue;
287 }
288
289 let s = match std::str::from_utf8(line) {
290 Ok(s) => s,
291 Err(_) => continue,
292 };
293
294 if let Some(colon) = s.find(':') {
295 let k = s[..colon].trim().to_lowercase();
296 let v = s[colon + 1..].trim().to_string();
297 headers.insert(k, v);
298 }
299 }
300
301 let content_length = headers
302 .get("content-length")
303 .and_then(|v| v.parse::<usize>().ok());
304
305 this.state = State::Body {
306 headers,
307 content_length,
308 };
309 continue;
310 }
311 }
312
313 State::Body {
314 headers,
315 content_length,
316 } => {
317 let (body_end, consume_sep_after_yield) = match *content_length {
327 Some(cl) if this.buf.len() >= cl => (Some(cl), true),
328 Some(_) => (None, false),
329 None => (find_subslice(&this.buf, &this.boundary_sep), false),
330 };
331
332 if let Some(end) = body_end {
333 let body = this.buf.split_to(end).freeze();
334
335 let part = BatchPart {
336 headers: std::mem::take(headers),
337 body,
338 };
339
340 if consume_sep_after_yield {
341 this.state = State::ConsumeSeparator;
342 } else {
343 let _ = this.buf.split_to(this.boundary_sep.len());
346 this.state = State::BoundarySuffix;
347 }
348
349 return Poll::Ready(Some(Ok(part)));
350 }
351 }
352
353 State::ConsumeSeparator => {
354 if let Some(idx) = find_subslice(&this.buf, &this.boundary_sep) {
355 let _ = this.buf.split_to(idx + this.boundary_sep.len());
356 this.state = State::BoundarySuffix;
357 continue;
358 }
359 }
361 }
362
363 if this.done {
365 return Poll::Ready(None);
366 }
367
368 match Pin::new(&mut this.inner).poll_next(cx) {
369 Poll::Pending => return Poll::Pending,
370 Poll::Ready(None) => {
371 this.done = true;
372 continue;
375 }
376 Poll::Ready(Some(Err(e))) => {
377 return Poll::Ready(Some(Err(ScrapflyError::Config(format!(
378 "batch stream error: {}",
379 e
380 )))));
381 }
382 Poll::Ready(Some(Ok(bytes))) => {
383 this.buf.extend_from_slice(&bytes);
384 continue;
385 }
386 }
387 }
388 }
389}
390
391pub fn parts_from_response(
394 resp: reqwest::Response,
395) -> Result<BatchPartStream<impl Stream<Item = Result<Bytes, reqwest::Error>>>, ScrapflyError> {
396 let ct = resp
397 .headers()
398 .get("content-type")
399 .and_then(|v| v.to_str().ok())
400 .unwrap_or("")
401 .to_string();
402
403 let (mime, params) = parse_content_type(&ct);
404
405 if mime != "multipart/mixed" {
406 return Err(ScrapflyError::Config(format!(
407 "scrape_batch: expected Content-Type multipart/mixed, got {:?}",
408 ct
409 )));
410 }
411
412 let boundary = params.get("boundary").cloned().ok_or_else(|| {
413 ScrapflyError::Config(format!(
414 "scrape_batch: Content-Type multipart/mixed missing boundary: {:?}",
415 ct
416 ))
417 })?;
418
419 Ok(BatchPartStream::new(resp.bytes_stream(), &boundary))
420}
421
422const UPSTREAM_PREFIX: &str = "x-scrapfly-upstream-";
425
426pub fn build_proxified_response(part: BatchPart) -> BatchProxifiedResponse {
431 let status: u16 = part
432 .headers
433 .get("x-scrapfly-scrape-status")
434 .and_then(|s| s.parse().ok())
435 .unwrap_or(200);
436
437 let mut out_headers: HashMap<String, String> = HashMap::new();
438
439 for (key, value) in &part.headers {
440 if key == "content-type" {
441 out_headers.insert("content-type".into(), value.clone());
442 } else if let Some(stripped) = key.strip_prefix(UPSTREAM_PREFIX) {
443 out_headers.insert(stripped.to_string(), value.clone());
444 } else if key.starts_with("x-scrapfly-") {
445 out_headers.insert(key.clone(), value.clone());
446 }
447 }
448
449 if !out_headers.contains_key("x-scrapfly-log") {
452 if let Some(log_uuid) = out_headers.get("x-scrapfly-log-uuid").cloned() {
453 out_headers.insert("x-scrapfly-log".into(), log_uuid);
454 }
455 }
456
457 BatchProxifiedResponse {
458 status,
459 headers: out_headers,
460 body: part.body,
461 }
462}
463
464pub fn decode_part_body<T: serde::de::DeserializeOwned>(
467 part: &BatchPart,
468) -> Result<T, ScrapflyError> {
469 let ct = part
470 .headers
471 .get("content-type")
472 .cloned()
473 .unwrap_or_else(|| "application/json".to_string());
474
475 if ct.starts_with("application/json") {
476 return serde_json::from_slice::<T>(&part.body)
477 .map_err(|e| ScrapflyError::Config(format!("scrape_batch: decode JSON part: {}", e)));
478 }
479
480 if ct.starts_with("application/msgpack") || ct.starts_with("application/x-msgpack") {
481 return rmp_serde::from_slice::<T>(&part.body).map_err(|e| {
482 ScrapflyError::Config(format!("scrape_batch: decode msgpack part: {}", e))
483 });
484 }
485
486 Err(ScrapflyError::Config(format!(
487 "scrape_batch: unsupported part Content-Type: {:?}",
488 ct
489 )))
490}