spacetimedb_sats/algebraic_value/
ser.rs

1use crate::bsatn::decode;
2use crate::de::DeserializeSeed;
3use crate::ser::{self, ForwardNamedToSeqProduct, Serialize};
4use crate::{i256, u256, WithTypespace};
5use crate::{AlgebraicValue, ArrayValue, F32, F64};
6use core::convert::Infallible;
7use core::mem::MaybeUninit;
8use core::ptr;
9use second_stack::uninit_slice;
10use std::alloc::{self, Layout};
11
12/// Serialize `x` as an [`AlgebraicValue`].
13pub fn value_serialize(x: &(impl Serialize + ?Sized)) -> AlgebraicValue {
14    x.serialize(ValueSerializer).unwrap_or_else(|e| match e {})
15}
16
17/// An implementation of [`Serializer`](ser::Serializer)
18/// where the output of serialization is an `AlgebraicValue`.
19pub struct ValueSerializer;
20
21macro_rules! method {
22    ($name:ident -> $t:ty) => {
23        fn $name(self, v: $t) -> Result<Self::Ok, Self::Error> {
24            Ok(v.into())
25        }
26    };
27}
28
29impl ser::Serializer for ValueSerializer {
30    type Ok = AlgebraicValue;
31    type Error = Infallible;
32
33    type SerializeArray = SerializeArrayValue;
34    type SerializeSeqProduct = SerializeProductValue;
35    type SerializeNamedProduct = ForwardNamedToSeqProduct<SerializeProductValue>;
36
37    method!(serialize_bool -> bool);
38    method!(serialize_u8 -> u8);
39    method!(serialize_u16 -> u16);
40    method!(serialize_u32 -> u32);
41    method!(serialize_u64 -> u64);
42    method!(serialize_u128 -> u128);
43    method!(serialize_u256 -> u256);
44    method!(serialize_i8 -> i8);
45    method!(serialize_i16 -> i16);
46    method!(serialize_i32 -> i32);
47    method!(serialize_i64 -> i64);
48    method!(serialize_i128 -> i128);
49    method!(serialize_i256 -> i256);
50    method!(serialize_f32 -> f32);
51    method!(serialize_f64 -> f64);
52
53    fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
54        Ok(AlgebraicValue::String(v.into()))
55    }
56    fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
57        Ok(AlgebraicValue::Bytes(v.into()))
58    }
59
60    fn serialize_array(self, len: usize) -> Result<Self::SerializeArray, Self::Error> {
61        Ok(SerializeArrayValue {
62            len: Some(len),
63            array: Default::default(),
64        })
65    }
66
67    fn serialize_seq_product(self, len: usize) -> Result<Self::SerializeSeqProduct, Self::Error> {
68        Ok(SerializeProductValue {
69            elements: Vec::with_capacity(len),
70        })
71    }
72
73    fn serialize_named_product(self, len: usize) -> Result<Self::SerializeNamedProduct, Self::Error> {
74        ForwardNamedToSeqProduct::forward(self, len)
75    }
76
77    fn serialize_variant<T: ser::Serialize + ?Sized>(
78        self,
79        tag: u8,
80        _name: Option<&str>,
81        value: &T,
82    ) -> Result<Self::Ok, Self::Error> {
83        value.serialize(self).map(|v| AlgebraicValue::sum(tag, v))
84    }
85
86    unsafe fn serialize_bsatn<Ty>(self, ty: &Ty, mut bsatn: &[u8]) -> Result<Self::Ok, Self::Error>
87    where
88        for<'a, 'de> WithTypespace<'a, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
89    {
90        let res = decode(ty, &mut bsatn);
91        // SAFETY: Caller promised that `res.is_ok()`.
92        let val = unsafe { res.unwrap_unchecked() };
93        Ok(val.into())
94    }
95
96    unsafe fn serialize_bsatn_in_chunks<'a, Ty, I: Iterator<Item = &'a [u8]>>(
97        self,
98        ty: &Ty,
99        total_bsatn_len: usize,
100        chunks: I,
101    ) -> Result<Self::Ok, Self::Error>
102    where
103        for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into<AlgebraicValue>>,
104    {
105        // SAFETY: Caller promised `total_bsatn_len == chunks.map(|c| c.len()).sum() <= isize::MAX`.
106        unsafe {
107            concat_byte_chunks_buf(total_bsatn_len, chunks, |bsatn| {
108                // SAFETY: Caller promised `AlgebraicValue::decode(ty, &mut bytes).is_ok()`.
109                ValueSerializer.serialize_bsatn(ty, bsatn)
110            })
111        }
112    }
113
114    unsafe fn serialize_str_in_chunks<'a, I: Iterator<Item = &'a [u8]>>(
115        self,
116        total_len: usize,
117        string: I,
118    ) -> Result<Self::Ok, Self::Error> {
119        // SAFETY: Caller promised `total_len == string.map(|c| c.len()).sum() <= isize::MAx`.
120        let bytes = unsafe { concat_byte_chunks(total_len, string) };
121
122        // SAFETY: Caller promised `bytes` is UTF-8.
123        let string = unsafe { String::from_utf8_unchecked(bytes) };
124        Ok(string.into_boxed_str().into())
125    }
126}
127
128/// Returns the concatenation of `chunks` that must be of `total_len` as a `Vec<u8>`.
129///
130/// # Safety
131///
132/// - `total_len == chunks.map(|c| c.len()).sum() <= isize::MAX`
133unsafe fn concat_byte_chunks<'a>(total_len: usize, chunks: impl Iterator<Item = &'a [u8]>) -> Vec<u8> {
134    if total_len == 0 {
135        return Vec::new();
136    }
137
138    // Allocate space for `[u8; total_len]` on the heap.
139    let layout = Layout::array::<u8>(total_len);
140    // SAFETY: Caller promised that `total_len <= isize`.
141    let layout = unsafe { layout.unwrap_unchecked() };
142    // SAFETY: We checked above that `layout.size() != 0`.
143    let ptr = unsafe { alloc::alloc(layout) };
144    if ptr.is_null() {
145        alloc::handle_alloc_error(layout);
146    }
147
148    // Copy over each `chunk`.
149    // SAFETY:
150    // 1. `ptr` is valid for writes as we own it
151    //    caller promised that all `chunk`s will fit in `total_len`.
152    // 2. `ptr` points to a new allocation so it cannot overlap with any in `chunks`.
153    unsafe { write_byte_chunks(ptr, chunks) };
154
155    // Convert allocation to a `Vec<u8>`.
156    // SAFETY:
157    // - `ptr` was allocated using global allocator.
158    // - `u8` and `ptr`'s allocation both have alignment of 1.
159    // - `ptr`'s allocation is `total_len <= isize::MAX`.
160    // - `total_len <= total_len` holds.
161    // - `total_len` values were initialized at type `u8`
162    //    as we know `total_len == chunks.map(|c| c.len()).sum()`.
163    unsafe { Vec::from_raw_parts(ptr, total_len, total_len) }
164}
165
166/// Returns the concatenation of `chunks` that must be of `total_len` as a `Vec<u8>`.
167///
168/// # Safety
169///
170/// - `total_len == chunks.map(|c| c.len()).sum() <= isize::MAX`
171pub unsafe fn concat_byte_chunks_buf<'a, R>(
172    total_len: usize,
173    chunks: impl Iterator<Item = &'a [u8]>,
174    run: impl FnOnce(&[u8]) -> R,
175) -> R {
176    uninit_slice(total_len, |buf: &mut [MaybeUninit<u8>]| {
177        let dst = buf.as_mut_ptr().cast();
178        debug_assert_eq!(total_len, buf.len());
179        // SAFETY:
180        // 1. `buf.len() == total_len`
181        // 2. `buf` cannot overlap with anything yielded by `var_iter`.
182        unsafe { write_byte_chunks(dst, chunks) }
183        // SAFETY: Every byte of `buf` was initialized in the previous call
184        // as we know that `total_len == var_iter.map(|c| c.len()).sum()`.
185        let bytes = unsafe { slice_assume_init_ref(buf) };
186        run(bytes)
187    })
188}
189
190/// Copies over each `chunk` in `chunks` to `dst`, writing `total_len` bytes to `dst`.
191///
192/// # Safety
193///
194/// Let `total_len == chunks.map(|c| c.len()).sum()`.
195/// 1. `dst` must be valid for writes for `total_len` bytes.
196/// 2. `dst..(dst + total_len)` does not overlap with any slice yielded by `chunks`.
197unsafe fn write_byte_chunks<'a>(mut dst: *mut u8, chunks: impl Iterator<Item = &'a [u8]>) {
198    // Copy over each `chunk`, moving `dst` by `chunk.len()` time.
199    for chunk in chunks {
200        let len = chunk.len();
201        // SAFETY:
202        // - By line above, `chunk` is valid for reads for `len` bytes.
203        // - By (1) `dst` is valid for writes as promised by caller
204        //   and that all `chunk`s will fit in `total_len`.
205        //   This entails that `dst..dst + len` is always in bounds of the allocation.
206        // - `chunk` and `dst` are trivially properly aligned (`align_of::<u8>() == 1`).
207        // - By (2) derived pointers of `dst` cannot overlap with `chunk`.
208        unsafe {
209            ptr::copy_nonoverlapping(chunk.as_ptr(), dst, len);
210        }
211        // SAFETY: Same as (1).
212        dst = unsafe { dst.add(len) };
213    }
214}
215
216/// Convert a `[MaybeUninit<T>]` into a `[T]` by asserting all elements are initialized.
217///
218/// Identical copy of the source of `MaybeUninit::slice_assume_init_ref`, but that's not stabilized.
219/// <https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#method.slice_assume_init_ref>
220///
221/// # Safety
222///
223/// All elements of `slice` must be initialized.
224pub const unsafe fn slice_assume_init_ref<T>(slice: &[MaybeUninit<T>]) -> &[T] {
225    // SAFETY: casting `slice` to a `*const [T]` is safe since the caller guarantees that
226    // `slice` is initialized, and `MaybeUninit` is guaranteed to have the same layout as `T`.
227    // The pointer obtained is valid since it refers to memory owned by `slice` which is a
228    // reference and thus guaranteed to be valid for reads.
229    unsafe { &*(slice as *const [MaybeUninit<T>] as *const [T]) }
230}
231
232/// Continuation for serializing an array.
233pub struct SerializeArrayValue {
234    /// For efficiency, the first time `serialize_element` is done,
235    /// this is used to allocate with capacity.
236    len: Option<usize>,
237    /// The array being built.
238    array: ArrayValueBuilder,
239}
240
241impl ser::SerializeArray for SerializeArrayValue {
242    type Ok = AlgebraicValue;
243    type Error = <ValueSerializer as ser::Serializer>::Error;
244
245    fn serialize_element<T: ser::Serialize + ?Sized>(&mut self, elem: &T) -> Result<(), Self::Error> {
246        self.array
247            .push(value_serialize(elem), self.len.take())
248            .expect("heterogeneous array");
249        Ok(())
250    }
251
252    fn end(self) -> Result<Self::Ok, Self::Error> {
253        Ok(ArrayValue::from(self.array).into())
254    }
255}
256
257/// A builder for [`ArrayValue`]s
258#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
259enum ArrayValueBuilder {
260    /// An array of [`SumValue`](crate::SumValue)s.
261    Sum(Vec<crate::SumValue>),
262    /// An array of [`ProductValue`](crate::ProductValue)s.
263    Product(Vec<crate::ProductValue>),
264    /// An array of [`bool`]s.
265    Bool(Vec<bool>),
266    /// An array of [`i8`]s.
267    I8(Vec<i8>),
268    /// An array of [`u8`]s.
269    U8(Vec<u8>),
270    /// An array of [`i16`]s.
271    I16(Vec<i16>),
272    /// An array of [`u16`]s.
273    U16(Vec<u16>),
274    /// An array of [`i32`]s.
275    I32(Vec<i32>),
276    /// An array of [`u32`]s.
277    U32(Vec<u32>),
278    /// An array of [`i64`]s.
279    I64(Vec<i64>),
280    /// An array of [`u64`]s.
281    U64(Vec<u64>),
282    /// An array of [`i128`]s.
283    I128(Vec<i128>),
284    /// An array of [`u128`]s.
285    U128(Vec<u128>),
286    /// An array of [`i256`]s.
287    I256(Vec<i256>),
288    /// An array of [`u256`]s.
289    U256(Vec<u256>),
290    /// An array of totally ordered [`F32`]s.
291    F32(Vec<F32>),
292    /// An array of totally ordered [`F64`]s.
293    F64(Vec<F64>),
294    /// An array of UTF-8 strings.
295    String(Vec<Box<str>>),
296    /// An array of arrays.
297    Array(Vec<ArrayValue>),
298}
299
300impl ArrayValueBuilder {
301    /// Returns the length of the array.
302    fn len(&self) -> usize {
303        match self {
304            Self::Sum(v) => v.len(),
305            Self::Product(v) => v.len(),
306            Self::Bool(v) => v.len(),
307            Self::I8(v) => v.len(),
308            Self::U8(v) => v.len(),
309            Self::I16(v) => v.len(),
310            Self::U16(v) => v.len(),
311            Self::I32(v) => v.len(),
312            Self::U32(v) => v.len(),
313            Self::I64(v) => v.len(),
314            Self::U64(v) => v.len(),
315            Self::I128(v) => v.len(),
316            Self::U128(v) => v.len(),
317            Self::I256(v) => v.len(),
318            Self::U256(v) => v.len(),
319            Self::F32(v) => v.len(),
320            Self::F64(v) => v.len(),
321            Self::String(v) => v.len(),
322            Self::Array(v) => v.len(),
323        }
324    }
325
326    /// Returns whether the array is empty.
327    #[must_use]
328    fn is_empty(&self) -> bool {
329        self.len() == 0
330    }
331
332    /// Returns a singleton array with `val` as its only element.
333    ///
334    /// Optionally allocates the backing `Vec<_>`s with `capacity`.
335    fn from_one_with_capacity(val: AlgebraicValue, capacity: Option<usize>) -> Self {
336        fn vec<T>(e: T, c: Option<usize>) -> Vec<T> {
337            let mut vec = c.map_or(Vec::new(), Vec::with_capacity);
338            vec.push(e);
339            vec
340        }
341
342        match val {
343            AlgebraicValue::Sum(x) => vec(x, capacity).into(),
344            AlgebraicValue::Product(x) => vec(x, capacity).into(),
345            AlgebraicValue::Bool(x) => vec(x, capacity).into(),
346            AlgebraicValue::I8(x) => vec(x, capacity).into(),
347            AlgebraicValue::U8(x) => vec(x, capacity).into(),
348            AlgebraicValue::I16(x) => vec(x, capacity).into(),
349            AlgebraicValue::U16(x) => vec(x, capacity).into(),
350            AlgebraicValue::I32(x) => vec(x, capacity).into(),
351            AlgebraicValue::U32(x) => vec(x, capacity).into(),
352            AlgebraicValue::I64(x) => vec(x, capacity).into(),
353            AlgebraicValue::U64(x) => vec(x, capacity).into(),
354            AlgebraicValue::I128(x) => vec(x.0, capacity).into(),
355            AlgebraicValue::U128(x) => vec(x.0, capacity).into(),
356            AlgebraicValue::I256(x) => vec(*x, capacity).into(),
357            AlgebraicValue::U256(x) => vec(*x, capacity).into(),
358            AlgebraicValue::F32(x) => vec(x, capacity).into(),
359            AlgebraicValue::F64(x) => vec(x, capacity).into(),
360            AlgebraicValue::String(x) => vec(x, capacity).into(),
361            AlgebraicValue::Array(x) => vec(x, capacity).into(),
362            AlgebraicValue::Min | AlgebraicValue::Max => panic!("not defined for Min/Max"),
363        }
364    }
365
366    /// Pushes the value `val` onto the array `self`
367    /// or returns back `Err(val)` if there was a type mismatch
368    /// between the base type of the array and `val`.
369    ///
370    /// Optionally allocates the backing `Vec<_>`s with `capacity`.
371    fn push(&mut self, val: AlgebraicValue, capacity: Option<usize>) -> Result<(), AlgebraicValue> {
372        match (self, val) {
373            (Self::Sum(v), AlgebraicValue::Sum(val)) => v.push(val),
374            (Self::Product(v), AlgebraicValue::Product(val)) => v.push(val),
375            (Self::Bool(v), AlgebraicValue::Bool(val)) => v.push(val),
376            (Self::I8(v), AlgebraicValue::I8(val)) => v.push(val),
377            (Self::U8(v), AlgebraicValue::U8(val)) => v.push(val),
378            (Self::I16(v), AlgebraicValue::I16(val)) => v.push(val),
379            (Self::U16(v), AlgebraicValue::U16(val)) => v.push(val),
380            (Self::I32(v), AlgebraicValue::I32(val)) => v.push(val),
381            (Self::U32(v), AlgebraicValue::U32(val)) => v.push(val),
382            (Self::I64(v), AlgebraicValue::I64(val)) => v.push(val),
383            (Self::U64(v), AlgebraicValue::U64(val)) => v.push(val),
384            (Self::I128(v), AlgebraicValue::I128(val)) => v.push(val.0),
385            (Self::U128(v), AlgebraicValue::U128(val)) => v.push(val.0),
386            (Self::I256(v), AlgebraicValue::I256(val)) => v.push(*val),
387            (Self::U256(v), AlgebraicValue::U256(val)) => v.push(*val),
388            (Self::F32(v), AlgebraicValue::F32(val)) => v.push(val),
389            (Self::F64(v), AlgebraicValue::F64(val)) => v.push(val),
390            (Self::String(v), AlgebraicValue::String(val)) => v.push(val),
391            (Self::Array(v), AlgebraicValue::Array(val)) => v.push(val),
392            (me, val) if me.is_empty() => *me = Self::from_one_with_capacity(val, capacity),
393            (_, val) => return Err(val),
394        }
395        Ok(())
396    }
397}
398
399impl From<ArrayValueBuilder> for ArrayValue {
400    fn from(value: ArrayValueBuilder) -> Self {
401        use ArrayValueBuilder::*;
402        match value {
403            Sum(v) => Self::Sum(v.into()),
404            Product(v) => Self::Product(v.into()),
405            Bool(v) => Self::Bool(v.into()),
406            I8(v) => Self::I8(v.into()),
407            U8(v) => Self::U8(v.into()),
408            I16(v) => Self::I16(v.into()),
409            U16(v) => Self::U16(v.into()),
410            I32(v) => Self::I32(v.into()),
411            U32(v) => Self::U32(v.into()),
412            I64(v) => Self::I64(v.into()),
413            U64(v) => Self::U64(v.into()),
414            I128(v) => Self::I128(v.into()),
415            U128(v) => Self::U128(v.into()),
416            I256(v) => Self::I256(v.into()),
417            U256(v) => Self::U256(v.into()),
418            F32(v) => Self::F32(v.into()),
419            F64(v) => Self::F64(v.into()),
420            String(v) => Self::String(v.into()),
421            Array(v) => Self::Array(v.into()),
422        }
423    }
424}
425
426impl Default for ArrayValueBuilder {
427    /// The default `ArrayValue` is an empty array of sum values.
428    fn default() -> Self {
429        Self::from(Vec::<crate::SumValue>::default())
430    }
431}
432
433macro_rules! impl_from_array {
434    ($el:ty, $var:ident) => {
435        impl From<Vec<$el>> for ArrayValueBuilder {
436            fn from(v: Vec<$el>) -> Self {
437                Self::$var(v)
438            }
439        }
440    };
441}
442
443impl_from_array!(crate::SumValue, Sum);
444impl_from_array!(crate::ProductValue, Product);
445impl_from_array!(bool, Bool);
446impl_from_array!(i8, I8);
447impl_from_array!(u8, U8);
448impl_from_array!(i16, I16);
449impl_from_array!(u16, U16);
450impl_from_array!(i32, I32);
451impl_from_array!(u32, U32);
452impl_from_array!(i64, I64);
453impl_from_array!(u64, U64);
454impl_from_array!(i128, I128);
455impl_from_array!(u128, U128);
456impl_from_array!(i256, I256);
457impl_from_array!(u256, U256);
458impl_from_array!(F32, F32);
459impl_from_array!(F64, F64);
460impl_from_array!(Box<str>, String);
461impl_from_array!(ArrayValue, Array);
462
463/// Continuation for serializing a map value.
464pub struct SerializeProductValue {
465    /// The elements serialized so far.
466    elements: Vec<AlgebraicValue>,
467}
468
469impl ser::SerializeSeqProduct for SerializeProductValue {
470    type Ok = AlgebraicValue;
471    type Error = <ValueSerializer as ser::Serializer>::Error;
472
473    fn serialize_element<T: ser::Serialize + ?Sized>(&mut self, elem: &T) -> Result<(), Self::Error> {
474        self.elements.push(value_serialize(elem));
475        Ok(())
476    }
477    fn end(self) -> Result<Self::Ok, Self::Error> {
478        Ok(AlgebraicValue::product(self.elements))
479    }
480}