web_socket_io/
lib.rs

1#![warn(missing_docs)]
2#![doc = include_str!("../README.md")]
3
4/// Error types
5pub mod error;
6use error::{ConnClose, NotifyError, ReceiverClosed};
7pub use web_socket;
8
9use std::{
10    collections::HashMap,
11    future::Future,
12    io,
13    ops::ControlFlow,
14    sync::{Arc, Mutex},
15    task::{Context, Poll},
16};
17use tokio::{
18    io::{AsyncRead, AsyncWrite},
19    sync::mpsc::Sender,
20};
21use web_socket::{DataType, Event, Stream, WebSocket};
22
23pub(crate) type DynErr = Box<dyn std::error::Error + Send + Sync>;
24
25type Resetter = Arc<Mutex<HashMap<u32, ResetShared>>>;
26
27/// `SocketIo` manages WebSocket communication for handling RPC events.
28/// 
29/// It utilizes WebSocket  technology to facilitate real-time communication, providing mechanisms for sending requests 
30/// and receiving responses.
31/// 
32/// Additionally, it supports RPC cancellation and timeout functionality, 
33/// allowing for better control over ongoing operations.
34/// 
35/// The struct efficiently manages concurrent RPC events and notifies clients of relevant occurrences.
36pub struct SocketIo {
37    ws: WebSocket<Box<dyn AsyncRead + Send + Unpin + 'static>>,
38    tx: Sender<Reply>,
39    resetter: Resetter,
40}
41
42enum Reply {
43    Ping(Box<[u8]>),
44    Response(Box<[u8]>),
45}
46
47/// `Procedure` represents an RPC (Remote Procedure Call) or notification in the system.
48pub enum Procedure {
49    /// `Call` represents a RPC event
50    Call(Request, Response, AbortController),
51
52    /// `Notify` represents a one-way notification that includes only a request.
53    Notify(Request),
54}
55
56/// `Notifier` is used to send notifications, Sends notifications where no response expected.
57#[derive(Clone)]
58pub struct Notifier {
59    tx: Sender<Reply>,
60}
61
62async fn notify(tx: &Sender<Reply>, name: &str, data: &[u8]) -> Result<(), NotifyError> {
63    let event_name = name.as_bytes();
64    let event_name_len: u8 = event_name
65        .len()
66        .try_into()
67        .map_err(|_| NotifyError::EventNameTooBig)?;
68
69    let mut buf = Vec::with_capacity(5 + data.len());
70
71    buf.push(1); // frame type
72    buf.push(event_name_len);
73    buf.extend_from_slice(event_name);
74    buf.extend_from_slice(data);
75
76    tx.send(Reply::Response(buf.into()))
77        .await
78        .map_err(|_| NotifyError::ReceiverClosed)
79}
80
81impl Notifier {
82    /// Sends a notification with the given name and data.
83    pub async fn notify(&self, name: &str, data: impl AsRef<[u8]>) -> Result<(), NotifyError> {
84        notify(&self.tx, name, data.as_ref()).await
85    }
86}
87
88impl SocketIo {
89    /// Returns a `Notifier` for sending notifications.
90    pub fn notifier(&self) -> Notifier {
91        Notifier {
92            tx: self.tx.clone(),
93        }
94    }
95
96    /// Sends a notification with the given name and data.
97    pub async fn notify(&mut self, name: &str, data: impl AsRef<[u8]>) -> Result<(), NotifyError> {
98        notify(&self.tx, name, data.as_ref()).await
99    }
100
101    /// Creates a new `SocketIo` instance with the specified reader, writer, and buffer size.
102    ///
103    /// # Arguments
104    ///
105    /// * `reader` - The source for reading data.
106    /// * `writer` - The destination for writing data.
107    /// * `buffer` - The size of the buffer for the channel.
108    pub fn new<I, O>(reader: I, writer: O, buffer: usize) -> Self
109    where
110        I: Unpin + AsyncRead + Send + 'static,
111        O: Unpin + AsyncWrite + Send + 'static,
112    {
113        let (tx, mut rx) = tokio::sync::mpsc::channel::<Reply>(buffer);
114        let mut ws_writer = WebSocket::server(writer);
115        tokio::spawn(async move {
116            loop {
117                while let Some(reply) = rx.recv().await {
118                    let o = match reply {
119                        Reply::Ping(data) => ws_writer.send_pong(data).await,
120                        Reply::Response(data) => ws_writer.send(&data[..]).await,
121                    };
122                    if o.is_err() {
123                        break;
124                    }
125                }
126            }
127        });
128        Self {
129            ws: WebSocket::server(Box::new(reader)),
130            tx,
131            resetter: Default::default(),
132        }
133    }
134
135    /// Receives the next `Procedure` (either a rpc or notification).
136    ///
137    /// ## Connection State
138    /// - Returns `io::ErrorKind::ConnectionReset` when an error event occurs.
139    /// - Returns `io::ErrorKind::ConnectionAborted` when a close event is received.
140    pub async fn recv(&mut self) -> io::Result<Procedure> {
141        let mut buf = Vec::with_capacity(4096);
142        let result = async {
143            loop {
144                match self.ws.recv().await? {
145                    Event::Data { ty, data } => match ty {
146                        DataType::Complete(_) => {
147                            if let ControlFlow::Break(p) = self
148                                .into_event(data)
149                                .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
150                            {
151                                return Ok(p);
152                            }
153                        }
154                        DataType::Stream(stream) => {
155                            buf.extend_from_slice(&data);
156                            if let Stream::End(_) = stream {
157                                if let ControlFlow::Break(p) =
158                                    self.into_event(data).map_err(|err| {
159                                        io::Error::new(io::ErrorKind::InvalidData, err)
160                                    })?
161                                {
162                                    return Ok(p);
163                                }
164                            }
165                        }
166                    },
167                    Event::Ping(data) => {
168                        let _ = self.tx.send(Reply::Ping(data)).await;
169                    }
170                    Event::Pong(_) => {}
171                    Event::Error(err) => {
172                        return Err(io::Error::new(io::ErrorKind::ConnectionReset, err))
173                    }
174                    Event::Close { code, reason } => {
175                        return Err(io::Error::new(
176                            io::ErrorKind::ConnectionAborted,
177                            ConnClose { code, reason },
178                        ))
179                    }
180                }
181            }
182        }
183        .await;
184        if result.is_err() {
185            for (_, reset_inner) in self.resetter.lock().unwrap().drain() {
186                reset_inner.lock().unwrap().reset();
187            }
188        }
189        result
190    }
191
192    fn into_event(&mut self, buf: Box<[u8]>) -> Result<ControlFlow<Procedure>, DynErr> {
193        let reader = &mut &buf[..];
194        let frame_type = get_slice(reader, 1)?[0];
195
196        match frame_type {
197            1 => {
198                let method_len = validate_and_parse_utf8_rpc_name(reader)?;
199                let data_offset = (buf.len() - reader.len()) as u16;
200                Ok(ControlFlow::Break(Procedure::Notify(Request {
201                    buf,
202                    method_offset: 2,
203                    method_len,
204                    data_offset,
205                })))
206            }
207            2 => {
208                let id = parse_rpc_id(reader)?;
209                let method_len = validate_and_parse_utf8_rpc_name(reader)?;
210                let data_offset = (buf.len() - reader.len()) as u16;
211
212                let reset = AbortController::new();
213                self.resetter
214                    .lock()
215                    .unwrap()
216                    .insert(id, reset.inner.clone());
217
218                Ok(ControlFlow::Break(Procedure::Call(
219                    Request {
220                        buf,
221                        method_offset: 6,
222                        method_len,
223                        data_offset,
224                    },
225                    Response {
226                        id,
227                        tx: self.tx.clone(),
228                        resetter: self.resetter.clone(),
229                    },
230                    reset,
231                )))
232            }
233            3 => {
234                let id = parse_rpc_id(reader)?;
235                if let Some(reset_inner) = self.resetter.lock().unwrap().remove(&id) {
236                    reset_inner.lock().unwrap().reset();
237                }
238                Ok(ControlFlow::Continue(()))
239            }
240            _ => Err("invalid frame".into()),
241        }
242    }
243}
244
245struct ResetInner {
246    is_reset: bool,
247    // todo: use `AtomicUsize` as state for both `is_reset` and `has_waker`
248    // todo: use spinlock using `AtomicUsize` state ?
249    waker: Option<std::task::Waker>,
250}
251
252impl ResetInner {
253    fn new() -> Self {
254        Self {
255            is_reset: false,
256            waker: None,
257        }
258    }
259
260    fn reset(&mut self) {
261        self.is_reset = true;
262        if let Some(waker) = &self.waker {
263            waker.wake_by_ref();
264        }
265    }
266}
267
268type ResetShared = Arc<Mutex<ResetInner>>;
269
270/// `AbortController` is a controller that allows you to monitor for a stream reset and
271/// cancel an associated asynchronous task if the reset occurs.
272pub struct AbortController {
273    inner: ResetShared,
274}
275
276impl AbortController {
277    pub(crate) fn new() -> Self {
278        Self {
279            inner: Arc::new(Mutex::new(ResetInner::new())),
280        }
281    }
282
283    /// Polls to be notified when the client resets this rpc.
284    /// If the stream has not been reset. This returns `Poll::Pending`
285    pub fn poll_reset(&mut self, cx: &mut Context<'_>) -> Poll<()> {
286        let mut inner = self.inner.lock().unwrap();
287        if inner.is_reset {
288            return Poll::Ready(());
289        }
290        match inner.waker.as_mut() {
291            Some(w) => w.clone_from(cx.waker()),
292            None => inner.waker = Some(cx.waker().clone()),
293        }
294        drop(inner);
295        Poll::Pending
296    }
297
298    /// Awaits the stream reset event.
299    pub async fn reset(&mut self) {
300        std::future::poll_fn(|cx| self.poll_reset(cx)).await;
301    }
302
303    /// Executes a given asynchronous task and aborts it when stream is reset.
304    ///
305    /// ### Example
306    ///
307    /// ```rust
308    /// controller.abort_on_reset(async {  }).await;
309    /// ```
310    pub async fn abort_on_reset(mut self, task: impl Future) {
311        let mut task = std::pin::pin!(task);
312        std::future::poll_fn(|cx| {
313            if let Poll::Ready(()) = self.poll_reset(cx) {
314                return Poll::Ready(());
315            }
316            task.as_mut().poll(cx).map(|_| ())
317        })
318        .await;
319    }
320
321    /// Spawns a new task that will be aborted if the stream is reset.
322    ///
323    /// This function spawns the given task in background, and automatically cancels
324    /// the task if the stream reset event occurs.
325    ///
326    /// ### Example
327    ///
328    /// ```rust
329    /// controller.spawn_and_abort_on_reset(async { ... });
330    /// ```
331    pub fn spawn_and_abort_on_reset<F>(self, task: F) -> tokio::task::JoinHandle<()>
332    where
333        F: Future + Send + 'static,
334    {
335        tokio::task::spawn(self.abort_on_reset(task))
336    }
337}
338
339/// Represents an incoming rpc request.
340#[derive(Debug)]
341pub struct Request {
342    buf: Box<[u8]>,
343    method_offset: u8,
344    method_len: u8,
345    data_offset: u16,
346}
347
348/// Represents a response used to send the result of a rpc request.
349pub struct Response {
350    id: u32,
351    tx: Sender<Reply>,
352    resetter: Resetter,
353}
354
355impl Drop for Response {
356    fn drop(&mut self) {
357        self.resetter.lock().unwrap().remove(&self.id);
358    }
359}
360
361/// Represents rpc response.
362impl Response {
363    /// Returns the ID of the rpc request.
364    #[inline]
365    pub fn id(&self) -> u32 {
366        self.id
367    }
368
369    /// Sends the response with the provided data.
370    pub async fn send(self, data: impl AsRef<[u8]>) -> Result<(), ReceiverClosed> {
371        let data = data.as_ref();
372        let mut buf = Vec::with_capacity(5 + data.len());
373
374        buf.push(4); // frame type
375        buf.extend_from_slice(&self.id.to_be_bytes()); // call id
376        buf.extend_from_slice(data);
377
378        self.tx
379            .send(Reply::Response(buf.into()))
380            .await
381            .map_err(|_| ReceiverClosed)
382    }
383}
384
385impl Request {
386    /// Returns the rpc method name.
387    #[inline]
388    pub fn method(&self) -> &str {
389        unsafe {
390            let offset = self.method_offset as usize;
391            let length = self.method_len as usize;
392            std::str::from_utf8_unchecked(&self.buf.get_unchecked(offset..(offset + length)))
393        }
394    }
395
396    /// Returns the data payload of the request.
397    #[inline]
398    pub fn data(&self) -> &[u8] {
399        &self.buf[self.data_offset.into()..]
400    }
401}
402
403fn parse_rpc_id(reader: &mut &[u8]) -> Result<u32, &'static str> {
404    let raw_id = get_slice(reader, 4)?;
405    let id = u32::from_be_bytes(raw_id.try_into().unwrap());
406    Ok(id)
407}
408
409fn validate_and_parse_utf8_rpc_name(reader: &mut &[u8]) -> Result<u8, DynErr> {
410    let method_len = get_slice(reader, 1)?[0];
411    std::str::from_utf8(get_slice(reader, method_len as usize)?)?;
412    Ok(method_len)
413}
414
415fn get_slice<'de>(reader: &mut &'de [u8], len: usize) -> Result<&'de [u8], &'static str> {
416    if len <= reader.len() {
417        unsafe {
418            let slice = reader.get_unchecked(..len);
419            *reader = reader.get_unchecked(len..);
420            Ok(slice)
421        }
422    } else {
423        Err("insufficient bytes")
424    }
425}