1use hyper::body::{Body, Frame, SizeHint};
2use pin_project_lite::pin_project;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6use std::task::{Context, Poll};
7use tonic::body::BoxBody;
8use tonic::transport::Channel;
9use tower_service::Service;
10
11#[derive(Clone, Debug, Default)]
13pub struct FrameSignal(Arc<AtomicBool>);
14
15impl FrameSignal {
16 fn signal(&self) {
17 self.0.store(true, Ordering::Release)
18 }
19
20 pub fn new() -> Self {
21 Self(Arc::new(AtomicBool::new(false)))
22 }
23
24 pub fn is_signalled(&self) -> bool {
25 self.0.load(Ordering::Acquire)
26 }
27
28 pub fn reset(&self) {
29 self.0.store(false, Ordering::Release)
30 }
31}
32
33pin_project! {
34 struct RequestFrameMonitorBody<B> {
35 #[pin]
36 inner: B,
37 frame_signal: FrameSignal,
38 }
39}
40
41impl<B> Body for RequestFrameMonitorBody<B>
42where
43 B: Body,
44{
45 type Data = B::Data;
46 type Error = B::Error;
47
48 fn poll_frame(
49 self: Pin<&mut Self>,
50 cx: &mut Context<'_>,
51 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
52 let this = self.project();
53 match this.inner.poll_frame(cx) {
54 Poll::Ready(Some(res)) => match res {
55 Ok(frame) => {
56 this.frame_signal.signal();
57 Poll::Ready(Some(Ok(frame)))
58 }
59 Err(status) => Poll::Ready(Some(Err(status))),
60 },
61 Poll::Ready(None) => Poll::Ready(None),
62 Poll::Pending => Poll::Pending,
63 }
64 }
65
66 fn is_end_stream(&self) -> bool {
67 self.inner.is_end_stream()
68 }
69
70 fn size_hint(&self) -> SizeHint {
71 self.inner.size_hint()
72 }
73}
74
75#[derive(Clone, Debug)]
77pub struct RequestFrameMonitor<S = Channel>
78where
79 S: Clone,
80{
81 inner: S,
83
84 frame_signal: FrameSignal,
86}
87
88impl<S: Clone> RequestFrameMonitor<S> {
89 pub fn new(inner: S, frame_signal: FrameSignal) -> Self {
90 Self {
91 inner,
92 frame_signal: frame_signal.clone(),
93 }
94 }
95}
96
97impl<S> Service<http::Request<BoxBody>> for RequestFrameMonitor<S>
98where
99 S: Service<http::Request<BoxBody>> + Clone,
100{
101 type Response = S::Response;
102 type Error = S::Error;
103 type Future = S::Future;
104
105 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
106 self.inner.poll_ready(cx)
107 }
108
109 fn call(&mut self, req: http::Request<BoxBody>) -> Self::Future {
110 let (head, body) = req.into_parts();
111 let body = BoxBody::new(RequestFrameMonitorBody {
112 inner: body,
113 frame_signal: self.frame_signal.clone(),
114 });
115 let clone = self.inner.clone();
117 let mut inner = std::mem::replace(&mut self.inner, clone);
118 inner.call(http::Request::from_parts(head, body))
119 }
120}