tf_idf_vectorizer/vectorizer/
serde.rs1use std::sync::Arc;
2use std::hash::Hash;
3
4use ahash::RandomState;
5use num_traits::Num;
6use serde::{ser::SerializeStruct, Deserialize, Serialize};
7
8use crate::{Corpus, TFIDFVectorizer, utils::datastruct::{map::IndexMap, vector::ZeroSpVecTrait}, vectorizer::{IDFVector, KeyRc, TFVector, tfidf::{DefaultTFIDFEngine, TFIDFEngine}}};
9
10#[derive(Debug, Deserialize, Serialize)]
14pub struct TFIDFData<N = f32, K = String, E = DefaultTFIDFEngine>
15where
16 N: Num + Copy,
17 E: TFIDFEngine<N, K>,
18 K: Clone + Eq + Hash,
19{
20 pub documents: IndexMap<KeyRc<K>, TFVector<N>>,
22 pub token_dim_sample: Vec<Box<str>>,
24 #[serde(default, skip_serializing, skip_deserializing)]
26 pub idf: Option<IDFVector>,
27 #[serde(default, skip_serializing, skip_deserializing)]
28 pub(crate) _marker: std::marker::PhantomData<E>,
29}
30
31impl<N, K, E> TFIDFData<N, K, E>
32where
33 N: Num + Copy + Into<f64> + Send + Sync,
34 E: TFIDFEngine<N, K>,
35 K: Clone + Send + Sync + Eq + Hash,
36{
37 pub fn into_tf_idf_vectorizer(self, corpus_ref: Arc<Corpus>) -> TFIDFVectorizer<N, K, E>
40 {
41 let mut token_dim_rev_index: IndexMap<Box<str>, Vec<KeyRc<K>>, RandomState> =
42 IndexMap::with_capacity(self.token_dim_sample.len());
43 self.token_dim_sample.iter().for_each(|token| {
45 token_dim_rev_index.insert(token.clone(), Vec::new());
46 });
47 self.documents.iter().for_each(|(key, doc)| {
48 doc.tf_vec.raw_iter().for_each(|(idx, _)| {
49 token_dim_rev_index.get_with_index_mut(idx).unwrap().push(key.clone());
50 });
51 });
52
53 let mut instance = TFIDFVectorizer {
54 documents: self.documents,
55 token_dim_rev_index: token_dim_rev_index,
56 corpus_ref,
57 idf_cache: IDFVector::new(),
58 _marker: std::marker::PhantomData,
59 };
60 instance.update_idf();
61 instance
62 }
63}
64
65impl<N, K, E> Serialize for TFIDFVectorizer<N, K, E>
66where
67 N: Num + Copy + Serialize + Into<f64> + Send + Sync,
68 K: Serialize + Clone + Send + Sync + Eq + Hash,
69 E: TFIDFEngine<N, K>,
70{
71 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
75 where
76 S: serde::Serializer,
77 {
78 let mut state = serializer.serialize_struct("TFIDFVectorizer", 3)?;
79 state.serialize_field("documents", &self.documents)?;
80 state.serialize_field("token_dim_sample", &self.token_dim_rev_index.keys())?;
81 state.serialize_field("corpus_ref", &self.corpus_ref)?;
82 state.end()
83 }
84}
85
86impl<'de, N, K, E> Deserialize<'de> for TFIDFVectorizer<N, K, E>
87where
88 N: Num + Copy + Deserialize<'de> + Into<f64> + Send + Sync,
89 K: Deserialize<'de> + Clone + Send + Sync + Eq + Hash,
90 E: TFIDFEngine<N, K> + Send + Sync,
91{
92 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
96 where
97 D: serde::Deserializer<'de>,
98 {
99 #[derive(Deserialize)]
100 struct TFIDFVectorizerHelper<N, K>
101 where
102 N: Num + Copy,
103 K: Clone + Eq + Hash,
104 {
105 documents: IndexMap<KeyRc<K>, TFVector<N>>,
106 token_dim_sample: Vec<Box<str>>,
107 #[serde(default, skip_deserializing)]
108 corpus_ref: Arc<Corpus>,
109 }
110
111 let helper = TFIDFVectorizerHelper::<N, K>::deserialize(deserializer)?;
112 let mut token_dim_rev_index: IndexMap<Box<str>, Vec<KeyRc<K>>, RandomState> =
113 IndexMap::with_capacity(helper.token_dim_sample.len());
114 helper.token_dim_sample.iter().for_each(|token| {
116 token_dim_rev_index.insert(token.clone(), Vec::new());
117 });
118 helper.documents.iter().for_each(|(key, doc)| {
119 doc.tf_vec.raw_iter().for_each(|(idx, _)| {
120 token_dim_rev_index.get_with_index_mut(idx).unwrap().push(key.clone());
121 });
122 });
123
124 Ok(TFIDFVectorizer {
125 documents: helper.documents,
126 token_dim_rev_index,
127 corpus_ref: helper.corpus_ref,
128 idf_cache: IDFVector::new(),
129 _marker: std::marker::PhantomData,
130 })
131 }
132}
133
134impl<N, K, E> TFIDFVectorizer<N, K, E>
135where
136 N: Num + Copy + Serialize + Into<f64> + Send + Sync,
137 K: Serialize + Clone + Send + Sync + Eq + Hash,
138 E: TFIDFEngine<N, K> + Send + Sync,
139{
140 pub fn into_tfidf_data(self) -> TFIDFData<N, K, E> {
141 let token_dim_sample = self.token_dim_rev_index.keys().clone();
142 TFIDFData {
143 documents: self.documents,
144 token_dim_sample,
145 idf: None,
146 _marker: std::marker::PhantomData,
147 }
148 }
149}