1use std::{
2 collections::VecDeque,
3 fmt, io,
4 task::{Context, Poll},
5 time::Duration,
6};
7
8use futures::{AsyncWriteExt, FutureExt};
9use futures_bounded::{Delay, FuturesMap};
10use volans_swarm::{
11 ConnectionHandler, ConnectionHandlerEvent, OutboundStreamHandler, OutboundUpgradeSend,
12 StreamUpgradeError, SubstreamProtocol,
13};
14
15use crate::{Codec, RequestId, Upgrade};
16
17pub struct Handler<TCodec>
18where
19 TCodec: Codec,
20{
21 codec: TCodec,
22 pending_outbound: VecDeque<OutboundRequest<TCodec>>,
23 requested_outbound: VecDeque<OutboundRequest<TCodec>>,
24 pending_events: VecDeque<Event<TCodec>>,
25 requesting: FuturesMap<RequestId, Result<Event<TCodec>, io::Error>>,
26}
27
28impl<TCodec> Handler<TCodec>
29where
30 TCodec: Codec + Send + 'static,
31{
32 pub fn new(codec: TCodec, stream_timeout: Duration) -> Self {
33 Self {
34 codec,
35 pending_outbound: VecDeque::new(),
36 requested_outbound: VecDeque::new(),
37 pending_events: VecDeque::new(),
38 requesting: FuturesMap::new(move || Delay::futures_timer(stream_timeout), 10),
39 }
40 }
41}
42
43pub enum Event<TCodec>
44where
45 TCodec: Codec,
46{
47 Response {
48 request_id: RequestId,
49 response: TCodec::Response,
50 },
51 Unsupported(RequestId),
52 Timeout(RequestId),
53 StreamError {
54 request_id: RequestId,
55 error: io::Error,
56 },
57}
58
59impl<TCodec> fmt::Debug for Event<TCodec>
60where
61 TCodec: Codec,
62{
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 match self {
65 Event::Response { request_id, .. } => f
66 .debug_struct("Response")
67 .field("request_id", request_id)
68 .finish_non_exhaustive(),
69 Event::Unsupported(request_id) => f
70 .debug_struct("UnsupportedProtocol")
71 .field("request_id", request_id)
72 .finish_non_exhaustive(),
73 Event::Timeout(request_id) => f
74 .debug_struct("Timeout")
75 .field("request_id", request_id)
76 .finish_non_exhaustive(),
77 Event::StreamError { request_id, error } => f
78 .debug_struct("StreamError")
79 .field("request_id", request_id)
80 .field("error", error)
81 .finish_non_exhaustive(),
82 }
83 }
84}
85
86pub struct OutboundRequest<TCodec: Codec> {
87 pub(crate) request_id: RequestId,
88 pub(crate) request: TCodec::Request,
89 pub(crate) protocol: TCodec::Protocol,
90}
91
92impl<TCodec: Codec> fmt::Debug for OutboundRequest<TCodec> {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 f.debug_struct("OutboundRequest").finish_non_exhaustive()
95 }
96}
97
98impl<TCodec> ConnectionHandler for Handler<TCodec>
99where
100 TCodec: Codec + Send + 'static,
101{
102 type Action = OutboundRequest<TCodec>;
103 type Event = Event<TCodec>;
104
105 fn handle_action(&mut self, action: Self::Action) {
106 self.pending_outbound.push_back(action);
107 }
108
109 fn poll_close(&mut self, _: &mut Context<'_>) -> Poll<Option<Self::Event>> {
110 if let Some(event) = self.pending_events.pop_front() {
111 return Poll::Ready(Some(event));
112 }
113 Poll::Ready(None)
114 }
115
116 fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::Event>> {
117 match self.requesting.poll_unpin(cx) {
118 Poll::Ready((_, Ok(Ok(event)))) => {
119 return Poll::Ready(ConnectionHandlerEvent::Notify(event));
120 }
121 Poll::Ready((request_id, Ok(Err(error)))) => {
122 return Poll::Ready(ConnectionHandlerEvent::Notify(Event::StreamError {
123 request_id,
124 error,
125 }));
126 }
127 Poll::Ready((request_id, Err(_))) => {
128 return Poll::Ready(ConnectionHandlerEvent::Notify(Event::Timeout(request_id)));
129 }
130 Poll::Pending => {}
131 }
132 if let Some(event) = self.pending_events.pop_front() {
133 return Poll::Ready(ConnectionHandlerEvent::Notify(event));
134 }
135 Poll::Pending
136 }
137}
138
139impl<TCodec> OutboundStreamHandler for Handler<TCodec>
140where
141 TCodec: Codec + Clone + Send + 'static,
142{
143 type OutboundUpgrade = Upgrade<TCodec::Protocol>;
144 type OutboundUserData = ();
145
146 fn on_fully_negotiated(
147 &mut self,
148 _user_data: Self::OutboundUserData,
149 (mut stream, protocol): <Self::OutboundUpgrade as OutboundUpgradeSend>::Output,
150 ) {
151 let message = self
152 .requested_outbound
153 .pop_front()
154 .expect("negotiated a stream without a pending message");
155
156 let mut codec = self.codec.clone();
157 let request_id = message.request_id;
158
159 let fut = async move {
160 let write = codec.write_request(&protocol, &mut stream, message.request);
161 write.await?;
162 stream.close().await?;
163 let read = codec.read_response(&protocol, &mut stream);
164 let response = read.await?;
165
166 Ok(Event::Response {
167 request_id,
168 response,
169 })
170 };
171
172 if self.requesting.try_push(request_id, fut.boxed()).is_err() {
173 self.pending_events.push_back(Event::StreamError {
174 request_id,
175 error: io::Error::other("max sub-streams reached"),
176 });
177 }
178 }
179
180 fn on_upgrade_error(
181 &mut self,
182 _user_data: Self::OutboundUserData,
183 error: StreamUpgradeError<<Self::OutboundUpgrade as OutboundUpgradeSend>::Error>,
184 ) {
185 let outbound = self
186 .requested_outbound
187 .pop_front()
188 .expect("negotiated a stream without a pending message");
189
190 match error {
191 StreamUpgradeError::Timeout => {
192 self.pending_events
193 .push_back(Event::Timeout(outbound.request_id));
194 }
195 StreamUpgradeError::NegotiationFailed => {
196 self.pending_events
197 .push_back(Event::Unsupported(outbound.request_id));
198 }
199 StreamUpgradeError::Apply(_) => {}
200 StreamUpgradeError::Io(error) => {
201 self.pending_events.push_back(Event::StreamError {
202 request_id: outbound.request_id,
203 error: error,
204 });
205 }
206 }
207 }
208
209 fn poll_outbound_request(
210 &mut self,
211 _cx: &mut Context<'_>,
212 ) -> Poll<SubstreamProtocol<Self::OutboundUpgrade, Self::OutboundUserData>> {
213 if let Some(request) = self.pending_outbound.pop_front() {
214 let protocol = request.protocol.clone();
215 self.requested_outbound.push_back(request);
216 return Poll::Ready(SubstreamProtocol::new(Upgrade::new_single(protocol), ()));
217 }
218 Poll::Pending
219 }
220}