tf_idf_vectorizer/utils/math/vector/
serde.rs

1use num::Num;
2use serde::{Serialize, Serializer, Deserialize, Deserializer};
3use serde::ser::SerializeStruct;
4use crate::utils::math::vector::ZeroSpVecTrait;
5
6use super::ZeroSpVec;
7
8impl<N> Serialize for ZeroSpVec<N>
9where
10    N: Num + Serialize + Copy,
11{
12    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
13    where S: Serializer {
14        // シリアライズするフィールドは len, nnz, entries とする
15        let mut state = serializer.serialize_struct("ZeroSpVec", 3)?;
16        state.serialize_field("len", &(self.len as u64))?;
17        state.serialize_field("nnz", &(self.nnz as u64))?;
18        
19        // entries: (index, value) のVecとして順序付きに出力する
20        let mut entries = Vec::with_capacity(self.nnz);
21        //  entries = self.raw_iter().map(|(idx, entry)| (idx as u64, *entry)).collect();
22        unsafe {
23            for i in 0..self.nnz {
24                let idx = *self.ind_ptr().add(i);
25                let val = *self.val_ptr().add(i);
26                entries.push((idx as u64, val));
27            }
28        }
29        state.serialize_field("entries", &entries)?;
30        state.end()
31    }
32}
33
34impl<'de, N> Deserialize<'de> for ZeroSpVec<N>
35where
36    N: Num + Deserialize<'de> + Copy,
37{
38    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
39    where D: Deserializer<'de> {
40        use serde::de::{Visitor, MapAccess, Error as DeError};
41        use std::fmt;
42
43        struct ZeroSpVecVisitor<N> {
44            marker: std::marker::PhantomData<N>,
45        }
46
47        impl<'de, N> Visitor<'de> for ZeroSpVecVisitor<N>
48        where N: Num + Deserialize<'de> + Copy {
49            type Value = ZeroSpVec<N>;
50
51            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
52                write!(f, "struct ZeroSpVec")
53            }
54
55            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
56            where A: MapAccess<'de> {
57                let mut len = None;
58                let mut nnz = None;
59                let mut entries = None;
60                while let Some(key) = map.next_key::<String>()? {
61                    match key.as_str() {
62                        "len" => len = Some(map.next_value::<u64>()? as usize),
63                        "nnz" => nnz = Some(map.next_value::<u64>()? as usize),
64                        "entries" => entries = Some(map.next_value::<Vec<(u64, N)>>()?),
65                        _ => { let _: serde::de::IgnoredAny = map.next_value()?; }
66                    }
67                }
68                let len = len.ok_or_else(|| DeError::missing_field("len"))?;
69                let nnz = nnz.ok_or_else(|| DeError::missing_field("nnz"))?;
70                let entries = entries.ok_or_else(|| DeError::missing_field("entries"))?;
71                let mut vec = ZeroSpVec::with_capacity(nnz);
72                vec.len = len;
73                for (index, value) in entries {
74                    unsafe { vec.raw_push(index as usize, value) };
75                }
76                Ok(vec)
77            }
78        }
79        deserializer.deserialize_struct(
80            "ZeroSpVec",
81            &["len", "nnz", "entries"],
82            ZeroSpVecVisitor { marker: std::marker::PhantomData },
83        )
84    }
85}