simple_websockets/
lib.rs

1//! An easy-to-use WebSocket server.
2//!
3//! To start a WebSocket listener, simply call [`launch()`], and use the
4//! returned [`EventHub`] to react to client messages, connections, and disconnections.
5//!
6//! # Example
7//!
8//! A WebSocket echo server:
9//!
10//! ```no_run
11//! use simple_websockets::{Event, Responder};
12//! use std::collections::HashMap;
13//!
14//! fn main() {
15//!     // listen for WebSockets on port 8080:
16//!     let event_hub = simple_websockets::launch(8080)
17//!         .expect("failed to listen on port 8080");
18//!     // map between client ids and the client's `Responder`:
19//!     let mut clients: HashMap<u64, Responder> = HashMap::new();
20//!
21//!     loop {
22//!         match event_hub.poll_event() {
23//!             Event::Connect(client_id, responder) => {
24//!                 println!("A client connected with id #{}", client_id);
25//!                 // add their Responder to our `clients` map:
26//!                 clients.insert(client_id, responder);
27//!             },
28//!             Event::Disconnect(client_id) => {
29//!                 println!("Client #{} disconnected.", client_id);
30//!                 // remove the disconnected client from the clients map:
31//!                 clients.remove(&client_id);
32//!             },
33//!             Event::Message(client_id, message) => {
34//!                 println!("Received a message from client #{}: {:?}", client_id, message);
35//!                 // retrieve this client's `Responder`:
36//!                 let responder = clients.get(&client_id).unwrap();
37//!                 // echo the message back:
38//!                 responder.send(message);
39//!             },
40//!         }
41//!     }
42//! }
43//! ```
44use futures_util::{SinkExt, StreamExt};
45use tokio::net::{TcpListener, TcpStream};
46use tokio::runtime::Runtime;
47use tokio_tungstenite::{accept_async, tungstenite};
48
49#[derive(Debug)]
50pub enum Error {
51    /// Returned by [`launch`] if the websocket listener thread failed to start
52    FailedToStart,
53}
54
55/// An outgoing/incoming message to/from a websocket.
56#[derive(Debug, Clone)]
57pub enum Message {
58    /// A text message
59    Text(String),
60    /// A binary message
61    Binary(Vec<u8>),
62}
63
64impl Message {
65    fn into_tungstenite(self) -> tungstenite::Message {
66        match self {
67            Self::Text(text) => tungstenite::Message::Text(text),
68            Self::Binary(bytes) => tungstenite::Message::Binary(bytes),
69        }
70    }
71
72    fn from_tungstenite(message: tungstenite::Message) -> Option<Self> {
73        match message {
74            tungstenite::Message::Binary(bytes) => Some(Self::Binary(bytes)),
75            tungstenite::Message::Text(text) => Some(Self::Text(text)),
76            _ => None,
77        }
78    }
79}
80
81enum ResponderCommand {
82    Message(Message),
83    CloseConnection,
84}
85
86/// Sends outgoing messages to a websocket.
87/// Every connected websocket client has a corresponding `Responder`.
88///
89/// `Responder`s can be safely cloned and sent across threads, to be used in a
90/// multi-producer single-consumer paradigm.
91///
92/// If a Reponder is dropped while its client is still connected, the connection
93/// will be automatically closed. If there are multiple clones of a Responder,
94/// The client will not be disconnected until the last Responder is dropped.
95#[derive(Debug, Clone)]
96pub struct Responder {
97    tx: flume::Sender<ResponderCommand>,
98    client_id: u64,
99}
100
101impl Responder {
102    fn new(tx: flume::Sender<ResponderCommand>, client_id: u64) -> Self {
103        Self { tx, client_id }
104    }
105
106    /// Sends a message to the client represented by this `Responder`.
107    ///
108    /// Returns true if the message was sent, or false if it wasn't
109    /// sent (because the client is disconnected).
110    ///
111    /// Note that this *doesn't* need a mutable reference to `self`.
112    pub fn send(&self, message: Message) -> bool {
113        self.tx.send(ResponderCommand::Message(message)).is_ok()
114    }
115
116    /// Closes this client's connection.
117    ///
118    /// Note that this *doesn't* need a mutable reference to `self`.
119    pub fn close(&self) {
120        let _ = self.tx.send(ResponderCommand::CloseConnection);
121    }
122
123    /// The id of the client that this `Responder` is connected to.
124    pub fn client_id(&self) -> u64 {
125        self.client_id
126    }
127}
128
129/// An incoming event from a client.
130/// This can be an incoming message, a new client connection, or a disconnection.
131#[derive(Debug)]
132pub enum Event {
133    /// A new client has connected.
134    Connect(
135        /// id of the client who connected
136        u64,
137        /// [`Responder`] used to send messages back to this client
138        Responder,
139    ),
140
141    /// A client has disconnected.
142    Disconnect(
143        /// id of the client who disconnected
144        u64,
145    ),
146
147    /// An incoming message from a client.
148    Message(
149        /// id of the client who sent the message
150        u64,
151        /// the message
152        Message,
153    ),
154}
155
156/// A queue of incoming events from clients.
157///
158/// The `EventHub` is the centerpiece of this library, and it is where all
159/// messages, connections, and disconnections are received.
160#[derive(Debug)]
161pub struct EventHub {
162    rx: flume::Receiver<Event>,
163}
164
165impl EventHub {
166    fn new(rx: flume::Receiver<Event>) -> Self {
167        Self { rx }
168    }
169
170    /// Clears the event queue and returns all the events that were in the queue.
171    pub fn drain(&self) -> Vec<Event> {
172        if self.rx.is_disconnected() && self.rx.is_empty() {
173            panic!("EventHub channel disconnected. Panicking because Websocket listener thread was killed.");
174        }
175
176        self.rx.drain().collect()
177    }
178
179    /// Returns the next event, or None if the queue is empty.
180    pub fn next_event(&self) -> Option<Event> {
181        self.rx.try_recv().ok()
182    }
183
184    /// Returns the next event, blocking if the queue is empty.
185    pub fn poll_event(&self) -> Event {
186        self.rx.recv().unwrap()
187    }
188
189    /// Async version of [`poll_event`](Self::poll_event)
190    pub async fn poll_async(&self) -> Event {
191        self.rx.recv_async().await.expect("Parent thread is dead")
192    }
193
194    /// Returns true if there are currently no events in the queue.
195    pub fn is_empty(&self) -> bool {
196        self.rx.is_empty()
197    }
198}
199
200/// Start listening for websocket connections on `port`.
201/// On success, returns an [`EventHub`] for receiving messages and
202/// connection/disconnection notifications.
203pub fn launch(port: u16) -> Result<EventHub, Error> {
204    let address = format!("0.0.0.0:{}", port);
205    let listener = std::net::TcpListener::bind(&address).map_err(|_| Error::FailedToStart)?;
206    return launch_from_listener(listener);
207}
208
209/// Start listening for websocket connections with the specified [`TcpListener`](std::net::TcpListener).
210/// The listener must be bound (by calling [`bind`](std::net::TcpListener::bind)) before being passed to
211/// `launch_from_listener`.
212///
213/// ```no_run
214/// use std::net::TcpListener;
215///
216/// fn main() {
217///     // Example of using a pre-bound listener instead of providing a port.
218///     let listener = TcpListener::bind("0.0.0.0:8080").unwrap();
219///     let event_hub = simple_websockets::launch_from_listener(listener).expect("failed to listen on port 8080");
220///     // ...
221/// }
222/// ```
223pub fn launch_from_listener(listener: std::net::TcpListener) -> Result<EventHub, Error> {
224    let (tx, rx) = flume::unbounded();
225    std::thread::Builder::new()
226        .name("Websocket listener".to_string())
227        .spawn(move || {
228            start_runtime(tx, listener).unwrap();
229        })
230        .map_err(|_| Error::FailedToStart)?;
231
232    Ok(EventHub::new(rx))
233}
234
235fn start_runtime(
236    event_tx: flume::Sender<Event>,
237    listener: std::net::TcpListener,
238) -> Result<(), Error> {
239    listener
240        .set_nonblocking(true)
241        .map_err(|_| Error::FailedToStart)?;
242    Runtime::new()
243        .map_err(|_| Error::FailedToStart)?
244        .block_on(async {
245            let tokio_listener = TcpListener::from_std(listener).unwrap();
246            let mut current_id: u64 = 0;
247            loop {
248                match tokio_listener.accept().await {
249                    Ok((stream, _)) => {
250                        tokio::spawn(handle_connection(stream, event_tx.clone(), current_id));
251                        current_id = current_id.wrapping_add(1);
252                    }
253                    _ => {}
254                }
255            }
256        })
257}
258
259async fn handle_connection(stream: TcpStream, event_tx: flume::Sender<Event>, id: u64) {
260    let ws_stream = match accept_async(stream).await {
261        Ok(s) => s,
262        Err(_) => return,
263    };
264
265    let (mut outgoing, mut incoming) = ws_stream.split();
266
267    // channel for the `Responder` to send things to this websocket
268    let (resp_tx, resp_rx) = flume::unbounded();
269
270    event_tx
271        .send(Event::Connect(id, Responder::new(resp_tx, id)))
272        .expect("Parent thread is dead");
273
274    // future that waits for commands from the `Responder`
275    let responder_events = async move {
276        while let Ok(event) = resp_rx.recv_async().await {
277            match event {
278                ResponderCommand::Message(message) => {
279                    if let Err(_) = outgoing.send(message.into_tungstenite()).await {
280                        let _ = outgoing.close().await;
281                        return Ok(());
282                    }
283                }
284                ResponderCommand::CloseConnection => {
285                    let _ = outgoing.close().await;
286                    return Ok(());
287                }
288            }
289        }
290
291        // Disconnect if the `Responder` was dropped without explicitly disconnecting
292        let _ = outgoing.close().await;
293
294        // this future always returns Ok, so that it wont stop the try_join
295        Result::<(), ()>::Ok(())
296    };
297
298    let event_tx2 = event_tx.clone();
299    //future that forwards messages received from the websocket to the event channel
300    let events = async move {
301        while let Some(message) = incoming.next().await {
302            if let Ok(tungstenite_msg) = message {
303                if let Some(msg) = Message::from_tungstenite(tungstenite_msg) {
304                    event_tx2
305                        .send(Event::Message(id, msg))
306                        .expect("Parent thread is dead");
307                }
308            }
309        }
310
311        // stop the try_join once the websocket is closed and all pending incoming
312        // messages have been sent to the event channel.
313        // stopping the try_join causes responder_events to be closed too so that the
314        // `Receiver` cant send any more messages.
315        Result::<(), ()>::Err(())
316    };
317
318    // use try_join so that when `events` returns Err (the websocket closes), responder_events will be stopped too
319    let _ = futures_util::try_join!(responder_events, events);
320
321    event_tx
322        .send(Event::Disconnect(id))
323        .expect("Parent thread is dead");
324}