tf_idf_vectorizer/utils/datastruct/vector/
serde.rs

1use num_traits::Num;
2use serde::{Serialize, Serializer, Deserialize, Deserializer};
3use serde::ser::SerializeStruct;
4use crate::utils::datastruct::vector::ZeroSpVecTrait;
5
6use super::ZeroSpVec;
7
8impl<N> Serialize for ZeroSpVec<N>
9where
10    N: Num + Serialize + Copy,
11{
12    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
13    where S: Serializer {
14    // シリアライズするフィールドは len, nnz, inds, vals
15    let mut state = serializer.serialize_struct("ZeroSpVec", 4)?;
16        state.serialize_field("len", &(self.len as u64))?;
17        state.serialize_field("nnz", &(self.nnz as u64))?;
18
19        // zero-copy
20        let inds: &[u32] = unsafe {
21            std::slice::from_raw_parts(self.ind_ptr() as *const u32, self.nnz)
22        };
23        let vals: &[N] = unsafe {
24            std::slice::from_raw_parts(self.val_ptr() as *const N, self.nnz)
25        };
26        state.serialize_field("inds", inds)?;
27        state.serialize_field("vals", vals)?;
28        state.end()
29    }
30}
31
32impl<'de, N> Deserialize<'de> for ZeroSpVec<N>
33where
34    N: Num + Deserialize<'de> + Copy,
35{
36    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
37    where D: Deserializer<'de> {
38    use serde::de::{Visitor, MapAccess, SeqAccess, DeserializeSeed, Error as DeError};
39        use std::fmt;
40
41        // シーケンスを直接内部バッファへ書き込むためのSeed
42        struct IndsSeed<'a, T: Num + Copy> {
43            vec: &'a mut ZeroSpVec<T>,
44            count: &'a mut usize,
45        }
46        impl<'de, 'a, T> DeserializeSeed<'de> for IndsSeed<'a, T>
47        where
48            T: Num + Deserialize<'de> + Copy,
49        {
50            type Value = ();
51            fn deserialize<Ds>(self, deserializer: Ds) -> Result<Self::Value, Ds::Error>
52            where
53                Ds: Deserializer<'de>,
54            {
55                struct V<'b, T: Num + Copy> {
56                    vec: &'b mut ZeroSpVec<T>,
57                    count: &'b mut usize,
58                }
59                impl<'de, 'b, T> Visitor<'de> for V<'b, T>
60                where
61                    T: Num + Deserialize<'de> + Copy,
62                {
63                    type Value = ();
64                    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
65                        write!(f, "sequence of u32 indices")
66                    }
67                    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
68                    where
69                        A: SeqAccess<'de>,
70                    {
71                        let mut i = 0usize;
72                        while let Some(idx) = seq.next_element::<u32>()? {
73                            if i == self.vec.buf.cap {
74                                self.vec.buf.grow();
75                            }
76                            unsafe { *self.vec.ind_ptr().add(i) = idx; }
77                            i += 1;
78                        }
79                        *self.count = i;
80                        Ok(())
81                    }
82                }
83                deserializer.deserialize_seq(V {
84                    vec: self.vec,
85                    count: self.count,
86                })
87            }
88        }
89
90        struct ValsSeed<'a, T: Num + Copy> {
91            vec: &'a mut ZeroSpVec<T>,
92            count: &'a mut usize,
93        }
94        impl<'de, 'a, T> DeserializeSeed<'de> for ValsSeed<'a, T>
95        where
96            T: Num + Deserialize<'de> + Copy,
97        {
98            type Value = ();
99            fn deserialize<Ds>(self, deserializer: Ds) -> Result<Self::Value, Ds::Error>
100            where
101                Ds: Deserializer<'de>,
102            {
103                struct V<'b, T: Num + Copy> {
104                    vec: &'b mut ZeroSpVec<T>,
105                    count: &'b mut usize,
106                }
107                impl<'de, 'b, T> Visitor<'de> for V<'b, T>
108                where
109                    T: Num + Deserialize<'de> + Copy,
110                {
111                    type Value = ();
112                    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
113                        write!(f, "sequence of values")
114                    }
115                    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
116                    where
117                        A: SeqAccess<'de>,
118                    {
119                        let mut i = 0usize;
120                        while let Some(val) = seq.next_element::<T>()? {
121                            if i == self.vec.buf.cap {
122                                self.vec.buf.grow();
123                            }
124                            unsafe { *self.vec.val_ptr().add(i) = val; }
125                            i += 1;
126                        }
127                        *self.count = i;
128                        Ok(())
129                    }
130                }
131                deserializer.deserialize_seq(V {
132                    vec: self.vec,
133                    count: self.count,
134                })
135            }
136        }
137
138        struct ZeroSpVecVisitor<N> {
139            marker: std::marker::PhantomData<N>,
140        }
141
142        impl<'de, N> Visitor<'de> for ZeroSpVecVisitor<N>
143        where N: Num + Deserialize<'de> + Copy {
144            type Value = ZeroSpVec<N>;
145
146            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
147                write!(f, "struct ZeroSpVec")
148            }
149
150            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
151            where A: MapAccess<'de> {
152                let mut len: Option<usize> = None;
153                let mut nnz: Option<usize> = None;
154                let mut vec = ZeroSpVec::new();
155                let mut inds_count: usize = 0;
156                let mut vals_count: usize = 0;
157
158                while let Some(key) = map.next_key::<String>()? {
159                    match key.as_str() {
160                        "len" => len = Some(map.next_value::<u64>()? as usize),
161                        "nnz" => {
162                            let v = map.next_value::<u64>()? as usize;
163                            if v > vec.buf.cap {
164                                let old_cap = vec.buf.cap;
165                                vec.buf.cap = v;
166                                if old_cap == 0 {
167                                    vec.buf.cap_set();
168                                } else {
169                                    vec.buf.re_cap_set();
170                                }
171                            }
172                            nnz = Some(v);
173                        },
174                        "inds" => { map.next_value_seed(IndsSeed { vec: &mut vec, count: &mut inds_count })?; },
175                        "vals" => { map.next_value_seed(ValsSeed { vec: &mut vec, count: &mut vals_count })?; },
176                        _ => { let _: serde::de::IgnoredAny = map.next_value()?; }
177                    }
178                }
179                let len = len.ok_or_else(|| DeError::missing_field("len"))?;
180                let nnz = nnz.ok_or_else(|| DeError::missing_field("nnz"))?;
181                if inds_count != vals_count { return Err(DeError::custom("inds and vals length mismatch")); }
182                if inds_count != nnz { return Err(DeError::custom("nnz does not match inds/vals length")); }
183                vec.len = len;
184                vec.nnz = nnz;
185                Ok(vec)
186            }
187
188            // bincode等のフィールド名を持たないフォーマット用: フィールド順のシーケンス
189            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
190            where A: SeqAccess<'de> {
191                let len_u64: u64 = seq
192                    .next_element()?
193                    .ok_or_else(|| DeError::custom("missing len"))?;
194                let nnz_u64: u64 = seq
195                    .next_element()?
196                    .ok_or_else(|| DeError::custom("missing nnz"))?;
197                let len = len_u64 as usize;
198                let nnz = nnz_u64 as usize;
199
200                // 新形式(inds, vals)のみ対応(bincodeの旧entries形式は非対応)
201                // Vecへ一旦受けず、内部バッファへ直接書き込む
202                let mut vec = ZeroSpVec::with_capacity(nnz);
203                vec.len = len;
204                let mut inds_count: usize = 0;
205                let mut vals_count: usize = 0;
206
207                seq.next_element_seed(IndsSeed {
208                    vec: &mut vec,
209                    count: &mut inds_count,
210                })?
211                .ok_or_else(|| DeError::custom("missing inds"))?;
212
213                seq.next_element_seed(ValsSeed {
214                    vec: &mut vec,
215                    count: &mut vals_count,
216                })?
217                .ok_or_else(|| DeError::custom("missing vals"))?;
218
219                if inds_count != vals_count {
220                    return Err(DeError::custom("inds and vals length mismatch"));
221                }
222                if inds_count != nnz {
223                    return Err(DeError::custom("nnz does not match inds/vals length"));
224                }
225                vec.nnz = nnz;
226                Ok(vec)
227            }
228        }
229        deserializer.deserialize_struct(
230            "ZeroSpVec",
231            &["len", "nnz", "inds", "vals"],
232            ZeroSpVecVisitor { marker: std::marker::PhantomData },
233        )
234    }
235}