spacetimedb_sats/bsatn/
ser.rs

1use crate::buffer::BufWriter;
2use crate::ser::{self, Error, ForwardNamedToSeqProduct, SerializeArray, SerializeSeqProduct};
3use crate::AlgebraicValue;
4use crate::{i256, u256};
5use core::fmt;
6
7/// Defines the BSATN serialization data format.
8pub struct Serializer<'a, W> {
9    writer: &'a mut W,
10}
11
12impl<'a, W> Serializer<'a, W> {
13    /// Returns a serializer using the given `writer`.
14    pub fn new(writer: &'a mut W) -> Self {
15        Self { writer }
16    }
17
18    /// Reborrows the serializer.
19    #[inline]
20    fn reborrow(&mut self) -> Serializer<'_, W> {
21        Serializer { writer: self.writer }
22    }
23}
24
25impl<W: BufWriter> Serializer<'_, W> {
26    /// Directly write `bytes` to the writer.
27    /// This is a raw API. Only use this if you know what you are doing.
28    #[inline(always)]
29    #[doc(hidden)]
30    pub fn raw_write_bytes(self, bytes: &[u8]) {
31        self.writer.put_slice(bytes);
32    }
33}
34
35/// An error during BSATN serialization.
36#[derive(Debug, Clone)]
37// TODO: rename to EncodeError
38pub struct BsatnError {
39    /// The error message for the BSATN error.
40    custom: String,
41}
42
43impl fmt::Display for BsatnError {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        f.write_str(&self.custom)
46    }
47}
48impl std::error::Error for BsatnError {}
49
50impl Error for BsatnError {
51    fn custom<T: fmt::Display>(msg: T) -> Self {
52        let custom = msg.to_string();
53        Self { custom }
54    }
55}
56
57/// Writes `len` converted to a `u32` to `writer`.
58///
59/// Errors if `len` would not fit in a `u32`.
60fn put_len(writer: &mut impl BufWriter, len: usize) -> Result<(), BsatnError> {
61    let len = len.try_into().map_err(|_| BsatnError::custom("len too long"))?;
62    writer.put_u32(len);
63    Ok(())
64}
65
66impl<W: BufWriter> ser::Serializer for Serializer<'_, W> {
67    type Ok = ();
68    type Error = BsatnError;
69    type SerializeArray = Self;
70    type SerializeSeqProduct = Self;
71    type SerializeNamedProduct = ForwardNamedToSeqProduct<Self>;
72
73    fn serialize_bool(self, v: bool) -> Result<Self::Ok, Self::Error> {
74        self.writer.put_u8(v as u8);
75        Ok(())
76    }
77    fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> {
78        self.writer.put_u8(v);
79        Ok(())
80    }
81    fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> {
82        self.writer.put_u16(v);
83        Ok(())
84    }
85    fn serialize_u32(self, v: u32) -> Result<Self::Ok, Self::Error> {
86        self.writer.put_u32(v);
87        Ok(())
88    }
89    fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {
90        self.writer.put_u64(v);
91        Ok(())
92    }
93    fn serialize_u128(self, v: u128) -> Result<Self::Ok, Self::Error> {
94        self.writer.put_u128(v);
95        Ok(())
96    }
97    fn serialize_u256(self, v: u256) -> Result<Self::Ok, Self::Error> {
98        self.writer.put_u256(v);
99        Ok(())
100    }
101    fn serialize_i8(self, v: i8) -> Result<Self::Ok, Self::Error> {
102        self.writer.put_i8(v);
103        Ok(())
104    }
105    fn serialize_i16(self, v: i16) -> Result<Self::Ok, Self::Error> {
106        self.writer.put_i16(v);
107        Ok(())
108    }
109    fn serialize_i32(self, v: i32) -> Result<Self::Ok, Self::Error> {
110        self.writer.put_i32(v);
111        Ok(())
112    }
113    fn serialize_i64(self, v: i64) -> Result<Self::Ok, Self::Error> {
114        self.writer.put_i64(v);
115        Ok(())
116    }
117    fn serialize_i128(self, v: i128) -> Result<Self::Ok, Self::Error> {
118        self.writer.put_i128(v);
119        Ok(())
120    }
121    fn serialize_i256(self, v: i256) -> Result<Self::Ok, Self::Error> {
122        self.writer.put_i256(v);
123        Ok(())
124    }
125    fn serialize_f32(self, v: f32) -> Result<Self::Ok, Self::Error> {
126        self.writer.put_u32(v.to_bits());
127        Ok(())
128    }
129    fn serialize_f64(self, v: f64) -> Result<Self::Ok, Self::Error> {
130        self.writer.put_u64(v.to_bits());
131        Ok(())
132    }
133    fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
134        self.serialize_bytes(v.as_bytes())
135    }
136    fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
137        put_len(self.writer, v.len())?; // N.B. `v.len() > u32::MAX` isn't allowed.
138        self.writer.put_slice(v);
139        Ok(())
140    }
141    fn serialize_array(self, len: usize) -> Result<Self::SerializeArray, Self::Error> {
142        put_len(self.writer, len)?; // N.B. `len > u32::MAX` isn't allowed.
143        Ok(self)
144    }
145    fn serialize_seq_product(self, _len: usize) -> Result<Self::SerializeSeqProduct, Self::Error> {
146        Ok(self)
147    }
148    fn serialize_named_product(self, len: usize) -> Result<Self::SerializeNamedProduct, Self::Error> {
149        // Serialize named like unnamed.
150        self.serialize_seq_product(len).map(ForwardNamedToSeqProduct::new)
151    }
152    fn serialize_variant<T: super::Serialize + ?Sized>(
153        self,
154        tag: u8,
155        _name: Option<&str>,
156        value: &T,
157    ) -> Result<Self::Ok, Self::Error> {
158        self.writer.put_u8(tag);
159        value.serialize(self)
160    }
161
162    unsafe fn serialize_bsatn(self, ty: &crate::AlgebraicType, bsatn: &[u8]) -> Result<Self::Ok, Self::Error> {
163        debug_assert!(AlgebraicValue::decode(ty, &mut { bsatn }).is_ok());
164        self.writer.put_slice(bsatn);
165        Ok(())
166    }
167
168    unsafe fn serialize_bsatn_in_chunks<'a, I: Clone + Iterator<Item = &'a [u8]>>(
169        self,
170        ty: &crate::AlgebraicType,
171        total_bsatn_len: usize,
172        bsatn: I,
173    ) -> Result<Self::Ok, Self::Error> {
174        debug_assert!(total_bsatn_len <= isize::MAX as usize);
175        debug_assert!(AlgebraicValue::decode(ty, &mut &*concat_bytes_slow(total_bsatn_len, bsatn.clone())).is_ok());
176
177        for chunk in bsatn {
178            self.writer.put_slice(chunk);
179        }
180        Ok(())
181    }
182
183    unsafe fn serialize_str_in_chunks<'a, I: Clone + Iterator<Item = &'a [u8]>>(
184        self,
185        total_len: usize,
186        string: I,
187    ) -> Result<Self::Ok, Self::Error> {
188        debug_assert!(total_len <= isize::MAX as usize);
189        debug_assert!(String::from_utf8(concat_bytes_slow(total_len, string.clone())).is_ok());
190
191        // We need to len-prefix to make this BSATN.
192        put_len(self.writer, total_len)?;
193
194        for chunk in string {
195            self.writer.put_slice(chunk);
196        }
197        Ok(())
198    }
199}
200
201/// Concatenates `chunks` into a `Vec<u8>`.
202fn concat_bytes_slow<'a>(cap: usize, chunks: impl Iterator<Item = &'a [u8]>) -> Vec<u8> {
203    let mut bytes = Vec::with_capacity(cap);
204    for chunk in chunks {
205        bytes.extend(chunk);
206    }
207    assert_eq!(bytes.len(), cap);
208    bytes
209}
210
211impl<W: BufWriter> SerializeArray for Serializer<'_, W> {
212    type Ok = ();
213    type Error = BsatnError;
214
215    fn serialize_element<T: super::Serialize + ?Sized>(&mut self, elem: &T) -> Result<(), Self::Error> {
216        elem.serialize(self.reborrow())
217    }
218
219    fn end(self) -> Result<Self::Ok, Self::Error> {
220        Ok(())
221    }
222}
223
224impl<W: BufWriter> SerializeSeqProduct for Serializer<'_, W> {
225    type Ok = ();
226    type Error = BsatnError;
227
228    fn serialize_element<T: super::Serialize + ?Sized>(&mut self, elem: &T) -> Result<(), Self::Error> {
229        elem.serialize(self.reborrow())
230    }
231    fn end(self) -> Result<Self::Ok, Self::Error> {
232        Ok(())
233    }
234}