Skip to main content

xitca_web/handler/types/
websocket.rs

1use core::{
2    cmp::Ordering,
3    convert::Infallible,
4    future::{Future, poll_fn},
5    pin::{Pin, pin},
6    time::Duration,
7};
8
9use futures_core::stream::Stream;
10use http_ws::{
11    HandshakeError, Item, Message as WsMessage, ProtocolError, WsOutput,
12    stream::{RequestStream, WsError},
13};
14use tokio::time::{Instant, sleep};
15use xitca_unsafe_collection::{
16    bytes::BytesStr,
17    futures::{Select, SelectOutput},
18};
19
20use crate::{
21    body::{BodyStream, RequestBody, ResponseBody, StreamDataBody},
22    bytes::Bytes,
23    context::WebContext,
24    error::{Error, ErrorStatus, HeaderNotFound},
25    handler::{FromRequest, Responder},
26    http::{
27        Protocol, StatusCode, WebResponse,
28        header::{CONNECTION, SEC_WEBSOCKET_VERSION, UPGRADE},
29    },
30    service::Service,
31};
32
33pub use http_ws::{ResponseSender, ResponseWeakSender};
34
35/// simplified websocket message type.
36/// for more variant of message please reference [http_ws::Message] type.
37#[derive(Debug, Eq, PartialEq)]
38pub enum Message {
39    Text(BytesStr),
40    Binary(Bytes),
41    Continuation(Item),
42}
43
44type BoxFuture<'a> = Pin<Box<dyn Future<Output = ()> + 'a>>;
45
46type OnMsgCB = Box<dyn for<'a> FnMut(&'a mut ResponseSender, Message) -> BoxFuture<'a>>;
47
48type OnErrCB<E> = Box<dyn FnMut(WsError<E>) -> BoxFuture<'static>>;
49
50type OnCloseCB<B> = Box<dyn for<'a> FnOnce(Pin<&'a mut RequestStream<StreamDataBody<B>>>) -> BoxFuture<'a>>;
51
52pub struct WebSocket<B = RequestBody>
53where
54    B: BodyStream,
55{
56    ws: WsOutput<StreamDataBody<B>>,
57    ping_interval: Duration,
58    max_unanswered_ping: u8,
59    on_msg: OnMsgCB,
60    on_err: OnErrCB<B::Error>,
61    on_close: OnCloseCB<B>,
62}
63
64impl<B> WebSocket<B>
65where
66    B: BodyStream,
67{
68    fn new(ws: WsOutput<StreamDataBody<B>>) -> Self {
69        #[cold]
70        #[inline(never)]
71        fn boxed_future() -> BoxFuture<'static> {
72            Box::pin(async {})
73        }
74
75        Self {
76            ws,
77            ping_interval: Duration::from_secs(15),
78            max_unanswered_ping: 3,
79            on_msg: Box::new(|_, _| boxed_future()),
80            on_err: Box::new(|_| boxed_future()),
81            on_close: Box::new(|_| boxed_future()),
82        }
83    }
84
85    /// Set interval duration of server side ping message to client.
86    pub fn set_ping_interval(&mut self, dur: Duration) -> &mut Self {
87        self.ping_interval = dur;
88        self
89    }
90
91    /// Set max number of consecutive server side ping messages that are not
92    /// answered by client.
93    ///
94    /// # Panic:
95    /// when 0 is passed as argument.
96    pub fn set_max_unanswered_ping(&mut self, size: u8) -> &mut Self {
97        assert!(size > 0, "max_unanswered_ping MUST be none 0");
98        self.max_unanswered_ping = size;
99        self
100    }
101
102    /// Get a reference of Websocket message sender.
103    /// Can be used to send message to client.
104    pub fn msg_sender(&self) -> &ResponseSender {
105        &self.ws.2
106    }
107
108    /// Async function that would be called when new message arrived from client.
109    pub fn on_msg<F>(&mut self, func: F) -> &mut Self
110    where
111        F: for<'a> FnMut(&'a mut ResponseSender, Message) -> BoxFuture<'a> + 'static,
112    {
113        self.on_msg = Box::new(func);
114        self
115    }
116
117    /// Async function that would be called when error occurred.
118    pub fn on_err<F, Fut>(&mut self, mut func: F) -> &mut Self
119    where
120        F: FnMut(WsError<B::Error>) -> Fut + 'static,
121        Fut: Future<Output = ()> + 'static,
122    {
123        self.on_err = Box::new(move |e| Box::pin(func(e)));
124        self
125    }
126
127    /// Async function that would be called when closing the websocket connection.
128    pub fn on_close<F, Fut>(&mut self, func: F) -> &mut Self
129    where
130        F: FnOnce(Pin<&mut RequestStream<StreamDataBody<B>>>) -> Fut + 'static,
131        Fut: Future<Output = ()> + 'static,
132    {
133        self.on_close = Box::new(|stream| Box::pin(func(stream)));
134        self
135    }
136}
137
138impl<'r, C, B> Service<WebContext<'r, C, B>> for HandshakeError {
139    type Response = WebResponse;
140    type Error = Infallible;
141
142    async fn call(&self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
143        let e = match self {
144            HandshakeError::NoConnectionUpgrade => HeaderNotFound(CONNECTION),
145            HandshakeError::NoVersionHeader => HeaderNotFound(SEC_WEBSOCKET_VERSION),
146            HandshakeError::NoWebsocketUpgrade => HeaderNotFound(UPGRADE),
147            // TODO: refine error mapping of the remaining branches.
148            _ => return StatusCode::INTERNAL_SERVER_ERROR.call(ctx).await,
149        };
150
151        e.call(ctx).await
152    }
153}
154
155impl<'a, 'r, C, B> FromRequest<'a, WebContext<'r, C, B>> for WebSocket<B>
156where
157    C: 'static,
158    B: BodyStream + Default + 'static,
159{
160    type Type<'b> = WebSocket<B>;
161    type Error = Error;
162
163    #[inline]
164    async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
165        let body = ctx.take_body_ref();
166        if let Some(protocol) = ctx.req().body().protocol() {
167            if protocol != Protocol::WEB_SOCKET {
168                // TODO: strong typed error display protocol mismatch information
169                return Err(ErrorStatus::bad_request().into());
170            }
171        }
172        let ws = http_ws::ws(ctx.req(), StreamDataBody::new(body)).map_err(Error::from_service)?;
173        Ok(WebSocket::new(ws))
174    }
175}
176
177impl<'r, C, B> Responder<WebContext<'r, C, B>> for WebSocket<B>
178where
179    B: BodyStream + 'static,
180{
181    type Response = WebResponse;
182    type Error = Infallible;
183
184    async fn respond(self, _: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
185        let Self {
186            ws,
187            ping_interval,
188            max_unanswered_ping,
189            on_msg,
190            on_err,
191            on_close,
192        } = self;
193
194        let (decode, res, tx) = ws;
195
196        tokio::task::spawn_local(spawn_task(
197            ping_interval,
198            max_unanswered_ping,
199            decode,
200            tx,
201            on_msg,
202            on_err,
203            on_close,
204        ));
205
206        Ok(res.map(|body| ResponseBody::boxed(StreamDataBody::new(body))))
207    }
208}
209
210async fn spawn_task<B>(
211    ping_interval: Duration,
212    max_unanswered_ping: u8,
213    decode: RequestStream<StreamDataBody<B>>,
214    mut tx: ResponseSender,
215    mut on_msg: OnMsgCB,
216    mut on_err: OnErrCB<B::Error>,
217    on_close: OnCloseCB<B>,
218) where
219    B: BodyStream,
220{
221    let on_msg = &mut *on_msg;
222    let on_err = &mut *on_err;
223
224    let mut decode = pin!(decode);
225
226    let spawn_inner = async {
227        let mut sleep = pin!(sleep(ping_interval));
228
229        let mut un_answered_ping = 0u8;
230
231        loop {
232            match poll_fn(|cx| decode.as_mut().poll_next(cx)).select(sleep.as_mut()).await {
233                SelectOutput::A(Some(Ok(msg))) => {
234                    let msg = match msg {
235                        WsMessage::Text(txt) => Message::Text(BytesStr::try_from(txt).unwrap()),
236                        WsMessage::Binary(bin) => Message::Binary(bin),
237                        WsMessage::Continuation(item) => Message::Continuation(item),
238                        WsMessage::Nop => continue,
239                        WsMessage::Pong(_) => {
240                            if let Some(num) = un_answered_ping.checked_sub(1) {
241                                un_answered_ping = num;
242                            }
243                            continue;
244                        }
245                        WsMessage::Ping(ping) => {
246                            tx.pong(ping).await?;
247                            continue;
248                        }
249                        WsMessage::Close(reason) => {
250                            return match tx.close(reason).await {
251                                // close echoed back to remote successfully
252                                Ok(_) => Ok(()),
253                                // close is locally intialized and echo of it has already been received
254                                Err(ProtocolError::SendClosed) => Ok(()),
255                                Err(err) => Err(err.into()),
256                            };
257                        }
258                    };
259
260                    on_msg(&mut tx, msg).await
261                }
262                SelectOutput::A(Some(Err(e))) => return Err(e),
263                // RequestSream never yields None. It's either Close message received or connection broken that mark the end of it
264                SelectOutput::A(None) => unreachable!("RequestStream never yields None"),
265                SelectOutput::B(_) => match un_answered_ping.cmp(&max_unanswered_ping) {
266                    Ordering::Less => {
267                        if let Err(e) = tx.ping(Bytes::new()).await {
268                            // continue ping timer when websocket is closed.
269                            // client may be lagging behind and not respond to close message immediately.
270                            if !matches!(e, ProtocolError::SendClosed) {
271                                return Err(e.into());
272                            }
273                        }
274                        un_answered_ping += 1;
275                        sleep.as_mut().reset(Instant::now() + ping_interval);
276                    }
277                    // ping time out
278                    _ => return Ok(()),
279                },
280            }
281        }
282    };
283
284    if let Err(err) = spawn_inner.await {
285        on_err(err).await;
286    }
287
288    on_close(decode).await;
289}