sod_tungstenite/
lib.rs

1//! [`sod::Service`] implementations to interact with [`tungstenite`] websockets.
2//!
3//! ## Service Impls
4//!
5//! All Services are [`Retryable`] and are able to be blocking or non-blocking.
6//!
7//! - [`WsSession`] is a [`MutService`] that wraps a [`tungstenite::WebSocket`], accepting [`WsSessionEvent`] to send or receive messages. `WsSession::into_split` can split a `WsSession` into a `WsReader`, `WsWriter`, and `WsFlusher`.
8//! - [`WsReader`] is a [`Service`] that wraps a [`Mutex<tungstenite::WebSocket>`], accepting a `()` as input and producing [`tungstenite::Message`] as output.
9//! - [`WsWriter`] is a [`Service`] that wraps a [`Mutex<tungstenite::WebSocket>`], accepting a `tungstenite::Message` as input.
10//! - [`WsFlusher`] is a [`Service`] that wraps a [`Mutex<tungstenite::WebSocket>`], accepting a `()` as input.
11//! - [`WsServer`] is a [`Service`] that that listens on a TCP port, accepting a `()` as input and producing a `WsSession` as output.
12//!
13//! ## Features
14//!
15//! - `native-tls` to enable Native TLS
16//! - `__rustls-tls` to enable Rustls TLS`
17//!
18//! ## Blocking Example
19//!
20//! ```no_run
21//! use sod::{idle::backoff, MaybeProcessService, MutService, RetryService, Service, ServiceChain};
22//! use sod_tungstenite::{UninitializedWsSession, WsServer, WsSession, WsSessionEvent};
23//! use std::{sync::atomic::Ordering, thread::spawn};
24//! use tungstenite::{http::StatusCode, Message};
25//! use url::Url;
26//!
27//! // server session logic to add `"pong: "` in front of text payload
28//! struct PongService;
29//! impl Service for PongService {
30//!     type Input = Message;
31//!     type Output = Option<Message>;
32//!     type Error = ();
33//!     fn process(&self, input: Message) -> Result<Self::Output, Self::Error> {
34//!         match input {
35//!             Message::Text(text) => Ok(Some(Message::Text(format!("pong: {text}")))),
36//!             _ => Ok(None),
37//!         }
38//!     }
39//! }
40//!
41//! // wires session logic and spawns in new thread
42//! struct SessionSpawner;
43//! impl Service for SessionSpawner {
44//!     type Input = UninitializedWsSession;
45//!     type Output = ();
46//!     type Error = ();
47//!     fn process(&self, input: UninitializedWsSession) -> Result<Self::Output, Self::Error> {
48//!         spawn(|| {
49//!             let (r, w, f) = input.handshake().unwrap().into_split();
50//!             let chain = ServiceChain::start(r)
51//!                 .next(PongService)
52//!                 .next(MaybeProcessService::new(w))
53//!                 .next(MaybeProcessService::new(f))
54//!                 .end();
55//!             sod::thread::spawn_loop(chain, |err| {
56//!                 println!("Session: {err:?}");
57//!                 Err(err) // stop thread on error
58//!             });
59//!         });
60//!         Ok(())
61//!     }
62//! }
63//!
64//! // start a blocking server that creates blocking sessions
65//! let server = WsServer::bind("127.0.0.1:48490").unwrap();
66//!
67//! // spawn a thread to start accepting new server sessions
68//! let handle = sod::thread::spawn_loop(
69//!     ServiceChain::start(server).next(SessionSpawner).end(),
70//!     |err| {
71//!         println!("Server: {err:?}");
72//!         Err(err) // stop thread on error
73//!     },
74//! );
75//!
76//! // connect a client to the server
77//! let (mut client, _) =
78//!     WsSession::connect(Url::parse("ws://127.0.0.1:48490/socket").unwrap()).unwrap();
79//!
80//! // client writes `"hello world"` payload
81//! client
82//!     .process(WsSessionEvent::WriteMessage(Message::Text(
83//!         "hello world!".to_owned(),
84//!     )))
85//!     .unwrap();
86//!
87//! // client receives `"pong: hello world"` payload
88//! println!(
89//!     "Received: {:?}",
90//!     client.process(WsSessionEvent::ReadMessage).unwrap()
91//! );
92//!
93//! // join until server crashes
94//! handle.join().unwrap();
95//! ```
96//!
97//! ## Non-Blocking Example
98//!
99//! ```
100//! use sod::{idle::backoff, MaybeProcessService, MutService, RetryService, Service, ServiceChain};
101//! use sod_tungstenite::{UninitializedWsSession, WsServer, WsSession, WsSessionEvent};
102//! use std::{sync::atomic::Ordering, thread::spawn};
103//! use tungstenite::{http::StatusCode, Message};
104//! use url::Url;
105//!
106//! // server session logic to add `"pong: "` in front of text payload
107//! struct PongService;
108//! impl Service for PongService {
109//!     type Input = Message;
110//!     type Output = Option<Message>;
111//!     type Error = ();
112//!     fn process(&self, input: Message) -> Result<Self::Output, Self::Error> {
113//!         match input {
114//!             Message::Text(text) => Ok(Some(Message::Text(format!("pong: {text}")))),
115//!             _ => Ok(None),
116//!         }
117//!     }
118//! }
119//!
120//! // wires session logic and spawns in new thread
121//! struct SessionSpawner;
122//! impl Service for SessionSpawner {
123//!     type Input = UninitializedWsSession;
124//!     type Output = ();
125//!     type Error = ();
126//!     fn process(&self, input: UninitializedWsSession) -> Result<Self::Output, Self::Error> {
127//!         spawn(|| {
128//!             let (r, w, f) = input.handshake().unwrap().into_split();
129//!             let chain = ServiceChain::start(RetryService::new(r, backoff))
130//!                 .next(PongService)
131//!                 .next(MaybeProcessService::new(RetryService::new(w, backoff)))
132//!                 .next(MaybeProcessService::new(f))
133//!                 .end();
134//!             sod::thread::spawn_loop(chain, |err| {
135//!                 println!("Session: {err:?}");
136//!                 Err(err) // stop thread on error
137//!             });
138//!         });
139//!         Ok(())
140//!     }
141//! }
142//!
143//! // start a non-blocking server that creates non-blocking sessions
144//! let server = WsServer::bind("127.0.0.1:48490")
145//!     .unwrap()
146//!     .with_nonblocking_sessions(true)
147//!     .with_nonblocking_server(true)
148//!     .unwrap();
149//!
150//! // spawn a thread to start accepting new server sessions
151//! let handle = sod::thread::spawn_loop(
152//!     ServiceChain::start(RetryService::new(server, backoff))
153//!         .next(SessionSpawner)
154//!         .end(),
155//!     |err| {
156//!         println!("Server: {err:?}");
157//!         Err(err) // stop thread on error
158//!     },
159//! );
160//!
161//! // connect a client to the server
162//! let (mut client, response) =
163//!     WsSession::connect(Url::parse("ws://127.0.0.1:48490/socket").unwrap()).unwrap();
164//! assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
165//!
166//! // client writes `"hello world"` payload
167//! client
168//!     .process(WsSessionEvent::WriteMessage(Message::Text(
169//!         "hello world!".to_owned(),
170//!     )))
171//!     .unwrap();
172//!
173//! // client receives `"pong: hello world"` payload
174//! assert_eq!(
175//!     client.process(WsSessionEvent::ReadMessage).unwrap(),
176//!     Some(Message::Text("pong: hello world!".to_owned()))
177//! );
178//!
179//! // stop the server
180//! sod::idle::KEEP_RUNNING.store(false, Ordering::Release);
181//! handle.join().unwrap();
182//! ```
183
184use sod::{MutService, RetryError, Retryable, Service};
185use std::{
186    borrow::BorrowMut,
187    io::{self, ErrorKind, Read, Write},
188    net::{TcpListener, TcpStream, ToSocketAddrs},
189    sync::{Arc, Mutex},
190};
191use tungstenite::{
192    accept_hdr_with_config, accept_with_config,
193    client::IntoClientRequest,
194    handshake::{
195        client::Response,
196        server::{Callback, NoCallback},
197    },
198    protocol::WebSocketConfig,
199    stream::MaybeTlsStream,
200    Error, Message, WebSocket,
201};
202
203pub extern crate tungstenite;
204
205/// An input event for [`WsSession`], which can be a read or write.
206#[derive(Clone, Debug)]
207pub enum WsSessionEvent {
208    ReadMessage,
209    WriteMessage(Message),
210    Flush,
211}
212
213/// A [`MutService`] that wraps a [`tungstenite::WebSocket`], processing a [`WsSessionEvent`], producing a `Some(Message)` when a [`Message`] is read, and producing `None` otherwise.
214pub struct WsSession<S> {
215    ws: WebSocket<S>,
216}
217impl<S> WsSession<S> {
218    /// Wrap the given [`WebSocket`]
219    pub fn new(ws: WebSocket<S>) -> Self {
220        Self { ws }
221    }
222    /// Split this `WsSession` into a [`WsReader`] and [`WsWriter`], utilizing a [`Mutex`] to coordinate mutability on the underlying stream.
223    pub fn into_split(self) -> (WsReader<S>, WsWriter<S>, WsFlusher<S>) {
224        let ws = Arc::new(Mutex::new(self.ws));
225        (
226            WsReader::new(Arc::clone(&ws)),
227            WsWriter::new(Arc::clone(&ws)),
228            WsFlusher::new(ws),
229        )
230    }
231}
232impl WsSession<MaybeTlsStream<TcpStream>> {
233    /// Connect to the given URL as a WebSocket Client, producing a [`WsSession`] and HTTP [`Response`].
234    pub fn connect<Req: IntoClientRequest>(
235        request: Req,
236    ) -> Result<(WsSession<MaybeTlsStream<TcpStream>>, Response), Error> {
237        let (ws, resp) = tungstenite::connect(request)?;
238        Ok((WsSession::new(ws), resp))
239    }
240    /// Configure the underlying [`MaybeTlsStream`] to be non-blocking.
241    ///
242    /// Non-blocking services should usually be encpasulated by a [`RetryService`].
243    pub fn set_nonblocking(&self, nonblocking: bool) -> Result<(), io::Error> {
244        set_nonblocking(self.ws.get_ref(), nonblocking)
245    }
246}
247impl<S: Read + Write> MutService for WsSession<S> {
248    type Input = WsSessionEvent;
249    type Output = Option<Message>;
250    type Error = Error;
251    fn process(&mut self, input: WsSessionEvent) -> Result<Self::Output, Self::Error> {
252        Ok(match input {
253            WsSessionEvent::ReadMessage => Some(self.ws.borrow_mut().read()?),
254            WsSessionEvent::WriteMessage(message) => {
255                self.ws.borrow_mut().send(message)?;
256                None
257            }
258            WsSessionEvent::Flush => {
259                self.ws.borrow_mut().flush()?;
260                None
261            }
262        })
263    }
264}
265impl<S> Retryable<WsSessionEvent, Error> for WsSession<S> {
266    fn parse_retry(&self, err: Error) -> Result<WsSessionEvent, RetryError<Error>> {
267        match err {
268            Error::WriteBufferFull(message) => Ok(WsSessionEvent::WriteMessage(message)),
269            Error::Io(io_err) => match &io_err.kind() {
270                ErrorKind::WouldBlock => Ok(WsSessionEvent::ReadMessage),
271                _ => Err(RetryError::ServiceError(Error::Io(io_err))),
272            },
273            err => Err(RetryError::ServiceError(err)),
274        }
275    }
276}
277
278/// The read-side of a split [`WsSession`].
279#[derive(Clone)]
280pub struct WsReader<S> {
281    ws: Arc<Mutex<WebSocket<S>>>,
282}
283impl<S> WsReader<S> {
284    fn new(ws: Arc<Mutex<WebSocket<S>>>) -> Self {
285        Self { ws }
286    }
287}
288impl<S: Read + Write> Service for WsReader<S> {
289    type Input = ();
290    type Output = Message;
291    type Error = Error;
292    fn process(&self, _: ()) -> Result<Self::Output, Self::Error> {
293        let mut lock = match self.ws.lock() {
294            Ok(lock) => lock,
295            Err(_) => {
296                return Err(Error::Io(io::Error::new(
297                    ErrorKind::Other,
298                    "WsReader mutex poisoned",
299                )))
300            }
301        };
302        lock.read()
303    }
304}
305impl<S> Retryable<(), Error> for WsReader<S> {
306    fn parse_retry(&self, err: Error) -> Result<(), RetryError<Error>> {
307        match err {
308            Error::Io(io_err) => match &io_err.kind() {
309                ErrorKind::WouldBlock => Ok(()),
310                _ => Err(RetryError::ServiceError(Error::Io(io_err))),
311            },
312            err => Err(RetryError::ServiceError(err)),
313        }
314    }
315}
316
317/// The write-side of a split [`WsSession`].
318#[derive(Clone)]
319pub struct WsWriter<S> {
320    ws: Arc<Mutex<WebSocket<S>>>,
321}
322impl<S> WsWriter<S> {
323    fn new(ws: Arc<Mutex<WebSocket<S>>>) -> Self {
324        Self { ws }
325    }
326}
327impl<S: Read + Write> Service for WsWriter<S> {
328    type Input = Message;
329    type Output = ();
330    type Error = Error;
331    fn process(&self, input: Message) -> Result<Self::Output, Self::Error> {
332        let mut lock = match self.ws.lock() {
333            Ok(lock) => lock,
334            Err(_) => {
335                return Err(Error::Io(io::Error::new(
336                    ErrorKind::Other,
337                    "WsWriter mutex poisoned",
338                )))
339            }
340        };
341        lock.write(input)
342    }
343}
344impl<S> Retryable<Message, Error> for WsWriter<S> {
345    fn parse_retry(&self, err: Error) -> Result<Message, RetryError<Error>> {
346        match err {
347            Error::WriteBufferFull(message) => Ok(message),
348            err => Err(RetryError::ServiceError(err)),
349        }
350    }
351}
352impl<S> Retryable<Option<Message>, Error> for WsWriter<S> {
353    fn parse_retry(&self, err: Error) -> Result<Option<Message>, RetryError<Error>> {
354        match err {
355            Error::WriteBufferFull(message) => Ok(Some(message)),
356            err => Err(RetryError::ServiceError(err)),
357        }
358    }
359}
360
361/// The flush-side of a split [`WsSession`].
362#[derive(Clone)]
363pub struct WsFlusher<S> {
364    ws: Arc<Mutex<WebSocket<S>>>,
365}
366impl<S> WsFlusher<S> {
367    fn new(ws: Arc<Mutex<WebSocket<S>>>) -> Self {
368        Self { ws }
369    }
370}
371impl<S: Read + Write> Service for WsFlusher<S> {
372    type Input = ();
373    type Output = ();
374    type Error = Error;
375    fn process(&self, (): ()) -> Result<Self::Output, Self::Error> {
376        let mut lock = match self.ws.lock() {
377            Ok(lock) => lock,
378            Err(_) => {
379                return Err(Error::Io(io::Error::new(
380                    ErrorKind::Other,
381                    "WsFlusher mutex poisoned",
382                )))
383            }
384        };
385        lock.flush()
386    }
387}
388
389/// Used to configure if and how TLS is used for a [`WsServer`].
390pub enum Tls {
391    None,
392    #[cfg(feature = "native-tls")]
393    Native,
394    #[cfg(feature = "__rustls-tls")]
395    Rustls,
396}
397
398/// A [`WsSession`] that has yet to complete its handshake.
399///
400/// Calling `UninitializedWsSession::handshake` will block on the handshake, producing a [`WsSession`].
401pub struct UninitializedWsSession {
402    stream: MaybeTlsStream<TcpStream>,
403    nonblocking: bool,
404}
405impl UninitializedWsSession {
406    fn new(stream: MaybeTlsStream<TcpStream>, nonblocking: bool) -> Self {
407        Self {
408            stream,
409            nonblocking,
410        }
411    }
412
413    /// Perform a blocking handshake, producing a [`WsSession`] or [`io::Error`] from `self`.
414    pub fn handshake(self) -> Result<WsSession<MaybeTlsStream<TcpStream>>, io::Error> {
415        self.handshake_with_params::<NoCallback>(None, None)
416    }
417
418    /// Perform a blocking handshake, with optional config and optional callback, producing a [`WsSession`] or [`io::Error`] from `self`.
419    pub fn handshake_with_params<C: Callback>(
420        self,
421        callback: Option<C>,
422        config: Option<WebSocketConfig>,
423    ) -> Result<WsSession<MaybeTlsStream<TcpStream>>, io::Error> {
424        let stream = self.stream;
425        set_nonblocking(&stream, false)?;
426        let ws = if let Some(callback) = callback {
427            match accept_hdr_with_config(stream, callback, config) {
428                Ok(v) => v,
429                Err(err) => {
430                    return Err(io::Error::new(
431                        ErrorKind::Other,
432                        format!("HandshakeError: {err:?}"),
433                    ))
434                }
435            }
436        } else {
437            match accept_with_config(stream, config) {
438                Ok(v) => v,
439                Err(err) => {
440                    return Err(io::Error::new(
441                        ErrorKind::Other,
442                        format!("HandshakeError: {err:?}"),
443                    ))
444                }
445            }
446        };
447        let session = WsSession::new(ws);
448        session.set_nonblocking(self.nonblocking)?;
449        return Ok(session);
450    }
451}
452
453/// A TCP Server that produces [`UninitializedWsSession`] as output.
454pub struct WsServer {
455    server: TcpListener,
456    tls: Tls,
457    nonblocking_sessions: bool,
458}
459impl WsServer {
460    /// Wrap the given TcpListener
461    pub fn new(server: TcpListener) -> Self {
462        Self {
463            server,
464            tls: Tls::None,
465            nonblocking_sessions: false,
466        }
467    }
468
469    /// Bind to the given socket address
470    pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self, io::Error> {
471        let server = TcpListener::bind(addr)?;
472        Ok(Self {
473            server,
474            tls: Tls::None,
475            nonblocking_sessions: false,
476        })
477    }
478}
479impl WsServer {
480    /// Builder pattern, set the TLS mode to use
481    pub fn with_tls(mut self, tls: Tls) -> Self {
482        self.tls = tls;
483        self
484    }
485    /// Builder pattern, configure the nonblocking status for the underlying [`TcpListener`]
486    pub fn with_nonblocking_server(self, nonblocking: bool) -> Result<Self, io::Error> {
487        self.server.set_nonblocking(nonblocking)?;
488        Ok(self)
489    }
490    /// Builder pattern, configure the default nonblocking status for produced [`WsSessions`] structs.
491    pub fn with_nonblocking_sessions(mut self, nonblocking_sessions: bool) -> Self {
492        self.nonblocking_sessions = nonblocking_sessions;
493        self
494    }
495}
496impl Service for WsServer {
497    type Input = ();
498    type Output = UninitializedWsSession;
499    type Error = io::Error;
500    fn process(&self, _: ()) -> Result<Self::Output, Self::Error> {
501        match self.server.accept() {
502            Ok((stream, _)) => {
503                #[cfg(not(feature = "native-tls"))]
504                let stream = match self.tls {
505                    Tls::None => MaybeTlsStream::Plain(stream),
506                    #[cfg(feature = "native-tls")]
507                    Tls::Native => MaybeTlsStream::NativeTls(stream),
508                    #[cfg(feature = "__rustls-tls")]
509                    Tls::Rustls => MaybeTlsStream::Rustls(stream),
510                };
511                Ok(UninitializedWsSession::new(
512                    stream,
513                    self.nonblocking_sessions,
514                ))
515            }
516            Err(err) => Err(err),
517        }
518    }
519}
520impl Retryable<(), io::Error> for WsServer {
521    fn parse_retry(&self, err: io::Error) -> Result<(), RetryError<io::Error>> {
522        match &err.kind() {
523            ErrorKind::WouldBlock => Ok(()),
524            _ => Err(RetryError::ServiceError(err)),
525        }
526    }
527}
528
529fn set_nonblocking(stream: &MaybeTlsStream<TcpStream>, nonblocking: bool) -> Result<(), io::Error> {
530    match stream {
531        MaybeTlsStream::Plain(stream) => stream.set_nonblocking(nonblocking),
532        #[cfg(feature = "native-tls")]
533        MaybeTlsStream::NativeTls(stream) => stream.set_nonblocking(nonblocking),
534        #[cfg(feature = "__rustls-tls")]
535        MaybeTlsStream::Rustls(stream) => stream.set_nonblocking(nonblocking),
536        _ => return Err(io::Error::new(ErrorKind::Other, "unrecognized stream type")),
537    }
538}