volans_request/client/
handler.rs

1use std::{
2    collections::VecDeque,
3    fmt, io,
4    task::{Context, Poll},
5    time::Duration,
6};
7
8use futures::{AsyncWriteExt, FutureExt};
9use futures_bounded::{Delay, FuturesMap};
10use volans_swarm::{
11    ConnectionHandler, ConnectionHandlerEvent, OutboundStreamHandler, OutboundUpgradeSend,
12    StreamUpgradeError, SubstreamProtocol,
13};
14
15use crate::{Codec, RequestId, Upgrade};
16
17pub struct Handler<TCodec>
18where
19    TCodec: Codec,
20{
21    codec: TCodec,
22    pending_outbound: VecDeque<OutboundRequest<TCodec>>,
23    requested_outbound: VecDeque<OutboundRequest<TCodec>>,
24    pending_events: VecDeque<Event<TCodec>>,
25    requesting: FuturesMap<RequestId, Result<Event<TCodec>, io::Error>>,
26}
27
28impl<TCodec> Handler<TCodec>
29where
30    TCodec: Codec + Send + 'static,
31{
32    pub fn new(codec: TCodec, stream_timeout: Duration) -> Self {
33        Self {
34            codec,
35            pending_outbound: VecDeque::new(),
36            requested_outbound: VecDeque::new(),
37            pending_events: VecDeque::new(),
38            requesting: FuturesMap::new(move || Delay::futures_timer(stream_timeout), 10),
39        }
40    }
41}
42
43pub enum Event<TCodec>
44where
45    TCodec: Codec,
46{
47    Response {
48        request_id: RequestId,
49        response: TCodec::Response,
50    },
51    Unsupported(RequestId),
52    Timeout(RequestId),
53    StreamError {
54        request_id: RequestId,
55        error: io::Error,
56    },
57}
58
59impl<TCodec> fmt::Debug for Event<TCodec>
60where
61    TCodec: Codec,
62{
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        match self {
65            Event::Response { request_id, .. } => f
66                .debug_struct("Response")
67                .field("request_id", request_id)
68                .finish_non_exhaustive(),
69            Event::Unsupported(request_id) => f
70                .debug_struct("UnsupportedProtocol")
71                .field("request_id", request_id)
72                .finish_non_exhaustive(),
73            Event::Timeout(request_id) => f
74                .debug_struct("Timeout")
75                .field("request_id", request_id)
76                .finish_non_exhaustive(),
77            Event::StreamError { request_id, error } => f
78                .debug_struct("StreamError")
79                .field("request_id", request_id)
80                .field("error", error)
81                .finish_non_exhaustive(),
82        }
83    }
84}
85
86pub struct OutboundRequest<TCodec: Codec> {
87    pub(crate) request_id: RequestId,
88    pub(crate) request: TCodec::Request,
89    pub(crate) protocol: TCodec::Protocol,
90}
91
92impl<TCodec: Codec> fmt::Debug for OutboundRequest<TCodec> {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        f.debug_struct("OutboundRequest").finish_non_exhaustive()
95    }
96}
97
98impl<TCodec> ConnectionHandler for Handler<TCodec>
99where
100    TCodec: Codec + Send + 'static,
101{
102    type Action = OutboundRequest<TCodec>;
103    type Event = Event<TCodec>;
104
105    fn handle_action(&mut self, action: Self::Action) {
106        self.pending_outbound.push_back(action);
107    }
108
109    fn poll_close(&mut self, _: &mut Context<'_>) -> Poll<Option<Self::Event>> {
110        if let Some(event) = self.pending_events.pop_front() {
111            return Poll::Ready(Some(event));
112        }
113        Poll::Ready(None)
114    }
115
116    fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::Event>> {
117        match self.requesting.poll_unpin(cx) {
118            Poll::Ready((_, Ok(Ok(event)))) => {
119                return Poll::Ready(ConnectionHandlerEvent::Notify(event));
120            }
121            Poll::Ready((request_id, Ok(Err(error)))) => {
122                return Poll::Ready(ConnectionHandlerEvent::Notify(Event::StreamError {
123                    request_id,
124                    error,
125                }));
126            }
127            Poll::Ready((request_id, Err(_))) => {
128                return Poll::Ready(ConnectionHandlerEvent::Notify(Event::Timeout(request_id)));
129            }
130            Poll::Pending => {}
131        }
132        if let Some(event) = self.pending_events.pop_front() {
133            return Poll::Ready(ConnectionHandlerEvent::Notify(event));
134        }
135        Poll::Pending
136    }
137}
138
139impl<TCodec> OutboundStreamHandler for Handler<TCodec>
140where
141    TCodec: Codec + Clone + Send + 'static,
142{
143    type OutboundUpgrade = Upgrade<TCodec::Protocol>;
144    type OutboundUserData = ();
145
146    fn on_fully_negotiated(
147        &mut self,
148        _user_data: Self::OutboundUserData,
149        (mut stream, protocol): <Self::OutboundUpgrade as OutboundUpgradeSend>::Output,
150    ) {
151        let message = self
152            .requested_outbound
153            .pop_front()
154            .expect("negotiated a stream without a pending message");
155
156        let mut codec = self.codec.clone();
157        let request_id = message.request_id;
158
159        let fut = async move {
160            let write = codec.write_request(&protocol, &mut stream, message.request);
161            write.await?;
162            stream.close().await?;
163            let read = codec.read_response(&protocol, &mut stream);
164            let response = read.await?;
165
166            Ok(Event::Response {
167                request_id,
168                response,
169            })
170        };
171
172        if self.requesting.try_push(request_id, fut.boxed()).is_err() {
173            self.pending_events.push_back(Event::StreamError {
174                request_id,
175                error: io::Error::other("max sub-streams reached"),
176            });
177        }
178    }
179
180    fn on_upgrade_error(
181        &mut self,
182        _user_data: Self::OutboundUserData,
183        error: StreamUpgradeError<<Self::OutboundUpgrade as OutboundUpgradeSend>::Error>,
184    ) {
185        let outbound = self
186            .requested_outbound
187            .pop_front()
188            .expect("negotiated a stream without a pending message");
189
190        match error {
191            StreamUpgradeError::Timeout => {
192                self.pending_events
193                    .push_back(Event::Timeout(outbound.request_id));
194            }
195            StreamUpgradeError::NegotiationFailed => {
196                self.pending_events
197                    .push_back(Event::Unsupported(outbound.request_id));
198            }
199            StreamUpgradeError::Apply(_) => {}
200            StreamUpgradeError::Io(error) => {
201                self.pending_events.push_back(Event::StreamError {
202                    request_id: outbound.request_id,
203                    error: error,
204                });
205            }
206        }
207    }
208
209    fn poll_outbound_request(
210        &mut self,
211        _cx: &mut Context<'_>,
212    ) -> Poll<SubstreamProtocol<Self::OutboundUpgrade, Self::OutboundUserData>> {
213        if let Some(request) = self.pending_outbound.pop_front() {
214            let protocol = request.protocol.clone();
215            self.requested_outbound.push_back(request);
216            return Poll::Ready(SubstreamProtocol::new(Upgrade::new_single(protocol), ()));
217        }
218        Poll::Pending
219    }
220}