volans_request/server/
handler.rs

1use std::{
2    convert::Infallible,
3    fmt, io,
4    task::{Context, Poll},
5    time::Duration,
6};
7
8use futures::{
9    AsyncWriteExt, FutureExt, SinkExt, StreamExt,
10    channel::{mpsc, oneshot},
11};
12use futures_bounded::{Delay, FuturesMap};
13use smallvec::SmallVec;
14use volans_swarm::{
15    ConnectionHandler, ConnectionHandlerEvent, InboundStreamHandler, InboundUpgradeSend,
16    SubstreamProtocol,
17};
18
19use crate::{Codec, RequestId, Upgrade};
20
21pub struct Handler<TCodec>
22where
23    TCodec: Codec,
24{
25    codec: TCodec,
26    protocols: SmallVec<[TCodec::Protocol; 2]>,
27    receiver: mpsc::Receiver<(
28        RequestId,
29        TCodec::Request,
30        oneshot::Sender<TCodec::Response>,
31    )>,
32    sender: mpsc::Sender<(
33        RequestId,
34        TCodec::Request,
35        oneshot::Sender<TCodec::Response>,
36    )>,
37    requesting: FuturesMap<RequestId, Result<Event<TCodec>, io::Error>>,
38}
39
40impl<TCodec> Handler<TCodec>
41where
42    TCodec: Codec + Send + 'static,
43{
44    pub fn new(
45        codec: TCodec,
46        protocols: SmallVec<[TCodec::Protocol; 2]>,
47        stream_timeout: Duration,
48    ) -> Self {
49        let (sender, receiver) = mpsc::channel(0);
50        Self {
51            codec,
52            protocols,
53            receiver,
54            sender,
55            requesting: FuturesMap::new(move || Delay::futures_timer(stream_timeout), 10),
56        }
57    }
58}
59
60pub enum Event<TCodec>
61where
62    TCodec: Codec,
63{
64    Request {
65        request_id: RequestId,
66        request: TCodec::Request,
67        sender: oneshot::Sender<TCodec::Response>,
68    },
69    Error {
70        request_id: RequestId,
71        error: io::Error,
72    },
73    Response(RequestId),
74    Discard(RequestId),
75    Timeout(RequestId),
76}
77
78impl<TCodec> fmt::Debug for Event<TCodec>
79where
80    TCodec: Codec,
81{
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        match self {
84            Event::Request { request_id, .. } => f
85                .debug_struct("InboundEvent::Request")
86                .field("request_id", request_id)
87                .finish(),
88            Event::Error { request_id, error } => f
89                .debug_struct("InboundEvent::Error")
90                .field("request_id", request_id)
91                .field("error", error)
92                .finish(),
93            Event::Response(request_id) => f
94                .debug_struct("InboundEvent::Response")
95                .field("request_id", request_id)
96                .finish(),
97            Event::Discard(request_id) => f
98                .debug_struct("InboundEvent::Discard")
99                .field("request_id", request_id)
100                .finish(),
101            Event::Timeout(request_id) => f
102                .debug_struct("InboundEvent::Timeout")
103                .field("request_id", request_id)
104                .finish(),
105        }
106    }
107}
108
109impl<TCodec> ConnectionHandler for Handler<TCodec>
110where
111    TCodec: Codec + Send + 'static,
112{
113    type Action = Infallible;
114    type Event = Event<TCodec>;
115
116    fn handle_action(&mut self, _action: Self::Action) {
117        unreachable!("Request handler does not support actions");
118    }
119
120    fn poll_close(&mut self, _: &mut Context<'_>) -> Poll<Option<Self::Event>> {
121        Poll::Ready(None)
122    }
123
124    fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ConnectionHandlerEvent<Self::Event>> {
125        match self.requesting.poll_unpin(cx) {
126            Poll::Ready((_, Ok(Ok(event)))) => {
127                return Poll::Ready(ConnectionHandlerEvent::Notify(event));
128            }
129            Poll::Ready((request_id, Ok(Err(error)))) => {
130                return Poll::Ready(ConnectionHandlerEvent::Notify(Event::Error {
131                    request_id,
132                    error,
133                }));
134            }
135            Poll::Ready((request_id, Err(_))) => {
136                return Poll::Ready(ConnectionHandlerEvent::Notify(Event::Timeout(request_id)));
137            }
138            Poll::Pending => {}
139        }
140
141        match self.receiver.poll_next_unpin(cx) {
142            Poll::Ready(Some((request_id, request, sender))) => {
143                return Poll::Ready(ConnectionHandlerEvent::Notify(Event::Request {
144                    request_id,
145                    request,
146                    sender,
147                }));
148            }
149            Poll::Ready(None) | Poll::Pending => {}
150        }
151
152        Poll::Pending
153    }
154}
155
156impl<TCodec> InboundStreamHandler for Handler<TCodec>
157where
158    TCodec: Codec + Clone + Send + 'static,
159{
160    type InboundUpgrade = Upgrade<TCodec::Protocol>;
161    type InboundUserData = ();
162
163    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundUpgrade, Self::InboundUserData> {
164        SubstreamProtocol::new(
165            Upgrade {
166                protocols: self.protocols.clone(),
167            },
168            (),
169        )
170    }
171
172    fn on_fully_negotiated(
173        &mut self,
174        _user_data: Self::InboundUserData,
175        (mut stream, protocol): <Self::InboundUpgrade as InboundUpgradeSend>::Output,
176    ) {
177        let mut codec = self.codec.clone();
178        let request_id = RequestId::next();
179        let mut sender = self.sender.clone();
180        let fut = async move {
181            let (response_sender, response_receiver) = oneshot::channel();
182            let request = codec.read_request(&protocol, &mut stream).await?;
183            sender
184                .send((request_id, request, response_sender))
185                .await
186                .expect("Request handler sender should not be closed");
187            drop(sender);
188            if let Ok(response) = response_receiver.await {
189                codec
190                    .write_response(&protocol, &mut stream, response)
191                    .await?;
192                stream.close().await?;
193                return Ok(Event::Response(request_id));
194            } else {
195                stream.close().await?;
196                return Ok(Event::Discard(request_id));
197            }
198        };
199        match self.requesting.try_push(request_id, fut.boxed()) {
200            Ok(()) => {}
201            Err(_) => {
202                tracing::warn!("Request handler is overloaded, dropping request");
203            }
204        }
205    }
206
207    fn on_upgrade_error(
208        &mut self,
209        _user_data: Self::InboundUserData,
210        _error: <Self::InboundUpgrade as InboundUpgradeSend>::Error,
211    ) {
212        unreachable!("Request handler does not support upgrade errors");
213    }
214}