typed_websocket/
lib.rs

1use futures_util::{SinkExt, StreamExt};
2use serde::{Deserialize, Serialize};
3use std::error::Error;
4use tokio_tungstenite::tungstenite::protocol::Message;
5use tokio_tungstenite::WebSocketStream;
6
7/// Generic WebSocket wrapper for typed OUTPUT and INPUT messages
8pub struct TypedWebSocketStream<S, INPUT, OUTPUT>
9where
10    INPUT: Serialize,
11    OUTPUT: for<'de> Deserialize<'de>,
12{
13    stream: WebSocketStream<S>,
14    _marker_in: std::marker::PhantomData<INPUT>,
15    _marker_out: std::marker::PhantomData<OUTPUT>,
16}
17
18impl<S, OUTPUT, INPUT> TypedWebSocketStream<S, INPUT, OUTPUT>
19where
20    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
21    INPUT: Serialize,
22    OUTPUT: for<'de> Deserialize<'de>,
23{
24    /// Create a new TypedWebSocketStream
25    pub fn new(stream: WebSocketStream<S>) -> Self {
26        Self {
27            stream,
28            _marker_in: std::marker::PhantomData,
29            _marker_out: std::marker::PhantomData,
30        }
31    }
32
33    /// Send a strongly-typed message
34    pub async fn send(&mut self, message: INPUT) -> Result<(), Box<dyn Error>> {
35        let json = serde_json::to_string(&message)?; // Serialize the message to JSON
36        self.stream.send(Message::Text(json)).await?; // Send as WebSocket text message
37        Ok(())
38    }
39
40    /// Receive a strongly-typed message
41    pub async fn receive(&mut self) -> Result<OUTPUT, Box<dyn Error>> {
42        if let Some(Ok(Message::Text(json))) = self.stream.next().await {
43            let message: OUTPUT = serde_json::from_str(&json)?; // Deserialize the JSON
44            Ok(message)
45        } else {
46            Err("Failed to receive valid text message".into())
47        }
48    }
49
50    /// Close the WebSocket connection
51    pub async fn close(&mut self) -> Result<(), Box<dyn Error>> {
52        self.stream.send(Message::Close(None)).await?;
53        Ok(())
54    }
55}