1use core::{
2 cmp::Ordering,
3 convert::Infallible,
4 future::{Future, poll_fn},
5 pin::{Pin, pin},
6 time::Duration,
7};
8
9use futures_core::stream::Stream;
10use http_ws::{
11 HandshakeError, Item, Message as WsMessage, ProtocolError, WsOutput,
12 stream::{RequestStream, WsError},
13};
14use tokio::time::{Instant, sleep};
15use xitca_unsafe_collection::{
16 bytes::BytesStr,
17 futures::{Select, SelectOutput},
18};
19
20use crate::{
21 body::{BodyStream, RequestBody, ResponseBody, StreamDataBody},
22 bytes::Bytes,
23 context::WebContext,
24 error::{Error, ErrorStatus, HeaderNotFound},
25 handler::{FromRequest, Responder},
26 http::{
27 Protocol, StatusCode, WebResponse,
28 header::{CONNECTION, SEC_WEBSOCKET_VERSION, UPGRADE},
29 },
30 service::Service,
31};
32
33pub use http_ws::{ResponseSender, ResponseWeakSender};
34
35#[derive(Debug, Eq, PartialEq)]
38pub enum Message {
39 Text(BytesStr),
40 Binary(Bytes),
41 Continuation(Item),
42}
43
44type BoxFuture<'a> = Pin<Box<dyn Future<Output = ()> + 'a>>;
45
46type OnMsgCB = Box<dyn for<'a> FnMut(&'a mut ResponseSender, Message) -> BoxFuture<'a>>;
47
48type OnErrCB<E> = Box<dyn FnMut(WsError<E>) -> BoxFuture<'static>>;
49
50type OnCloseCB<B> = Box<dyn for<'a> FnOnce(Pin<&'a mut RequestStream<StreamDataBody<B>>>) -> BoxFuture<'a>>;
51
52pub struct WebSocket<B = RequestBody>
53where
54 B: BodyStream,
55{
56 ws: WsOutput<StreamDataBody<B>>,
57 ping_interval: Duration,
58 max_unanswered_ping: u8,
59 on_msg: OnMsgCB,
60 on_err: OnErrCB<B::Error>,
61 on_close: OnCloseCB<B>,
62}
63
64impl<B> WebSocket<B>
65where
66 B: BodyStream,
67{
68 fn new(ws: WsOutput<StreamDataBody<B>>) -> Self {
69 #[cold]
70 #[inline(never)]
71 fn boxed_future() -> BoxFuture<'static> {
72 Box::pin(async {})
73 }
74
75 Self {
76 ws,
77 ping_interval: Duration::from_secs(15),
78 max_unanswered_ping: 3,
79 on_msg: Box::new(|_, _| boxed_future()),
80 on_err: Box::new(|_| boxed_future()),
81 on_close: Box::new(|_| boxed_future()),
82 }
83 }
84
85 pub fn set_ping_interval(&mut self, dur: Duration) -> &mut Self {
87 self.ping_interval = dur;
88 self
89 }
90
91 pub fn set_max_unanswered_ping(&mut self, size: u8) -> &mut Self {
97 assert!(size > 0, "max_unanswered_ping MUST be none 0");
98 self.max_unanswered_ping = size;
99 self
100 }
101
102 pub fn msg_sender(&self) -> &ResponseSender {
105 &self.ws.2
106 }
107
108 pub fn on_msg<F>(&mut self, func: F) -> &mut Self
110 where
111 F: for<'a> FnMut(&'a mut ResponseSender, Message) -> BoxFuture<'a> + 'static,
112 {
113 self.on_msg = Box::new(func);
114 self
115 }
116
117 pub fn on_err<F, Fut>(&mut self, mut func: F) -> &mut Self
119 where
120 F: FnMut(WsError<B::Error>) -> Fut + 'static,
121 Fut: Future<Output = ()> + 'static,
122 {
123 self.on_err = Box::new(move |e| Box::pin(func(e)));
124 self
125 }
126
127 pub fn on_close<F, Fut>(&mut self, func: F) -> &mut Self
129 where
130 F: FnOnce(Pin<&mut RequestStream<StreamDataBody<B>>>) -> Fut + 'static,
131 Fut: Future<Output = ()> + 'static,
132 {
133 self.on_close = Box::new(|stream| Box::pin(func(stream)));
134 self
135 }
136}
137
138impl<'r, C, B> Service<WebContext<'r, C, B>> for HandshakeError {
139 type Response = WebResponse;
140 type Error = Infallible;
141
142 async fn call(&self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
143 let e = match self {
144 HandshakeError::NoConnectionUpgrade => HeaderNotFound(CONNECTION),
145 HandshakeError::NoVersionHeader => HeaderNotFound(SEC_WEBSOCKET_VERSION),
146 HandshakeError::NoWebsocketUpgrade => HeaderNotFound(UPGRADE),
147 _ => return StatusCode::INTERNAL_SERVER_ERROR.call(ctx).await,
149 };
150
151 e.call(ctx).await
152 }
153}
154
155impl<'a, 'r, C, B> FromRequest<'a, WebContext<'r, C, B>> for WebSocket<B>
156where
157 C: 'static,
158 B: BodyStream + Default + 'static,
159{
160 type Type<'b> = WebSocket<B>;
161 type Error = Error;
162
163 #[inline]
164 async fn from_request(ctx: &'a WebContext<'r, C, B>) -> Result<Self, Self::Error> {
165 let body = ctx.take_body_ref();
166 if let Some(protocol) = ctx.req().body().protocol() {
167 if protocol != Protocol::WEB_SOCKET {
168 return Err(ErrorStatus::bad_request().into());
170 }
171 }
172 let ws = http_ws::ws(ctx.req(), StreamDataBody::new(body)).map_err(Error::from_service)?;
173 Ok(WebSocket::new(ws))
174 }
175}
176
177impl<'r, C, B> Responder<WebContext<'r, C, B>> for WebSocket<B>
178where
179 B: BodyStream + 'static,
180{
181 type Response = WebResponse;
182 type Error = Infallible;
183
184 async fn respond(self, _: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
185 let Self {
186 ws,
187 ping_interval,
188 max_unanswered_ping,
189 on_msg,
190 on_err,
191 on_close,
192 } = self;
193
194 let (decode, res, tx) = ws;
195
196 tokio::task::spawn_local(spawn_task(
197 ping_interval,
198 max_unanswered_ping,
199 decode,
200 tx,
201 on_msg,
202 on_err,
203 on_close,
204 ));
205
206 Ok(res.map(|body| ResponseBody::boxed(StreamDataBody::new(body))))
207 }
208}
209
210async fn spawn_task<B>(
211 ping_interval: Duration,
212 max_unanswered_ping: u8,
213 decode: RequestStream<StreamDataBody<B>>,
214 mut tx: ResponseSender,
215 mut on_msg: OnMsgCB,
216 mut on_err: OnErrCB<B::Error>,
217 on_close: OnCloseCB<B>,
218) where
219 B: BodyStream,
220{
221 let on_msg = &mut *on_msg;
222 let on_err = &mut *on_err;
223
224 let mut decode = pin!(decode);
225
226 let spawn_inner = async {
227 let mut sleep = pin!(sleep(ping_interval));
228
229 let mut un_answered_ping = 0u8;
230
231 loop {
232 match poll_fn(|cx| decode.as_mut().poll_next(cx)).select(sleep.as_mut()).await {
233 SelectOutput::A(Some(Ok(msg))) => {
234 let msg = match msg {
235 WsMessage::Text(txt) => Message::Text(BytesStr::try_from(txt).unwrap()),
236 WsMessage::Binary(bin) => Message::Binary(bin),
237 WsMessage::Continuation(item) => Message::Continuation(item),
238 WsMessage::Nop => continue,
239 WsMessage::Pong(_) => {
240 if let Some(num) = un_answered_ping.checked_sub(1) {
241 un_answered_ping = num;
242 }
243 continue;
244 }
245 WsMessage::Ping(ping) => {
246 tx.pong(ping).await?;
247 continue;
248 }
249 WsMessage::Close(reason) => {
250 return match tx.close(reason).await {
251 Ok(_) => Ok(()),
253 Err(ProtocolError::SendClosed) => Ok(()),
255 Err(err) => Err(err.into()),
256 };
257 }
258 };
259
260 on_msg(&mut tx, msg).await
261 }
262 SelectOutput::A(Some(Err(e))) => return Err(e),
263 SelectOutput::A(None) => unreachable!("RequestStream never yields None"),
265 SelectOutput::B(_) => match un_answered_ping.cmp(&max_unanswered_ping) {
266 Ordering::Less => {
267 if let Err(e) = tx.ping(Bytes::new()).await {
268 if !matches!(e, ProtocolError::SendClosed) {
271 return Err(e.into());
272 }
273 }
274 un_answered_ping += 1;
275 sleep.as_mut().reset(Instant::now() + ping_interval);
276 }
277 _ => return Ok(()),
279 },
280 }
281 }
282 };
283
284 if let Err(err) = spawn_inner.await {
285 on_err(err).await;
286 }
287
288 on_close(decode).await;
289}