Skip to main content

rustrade_integration/serde/de/
mod.rs

1use crate::serde::de::error::{DeBinaryError, DeBinaryErrorKind};
2use tracing::debug;
3
4/// Deserialisation error variants.
5pub mod error;
6
7/// Deserialisation utilities.
8mod util;
9
10// Re-export deserialisation utilities.
11pub use util::*;
12
13/// Trait for types that can deserialise input payload data (eg/ strings, bytes, etc.) into
14/// structured Rust types.
15pub trait Deserialiser<Input, Output> {
16    type Error;
17
18    /// Deserialises the `Input` into the `Output`.
19    fn deserialise(input: Input) -> Result<Output, Self::Error>;
20}
21
22/// JSON deserialiser.
23#[derive(Debug, Default)]
24pub struct DeJson;
25
26impl<'a, Output> Deserialiser<&'a [u8], Output> for DeJson
27where
28    Output: serde::Deserialize<'a> + 'a,
29{
30    type Error = DeBinaryError;
31
32    fn deserialise(input: &'a [u8]) -> Result<Output, Self::Error> {
33        Self::de_bytes(input)
34    }
35}
36
37impl<Output> Deserialiser<bytes::Bytes, Output> for DeJson
38where
39    Output: for<'a> serde::Deserialize<'a>,
40{
41    type Error = DeBinaryError;
42
43    fn deserialise(input: bytes::Bytes) -> Result<Output, Self::Error> {
44        Self::de_bytes(input.as_ref())
45    }
46}
47
48impl DeJson {
49    /// Deserialises a byte slice into the target `Output` type using [`serde_json`].
50    pub fn de_bytes<'a, Output>(input: &'a [u8]) -> Result<Output, DeBinaryError>
51    where
52        Output: serde::Deserialize<'a> + 'a,
53    {
54        serde_json::from_slice::<Output>(input).map_err(|error| {
55            let input_str = std::str::from_utf8(input).unwrap_or("<invalid UTF-8>");
56
57            debug!(
58                %error,
59                ?input,
60                %input_str,
61                input_type = "&[u8]",
62                target_type = %std::any::type_name::<Output>(),
63                "failed to deserialise via SerDe"
64            );
65
66            DeBinaryError {
67                payload: input.to_vec(),
68                kind: DeBinaryErrorKind::Serde(error),
69            }
70        })
71    }
72}
73
74/// Protobuf deserialiser.
75#[derive(Debug, Default)]
76pub struct DeProtobuf;
77
78impl<'a, Output> Deserialiser<&'a [u8], Output> for DeProtobuf
79where
80    Output: prost::Message + Default,
81{
82    type Error = DeBinaryError;
83
84    fn deserialise(input: &'a [u8]) -> Result<Output, Self::Error> {
85        Self::decode_bytes(input)
86    }
87}
88
89impl<Output> Deserialiser<bytes::Bytes, Output> for DeProtobuf
90where
91    Output: prost::Message + Default,
92{
93    type Error = DeBinaryError;
94
95    fn deserialise(input: bytes::Bytes) -> Result<Output, Self::Error> {
96        Self::decode_bytes(input.as_ref())
97    }
98}
99
100impl DeProtobuf {
101    /// Decodes a byte slice into the target `Output` type using [`prost`].
102    pub fn decode_bytes<Output>(input: &[u8]) -> Result<Output, DeBinaryError>
103    where
104        Output: prost::Message + Default,
105    {
106        Output::decode(input).map_err(|error| {
107            debug!(
108                %error,
109                ?input,
110                target_type = %std::any::type_name::<Output>(),
111                input_type = "&[u8]",
112                "failed to deserialise via prost::Message"
113            );
114
115            DeBinaryError {
116                payload: input.to_vec(),
117                kind: DeBinaryErrorKind::Proto(error),
118            }
119        })
120    }
121}
122
123#[cfg(test)]
124#[allow(clippy::unwrap_used)] // Test code: panics on bad input are acceptable
125mod tests {
126    use super::*;
127
128    #[test]
129    fn test_de_json_de_bytes_valid() {
130        let input = br#"{"a":1,"b":"hello"}"#;
131        let result: std::collections::HashMap<String, serde_json::Value> =
132            DeJson::de_bytes(input).unwrap();
133        assert_eq!(result["a"], serde_json::json!(1));
134        assert_eq!(result["b"], serde_json::json!("hello"));
135    }
136
137    #[test]
138    fn test_de_json_de_bytes_invalid_json() {
139        let input = b"not valid json";
140        let result = DeJson::de_bytes::<serde_json::Value>(input);
141        assert!(result.is_err());
142        let err = result.unwrap_err();
143        assert_eq!(err.payload, input.to_vec());
144        assert!(matches!(err.kind, DeBinaryErrorKind::Serde(_)));
145    }
146
147    #[test]
148    fn test_de_json_deserialiser_trait_bytes() {
149        let input = bytes::Bytes::from(r#"42"#);
150        let result: u64 = <DeJson as Deserialiser<bytes::Bytes, _>>::deserialise(input).unwrap();
151        assert_eq!(result, 42);
152    }
153}