1use serde::de::{DeserializeSeed, Error, IgnoredAny, IntoDeserializer, SeqAccess, Visitor};
2use serde::{Deserialize, Deserializer};
3use std::borrow::BorrowMut;
4use std::boxed::Box;
5use std::vec;
6use std::vec::Vec;
7
8macro_rules! forward_visitors {
9 ($(fn $method:ident ($arg:ty);)*) => ($(
10 fn $method<E: Error>(self, arg: $arg) -> Result<Self::Value, E> {
11 self.deserialize_num(arg)
12 }
13 )*);
14}
15
16pub trait Shape: BorrowMut<[usize]> {
18 const MIN_DIMS: usize;
20 const MAX_DIMS: usize;
22
23 fn new_zeroed(dims: usize) -> Self;
25
26 fn dim_len(&self, dims: usize) -> Option<usize> {
28 self.borrow().get(dims).copied()
29 }
30
31 fn set_dim_len(&mut self, dims: usize, value: usize) {
33 debug_assert_eq!(self.dim_len(dims), Some(0));
35 self.borrow_mut()[dims] = value;
36 }
37}
38
39impl Shape for Box<[usize]> {
40 const MIN_DIMS: usize = 0;
41 const MAX_DIMS: usize = usize::MAX;
42
43 fn new_zeroed(dims: usize) -> Self {
44 vec![0; dims].into_boxed_slice()
45 }
46}
47
48impl<const DIMS: usize> Shape for [usize; DIMS] {
49 const MIN_DIMS: usize = DIMS;
50 const MAX_DIMS: usize = DIMS;
51
52 fn new_zeroed(dims: usize) -> Self {
53 debug_assert_eq!(dims, DIMS);
54 [0; DIMS]
55 }
56}
57
58#[cfg(feature = "arrayvec")]
59impl<const MAX_DIMS: usize> Shape for arrayvec::ArrayVec<usize, MAX_DIMS> {
60 const MIN_DIMS: usize = 0;
61 const MAX_DIMS: usize = MAX_DIMS;
62
63 fn new_zeroed(dims: usize) -> Self {
64 debug_assert!(dims <= MAX_DIMS);
65 let mut shape = Self::new();
66 shape.extend(core::iter::repeat(0).take(dims));
67 shape
68 }
69}
70
71#[derive(Debug)]
72struct Context<T, S> {
73 data: Vec<T>,
74 shape: Option<S>,
75 current_dim: usize,
76}
77
78impl<'de, T: Deserialize<'de>, S: Shape> Context<T, S> {
79 fn got_number<E: Error>(&mut self) -> Result<(), E> {
80 match &self.shape {
81 Some(shape) => {
82 if self.current_dim < shape.borrow().len() {
83 return Err(E::invalid_type(
85 serde::de::Unexpected::Other("a single number"),
86 &"a sequence",
87 ));
88 }
89 }
90 None => {
91 if self.current_dim < S::MIN_DIMS {
96 return Err(Error::custom(format_args!(
97 "didn't reach the expected minimum dims {}, got {}",
98 S::MIN_DIMS,
99 self.current_dim,
100 )));
101 }
102 self.shape = Some(S::new_zeroed(self.current_dim));
103 }
104 }
105 Ok(())
106 }
107
108 fn deserialize_num_from<D: Deserializer<'de>>(
109 &mut self,
110 deserializer: D,
111 ) -> Result<(), D::Error> {
112 self.got_number()?;
113 let value = T::deserialize(deserializer)?;
114 self.data.push(value);
115 Ok(())
116 }
117
118 fn deserialize_num<E: Error>(&mut self, arg: impl IntoDeserializer<'de, E>) -> Result<(), E> {
119 self.deserialize_num_from(arg.into_deserializer())
120 }
121}
122
123impl<'de, T: Deserialize<'de>, S: Shape> Visitor<'de> for &mut Context<T, S> {
124 type Value = ();
125
126 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
127 formatter.write_str("a sequence or a single number")
128 }
129
130 fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
131 if let Some(shape) = &self.shape {
134 let expected_len = shape.dim_len(self.current_dim).ok_or_else(|| {
138 Error::invalid_type(serde::de::Unexpected::Seq, &"a single number")
139 })?;
140 self.current_dim += 1;
141 for _ in 0..expected_len {
143 seq.next_element_seed(&mut *self)?
144 .ok_or_else(|| Error::custom("expected more elements"))?;
145 }
146 if seq.next_element::<IgnoredAny>()?.is_some() {
149 return Err(Error::custom("expected end of sequence"));
150 }
151 self.current_dim -= 1;
152 } else {
153 debug_assert!(self.shape.is_none());
155 self.current_dim += 1;
156 if self.current_dim > S::MAX_DIMS {
157 return Err(Error::custom(format_args!(
158 "maximum dims of {} exceeded",
159 S::MAX_DIMS
160 )));
161 }
162 let mut len = 0;
164 while seq.next_element_seed(&mut *self)?.is_some() {
165 len += 1;
166 }
167 self.current_dim -= 1;
168 let shape = self
170 .shape
171 .as_mut()
172 .expect("internal error: shape should be allocated by now");
173 shape.set_dim_len(self.current_dim, len);
174 }
175 Ok(())
176 }
177
178 forward_visitors! {
179 fn visit_i8(i8);
180 fn visit_i16(i16);
181 fn visit_i32(i32);
182 fn visit_i64(i64);
183 fn visit_u8(u8);
184 fn visit_u16(u16);
185 fn visit_u32(u32);
186 fn visit_u64(u64);
187 fn visit_f32(f32);
188 fn visit_f64(f64);
189 fn visit_i128(i128);
190 fn visit_u128(u128);
191 }
192
193 fn visit_newtype_struct<D: Deserializer<'de>>(
194 self,
195 deserializer: D,
196 ) -> Result<Self::Value, D::Error> {
197 self.deserialize_num_from(deserializer)
200 }
201}
202
203impl<'de, T: Deserialize<'de>, S: Shape> DeserializeSeed<'de> for &mut Context<T, S> {
204 type Value = ();
205
206 fn deserialize<D>(self, deserializer: D) -> Result<(), D::Error>
207 where
208 D: Deserializer<'de>,
209 {
210 deserializer.deserialize_any(self)
211 }
212}
213
214pub trait MakeNDim {
216 type Shape: Shape;
218 type Item;
220
221 fn from_shape_and_data(shape: Self::Shape, data: Vec<Self::Item>) -> Self;
223}
224
225pub fn deserialize<'de, A, D>(deserializer: D) -> Result<A, D::Error>
229where
230 A: MakeNDim,
231 A::Item: Deserialize<'de>,
232 D: Deserializer<'de>,
233{
234 let mut context = Context {
235 data: Vec::new(),
236 shape: None,
237 current_dim: 0,
238 };
239 deserializer.deserialize_any(&mut context)?;
240 Ok(A::from_shape_and_data(context.shape.unwrap(), context.data))
241}