tf_idf_vectorizer/utils/math/vector/
serde.rs

1use num::Num;
2use serde::{Serialize, Serializer, Deserialize, Deserializer};
3use serde::ser::SerializeStruct;
4use crate::utils::math::vector::ZeroSpVecTrait;
5
6use super::ZeroSpVec;
7
8impl<N> Serialize for ZeroSpVec<N>
9where
10    N: Num + Serialize + Copy,
11{
12    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
13    where S: Serializer {
14    // シリアライズするフィールドは len, nnz, inds, vals
15    let mut state = serializer.serialize_struct("ZeroSpVec", 4)?;
16        state.serialize_field("len", &(self.len as u64))?;
17        state.serialize_field("nnz", &(self.nnz as u64))?;
18        
19        // inds/vals: バッファをそのまま配列として出力する(タプルを避ける)
20        let mut inds: Vec<u32> = Vec::with_capacity(self.nnz);
21        let mut vals: Vec<N> = Vec::with_capacity(self.nnz);
22        unsafe {
23            for i in 0..self.nnz {
24                inds.push(*self.ind_ptr().add(i));
25                vals.push(*self.val_ptr().add(i));
26            }
27        }
28        state.serialize_field("inds", &inds)?;
29        state.serialize_field("vals", &vals)?;
30        state.end()
31    }
32}
33
34impl<'de, N> Deserialize<'de> for ZeroSpVec<N>
35where
36    N: Num + Deserialize<'de> + Copy,
37{
38    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
39    where D: Deserializer<'de> {
40    use serde::de::{Visitor, MapAccess, SeqAccess, DeserializeSeed, Error as DeError};
41        use std::fmt;
42
43        struct ZeroSpVecVisitor<N> {
44            marker: std::marker::PhantomData<N>,
45        }
46
47        impl<'de, N> Visitor<'de> for ZeroSpVecVisitor<N>
48        where N: Num + Deserialize<'de> + Copy {
49            type Value = ZeroSpVec<N>;
50
51            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
52                write!(f, "struct ZeroSpVec")
53            }
54
55            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
56            where A: MapAccess<'de> {
57                let mut len: Option<usize> = None;
58                let mut nnz: Option<usize> = None;
59                let mut vec = ZeroSpVec::new();
60                let mut inds_count: usize = 0;
61                let mut vals_count: usize = 0;
62
63                // シーケンスを直接内部バッファへ書き込むためのSeed
64                struct IndsSeed<'a, T: Num + Copy> { vec: &'a mut ZeroSpVec<T>, count: &'a mut usize }
65                impl<'de, 'a, T> DeserializeSeed<'de> for IndsSeed<'a, T>
66                where T: Num + Deserialize<'de> + Copy {
67                    type Value = ();
68                    fn deserialize<Ds>(self, deserializer: Ds) -> Result<Self::Value, Ds::Error>
69                    where Ds: Deserializer<'de> {
70                        struct V<'b, T: Num + Copy> { vec: &'b mut ZeroSpVec<T>, count: &'b mut usize }
71                        impl<'de, 'b, T> Visitor<'de> for V<'b, T>
72                        where T: Num + Deserialize<'de> + Copy {
73                            type Value = ();
74                            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "sequence of u32 indices") }
75                            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
76                            where A: SeqAccess<'de> {
77                                let mut i = 0usize;
78                                while let Some(idx) = seq.next_element::<u32>()? {
79                                    if i == self.vec.buf.cap { self.vec.buf.grow(); }
80                                    unsafe { *self.vec.ind_ptr().add(i) = idx; }
81                                    i += 1;
82                                }
83                                *self.count = i;
84                                Ok(())
85                            }
86                        }
87                        deserializer.deserialize_seq(V { vec: self.vec, count: self.count })
88                    }
89                }
90
91                struct ValsSeed<'a, T: Num + Copy> { vec: &'a mut ZeroSpVec<T>, count: &'a mut usize }
92                impl<'de, 'a, T> DeserializeSeed<'de> for ValsSeed<'a, T>
93                where T: Num + Deserialize<'de> + Copy {
94                    type Value = ();
95                    fn deserialize<Ds>(self, deserializer: Ds) -> Result<Self::Value, Ds::Error>
96                    where Ds: Deserializer<'de> {
97                        struct V<'b, T: Num + Copy> { vec: &'b mut ZeroSpVec<T>, count: &'b mut usize }
98                        impl<'de, 'b, T> Visitor<'de> for V<'b, T>
99                        where T: Num + Deserialize<'de> + Copy {
100                            type Value = ();
101                            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "sequence of values") }
102                            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
103                            where A: SeqAccess<'de> {
104                                let mut i = 0usize;
105                                while let Some(val) = seq.next_element::<T>()? {
106                                    if i == self.vec.buf.cap { self.vec.buf.grow(); }
107                                    unsafe { *self.vec.val_ptr().add(i) = val; }
108                                    i += 1;
109                                }
110                                *self.count = i;
111                                Ok(())
112                            }
113                        }
114                        deserializer.deserialize_seq(V { vec: self.vec, count: self.count })
115                    }
116                }
117
118                while let Some(key) = map.next_key::<String>()? {
119                    match key.as_str() {
120                        "len" => len = Some(map.next_value::<u64>()? as usize),
121                        "nnz" => {
122                            let v = map.next_value::<u64>()? as usize;
123                            if v > vec.buf.cap { vec.buf.cap = v; vec.buf.cap_set(); }
124                            nnz = Some(v);
125                        },
126                        "inds" => { map.next_value_seed(IndsSeed { vec: &mut vec, count: &mut inds_count })?; },
127                        "vals" => { map.next_value_seed(ValsSeed { vec: &mut vec, count: &mut vals_count })?; },
128                        _ => { let _: serde::de::IgnoredAny = map.next_value()?; }
129                    }
130                }
131                let len = len.ok_or_else(|| DeError::missing_field("len"))?;
132                let nnz = nnz.ok_or_else(|| DeError::missing_field("nnz"))?;
133                if inds_count != vals_count { return Err(DeError::custom("inds and vals length mismatch")); }
134                if inds_count != nnz { return Err(DeError::custom("nnz does not match inds/vals length")); }
135                vec.len = len;
136                vec.nnz = nnz;
137                Ok(vec)
138            }
139
140            // bincode等のフィールド名を持たないフォーマット用: フィールド順のシーケンス
141            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
142            where A: SeqAccess<'de> {
143                let len_u64: u64 = seq
144                    .next_element()?
145                    .ok_or_else(|| DeError::custom("missing len"))?;
146                let nnz_u64: u64 = seq
147                    .next_element()?
148                    .ok_or_else(|| DeError::custom("missing nnz"))?;
149                let len = len_u64 as usize;
150                let nnz = nnz_u64 as usize;
151
152                // 新形式(inds, vals)のみ対応(bincodeの旧entries形式は非対応)
153                let inds: Vec<u32> = seq
154                    .next_element()?
155                    .ok_or_else(|| DeError::custom("missing inds"))?;
156                let vals: Vec<N> = seq
157                    .next_element()?
158                    .ok_or_else(|| DeError::custom("missing vals"))?;
159
160                if inds.len() != vals.len() {
161                    return Err(DeError::custom("inds and vals length mismatch"));
162                }
163                if inds.len() != nnz {
164                    return Err(DeError::custom("nnz does not match inds/vals length"));
165                }
166
167                let mut vec = ZeroSpVec::with_capacity(nnz);
168                vec.len = len;
169                for i in 0..nnz {
170                    unsafe { vec.raw_push(inds[i] as usize, vals[i]) };
171                }
172                Ok(vec)
173            }
174        }
175        deserializer.deserialize_struct(
176            "ZeroSpVec",
177            &["len", "nnz", "inds", "vals"],
178            ZeroSpVecVisitor { marker: std::marker::PhantomData },
179        )
180    }
181}