serde_ndim/
de.rs

1use serde::de::{DeserializeSeed, Error, IgnoredAny, IntoDeserializer, SeqAccess, Visitor};
2use serde::{Deserialize, Deserializer};
3use std::borrow::BorrowMut;
4use std::boxed::Box;
5use std::vec;
6use std::vec::Vec;
7
8macro_rules! forward_visitors {
9    ($(fn $method:ident ($arg:ty);)*) => ($(
10        fn $method<E: Error>(self, arg: $arg) -> Result<Self::Value, E> {
11            self.deserialize_num(arg)
12        }
13    )*);
14}
15
16/// Multi-dimensional shape storage for deserialization.
17pub trait Shape: BorrowMut<[usize]> {
18    /// Minimum number of dimensions.
19    const MIN_DIMS: usize;
20    /// Maximum number of dimensions.
21    const MAX_DIMS: usize;
22
23    /// Create a new shape with all dimensions set to `0`.
24    fn new_zeroed(dims: usize) -> Self;
25
26    /// Get the length of the given dimension (or `None` if the dimension is out of bounds).
27    fn dim_len(&self, dims: usize) -> Option<usize> {
28        self.borrow().get(dims).copied()
29    }
30
31    /// Set the length of the given dimension.
32    fn set_dim_len(&mut self, dims: usize, value: usize) {
33        // Check that we're replacing a `0` placeholder.
34        debug_assert_eq!(self.dim_len(dims), Some(0));
35        self.borrow_mut()[dims] = value;
36    }
37}
38
39impl Shape for Box<[usize]> {
40    const MIN_DIMS: usize = 0;
41    const MAX_DIMS: usize = usize::MAX;
42
43    fn new_zeroed(dims: usize) -> Self {
44        vec![0; dims].into_boxed_slice()
45    }
46}
47
48impl<const DIMS: usize> Shape for [usize; DIMS] {
49    const MIN_DIMS: usize = DIMS;
50    const MAX_DIMS: usize = DIMS;
51
52    fn new_zeroed(dims: usize) -> Self {
53        debug_assert_eq!(dims, DIMS);
54        [0; DIMS]
55    }
56}
57
58#[cfg(feature = "arrayvec")]
59impl<const MAX_DIMS: usize> Shape for arrayvec::ArrayVec<usize, MAX_DIMS> {
60    const MIN_DIMS: usize = 0;
61    const MAX_DIMS: usize = MAX_DIMS;
62
63    fn new_zeroed(dims: usize) -> Self {
64        debug_assert!(dims <= MAX_DIMS);
65        let mut shape = Self::new();
66        shape.extend(core::iter::repeat(0).take(dims));
67        shape
68    }
69}
70
71#[derive(Debug)]
72struct Context<T, S> {
73    data: Vec<T>,
74    shape: Option<S>,
75    current_dim: usize,
76}
77
78impl<'de, T: Deserialize<'de>, S: Shape> Context<T, S> {
79    fn got_number<E: Error>(&mut self) -> Result<(), E> {
80        match &self.shape {
81            Some(shape) => {
82                if self.current_dim < shape.borrow().len() {
83                    // We've seen a sequence at this dims before, but got a number now.
84                    return Err(E::invalid_type(
85                        serde::de::Unexpected::Other("a single number"),
86                        &"a sequence",
87                    ));
88                }
89            }
90            None => {
91                // We've seen a sequence at this dims before, but got a number now.
92                // Once we've seen a numeric value for the first time, this means we reached the innermost dimension.
93                // From now on, start collecting shape info.
94                // To start, allocate the dimension lenghs with placeholders.
95                if self.current_dim < S::MIN_DIMS {
96                    return Err(Error::custom(format_args!(
97                        "didn't reach the expected minimum dims {}, got {}",
98                        S::MIN_DIMS,
99                        self.current_dim,
100                    )));
101                }
102                self.shape = Some(S::new_zeroed(self.current_dim));
103            }
104        }
105        Ok(())
106    }
107
108    fn deserialize_num_from<D: Deserializer<'de>>(
109        &mut self,
110        deserializer: D,
111    ) -> Result<(), D::Error> {
112        self.got_number()?;
113        let value = T::deserialize(deserializer)?;
114        self.data.push(value);
115        Ok(())
116    }
117
118    fn deserialize_num<E: Error>(&mut self, arg: impl IntoDeserializer<'de, E>) -> Result<(), E> {
119        self.deserialize_num_from(arg.into_deserializer())
120    }
121}
122
123impl<'de, T: Deserialize<'de>, S: Shape> Visitor<'de> for &mut Context<T, S> {
124    type Value = ();
125
126    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
127        formatter.write_str("a sequence or a single number")
128    }
129
130    fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
131        // The code paths for when we know the shape and when we're still doing the first
132        // descent are quite different, so we split them into two separate branches.
133        if let Some(shape) = &self.shape {
134            // This is not the first pass anymore, so we've seen all the dimensions.
135            // Check that the current dimension has seen a sequence before and return
136            // its expected length.
137            let expected_len = shape.dim_len(self.current_dim).ok_or_else(|| {
138                Error::invalid_type(serde::de::Unexpected::Seq, &"a single number")
139            })?;
140            self.current_dim += 1;
141            // Consume the expected number of elements.
142            for _ in 0..expected_len {
143                seq.next_element_seed(&mut *self)?
144                    .ok_or_else(|| Error::custom("expected more elements"))?;
145            }
146            // We've seen all the expected elements in this sequence.
147            // Ensure there are no more elements.
148            if seq.next_element::<IgnoredAny>()?.is_some() {
149                return Err(Error::custom("expected end of sequence"));
150            }
151            self.current_dim -= 1;
152        } else {
153            // We're still in the first pass, so we don't know the shape yet.
154            debug_assert!(self.shape.is_none());
155            self.current_dim += 1;
156            if self.current_dim > S::MAX_DIMS {
157                return Err(Error::custom(format_args!(
158                    "maximum dims of {} exceeded",
159                    S::MAX_DIMS
160                )));
161            }
162            // Consume & count all the elements.
163            let mut len = 0;
164            while seq.next_element_seed(&mut *self)?.is_some() {
165                len += 1;
166            }
167            self.current_dim -= 1;
168            // Replace the placeholder `0` with the actual length.
169            let shape = self
170                .shape
171                .as_mut()
172                .expect("internal error: shape should be allocated by now");
173            shape.set_dim_len(self.current_dim, len);
174        }
175        Ok(())
176    }
177
178    forward_visitors! {
179        fn visit_i8(i8);
180        fn visit_i16(i16);
181        fn visit_i32(i32);
182        fn visit_i64(i64);
183        fn visit_u8(u8);
184        fn visit_u16(u16);
185        fn visit_u32(u32);
186        fn visit_u64(u64);
187        fn visit_f32(f32);
188        fn visit_f64(f64);
189        fn visit_i128(i128);
190        fn visit_u128(u128);
191    }
192
193    fn visit_newtype_struct<D: Deserializer<'de>>(
194        self,
195        deserializer: D,
196    ) -> Result<Self::Value, D::Error> {
197        // TODO(?): some deserialize implementations don't treat newtypes as transparent.
198        // If someone complains, add logic for deserializing as actual wrapped newtype.
199        self.deserialize_num_from(deserializer)
200    }
201}
202
203impl<'de, T: Deserialize<'de>, S: Shape> DeserializeSeed<'de> for &mut Context<T, S> {
204    type Value = ();
205
206    fn deserialize<D>(self, deserializer: D) -> Result<(), D::Error>
207    where
208        D: Deserializer<'de>,
209    {
210        deserializer.deserialize_any(self)
211    }
212}
213
214/// A trait for types that can be constructed from a shape and a flat data.
215pub trait MakeNDim {
216    /// The shape of the multi-dimensional array.
217    type Shape: Shape;
218    /// Array element type.
219    type Item;
220
221    /// Construct a multi-dimensional array from a shape and a flat data.
222    fn from_shape_and_data(shape: Self::Shape, data: Vec<Self::Item>) -> Self;
223}
224
225/// Deserialize a multi-dimensional column-major array from a recursively nested sequence of numbers.
226///
227/// See [crate-level documentation](../#Deserialization) for more details.
228pub fn deserialize<'de, A, D>(deserializer: D) -> Result<A, D::Error>
229where
230    A: MakeNDim,
231    A::Item: Deserialize<'de>,
232    D: Deserializer<'de>,
233{
234    let mut context = Context {
235        data: Vec::new(),
236        shape: None,
237        current_dim: 0,
238    };
239    deserializer.deserialize_any(&mut context)?;
240    Ok(A::from_shape_and_data(context.shape.unwrap(), context.data))
241}