1use bytes::Bytes;
22use futures_util::Stream;
23use http::{header, StatusCode};
24
25use crate::response::{IntoResponse, Response};
26
27pub struct StreamBody<S> {
49 #[allow(dead_code)]
50 stream: S,
51 content_type: Option<String>,
52}
53
54impl<S> StreamBody<S> {
55 pub fn new(stream: S) -> Self {
57 Self {
58 stream,
59 content_type: None,
60 }
61 }
62
63 pub fn content_type(mut self, content_type: impl Into<String>) -> Self {
65 self.content_type = Some(content_type.into());
66 self
67 }
68}
69
70impl<S, E> IntoResponse for StreamBody<S>
74where
75 S: Stream<Item = Result<Bytes, E>> + Send + 'static,
76 E: std::error::Error + Send + Sync + 'static,
77{
78 fn into_response(self) -> Response {
79 let content_type = self
80 .content_type
81 .unwrap_or_else(|| "application/octet-stream".to_string());
82
83 use futures_util::StreamExt;
84 let stream = self
85 .stream
86 .map(|res| res.map_err(|e| crate::error::ApiError::internal(e.to_string())));
87 let body = crate::response::Body::from_stream(stream);
88
89 http::Response::builder()
90 .status(StatusCode::OK)
91 .header(header::CONTENT_TYPE, content_type)
92 .header(header::TRANSFER_ENCODING, "chunked")
93 .body(body)
94 .unwrap()
95 }
96}
97
98pub fn stream_from_iter<I, E>(
102 chunks: I,
103) -> StreamBody<futures_util::stream::Iter<std::vec::IntoIter<Result<Bytes, E>>>>
104where
105 I: IntoIterator<Item = Result<Bytes, E>>,
106{
107 use futures_util::stream;
108 let vec: Vec<_> = chunks.into_iter().collect();
109 StreamBody::new(stream::iter(vec))
110}
111
112pub fn stream_from_strings<I, S, E>(
116 strings: I,
117) -> StreamBody<futures_util::stream::Iter<std::vec::IntoIter<Result<Bytes, E>>>>
118where
119 I: IntoIterator<Item = Result<S, E>>,
120 S: Into<String>,
121{
122 use futures_util::stream;
123 let vec: Vec<_> = strings
124 .into_iter()
125 .map(|r| r.map(|s| Bytes::from(s.into())))
126 .collect();
127 StreamBody::new(stream::iter(vec))
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use futures_util::stream;
134
135 #[test]
136 fn test_stream_body_default_content_type() {
137 let chunks: Vec<Result<Bytes, std::convert::Infallible>> = vec![Ok(Bytes::from("chunk 1"))];
138 let stream_body = StreamBody::new(stream::iter(chunks));
139 let response = stream_body.into_response();
140
141 assert_eq!(response.status(), StatusCode::OK);
142 assert_eq!(
143 response.headers().get(header::CONTENT_TYPE).unwrap(),
144 "application/octet-stream"
145 );
146 assert_eq!(
147 response.headers().get(header::TRANSFER_ENCODING).unwrap(),
148 "chunked"
149 );
150 }
151
152 #[test]
153 fn test_stream_body_custom_content_type() {
154 let chunks: Vec<Result<Bytes, std::convert::Infallible>> = vec![Ok(Bytes::from("chunk 1"))];
155 let stream_body = StreamBody::new(stream::iter(chunks)).content_type("text/plain");
156 let response = stream_body.into_response();
157
158 assert_eq!(response.status(), StatusCode::OK);
159 assert_eq!(
160 response.headers().get(header::CONTENT_TYPE).unwrap(),
161 "text/plain"
162 );
163 }
164
165 #[test]
166 fn test_stream_from_iter() {
167 let chunks: Vec<Result<Bytes, std::convert::Infallible>> =
168 vec![Ok(Bytes::from("chunk 1")), Ok(Bytes::from("chunk 2"))];
169 let stream_body = stream_from_iter(chunks);
170 let response = stream_body.into_response();
171
172 assert_eq!(response.status(), StatusCode::OK);
173 }
174
175 #[test]
176 fn test_stream_from_strings() {
177 let strings: Vec<Result<&str, std::convert::Infallible>> = vec![Ok("hello"), Ok("world")];
178 let stream_body = stream_from_strings(strings);
179 let response = stream_body.into_response();
180
181 assert_eq!(response.status(), StatusCode::OK);
182 }
183}
184
185#[cfg(test)]
186mod property_tests {
187 use super::*;
188 use futures_util::stream;
189 use futures_util::StreamExt;
190 use proptest::prelude::*;
191
192 proptest! {
203 #![proptest_config(ProptestConfig::with_cases(100))]
204
205 #[test]
207 fn prop_chunk_within_limit_accepted(
208 chunk_size in 100usize..1000,
209 limit in 1000usize..10000,
210 ) {
211 tokio::runtime::Runtime::new().unwrap().block_on(async {
212 let data = vec![0u8; chunk_size];
213 let chunks: Vec<Result<Bytes, crate::error::ApiError>> =
214 vec![Ok(Bytes::from(data))];
215 let stream_data = stream::iter(chunks);
216
217 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
218
219 let result = streaming_body.next().await;
221 prop_assert!(result.is_some());
222 prop_assert!(result.unwrap().is_ok());
223
224 prop_assert_eq!(streaming_body.bytes_read(), chunk_size);
226
227 Ok(())
228 })?;
229 }
230
231 #[test]
233 fn prop_chunk_exceeding_limit_rejected(
234 limit in 100usize..1000,
235 excess in 1usize..100,
236 ) {
237 tokio::runtime::Runtime::new().unwrap().block_on(async {
238 let chunk_size = limit + excess;
239 let data = vec![0u8; chunk_size];
240 let chunks: Vec<Result<Bytes, crate::error::ApiError>> =
241 vec![Ok(Bytes::from(data))];
242 let stream_data = stream::iter(chunks);
243
244 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
245
246 let result = streaming_body.next().await;
248 prop_assert!(result.is_some());
249 let error = result.unwrap();
250 prop_assert!(error.is_err());
251
252 let err = error.unwrap_err();
254 prop_assert_eq!(err.status, StatusCode::PAYLOAD_TOO_LARGE);
255
256 Ok(())
257 })?;
258 }
259
260 #[test]
262 fn prop_multiple_chunks_within_limit(
263 chunk_size in 100usize..500,
264 num_chunks in 2usize..5,
265 ) {
266 tokio::runtime::Runtime::new().unwrap().block_on(async {
267 let total_size = chunk_size * num_chunks;
268 let limit = total_size + 100; let chunks: Vec<Result<Bytes, crate::error::ApiError>> = (0..num_chunks)
271 .map(|_| Ok(Bytes::from(vec![0u8; chunk_size])))
272 .collect();
273 let stream_data = stream::iter(chunks);
274
275 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
276
277 let mut total_read = 0;
279 while let Some(result) = streaming_body.next().await {
280 prop_assert!(result.is_ok());
281 total_read += result.unwrap().len();
282 }
283
284 prop_assert_eq!(total_read, total_size);
286 prop_assert_eq!(streaming_body.bytes_read(), total_size);
287
288 Ok(())
289 })?;
290 }
291
292 #[test]
294 fn prop_multiple_chunks_exceeding_limit(
295 chunk_size in 100usize..500,
296 num_chunks in 3usize..6,
297 ) {
298 tokio::runtime::Runtime::new().unwrap().block_on(async {
299 let _total_size = chunk_size * num_chunks;
300 let limit = chunk_size + 50; let chunks: Vec<Result<Bytes, crate::error::ApiError>> = (0..num_chunks)
303 .map(|_| Ok(Bytes::from(vec![0u8; chunk_size])))
304 .collect();
305 let stream_data = stream::iter(chunks);
306
307 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
308
309 let first = streaming_body.next().await;
311 prop_assert!(first.is_some());
312 prop_assert!(first.unwrap().is_ok());
313
314 let second = streaming_body.next().await;
316 prop_assert!(second.is_some());
317 let error = second.unwrap();
318 prop_assert!(error.is_err());
319
320 let err = error.unwrap_err();
321 prop_assert_eq!(err.status, StatusCode::PAYLOAD_TOO_LARGE);
322
323 Ok(())
324 })?;
325 }
326
327 #[test]
329 fn prop_no_limit_unlimited(
330 chunk_size in 1000usize..10000,
331 num_chunks in 5usize..10,
332 ) {
333 tokio::runtime::Runtime::new().unwrap().block_on(async {
334 let chunks: Vec<Result<Bytes, crate::error::ApiError>> = (0..num_chunks)
335 .map(|_| Ok(Bytes::from(vec![0u8; chunk_size])))
336 .collect();
337 let stream_data = stream::iter(chunks);
338
339 let mut streaming_body = StreamingBody::from_stream(stream_data, None);
340
341 let mut count = 0;
343 while let Some(result) = streaming_body.next().await {
344 prop_assert!(result.is_ok());
345 count += 1;
346 }
347
348 prop_assert_eq!(count, num_chunks);
349 prop_assert_eq!(streaming_body.bytes_read(), chunk_size * num_chunks);
350
351 Ok(())
352 })?;
353 }
354
355 #[test]
357 fn prop_bytes_read_accurate(
358 sizes in prop::collection::vec(100usize..1000, 1..10)
359 ) {
360 tokio::runtime::Runtime::new().unwrap().block_on(async {
361 let total_size: usize = sizes.iter().sum();
362 let limit = total_size + 1000; let chunks: Vec<Result<Bytes, crate::error::ApiError>> = sizes
365 .iter()
366 .map(|&size| Ok(Bytes::from(vec![0u8; size])))
367 .collect();
368 let stream_data = stream::iter(chunks);
369
370 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
371
372 let mut cumulative = 0;
373 while let Some(result) = streaming_body.next().await {
374 let chunk = result.unwrap();
375 cumulative += chunk.len();
376
377 prop_assert_eq!(streaming_body.bytes_read(), cumulative);
379 }
380
381 prop_assert_eq!(streaming_body.bytes_read(), total_size);
382
383 Ok(())
384 })?;
385 }
386
387 #[test]
389 fn prop_exact_limit_accepted(chunk_size in 500usize..5000) {
390 tokio::runtime::Runtime::new().unwrap().block_on(async {
391 let limit = chunk_size; let data = vec![0u8; chunk_size];
393 let chunks: Vec<Result<Bytes, crate::error::ApiError>> =
394 vec![Ok(Bytes::from(data))];
395 let stream_data = stream::iter(chunks);
396
397 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
398
399 let result = streaming_body.next().await;
401 prop_assert!(result.is_some());
402 prop_assert!(result.unwrap().is_ok());
403
404 prop_assert_eq!(streaming_body.bytes_read(), chunk_size);
405
406 Ok(())
407 })?;
408 }
409
410 #[test]
412 fn prop_one_byte_over_rejected(limit in 500usize..5000) {
413 tokio::runtime::Runtime::new().unwrap().block_on(async {
414 let chunk_size = limit + 1; let data = vec![0u8; chunk_size];
416 let chunks: Vec<Result<Bytes, crate::error::ApiError>> =
417 vec![Ok(Bytes::from(data))];
418 let stream_data = stream::iter(chunks);
419
420 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
421
422 let result = streaming_body.next().await;
424 prop_assert!(result.is_some());
425 let error = result.unwrap();
426 prop_assert!(error.is_err());
427
428 Ok(())
429 })?;
430 }
431
432 #[test]
434 fn prop_empty_chunks_ignored(
435 chunk_size in 100usize..1000,
436 num_empty in 1usize..5,
437 ) {
438 tokio::runtime::Runtime::new().unwrap().block_on(async {
439 let limit = chunk_size + 100;
440
441 let mut chunks: Vec<Result<Bytes, crate::error::ApiError>> = vec![];
442
443 for _ in 0..num_empty {
445 chunks.push(Ok(Bytes::new()));
446 }
447
448 chunks.push(Ok(Bytes::from(vec![0u8; chunk_size])));
450
451 let stream_data = stream::iter(chunks);
452 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
453
454 while let Some(result) = streaming_body.next().await {
456 prop_assert!(result.is_ok());
457 }
458
459 prop_assert_eq!(streaming_body.bytes_read(), chunk_size);
461
462 Ok(())
463 })?;
464 }
465
466 #[test]
468 fn prop_limit_cumulative(
469 chunk1_size in 300usize..600,
470 chunk2_size in 300usize..600,
471 limit in 500usize..900,
472 ) {
473 tokio::runtime::Runtime::new().unwrap().block_on(async {
474 let chunks: Vec<Result<Bytes, crate::error::ApiError>> = vec![
475 Ok(Bytes::from(vec![0u8; chunk1_size])),
476 Ok(Bytes::from(vec![0u8; chunk2_size])),
477 ];
478 let stream_data = stream::iter(chunks);
479
480 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
481
482 let first = streaming_body.next().await;
484 if chunk1_size <= limit {
485 prop_assert!(first.unwrap().is_ok());
486
487 let second = streaming_body.next().await;
489 let total = chunk1_size + chunk2_size;
490
491 if total <= limit {
492 prop_assert!(second.unwrap().is_ok());
494 } else {
495 prop_assert!(second.unwrap().is_err());
497 }
498 } else {
499 prop_assert!(first.unwrap().is_err());
501 }
502
503 Ok(())
504 })?;
505 }
506
507 #[test]
509 fn prop_default_config_limit(_seed in 0u32..10) {
510 let config = StreamingConfig::default();
511 prop_assert_eq!(config.max_body_size, Some(10 * 1024 * 1024));
512 }
513
514 #[test]
516 fn prop_error_message_includes_limit(limit in 1000usize..10000) {
517 tokio::runtime::Runtime::new().unwrap().block_on(async {
518 let chunk_size = limit + 100;
519 let data = vec![0u8; chunk_size];
520 let chunks: Vec<Result<Bytes, crate::error::ApiError>> =
521 vec![Ok(Bytes::from(data))];
522 let stream_data = stream::iter(chunks);
523
524 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
525
526 let result = streaming_body.next().await;
527 let error = result.unwrap().unwrap_err();
528
529 prop_assert!(error.message.contains(&limit.to_string()));
531 prop_assert!(error.message.contains("exceeded limit"));
532
533 Ok(())
534 })?;
535 }
536 }
537}
538
539#[derive(Debug, Clone, Copy)]
541pub struct StreamingConfig {
542 pub max_body_size: Option<usize>,
544}
545
546impl Default for StreamingConfig {
547 fn default() -> Self {
548 Self {
549 max_body_size: Some(10 * 1024 * 1024), }
551 }
552}
553
554pub struct StreamingBody {
558 inner: StreamingInner,
559 bytes_read: usize,
560 limit: Option<usize>,
561}
562
563enum StreamingInner {
564 Hyper(hyper::body::Incoming),
565 Generic(
566 std::pin::Pin<
567 Box<
568 dyn futures_util::Stream<Item = Result<Bytes, crate::error::ApiError>>
569 + Send
570 + Sync,
571 >,
572 >,
573 ),
574}
575
576impl StreamingBody {
577 pub fn new(inner: hyper::body::Incoming, limit: Option<usize>) -> Self {
579 Self {
580 inner: StreamingInner::Hyper(inner),
581 bytes_read: 0,
582 limit,
583 }
584 }
585
586 pub fn from_stream<S>(stream: S, limit: Option<usize>) -> Self
588 where
589 S: futures_util::Stream<Item = Result<Bytes, crate::error::ApiError>>
590 + Send
591 + Sync
592 + 'static,
593 {
594 Self {
595 inner: StreamingInner::Generic(Box::pin(stream)),
596 bytes_read: 0,
597 limit,
598 }
599 }
600
601 pub fn bytes_read(&self) -> usize {
603 self.bytes_read
604 }
605}
606
607impl Stream for StreamingBody {
608 type Item = Result<Bytes, crate::error::ApiError>;
609
610 fn poll_next(
611 mut self: std::pin::Pin<&mut Self>,
612 cx: &mut std::task::Context<'_>,
613 ) -> std::task::Poll<Option<Self::Item>> {
614 use hyper::body::Body;
615
616 match &mut self.inner {
617 StreamingInner::Hyper(incoming) => {
618 loop {
619 match std::pin::Pin::new(&mut *incoming).poll_frame(cx) {
620 std::task::Poll::Ready(Some(Ok(frame))) => {
621 if let Ok(data) = frame.into_data() {
622 let len = data.len();
623 self.bytes_read += len;
624 if let Some(limit) = self.limit {
625 if self.bytes_read > limit {
626 return std::task::Poll::Ready(Some(Err(
627 crate::error::ApiError::new(
628 StatusCode::PAYLOAD_TOO_LARGE,
629 "payload_too_large",
630 format!(
631 "Body size exceeded limit of {} bytes",
632 limit
633 ),
634 ),
635 )));
636 }
637 }
638 return std::task::Poll::Ready(Some(Ok(data)));
639 }
640 continue; }
642 std::task::Poll::Ready(Some(Err(e))) => {
643 return std::task::Poll::Ready(Some(Err(
644 crate::error::ApiError::bad_request(e.to_string()),
645 )));
646 }
647 std::task::Poll::Ready(None) => return std::task::Poll::Ready(None),
648 std::task::Poll::Pending => return std::task::Poll::Pending,
649 }
650 }
651 }
652 StreamingInner::Generic(stream) => match stream.as_mut().poll_next(cx) {
653 std::task::Poll::Ready(Some(Ok(data))) => {
654 let len = data.len();
655 self.bytes_read += len;
656 if let Some(limit) = self.limit {
657 if self.bytes_read > limit {
658 return std::task::Poll::Ready(Some(Err(crate::error::ApiError::new(
659 StatusCode::PAYLOAD_TOO_LARGE,
660 "payload_too_large",
661 format!("Body size exceeded limit of {} bytes", limit),
662 ))));
663 }
664 }
665 std::task::Poll::Ready(Some(Ok(data)))
666 }
667 other => other,
668 },
669 }
670 }
671}