1use crate::error::{ApiError, ErrorResponse};
74use bytes::Bytes;
75use futures_util::StreamExt;
76use http::{header, HeaderMap, HeaderValue, StatusCode};
77use http_body_util::Full;
78use rustapi_openapi::schema::{RustApiSchema, SchemaCtx};
79use rustapi_openapi::{MediaType, Operation, ResponseModifier, ResponseSpec, SchemaRef};
80use serde::Serialize;
81use std::collections::BTreeMap;
82use std::pin::Pin;
83use std::task::{Context, Poll};
84
85pub enum Body {
87 Full(Full<Bytes>),
89 Streaming(Pin<Box<dyn http_body::Body<Data = Bytes, Error = ApiError> + Send + 'static>>),
91}
92
93impl Body {
94 pub fn new(bytes: Bytes) -> Self {
96 Self::Full(Full::new(bytes))
97 }
98
99 pub fn empty() -> Self {
101 Self::Full(Full::new(Bytes::new()))
102 }
103
104 pub fn from_stream<S, E>(stream: S) -> Self
106 where
107 S: futures_util::Stream<Item = Result<Bytes, E>> + Send + 'static,
108 E: Into<ApiError> + 'static,
109 {
110 let body = http_body_util::StreamBody::new(
111 stream.map(|res| res.map_err(|e| e.into()).map(http_body::Frame::data)),
112 );
113 Self::Streaming(Box::pin(body))
114 }
115}
116
117impl Default for Body {
118 fn default() -> Self {
119 Self::empty()
120 }
121}
122
123impl http_body::Body for Body {
124 type Data = Bytes;
125 type Error = ApiError;
126
127 fn poll_frame(
128 self: Pin<&mut Self>,
129 cx: &mut Context<'_>,
130 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
131 match self.get_mut() {
132 Body::Full(b) => Pin::new(b)
133 .poll_frame(cx)
134 .map_err(|_| ApiError::internal("Infallible error")),
135 Body::Streaming(b) => b.as_mut().poll_frame(cx),
136 }
137 }
138
139 fn is_end_stream(&self) -> bool {
140 match self {
141 Body::Full(b) => b.is_end_stream(),
142 Body::Streaming(b) => b.is_end_stream(),
143 }
144 }
145
146 fn size_hint(&self) -> http_body::SizeHint {
147 match self {
148 Body::Full(b) => b.size_hint(),
149 Body::Streaming(b) => b.size_hint(),
150 }
151 }
152}
153
154impl From<Bytes> for Body {
155 fn from(bytes: Bytes) -> Self {
156 Self::new(bytes)
157 }
158}
159
160impl From<String> for Body {
161 fn from(s: String) -> Self {
162 Self::new(Bytes::from(s))
163 }
164}
165
166impl From<&'static str> for Body {
167 fn from(s: &'static str) -> Self {
168 Self::new(Bytes::from(s))
169 }
170}
171
172impl From<Vec<u8>> for Body {
173 fn from(v: Vec<u8>) -> Self {
174 Self::new(Bytes::from(v))
175 }
176}
177
178pub type Response = http::Response<Body>;
180
181pub trait IntoResponse {
183 fn into_response(self) -> Response;
185}
186
187impl IntoResponse for Response {
189 fn into_response(self) -> Response {
190 self
191 }
192}
193
194impl IntoResponse for () {
196 fn into_response(self) -> Response {
197 http::Response::builder()
198 .status(StatusCode::OK)
199 .body(Body::empty())
200 .unwrap()
201 }
202}
203
204impl IntoResponse for &'static str {
206 fn into_response(self) -> Response {
207 http::Response::builder()
208 .status(StatusCode::OK)
209 .header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
210 .body(Body::from(self))
211 .unwrap()
212 }
213}
214
215impl IntoResponse for String {
217 fn into_response(self) -> Response {
218 http::Response::builder()
219 .status(StatusCode::OK)
220 .header(header::CONTENT_TYPE, "text/plain; charset=utf-8")
221 .body(Body::from(self))
222 .unwrap()
223 }
224}
225
226impl IntoResponse for StatusCode {
228 fn into_response(self) -> Response {
229 http::Response::builder()
230 .status(self)
231 .body(Body::empty())
232 .unwrap()
233 }
234}
235
236impl<R: IntoResponse> IntoResponse for (StatusCode, R) {
238 fn into_response(self) -> Response {
239 let mut response = self.1.into_response();
240 *response.status_mut() = self.0;
241 response
242 }
243}
244
245impl<R: IntoResponse> IntoResponse for (StatusCode, HeaderMap, R) {
247 fn into_response(self) -> Response {
248 let mut response = self.2.into_response();
249 *response.status_mut() = self.0;
250 response.headers_mut().extend(self.1);
251 response
252 }
253}
254
255impl<T: IntoResponse, E: IntoResponse> IntoResponse for Result<T, E> {
257 fn into_response(self) -> Response {
258 match self {
259 Ok(v) => v.into_response(),
260 Err(e) => e.into_response(),
261 }
262 }
263}
264
265impl IntoResponse for ApiError {
268 fn into_response(self) -> Response {
269 let status = self.status;
270 let error_response = ErrorResponse::from(self);
272 let body = serde_json::to_vec(&error_response).unwrap_or_else(|_| {
273 br#"{"error":{"type":"internal_error","message":"Failed to serialize error"}}"#.to_vec()
274 });
275
276 http::Response::builder()
277 .status(status)
278 .header(header::CONTENT_TYPE, "application/json")
279 .body(Body::from(body))
280 .unwrap()
281 }
282}
283
284impl ResponseModifier for ApiError {
285 fn update_response(op: &mut Operation) {
286 op.responses.insert(
289 "400".to_string(),
290 ResponseSpec {
291 description: "Bad Request".to_string(),
292 content: {
293 let mut map = BTreeMap::new();
294 map.insert(
295 "application/json".to_string(),
296 MediaType {
297 schema: Some(SchemaRef::Ref {
298 reference: "#/components/schemas/ErrorSchema".to_string(),
299 }),
300 example: None,
301 },
302 );
303 map
304 },
305 headers: BTreeMap::new(),
306 },
307 );
308
309 op.responses.insert(
311 "500".to_string(),
312 ResponseSpec {
313 description: "Internal Server Error".to_string(),
314 content: {
315 let mut map = BTreeMap::new();
316 map.insert(
317 "application/json".to_string(),
318 MediaType {
319 schema: Some(SchemaRef::Ref {
320 reference: "#/components/schemas/ErrorSchema".to_string(),
321 }),
322 example: None,
323 },
324 );
325 map
326 },
327 headers: BTreeMap::new(),
328 },
329 );
330 }
331}
332
333#[derive(Debug, Clone)]
346pub struct Created<T>(pub T);
347
348impl<T: Serialize> IntoResponse for Created<T> {
349 fn into_response(self) -> Response {
350 match serde_json::to_vec(&self.0) {
351 Ok(body) => http::Response::builder()
352 .status(StatusCode::CREATED)
353 .header(header::CONTENT_TYPE, "application/json")
354 .body(Body::from(body))
355 .unwrap(),
356 Err(err) => {
357 ApiError::internal(format!("Failed to serialize response: {}", err)).into_response()
358 }
359 }
360 }
361}
362
363impl<T: RustApiSchema> ResponseModifier for Created<T> {
364 fn update_response(op: &mut Operation) {
365 let mut ctx = SchemaCtx::new();
366 let schema_ref = T::schema(&mut ctx);
367
368 op.responses.insert(
369 "201".to_string(),
370 ResponseSpec {
371 description: "Created".to_string(),
372 content: {
373 let mut map = BTreeMap::new();
374 map.insert(
375 "application/json".to_string(),
376 MediaType {
377 schema: Some(schema_ref),
378 example: None,
379 },
380 );
381 map
382 },
383 headers: BTreeMap::new(),
384 },
385 );
386 }
387}
388
389#[derive(Debug, Clone, Copy)]
402pub struct NoContent;
403
404impl IntoResponse for NoContent {
405 fn into_response(self) -> Response {
406 http::Response::builder()
407 .status(StatusCode::NO_CONTENT)
408 .body(Body::empty())
409 .unwrap()
410 }
411}
412
413impl ResponseModifier for NoContent {
414 fn update_response(op: &mut Operation) {
415 op.responses.insert(
416 "204".to_string(),
417 ResponseSpec {
418 description: "No Content".to_string(),
419 content: BTreeMap::new(),
420 headers: BTreeMap::new(),
421 },
422 );
423 }
424}
425
426#[derive(Debug, Clone)]
428pub struct Html<T>(pub T);
429
430impl<T: Into<String>> IntoResponse for Html<T> {
431 fn into_response(self) -> Response {
432 http::Response::builder()
433 .status(StatusCode::OK)
434 .header(header::CONTENT_TYPE, "text/html; charset=utf-8")
435 .body(Body::from(self.0.into()))
436 .unwrap()
437 }
438}
439
440impl<T> ResponseModifier for Html<T> {
441 fn update_response(op: &mut Operation) {
442 op.responses.insert(
443 "200".to_string(),
444 ResponseSpec {
445 description: "HTML Content".to_string(),
446 content: {
447 let mut map = BTreeMap::new();
448 map.insert(
449 "text/html".to_string(),
450 MediaType {
451 schema: Some(SchemaRef::Inline(
452 serde_json::json!({ "type": "string" }),
453 )),
454 example: None,
455 },
456 );
457 map
458 },
459 headers: BTreeMap::new(),
460 },
461 );
462 }
463}
464
465#[derive(Debug, Clone)]
467pub struct Redirect {
468 status: StatusCode,
469 location: HeaderValue,
470}
471
472impl Redirect {
473 pub fn to(uri: &str) -> Self {
475 Self {
476 status: StatusCode::FOUND,
477 location: HeaderValue::from_str(uri).expect("Invalid redirect URI"),
478 }
479 }
480
481 pub fn permanent(uri: &str) -> Self {
483 Self {
484 status: StatusCode::MOVED_PERMANENTLY,
485 location: HeaderValue::from_str(uri).expect("Invalid redirect URI"),
486 }
487 }
488
489 pub fn temporary(uri: &str) -> Self {
491 Self {
492 status: StatusCode::TEMPORARY_REDIRECT,
493 location: HeaderValue::from_str(uri).expect("Invalid redirect URI"),
494 }
495 }
496}
497
498impl IntoResponse for Redirect {
499 fn into_response(self) -> Response {
500 http::Response::builder()
501 .status(self.status)
502 .header(header::LOCATION, self.location)
503 .body(Body::empty())
504 .unwrap()
505 }
506}
507
508impl ResponseModifier for Redirect {
509 fn update_response(op: &mut Operation) {
510 op.responses.insert(
513 "3xx".to_string(),
514 ResponseSpec {
515 description: "Redirection".to_string(),
516 content: BTreeMap::new(),
517 headers: BTreeMap::new(),
518 },
519 );
520 }
521}
522
523#[derive(Debug, Clone)]
541pub struct WithStatus<T, const CODE: u16>(pub T);
542
543impl<T: IntoResponse, const CODE: u16> IntoResponse for WithStatus<T, CODE> {
544 fn into_response(self) -> Response {
545 let mut response = self.0.into_response();
546 if let Ok(status) = StatusCode::from_u16(CODE) {
548 *response.status_mut() = status;
549 }
550 response
551 }
552}
553
554impl<T: RustApiSchema, const CODE: u16> ResponseModifier for WithStatus<T, CODE> {
555 fn update_response(op: &mut Operation) {
556 let mut ctx = SchemaCtx::new();
557 let schema_ref = T::schema(&mut ctx);
558
559 op.responses.insert(
560 CODE.to_string(),
561 ResponseSpec {
562 description: format!("Response with status {}", CODE),
563 content: {
564 let mut map = BTreeMap::new();
565 map.insert(
566 "application/json".to_string(),
567 MediaType {
568 schema: Some(schema_ref),
569 example: None,
570 },
571 );
572 map
573 },
574 headers: BTreeMap::new(),
575 },
576 );
577 }
578}
579
580#[cfg(test)]
581mod tests {
582 use super::*;
583 use proptest::prelude::*;
584
585 async fn body_to_bytes(body: Body) -> Bytes {
587 use http_body_util::BodyExt;
588 body.collect().await.unwrap().to_bytes()
589 }
590
591 proptest! {
598 #![proptest_config(ProptestConfig::with_cases(100))]
599
600 #[test]
601 fn prop_with_status_response_correctness(
602 body in "[a-zA-Z0-9 ]{0,100}",
603 ) {
604 let rt = tokio::runtime::Runtime::new().unwrap();
605 rt.block_on(async {
606 let response_200: Response = WithStatus::<_, 200>(body.clone()).into_response();
612 prop_assert_eq!(response_200.status().as_u16(), 200);
613
614 let response_201: Response = WithStatus::<_, 201>(body.clone()).into_response();
616 prop_assert_eq!(response_201.status().as_u16(), 201);
617
618 let response_202: Response = WithStatus::<_, 202>(body.clone()).into_response();
620 prop_assert_eq!(response_202.status().as_u16(), 202);
621
622 let response_204: Response = WithStatus::<_, 204>(body.clone()).into_response();
624 prop_assert_eq!(response_204.status().as_u16(), 204);
625
626 let response_400: Response = WithStatus::<_, 400>(body.clone()).into_response();
628 prop_assert_eq!(response_400.status().as_u16(), 400);
629
630 let response_404: Response = WithStatus::<_, 404>(body.clone()).into_response();
632 prop_assert_eq!(response_404.status().as_u16(), 404);
633
634 let response_418: Response = WithStatus::<_, 418>(body.clone()).into_response();
636 prop_assert_eq!(response_418.status().as_u16(), 418);
637
638 let response_500: Response = WithStatus::<_, 500>(body.clone()).into_response();
640 prop_assert_eq!(response_500.status().as_u16(), 500);
641
642 let response_503: Response = WithStatus::<_, 503>(body.clone()).into_response();
644 prop_assert_eq!(response_503.status().as_u16(), 503);
645
646 let response_for_body: Response = WithStatus::<_, 200>(body.clone()).into_response();
648 let body_bytes = body_to_bytes(response_for_body.into_body()).await;
649 let body_str = String::from_utf8_lossy(&body_bytes);
650 prop_assert_eq!(body_str.as_ref(), body.as_str());
651
652 Ok(())
653 })?;
654 }
655 }
656
657 #[tokio::test]
658 async fn test_with_status_preserves_content_type() {
659 let response: Response = WithStatus::<_, 202>("hello world").into_response();
661
662 assert_eq!(response.status().as_u16(), 202);
663 assert_eq!(
664 response.headers().get(header::CONTENT_TYPE).unwrap(),
665 "text/plain; charset=utf-8"
666 );
667 }
668
669 #[tokio::test]
670 async fn test_with_status_with_empty_body() {
671 let response: Response = WithStatus::<_, 204>(()).into_response();
672
673 assert_eq!(response.status().as_u16(), 204);
674 let body_bytes = body_to_bytes(response.into_body()).await;
676 assert!(body_bytes.is_empty());
677 }
678
679 #[test]
680 fn test_with_status_common_codes() {
681 assert_eq!(
683 WithStatus::<_, 100>("").into_response().status().as_u16(),
684 100
685 ); assert_eq!(
687 WithStatus::<_, 200>("").into_response().status().as_u16(),
688 200
689 ); assert_eq!(
691 WithStatus::<_, 201>("").into_response().status().as_u16(),
692 201
693 ); assert_eq!(
695 WithStatus::<_, 202>("").into_response().status().as_u16(),
696 202
697 ); assert_eq!(
699 WithStatus::<_, 204>("").into_response().status().as_u16(),
700 204
701 ); assert_eq!(
703 WithStatus::<_, 301>("").into_response().status().as_u16(),
704 301
705 ); assert_eq!(
707 WithStatus::<_, 302>("").into_response().status().as_u16(),
708 302
709 ); assert_eq!(
711 WithStatus::<_, 400>("").into_response().status().as_u16(),
712 400
713 ); assert_eq!(
715 WithStatus::<_, 401>("").into_response().status().as_u16(),
716 401
717 ); assert_eq!(
719 WithStatus::<_, 403>("").into_response().status().as_u16(),
720 403
721 ); assert_eq!(
723 WithStatus::<_, 404>("").into_response().status().as_u16(),
724 404
725 ); assert_eq!(
727 WithStatus::<_, 500>("").into_response().status().as_u16(),
728 500
729 ); assert_eq!(
731 WithStatus::<_, 502>("").into_response().status().as_u16(),
732 502
733 ); assert_eq!(
735 WithStatus::<_, 503>("").into_response().status().as_u16(),
736 503
737 ); }
739}