1#![warn(missing_docs)]
2#![doc = include_str!("../README.md")]
3
4pub mod error;
6use error::{ConnClose, NotifyError, ReceiverClosed};
7pub use web_socket;
8
9use std::{
10 collections::HashMap,
11 future::Future,
12 io,
13 ops::ControlFlow,
14 sync::{Arc, Mutex},
15 task::{Context, Poll},
16};
17use tokio::{
18 io::{AsyncRead, AsyncWrite},
19 sync::mpsc::Sender,
20};
21use web_socket::{DataType, Event, Stream, WebSocket};
22
23pub(crate) type DynErr = Box<dyn std::error::Error + Send + Sync>;
24
25type Resetter = Arc<Mutex<HashMap<u32, ResetShared>>>;
26
27pub struct SocketIo {
37 ws: WebSocket<Box<dyn AsyncRead + Send + Unpin + 'static>>,
38 tx: Sender<Reply>,
39 resetter: Resetter,
40}
41
42enum Reply {
43 Ping(Box<[u8]>),
44 Response(Box<[u8]>),
45}
46
47pub enum Procedure {
49 Call(Request, Response, AbortController),
51
52 Notify(Request),
54}
55
56#[derive(Clone)]
58pub struct Notifier {
59 tx: Sender<Reply>,
60}
61
62async fn notify(tx: &Sender<Reply>, name: &str, data: &[u8]) -> Result<(), NotifyError> {
63 let event_name = name.as_bytes();
64 let event_name_len: u8 = event_name
65 .len()
66 .try_into()
67 .map_err(|_| NotifyError::EventNameTooBig)?;
68
69 let mut buf = Vec::with_capacity(5 + data.len());
70
71 buf.push(1); buf.push(event_name_len);
73 buf.extend_from_slice(event_name);
74 buf.extend_from_slice(data);
75
76 tx.send(Reply::Response(buf.into()))
77 .await
78 .map_err(|_| NotifyError::ReceiverClosed)
79}
80
81impl Notifier {
82 pub async fn notify(&self, name: &str, data: impl AsRef<[u8]>) -> Result<(), NotifyError> {
84 notify(&self.tx, name, data.as_ref()).await
85 }
86}
87
88impl SocketIo {
89 pub fn notifier(&self) -> Notifier {
91 Notifier {
92 tx: self.tx.clone(),
93 }
94 }
95
96 pub async fn notify(&mut self, name: &str, data: impl AsRef<[u8]>) -> Result<(), NotifyError> {
98 notify(&self.tx, name, data.as_ref()).await
99 }
100
101 pub fn new<I, O>(reader: I, writer: O, buffer: usize) -> Self
109 where
110 I: Unpin + AsyncRead + Send + 'static,
111 O: Unpin + AsyncWrite + Send + 'static,
112 {
113 let (tx, mut rx) = tokio::sync::mpsc::channel::<Reply>(buffer);
114 let mut ws_writer = WebSocket::server(writer);
115 tokio::spawn(async move {
116 loop {
117 while let Some(reply) = rx.recv().await {
118 let o = match reply {
119 Reply::Ping(data) => ws_writer.send_pong(data).await,
120 Reply::Response(data) => ws_writer.send(&data[..]).await,
121 };
122 if o.is_err() {
123 break;
124 }
125 }
126 }
127 });
128 Self {
129 ws: WebSocket::server(Box::new(reader)),
130 tx,
131 resetter: Default::default(),
132 }
133 }
134
135 pub async fn recv(&mut self) -> io::Result<Procedure> {
141 let mut buf = Vec::with_capacity(4096);
142 let result = async {
143 loop {
144 match self.ws.recv().await? {
145 Event::Data { ty, data } => match ty {
146 DataType::Complete(_) => {
147 if let ControlFlow::Break(p) = self
148 .into_event(data)
149 .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
150 {
151 return Ok(p);
152 }
153 }
154 DataType::Stream(stream) => {
155 buf.extend_from_slice(&data);
156 if let Stream::End(_) = stream {
157 if let ControlFlow::Break(p) =
158 self.into_event(data).map_err(|err| {
159 io::Error::new(io::ErrorKind::InvalidData, err)
160 })?
161 {
162 return Ok(p);
163 }
164 }
165 }
166 },
167 Event::Ping(data) => {
168 let _ = self.tx.send(Reply::Ping(data)).await;
169 }
170 Event::Pong(_) => {}
171 Event::Error(err) => {
172 return Err(io::Error::new(io::ErrorKind::ConnectionReset, err))
173 }
174 Event::Close { code, reason } => {
175 return Err(io::Error::new(
176 io::ErrorKind::ConnectionAborted,
177 ConnClose { code, reason },
178 ))
179 }
180 }
181 }
182 }
183 .await;
184 if result.is_err() {
185 for (_, reset_inner) in self.resetter.lock().unwrap().drain() {
186 reset_inner.lock().unwrap().reset();
187 }
188 }
189 result
190 }
191
192 fn into_event(&mut self, buf: Box<[u8]>) -> Result<ControlFlow<Procedure>, DynErr> {
193 let reader = &mut &buf[..];
194 let frame_type = get_slice(reader, 1)?[0];
195
196 match frame_type {
197 1 => {
198 let method_len = validate_and_parse_utf8_rpc_name(reader)?;
199 let data_offset = (buf.len() - reader.len()) as u16;
200 Ok(ControlFlow::Break(Procedure::Notify(Request {
201 buf,
202 method_offset: 2,
203 method_len,
204 data_offset,
205 })))
206 }
207 2 => {
208 let id = parse_rpc_id(reader)?;
209 let method_len = validate_and_parse_utf8_rpc_name(reader)?;
210 let data_offset = (buf.len() - reader.len()) as u16;
211
212 let reset = AbortController::new();
213 self.resetter
214 .lock()
215 .unwrap()
216 .insert(id, reset.inner.clone());
217
218 Ok(ControlFlow::Break(Procedure::Call(
219 Request {
220 buf,
221 method_offset: 6,
222 method_len,
223 data_offset,
224 },
225 Response {
226 id,
227 tx: self.tx.clone(),
228 resetter: self.resetter.clone(),
229 },
230 reset,
231 )))
232 }
233 3 => {
234 let id = parse_rpc_id(reader)?;
235 if let Some(reset_inner) = self.resetter.lock().unwrap().remove(&id) {
236 reset_inner.lock().unwrap().reset();
237 }
238 Ok(ControlFlow::Continue(()))
239 }
240 _ => Err("invalid frame".into()),
241 }
242 }
243}
244
245struct ResetInner {
246 is_reset: bool,
247 waker: Option<std::task::Waker>,
250}
251
252impl ResetInner {
253 fn new() -> Self {
254 Self {
255 is_reset: false,
256 waker: None,
257 }
258 }
259
260 fn reset(&mut self) {
261 self.is_reset = true;
262 if let Some(waker) = &self.waker {
263 waker.wake_by_ref();
264 }
265 }
266}
267
268type ResetShared = Arc<Mutex<ResetInner>>;
269
270pub struct AbortController {
273 inner: ResetShared,
274}
275
276impl AbortController {
277 pub(crate) fn new() -> Self {
278 Self {
279 inner: Arc::new(Mutex::new(ResetInner::new())),
280 }
281 }
282
283 pub fn poll_reset(&mut self, cx: &mut Context<'_>) -> Poll<()> {
286 let mut inner = self.inner.lock().unwrap();
287 if inner.is_reset {
288 return Poll::Ready(());
289 }
290 match inner.waker.as_mut() {
291 Some(w) => w.clone_from(cx.waker()),
292 None => inner.waker = Some(cx.waker().clone()),
293 }
294 drop(inner);
295 Poll::Pending
296 }
297
298 pub async fn reset(&mut self) {
300 std::future::poll_fn(|cx| self.poll_reset(cx)).await;
301 }
302
303 pub async fn abort_on_reset(mut self, task: impl Future) {
311 let mut task = std::pin::pin!(task);
312 std::future::poll_fn(|cx| {
313 if let Poll::Ready(()) = self.poll_reset(cx) {
314 return Poll::Ready(());
315 }
316 task.as_mut().poll(cx).map(|_| ())
317 })
318 .await;
319 }
320
321 pub fn spawn_and_abort_on_reset<F>(self, task: F) -> tokio::task::JoinHandle<()>
332 where
333 F: Future + Send + 'static,
334 {
335 tokio::task::spawn(self.abort_on_reset(task))
336 }
337}
338
339#[derive(Debug)]
341pub struct Request {
342 buf: Box<[u8]>,
343 method_offset: u8,
344 method_len: u8,
345 data_offset: u16,
346}
347
348pub struct Response {
350 id: u32,
351 tx: Sender<Reply>,
352 resetter: Resetter,
353}
354
355impl Drop for Response {
356 fn drop(&mut self) {
357 self.resetter.lock().unwrap().remove(&self.id);
358 }
359}
360
361impl Response {
363 #[inline]
365 pub fn id(&self) -> u32 {
366 self.id
367 }
368
369 pub async fn send(self, data: impl AsRef<[u8]>) -> Result<(), ReceiverClosed> {
371 let data = data.as_ref();
372 let mut buf = Vec::with_capacity(5 + data.len());
373
374 buf.push(4); buf.extend_from_slice(&self.id.to_be_bytes()); buf.extend_from_slice(data);
377
378 self.tx
379 .send(Reply::Response(buf.into()))
380 .await
381 .map_err(|_| ReceiverClosed)
382 }
383}
384
385impl Request {
386 #[inline]
388 pub fn method(&self) -> &str {
389 unsafe {
390 let offset = self.method_offset as usize;
391 let length = self.method_len as usize;
392 std::str::from_utf8_unchecked(&self.buf.get_unchecked(offset..(offset + length)))
393 }
394 }
395
396 #[inline]
398 pub fn data(&self) -> &[u8] {
399 &self.buf[self.data_offset.into()..]
400 }
401}
402
403fn parse_rpc_id(reader: &mut &[u8]) -> Result<u32, &'static str> {
404 let raw_id = get_slice(reader, 4)?;
405 let id = u32::from_be_bytes(raw_id.try_into().unwrap());
406 Ok(id)
407}
408
409fn validate_and_parse_utf8_rpc_name(reader: &mut &[u8]) -> Result<u8, DynErr> {
410 let method_len = get_slice(reader, 1)?[0];
411 std::str::from_utf8(get_slice(reader, method_len as usize)?)?;
412 Ok(method_len)
413}
414
415fn get_slice<'de>(reader: &mut &'de [u8], len: usize) -> Result<&'de [u8], &'static str> {
416 if len <= reader.len() {
417 unsafe {
418 let slice = reader.get_unchecked(..len);
419 *reader = reader.get_unchecked(len..);
420 Ok(slice)
421 }
422 } else {
423 Err("insufficient bytes")
424 }
425}