Skip to main content

sui_http/middleware/callback/
body.rs

1// Copyright (c) Mysten Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use super::RequestHandler;
5use super::ResponseHandler;
6use http_body::Body;
7use pin_project_lite::pin_project;
8use std::fmt;
9use std::pin::Pin;
10use std::task::Context;
11use std::task::Poll;
12use std::task::ready;
13
14pin_project! {
15    /// Request body wrapper for [`Callback`].
16    ///
17    /// Forwards frames from the inner request body unchanged, surfacing
18    /// every event to the configured [`RequestHandler`] via
19    /// `on_body_chunk`, `on_end_of_stream`, or `on_body_error`.
20    ///
21    /// [`Callback`]: super::Callback
22    pub struct RequestBody<B, H> {
23        #[pin]
24        pub(crate) inner: B,
25        pub(crate) handler: H,
26        // Ensures `on_end_of_stream` fires at most once. A body that emits
27        // a trailers frame and then `Poll::Ready(None)` would otherwise
28        // trigger two end-of-stream callbacks.
29        pub(crate) ended: bool,
30    }
31}
32
33impl<B, H> Body for RequestBody<B, H>
34where
35    B: Body,
36    B::Error: fmt::Display + 'static,
37    H: RequestHandler,
38{
39    type Data = B::Data;
40    type Error = B::Error;
41
42    fn poll_frame(
43        self: Pin<&mut Self>,
44        cx: &mut Context<'_>,
45    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
46        let this = self.project();
47        let result = ready!(this.inner.poll_frame(cx));
48
49        match result {
50            Some(Ok(frame)) => {
51                if let Some(chunk) = frame.data_ref() {
52                    this.handler.on_body_chunk(chunk);
53                } else if let Some(trailers) = frame.trailers_ref()
54                    && !*this.ended
55                {
56                    this.handler.on_end_of_stream(Some(trailers));
57                    *this.ended = true;
58                }
59
60                Poll::Ready(Some(Ok(frame)))
61            }
62            Some(Err(err)) => {
63                this.handler.on_body_error(&err);
64
65                Poll::Ready(Some(Err(err)))
66            }
67            None => {
68                if !*this.ended {
69                    this.handler.on_end_of_stream(None);
70                    *this.ended = true;
71                }
72
73                Poll::Ready(None)
74            }
75        }
76    }
77
78    fn is_end_stream(&self) -> bool {
79        self.inner.is_end_stream()
80    }
81
82    fn size_hint(&self) -> http_body::SizeHint {
83        self.inner.size_hint()
84    }
85}
86
87pin_project! {
88    /// Response body wrapper for [`Callback`].
89    ///
90    /// Forwards frames from the inner response body unchanged, surfacing
91    /// every event to the configured [`ResponseHandler`] via
92    /// `on_body_chunk`, `on_end_of_stream`, or `on_body_error`.
93    ///
94    /// [`Callback`]: super::Callback
95    pub struct ResponseBody<B, H> {
96        #[pin]
97        pub(crate) inner: B,
98        pub(crate) handler: H,
99        pub(crate) ended: bool,
100    }
101}
102
103impl<B, H> Body for ResponseBody<B, H>
104where
105    B: Body,
106    B::Error: fmt::Display + 'static,
107    H: ResponseHandler,
108{
109    type Data = B::Data;
110    type Error = B::Error;
111
112    fn poll_frame(
113        self: Pin<&mut Self>,
114        cx: &mut Context<'_>,
115    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
116        let this = self.project();
117        let result = ready!(this.inner.poll_frame(cx));
118
119        match result {
120            Some(Ok(frame)) => {
121                if let Some(chunk) = frame.data_ref() {
122                    this.handler.on_body_chunk(chunk);
123                } else if let Some(trailers) = frame.trailers_ref()
124                    && !*this.ended
125                {
126                    this.handler.on_end_of_stream(Some(trailers));
127                    *this.ended = true;
128                }
129
130                Poll::Ready(Some(Ok(frame)))
131            }
132            Some(Err(err)) => {
133                this.handler.on_body_error(&err);
134
135                Poll::Ready(Some(Err(err)))
136            }
137            None => {
138                if !*this.ended {
139                    this.handler.on_end_of_stream(None);
140                    *this.ended = true;
141                }
142
143                Poll::Ready(None)
144            }
145        }
146    }
147
148    fn is_end_stream(&self) -> bool {
149        self.inner.is_end_stream()
150    }
151
152    fn size_hint(&self) -> http_body::SizeHint {
153        self.inner.size_hint()
154    }
155}