Skip to main content

titan_api_codec/codec/ws/
v1.rs

1//! Codecs for version 1 of the Titan WebSocket API.
2
3use std::str::FromStr;
4
5use titan_api_types::ws::v1;
6
7use crate::{
8    codec::{Codec, CodecLoadError, TypedDecoder, TypedEncoder, WrappedDecoder, WrappedEncoder},
9    dec::{messagepack::MessagePackDecoder, DecodeError, Decoder},
10    enc::{messagepack::MessagePackEncoder, EncodeError, Encoder},
11    transform::{
12        BrotliCompressor, BrotliDecompressor, GzipCompressor, GzipDecompressor, ZstdCompressor,
13        ZstdDecompressor,
14    },
15};
16
17/// Codec for clients to the Titan WebSocket API version 1.
18#[derive(Debug, Default, PartialEq, Eq)]
19#[non_exhaustive]
20pub enum ClientCodec {
21    /// V1 messages encoded via MessagePack with no compression.
22    #[default]
23    Uncompressed,
24    /// V1 messages encoded via MessagePack with zstd compression.
25    #[cfg(feature = "zstd")]
26    Zstd,
27    /// V1 messages encoded via MessagePack with brotli compression.
28    #[cfg(feature = "brotli")]
29    Brotli,
30    /// V1 messages encoded via MessagePack with gzip compression.
31    #[cfg(feature = "gzip")]
32    Gzip,
33}
34
35/// Codec for servers providing the Titan WebSocket API version 1.
36#[derive(Debug, Default, PartialEq, Eq)]
37#[non_exhaustive]
38pub enum ServerCodec {
39    /// V1 messages encoded via MessagePack with no compression.
40    #[default]
41    Uncompressed,
42    /// V1 messages encoded via MessagePack with zstd compression.
43    #[cfg(feature = "zstd")]
44    Zstd,
45    /// V1 messages encoded via MessagePack with brotli compression.
46    #[cfg(feature = "brotli")]
47    Brotli,
48    /// V1 messages encoded via MessagePack with gzip compression.
49    #[cfg(feature = "gzip")]
50    Gzip,
51}
52
53impl FromStr for ClientCodec {
54    type Err = CodecLoadError;
55
56    /// Attempts to load a [`ClientCodec`] from a WebSocket subprotocol string.
57    fn from_str(s: &str) -> Result<Self, Self::Err> {
58        match s {
59            v1::WEBSOCKET_SUBPROTO_BASE => Ok(Self::Uncompressed),
60            #[cfg(feature = "zstd")]
61            v1::WEBSOCKET_SUBPROTO_ZSTD => Ok(Self::Zstd),
62            #[cfg(not(feature = "zstd"))]
63            v1::WEBSOCKET_SUBPROTO_ZSTD => Err(CodecLoadError::DisabledEncoding("zstd")),
64            #[cfg(feature = "brotli")]
65            v1::WEBSOCKET_SUBPROTO_BROTLI => Ok(Self::Brotli),
66            #[cfg(not(feature = "brotli"))]
67            v1::WEBSOCKET_SUBPROTO_ZSTD => Err(CodecLoadError::DisabledEncoding("brotli")),
68            #[cfg(feature = "gzip")]
69            v1::WEBSOCKET_SUBPROTO_GZIP => Ok(Self::Gzip),
70            #[cfg(not(feature = "gzip"))]
71            v1::WEBSOCKET_SUBPROTO_ZSTD => Err(CodecLoadError::DisabledEncoding("gzip")),
72            _ => Err(CodecLoadError::UnsupportedProtocol(s.into())),
73        }
74    }
75}
76
77impl Codec for ClientCodec {
78    type SendItem = v1::ClientRequest;
79    type SendError = EncodeError;
80    type RecvItem = v1::ServerMessage;
81    type RecvError = DecodeError;
82
83    fn encoder(
84        &self,
85    ) -> Box<dyn TypedEncoder<Self::SendItem, Error = Self::SendError> + Send + Sync> {
86        match self {
87            Self::Uncompressed => Box::new(WrappedEncoder::new(MessagePackEncoder::default())),
88            #[cfg(feature = "zstd")]
89            Self::Zstd => Box::new(WrappedEncoder::new(
90                MessagePackEncoder::default().transform(ZstdCompressor::default()),
91            )),
92            #[cfg(feature = "brotli")]
93            Self::Brotli => Box::new(WrappedEncoder::new(
94                MessagePackEncoder::default().transform(BrotliCompressor::default()),
95            )),
96            #[cfg(feature = "gzip")]
97            Self::Gzip => Box::new(WrappedEncoder::new(
98                MessagePackEncoder::default().transform(GzipCompressor::default()),
99            )),
100        }
101    }
102
103    fn decoder(
104        &self,
105    ) -> Box<dyn TypedDecoder<Item = Self::RecvItem, Error = Self::RecvError> + Send + Sync> {
106        match self {
107            Self::Uncompressed => Box::new(WrappedDecoder::new(MessagePackDecoder::default())),
108            #[cfg(feature = "zstd")]
109            Self::Zstd => Box::new(WrappedDecoder::new(
110                MessagePackDecoder::default().transformed(ZstdDecompressor::default()),
111            )),
112            #[cfg(feature = "brotli")]
113            Self::Brotli => Box::new(WrappedDecoder::new(
114                MessagePackDecoder::default().transformed(BrotliDecompressor::default()),
115            )),
116            #[cfg(feature = "gzip")]
117            Self::Gzip => Box::new(WrappedDecoder::new(
118                MessagePackDecoder::default().transformed(GzipDecompressor::default()),
119            )),
120        }
121    }
122}
123
124impl FromStr for ServerCodec {
125    type Err = CodecLoadError;
126
127    /// Attempts to load a [`ServerCodec`] from a WebSocket subprotocol string.
128    fn from_str(s: &str) -> Result<Self, Self::Err> {
129        match s {
130            v1::WEBSOCKET_SUBPROTO_BASE => Ok(Self::Uncompressed),
131            #[cfg(feature = "zstd")]
132            v1::WEBSOCKET_SUBPROTO_ZSTD => Ok(Self::Zstd),
133            #[cfg(not(feature = "zstd"))]
134            v1::WEBSOCKET_SUBPROTO_ZSTD => Err(CodecLoadError::DisabledEncoding("zstd")),
135            #[cfg(feature = "brotli")]
136            v1::WEBSOCKET_SUBPROTO_BROTLI => Ok(Self::Brotli),
137            #[cfg(not(feature = "brotli"))]
138            v1::WEBSOCKET_SUBPROTO_ZSTD => Err(CodecLoadError::DisabledEncoding("brotli")),
139            #[cfg(feature = "gzip")]
140            v1::WEBSOCKET_SUBPROTO_GZIP => Ok(Self::Gzip),
141            #[cfg(not(feature = "gzip"))]
142            v1::WEBSOCKET_SUBPROTO_ZSTD => Err(CodecLoadError::DisabledEncoding("gzip")),
143            _ => Err(CodecLoadError::UnsupportedProtocol(s.into())),
144        }
145    }
146}
147
148impl Codec for ServerCodec {
149    type SendItem = v1::ServerMessage;
150    type SendError = EncodeError;
151    type RecvItem = v1::ClientRequest;
152    type RecvError = DecodeError;
153
154    fn encoder(
155        &self,
156    ) -> Box<dyn TypedEncoder<Self::SendItem, Error = Self::SendError> + Send + Sync> {
157        match self {
158            Self::Uncompressed => Box::new(WrappedEncoder::new(MessagePackEncoder::default())),
159            #[cfg(feature = "zstd")]
160            Self::Zstd => Box::new(WrappedEncoder::new(
161                MessagePackEncoder::default().transform(ZstdCompressor::default()),
162            )),
163            #[cfg(feature = "brotli")]
164            Self::Brotli => Box::new(WrappedEncoder::new(
165                MessagePackEncoder::default().transform(BrotliCompressor::default()),
166            )),
167            #[cfg(feature = "gzip")]
168            Self::Gzip => Box::new(WrappedEncoder::new(
169                MessagePackEncoder::default().transform(GzipCompressor::default()),
170            )),
171        }
172    }
173
174    fn decoder(
175        &self,
176    ) -> Box<dyn TypedDecoder<Item = Self::RecvItem, Error = Self::RecvError> + Send + Sync> {
177        match self {
178            Self::Uncompressed => Box::new(WrappedDecoder::new(MessagePackDecoder::default())),
179            #[cfg(feature = "zstd")]
180            Self::Zstd => Box::new(WrappedDecoder::new(
181                MessagePackDecoder::default().transformed(ZstdDecompressor::default()),
182            )),
183            #[cfg(feature = "brotli")]
184            Self::Brotli => Box::new(WrappedDecoder::new(
185                MessagePackDecoder::default().transformed(BrotliDecompressor::default()),
186            )),
187            #[cfg(feature = "gzip")]
188            Self::Gzip => Box::new(WrappedDecoder::new(
189                MessagePackDecoder::default().transformed(GzipDecompressor::default()),
190            )),
191        }
192    }
193}
194
195#[cfg(test)]
196mod test {
197    use std::str::FromStr;
198
199    use titan_api_types::ws::v1;
200
201    use crate::codec::{Codec, CodecLoadError};
202
203    use super::{ClientCodec, ServerCodec};
204
205    #[test]
206    fn construct_client_codec_base() {
207        let codec = ClientCodec::from_str(v1::WEBSOCKET_SUBPROTO_BASE)
208            .expect("should construct base codec");
209        assert_eq!(codec, ClientCodec::Uncompressed);
210    }
211
212    #[cfg(feature = "zstd")]
213    #[test]
214    fn construct_client_codec_zstd() {
215        let codec = ClientCodec::from_str(v1::WEBSOCKET_SUBPROTO_ZSTD)
216            .expect("should construct zstd codec");
217        assert_eq!(codec, ClientCodec::Zstd);
218    }
219
220    #[cfg(feature = "brotli")]
221    #[test]
222    fn construct_client_codec_brotli() {
223        let codec = ClientCodec::from_str(v1::WEBSOCKET_SUBPROTO_BROTLI)
224            .expect("should construct brotli codec");
225        assert_eq!(codec, ClientCodec::Brotli);
226    }
227
228    #[cfg(feature = "gzip")]
229    #[test]
230    fn construct_client_codec_gzip() {
231        let codec = ClientCodec::from_str(v1::WEBSOCKET_SUBPROTO_GZIP)
232            .expect("should construct gzip codec");
233        assert_eq!(codec, ClientCodec::Gzip);
234    }
235
236    #[test]
237    fn construct_client_codec_invalid() {
238        let err =
239            ClientCodec::from_str("invalid").expect_err("should have errored on invalid protocol");
240        assert_eq!(err, CodecLoadError::UnsupportedProtocol("invalid".into()))
241    }
242
243    #[test]
244    fn construct_server_codec_base() {
245        let codec = ServerCodec::from_str(v1::WEBSOCKET_SUBPROTO_BASE)
246            .expect("should construct base codec");
247        assert_eq!(codec, ServerCodec::Uncompressed);
248    }
249
250    #[cfg(feature = "zstd")]
251    #[test]
252    fn construct_server_codec_zstd() {
253        let codec = ServerCodec::from_str(v1::WEBSOCKET_SUBPROTO_ZSTD)
254            .expect("should construct zstd codec");
255        assert_eq!(codec, ServerCodec::Zstd);
256    }
257
258    #[cfg(feature = "brotli")]
259    #[test]
260    fn construct_server_codec_brotli() {
261        let codec = ServerCodec::from_str(v1::WEBSOCKET_SUBPROTO_BROTLI)
262            .expect("should construct brotli codec");
263        assert_eq!(codec, ServerCodec::Brotli);
264    }
265
266    #[cfg(feature = "gzip")]
267    #[test]
268    fn construct_server_codec_gzip() {
269        let codec = ServerCodec::from_str(v1::WEBSOCKET_SUBPROTO_GZIP)
270            .expect("should construct gzip codec");
271        assert_eq!(codec, ServerCodec::Gzip);
272    }
273
274    #[test]
275    fn construct_server_codec_invalid() {
276        let err =
277            ServerCodec::from_str("invalid").expect_err("should have errored on invalid protocol");
278        assert_eq!(err, CodecLoadError::UnsupportedProtocol("invalid".into()))
279    }
280
281    #[test]
282    fn roundtrip_base() {
283        let client_codec = ClientCodec::Uncompressed;
284        let mut encoder = client_codec.encoder();
285        let server_codec = ServerCodec::Uncompressed;
286        let mut decoder = server_codec.decoder();
287
288        let request = v1::ClientRequest {
289            id: 1,
290            data: v1::RequestData::GetInfo(v1::GetInfoRequest::default()),
291        };
292        let encoded = encoder
293            .encode_mut(&request)
294            .expect("should encode client request");
295        let decoded = decoder
296            .decode_mut(encoded)
297            .expect("should decode client request");
298
299        assert_eq!(request, decoded);
300    }
301}