1use bytes::Bytes;
22use futures_util::Stream;
23use http::{header, StatusCode};
24use http_body_util::Full;
25
26use crate::response::{IntoResponse, Response};
27
28pub struct StreamBody<S> {
50 #[allow(dead_code)]
51 stream: S,
52 content_type: Option<String>,
53}
54
55impl<S> StreamBody<S> {
56 pub fn new(stream: S) -> Self {
58 Self {
59 stream,
60 content_type: None,
61 }
62 }
63
64 pub fn content_type(mut self, content_type: impl Into<String>) -> Self {
66 self.content_type = Some(content_type.into());
67 self
68 }
69}
70
71impl<S, E> IntoResponse for StreamBody<S>
75where
76 S: Stream<Item = Result<Bytes, E>> + Send + 'static,
77 E: std::error::Error + Send + Sync + 'static,
78{
79 fn into_response(self) -> Response {
80 let content_type = self
84 .content_type
85 .unwrap_or_else(|| "application/octet-stream".to_string());
86
87 http::Response::builder()
88 .status(StatusCode::OK)
89 .header(header::CONTENT_TYPE, content_type)
90 .header(header::TRANSFER_ENCODING, "chunked")
91 .body(Full::new(Bytes::new()))
92 .unwrap()
93 }
94}
95
96pub fn stream_from_iter<I, E>(
100 chunks: I,
101) -> StreamBody<futures_util::stream::Iter<std::vec::IntoIter<Result<Bytes, E>>>>
102where
103 I: IntoIterator<Item = Result<Bytes, E>>,
104{
105 use futures_util::stream;
106 let vec: Vec<_> = chunks.into_iter().collect();
107 StreamBody::new(stream::iter(vec))
108}
109
110pub fn stream_from_strings<I, S, E>(
114 strings: I,
115) -> StreamBody<futures_util::stream::Iter<std::vec::IntoIter<Result<Bytes, E>>>>
116where
117 I: IntoIterator<Item = Result<S, E>>,
118 S: Into<String>,
119{
120 use futures_util::stream;
121 let vec: Vec<_> = strings
122 .into_iter()
123 .map(|r| r.map(|s| Bytes::from(s.into())))
124 .collect();
125 StreamBody::new(stream::iter(vec))
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use futures_util::stream;
132
133 #[test]
134 fn test_stream_body_default_content_type() {
135 let chunks: Vec<Result<Bytes, std::convert::Infallible>> = vec![Ok(Bytes::from("chunk 1"))];
136 let stream_body = StreamBody::new(stream::iter(chunks));
137 let response = stream_body.into_response();
138
139 assert_eq!(response.status(), StatusCode::OK);
140 assert_eq!(
141 response.headers().get(header::CONTENT_TYPE).unwrap(),
142 "application/octet-stream"
143 );
144 assert_eq!(
145 response.headers().get(header::TRANSFER_ENCODING).unwrap(),
146 "chunked"
147 );
148 }
149
150 #[test]
151 fn test_stream_body_custom_content_type() {
152 let chunks: Vec<Result<Bytes, std::convert::Infallible>> = vec![Ok(Bytes::from("chunk 1"))];
153 let stream_body = StreamBody::new(stream::iter(chunks)).content_type("text/plain");
154 let response = stream_body.into_response();
155
156 assert_eq!(response.status(), StatusCode::OK);
157 assert_eq!(
158 response.headers().get(header::CONTENT_TYPE).unwrap(),
159 "text/plain"
160 );
161 }
162
163 #[test]
164 fn test_stream_from_iter() {
165 let chunks: Vec<Result<Bytes, std::convert::Infallible>> =
166 vec![Ok(Bytes::from("chunk 1")), Ok(Bytes::from("chunk 2"))];
167 let stream_body = stream_from_iter(chunks);
168 let response = stream_body.into_response();
169
170 assert_eq!(response.status(), StatusCode::OK);
171 }
172
173 #[test]
174 fn test_stream_from_strings() {
175 let strings: Vec<Result<&str, std::convert::Infallible>> = vec![Ok("hello"), Ok("world")];
176 let stream_body = stream_from_strings(strings);
177 let response = stream_body.into_response();
178
179 assert_eq!(response.status(), StatusCode::OK);
180 }
181}
182
183#[cfg(test)]
184mod property_tests {
185 use super::*;
186 use futures_util::stream;
187 use futures_util::StreamExt;
188 use proptest::prelude::*;
189
190 proptest! {
201 #![proptest_config(ProptestConfig::with_cases(100))]
202
203 #[test]
205 fn prop_chunk_within_limit_accepted(
206 chunk_size in 100usize..1000,
207 limit in 1000usize..10000,
208 ) {
209 tokio::runtime::Runtime::new().unwrap().block_on(async {
210 let data = vec![0u8; chunk_size];
211 let chunks: Vec<Result<Bytes, crate::error::ApiError>> =
212 vec![Ok(Bytes::from(data))];
213 let stream_data = stream::iter(chunks);
214
215 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
216
217 let result = streaming_body.next().await;
219 prop_assert!(result.is_some());
220 prop_assert!(result.unwrap().is_ok());
221
222 prop_assert_eq!(streaming_body.bytes_read(), chunk_size);
224
225 Ok(())
226 })?;
227 }
228
229 #[test]
231 fn prop_chunk_exceeding_limit_rejected(
232 limit in 100usize..1000,
233 excess in 1usize..100,
234 ) {
235 tokio::runtime::Runtime::new().unwrap().block_on(async {
236 let chunk_size = limit + excess;
237 let data = vec![0u8; chunk_size];
238 let chunks: Vec<Result<Bytes, crate::error::ApiError>> =
239 vec![Ok(Bytes::from(data))];
240 let stream_data = stream::iter(chunks);
241
242 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
243
244 let result = streaming_body.next().await;
246 prop_assert!(result.is_some());
247 let error = result.unwrap();
248 prop_assert!(error.is_err());
249
250 let err = error.unwrap_err();
252 prop_assert_eq!(err.status, StatusCode::PAYLOAD_TOO_LARGE);
253
254 Ok(())
255 })?;
256 }
257
258 #[test]
260 fn prop_multiple_chunks_within_limit(
261 chunk_size in 100usize..500,
262 num_chunks in 2usize..5,
263 ) {
264 tokio::runtime::Runtime::new().unwrap().block_on(async {
265 let total_size = chunk_size * num_chunks;
266 let limit = total_size + 100; let chunks: Vec<Result<Bytes, crate::error::ApiError>> = (0..num_chunks)
269 .map(|_| Ok(Bytes::from(vec![0u8; chunk_size])))
270 .collect();
271 let stream_data = stream::iter(chunks);
272
273 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
274
275 let mut total_read = 0;
277 while let Some(result) = streaming_body.next().await {
278 prop_assert!(result.is_ok());
279 total_read += result.unwrap().len();
280 }
281
282 prop_assert_eq!(total_read, total_size);
284 prop_assert_eq!(streaming_body.bytes_read(), total_size);
285
286 Ok(())
287 })?;
288 }
289
290 #[test]
292 fn prop_multiple_chunks_exceeding_limit(
293 chunk_size in 100usize..500,
294 num_chunks in 3usize..6,
295 ) {
296 tokio::runtime::Runtime::new().unwrap().block_on(async {
297 let total_size = chunk_size * num_chunks;
298 let limit = chunk_size + 50; let chunks: Vec<Result<Bytes, crate::error::ApiError>> = (0..num_chunks)
301 .map(|_| Ok(Bytes::from(vec![0u8; chunk_size])))
302 .collect();
303 let stream_data = stream::iter(chunks);
304
305 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
306
307 let first = streaming_body.next().await;
309 prop_assert!(first.is_some());
310 prop_assert!(first.unwrap().is_ok());
311
312 let second = streaming_body.next().await;
314 prop_assert!(second.is_some());
315 let error = second.unwrap();
316 prop_assert!(error.is_err());
317
318 let err = error.unwrap_err();
319 prop_assert_eq!(err.status, StatusCode::PAYLOAD_TOO_LARGE);
320
321 Ok(())
322 })?;
323 }
324
325 #[test]
327 fn prop_no_limit_unlimited(
328 chunk_size in 1000usize..10000,
329 num_chunks in 5usize..10,
330 ) {
331 tokio::runtime::Runtime::new().unwrap().block_on(async {
332 let chunks: Vec<Result<Bytes, crate::error::ApiError>> = (0..num_chunks)
333 .map(|_| Ok(Bytes::from(vec![0u8; chunk_size])))
334 .collect();
335 let stream_data = stream::iter(chunks);
336
337 let mut streaming_body = StreamingBody::from_stream(stream_data, None);
338
339 let mut count = 0;
341 while let Some(result) = streaming_body.next().await {
342 prop_assert!(result.is_ok());
343 count += 1;
344 }
345
346 prop_assert_eq!(count, num_chunks);
347 prop_assert_eq!(streaming_body.bytes_read(), chunk_size * num_chunks);
348
349 Ok(())
350 })?;
351 }
352
353 #[test]
355 fn prop_bytes_read_accurate(
356 sizes in prop::collection::vec(100usize..1000, 1..10)
357 ) {
358 tokio::runtime::Runtime::new().unwrap().block_on(async {
359 let total_size: usize = sizes.iter().sum();
360 let limit = total_size + 1000; let chunks: Vec<Result<Bytes, crate::error::ApiError>> = sizes
363 .iter()
364 .map(|&size| Ok(Bytes::from(vec![0u8; size])))
365 .collect();
366 let stream_data = stream::iter(chunks);
367
368 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
369
370 let mut cumulative = 0;
371 while let Some(result) = streaming_body.next().await {
372 let chunk = result.unwrap();
373 cumulative += chunk.len();
374
375 prop_assert_eq!(streaming_body.bytes_read(), cumulative);
377 }
378
379 prop_assert_eq!(streaming_body.bytes_read(), total_size);
380
381 Ok(())
382 })?;
383 }
384
385 #[test]
387 fn prop_exact_limit_accepted(chunk_size in 500usize..5000) {
388 tokio::runtime::Runtime::new().unwrap().block_on(async {
389 let limit = chunk_size; let data = vec![0u8; chunk_size];
391 let chunks: Vec<Result<Bytes, crate::error::ApiError>> =
392 vec![Ok(Bytes::from(data))];
393 let stream_data = stream::iter(chunks);
394
395 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
396
397 let result = streaming_body.next().await;
399 prop_assert!(result.is_some());
400 prop_assert!(result.unwrap().is_ok());
401
402 prop_assert_eq!(streaming_body.bytes_read(), chunk_size);
403
404 Ok(())
405 })?;
406 }
407
408 #[test]
410 fn prop_one_byte_over_rejected(limit in 500usize..5000) {
411 tokio::runtime::Runtime::new().unwrap().block_on(async {
412 let chunk_size = limit + 1; let data = vec![0u8; chunk_size];
414 let chunks: Vec<Result<Bytes, crate::error::ApiError>> =
415 vec![Ok(Bytes::from(data))];
416 let stream_data = stream::iter(chunks);
417
418 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
419
420 let result = streaming_body.next().await;
422 prop_assert!(result.is_some());
423 let error = result.unwrap();
424 prop_assert!(error.is_err());
425
426 Ok(())
427 })?;
428 }
429
430 #[test]
432 fn prop_empty_chunks_ignored(
433 chunk_size in 100usize..1000,
434 num_empty in 1usize..5,
435 ) {
436 tokio::runtime::Runtime::new().unwrap().block_on(async {
437 let limit = chunk_size + 100;
438
439 let mut chunks: Vec<Result<Bytes, crate::error::ApiError>> = vec![];
440
441 for _ in 0..num_empty {
443 chunks.push(Ok(Bytes::new()));
444 }
445
446 chunks.push(Ok(Bytes::from(vec![0u8; chunk_size])));
448
449 let stream_data = stream::iter(chunks);
450 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
451
452 while let Some(result) = streaming_body.next().await {
454 prop_assert!(result.is_ok());
455 }
456
457 prop_assert_eq!(streaming_body.bytes_read(), chunk_size);
459
460 Ok(())
461 })?;
462 }
463
464 #[test]
466 fn prop_limit_cumulative(
467 chunk1_size in 300usize..600,
468 chunk2_size in 300usize..600,
469 limit in 500usize..900,
470 ) {
471 tokio::runtime::Runtime::new().unwrap().block_on(async {
472 let chunks: Vec<Result<Bytes, crate::error::ApiError>> = vec![
473 Ok(Bytes::from(vec![0u8; chunk1_size])),
474 Ok(Bytes::from(vec![0u8; chunk2_size])),
475 ];
476 let stream_data = stream::iter(chunks);
477
478 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
479
480 let first = streaming_body.next().await;
482 if chunk1_size <= limit {
483 prop_assert!(first.unwrap().is_ok());
484
485 let second = streaming_body.next().await;
487 let total = chunk1_size + chunk2_size;
488
489 if total <= limit {
490 prop_assert!(second.unwrap().is_ok());
492 } else {
493 prop_assert!(second.unwrap().is_err());
495 }
496 } else {
497 prop_assert!(first.unwrap().is_err());
499 }
500
501 Ok(())
502 })?;
503 }
504
505 #[test]
507 fn prop_default_config_limit(_seed in 0u32..10) {
508 let config = StreamingConfig::default();
509 prop_assert_eq!(config.max_body_size, Some(10 * 1024 * 1024));
510 }
511
512 #[test]
514 fn prop_error_message_includes_limit(limit in 1000usize..10000) {
515 tokio::runtime::Runtime::new().unwrap().block_on(async {
516 let chunk_size = limit + 100;
517 let data = vec![0u8; chunk_size];
518 let chunks: Vec<Result<Bytes, crate::error::ApiError>> =
519 vec![Ok(Bytes::from(data))];
520 let stream_data = stream::iter(chunks);
521
522 let mut streaming_body = StreamingBody::from_stream(stream_data, Some(limit));
523
524 let result = streaming_body.next().await;
525 let error = result.unwrap().unwrap_err();
526
527 prop_assert!(error.message.contains(&limit.to_string()));
529 prop_assert!(error.message.contains("exceeded limit"));
530
531 Ok(())
532 })?;
533 }
534 }
535}
536
537#[derive(Debug, Clone, Copy)]
539pub struct StreamingConfig {
540 pub max_body_size: Option<usize>,
542}
543
544impl Default for StreamingConfig {
545 fn default() -> Self {
546 Self {
547 max_body_size: Some(10 * 1024 * 1024), }
549 }
550}
551
552pub struct StreamingBody {
556 inner: StreamingInner,
557 bytes_read: usize,
558 limit: Option<usize>,
559}
560
561enum StreamingInner {
562 Hyper(hyper::body::Incoming),
563 Generic(
564 std::pin::Pin<
565 Box<
566 dyn futures_util::Stream<Item = Result<Bytes, crate::error::ApiError>>
567 + Send
568 + Sync,
569 >,
570 >,
571 ),
572}
573
574impl StreamingBody {
575 pub fn new(inner: hyper::body::Incoming, limit: Option<usize>) -> Self {
577 Self {
578 inner: StreamingInner::Hyper(inner),
579 bytes_read: 0,
580 limit,
581 }
582 }
583
584 pub fn from_stream<S>(stream: S, limit: Option<usize>) -> Self
586 where
587 S: futures_util::Stream<Item = Result<Bytes, crate::error::ApiError>>
588 + Send
589 + Sync
590 + 'static,
591 {
592 Self {
593 inner: StreamingInner::Generic(Box::pin(stream)),
594 bytes_read: 0,
595 limit,
596 }
597 }
598
599 pub fn bytes_read(&self) -> usize {
601 self.bytes_read
602 }
603}
604
605impl Stream for StreamingBody {
606 type Item = Result<Bytes, crate::error::ApiError>;
607
608 fn poll_next(
609 mut self: std::pin::Pin<&mut Self>,
610 cx: &mut std::task::Context<'_>,
611 ) -> std::task::Poll<Option<Self::Item>> {
612 use hyper::body::Body;
613
614 match &mut self.inner {
615 StreamingInner::Hyper(incoming) => {
616 loop {
617 match std::pin::Pin::new(&mut *incoming).poll_frame(cx) {
618 std::task::Poll::Ready(Some(Ok(frame))) => {
619 if let Ok(data) = frame.into_data() {
620 let len = data.len();
621 self.bytes_read += len;
622 if let Some(limit) = self.limit {
623 if self.bytes_read > limit {
624 return std::task::Poll::Ready(Some(Err(
625 crate::error::ApiError::new(
626 StatusCode::PAYLOAD_TOO_LARGE,
627 "payload_too_large",
628 format!(
629 "Body size exceeded limit of {} bytes",
630 limit
631 ),
632 ),
633 )));
634 }
635 }
636 return std::task::Poll::Ready(Some(Ok(data)));
637 }
638 continue; }
640 std::task::Poll::Ready(Some(Err(e))) => {
641 return std::task::Poll::Ready(Some(Err(
642 crate::error::ApiError::bad_request(e.to_string()),
643 )));
644 }
645 std::task::Poll::Ready(None) => return std::task::Poll::Ready(None),
646 std::task::Poll::Pending => return std::task::Poll::Pending,
647 }
648 }
649 }
650 StreamingInner::Generic(stream) => match stream.as_mut().poll_next(cx) {
651 std::task::Poll::Ready(Some(Ok(data))) => {
652 let len = data.len();
653 self.bytes_read += len;
654 if let Some(limit) = self.limit {
655 if self.bytes_read > limit {
656 return std::task::Poll::Ready(Some(Err(crate::error::ApiError::new(
657 StatusCode::PAYLOAD_TOO_LARGE,
658 "payload_too_large",
659 format!("Body size exceeded limit of {} bytes", limit),
660 ))));
661 }
662 }
663 std::task::Poll::Ready(Some(Ok(data)))
664 }
665 other => other,
666 },
667 }
668 }
669}