spacetimedb_sats/algebraic_value/
ser.rs

1use crate::ser::{self, ForwardNamedToSeqProduct, Serialize};
2use crate::{i256, u256};
3use crate::{AlgebraicType, AlgebraicValue, ArrayValue, F32, F64};
4use core::convert::Infallible;
5use core::mem::MaybeUninit;
6use core::ptr;
7use second_stack::uninit_slice;
8use std::alloc::{self, Layout};
9
10/// Serialize `x` as an [`AlgebraicValue`].
11pub fn value_serialize(x: &(impl Serialize + ?Sized)) -> AlgebraicValue {
12    x.serialize(ValueSerializer).unwrap_or_else(|e| match e {})
13}
14
15/// An implementation of [`Serializer`](ser::Serializer)
16/// where the output of serialization is an `AlgebraicValue`.
17pub struct ValueSerializer;
18
19macro_rules! method {
20    ($name:ident -> $t:ty) => {
21        fn $name(self, v: $t) -> Result<Self::Ok, Self::Error> {
22            Ok(v.into())
23        }
24    };
25}
26
27impl ser::Serializer for ValueSerializer {
28    type Ok = AlgebraicValue;
29    type Error = Infallible;
30
31    type SerializeArray = SerializeArrayValue;
32    type SerializeSeqProduct = SerializeProductValue;
33    type SerializeNamedProduct = ForwardNamedToSeqProduct<SerializeProductValue>;
34
35    method!(serialize_bool -> bool);
36    method!(serialize_u8 -> u8);
37    method!(serialize_u16 -> u16);
38    method!(serialize_u32 -> u32);
39    method!(serialize_u64 -> u64);
40    method!(serialize_u128 -> u128);
41    method!(serialize_u256 -> u256);
42    method!(serialize_i8 -> i8);
43    method!(serialize_i16 -> i16);
44    method!(serialize_i32 -> i32);
45    method!(serialize_i64 -> i64);
46    method!(serialize_i128 -> i128);
47    method!(serialize_i256 -> i256);
48    method!(serialize_f32 -> f32);
49    method!(serialize_f64 -> f64);
50
51    fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
52        Ok(AlgebraicValue::String(v.into()))
53    }
54    fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
55        Ok(AlgebraicValue::Bytes(v.into()))
56    }
57
58    fn serialize_array(self, len: usize) -> Result<Self::SerializeArray, Self::Error> {
59        Ok(SerializeArrayValue {
60            len: Some(len),
61            array: Default::default(),
62        })
63    }
64
65    fn serialize_seq_product(self, len: usize) -> Result<Self::SerializeSeqProduct, Self::Error> {
66        Ok(SerializeProductValue {
67            elements: Vec::with_capacity(len),
68        })
69    }
70
71    fn serialize_named_product(self, len: usize) -> Result<Self::SerializeNamedProduct, Self::Error> {
72        ForwardNamedToSeqProduct::forward(self, len)
73    }
74
75    fn serialize_variant<T: ser::Serialize + ?Sized>(
76        self,
77        tag: u8,
78        _name: Option<&str>,
79        value: &T,
80    ) -> Result<Self::Ok, Self::Error> {
81        value.serialize(self).map(|v| AlgebraicValue::sum(tag, v))
82    }
83
84    unsafe fn serialize_bsatn(self, ty: &AlgebraicType, mut bsatn: &[u8]) -> Result<Self::Ok, Self::Error> {
85        let res = AlgebraicValue::decode(ty, &mut bsatn);
86        // SAFETY: Caller promised that `res.is_ok()`.
87        Ok(unsafe { res.unwrap_unchecked() })
88    }
89
90    unsafe fn serialize_bsatn_in_chunks<'a, I: Iterator<Item = &'a [u8]>>(
91        self,
92        ty: &crate::AlgebraicType,
93        total_bsatn_len: usize,
94        chunks: I,
95    ) -> Result<Self::Ok, Self::Error> {
96        // SAFETY: Caller promised `total_bsatn_len == chunks.map(|c| c.len()).sum() <= isize::MAX`.
97        unsafe {
98            concat_byte_chunks_buf(total_bsatn_len, chunks, |bsatn| {
99                // SAFETY: Caller promised `AlgebraicValue::decode(ty, &mut bytes).is_ok()`.
100                ValueSerializer.serialize_bsatn(ty, bsatn)
101            })
102        }
103    }
104
105    unsafe fn serialize_str_in_chunks<'a, I: Iterator<Item = &'a [u8]>>(
106        self,
107        total_len: usize,
108        string: I,
109    ) -> Result<Self::Ok, Self::Error> {
110        // SAFETY: Caller promised `total_len == string.map(|c| c.len()).sum() <= isize::MAx`.
111        let bytes = unsafe { concat_byte_chunks(total_len, string) };
112
113        // SAFETY: Caller promised `bytes` is UTF-8.
114        let string = unsafe { String::from_utf8_unchecked(bytes) };
115        Ok(string.into_boxed_str().into())
116    }
117}
118
119/// Returns the concatenation of `chunks` that must be of `total_len` as a `Vec<u8>`.
120///
121/// # Safety
122///
123/// - `total_len == chunks.map(|c| c.len()).sum() <= isize::MAX`
124unsafe fn concat_byte_chunks<'a>(total_len: usize, chunks: impl Iterator<Item = &'a [u8]>) -> Vec<u8> {
125    if total_len == 0 {
126        return Vec::new();
127    }
128
129    // Allocate space for `[u8; total_len]` on the heap.
130    let layout = Layout::array::<u8>(total_len);
131    // SAFETY: Caller promised that `total_len <= isize`.
132    let layout = unsafe { layout.unwrap_unchecked() };
133    // SAFETY: We checked above that `layout.size() != 0`.
134    let ptr = unsafe { alloc::alloc(layout) };
135    if ptr.is_null() {
136        alloc::handle_alloc_error(layout);
137    }
138
139    // Copy over each `chunk`.
140    // SAFETY:
141    // 1. `ptr` is valid for writes as we own it
142    //    caller promised that all `chunk`s will fit in `total_len`.
143    // 2. `ptr` points to a new allocation so it cannot overlap with any in `chunks`.
144    unsafe { write_byte_chunks(ptr, chunks) };
145
146    // Convert allocation to a `Vec<u8>`.
147    // SAFETY:
148    // - `ptr` was allocated using global allocator.
149    // - `u8` and `ptr`'s allocation both have alignment of 1.
150    // - `ptr`'s allocation is `total_len <= isize::MAX`.
151    // - `total_len <= total_len` holds.
152    // - `total_len` values were initialized at type `u8`
153    //    as we know `total_len == chunks.map(|c| c.len()).sum()`.
154    unsafe { Vec::from_raw_parts(ptr, total_len, total_len) }
155}
156
157/// Returns the concatenation of `chunks` that must be of `total_len` as a `Vec<u8>`.
158///
159/// # Safety
160///
161/// - `total_len == chunks.map(|c| c.len()).sum() <= isize::MAX`
162pub unsafe fn concat_byte_chunks_buf<'a, R>(
163    total_len: usize,
164    chunks: impl Iterator<Item = &'a [u8]>,
165    run: impl FnOnce(&[u8]) -> R,
166) -> R {
167    uninit_slice(total_len, |buf: &mut [MaybeUninit<u8>]| {
168        let dst = buf.as_mut_ptr().cast();
169        debug_assert_eq!(total_len, buf.len());
170        // SAFETY:
171        // 1. `buf.len() == total_len`
172        // 2. `buf` cannot overlap with anything yielded by `var_iter`.
173        unsafe { write_byte_chunks(dst, chunks) }
174        // SAFETY: Every byte of `buf` was initialized in the previous call
175        // as we know that `total_len == var_iter.map(|c| c.len()).sum()`.
176        let bytes = unsafe { slice_assume_init_ref(buf) };
177        run(bytes)
178    })
179}
180
181/// Copies over each `chunk` in `chunks` to `dst`, writing `total_len` bytes to `dst`.
182///
183/// # Safety
184///
185/// Let `total_len == chunks.map(|c| c.len()).sum()`.
186/// 1. `dst` must be valid for writes for `total_len` bytes.
187/// 2. `dst..(dst + total_len)` does not overlap with any slice yielded by `chunks`.
188unsafe fn write_byte_chunks<'a>(mut dst: *mut u8, chunks: impl Iterator<Item = &'a [u8]>) {
189    // Copy over each `chunk`, moving `dst` by `chunk.len()` time.
190    for chunk in chunks {
191        let len = chunk.len();
192        // SAFETY:
193        // - By line above, `chunk` is valid for reads for `len` bytes.
194        // - By (1) `dst` is valid for writes as promised by caller
195        //   and that all `chunk`s will fit in `total_len`.
196        //   This entails that `dst..dst + len` is always in bounds of the allocation.
197        // - `chunk` and `dst` are trivially properly aligned (`align_of::<u8>() == 1`).
198        // - By (2) derived pointers of `dst` cannot overlap with `chunk`.
199        unsafe {
200            ptr::copy_nonoverlapping(chunk.as_ptr(), dst, len);
201        }
202        // SAFETY: Same as (1).
203        dst = unsafe { dst.add(len) };
204    }
205}
206
207/// Convert a `[MaybeUninit<T>]` into a `[T]` by asserting all elements are initialized.
208///
209/// Identical copy of the source of `MaybeUninit::slice_assume_init_ref`, but that's not stabilized.
210/// <https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#method.slice_assume_init_ref>
211///
212/// # Safety
213///
214/// All elements of `slice` must be initialized.
215pub const unsafe fn slice_assume_init_ref<T>(slice: &[MaybeUninit<T>]) -> &[T] {
216    // SAFETY: casting `slice` to a `*const [T]` is safe since the caller guarantees that
217    // `slice` is initialized, and `MaybeUninit` is guaranteed to have the same layout as `T`.
218    // The pointer obtained is valid since it refers to memory owned by `slice` which is a
219    // reference and thus guaranteed to be valid for reads.
220    unsafe { &*(slice as *const [MaybeUninit<T>] as *const [T]) }
221}
222
223/// Continuation for serializing an array.
224pub struct SerializeArrayValue {
225    /// For efficiency, the first time `serialize_element` is done,
226    /// this is used to allocate with capacity.
227    len: Option<usize>,
228    /// The array being built.
229    array: ArrayValueBuilder,
230}
231
232impl ser::SerializeArray for SerializeArrayValue {
233    type Ok = AlgebraicValue;
234    type Error = <ValueSerializer as ser::Serializer>::Error;
235
236    fn serialize_element<T: ser::Serialize + ?Sized>(&mut self, elem: &T) -> Result<(), Self::Error> {
237        self.array
238            .push(value_serialize(elem), self.len.take())
239            .expect("heterogeneous array");
240        Ok(())
241    }
242
243    fn end(self) -> Result<Self::Ok, Self::Error> {
244        Ok(ArrayValue::from(self.array).into())
245    }
246}
247
248/// A builder for [`ArrayValue`]s
249#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
250enum ArrayValueBuilder {
251    /// An array of [`SumValue`](crate::SumValue)s.
252    Sum(Vec<crate::SumValue>),
253    /// An array of [`ProductValue`](crate::ProductValue)s.
254    Product(Vec<crate::ProductValue>),
255    /// An array of [`bool`]s.
256    Bool(Vec<bool>),
257    /// An array of [`i8`]s.
258    I8(Vec<i8>),
259    /// An array of [`u8`]s.
260    U8(Vec<u8>),
261    /// An array of [`i16`]s.
262    I16(Vec<i16>),
263    /// An array of [`u16`]s.
264    U16(Vec<u16>),
265    /// An array of [`i32`]s.
266    I32(Vec<i32>),
267    /// An array of [`u32`]s.
268    U32(Vec<u32>),
269    /// An array of [`i64`]s.
270    I64(Vec<i64>),
271    /// An array of [`u64`]s.
272    U64(Vec<u64>),
273    /// An array of [`i128`]s.
274    I128(Vec<i128>),
275    /// An array of [`u128`]s.
276    U128(Vec<u128>),
277    /// An array of [`i256`]s.
278    I256(Vec<i256>),
279    /// An array of [`u256`]s.
280    U256(Vec<u256>),
281    /// An array of totally ordered [`F32`]s.
282    F32(Vec<F32>),
283    /// An array of totally ordered [`F64`]s.
284    F64(Vec<F64>),
285    /// An array of UTF-8 strings.
286    String(Vec<Box<str>>),
287    /// An array of arrays.
288    Array(Vec<ArrayValue>),
289}
290
291impl ArrayValueBuilder {
292    /// Returns the length of the array.
293    fn len(&self) -> usize {
294        match self {
295            Self::Sum(v) => v.len(),
296            Self::Product(v) => v.len(),
297            Self::Bool(v) => v.len(),
298            Self::I8(v) => v.len(),
299            Self::U8(v) => v.len(),
300            Self::I16(v) => v.len(),
301            Self::U16(v) => v.len(),
302            Self::I32(v) => v.len(),
303            Self::U32(v) => v.len(),
304            Self::I64(v) => v.len(),
305            Self::U64(v) => v.len(),
306            Self::I128(v) => v.len(),
307            Self::U128(v) => v.len(),
308            Self::I256(v) => v.len(),
309            Self::U256(v) => v.len(),
310            Self::F32(v) => v.len(),
311            Self::F64(v) => v.len(),
312            Self::String(v) => v.len(),
313            Self::Array(v) => v.len(),
314        }
315    }
316
317    /// Returns whether the array is empty.
318    #[must_use]
319    fn is_empty(&self) -> bool {
320        self.len() == 0
321    }
322
323    /// Returns a singleton array with `val` as its only element.
324    ///
325    /// Optionally allocates the backing `Vec<_>`s with `capacity`.
326    fn from_one_with_capacity(val: AlgebraicValue, capacity: Option<usize>) -> Self {
327        fn vec<T>(e: T, c: Option<usize>) -> Vec<T> {
328            let mut vec = c.map_or(Vec::new(), Vec::with_capacity);
329            vec.push(e);
330            vec
331        }
332
333        match val {
334            AlgebraicValue::Sum(x) => vec(x, capacity).into(),
335            AlgebraicValue::Product(x) => vec(x, capacity).into(),
336            AlgebraicValue::Bool(x) => vec(x, capacity).into(),
337            AlgebraicValue::I8(x) => vec(x, capacity).into(),
338            AlgebraicValue::U8(x) => vec(x, capacity).into(),
339            AlgebraicValue::I16(x) => vec(x, capacity).into(),
340            AlgebraicValue::U16(x) => vec(x, capacity).into(),
341            AlgebraicValue::I32(x) => vec(x, capacity).into(),
342            AlgebraicValue::U32(x) => vec(x, capacity).into(),
343            AlgebraicValue::I64(x) => vec(x, capacity).into(),
344            AlgebraicValue::U64(x) => vec(x, capacity).into(),
345            AlgebraicValue::I128(x) => vec(x.0, capacity).into(),
346            AlgebraicValue::U128(x) => vec(x.0, capacity).into(),
347            AlgebraicValue::I256(x) => vec(*x, capacity).into(),
348            AlgebraicValue::U256(x) => vec(*x, capacity).into(),
349            AlgebraicValue::F32(x) => vec(x, capacity).into(),
350            AlgebraicValue::F64(x) => vec(x, capacity).into(),
351            AlgebraicValue::String(x) => vec(x, capacity).into(),
352            AlgebraicValue::Array(x) => vec(x, capacity).into(),
353            AlgebraicValue::Min | AlgebraicValue::Max => panic!("not defined for Min/Max"),
354        }
355    }
356
357    /// Pushes the value `val` onto the array `self`
358    /// or returns back `Err(val)` if there was a type mismatch
359    /// between the base type of the array and `val`.
360    ///
361    /// Optionally allocates the backing `Vec<_>`s with `capacity`.
362    fn push(&mut self, val: AlgebraicValue, capacity: Option<usize>) -> Result<(), AlgebraicValue> {
363        match (self, val) {
364            (Self::Sum(v), AlgebraicValue::Sum(val)) => v.push(val),
365            (Self::Product(v), AlgebraicValue::Product(val)) => v.push(val),
366            (Self::Bool(v), AlgebraicValue::Bool(val)) => v.push(val),
367            (Self::I8(v), AlgebraicValue::I8(val)) => v.push(val),
368            (Self::U8(v), AlgebraicValue::U8(val)) => v.push(val),
369            (Self::I16(v), AlgebraicValue::I16(val)) => v.push(val),
370            (Self::U16(v), AlgebraicValue::U16(val)) => v.push(val),
371            (Self::I32(v), AlgebraicValue::I32(val)) => v.push(val),
372            (Self::U32(v), AlgebraicValue::U32(val)) => v.push(val),
373            (Self::I64(v), AlgebraicValue::I64(val)) => v.push(val),
374            (Self::U64(v), AlgebraicValue::U64(val)) => v.push(val),
375            (Self::I128(v), AlgebraicValue::I128(val)) => v.push(val.0),
376            (Self::U128(v), AlgebraicValue::U128(val)) => v.push(val.0),
377            (Self::I256(v), AlgebraicValue::I256(val)) => v.push(*val),
378            (Self::U256(v), AlgebraicValue::U256(val)) => v.push(*val),
379            (Self::F32(v), AlgebraicValue::F32(val)) => v.push(val),
380            (Self::F64(v), AlgebraicValue::F64(val)) => v.push(val),
381            (Self::String(v), AlgebraicValue::String(val)) => v.push(val),
382            (Self::Array(v), AlgebraicValue::Array(val)) => v.push(val),
383            (me, val) if me.is_empty() => *me = Self::from_one_with_capacity(val, capacity),
384            (_, val) => return Err(val),
385        }
386        Ok(())
387    }
388}
389
390impl From<ArrayValueBuilder> for ArrayValue {
391    fn from(value: ArrayValueBuilder) -> Self {
392        use ArrayValueBuilder::*;
393        match value {
394            Sum(v) => Self::Sum(v.into()),
395            Product(v) => Self::Product(v.into()),
396            Bool(v) => Self::Bool(v.into()),
397            I8(v) => Self::I8(v.into()),
398            U8(v) => Self::U8(v.into()),
399            I16(v) => Self::I16(v.into()),
400            U16(v) => Self::U16(v.into()),
401            I32(v) => Self::I32(v.into()),
402            U32(v) => Self::U32(v.into()),
403            I64(v) => Self::I64(v.into()),
404            U64(v) => Self::U64(v.into()),
405            I128(v) => Self::I128(v.into()),
406            U128(v) => Self::U128(v.into()),
407            I256(v) => Self::I256(v.into()),
408            U256(v) => Self::U256(v.into()),
409            F32(v) => Self::F32(v.into()),
410            F64(v) => Self::F64(v.into()),
411            String(v) => Self::String(v.into()),
412            Array(v) => Self::Array(v.into()),
413        }
414    }
415}
416
417impl Default for ArrayValueBuilder {
418    /// The default `ArrayValue` is an empty array of sum values.
419    fn default() -> Self {
420        Self::from(Vec::<crate::SumValue>::default())
421    }
422}
423
424macro_rules! impl_from_array {
425    ($el:ty, $var:ident) => {
426        impl From<Vec<$el>> for ArrayValueBuilder {
427            fn from(v: Vec<$el>) -> Self {
428                Self::$var(v)
429            }
430        }
431    };
432}
433
434impl_from_array!(crate::SumValue, Sum);
435impl_from_array!(crate::ProductValue, Product);
436impl_from_array!(bool, Bool);
437impl_from_array!(i8, I8);
438impl_from_array!(u8, U8);
439impl_from_array!(i16, I16);
440impl_from_array!(u16, U16);
441impl_from_array!(i32, I32);
442impl_from_array!(u32, U32);
443impl_from_array!(i64, I64);
444impl_from_array!(u64, U64);
445impl_from_array!(i128, I128);
446impl_from_array!(u128, U128);
447impl_from_array!(i256, I256);
448impl_from_array!(u256, U256);
449impl_from_array!(F32, F32);
450impl_from_array!(F64, F64);
451impl_from_array!(Box<str>, String);
452impl_from_array!(ArrayValue, Array);
453
454/// Continuation for serializing a map value.
455pub struct SerializeProductValue {
456    /// The elements serialized so far.
457    elements: Vec<AlgebraicValue>,
458}
459
460impl ser::SerializeSeqProduct for SerializeProductValue {
461    type Ok = AlgebraicValue;
462    type Error = <ValueSerializer as ser::Serializer>::Error;
463
464    fn serialize_element<T: ser::Serialize + ?Sized>(&mut self, elem: &T) -> Result<(), Self::Error> {
465        self.elements.push(value_serialize(elem));
466        Ok(())
467    }
468    fn end(self) -> Result<Self::Ok, Self::Error> {
469        Ok(AlgebraicValue::product(self.elements))
470    }
471}