xitca_web/handler/types/
websocket.rs

1use core::{
2    cmp::Ordering,
3    convert::Infallible,
4    future::{Future, poll_fn},
5    pin::{Pin, pin},
6    time::Duration,
7};
8
9use std::io;
10
11use futures_core::stream::Stream;
12use http_ws::{
13    HandshakeError, Item, Message as WsMessage, ProtocolError, WsOutput,
14    stream::{RequestStream, WsError},
15};
16use tokio::time::{Instant, sleep};
17use xitca_unsafe_collection::{
18    bytes::BytesStr,
19    futures::{Select, SelectOutput},
20};
21
22use crate::{
23    body::{BodyStream, RequestBody, ResponseBody},
24    bytes::Bytes,
25    context::WebContext,
26    error::{Error, HeaderNotFound},
27    handler::{FromRequest, Responder},
28    http::{
29        StatusCode, WebResponse,
30        header::{CONNECTION, SEC_WEBSOCKET_VERSION, UPGRADE},
31    },
32    service::Service,
33};
34
35pub use http_ws::{ResponseSender, ResponseWeakSender};
36
37/// simplified websocket message type.
38/// for more variant of message please reference [http_ws::Message] type.
39#[derive(Debug, Eq, PartialEq)]
40pub enum Message {
41    Text(BytesStr),
42    Binary(Bytes),
43    Continuation(Item),
44}
45
46type BoxFuture<'a> = Pin<Box<dyn Future<Output = ()> + 'a>>;
47
48type OnMsgCB = Box<dyn for<'a> FnMut(&'a mut ResponseSender, Message) -> BoxFuture<'a>>;
49
50type OnErrCB<E> = Box<dyn FnMut(WsError<E>) -> BoxFuture<'static>>;
51
52type OnCloseCB<B> = Box<dyn for<'a> FnOnce(Pin<&'a mut RequestStream<B>>) -> BoxFuture<'a>>;
53
54pub struct WebSocket<B = RequestBody>
55where
56    B: BodyStream,
57{
58    ws: WsOutput<B>,
59    ping_interval: Duration,
60    max_unanswered_ping: u8,
61    on_msg: OnMsgCB,
62    on_err: OnErrCB<B::Error>,
63    on_close: OnCloseCB<B>,
64}
65
66impl<B> WebSocket<B>
67where
68    B: BodyStream,
69{
70    fn new(ws: WsOutput<B>) -> Self {
71        #[cold]
72        #[inline(never)]
73        fn boxed_future() -> BoxFuture<'static> {
74            Box::pin(async {})
75        }
76
77        Self {
78            ws,
79            ping_interval: Duration::from_secs(15),
80            max_unanswered_ping: 3,
81            on_msg: Box::new(|_, _| boxed_future()),
82            on_err: Box::new(|_| boxed_future()),
83            on_close: Box::new(|_| boxed_future()),
84        }
85    }
86
87    /// Set interval duration of server side ping message to client.
88    pub fn set_ping_interval(&mut self, dur: Duration) -> &mut Self {
89        self.ping_interval = dur;
90        self
91    }
92
93    /// Set max number of consecutive server side ping messages that are not
94    /// answered by client.
95    ///
96    /// # Panic:
97    /// when 0 is passed as argument.
98    pub fn set_max_unanswered_ping(&mut self, size: u8) -> &mut Self {
99        assert!(size > 0, "max_unanswered_ping MUST be none 0");
100        self.max_unanswered_ping = size;
101        self
102    }
103
104    /// Get a reference of Websocket message sender.
105    /// Can be used to send message to client.
106    pub fn msg_sender(&self) -> &ResponseSender {
107        &self.ws.2
108    }
109
110    /// Async function that would be called when new message arrived from client.
111    pub fn on_msg<F>(&mut self, func: F) -> &mut Self
112    where
113        F: for<'a> FnMut(&'a mut ResponseSender, Message) -> BoxFuture<'a> + 'static,
114    {
115        self.on_msg = Box::new(func);
116        self
117    }
118
119    /// Async function that would be called when error occurred.
120    pub fn on_err<F, Fut>(&mut self, mut func: F) -> &mut Self
121    where
122        F: FnMut(WsError<B::Error>) -> Fut + 'static,
123        Fut: Future<Output = ()> + 'static,
124    {
125        self.on_err = Box::new(move |e| Box::pin(func(e)));
126        self
127    }
128
129    /// Async function that would be called when closing the websocket connection.
130    pub fn on_close<F, Fut>(&mut self, func: F) -> &mut Self
131    where
132        F: FnOnce(Pin<&mut RequestStream<B>>) -> Fut + 'static,
133        Fut: Future<Output = ()> + 'static,
134    {
135        self.on_close = Box::new(|stream| Box::pin(func(stream)));
136        self
137    }
138}
139
140impl<'r, C, B> Service<WebContext<'r, C, B>> for HandshakeError {
141    type Response = WebResponse;
142    type Error = Infallible;
143
144    async fn call(&self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
145        let e = match self {
146            HandshakeError::NoConnectionUpgrade => HeaderNotFound(CONNECTION),
147            HandshakeError::NoVersionHeader => HeaderNotFound(SEC_WEBSOCKET_VERSION),
148            HandshakeError::NoWebsocketUpgrade => HeaderNotFound(UPGRADE),
149            // TODO: refine error mapping of the remaining branches.
150            _ => return StatusCode::INTERNAL_SERVER_ERROR.call(ctx).await,
151        };
152
153        e.call(ctx).await
154    }
155}
156
157impl<'a, 'r, C, B> FromRequest<'a, WebContext<'r, C, B>> for WebSocket<B>
158where
159    C: 'static,
160    B: BodyStream + Default + 'static,
161{
162    type Type<'b> = WebSocket<B>;
163    type Error = Error;
164
165    #[inline]
166    async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
167        let body = ctx.take_body_ref();
168        let ws = http_ws::ws(ctx.req(), body).map_err(Error::from_service)?;
169        Ok(WebSocket::new(ws))
170    }
171}
172
173impl<'r, C, B> Responder<WebContext<'r, C, B>> for WebSocket<B>
174where
175    B: BodyStream + 'static,
176{
177    type Response = WebResponse;
178    type Error = Infallible;
179
180    async fn respond(self, _: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
181        let Self {
182            ws,
183            ping_interval,
184            max_unanswered_ping,
185            on_msg,
186            on_err,
187            on_close,
188        } = self;
189
190        let (decode, res, tx) = ws;
191
192        tokio::task::spawn_local(spawn_task(
193            ping_interval,
194            max_unanswered_ping,
195            decode,
196            tx,
197            on_msg,
198            on_err,
199            on_close,
200        ));
201
202        Ok(res.map(ResponseBody::box_stream))
203    }
204}
205
206async fn spawn_task<B>(
207    ping_interval: Duration,
208    max_unanswered_ping: u8,
209    decode: RequestStream<B>,
210    mut tx: ResponseSender,
211    mut on_msg: OnMsgCB,
212    mut on_err: OnErrCB<B::Error>,
213    on_close: OnCloseCB<B>,
214) where
215    B: BodyStream,
216{
217    let on_msg = &mut *on_msg;
218    let on_err = &mut *on_err;
219
220    let mut decode = pin!(decode);
221
222    let spawn_inner = async {
223        let mut sleep = pin!(sleep(ping_interval));
224
225        let mut un_answered_ping = 0u8;
226
227        loop {
228            match poll_fn(|cx| decode.as_mut().poll_next(cx)).select(sleep.as_mut()).await {
229                SelectOutput::A(Some(Ok(msg))) => {
230                    let msg = match msg {
231                        WsMessage::Text(txt) => Message::Text(BytesStr::try_from(txt).unwrap()),
232                        WsMessage::Binary(bin) => Message::Binary(bin),
233                        WsMessage::Continuation(item) => Message::Continuation(item),
234                        WsMessage::Nop => continue,
235                        WsMessage::Pong(_) => {
236                            if let Some(num) = un_answered_ping.checked_sub(1) {
237                                un_answered_ping = num;
238                            }
239                            continue;
240                        }
241                        WsMessage::Ping(ping) => {
242                            tx.send(WsMessage::Pong(ping)).await?;
243                            continue;
244                        }
245                        WsMessage::Close(reason) => {
246                            match tx.send(WsMessage::Close(reason)).await {
247                                // ProtocolError::Closed error means someone already sent close message
248                                // so just ignore it and treat as success.
249                                Ok(_) | Err(ProtocolError::Closed) => return Ok(()),
250                                Err(e) => return Err(e.into()),
251                            }
252                        }
253                    };
254
255                    on_msg(&mut tx, msg).await
256                }
257                SelectOutput::A(Some(Err(e))) => on_err(e).await,
258                SelectOutput::A(None) => return Ok(()),
259                SelectOutput::B(_) => match un_answered_ping.cmp(&max_unanswered_ping) {
260                    Ordering::Less => {
261                        if let Err(e) = tx.send(WsMessage::Ping(Bytes::new())).await {
262                            // continue ping timer when websocket is closed.
263                            // client may be lagging behind and not respond to close message immediately.
264                            if !matches!(e, ProtocolError::Closed) {
265                                return Err(e.into());
266                            }
267                        }
268                        un_answered_ping += 1;
269                        sleep.as_mut().reset(Instant::now() + ping_interval);
270                    }
271                    // on last interval try to send close message to client to inform it connection
272                    // is going away.
273                    Ordering::Equal => match tx.send(WsMessage::Close(None)).await {
274                        Ok(_) => un_answered_ping += 1,
275                        // ProtocolError::Closed error means someone already sent close message
276                        // so just ignore it and end connection right away.
277                        Err(ProtocolError::Closed) => return Ok(()),
278                        Err(e) => return Err(e.into()),
279                    },
280                    // this will only happen when client fail to respond to the close message on last
281                    // interval in time and at this point just closed the connection with an io error.
282                    Ordering::Greater => {
283                        let _ = tx.send_error(io::ErrorKind::UnexpectedEof.into()).await;
284                        return Ok(());
285                    }
286                },
287            }
288        }
289    };
290
291    if let Err(e) = spawn_inner.await {
292        on_err(e).await;
293    }
294
295    on_close(decode).await;
296}