1use core::{
2 cmp::Ordering,
3 convert::Infallible,
4 future::{Future, poll_fn},
5 pin::{Pin, pin},
6 time::Duration,
7};
8
9use std::io;
10
11use futures_core::stream::Stream;
12use http_ws::{
13 HandshakeError, Item, Message as WsMessage, ProtocolError, WsOutput,
14 stream::{RequestStream, WsError},
15};
16use tokio::time::{Instant, sleep};
17use xitca_unsafe_collection::{
18 bytes::BytesStr,
19 futures::{Select, SelectOutput},
20};
21
22use crate::{
23 body::{BodyStream, RequestBody, ResponseBody},
24 bytes::Bytes,
25 context::WebContext,
26 error::{Error, HeaderNotFound},
27 handler::{FromRequest, Responder},
28 http::{
29 StatusCode, WebResponse,
30 header::{CONNECTION, SEC_WEBSOCKET_VERSION, UPGRADE},
31 },
32 service::Service,
33};
34
35pub use http_ws::{ResponseSender, ResponseWeakSender};
36
37#[derive(Debug, Eq, PartialEq)]
40pub enum Message {
41 Text(BytesStr),
42 Binary(Bytes),
43 Continuation(Item),
44}
45
46type BoxFuture<'a> = Pin<Box<dyn Future<Output = ()> + 'a>>;
47
48type OnMsgCB = Box<dyn for<'a> FnMut(&'a mut ResponseSender, Message) -> BoxFuture<'a>>;
49
50type OnErrCB<E> = Box<dyn FnMut(WsError<E>) -> BoxFuture<'static>>;
51
52type OnCloseCB<B> = Box<dyn for<'a> FnOnce(Pin<&'a mut RequestStream<B>>) -> BoxFuture<'a>>;
53
54pub struct WebSocket<B = RequestBody>
55where
56 B: BodyStream,
57{
58 ws: WsOutput<B>,
59 ping_interval: Duration,
60 max_unanswered_ping: u8,
61 on_msg: OnMsgCB,
62 on_err: OnErrCB<B::Error>,
63 on_close: OnCloseCB<B>,
64}
65
66impl<B> WebSocket<B>
67where
68 B: BodyStream,
69{
70 fn new(ws: WsOutput<B>) -> Self {
71 #[cold]
72 #[inline(never)]
73 fn boxed_future() -> BoxFuture<'static> {
74 Box::pin(async {})
75 }
76
77 Self {
78 ws,
79 ping_interval: Duration::from_secs(15),
80 max_unanswered_ping: 3,
81 on_msg: Box::new(|_, _| boxed_future()),
82 on_err: Box::new(|_| boxed_future()),
83 on_close: Box::new(|_| boxed_future()),
84 }
85 }
86
87 pub fn set_ping_interval(&mut self, dur: Duration) -> &mut Self {
89 self.ping_interval = dur;
90 self
91 }
92
93 pub fn set_max_unanswered_ping(&mut self, size: u8) -> &mut Self {
99 assert!(size > 0, "max_unanswered_ping MUST be none 0");
100 self.max_unanswered_ping = size;
101 self
102 }
103
104 pub fn msg_sender(&self) -> &ResponseSender {
107 &self.ws.2
108 }
109
110 pub fn on_msg<F>(&mut self, func: F) -> &mut Self
112 where
113 F: for<'a> FnMut(&'a mut ResponseSender, Message) -> BoxFuture<'a> + 'static,
114 {
115 self.on_msg = Box::new(func);
116 self
117 }
118
119 pub fn on_err<F, Fut>(&mut self, mut func: F) -> &mut Self
121 where
122 F: FnMut(WsError<B::Error>) -> Fut + 'static,
123 Fut: Future<Output = ()> + 'static,
124 {
125 self.on_err = Box::new(move |e| Box::pin(func(e)));
126 self
127 }
128
129 pub fn on_close<F, Fut>(&mut self, func: F) -> &mut Self
131 where
132 F: FnOnce(Pin<&mut RequestStream<B>>) -> Fut + 'static,
133 Fut: Future<Output = ()> + 'static,
134 {
135 self.on_close = Box::new(|stream| Box::pin(func(stream)));
136 self
137 }
138}
139
140impl<'r, C, B> Service<WebContext<'r, C, B>> for HandshakeError {
141 type Response = WebResponse;
142 type Error = Infallible;
143
144 async fn call(&self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
145 let e = match self {
146 HandshakeError::NoConnectionUpgrade => HeaderNotFound(CONNECTION),
147 HandshakeError::NoVersionHeader => HeaderNotFound(SEC_WEBSOCKET_VERSION),
148 HandshakeError::NoWebsocketUpgrade => HeaderNotFound(UPGRADE),
149 _ => return StatusCode::INTERNAL_SERVER_ERROR.call(ctx).await,
151 };
152
153 e.call(ctx).await
154 }
155}
156
157impl<'a, 'r, C, B> FromRequest<'a, WebContext<'r, C, B>> for WebSocket<B>
158where
159 C: 'static,
160 B: BodyStream + Default + 'static,
161{
162 type Type<'b> = WebSocket<B>;
163 type Error = Error;
164
165 #[inline]
166 async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
167 let body = ctx.take_body_ref();
168 let ws = http_ws::ws(ctx.req(), body).map_err(Error::from_service)?;
169 Ok(WebSocket::new(ws))
170 }
171}
172
173impl<'r, C, B> Responder<WebContext<'r, C, B>> for WebSocket<B>
174where
175 B: BodyStream + 'static,
176{
177 type Response = WebResponse;
178 type Error = Infallible;
179
180 async fn respond(self, _: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
181 let Self {
182 ws,
183 ping_interval,
184 max_unanswered_ping,
185 on_msg,
186 on_err,
187 on_close,
188 } = self;
189
190 let (decode, res, tx) = ws;
191
192 tokio::task::spawn_local(spawn_task(
193 ping_interval,
194 max_unanswered_ping,
195 decode,
196 tx,
197 on_msg,
198 on_err,
199 on_close,
200 ));
201
202 Ok(res.map(ResponseBody::box_stream))
203 }
204}
205
206async fn spawn_task<B>(
207 ping_interval: Duration,
208 max_unanswered_ping: u8,
209 decode: RequestStream<B>,
210 mut tx: ResponseSender,
211 mut on_msg: OnMsgCB,
212 mut on_err: OnErrCB<B::Error>,
213 on_close: OnCloseCB<B>,
214) where
215 B: BodyStream,
216{
217 let on_msg = &mut *on_msg;
218 let on_err = &mut *on_err;
219
220 let mut decode = pin!(decode);
221
222 let spawn_inner = async {
223 let mut sleep = pin!(sleep(ping_interval));
224
225 let mut un_answered_ping = 0u8;
226
227 loop {
228 match poll_fn(|cx| decode.as_mut().poll_next(cx)).select(sleep.as_mut()).await {
229 SelectOutput::A(Some(Ok(msg))) => {
230 let msg = match msg {
231 WsMessage::Text(txt) => Message::Text(BytesStr::try_from(txt).unwrap()),
232 WsMessage::Binary(bin) => Message::Binary(bin),
233 WsMessage::Continuation(item) => Message::Continuation(item),
234 WsMessage::Nop => continue,
235 WsMessage::Pong(_) => {
236 if let Some(num) = un_answered_ping.checked_sub(1) {
237 un_answered_ping = num;
238 }
239 continue;
240 }
241 WsMessage::Ping(ping) => {
242 tx.send(WsMessage::Pong(ping)).await?;
243 continue;
244 }
245 WsMessage::Close(reason) => {
246 match tx.send(WsMessage::Close(reason)).await {
247 Ok(_) | Err(ProtocolError::Closed) => return Ok(()),
250 Err(e) => return Err(e.into()),
251 }
252 }
253 };
254
255 on_msg(&mut tx, msg).await
256 }
257 SelectOutput::A(Some(Err(e))) => on_err(e).await,
258 SelectOutput::A(None) => return Ok(()),
259 SelectOutput::B(_) => match un_answered_ping.cmp(&max_unanswered_ping) {
260 Ordering::Less => {
261 if let Err(e) = tx.send(WsMessage::Ping(Bytes::new())).await {
262 if !matches!(e, ProtocolError::Closed) {
265 return Err(e.into());
266 }
267 }
268 un_answered_ping += 1;
269 sleep.as_mut().reset(Instant::now() + ping_interval);
270 }
271 Ordering::Equal => match tx.send(WsMessage::Close(None)).await {
274 Ok(_) => un_answered_ping += 1,
275 Err(ProtocolError::Closed) => return Ok(()),
278 Err(e) => return Err(e.into()),
279 },
280 Ordering::Greater => {
283 let _ = tx.send_error(io::ErrorKind::UnexpectedEof.into()).await;
284 return Ok(());
285 }
286 },
287 }
288 }
289 };
290
291 if let Err(e) = spawn_inner.await {
292 on_err(e).await;
293 }
294
295 on_close(decode).await;
296}