tf_idf_vectorizer/utils/datastruct/vector/
serde.rs1use num_traits::Num;
2use serde::{Deserialize, Serialize, Serializer};
3use serde::ser::SerializeStruct;
4
5use crate::utils::datastruct::vector::{TFVector, TFVectorTrait};
6
7impl<N> Serialize for TFVector<N>
8where
9 N: Num + Serialize + Copy,
10{
11 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
12 where S: Serializer {
13 let mut state = serializer.serialize_struct("TFVector", 5)?;
15 let vals = self.as_val_slice();
16 let inds = self.as_ind_slice();
17 state.serialize_field("len", &(self.len()))?;
18 state.serialize_field("nnz", &(self.nnz()))?;
19 state.serialize_field("inds", inds)?;
20 state.serialize_field("vals", vals)?;
21 state.serialize_field("term_sum", &self.term_sum())?;
22 state.end()
23 }
24}
25
26#[derive(Deserialize)]
27struct TFVectorDe<N> {
28 len: u32,
29 #[serde(default)]
30 nnz: Option<u32>,
31 inds: Vec<u32>,
32 vals: Vec<N>,
33 term_sum: u32,
34}
35
36impl<'de, N> Deserialize<'de> for TFVector<N>
37where
38 N: Num + Deserialize<'de> + Copy,
39{
40 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41 where
42 D: serde::Deserializer<'de>,
43 {
44 let de = TFVectorDe::<N>::deserialize(deserializer)?;
45
46 debug_assert_eq!(de.inds.len(), de.vals.len());
47 debug_assert!(de.nnz.map_or(true, |n| n as usize == de.inds.len()));
48
49 Ok(unsafe { TFVector::from_vec(de.inds, de.vals, de.len, de.term_sum) })
50 }
51}