1use axum::{
2 BoxError,
3 body::Body,
4 http::{HeaderMap, HeaderName, HeaderValue, StatusCode},
5 response::Response as AxumResponse,
6};
7use bytes::Bytes;
8use futures::{Stream, StreamExt};
9use std::pin::Pin;
10
11pub enum HandlerResponse {
48 Response(AxumResponse<Body>),
50 Stream {
52 stream: Pin<Box<dyn Stream<Item = Result<Bytes, BoxError>> + Send + 'static>>,
54 status: StatusCode,
56 headers: HeaderMap,
58 },
59}
60
61impl From<AxumResponse<Body>> for HandlerResponse {
62 fn from(response: AxumResponse<Body>) -> Self {
63 HandlerResponse::Response(response)
64 }
65}
66
67impl HandlerResponse {
68 pub fn into_response(self) -> AxumResponse<Body> {
77 match self {
78 HandlerResponse::Response(response) => response,
79 HandlerResponse::Stream {
80 stream,
81 status,
82 mut headers,
83 } => {
84 let body = Body::from_stream(stream);
85 let mut response = AxumResponse::new(body);
86 *response.status_mut() = status;
87 response.headers_mut().extend(headers.drain());
88 response
89 }
90 }
91 }
92
93 pub fn stream<S, E>(stream: S) -> Self
124 where
125 S: Stream<Item = Result<Bytes, E>> + Send + 'static,
126 E: Into<BoxError>,
127 {
128 let mapped = stream.map(|chunk| chunk.map_err(Into::into));
129 HandlerResponse::Stream {
130 stream: Box::pin(mapped),
131 status: StatusCode::OK,
132 headers: HeaderMap::new(),
133 }
134 }
135
136 pub fn with_status(mut self, status: StatusCode) -> Self {
154 if let HandlerResponse::Stream { status: s, .. } = &mut self {
155 *s = status;
156 }
157 self
158 }
159
160 pub fn with_header(mut self, name: HeaderName, value: HeaderValue) -> Self {
185 if let HandlerResponse::Stream { headers, .. } = &mut self {
186 headers.insert(name, value);
187 }
188 self
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use axum::http::header;
196 use http_body_util::BodyExt;
197
198 #[test]
200 fn test_from_axum_response() {
201 let axum_response = AxumResponse::new(Body::from("test body"));
202 let handler_response = HandlerResponse::from(axum_response);
203
204 match handler_response {
205 HandlerResponse::Response(_) => {}
206 HandlerResponse::Stream { .. } => panic!("Expected Response variant"),
207 }
208 }
209
210 #[tokio::test]
212 async fn test_stream_creation_with_chunks() {
213 let chunks = vec![
214 Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from("chunk1")),
215 Ok(Bytes::from("chunk2")),
216 Ok(Bytes::from("chunk3")),
217 ];
218 let stream = futures::stream::iter(chunks);
219 let handler_response = HandlerResponse::stream(stream);
220
221 match handler_response {
222 HandlerResponse::Stream { status, headers, .. } => {
223 assert_eq!(status, StatusCode::OK);
224 assert!(headers.is_empty());
225 }
226 HandlerResponse::Response(_) => panic!("Expected Stream variant"),
227 }
228 }
229
230 #[tokio::test]
232 async fn test_stream_with_custom_status() {
233 let chunks = vec![Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from(
234 "partial",
235 ))];
236 let stream = futures::stream::iter(chunks);
237 let handler_response = HandlerResponse::stream(stream).with_status(StatusCode::PARTIAL_CONTENT);
238
239 match handler_response {
240 HandlerResponse::Stream { status, .. } => {
241 assert_eq!(status, StatusCode::PARTIAL_CONTENT);
242 }
243 HandlerResponse::Response(_) => panic!("Expected Stream variant"),
244 }
245 }
246
247 #[tokio::test]
249 async fn test_stream_with_headers() {
250 let chunks = vec![Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from("test"))];
251 let stream = futures::stream::iter(chunks);
252 let handler_response = HandlerResponse::stream(stream)
253 .with_header(header::CONTENT_TYPE, HeaderValue::from_static("application/x-ndjson"))
254 .with_header(header::CACHE_CONTROL, HeaderValue::from_static("no-cache"));
255
256 match handler_response {
257 HandlerResponse::Stream { headers, .. } => {
258 assert_eq!(headers.get(header::CONTENT_TYPE).unwrap(), "application/x-ndjson");
259 assert_eq!(headers.get(header::CACHE_CONTROL).unwrap(), "no-cache");
260 }
261 HandlerResponse::Response(_) => panic!("Expected Stream variant"),
262 }
263 }
264
265 #[tokio::test]
267 async fn test_stream_body_consumption() {
268 let chunks = vec![
269 Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from("hello ")),
270 Ok(Bytes::from("world")),
271 Ok(Bytes::from("!")),
272 ];
273 let stream = futures::stream::iter(chunks);
274 let handler_response = HandlerResponse::stream(stream).with_status(StatusCode::OK);
275
276 let axum_response = handler_response.into_response();
277 let body = axum_response.into_body().collect().await.unwrap();
278 let bytes = body.to_bytes();
279
280 assert_eq!(bytes, "hello world!");
281 }
282
283 #[tokio::test]
285 async fn test_into_response_for_response_variant() {
286 let original_body = "test response body";
287 let axum_response = AxumResponse::new(Body::from(original_body));
288 let handler_response = HandlerResponse::from(axum_response);
289
290 let result = handler_response.into_response();
291 let body = result.into_body().collect().await.unwrap();
292 let bytes = body.to_bytes();
293
294 assert_eq!(bytes, original_body);
295 }
296
297 #[tokio::test]
299 async fn test_method_chaining() {
300 let chunks = vec![Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from(
301 "chained",
302 ))];
303 let stream = futures::stream::iter(chunks);
304
305 let handler_response = HandlerResponse::stream(stream)
306 .with_status(StatusCode::CREATED)
307 .with_header(header::CONTENT_TYPE, HeaderValue::from_static("text/plain"))
308 .with_header(header::ETAG, HeaderValue::from_static("\"abc123\""));
309
310 match handler_response {
311 HandlerResponse::Stream { status, headers, .. } => {
312 assert_eq!(status, StatusCode::CREATED);
313 assert_eq!(headers.get(header::CONTENT_TYPE).unwrap(), "text/plain");
314 assert_eq!(headers.get(header::ETAG).unwrap(), "\"abc123\"");
315 }
316 HandlerResponse::Response(_) => panic!("Expected Stream variant"),
317 }
318 }
319
320 #[tokio::test]
322 async fn test_empty_stream() {
323 let chunks: Vec<Result<Bytes, Box<dyn std::error::Error + Send + Sync>>> = vec![];
324 let stream = futures::stream::iter(chunks);
325 let handler_response = HandlerResponse::stream(stream).with_status(StatusCode::NO_CONTENT);
326
327 let axum_response = handler_response.into_response();
328 let status = axum_response.status();
329 let body = axum_response.into_body().collect().await.unwrap();
330 let bytes = body.to_bytes();
331
332 assert!(bytes.is_empty());
333 assert_eq!(status, StatusCode::NO_CONTENT);
334 }
335
336 #[tokio::test]
338 async fn test_large_stream() {
339 let chunks: Vec<Result<Bytes, Box<dyn std::error::Error + Send + Sync>>> =
340 (0..150).map(|i| Ok(Bytes::from(format!("chunk{} ", i)))).collect();
341
342 let stream = futures::stream::iter(chunks);
343 let handler_response = HandlerResponse::stream(stream);
344
345 let axum_response = handler_response.into_response();
346 let body = axum_response.into_body().collect().await.unwrap();
347 let bytes = body.to_bytes();
348
349 assert!(bytes.len() > 1000);
350 for i in 0..150 {
351 let expected = format!("chunk{} ", i);
352 assert!(std::str::from_utf8(&bytes).unwrap().contains(&expected));
353 }
354 }
355
356 #[tokio::test]
358 async fn test_stream_error_propagation() {
359 let chunks: Vec<Result<Bytes, Box<dyn std::error::Error + Send + Sync>>> = vec![
360 Ok(Bytes::from("good1 ")),
361 Err("custom error".into()),
362 Ok(Bytes::from("good2")),
363 ];
364
365 let stream = futures::stream::iter(chunks);
366 let handler_response = HandlerResponse::stream(stream);
367
368 let axum_response = handler_response.into_response();
369 let result = axum_response.into_body().collect().await;
370
371 assert!(result.is_err());
372 }
373
374 #[test]
376 fn test_response_variant_ignores_with_status() {
377 let axum_response = AxumResponse::builder()
378 .status(StatusCode::OK)
379 .body(Body::from("test"))
380 .unwrap();
381 let handler_response = HandlerResponse::from(axum_response);
382
383 let result = handler_response.with_status(StatusCode::NOT_FOUND);
384
385 match result {
386 HandlerResponse::Response(resp) => {
387 assert_eq!(resp.status(), StatusCode::OK);
388 }
389 HandlerResponse::Stream { .. } => panic!("Expected Response variant"),
390 }
391 }
392
393 #[test]
395 fn test_response_variant_ignores_with_header() {
396 let axum_response = AxumResponse::builder()
397 .status(StatusCode::OK)
398 .header(header::CONTENT_TYPE, "text/plain")
399 .body(Body::from("test"))
400 .unwrap();
401 let handler_response = HandlerResponse::from(axum_response);
402
403 let result = handler_response.with_header(header::CACHE_CONTROL, HeaderValue::from_static("max-age=3600"));
404
405 match result {
406 HandlerResponse::Response(resp) => {
407 assert!(resp.headers().get(header::CACHE_CONTROL).is_none());
408 }
409 HandlerResponse::Stream { .. } => panic!("Expected Response variant"),
410 }
411 }
412
413 #[tokio::test]
415 async fn test_stream_into_response_applies_status_and_headers() {
416 let chunks = vec![Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from(
417 "stream data",
418 ))];
419 let stream = futures::stream::iter(chunks);
420
421 let handler_response = HandlerResponse::stream(stream)
422 .with_status(StatusCode::PARTIAL_CONTENT)
423 .with_header(header::CONTENT_RANGE, HeaderValue::from_static("bytes 0-10/100"));
424
425 let axum_response = handler_response.into_response();
426
427 assert_eq!(axum_response.status(), StatusCode::PARTIAL_CONTENT);
428 assert_eq!(
429 axum_response.headers().get(header::CONTENT_RANGE).unwrap(),
430 "bytes 0-10/100"
431 );
432
433 let body = axum_response.into_body().collect().await.unwrap();
434 assert_eq!(body.to_bytes(), "stream data");
435 }
436
437 #[tokio::test]
439 async fn test_multiple_header_replacements() {
440 let chunks = vec![Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from("data"))];
441 let stream = futures::stream::iter(chunks);
442
443 let handler_response = HandlerResponse::stream(stream)
444 .with_header(header::CONTENT_TYPE, HeaderValue::from_static("application/json"))
445 .with_header(header::CONTENT_TYPE, HeaderValue::from_static("application/x-ndjson"));
446
447 match handler_response {
448 HandlerResponse::Stream { headers, .. } => {
449 assert_eq!(headers.get(header::CONTENT_TYPE).unwrap(), "application/x-ndjson");
450 }
451 HandlerResponse::Response(_) => panic!("Expected Stream variant"),
452 }
453 }
454
455 #[tokio::test]
457 async fn test_stream_with_various_status_codes() {
458 let status_codes = vec![
459 StatusCode::OK,
460 StatusCode::CREATED,
461 StatusCode::ACCEPTED,
462 StatusCode::PARTIAL_CONTENT,
463 StatusCode::MULTI_STATUS,
464 ];
465
466 for status in status_codes {
467 let chunks = vec![Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from("test"))];
468 let stream = futures::stream::iter(chunks);
469 let handler_response = HandlerResponse::stream(stream).with_status(status);
470
471 match handler_response {
472 HandlerResponse::Stream { status: s, .. } => {
473 assert_eq!(s, status);
474 }
475 HandlerResponse::Response(_) => panic!("Expected Stream variant"),
476 }
477 }
478 }
479
480 #[tokio::test]
482 async fn test_stream_with_json_lines_content() {
483 let chunks = vec![
484 Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from(r#"{"index":0,"payload":"alpha"}"#)),
485 Ok(Bytes::from("\n")),
486 Ok(Bytes::from(r#"{"index":1,"payload":"beta"}"#)),
487 Ok(Bytes::from("\n")),
488 Ok(Bytes::from(r#"{"index":2,"payload":"gamma"}"#)),
489 Ok(Bytes::from("\n")),
490 ];
491
492 let stream = futures::stream::iter(chunks);
493 let handler_response = HandlerResponse::stream(stream)
494 .with_status(StatusCode::OK)
495 .with_header(header::CONTENT_TYPE, HeaderValue::from_static("application/x-ndjson"));
496
497 let axum_response = handler_response.into_response();
498 let status = axum_response.status();
499 let body = axum_response.into_body().collect().await.unwrap();
500 let bytes = body.to_bytes();
501 let body_str = std::str::from_utf8(&bytes).unwrap();
502
503 assert_eq!(status, StatusCode::OK);
504 assert!(body_str.contains("alpha"));
505 assert!(body_str.contains("beta"));
506 assert!(body_str.contains("gamma"));
507 }
508
509 #[tokio::test]
511 async fn test_response_roundtrip() {
512 let original = AxumResponse::builder()
513 .status(StatusCode::OK)
514 .header(header::CONTENT_TYPE, "text/plain")
515 .body(Body::from("roundtrip test"))
516 .unwrap();
517
518 let handler_response = HandlerResponse::from(original);
519 let result = handler_response.into_response();
520
521 assert_eq!(result.status(), StatusCode::OK);
522 assert_eq!(result.headers().get(header::CONTENT_TYPE).unwrap(), "text/plain");
523
524 let body = result.into_body().collect().await.unwrap();
525 assert_eq!(body.to_bytes(), "roundtrip test");
526 }
527
528 #[tokio::test]
530 async fn test_single_chunk_stream() {
531 let chunks = vec![Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from("only"))];
532 let stream = futures::stream::iter(chunks);
533 let handler_response = HandlerResponse::stream(stream).with_status(StatusCode::OK);
534
535 let axum_response = handler_response.into_response();
536 let status = axum_response.status();
537 let body = axum_response.into_body().collect().await.unwrap();
538 let bytes = body.to_bytes();
539
540 assert_eq!(bytes, "only");
541 assert_eq!(status, StatusCode::OK);
542 }
543
544 #[tokio::test]
546 async fn test_very_large_stream_many_chunks() {
547 let chunk_count = 1500;
548 let chunks: Vec<Result<Bytes, Box<dyn std::error::Error + Send + Sync>>> =
549 (0..chunk_count).map(|_| Ok(Bytes::from(format!("x")))).collect();
550
551 let stream = futures::stream::iter(chunks);
552 let handler_response = HandlerResponse::stream(stream);
553
554 let axum_response = handler_response.into_response();
555 let body = axum_response.into_body().collect().await.unwrap();
556 let bytes = body.to_bytes();
557
558 assert_eq!(bytes.len(), chunk_count);
559 }
560
561 #[tokio::test]
563 async fn test_stream_with_varying_chunk_sizes() {
564 let chunks: Vec<Result<Bytes, Box<dyn std::error::Error + Send + Sync>>> = vec![
565 Ok(Bytes::from("x")),
566 Ok(Bytes::from("xx".repeat(100))),
567 Ok(Bytes::from("x".repeat(10_000))),
568 Ok(Bytes::from("x".repeat(100_000))),
569 ];
570
571 let stream = futures::stream::iter(chunks);
572 let handler_response = HandlerResponse::stream(stream);
573
574 let axum_response = handler_response.into_response();
575 let body = axum_response.into_body().collect().await.unwrap();
576 let bytes = body.to_bytes();
577
578 assert_eq!(bytes.len(), 110_201);
579 }
580
581 #[tokio::test]
583 async fn test_stream_error_in_middle() {
584 let chunks: Vec<Result<Bytes, Box<dyn std::error::Error + Send + Sync>>> = (0..1000)
585 .map(|i| {
586 if i == 500 {
587 Err("midstream error".into())
588 } else {
589 Ok(Bytes::from("chunk"))
590 }
591 })
592 .collect();
593
594 let stream = futures::stream::iter(chunks);
595 let handler_response = HandlerResponse::stream(stream);
596
597 let axum_response = handler_response.into_response();
598 let result = axum_response.into_body().collect().await;
599
600 assert!(result.is_err());
601 }
602
603 #[tokio::test]
605 async fn test_stream_with_sse_headers() {
606 let chunks = vec![
607 Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from("event: message\n")),
608 Ok(Bytes::from("data: {\"msg\": \"hello\"}\n\n")),
609 ];
610 let stream = futures::stream::iter(chunks);
611
612 let handler_response = HandlerResponse::stream(stream)
613 .with_status(StatusCode::OK)
614 .with_header(header::CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))
615 .with_header(header::CACHE_CONTROL, HeaderValue::from_static("no-cache"))
616 .with_header(header::CONNECTION, HeaderValue::from_static("keep-alive"));
617
618 let axum_response = handler_response.into_response();
619
620 assert_eq!(axum_response.status(), StatusCode::OK);
621 assert_eq!(
622 axum_response.headers().get(header::CONTENT_TYPE).unwrap(),
623 "text/event-stream"
624 );
625 assert_eq!(axum_response.headers().get(header::CACHE_CONTROL).unwrap(), "no-cache");
626
627 let body = axum_response.into_body().collect().await.unwrap();
628 let body_bytes = body.to_bytes();
629 let body_str = std::str::from_utf8(&body_bytes).unwrap();
630 assert!(body_str.contains("event: message"));
631 }
632
633 #[tokio::test]
635 async fn test_stream_with_websocket_headers() {
636 let chunks = vec![Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from(
637 "ws-frame-data",
638 ))];
639 let stream = futures::stream::iter(chunks);
640
641 let handler_response = HandlerResponse::stream(stream)
642 .with_status(StatusCode::OK)
643 .with_header(header::UPGRADE, HeaderValue::from_static("websocket"))
644 .with_header(
645 HeaderName::from_static("sec-websocket-accept"),
646 HeaderValue::from_static("s3pPLMBiTxaQ9kYGzzhZRbK+xOo="),
647 );
648
649 let axum_response = handler_response.into_response();
650
651 assert_eq!(axum_response.status(), StatusCode::OK);
652 assert_eq!(axum_response.headers().get(header::UPGRADE).unwrap(), "websocket");
653
654 let body = axum_response.into_body().collect().await.unwrap();
655 assert_eq!(body.to_bytes(), "ws-frame-data");
656 }
657
658 #[tokio::test]
660 async fn test_stream_status_transition() {
661 let chunks = vec![Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from("data"))];
662 let stream = futures::stream::iter(chunks);
663
664 let handler_response = HandlerResponse::stream(stream)
665 .with_status(StatusCode::OK)
666 .with_status(StatusCode::PARTIAL_CONTENT);
667
668 match handler_response {
669 HandlerResponse::Stream { status, .. } => {
670 assert_eq!(status, StatusCode::PARTIAL_CONTENT);
671 }
672 HandlerResponse::Response(_) => panic!("Expected Stream variant"),
673 }
674 }
675
676 #[tokio::test]
678 async fn test_stream_chunked_encoding_simulation() {
679 let chunks = vec![
680 Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from("5\r\nhello\r\n")),
681 Ok(Bytes::from("5\r\nworld\r\n")),
682 Ok(Bytes::from("0\r\n\r\n")),
683 ];
684
685 let stream = futures::stream::iter(chunks);
686 let handler_response =
687 HandlerResponse::stream(stream).with_header(header::TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
688
689 let axum_response = handler_response.into_response();
690 let body = axum_response.into_body().collect().await.unwrap();
691 let body_bytes = body.to_bytes();
692
693 assert!(std::str::from_utf8(&body_bytes).unwrap().contains("hello"));
694 }
695
696 #[tokio::test]
698 async fn test_stream_with_binary_data() {
699 let chunks = vec![
700 Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from(vec![0xFF, 0xD8, 0xFF])),
701 Ok(Bytes::from(vec![0xE0, 0x00, 0x10])),
702 Ok(Bytes::from(vec![0x4A, 0x46, 0x49])),
703 ];
704
705 let stream = futures::stream::iter(chunks);
706 let handler_response = HandlerResponse::stream(stream).with_header(
707 header::CONTENT_TYPE,
708 HeaderValue::from_static("application/octet-stream"),
709 );
710
711 let axum_response = handler_response.into_response();
712 let body = axum_response.into_body().collect().await.unwrap();
713 let bytes = body.to_bytes();
714
715 assert_eq!(bytes[0], 0xFF);
716 assert_eq!(bytes[1], 0xD8);
717 assert_eq!(bytes[2], 0xFF);
718 assert_eq!(bytes[3], 0xE0);
719 assert_eq!(bytes[4], 0x00);
720 }
721
722 #[tokio::test]
724 async fn test_stream_with_null_bytes() {
725 let chunks = vec![
726 Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from(vec![0x00, 0x01, 0x02])),
727 Ok(Bytes::from(vec![0x00, 0x00, 0x00])),
728 Ok(Bytes::from(vec![0xFF, 0xFE, 0xFD])),
729 ];
730
731 let stream = futures::stream::iter(chunks);
732 let handler_response = HandlerResponse::stream(stream);
733
734 let axum_response = handler_response.into_response();
735 let body = axum_response.into_body().collect().await.unwrap();
736 let bytes = body.to_bytes();
737
738 assert_eq!(bytes.len(), 9);
739 assert_eq!(bytes[0], 0x00);
740 assert_eq!(bytes[4], 0x00);
741 assert_eq!(bytes[8], 0xFD);
742 }
743
744 #[tokio::test]
746 async fn test_stream_with_many_headers() {
747 let chunks = vec![Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from("data"))];
748 let stream = futures::stream::iter(chunks);
749
750 let mut handler_response = HandlerResponse::stream(stream);
751
752 for i in 0..50 {
753 let header_name = format!("x-custom-{}", i);
754 handler_response = handler_response.with_header(
755 HeaderName::from_bytes(header_name.as_bytes()).unwrap(),
756 HeaderValue::from_static("value"),
757 );
758 }
759
760 let axum_response = handler_response.into_response();
761 assert_eq!(axum_response.status(), StatusCode::OK);
762 assert_eq!(axum_response.headers().len(), 50);
763 }
764
765 #[tokio::test]
767 async fn test_empty_stream_with_204_no_content() {
768 let chunks: Vec<Result<Bytes, Box<dyn std::error::Error + Send + Sync>>> = vec![];
769 let stream = futures::stream::iter(chunks);
770
771 let handler_response = HandlerResponse::stream(stream).with_status(StatusCode::NO_CONTENT);
772
773 let axum_response = handler_response.into_response();
774
775 assert_eq!(axum_response.status(), StatusCode::NO_CONTENT);
776 let body = axum_response.into_body().collect().await.unwrap();
777 assert!(body.to_bytes().is_empty());
778 }
779
780 #[tokio::test]
782 async fn test_stream_repeated_header_updates() {
783 let chunks = vec![Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from("test"))];
784 let stream = futures::stream::iter(chunks);
785
786 let handler_response = HandlerResponse::stream(stream)
787 .with_header(header::CONTENT_TYPE, HeaderValue::from_static("text/plain"))
788 .with_header(header::CONTENT_TYPE, HeaderValue::from_static("application/json"))
789 .with_header(header::CONTENT_TYPE, HeaderValue::from_static("application/xml"));
790
791 match handler_response {
792 HandlerResponse::Stream { headers, .. } => {
793 assert_eq!(headers.get(header::CONTENT_TYPE).unwrap(), "application/xml");
794 }
795 HandlerResponse::Response(_) => panic!("Expected Stream variant"),
796 }
797 }
798
799 #[tokio::test]
801 async fn test_stream_with_extremely_long_chunk() {
802 let large_chunk = "x".repeat(10_000_000);
803 let chunks = vec![Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from(
804 large_chunk,
805 ))];
806 let stream = futures::stream::iter(chunks);
807
808 let handler_response = HandlerResponse::stream(stream);
809
810 let axum_response = handler_response.into_response();
811 let body = axum_response.into_body().collect().await.unwrap();
812 let bytes = body.to_bytes();
813
814 assert_eq!(bytes.len(), 10_000_000);
815 }
816
817 #[tokio::test]
819 async fn test_stream_with_zero_length_chunks() {
820 let chunks: Vec<Result<Bytes, Box<dyn std::error::Error + Send + Sync>>> = vec![
821 Ok(Bytes::from("hello")),
822 Ok(Bytes::new()),
823 Ok(Bytes::from("world")),
824 Ok(Bytes::new()),
825 Ok(Bytes::from("!")),
826 ];
827
828 let stream = futures::stream::iter(chunks);
829 let handler_response = HandlerResponse::stream(stream);
830
831 let axum_response = handler_response.into_response();
832 let body = axum_response.into_body().collect().await.unwrap();
833 let bytes = body.to_bytes();
834
835 assert_eq!(bytes, "helloworld!");
836 }
837
838 #[test]
840 fn test_response_variant_preserves_original_status() {
841 let axum_response = AxumResponse::builder()
842 .status(StatusCode::BAD_REQUEST)
843 .body(Body::from("error"))
844 .unwrap();
845
846 let handler_response = HandlerResponse::from(axum_response);
847
848 let result = handler_response
849 .with_status(StatusCode::OK)
850 .with_status(StatusCode::INTERNAL_SERVER_ERROR);
851
852 match result {
853 HandlerResponse::Response(resp) => {
854 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
855 }
856 HandlerResponse::Stream { .. } => panic!("Expected Response variant"),
857 }
858 }
859
860 #[tokio::test]
862 async fn test_stream_into_response_preserves_headers() {
863 let chunks = vec![Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from("data"))];
864 let stream = futures::stream::iter(chunks);
865
866 let handler_response = HandlerResponse::stream(stream)
867 .with_header(header::CONTENT_TYPE, HeaderValue::from_static("application/json"))
868 .with_header(header::CACHE_CONTROL, HeaderValue::from_static("max-age=3600"))
869 .with_header(header::ETAG, HeaderValue::from_static("\"abc123\""));
870
871 let axum_response = handler_response.into_response();
872
873 assert!(axum_response.headers().get(header::CONTENT_TYPE).is_some());
874 assert!(axum_response.headers().get(header::CACHE_CONTROL).is_some());
875 assert!(axum_response.headers().get(header::ETAG).is_some());
876 assert_eq!(axum_response.headers().len(), 3);
877 }
878
879 #[tokio::test]
881 async fn test_stream_with_error_status_codes() {
882 let error_statuses = vec![
883 StatusCode::INTERNAL_SERVER_ERROR,
884 StatusCode::SERVICE_UNAVAILABLE,
885 StatusCode::GATEWAY_TIMEOUT,
886 ];
887
888 for status in error_statuses {
889 let chunks = vec![Ok::<_, Box<dyn std::error::Error + Send + Sync>>(Bytes::from("error"))];
890 let stream = futures::stream::iter(chunks);
891 let handler_response = HandlerResponse::stream(stream).with_status(status);
892
893 match handler_response {
894 HandlerResponse::Stream { status: s, .. } => {
895 assert_eq!(s, status);
896 }
897 HandlerResponse::Response(_) => panic!("Expected Stream variant"),
898 }
899 }
900 }
901}