tf_idf_vectorizer/vectorizer/
mod.rs

1pub mod corpus;
2pub mod tfidf;
3pub mod token;
4pub mod serde;
5pub mod evaluate;
6
7use std::sync::Arc;
8
9use indexmap::IndexSet;
10use num_traits::Num;
11use ::serde::{Deserialize, Serialize};
12
13use crate::{utils::{math::vector::{ZeroSpVec, ZeroSpVecTrait}, normalizer::DeNormalizer}, vectorizer::{corpus::Corpus, tfidf::{DefaultTFIDFEngine, TFIDFEngine}, token::TokenFrequency}};
14use ahash::RandomState;
15
16#[derive(Debug, Clone)]
17pub struct TFIDFVectorizer<N = f32, K = String, E = DefaultTFIDFEngine>
18where
19    N: Num + Copy + Into<f64> + Send + Sync,
20    E: TFIDFEngine<N> + Send + Sync,
21    K: Clone + Send + Sync,
22{
23    /// Document's TF Vector
24    pub documents: Vec<TFVector<N, K>>,
25    /// TF Vector's token dimension sample
26    pub token_dim_sample: IndexSet<Box<str>, RandomState>,
27    /// Corpus reference
28    pub corpus_ref: Arc<Corpus>,
29    /// IDF Vector
30    pub idf: IDFVector<N>,
31    _marker: std::marker::PhantomData<E>,
32}
33
34#[derive(Debug, Serialize, Deserialize, Clone)]
35pub struct TFVector<N, K>
36where
37    N: Num + Copy,
38{
39    /// TF Vector
40    /// use sparse vector
41    pub tf_vec: ZeroSpVec<N>,
42    /// sum of tokens of this document
43    pub token_sum: u64,
44    /// denormalize number for this document
45    /// for reverse calculation to get token counts from tf values
46    pub denormalize_num: f64,
47    /// Document ID
48    pub key: K,
49}
50
51impl<N, K> TFVector<N, K>
52where
53    N: Num + Copy,
54{
55    pub fn shrink_to_fit(&mut self) {
56        self.tf_vec.shrink_to_fit();
57    }
58}
59
60#[derive(Debug, Serialize, Deserialize, Clone)]
61pub struct IDFVector<N>
62where
63    N: Num,
64{
65    /// IDF Vector it is not sparse because it is mostly filled
66    pub idf_vec: Vec<N>,
67    /// denormalize number for idf
68    pub denormalize_num: f64,
69    /// latest entropy
70    pub latest_entropy: u64,
71    /// document count
72    pub doc_num: u64,
73}
74
75impl <N> IDFVector<N>
76where
77    N: Num,
78{
79    pub fn new() -> Self {
80        Self {
81            idf_vec: Vec::new(),
82            denormalize_num: 1.0,
83            latest_entropy: 0,
84            doc_num: 0,
85        }
86    }
87}
88
89impl <N, K, E> TFIDFVectorizer<N, K, E>
90where
91    N: Num + Copy + Into<f64> + Send + Sync,
92    E: TFIDFEngine<N> + Send + Sync,
93    K: Clone + Send + Sync,
94{
95    /// Create a new TFIDFVectorizer instance
96    pub fn new(corpus_ref: Arc<Corpus>) -> Self {
97        let mut instance = Self {
98            documents: Vec::new(),
99            token_dim_sample: IndexSet::with_hasher(RandomState::new()),
100            corpus_ref,
101            idf: IDFVector::new(),
102            _marker: std::marker::PhantomData,
103        };
104        instance.re_calc_idf();
105        instance
106    }
107
108    /// set corpus reference
109    /// and recalculate idf
110    pub fn set_corpus_ref(&mut self, corpus_ref: Arc<Corpus>) {
111        self.corpus_ref = corpus_ref;
112        self.re_calc_idf();
113    }
114
115    /// Corpusに変更があればIDFを再計算する
116    pub fn update_idf(&mut self) {
117        if self.corpus_ref.get_gen_num() != self.idf.latest_entropy {
118            self.re_calc_idf();
119        }
120        // 更新がなければ何もしない
121    }
122
123    /// CorpusからIDFを再計算する
124    fn re_calc_idf(&mut self) {
125        self.idf.latest_entropy = self.corpus_ref.get_gen_num();
126        self.idf.doc_num = self.corpus_ref.get_doc_num();
127        (self.idf.idf_vec, self.idf.denormalize_num) = E::idf_vec(&self.corpus_ref, &self.token_dim_sample)
128    }
129}
130
131impl <N, K, E> TFIDFVectorizer<N, K, E>
132where
133    N: Num + Copy + Into<f64> + Send + Sync,
134    E: TFIDFEngine<N> + Send + Sync,
135    K: PartialEq + Clone + Send + Sync
136{
137    /// Add a document
138    /// The immediately referenced Corpus is also updated
139    pub fn add_doc(&mut self, doc_id: K, doc: &TokenFrequency) {
140        let token_sum = doc.token_sum();
141        // ドキュメントのトークンをコーパスに追加
142        self.add_corpus(doc);
143        // 新語彙を差分追加 (O(|doc_vocab|))
144        for tok in doc.token_set_ref_str() {
145            if !self.token_dim_sample.contains(tok) {
146                self.token_dim_sample.insert(tok.into());
147            }
148        }
149
150        let (tf_vec, denormalize_num) = E::tf_vec(doc, &self.token_dim_sample);
151        let mut doc = TFVector {
152            tf_vec,
153            token_sum,
154            denormalize_num,
155            key: doc_id,
156        };
157        doc.shrink_to_fit();
158        self.documents.push(doc);
159    }
160
161    pub fn del_doc(&mut self, doc_id: &K)
162    where
163        K: PartialEq,
164    {
165        if let Some(pos) = self.documents.iter().position(|doc| &doc.key == doc_id) {
166            let doc = &self.documents[pos];
167            let token_set = doc.tf_vec.raw_iter()
168                .filter_map(|(idx, _)| self.token_dim_sample.get_index(idx).map(|s| s.as_ref()))
169                .collect::<Vec<&str>>();
170            // コーパスからドキュメントのトークンを削除
171            self.corpus_ref.sub_set(&token_set);
172            // ドキュメントを削除
173            self.documents.remove(pos);
174        }
175    }
176
177    /// Get TFVector by document ID
178    pub fn get_tf(&self, key: &K) -> Option<&TFVector<N, K>>
179    where
180        K: PartialEq,
181    {
182        self.documents.iter().find(|doc| &doc.key == key)
183    }
184
185    /// Get TokenFrequency by document ID
186    /// If quantized, there may be some error
187    /// Words not included in the corpus are ignored
188    pub fn get_tf_into_token_freq(&self, key: &K) -> Option<TokenFrequency>
189    {
190        if let Some(tf_vec) = self.get_tf(key) {
191            let mut token_freq = TokenFrequency::new();
192            tf_vec.tf_vec.raw_iter().for_each(|(idx, val)| {
193                if let Some(token) = self.token_dim_sample.get_index(idx) {
194                    let val_f64: f64 = (*val).into();
195                    let token_num: f64 = tf_vec.token_sum.denormalize(tf_vec.denormalize_num) * val_f64;
196                    token_freq.set_token_count(token, token_num as u64);
197                } // out of range is ignored
198            });
199            Some(token_freq)
200        } else {
201            None
202        }
203    }
204
205    /// Check if a document with the given ID exists
206    pub fn contains_doc(&self, key: &K) -> bool
207    where
208        K: PartialEq,
209    {
210        self.documents.iter().any(|doc| &doc.key == key)
211    }
212
213    /// Check if the token exists in the token dimension sample
214    pub fn contains_token(&self, token: &str) -> bool {
215        self.token_dim_sample.contains(token)
216    }
217
218    /// Check if all tokens in the given TokenFrequency exist in the token dimension sample
219    pub fn contains_tokens_from_freq(&self, freq: &TokenFrequency) -> bool {
220        freq.token_set_ref_str().iter().all(|tok| self.token_dim_sample.contains(*tok))
221    }
222
223    pub fn doc_num(&self) -> usize {
224        self.documents.len()
225    }
226
227    /// add document to corpus
228    /// update the referenced corpus
229    fn add_corpus(&mut self, doc: &TokenFrequency) {
230        // add document to corpus
231        self.corpus_ref.add_set(&doc.token_set_ref_str());
232    }
233}