spectrusty_core/memory/
arrays.rs1use std::{marker::PhantomData};
10use std::mem::{MaybeUninit};
11use std::ptr;
12
13use serde::{
14 de::{SeqAccess, Visitor},
15 ser::SerializeTuple,
16 Deserialize, Deserializer, Serialize, Serializer,
17};
18
19pub fn serialize<S: Serializer, T: Serialize, const N: usize>(
20 data: &[T; N],
21 ser: S,
22) -> Result<S::Ok, S::Error> {
23 let mut tuple = ser.serialize_tuple(N)?;
24 for item in data {
25 tuple.serialize_element(item)?;
26 }
27 tuple.end()
28}
29
30struct ArrayVisitor<T, const N: usize>(PhantomData<T>);
31
32impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
33 where T: Deserialize<'de>,
34{
35 type Value = [T; N];
36
37 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
38 formatter.write_str(&format!("an array of length {}", N))
39 }
40
41 #[inline]
42 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
43 where A: SeqAccess<'de>,
44 {
45 struct ArrayUninit<T, const N: usize> {
46 data: [MaybeUninit<T>; N],
47 init_len: usize
48 }
49
50 impl<T, const N: usize> Drop for ArrayUninit<T, N> {
51 fn drop(&mut self) {
52 for elem in &mut self.data[0..self.init_len] {
53 unsafe { ptr::drop_in_place(elem.as_mut_ptr()); }
54 }
55 }
56 }
57
58 let mut ary: ArrayUninit<T, N> = ArrayUninit {
59 data: unsafe { MaybeUninit::uninit().assume_init() },
60 init_len: 0
61 };
62 let mut iter = ary.data.iter_mut();
63 while let Some(val) = seq.next_element()? {
64 if let Some(elem) = iter.next() {
65 elem.write(val);
66 ary.init_len += 1;
67 }
68 else {
69 return Err(serde::de::Error::invalid_length(N, &self));
70 }
71 }
72 if ary.init_len != N {
73 return Err(serde::de::Error::invalid_length(N, &self));
74 }
75 ary.init_len = 0;
77 Ok(unsafe { (&ary.data as *const _ as *const [T; N]).read() })
78 }
79}
80
81pub fn deserialize<'de, D, T, const N: usize>(deserializer: D) -> Result<[T; N], D::Error>
82where
83 D: Deserializer<'de>,
84 T: Deserialize<'de>,
85{
86 deserializer.deserialize_tuple(N, ArrayVisitor::<T, N>(PhantomData))
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 #[derive(Clone, Copy, PartialEq, Debug)]
93 #[derive(Serialize, Deserialize)]
94 #[serde(transparent)]
95 struct ArrayWrap<T: Serialize, const N: usize>(
96 #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] [T;N])
97 where for <'a> T: Deserialize<'a>;
98
99 #[test]
100 fn arrays_serde_works() {
101 let ary: ArrayWrap<u8,5> = ArrayWrap([1,2,3,4,5]);
102 let serary = serde_json::to_string(&ary).unwrap();
103 assert_eq!(&serary, "[1,2,3,4,5]");
104 assert!(serde_json::from_str::<ArrayWrap<u8,4>>(&serary).is_err());
105 assert!(serde_json::from_str::<ArrayWrap<u8,6>>(&serary).is_err());
106 let ary_de: ArrayWrap<u8,5> = serde_json::from_str(&serary).unwrap();
107 assert_eq!(&ary, &ary_de);
108
109 let encoded: Vec<u8> = bincode::serialize(&ary).unwrap();
110 assert!(bincode::deserialize::<ArrayWrap<u8,6>>(&encoded).is_err());
111 let ary_de: ArrayWrap<u8,5> = bincode::deserialize(&encoded).unwrap();
112 assert_eq!(&ary, &ary_de);
113
114 let ary: ArrayWrap<String,3> = ArrayWrap(["foo".to_string(), "bar".to_string(), "baz".to_string()]);
115 let serary = serde_json::to_string(&ary).unwrap();
116 assert_eq!(&serary, r#"["foo","bar","baz"]"#);
117 assert!(serde_json::from_str::<ArrayWrap<String,2>>(&serary).is_err());
118 assert!(serde_json::from_str::<ArrayWrap<String,4>>(&serary).is_err());
119 let ary_de: ArrayWrap<String,3> = serde_json::from_str(&serary).unwrap();
120 assert_eq!(&ary, &ary_de);
121 }
122}