tf_idf_vectorizer/utils/datastruct/vector/
serde.rs1use num_traits::Num;
2use serde::{Serialize, Serializer, Deserialize, Deserializer};
3use serde::ser::SerializeStruct;
4use crate::utils::datastruct::vector::ZeroSpVecTrait;
5
6use super::ZeroSpVec;
7
8impl<N> Serialize for ZeroSpVec<N>
9where
10 N: Num + Serialize + Copy,
11{
12 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
13 where S: Serializer {
14 let mut state = serializer.serialize_struct("ZeroSpVec", 4)?;
16 state.serialize_field("len", &(self.len as u64))?;
17 state.serialize_field("nnz", &(self.nnz as u64))?;
18
19 let inds: &[u32] = unsafe {
21 std::slice::from_raw_parts(self.ind_ptr() as *const u32, self.nnz)
22 };
23 let vals: &[N] = unsafe {
24 std::slice::from_raw_parts(self.val_ptr() as *const N, self.nnz)
25 };
26 state.serialize_field("inds", inds)?;
27 state.serialize_field("vals", vals)?;
28 state.end()
29 }
30}
31
32impl<'de, N> Deserialize<'de> for ZeroSpVec<N>
33where
34 N: Num + Deserialize<'de> + Copy,
35{
36 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
37 where D: Deserializer<'de> {
38 use serde::de::{Visitor, MapAccess, SeqAccess, DeserializeSeed, Error as DeError};
39 use std::fmt;
40
41 struct IndsSeed<'a, T: Num + Copy> {
43 vec: &'a mut ZeroSpVec<T>,
44 count: &'a mut usize,
45 }
46 impl<'de, 'a, T> DeserializeSeed<'de> for IndsSeed<'a, T>
47 where
48 T: Num + Deserialize<'de> + Copy,
49 {
50 type Value = ();
51 fn deserialize<Ds>(self, deserializer: Ds) -> Result<Self::Value, Ds::Error>
52 where
53 Ds: Deserializer<'de>,
54 {
55 struct V<'b, T: Num + Copy> {
56 vec: &'b mut ZeroSpVec<T>,
57 count: &'b mut usize,
58 }
59 impl<'de, 'b, T> Visitor<'de> for V<'b, T>
60 where
61 T: Num + Deserialize<'de> + Copy,
62 {
63 type Value = ();
64 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
65 write!(f, "sequence of u32 indices")
66 }
67 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
68 where
69 A: SeqAccess<'de>,
70 {
71 let mut i = 0usize;
72 while let Some(idx) = seq.next_element::<u32>()? {
73 if i == self.vec.buf.cap {
74 self.vec.buf.grow();
75 }
76 unsafe { *self.vec.ind_ptr().add(i) = idx; }
77 i += 1;
78 }
79 *self.count = i;
80 Ok(())
81 }
82 }
83 deserializer.deserialize_seq(V {
84 vec: self.vec,
85 count: self.count,
86 })
87 }
88 }
89
90 struct ValsSeed<'a, T: Num + Copy> {
91 vec: &'a mut ZeroSpVec<T>,
92 count: &'a mut usize,
93 }
94 impl<'de, 'a, T> DeserializeSeed<'de> for ValsSeed<'a, T>
95 where
96 T: Num + Deserialize<'de> + Copy,
97 {
98 type Value = ();
99 fn deserialize<Ds>(self, deserializer: Ds) -> Result<Self::Value, Ds::Error>
100 where
101 Ds: Deserializer<'de>,
102 {
103 struct V<'b, T: Num + Copy> {
104 vec: &'b mut ZeroSpVec<T>,
105 count: &'b mut usize,
106 }
107 impl<'de, 'b, T> Visitor<'de> for V<'b, T>
108 where
109 T: Num + Deserialize<'de> + Copy,
110 {
111 type Value = ();
112 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
113 write!(f, "sequence of values")
114 }
115 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
116 where
117 A: SeqAccess<'de>,
118 {
119 let mut i = 0usize;
120 while let Some(val) = seq.next_element::<T>()? {
121 if i == self.vec.buf.cap {
122 self.vec.buf.grow();
123 }
124 unsafe { *self.vec.val_ptr().add(i) = val; }
125 i += 1;
126 }
127 *self.count = i;
128 Ok(())
129 }
130 }
131 deserializer.deserialize_seq(V {
132 vec: self.vec,
133 count: self.count,
134 })
135 }
136 }
137
138 struct ZeroSpVecVisitor<N> {
139 marker: std::marker::PhantomData<N>,
140 }
141
142 impl<'de, N> Visitor<'de> for ZeroSpVecVisitor<N>
143 where N: Num + Deserialize<'de> + Copy {
144 type Value = ZeroSpVec<N>;
145
146 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
147 write!(f, "struct ZeroSpVec")
148 }
149
150 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
151 where A: MapAccess<'de> {
152 let mut len: Option<usize> = None;
153 let mut nnz: Option<usize> = None;
154 let mut vec = ZeroSpVec::new();
155 let mut inds_count: usize = 0;
156 let mut vals_count: usize = 0;
157
158 while let Some(key) = map.next_key::<String>()? {
159 match key.as_str() {
160 "len" => len = Some(map.next_value::<u64>()? as usize),
161 "nnz" => {
162 let v = map.next_value::<u64>()? as usize;
163 if v > vec.buf.cap {
164 let old_cap = vec.buf.cap;
165 vec.buf.cap = v;
166 if old_cap == 0 {
167 vec.buf.cap_set();
168 } else {
169 vec.buf.re_cap_set();
170 }
171 }
172 nnz = Some(v);
173 },
174 "inds" => { map.next_value_seed(IndsSeed { vec: &mut vec, count: &mut inds_count })?; },
175 "vals" => { map.next_value_seed(ValsSeed { vec: &mut vec, count: &mut vals_count })?; },
176 _ => { let _: serde::de::IgnoredAny = map.next_value()?; }
177 }
178 }
179 let len = len.ok_or_else(|| DeError::missing_field("len"))?;
180 let nnz = nnz.ok_or_else(|| DeError::missing_field("nnz"))?;
181 if inds_count != vals_count { return Err(DeError::custom("inds and vals length mismatch")); }
182 if inds_count != nnz { return Err(DeError::custom("nnz does not match inds/vals length")); }
183 vec.len = len;
184 vec.nnz = nnz;
185 Ok(vec)
186 }
187
188 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
190 where A: SeqAccess<'de> {
191 let len_u64: u64 = seq
192 .next_element()?
193 .ok_or_else(|| DeError::custom("missing len"))?;
194 let nnz_u64: u64 = seq
195 .next_element()?
196 .ok_or_else(|| DeError::custom("missing nnz"))?;
197 let len = len_u64 as usize;
198 let nnz = nnz_u64 as usize;
199
200 let mut vec = ZeroSpVec::with_capacity(nnz);
203 vec.len = len;
204 let mut inds_count: usize = 0;
205 let mut vals_count: usize = 0;
206
207 seq.next_element_seed(IndsSeed {
208 vec: &mut vec,
209 count: &mut inds_count,
210 })?
211 .ok_or_else(|| DeError::custom("missing inds"))?;
212
213 seq.next_element_seed(ValsSeed {
214 vec: &mut vec,
215 count: &mut vals_count,
216 })?
217 .ok_or_else(|| DeError::custom("missing vals"))?;
218
219 if inds_count != vals_count {
220 return Err(DeError::custom("inds and vals length mismatch"));
221 }
222 if inds_count != nnz {
223 return Err(DeError::custom("nnz does not match inds/vals length"));
224 }
225 vec.nnz = nnz;
226 Ok(vec)
227 }
228 }
229 deserializer.deserialize_struct(
230 "ZeroSpVec",
231 &["len", "nnz", "inds", "vals"],
232 ZeroSpVecVisitor { marker: std::marker::PhantomData },
233 )
234 }
235}