tf_idf_vectorizer/utils/math/vector/
serde.rs1use num::Num;
2use serde::{Serialize, Serializer, Deserialize, Deserializer};
3use serde::ser::SerializeStruct;
4use crate::utils::math::vector::ZeroSpVecTrait;
5
6use super::ZeroSpVec;
7
8impl<N> Serialize for ZeroSpVec<N>
9where
10 N: Num + Serialize + Copy,
11{
12 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
13 where S: Serializer {
14 let mut state = serializer.serialize_struct("ZeroSpVec", 4)?;
16 state.serialize_field("len", &(self.len as u64))?;
17 state.serialize_field("nnz", &(self.nnz as u64))?;
18
19 let mut inds: Vec<u32> = Vec::with_capacity(self.nnz);
21 let mut vals: Vec<N> = Vec::with_capacity(self.nnz);
22 unsafe {
23 for i in 0..self.nnz {
24 inds.push(*self.ind_ptr().add(i));
25 vals.push(*self.val_ptr().add(i));
26 }
27 }
28 state.serialize_field("inds", &inds)?;
29 state.serialize_field("vals", &vals)?;
30 state.end()
31 }
32}
33
34impl<'de, N> Deserialize<'de> for ZeroSpVec<N>
35where
36 N: Num + Deserialize<'de> + Copy,
37{
38 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
39 where D: Deserializer<'de> {
40 use serde::de::{Visitor, MapAccess, SeqAccess, DeserializeSeed, Error as DeError};
41 use std::fmt;
42
43 struct ZeroSpVecVisitor<N> {
44 marker: std::marker::PhantomData<N>,
45 }
46
47 impl<'de, N> Visitor<'de> for ZeroSpVecVisitor<N>
48 where N: Num + Deserialize<'de> + Copy {
49 type Value = ZeroSpVec<N>;
50
51 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
52 write!(f, "struct ZeroSpVec")
53 }
54
55 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
56 where A: MapAccess<'de> {
57 let mut len: Option<usize> = None;
58 let mut nnz: Option<usize> = None;
59 let mut vec = ZeroSpVec::new();
60 let mut inds_count: usize = 0;
61 let mut vals_count: usize = 0;
62
63 struct IndsSeed<'a, T: Num + Copy> { vec: &'a mut ZeroSpVec<T>, count: &'a mut usize }
65 impl<'de, 'a, T> DeserializeSeed<'de> for IndsSeed<'a, T>
66 where T: Num + Deserialize<'de> + Copy {
67 type Value = ();
68 fn deserialize<Ds>(self, deserializer: Ds) -> Result<Self::Value, Ds::Error>
69 where Ds: Deserializer<'de> {
70 struct V<'b, T: Num + Copy> { vec: &'b mut ZeroSpVec<T>, count: &'b mut usize }
71 impl<'de, 'b, T> Visitor<'de> for V<'b, T>
72 where T: Num + Deserialize<'de> + Copy {
73 type Value = ();
74 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "sequence of u32 indices") }
75 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
76 where A: SeqAccess<'de> {
77 let mut i = 0usize;
78 while let Some(idx) = seq.next_element::<u32>()? {
79 if i == self.vec.buf.cap { self.vec.buf.grow(); }
80 unsafe { *self.vec.ind_ptr().add(i) = idx; }
81 i += 1;
82 }
83 *self.count = i;
84 Ok(())
85 }
86 }
87 deserializer.deserialize_seq(V { vec: self.vec, count: self.count })
88 }
89 }
90
91 struct ValsSeed<'a, T: Num + Copy> { vec: &'a mut ZeroSpVec<T>, count: &'a mut usize }
92 impl<'de, 'a, T> DeserializeSeed<'de> for ValsSeed<'a, T>
93 where T: Num + Deserialize<'de> + Copy {
94 type Value = ();
95 fn deserialize<Ds>(self, deserializer: Ds) -> Result<Self::Value, Ds::Error>
96 where Ds: Deserializer<'de> {
97 struct V<'b, T: Num + Copy> { vec: &'b mut ZeroSpVec<T>, count: &'b mut usize }
98 impl<'de, 'b, T> Visitor<'de> for V<'b, T>
99 where T: Num + Deserialize<'de> + Copy {
100 type Value = ();
101 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "sequence of values") }
102 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
103 where A: SeqAccess<'de> {
104 let mut i = 0usize;
105 while let Some(val) = seq.next_element::<T>()? {
106 if i == self.vec.buf.cap { self.vec.buf.grow(); }
107 unsafe { *self.vec.val_ptr().add(i) = val; }
108 i += 1;
109 }
110 *self.count = i;
111 Ok(())
112 }
113 }
114 deserializer.deserialize_seq(V { vec: self.vec, count: self.count })
115 }
116 }
117
118 while let Some(key) = map.next_key::<String>()? {
119 match key.as_str() {
120 "len" => len = Some(map.next_value::<u64>()? as usize),
121 "nnz" => {
122 let v = map.next_value::<u64>()? as usize;
123 if v > vec.buf.cap { vec.buf.cap = v; vec.buf.cap_set(); }
124 nnz = Some(v);
125 },
126 "inds" => { map.next_value_seed(IndsSeed { vec: &mut vec, count: &mut inds_count })?; },
127 "vals" => { map.next_value_seed(ValsSeed { vec: &mut vec, count: &mut vals_count })?; },
128 _ => { let _: serde::de::IgnoredAny = map.next_value()?; }
129 }
130 }
131 let len = len.ok_or_else(|| DeError::missing_field("len"))?;
132 let nnz = nnz.ok_or_else(|| DeError::missing_field("nnz"))?;
133 if inds_count != vals_count { return Err(DeError::custom("inds and vals length mismatch")); }
134 if inds_count != nnz { return Err(DeError::custom("nnz does not match inds/vals length")); }
135 vec.len = len;
136 vec.nnz = nnz;
137 Ok(vec)
138 }
139
140 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
142 where A: SeqAccess<'de> {
143 let len_u64: u64 = seq
144 .next_element()?
145 .ok_or_else(|| DeError::custom("missing len"))?;
146 let nnz_u64: u64 = seq
147 .next_element()?
148 .ok_or_else(|| DeError::custom("missing nnz"))?;
149 let len = len_u64 as usize;
150 let nnz = nnz_u64 as usize;
151
152 let inds: Vec<u32> = seq
154 .next_element()?
155 .ok_or_else(|| DeError::custom("missing inds"))?;
156 let vals: Vec<N> = seq
157 .next_element()?
158 .ok_or_else(|| DeError::custom("missing vals"))?;
159
160 if inds.len() != vals.len() {
161 return Err(DeError::custom("inds and vals length mismatch"));
162 }
163 if inds.len() != nnz {
164 return Err(DeError::custom("nnz does not match inds/vals length"));
165 }
166
167 let mut vec = ZeroSpVec::with_capacity(nnz);
168 vec.len = len;
169 for i in 0..nnz {
170 unsafe { vec.raw_push(inds[i] as usize, vals[i]) };
171 }
172 Ok(vec)
173 }
174 }
175 deserializer.deserialize_struct(
176 "ZeroSpVec",
177 &["len", "nnz", "inds", "vals"],
178 ZeroSpVecVisitor { marker: std::marker::PhantomData },
179 )
180 }
181}