trillium_websockets/
json.rs

1/*!
2# websocket json adapter
3
4See the documentation for [`JsonWebSocketHandler`]
5*/
6
7use crate::{WebSocket, WebSocketConn, WebSocketHandler};
8use async_tungstenite::tungstenite::{protocol::CloseFrame, Message};
9use futures_lite::{ready, Stream};
10use serde::{de::DeserializeOwned, Serialize};
11use std::{
12    fmt::Debug,
13    ops::{Deref, DerefMut},
14    pin::Pin,
15    task::{Context, Poll},
16};
17use trillium::async_trait;
18
19/**
20# Implement this trait to use websockets with a json handler
21
22JsonWebSocketHandler provides a small layer of abstraction on top of
23[`WebSocketHandler`], serializing and deserializing messages for
24you. This may eventually move to a crate of its own.
25
26## ℹ️ In order to use this trait, the `json` crate feature must be enabled.
27
28```
29use async_channel::{unbounded, Receiver, Sender};
30use serde::{Deserialize, Serialize};
31use std::pin::Pin;
32use trillium::{async_trait, log_error};
33use trillium_websockets::{json_websocket, JsonWebSocketHandler, WebSocketConn, Result};
34
35#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
36struct Response {
37    inbound_message: Inbound,
38}
39
40#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
41struct Inbound {
42    message: String,
43}
44
45struct SomeJsonChannel;
46
47#[async_trait]
48impl JsonWebSocketHandler for SomeJsonChannel {
49    type InboundMessage = Inbound;
50    type OutboundMessage = Response;
51    type StreamType = Pin<Box<Receiver<Self::OutboundMessage>>>;
52
53    async fn connect(&self, conn: &mut WebSocketConn) -> Self::StreamType {
54        let (s, r) = unbounded();
55        conn.insert_state(s);
56        Box::pin(r)
57    }
58
59    async fn receive_message(
60        &self,
61        inbound_message: Result<Self::InboundMessage>,
62        conn: &mut WebSocketConn,
63    ) {
64        if let Ok(inbound_message) = inbound_message {
65            log_error!(
66                conn.state::<Sender<Response>>()
67                    .unwrap()
68                    .send(Response { inbound_message })
69                    .await
70            );
71        }
72    }
73}
74
75// fn main() {
76//    trillium_smol::run(json_websocket(SomeJsonChannel));
77// }
78```
79
80*/
81#[allow(unused_variables)]
82#[async_trait]
83pub trait JsonWebSocketHandler: Send + Sync + 'static {
84    /**
85    A type that can be deserialized from the json sent from the
86    connected clients
87    */
88    type InboundMessage: DeserializeOwned + Send + 'static;
89
90    /**
91    A serializable type that will be sent in the StreamType and
92    received by the connected websocket clients
93    */
94    type OutboundMessage: Serialize + Send + 'static;
95
96    /**
97    A type that implements a stream of
98    [`Self::OutboundMessage`]s. This can be
99    futures_lite::stream::Pending if you never need to send an
100    outbound message.
101    */
102    type StreamType: Stream<Item = Self::OutboundMessage> + Send + Sync + 'static;
103
104    /**
105    `connect` is called once for each upgraded websocket
106    connection, and returns a Self::StreamType.
107    */
108    async fn connect(&self, conn: &mut WebSocketConn) -> Self::StreamType;
109
110    /**
111    `receive_message` is called once for each successfully deserialized
112    InboundMessage along with the websocket conn that it was received
113    from.
114    */
115    async fn receive_message(
116        &self,
117        message: crate::Result<Self::InboundMessage>,
118        conn: &mut WebSocketConn,
119    );
120
121    /**
122    `disconnect` is called when websocket clients disconnect, along
123    with a CloseFrame, if one was provided. Implementing `disconnect`
124    is optional.
125    */
126    async fn disconnect(&self, conn: &mut WebSocketConn, close_frame: Option<CloseFrame<'static>>) {
127    }
128}
129
130/**
131A wrapper type for [`JsonWebSocketHandler`]s
132
133You do not need to interact with this type directly. Instead, use
134[`WebSocket::new_json`] or [`json_websocket`].
135*/
136pub struct JsonHandler<T> {
137    pub(crate) handler: T,
138}
139
140impl<T> Deref for JsonHandler<T> {
141    type Target = T;
142
143    fn deref(&self) -> &Self::Target {
144        &self.handler
145    }
146}
147
148impl<T> DerefMut for JsonHandler<T> {
149    fn deref_mut(&mut self) -> &mut Self::Target {
150        &mut self.handler
151    }
152}
153
154impl<T: JsonWebSocketHandler> JsonHandler<T> {
155    pub(crate) fn new(handler: T) -> Self {
156        Self { handler }
157    }
158}
159
160impl<T> Debug for JsonHandler<T> {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        f.debug_struct("JsonWebSocketHandler").finish()
163    }
164}
165
166pin_project_lite::pin_project! {
167    /**
168    A stream for internal use that attempts to serialize the items in the
169    wrapped stream to a [`Message::Text`]
170     */
171    #[derive(Debug)]
172    pub struct SerializedStream<T> {
173        #[pin] inner: T
174    }
175}
176
177impl<T> Stream for SerializedStream<T>
178where
179    T: Stream,
180    T::Item: Serialize,
181{
182    type Item = Message;
183
184    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
185        Poll::Ready(
186            ready!(self.project().inner.poll_next(cx))
187                .and_then(|i| match serde_json::to_string(&i) {
188                    Ok(j) => Some(j),
189                    Err(e) => {
190                        log::error!("serialization error: {e}");
191                        None
192                    }
193                })
194                .map(Message::Text),
195        )
196    }
197}
198
199#[async_trait]
200impl<T> WebSocketHandler for JsonHandler<T>
201where
202    T: JsonWebSocketHandler,
203{
204    type OutboundStream = SerializedStream<T::StreamType>;
205
206    async fn connect(
207        &self,
208        mut conn: WebSocketConn,
209    ) -> Option<(WebSocketConn, Self::OutboundStream)> {
210        let stream = SerializedStream {
211            inner: self.handler.connect(&mut conn).await,
212        };
213        Some((conn, stream))
214    }
215
216    async fn inbound(&self, message: Message, conn: &mut WebSocketConn) {
217        self.handler
218            .receive_message(
219                message
220                    .to_text()
221                    .map_err(Into::into)
222                    .and_then(|m| serde_json::from_str(m).map_err(Into::into)),
223                conn,
224            )
225            .await;
226    }
227
228    async fn disconnect(&self, conn: &mut WebSocketConn, close_frame: Option<CloseFrame<'static>>) {
229        self.handler.disconnect(conn, close_frame).await
230    }
231}
232
233impl<T> WebSocket<JsonHandler<T>>
234where
235    T: JsonWebSocketHandler,
236{
237    /**
238    Build a new trillium WebSocket handler from the provided
239    [`JsonWebSocketHandler`]
240     */
241    pub fn new_json(handler: T) -> Self {
242        Self::new(JsonHandler::new(handler))
243    }
244}
245
246/**
247builds a new trillium handler from the provided
248[`JsonWebSocketHandler`]. Alias for [`WebSocket::new_json`]
249*/
250pub fn json_websocket<T>(json_websocket_handler: T) -> WebSocket<JsonHandler<T>>
251where
252    T: JsonWebSocketHandler,
253{
254    WebSocket::new_json(json_websocket_handler)
255}