tf_idf_vectorizer/vectorizer/
serde.rs

1use std::sync::Arc;
2use std::hash::Hash;
3
4use ahash::RandomState;
5use num_traits::Num;
6use serde::{ser::SerializeStruct, Deserialize, Serialize};
7
8use crate::{Corpus, TFIDFVectorizer, utils::datastruct::{map::IndexMap, vector::ZeroSpVecTrait}, vectorizer::{IDFVector, KeyRc, TFVector, tfidf::{DefaultTFIDFEngine, TFIDFEngine}}};
9
10/// Data structure for deserializing TFIDFVectorizer.
11/// This struct does not contain references, so it can be serialized.
12/// Use the `into_tf_idf_vectorizer` method to convert to `TFIDFVectorizer`.
13#[derive(Debug, Deserialize, Serialize)]
14pub struct TFIDFData<N = f32, K = String, E = DefaultTFIDFEngine>
15where
16    N: Num + Copy,
17    E: TFIDFEngine<N, K>,
18    K: Clone + Eq + Hash,
19{
20    /// TF vectors for documents
21    pub documents: IndexMap<KeyRc<K>, TFVector<N>>,
22    /// Token dimension sample for TF vectors
23    pub token_dim_sample: Vec<Box<str>>,
24    /// IDF vector
25    #[serde(default, skip_serializing, skip_deserializing)]
26    pub idf: Option<IDFVector>,
27    #[serde(default, skip_serializing, skip_deserializing)]
28    pub(crate) _marker: std::marker::PhantomData<E>,
29}
30
31impl<N, K, E> TFIDFData<N, K, E>
32where
33    N: Num + Copy + Into<f64> + Send + Sync,
34    E: TFIDFEngine<N, K>,
35    K: Clone + Send + Sync + Eq + Hash,
36{
37    /// Convert `TFIDFData` into `TFIDFVectorizer`.
38    /// `corpus_ref` is a reference to the corpus.
39    pub fn into_tf_idf_vectorizer(self, corpus_ref: Arc<Corpus>) -> TFIDFVectorizer<N, K, E>
40    {
41        let mut token_dim_rev_index: IndexMap<Box<str>, Vec<KeyRc<K>>, RandomState> =
42            IndexMap::with_capacity(self.token_dim_sample.len());
43        // 順序通りに初めに登録しておく
44        self.token_dim_sample.iter().for_each(|token| {
45            token_dim_rev_index.insert(token.clone(), Vec::new());
46        });
47        self.documents.iter().for_each(|(key, doc)| {
48            doc.tf_vec.raw_iter().for_each(|(idx, _)| {
49                token_dim_rev_index.get_with_index_mut(idx).unwrap().push(key.clone());
50            });
51        });
52
53        let mut instance = TFIDFVectorizer {
54            documents: self.documents,
55            token_dim_rev_index: token_dim_rev_index,
56            corpus_ref,
57            idf_cache: IDFVector::new(),
58            _marker: std::marker::PhantomData,
59        };
60        instance.update_idf();
61        instance
62    }
63}
64
65impl<N, K, E> Serialize for TFIDFVectorizer<N, K, E>
66where
67    N: Num + Copy + Serialize + Into<f64> + Send + Sync,
68    K: Serialize + Clone + Send + Sync + Eq + Hash,
69    E: TFIDFEngine<N, K>,
70{
71    /// Serialize TFIDFVectorizer.
72    /// This struct contains references, so they are excluded from serialization.
73    /// Use `TFIDFData` for deserialization.
74    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
75    where
76        S: serde::Serializer,
77    {
78        let mut state = serializer.serialize_struct("TFIDFVectorizer", 3)?;
79        state.serialize_field("documents", &self.documents)?;
80        state.serialize_field("token_dim_sample", &self.token_dim_rev_index.keys())?;
81        state.serialize_field("corpus_ref", &self.corpus_ref)?;
82        state.end()
83    }
84}
85
86impl<'de, N, K, E> Deserialize<'de> for TFIDFVectorizer<N, K, E>
87where
88    N: Num + Copy + Deserialize<'de> + Into<f64> + Send + Sync,
89    K: Deserialize<'de> + Clone + Send + Sync + Eq + Hash,
90    E: TFIDFEngine<N, K> + Send + Sync,
91{
92    /// Deserialize TFIDFVectorizer.
93    /// This struct contains references, so they are excluded from deserialization.
94    /// Use `TFIDFData` for deserialization.
95    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
96    where
97        D: serde::Deserializer<'de>,
98    {
99        #[derive(Deserialize)]
100        struct TFIDFVectorizerHelper<N, K>
101        where
102            N: Num + Copy,
103            K: Clone + Eq + Hash,
104        {
105            documents: IndexMap<KeyRc<K>, TFVector<N>>,
106            token_dim_sample: Vec<Box<str>>,
107            #[serde(default, skip_deserializing)]
108            corpus_ref: Arc<Corpus>,
109        }
110
111        let helper = TFIDFVectorizerHelper::<N, K>::deserialize(deserializer)?;
112        let mut token_dim_rev_index: IndexMap<Box<str>, Vec<KeyRc<K>>, RandomState> =
113            IndexMap::with_capacity(helper.token_dim_sample.len());
114        // 順序通りに初めに登録しておく
115        helper.token_dim_sample.iter().for_each(|token| {
116            token_dim_rev_index.insert(token.clone(), Vec::new());
117        });
118        helper.documents.iter().for_each(|(key, doc)| {
119            doc.tf_vec.raw_iter().for_each(|(idx, _)| {
120                token_dim_rev_index.get_with_index_mut(idx).unwrap().push(key.clone());
121            });
122        });
123
124        Ok(TFIDFVectorizer {
125            documents: helper.documents,
126            token_dim_rev_index,
127            corpus_ref: helper.corpus_ref,
128            idf_cache: IDFVector::new(),
129            _marker: std::marker::PhantomData,
130        })
131    }
132}
133
134impl<N, K, E> TFIDFVectorizer<N, K, E>
135where
136    N: Num + Copy + Serialize + Into<f64> + Send + Sync,
137    K: Serialize + Clone + Send + Sync + Eq + Hash,
138    E: TFIDFEngine<N, K> + Send + Sync,
139{
140    pub fn into_tfidf_data(self) -> TFIDFData<N, K, E> {
141        let token_dim_sample = self.token_dim_rev_index.keys().clone();
142        TFIDFData {
143            documents: self.documents,
144            token_dim_sample,
145            idf: None,
146            _marker: std::marker::PhantomData,
147        }
148    }
149}