1use std::{
2 convert::Infallible,
3 fmt, io,
4 task::{Context, Poll},
5 time::Duration,
6};
7
8use futures::{
9 AsyncWriteExt, FutureExt, SinkExt, StreamExt,
10 channel::{mpsc, oneshot},
11};
12use futures_bounded::{Delay, FuturesMap};
13use smallvec::SmallVec;
14use volans_swarm::{
15 ConnectionHandler, ConnectionHandlerEvent, InboundStreamHandler, InboundUpgradeSend,
16 SubstreamProtocol,
17};
18
19use crate::{Codec, RequestId, Upgrade};
20
21pub struct Handler<TCodec>
22where
23 TCodec: Codec,
24{
25 codec: TCodec,
26 protocols: SmallVec<[TCodec::Protocol; 2]>,
27 receiver: mpsc::Receiver<(
28 RequestId,
29 TCodec::Request,
30 oneshot::Sender<TCodec::Response>,
31 )>,
32 sender: mpsc::Sender<(
33 RequestId,
34 TCodec::Request,
35 oneshot::Sender<TCodec::Response>,
36 )>,
37 requesting: FuturesMap<RequestId, Result<Event<TCodec>, io::Error>>,
38}
39
40impl<TCodec> Handler<TCodec>
41where
42 TCodec: Codec + Send + 'static,
43{
44 pub fn new(
45 codec: TCodec,
46 protocols: SmallVec<[TCodec::Protocol; 2]>,
47 stream_timeout: Duration,
48 ) -> Self {
49 let (sender, receiver) = mpsc::channel(0);
50 Self {
51 codec,
52 protocols,
53 receiver,
54 sender,
55 requesting: FuturesMap::new(move || Delay::futures_timer(stream_timeout), 10),
56 }
57 }
58}
59
60pub enum Event<TCodec>
61where
62 TCodec: Codec,
63{
64 Request {
65 request_id: RequestId,
66 request: TCodec::Request,
67 sender: oneshot::Sender<TCodec::Response>,
68 },
69 Error {
70 request_id: RequestId,
71 error: io::Error,
72 },
73 Response(RequestId),
74 Discard(RequestId),
75 Timeout(RequestId),
76}
77
78impl<TCodec> fmt::Debug for Event<TCodec>
79where
80 TCodec: Codec,
81{
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 match self {
84 Event::Request { request_id, .. } => f
85 .debug_struct("InboundEvent::Request")
86 .field("request_id", request_id)
87 .finish(),
88 Event::Error { request_id, error } => f
89 .debug_struct("InboundEvent::Error")
90 .field("request_id", request_id)
91 .field("error", error)
92 .finish(),
93 Event::Response(request_id) => f
94 .debug_struct("InboundEvent::Response")
95 .field("request_id", request_id)
96 .finish(),
97 Event::Discard(request_id) => f
98 .debug_struct("InboundEvent::Discard")
99 .field("request_id", request_id)
100 .finish(),
101 Event::Timeout(request_id) => f
102 .debug_struct("InboundEvent::Timeout")
103 .field("request_id", request_id)
104 .finish(),
105 }
106 }
107}
108
109impl<TCodec> ConnectionHandler for Handler<TCodec>
110where
111 TCodec: Codec + Send + 'static,
112{
113 type Action = Infallible;
114 type Event = Event<TCodec>;
115
116 fn handle_action(&mut self, _action: Self::Action) {
117 unreachable!("Request handler does not support actions");
118 }
119
120 fn poll_close(&mut self, _: &mut Context<'_>) -> Poll<Option<Self::Event>> {
121 Poll::Ready(None)
122 }
123
124 fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::Event>> {
125 match self.requesting.poll_unpin(cx) {
126 Poll::Ready((_, Ok(Ok(event)))) => {
127 return Poll::Ready(ConnectionHandlerEvent::Notify(event));
128 }
129 Poll::Ready((request_id, Ok(Err(error)))) => {
130 return Poll::Ready(ConnectionHandlerEvent::Notify(Event::Error {
131 request_id,
132 error,
133 }));
134 }
135 Poll::Ready((request_id, Err(_))) => {
136 return Poll::Ready(ConnectionHandlerEvent::Notify(Event::Timeout(request_id)));
137 }
138 Poll::Pending => {}
139 }
140
141 match self.receiver.poll_next_unpin(cx) {
142 Poll::Ready(Some((request_id, request, sender))) => {
143 return Poll::Ready(ConnectionHandlerEvent::Notify(Event::Request {
144 request_id,
145 request,
146 sender,
147 }));
148 }
149 Poll::Ready(None) | Poll::Pending => {}
150 }
151
152 Poll::Pending
153 }
154}
155
156impl<TCodec> InboundStreamHandler for Handler<TCodec>
157where
158 TCodec: Codec + Clone + Send + 'static,
159{
160 type InboundUpgrade = Upgrade<TCodec::Protocol>;
161 type InboundUserData = ();
162
163 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundUpgrade, Self::InboundUserData> {
164 SubstreamProtocol::new(
165 Upgrade {
166 protocols: self.protocols.clone(),
167 },
168 (),
169 )
170 }
171
172 fn on_fully_negotiated(
173 &mut self,
174 _user_data: Self::InboundUserData,
175 (mut stream, protocol): <Self::InboundUpgrade as InboundUpgradeSend>::Output,
176 ) {
177 let mut codec = self.codec.clone();
178 let request_id = RequestId::next();
179 let mut sender = self.sender.clone();
180 let fut = async move {
181 let (response_sender, response_receiver) = oneshot::channel();
182 let request = codec.read_request(&protocol, &mut stream).await?;
183 sender
184 .send((request_id, request, response_sender))
185 .await
186 .expect("Request handler sender should not be closed");
187 drop(sender);
188 if let Ok(response) = response_receiver.await {
189 codec
190 .write_response(&protocol, &mut stream, response)
191 .await?;
192 stream.close().await?;
193 return Ok(Event::Response(request_id));
194 } else {
195 stream.close().await?;
196 return Ok(Event::Discard(request_id));
197 }
198 };
199 match self.requesting.try_push(request_id, fut.boxed()) {
200 Ok(()) => {}
201 Err(_) => {
202 tracing::warn!("Request handler is overloaded, dropping request");
203 }
204 }
205 }
206
207 fn on_upgrade_error(
208 &mut self,
209 _user_data: Self::InboundUserData,
210 _error: <Self::InboundUpgrade as InboundUpgradeSend>::Error,
211 ) {
212 unreachable!("Request handler does not support upgrade errors");
213 }
214}