sod_tungstenite/lib.rs
1//! [`sod::Service`] implementations to interact with [`tungstenite`] websockets.
2//!
3//! ## Service Impls
4//!
5//! All Services are [`Retryable`] and are able to be blocking or non-blocking.
6//!
7//! - [`WsSession`] is a [`MutService`] that wraps a [`tungstenite::WebSocket`], accepting [`WsSessionEvent`] to send or receive messages. `WsSession::into_split` can split a `WsSession` into a `WsReader`, `WsWriter`, and `WsFlusher`.
8//! - [`WsReader`] is a [`Service`] that wraps a [`Mutex<tungstenite::WebSocket>`], accepting a `()` as input and producing [`tungstenite::Message`] as output.
9//! - [`WsWriter`] is a [`Service`] that wraps a [`Mutex<tungstenite::WebSocket>`], accepting a `tungstenite::Message` as input.
10//! - [`WsFlusher`] is a [`Service`] that wraps a [`Mutex<tungstenite::WebSocket>`], accepting a `()` as input.
11//! - [`WsServer`] is a [`Service`] that that listens on a TCP port, accepting a `()` as input and producing a `WsSession` as output.
12//!
13//! ## Features
14//!
15//! - `native-tls` to enable Native TLS
16//! - `__rustls-tls` to enable Rustls TLS`
17//!
18//! ## Blocking Example
19//!
20//! ```no_run
21//! use sod::{idle::backoff, MaybeProcessService, MutService, RetryService, Service, ServiceChain};
22//! use sod_tungstenite::{UninitializedWsSession, WsServer, WsSession, WsSessionEvent};
23//! use std::{sync::atomic::Ordering, thread::spawn};
24//! use tungstenite::{http::StatusCode, Message};
25//! use url::Url;
26//!
27//! // server session logic to add `"pong: "` in front of text payload
28//! struct PongService;
29//! impl Service for PongService {
30//! type Input = Message;
31//! type Output = Option<Message>;
32//! type Error = ();
33//! fn process(&self, input: Message) -> Result<Self::Output, Self::Error> {
34//! match input {
35//! Message::Text(text) => Ok(Some(Message::Text(format!("pong: {text}")))),
36//! _ => Ok(None),
37//! }
38//! }
39//! }
40//!
41//! // wires session logic and spawns in new thread
42//! struct SessionSpawner;
43//! impl Service for SessionSpawner {
44//! type Input = UninitializedWsSession;
45//! type Output = ();
46//! type Error = ();
47//! fn process(&self, input: UninitializedWsSession) -> Result<Self::Output, Self::Error> {
48//! spawn(|| {
49//! let (r, w, f) = input.handshake().unwrap().into_split();
50//! let chain = ServiceChain::start(r)
51//! .next(PongService)
52//! .next(MaybeProcessService::new(w))
53//! .next(MaybeProcessService::new(f))
54//! .end();
55//! sod::thread::spawn_loop(chain, |err| {
56//! println!("Session: {err:?}");
57//! Err(err) // stop thread on error
58//! });
59//! });
60//! Ok(())
61//! }
62//! }
63//!
64//! // start a blocking server that creates blocking sessions
65//! let server = WsServer::bind("127.0.0.1:48490").unwrap();
66//!
67//! // spawn a thread to start accepting new server sessions
68//! let handle = sod::thread::spawn_loop(
69//! ServiceChain::start(server).next(SessionSpawner).end(),
70//! |err| {
71//! println!("Server: {err:?}");
72//! Err(err) // stop thread on error
73//! },
74//! );
75//!
76//! // connect a client to the server
77//! let (mut client, _) =
78//! WsSession::connect(Url::parse("ws://127.0.0.1:48490/socket").unwrap()).unwrap();
79//!
80//! // client writes `"hello world"` payload
81//! client
82//! .process(WsSessionEvent::WriteMessage(Message::Text(
83//! "hello world!".to_owned(),
84//! )))
85//! .unwrap();
86//!
87//! // client receives `"pong: hello world"` payload
88//! println!(
89//! "Received: {:?}",
90//! client.process(WsSessionEvent::ReadMessage).unwrap()
91//! );
92//!
93//! // join until server crashes
94//! handle.join().unwrap();
95//! ```
96//!
97//! ## Non-Blocking Example
98//!
99//! ```
100//! use sod::{idle::backoff, MaybeProcessService, MutService, RetryService, Service, ServiceChain};
101//! use sod_tungstenite::{UninitializedWsSession, WsServer, WsSession, WsSessionEvent};
102//! use std::{sync::atomic::Ordering, thread::spawn};
103//! use tungstenite::{http::StatusCode, Message};
104//! use url::Url;
105//!
106//! // server session logic to add `"pong: "` in front of text payload
107//! struct PongService;
108//! impl Service for PongService {
109//! type Input = Message;
110//! type Output = Option<Message>;
111//! type Error = ();
112//! fn process(&self, input: Message) -> Result<Self::Output, Self::Error> {
113//! match input {
114//! Message::Text(text) => Ok(Some(Message::Text(format!("pong: {text}")))),
115//! _ => Ok(None),
116//! }
117//! }
118//! }
119//!
120//! // wires session logic and spawns in new thread
121//! struct SessionSpawner;
122//! impl Service for SessionSpawner {
123//! type Input = UninitializedWsSession;
124//! type Output = ();
125//! type Error = ();
126//! fn process(&self, input: UninitializedWsSession) -> Result<Self::Output, Self::Error> {
127//! spawn(|| {
128//! let (r, w, f) = input.handshake().unwrap().into_split();
129//! let chain = ServiceChain::start(RetryService::new(r, backoff))
130//! .next(PongService)
131//! .next(MaybeProcessService::new(RetryService::new(w, backoff)))
132//! .next(MaybeProcessService::new(f))
133//! .end();
134//! sod::thread::spawn_loop(chain, |err| {
135//! println!("Session: {err:?}");
136//! Err(err) // stop thread on error
137//! });
138//! });
139//! Ok(())
140//! }
141//! }
142//!
143//! // start a non-blocking server that creates non-blocking sessions
144//! let server = WsServer::bind("127.0.0.1:48490")
145//! .unwrap()
146//! .with_nonblocking_sessions(true)
147//! .with_nonblocking_server(true)
148//! .unwrap();
149//!
150//! // spawn a thread to start accepting new server sessions
151//! let handle = sod::thread::spawn_loop(
152//! ServiceChain::start(RetryService::new(server, backoff))
153//! .next(SessionSpawner)
154//! .end(),
155//! |err| {
156//! println!("Server: {err:?}");
157//! Err(err) // stop thread on error
158//! },
159//! );
160//!
161//! // connect a client to the server
162//! let (mut client, response) =
163//! WsSession::connect(Url::parse("ws://127.0.0.1:48490/socket").unwrap()).unwrap();
164//! assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
165//!
166//! // client writes `"hello world"` payload
167//! client
168//! .process(WsSessionEvent::WriteMessage(Message::Text(
169//! "hello world!".to_owned(),
170//! )))
171//! .unwrap();
172//!
173//! // client receives `"pong: hello world"` payload
174//! assert_eq!(
175//! client.process(WsSessionEvent::ReadMessage).unwrap(),
176//! Some(Message::Text("pong: hello world!".to_owned()))
177//! );
178//!
179//! // stop the server
180//! sod::idle::KEEP_RUNNING.store(false, Ordering::Release);
181//! handle.join().unwrap();
182//! ```
183
184use sod::{MutService, RetryError, Retryable, Service};
185use std::{
186 borrow::BorrowMut,
187 io::{self, ErrorKind, Read, Write},
188 net::{TcpListener, TcpStream, ToSocketAddrs},
189 sync::{Arc, Mutex},
190};
191use tungstenite::{
192 accept_hdr_with_config, accept_with_config,
193 client::IntoClientRequest,
194 handshake::{
195 client::Response,
196 server::{Callback, NoCallback},
197 },
198 protocol::WebSocketConfig,
199 stream::MaybeTlsStream,
200 Error, Message, WebSocket,
201};
202
203pub extern crate tungstenite;
204
205/// An input event for [`WsSession`], which can be a read or write.
206#[derive(Clone, Debug)]
207pub enum WsSessionEvent {
208 ReadMessage,
209 WriteMessage(Message),
210 Flush,
211}
212
213/// A [`MutService`] that wraps a [`tungstenite::WebSocket`], processing a [`WsSessionEvent`], producing a `Some(Message)` when a [`Message`] is read, and producing `None` otherwise.
214pub struct WsSession<S> {
215 ws: WebSocket<S>,
216}
217impl<S> WsSession<S> {
218 /// Wrap the given [`WebSocket`]
219 pub fn new(ws: WebSocket<S>) -> Self {
220 Self { ws }
221 }
222 /// Split this `WsSession` into a [`WsReader`] and [`WsWriter`], utilizing a [`Mutex`] to coordinate mutability on the underlying stream.
223 pub fn into_split(self) -> (WsReader<S>, WsWriter<S>, WsFlusher<S>) {
224 let ws = Arc::new(Mutex::new(self.ws));
225 (
226 WsReader::new(Arc::clone(&ws)),
227 WsWriter::new(Arc::clone(&ws)),
228 WsFlusher::new(ws),
229 )
230 }
231}
232impl WsSession<MaybeTlsStream<TcpStream>> {
233 /// Connect to the given URL as a WebSocket Client, producing a [`WsSession`] and HTTP [`Response`].
234 pub fn connect<Req: IntoClientRequest>(
235 request: Req,
236 ) -> Result<(WsSession<MaybeTlsStream<TcpStream>>, Response), Error> {
237 let (ws, resp) = tungstenite::connect(request)?;
238 Ok((WsSession::new(ws), resp))
239 }
240 /// Configure the underlying [`MaybeTlsStream`] to be non-blocking.
241 ///
242 /// Non-blocking services should usually be encpasulated by a [`RetryService`].
243 pub fn set_nonblocking(&self, nonblocking: bool) -> Result<(), io::Error> {
244 set_nonblocking(self.ws.get_ref(), nonblocking)
245 }
246}
247impl<S: Read + Write> MutService for WsSession<S> {
248 type Input = WsSessionEvent;
249 type Output = Option<Message>;
250 type Error = Error;
251 fn process(&mut self, input: WsSessionEvent) -> Result<Self::Output, Self::Error> {
252 Ok(match input {
253 WsSessionEvent::ReadMessage => Some(self.ws.borrow_mut().read()?),
254 WsSessionEvent::WriteMessage(message) => {
255 self.ws.borrow_mut().send(message)?;
256 None
257 }
258 WsSessionEvent::Flush => {
259 self.ws.borrow_mut().flush()?;
260 None
261 }
262 })
263 }
264}
265impl<S> Retryable<WsSessionEvent, Error> for WsSession<S> {
266 fn parse_retry(&self, err: Error) -> Result<WsSessionEvent, RetryError<Error>> {
267 match err {
268 Error::WriteBufferFull(message) => Ok(WsSessionEvent::WriteMessage(message)),
269 Error::Io(io_err) => match &io_err.kind() {
270 ErrorKind::WouldBlock => Ok(WsSessionEvent::ReadMessage),
271 _ => Err(RetryError::ServiceError(Error::Io(io_err))),
272 },
273 err => Err(RetryError::ServiceError(err)),
274 }
275 }
276}
277
278/// The read-side of a split [`WsSession`].
279#[derive(Clone)]
280pub struct WsReader<S> {
281 ws: Arc<Mutex<WebSocket<S>>>,
282}
283impl<S> WsReader<S> {
284 fn new(ws: Arc<Mutex<WebSocket<S>>>) -> Self {
285 Self { ws }
286 }
287}
288impl<S: Read + Write> Service for WsReader<S> {
289 type Input = ();
290 type Output = Message;
291 type Error = Error;
292 fn process(&self, _: ()) -> Result<Self::Output, Self::Error> {
293 let mut lock = match self.ws.lock() {
294 Ok(lock) => lock,
295 Err(_) => {
296 return Err(Error::Io(io::Error::new(
297 ErrorKind::Other,
298 "WsReader mutex poisoned",
299 )))
300 }
301 };
302 lock.read()
303 }
304}
305impl<S> Retryable<(), Error> for WsReader<S> {
306 fn parse_retry(&self, err: Error) -> Result<(), RetryError<Error>> {
307 match err {
308 Error::Io(io_err) => match &io_err.kind() {
309 ErrorKind::WouldBlock => Ok(()),
310 _ => Err(RetryError::ServiceError(Error::Io(io_err))),
311 },
312 err => Err(RetryError::ServiceError(err)),
313 }
314 }
315}
316
317/// The write-side of a split [`WsSession`].
318#[derive(Clone)]
319pub struct WsWriter<S> {
320 ws: Arc<Mutex<WebSocket<S>>>,
321}
322impl<S> WsWriter<S> {
323 fn new(ws: Arc<Mutex<WebSocket<S>>>) -> Self {
324 Self { ws }
325 }
326}
327impl<S: Read + Write> Service for WsWriter<S> {
328 type Input = Message;
329 type Output = ();
330 type Error = Error;
331 fn process(&self, input: Message) -> Result<Self::Output, Self::Error> {
332 let mut lock = match self.ws.lock() {
333 Ok(lock) => lock,
334 Err(_) => {
335 return Err(Error::Io(io::Error::new(
336 ErrorKind::Other,
337 "WsWriter mutex poisoned",
338 )))
339 }
340 };
341 lock.write(input)
342 }
343}
344impl<S> Retryable<Message, Error> for WsWriter<S> {
345 fn parse_retry(&self, err: Error) -> Result<Message, RetryError<Error>> {
346 match err {
347 Error::WriteBufferFull(message) => Ok(message),
348 err => Err(RetryError::ServiceError(err)),
349 }
350 }
351}
352impl<S> Retryable<Option<Message>, Error> for WsWriter<S> {
353 fn parse_retry(&self, err: Error) -> Result<Option<Message>, RetryError<Error>> {
354 match err {
355 Error::WriteBufferFull(message) => Ok(Some(message)),
356 err => Err(RetryError::ServiceError(err)),
357 }
358 }
359}
360
361/// The flush-side of a split [`WsSession`].
362#[derive(Clone)]
363pub struct WsFlusher<S> {
364 ws: Arc<Mutex<WebSocket<S>>>,
365}
366impl<S> WsFlusher<S> {
367 fn new(ws: Arc<Mutex<WebSocket<S>>>) -> Self {
368 Self { ws }
369 }
370}
371impl<S: Read + Write> Service for WsFlusher<S> {
372 type Input = ();
373 type Output = ();
374 type Error = Error;
375 fn process(&self, (): ()) -> Result<Self::Output, Self::Error> {
376 let mut lock = match self.ws.lock() {
377 Ok(lock) => lock,
378 Err(_) => {
379 return Err(Error::Io(io::Error::new(
380 ErrorKind::Other,
381 "WsFlusher mutex poisoned",
382 )))
383 }
384 };
385 lock.flush()
386 }
387}
388
389/// Used to configure if and how TLS is used for a [`WsServer`].
390pub enum Tls {
391 None,
392 #[cfg(feature = "native-tls")]
393 Native,
394 #[cfg(feature = "__rustls-tls")]
395 Rustls,
396}
397
398/// A [`WsSession`] that has yet to complete its handshake.
399///
400/// Calling `UninitializedWsSession::handshake` will block on the handshake, producing a [`WsSession`].
401pub struct UninitializedWsSession {
402 stream: MaybeTlsStream<TcpStream>,
403 nonblocking: bool,
404}
405impl UninitializedWsSession {
406 fn new(stream: MaybeTlsStream<TcpStream>, nonblocking: bool) -> Self {
407 Self {
408 stream,
409 nonblocking,
410 }
411 }
412
413 /// Perform a blocking handshake, producing a [`WsSession`] or [`io::Error`] from `self`.
414 pub fn handshake(self) -> Result<WsSession<MaybeTlsStream<TcpStream>>, io::Error> {
415 self.handshake_with_params::<NoCallback>(None, None)
416 }
417
418 /// Perform a blocking handshake, with optional config and optional callback, producing a [`WsSession`] or [`io::Error`] from `self`.
419 pub fn handshake_with_params<C: Callback>(
420 self,
421 callback: Option<C>,
422 config: Option<WebSocketConfig>,
423 ) -> Result<WsSession<MaybeTlsStream<TcpStream>>, io::Error> {
424 let stream = self.stream;
425 set_nonblocking(&stream, false)?;
426 let ws = if let Some(callback) = callback {
427 match accept_hdr_with_config(stream, callback, config) {
428 Ok(v) => v,
429 Err(err) => {
430 return Err(io::Error::new(
431 ErrorKind::Other,
432 format!("HandshakeError: {err:?}"),
433 ))
434 }
435 }
436 } else {
437 match accept_with_config(stream, config) {
438 Ok(v) => v,
439 Err(err) => {
440 return Err(io::Error::new(
441 ErrorKind::Other,
442 format!("HandshakeError: {err:?}"),
443 ))
444 }
445 }
446 };
447 let session = WsSession::new(ws);
448 session.set_nonblocking(self.nonblocking)?;
449 return Ok(session);
450 }
451}
452
453/// A TCP Server that produces [`UninitializedWsSession`] as output.
454pub struct WsServer {
455 server: TcpListener,
456 tls: Tls,
457 nonblocking_sessions: bool,
458}
459impl WsServer {
460 /// Wrap the given TcpListener
461 pub fn new(server: TcpListener) -> Self {
462 Self {
463 server,
464 tls: Tls::None,
465 nonblocking_sessions: false,
466 }
467 }
468
469 /// Bind to the given socket address
470 pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self, io::Error> {
471 let server = TcpListener::bind(addr)?;
472 Ok(Self {
473 server,
474 tls: Tls::None,
475 nonblocking_sessions: false,
476 })
477 }
478}
479impl WsServer {
480 /// Builder pattern, set the TLS mode to use
481 pub fn with_tls(mut self, tls: Tls) -> Self {
482 self.tls = tls;
483 self
484 }
485 /// Builder pattern, configure the nonblocking status for the underlying [`TcpListener`]
486 pub fn with_nonblocking_server(self, nonblocking: bool) -> Result<Self, io::Error> {
487 self.server.set_nonblocking(nonblocking)?;
488 Ok(self)
489 }
490 /// Builder pattern, configure the default nonblocking status for produced [`WsSessions`] structs.
491 pub fn with_nonblocking_sessions(mut self, nonblocking_sessions: bool) -> Self {
492 self.nonblocking_sessions = nonblocking_sessions;
493 self
494 }
495}
496impl Service for WsServer {
497 type Input = ();
498 type Output = UninitializedWsSession;
499 type Error = io::Error;
500 fn process(&self, _: ()) -> Result<Self::Output, Self::Error> {
501 match self.server.accept() {
502 Ok((stream, _)) => {
503 #[cfg(not(feature = "native-tls"))]
504 let stream = match self.tls {
505 Tls::None => MaybeTlsStream::Plain(stream),
506 #[cfg(feature = "native-tls")]
507 Tls::Native => MaybeTlsStream::NativeTls(stream),
508 #[cfg(feature = "__rustls-tls")]
509 Tls::Rustls => MaybeTlsStream::Rustls(stream),
510 };
511 Ok(UninitializedWsSession::new(
512 stream,
513 self.nonblocking_sessions,
514 ))
515 }
516 Err(err) => Err(err),
517 }
518 }
519}
520impl Retryable<(), io::Error> for WsServer {
521 fn parse_retry(&self, err: io::Error) -> Result<(), RetryError<io::Error>> {
522 match &err.kind() {
523 ErrorKind::WouldBlock => Ok(()),
524 _ => Err(RetryError::ServiceError(err)),
525 }
526 }
527}
528
529fn set_nonblocking(stream: &MaybeTlsStream<TcpStream>, nonblocking: bool) -> Result<(), io::Error> {
530 match stream {
531 MaybeTlsStream::Plain(stream) => stream.set_nonblocking(nonblocking),
532 #[cfg(feature = "native-tls")]
533 MaybeTlsStream::NativeTls(stream) => stream.set_nonblocking(nonblocking),
534 #[cfg(feature = "__rustls-tls")]
535 MaybeTlsStream::Rustls(stream) => stream.set_nonblocking(nonblocking),
536 _ => return Err(io::Error::new(ErrorKind::Other, "unrecognized stream type")),
537 }
538}