tf_idf_vectorizer/utils/datastruct/vector/
serde.rs

1use num_traits::Num;
2use serde::{Deserialize, Serialize, Serializer};
3use serde::ser::SerializeStruct;
4
5use crate::utils::datastruct::vector::{TFVector, TFVectorTrait};
6
7impl<N> Serialize for TFVector<N>
8where
9    N: Num + Serialize + Copy,
10{
11    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
12    where S: Serializer {
13        // シリアライズするフィールドは len, nnz, inds, vals, term_sum
14        let mut state = serializer.serialize_struct("TFVector", 5)?;
15        let vals = self.as_val_slice();
16        let inds = self.as_ind_slice();
17        state.serialize_field("len", &(self.len()))?;
18        state.serialize_field("nnz", &(self.nnz()))?;
19        state.serialize_field("inds", inds)?;
20        state.serialize_field("vals", vals)?;
21        state.serialize_field("term_sum", &self.term_sum())?;
22        state.end()
23    }
24}
25
26#[derive(Deserialize)]
27struct TFVectorDe<N> {
28    len: u32,
29    #[serde(default)]
30    nnz: Option<u32>,
31    inds: Vec<u32>,
32    vals: Vec<N>,
33    term_sum: u32,
34}
35
36impl<'de, N> Deserialize<'de> for TFVector<N>
37where
38    N: Num + Deserialize<'de> + Copy,
39{
40    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41    where
42        D: serde::Deserializer<'de>,
43    {
44        let de = TFVectorDe::<N>::deserialize(deserializer)?;
45
46        debug_assert_eq!(de.inds.len(), de.vals.len());
47        debug_assert!(de.nnz.map_or(true, |n| n as usize == de.inds.len()));
48
49        Ok(unsafe { TFVector::from_vec(de.inds, de.vals, de.len, de.term_sum) })
50    }
51}