tf_idf_vectorizer/vectorizer/
serde.rs1use std::sync::Arc;
2use std::hash::Hash;
3
4use ahash::RandomState;
5use num_traits::Num;
6use serde::{ser::SerializeStruct, Deserialize, Serialize};
7
8use crate::{Corpus, TFIDFVectorizer, utils::datastruct::{map::IndexMap, vector::TFVectorTrait}, vectorizer::{IDFVector, KeyRc, TFVector, tfidf::{DefaultTFIDFEngine, TFIDFEngine}}};
9
10#[derive(Debug, Deserialize, Serialize)]
28pub struct TFIDFData<N = f32, K = String, E = DefaultTFIDFEngine>
29where
30 N: Num + Copy,
31 E: TFIDFEngine<N>,
32 K: Clone + Eq + Hash,
33{
34 pub documents: IndexMap<KeyRc<K>, TFVector<N>>,
36 pub term_dim_sample: Vec<Box<str>>,
38 #[serde(default, skip_serializing, skip_deserializing)]
40 pub idf: Option<IDFVector>,
41 #[serde(default, skip_serializing, skip_deserializing)]
42 pub(crate) _marker: std::marker::PhantomData<E>,
43}
44
45impl<N, K, E> TFIDFData<N, K, E>
46where
47 N: Num + Copy + Into<f64> + Send + Sync,
48 E: TFIDFEngine<N>,
49 K: Clone + Send + Sync + Eq + Hash,
50{
51 pub fn into_tf_idf_vectorizer(self, corpus_ref: Arc<Corpus>) -> TFIDFVectorizer<N, K, E>
54 {
55 let mut term_dim_rev_index: IndexMap<Box<str>, Vec<KeyRc<K>>, RandomState> =
56 IndexMap::with_capacity(self.term_dim_sample.len());
57 self.term_dim_sample.iter().for_each(|term| {
59 term_dim_rev_index.insert(term.clone(), Vec::new());
60 });
61 self.documents.iter().for_each(|(key, doc_tf_vec)| {
62 doc_tf_vec.raw_iter().for_each(|(idx, _)| {
63 term_dim_rev_index.get_with_index_mut(idx as usize).unwrap().push(key.clone());
64 });
65 });
66
67 let mut instance = TFIDFVectorizer {
68 documents: self.documents,
69 term_dim_rev_index: term_dim_rev_index,
70 corpus_ref,
71 idf_cache: IDFVector::new(),
72 _marker: std::marker::PhantomData,
73 };
74 instance.update_idf();
75 instance
76 }
77}
78
79impl<N, K, E> Serialize for TFIDFVectorizer<N, K, E>
80where
81 N: Num + Copy + Serialize + Into<f64> + Send + Sync,
82 K: Serialize + Clone + Send + Sync + Eq + Hash,
83 E: TFIDFEngine<N>,
84{
85 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
89 where
90 S: serde::Serializer,
91 {
92 let mut state = serializer.serialize_struct("TFIDFVectorizer", 3)?;
93 state.serialize_field("documents", &self.documents)?;
94 state.serialize_field("term_dim_sample", &self.term_dim_rev_index.keys())?;
95 state.serialize_field("corpus_ref", &self.corpus_ref)?;
96 state.end()
97 }
98}
99
100impl<'de, N, K, E> Deserialize<'de> for TFIDFVectorizer<N, K, E>
101where
102 N: Num + Copy + Deserialize<'de> + Into<f64> + Send + Sync,
103 K: Deserialize<'de> + Clone + Send + Sync + Eq + Hash,
104 E: TFIDFEngine<N> + Send + Sync,
105{
106 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
110 where
111 D: serde::Deserializer<'de>,
112 {
113 #[derive(Deserialize)]
114 struct TFIDFVectorizerHelper<N, K>
115 where
116 N: Num + Copy,
117 K: Clone + Eq + Hash,
118 {
119 documents: IndexMap<KeyRc<K>, TFVector<N>>,
120 term_dim_sample: Vec<Box<str>>,
121 corpus_ref: Arc<Corpus>,
122 }
123
124 let helper = TFIDFVectorizerHelper::<N, K>::deserialize(deserializer)?;
125 let mut term_dim_rev_index: IndexMap<Box<str>, Vec<KeyRc<K>>, RandomState> =
126 IndexMap::with_capacity(helper.term_dim_sample.len());
127 helper.term_dim_sample.iter().for_each(|term| {
129 term_dim_rev_index.insert(term.clone(), Vec::new());
130 });
131 helper.documents.iter().for_each(|(key, doc_tf_vec)| {
132 doc_tf_vec.raw_iter().for_each(|(idx, _)| {
133 term_dim_rev_index.get_with_index_mut(idx as usize).unwrap().push(key.clone());
134 });
135 });
136
137 Ok(TFIDFVectorizer {
138 documents: helper.documents,
139 term_dim_rev_index,
140 corpus_ref: helper.corpus_ref,
141 idf_cache: IDFVector::new(),
142 _marker: std::marker::PhantomData,
143 })
144 }
145}
146
147impl<N, K, E> TFIDFVectorizer<N, K, E>
148where
149 N: Num + Copy + Serialize + Into<f64> + Send + Sync,
150 K: Serialize + Clone + Send + Sync + Eq + Hash,
151 E: TFIDFEngine<N> + Send + Sync,
152{
153 pub fn into_tfidf_data(self) -> TFIDFData<N, K, E> {
154 let term_dim_sample = self.term_dim_rev_index.keys().clone();
155 TFIDFData {
156 documents: self.documents,
157 term_dim_sample,
158 idf: None,
159 _marker: std::marker::PhantomData,
160 }
161 }
162}