tf_idf_vectorizer/utils/math/vector/
serde.rs1use num::Num;
2use serde::{Serialize, Serializer, Deserialize, Deserializer};
3use serde::ser::SerializeStruct;
4use crate::utils::math::vector::ZeroSpVecTrait;
5
6use super::ZeroSpVec;
7
8impl<N> Serialize for ZeroSpVec<N>
9where
10 N: Num + Serialize + Copy,
11{
12 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
13 where S: Serializer {
14 let mut state = serializer.serialize_struct("ZeroSpVec", 3)?;
16 state.serialize_field("len", &(self.len as u64))?;
17 state.serialize_field("nnz", &(self.nnz as u64))?;
18
19 let mut entries = Vec::with_capacity(self.nnz);
21 unsafe {
23 for i in 0..self.nnz {
24 let idx = *self.ind_ptr().add(i) as u64;
25 let val = *self.val_ptr().add(i);
26 entries.push((idx, val));
27 }
28 }
29 state.serialize_field("entries", &entries)?;
30 state.end()
31 }
32}
33
34impl<'de, N> Deserialize<'de> for ZeroSpVec<N>
35where
36 N: Num + Deserialize<'de> + Copy,
37{
38 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
39 where D: Deserializer<'de> {
40 use serde::de::{Visitor, MapAccess, Error as DeError};
41 use std::fmt;
42
43 struct ZeroSpVecVisitor<N> {
44 marker: std::marker::PhantomData<N>,
45 }
46
47 impl<'de, N> Visitor<'de> for ZeroSpVecVisitor<N>
48 where N: Num + Deserialize<'de> + Copy {
49 type Value = ZeroSpVec<N>;
50
51 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
52 write!(f, "struct ZeroSpVec")
53 }
54
55 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
56 where A: MapAccess<'de> {
57 let mut len = None;
58 let mut nnz = None;
59 let mut entries = None;
60 while let Some(key) = map.next_key::<String>()? {
61 match key.as_str() {
62 "len" => len = Some(map.next_value::<u64>()? as usize),
63 "nnz" => nnz = Some(map.next_value::<u64>()? as usize),
64 "entries" => entries = Some(map.next_value::<Vec<(u64, N)>>()?),
65 _ => { let _: serde::de::IgnoredAny = map.next_value()?; }
66 }
67 }
68 let len = len.ok_or_else(|| DeError::missing_field("len"))?;
69 let nnz = nnz.ok_or_else(|| DeError::missing_field("nnz"))?;
70 let entries = entries.ok_or_else(|| DeError::missing_field("entries"))?;
71 let mut vec = ZeroSpVec::with_capacity(nnz);
72 vec.len = len;
73 for (index, value) in entries {
74 let idx_u32: u32 = u32::try_from(index).map_err(|_| DeError::custom("index overflow for u32 storage"))?;
75 unsafe { vec.raw_push(idx_u32 as usize, value) };
76 }
77 Ok(vec)
78 }
79 }
80 deserializer.deserialize_struct(
81 "ZeroSpVec",
82 &["len", "nnz", "entries"],
83 ZeroSpVecVisitor { marker: std::marker::PhantomData },
84 )
85 }
86}