1use futures_util::{SinkExt, StreamExt};
2use serde::{Deserialize, Serialize};
3use std::error::Error;
4use tokio_tungstenite::tungstenite::protocol::Message;
5use tokio_tungstenite::WebSocketStream;
6
7pub struct TypedWebSocketStream<S, INPUT, OUTPUT>
9where
10 INPUT: Serialize,
11 OUTPUT: for<'de> Deserialize<'de>,
12{
13 stream: WebSocketStream<S>,
14 _marker_in: std::marker::PhantomData<INPUT>,
15 _marker_out: std::marker::PhantomData<OUTPUT>,
16}
17
18impl<S, OUTPUT, INPUT> TypedWebSocketStream<S, INPUT, OUTPUT>
19where
20 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
21 INPUT: Serialize,
22 OUTPUT: for<'de> Deserialize<'de>,
23{
24 pub fn new(stream: WebSocketStream<S>) -> Self {
26 Self {
27 stream,
28 _marker_in: std::marker::PhantomData,
29 _marker_out: std::marker::PhantomData,
30 }
31 }
32
33 pub async fn send(&mut self, message: INPUT) -> Result<(), Box<dyn Error>> {
35 let json = serde_json::to_string(&message)?; self.stream.send(Message::Text(json)).await?; Ok(())
38 }
39
40 pub async fn receive(&mut self) -> Result<OUTPUT, Box<dyn Error>> {
42 if let Some(Ok(Message::Text(json))) = self.stream.next().await {
43 let message: OUTPUT = serde_json::from_str(&json)?; Ok(message)
45 } else {
46 Err("Failed to receive valid text message".into())
47 }
48 }
49
50 pub async fn close(&mut self) -> Result<(), Box<dyn Error>> {
52 self.stream.send(Message::Close(None)).await?;
53 Ok(())
54 }
55}