1use crate::error::{ApiError, ErrorResponse};
74use bytes::Bytes;
75use futures_util::StreamExt;
76use http::{header, HeaderMap, HeaderValue, StatusCode};
77use http_body_util::Full;
78use rustapi_openapi::{MediaType, Operation, ResponseModifier, ResponseSpec, Schema, SchemaRef};
79use serde::Serialize;
80use std::collections::HashMap;
81use std::pin::Pin;
82use std::task::{Context, Poll};
83
84pub enum Body {
86 Full(Full<Bytes>),
88 Streaming(Pin<Box<dyn http_body::Body<Data = Bytes, Error = ApiError> + Send + 'static>>),
90}
91
92impl Body {
93 pub fn new(bytes: Bytes) -> Self {
95 Self::Full(Full::new(bytes))
96 }
97
98 pub fn empty() -> Self {
100 Self::Full(Full::new(Bytes::new()))
101 }
102
103 pub fn from_stream<S, E>(stream: S) -> Self
105 where
106 S: futures_util::Stream<Item = Result<Bytes, E>> + Send + 'static,
107 E: Into<ApiError> + 'static,
108 {
109 let body = http_body_util::StreamBody::new(
110 stream.map(|res| res.map_err(|e| e.into()).map(http_body::Frame::data)),
111 );
112 Self::Streaming(Box::pin(body))
113 }
114}
115
116impl Default for Body {
117 fn default() -> Self {
118 Self::empty()
119 }
120}
121
122impl http_body::Body for Body {
123 type Data = Bytes;
124 type Error = ApiError;
125
126 fn poll_frame(
127 self: Pin<&mut Self>,
128 cx: &mut Context<'_>,
129 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
130 match self.get_mut() {
131 Body::Full(b) => Pin::new(b)
132 .poll_frame(cx)
133 .map_err(|_| ApiError::internal("Infallible error")),
134 Body::Streaming(b) => b.as_mut().poll_frame(cx),
135 }
136 }
137
138 fn is_end_stream(&self) -> bool {
139 match self {
140 Body::Full(b) => b.is_end_stream(),
141 Body::Streaming(b) => b.is_end_stream(),
142 }
143 }
144
145 fn size_hint(&self) -> http_body::SizeHint {
146 match self {
147 Body::Full(b) => b.size_hint(),
148 Body::Streaming(b) => b.size_hint(),
149 }
150 }
151}
152
153impl From<Bytes> for Body {
154 fn from(bytes: Bytes) -> Self {
155 Self::new(bytes)
156 }
157}
158
159impl From<String> for Body {
160 fn from(s: String) -> Self {
161 Self::new(Bytes::from(s))
162 }
163}
164
165impl From<&'static str> for Body {
166 fn from(s: &'static str) -> Self {
167 Self::new(Bytes::from(s))
168 }
169}
170
171impl From<Vec<u8>> for Body {
172 fn from(v: Vec<u8>) -> Self {
173 Self::new(Bytes::from(v))
174 }
175}
176
177pub type Response = http::Response<Body>;
179
180pub trait IntoResponse {
182 fn into_response(self) -> Response;
184}
185
186impl IntoResponse for Response {
188 fn into_response(self) -> Response {
189 self
190 }
191}
192
193impl IntoResponse for () {
195 fn into_response(self) -> Response {
196 http::Response::builder()
197 .status(StatusCode::OK)
198 .body(Body::empty())
199 .unwrap()
200 }
201}
202
203impl IntoResponse for &'static str {
205 fn into_response(self) -> Response {
206 http::Response::builder()
207 .status(StatusCode::OK)
208 .header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
209 .body(Body::from(self))
210 .unwrap()
211 }
212}
213
214impl IntoResponse for String {
216 fn into_response(self) -> Response {
217 http::Response::builder()
218 .status(StatusCode::OK)
219 .header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
220 .body(Body::from(self))
221 .unwrap()
222 }
223}
224
225impl IntoResponse for StatusCode {
227 fn into_response(self) -> Response {
228 http::Response::builder()
229 .status(self)
230 .body(Body::empty())
231 .unwrap()
232 }
233}
234
235impl<R: IntoResponse> IntoResponse for (StatusCode, R) {
237 fn into_response(self) -> Response {
238 let mut response = self.1.into_response();
239 *response.status_mut() = self.0;
240 response
241 }
242}
243
244impl<R: IntoResponse> IntoResponse for (StatusCode, HeaderMap, R) {
246 fn into_response(self) -> Response {
247 let mut response = self.2.into_response();
248 *response.status_mut() = self.0;
249 response.headers_mut().extend(self.1);
250 response
251 }
252}
253
254impl<T: IntoResponse, E: IntoResponse> IntoResponse for Result<T, E> {
256 fn into_response(self) -> Response {
257 match self {
258 Ok(v) => v.into_response(),
259 Err(e) => e.into_response(),
260 }
261 }
262}
263
264impl IntoResponse for ApiError {
267 fn into_response(self) -> Response {
268 let status = self.status;
269 let error_response = ErrorResponse::from(self);
271 let body = serde_json::to_vec(&error_response).unwrap_or_else(|_| {
272 br#"{"error":{"type":"internal_error","message":"Failed to serialize error"}}"#.to_vec()
273 });
274
275 http::Response::builder()
276 .status(status)
277 .header(header::CONTENT_TYPE, "application/json")
278 .body(Body::from(body))
279 .unwrap()
280 }
281}
282
283impl ResponseModifier for ApiError {
284 fn update_response(op: &mut Operation) {
285 op.responses.insert(
288 "400".to_string(),
289 ResponseSpec {
290 description: "Bad Request".to_string(),
291 content: {
292 let mut map = HashMap::new();
293 map.insert(
294 "application/json".to_string(),
295 MediaType {
296 schema: SchemaRef::Ref {
297 reference: "#/components/schemas/ErrorSchema".to_string(),
298 },
299 },
300 );
301 Some(map)
302 },
303 },
304 );
305
306 op.responses.insert(
308 "500".to_string(),
309 ResponseSpec {
310 description: "Internal Server Error".to_string(),
311 content: {
312 let mut map = HashMap::new();
313 map.insert(
314 "application/json".to_string(),
315 MediaType {
316 schema: SchemaRef::Ref {
317 reference: "#/components/schemas/ErrorSchema".to_string(),
318 },
319 },
320 );
321 Some(map)
322 },
323 },
324 );
325 }
326}
327
328#[derive(Debug, Clone)]
341pub struct Created<T>(pub T);
342
343impl<T: Serialize> IntoResponse for Created<T> {
344 fn into_response(self) -> Response {
345 match serde_json::to_vec(&self.0) {
346 Ok(body) => http::Response::builder()
347 .status(StatusCode::CREATED)
348 .header(header::CONTENT_TYPE, "application/json")
349 .body(Body::from(body))
350 .unwrap(),
351 Err(err) => {
352 ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
353 }
354 }
355 }
356}
357
358impl<T: for<'a> Schema<'a>> ResponseModifier for Created<T> {
359 fn update_response(op: &mut Operation) {
360 let (name, _) = T::schema();
361
362 let schema_ref = SchemaRef::Ref {
363 reference: format!("#/components/schemas/{}", name),
364 };
365
366 op.responses.insert(
367 "201".to_string(),
368 ResponseSpec {
369 description: "Created".to_string(),
370 content: {
371 let mut map = HashMap::new();
372 map.insert(
373 "application/json".to_string(),
374 MediaType { schema: schema_ref },
375 );
376 Some(map)
377 },
378 },
379 );
380 }
381}
382
383#[derive(Debug, Clone, Copy)]
396pub struct NoContent;
397
398impl IntoResponse for NoContent {
399 fn into_response(self) -> Response {
400 http::Response::builder()
401 .status(StatusCode::NO_CONTENT)
402 .body(Body::empty())
403 .unwrap()
404 }
405}
406
407impl ResponseModifier for NoContent {
408 fn update_response(op: &mut Operation) {
409 op.responses.insert(
410 "204".to_string(),
411 ResponseSpec {
412 description: "No Content".to_string(),
413 content: None,
414 },
415 );
416 }
417}
418
419#[derive(Debug, Clone)]
421pub struct Html<T>(pub T);
422
423impl<T: Into<String>> IntoResponse for Html<T> {
424 fn into_response(self) -> Response {
425 http::Response::builder()
426 .status(StatusCode::OK)
427 .header(header::CONTENT_TYPE, "text/html; charset=utf-8")
428 .body(Body::from(self.0.into()))
429 .unwrap()
430 }
431}
432
433impl<T> ResponseModifier for Html<T> {
434 fn update_response(op: &mut Operation) {
435 op.responses.insert(
436 "200".to_string(),
437 ResponseSpec {
438 description: "HTML Content".to_string(),
439 content: {
440 let mut map = HashMap::new();
441 map.insert(
442 "text/html".to_string(),
443 MediaType {
444 schema: SchemaRef::Inline(serde_json::json!({ "type": "string" })),
445 },
446 );
447 Some(map)
448 },
449 },
450 );
451 }
452}
453
454#[derive(Debug, Clone)]
456pub struct Redirect {
457 status: StatusCode,
458 location: HeaderValue,
459}
460
461impl Redirect {
462 pub fn to(uri: &str) -> Self {
464 Self {
465 status: StatusCode::FOUND,
466 location: HeaderValue::from_str(uri).expect("Invalid redirect URI"),
467 }
468 }
469
470 pub fn permanent(uri: &str) -> Self {
472 Self {
473 status: StatusCode::MOVED_PERMANENTLY,
474 location: HeaderValue::from_str(uri).expect("Invalid redirect URI"),
475 }
476 }
477
478 pub fn temporary(uri: &str) -> Self {
480 Self {
481 status: StatusCode::TEMPORARY_REDIRECT,
482 location: HeaderValue::from_str(uri).expect("Invalid redirect URI"),
483 }
484 }
485}
486
487impl IntoResponse for Redirect {
488 fn into_response(self) -> Response {
489 http::Response::builder()
490 .status(self.status)
491 .header(header::LOCATION, self.location)
492 .body(Body::empty())
493 .unwrap()
494 }
495}
496
497impl ResponseModifier for Redirect {
498 fn update_response(op: &mut Operation) {
499 op.responses.insert(
502 "3xx".to_string(),
503 ResponseSpec {
504 description: "Redirection".to_string(),
505 content: None,
506 },
507 );
508 }
509}
510
511#[derive(Debug, Clone)]
529pub struct WithStatus<T, const CODE: u16>(pub T);
530
531impl<T: IntoResponse, const CODE: u16> IntoResponse for WithStatus<T, CODE> {
532 fn into_response(self) -> Response {
533 let mut response = self.0.into_response();
534 if let Ok(status) = StatusCode::from_u16(CODE) {
536 *response.status_mut() = status;
537 }
538 response
539 }
540}
541
542impl<T: for<'a> Schema<'a>, const CODE: u16> ResponseModifier for WithStatus<T, CODE> {
543 fn update_response(op: &mut Operation) {
544 let (name, _) = T::schema();
545
546 let schema_ref = SchemaRef::Ref {
547 reference: format!("#/components/schemas/{}", name),
548 };
549
550 op.responses.insert(
551 CODE.to_string(),
552 ResponseSpec {
553 description: format!("Response with status {}", CODE),
554 content: {
555 let mut map = HashMap::new();
556 map.insert(
557 "application/json".to_string(),
558 MediaType { schema: schema_ref },
559 );
560 Some(map)
561 },
562 },
563 );
564 }
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570 use proptest::prelude::*;
571
572 async fn body_to_bytes(body: Body) -> Bytes {
574 use http_body_util::BodyExt;
575 body.collect().await.unwrap().to_bytes()
576 }
577
578 proptest! {
585 #![proptest_config(ProptestConfig::with_cases(100))]
586
587 #[test]
588 fn prop_with_status_response_correctness(
589 body in "[a-zA-Z0-9 ]{0,100}",
590 ) {
591 let rt = tokio::runtime::Runtime::new().unwrap();
592 rt.block_on(async {
593 let response_200: Response = WithStatus::<_, 200>(body.clone()).into_response();
599 prop_assert_eq!(response_200.status().as_u16(), 200);
600
601 let response_201: Response = WithStatus::<_, 201>(body.clone()).into_response();
603 prop_assert_eq!(response_201.status().as_u16(), 201);
604
605 let response_202: Response = WithStatus::<_, 202>(body.clone()).into_response();
607 prop_assert_eq!(response_202.status().as_u16(), 202);
608
609 let response_204: Response = WithStatus::<_, 204>(body.clone()).into_response();
611 prop_assert_eq!(response_204.status().as_u16(), 204);
612
613 let response_400: Response = WithStatus::<_, 400>(body.clone()).into_response();
615 prop_assert_eq!(response_400.status().as_u16(), 400);
616
617 let response_404: Response = WithStatus::<_, 404>(body.clone()).into_response();
619 prop_assert_eq!(response_404.status().as_u16(), 404);
620
621 let response_418: Response = WithStatus::<_, 418>(body.clone()).into_response();
623 prop_assert_eq!(response_418.status().as_u16(), 418);
624
625 let response_500: Response = WithStatus::<_, 500>(body.clone()).into_response();
627 prop_assert_eq!(response_500.status().as_u16(), 500);
628
629 let response_503: Response = WithStatus::<_, 503>(body.clone()).into_response();
631 prop_assert_eq!(response_503.status().as_u16(), 503);
632
633 let response_for_body: Response = WithStatus::<_, 200>(body.clone()).into_response();
635 let body_bytes = body_to_bytes(response_for_body.into_body()).await;
636 let body_str = String::from_utf8_lossy(&body_bytes);
637 prop_assert_eq!(body_str.as_ref(), body.as_str());
638
639 Ok(())
640 })?;
641 }
642 }
643
644 #[tokio::test]
645 async fn test_with_status_preserves_content_type() {
646 let response: Response = WithStatus::<_, 202>("hello world").into_response();
648
649 assert_eq!(response.status().as_u16(), 202);
650 assert_eq!(
651 response.headers().get(header::CONTENT_TYPE).unwrap(),
652 "text/plain; charset=utf-8"
653 );
654 }
655
656 #[tokio::test]
657 async fn test_with_status_with_empty_body() {
658 let response: Response = WithStatus::<_, 204>(()).into_response();
659
660 assert_eq!(response.status().as_u16(), 204);
661 let body_bytes = body_to_bytes(response.into_body()).await;
663 assert!(body_bytes.is_empty());
664 }
665
666 #[test]
667 fn test_with_status_common_codes() {
668 assert_eq!(
670 WithStatus::<_, 100>("").into_response().status().as_u16(),
671 100
672 ); assert_eq!(
674 WithStatus::<_, 200>("").into_response().status().as_u16(),
675 200
676 ); assert_eq!(
678 WithStatus::<_, 201>("").into_response().status().as_u16(),
679 201
680 ); assert_eq!(
682 WithStatus::<_, 202>("").into_response().status().as_u16(),
683 202
684 ); assert_eq!(
686 WithStatus::<_, 204>("").into_response().status().as_u16(),
687 204
688 ); assert_eq!(
690 WithStatus::<_, 301>("").into_response().status().as_u16(),
691 301
692 ); assert_eq!(
694 WithStatus::<_, 302>("").into_response().status().as_u16(),
695 302
696 ); assert_eq!(
698 WithStatus::<_, 400>("").into_response().status().as_u16(),
699 400
700 ); assert_eq!(
702 WithStatus::<_, 401>("").into_response().status().as_u16(),
703 401
704 ); assert_eq!(
706 WithStatus::<_, 403>("").into_response().status().as_u16(),
707 403
708 ); assert_eq!(
710 WithStatus::<_, 404>("").into_response().status().as_u16(),
711 404
712 ); assert_eq!(
714 WithStatus::<_, 500>("").into_response().status().as_u16(),
715 500
716 ); assert_eq!(
718 WithStatus::<_, 502>("").into_response().status().as_u16(),
719 502
720 ); assert_eq!(
722 WithStatus::<_, 503>("").into_response().status().as_u16(),
723 503
724 ); }
726}