tf_idf_vectorizer/vectorizer/
mod.rs

1pub mod corpus;
2pub mod tfidf;
3pub mod token;
4pub mod serde;
5pub mod evaluate;
6
7use std::{rc::Rc, sync::Arc};
8use std::hash::Hash;
9
10use num_traits::Num;
11use ::serde::{Deserialize, Serialize};
12
13use crate::utils::datastruct::map::IndexMap;
14use crate::{utils::{datastruct::{vector::{ZeroSpVec, ZeroSpVecTrait}}, normalizer::DeNormalizer}, vectorizer::{corpus::Corpus, tfidf::{DefaultTFIDFEngine, TFIDFEngine}, token::TokenFrequency}};
15
16pub type KeyRc<K> = Rc<K>;
17
18#[derive(Debug, Clone)]
19pub struct TFIDFVectorizer<N = f32, K = String, E = DefaultTFIDFEngine>
20where
21    N: Num + Copy + Into<f64> + Send + Sync,
22    E: TFIDFEngine<N, K> + Send + Sync,
23    K: Clone + Send + Sync + Eq + std::hash::Hash,
24{
25    /// Document's TF Vector
26    pub documents: IndexMap<KeyRc<K>, TFVector<N>>,
27    /// TF Vector's token dimension sample and reverse index
28    pub token_dim_rev_index: IndexMap<Box<str>, Vec<KeyRc<K>>>,
29    /// Corpus reference
30    pub corpus_ref: Arc<Corpus>,
31    /// IDF Vector
32    pub idf_cache: IDFVector,
33    _marker: std::marker::PhantomData<E>,
34}
35
36/// 子要素含めて合計64bytesになってる
37#[derive(Debug, Serialize, Deserialize, Clone)]
38#[repr(align(64))]
39pub struct TFVector<N>
40where
41    N: Num + Copy,
42{
43    /// TF Vector
44    /// use sparse vector
45    pub tf_vec: ZeroSpVec<N>,
46    /// sum of tokens of this document
47    pub token_sum: u64,
48    /// denormalize number for this document
49    /// for reverse calculation to get token counts from tf values
50    pub denormalize_num: f64,
51}
52
53// サイズアサーション
54// const evaluation より
55// エラーなってたら64bytes になってないってこと
56
57#[allow(dead_code)]
58const TF_VECTOR_SIZE: usize = core::mem::size_of::<TFVector<f32>>();
59static_assertions::const_assert!(TF_VECTOR_SIZE == 64);
60#[allow(dead_code)]
61const ZSV_SIZE: usize = core::mem::size_of::<ZeroSpVec<f32>>();
62static_assertions::const_assert!(ZSV_SIZE == 48);
63
64impl<N> TFVector<N>
65where
66    N: Num + Copy,
67{
68    pub fn shrink_to_fit(&mut self) {
69        self.tf_vec.shrink_to_fit();
70    }
71}
72
73#[derive(Debug, Serialize, Deserialize, Clone)]
74pub struct IDFVector
75{
76    /// IDF Vector it is not sparse because it is mostly filled
77    pub idf_vec: Vec<f32>,
78    /// denormalize number for idf
79    pub denormalize_num: f64,
80    /// latest entropy
81    pub latest_entropy: u64,
82    /// document count
83    pub doc_num: u64,
84}
85
86impl IDFVector
87{
88    pub fn new() -> Self {
89        Self {
90            idf_vec: Vec::new(),
91            denormalize_num: 1.0,
92            latest_entropy: 0,
93            doc_num: 0,
94        }
95    }
96}
97
98impl <N, K, E> TFIDFVectorizer<N, K, E>
99where
100    N: Num + Copy + Into<f64> + Send + Sync,
101    E: TFIDFEngine<N, K> + Send + Sync,
102    K: Clone + Send + Sync + Eq + Hash,
103{
104    /// Create a new TFIDFVectorizer instance
105    pub fn new(corpus_ref: Arc<Corpus>) -> Self {
106        let mut instance = Self {
107            documents: IndexMap::new(),
108            token_dim_rev_index: IndexMap::new(),
109            corpus_ref,
110            idf_cache: IDFVector::new(),
111            _marker: std::marker::PhantomData,
112        };
113        instance.re_calc_idf();
114        instance
115    }
116
117    /// set corpus reference
118    /// and recalculate idf
119    pub fn set_corpus_ref(&mut self, corpus_ref: Arc<Corpus>) {
120        self.corpus_ref = corpus_ref;
121        self.re_calc_idf();
122    }
123
124    /// Corpusに変更があればIDFを再計算する
125    pub fn update_idf(&mut self) {
126        if self.corpus_ref.get_gen_num() != self.idf_cache.latest_entropy {
127            self.re_calc_idf();
128        }
129        // 更新がなければ何もしない
130    }
131
132    /// CorpusからIDFを再計算する
133    fn re_calc_idf(&mut self) {
134        self.idf_cache.latest_entropy = self.corpus_ref.get_gen_num();
135        self.idf_cache.doc_num = self.corpus_ref.get_doc_num();
136        (self.idf_cache.idf_vec, self.idf_cache.denormalize_num) = E::idf_vec(&self.corpus_ref, self.token_dim_rev_index.keys());
137    }
138}
139
140impl <N, K, E> TFIDFVectorizer<N, K, E>
141where
142    N: Num + Copy + Into<f64> + Send + Sync,
143    E: TFIDFEngine<N, K> + Send + Sync,
144    K: PartialEq + Clone + Send + Sync + Eq + Hash
145{
146    /// Add a document
147    /// The immediately referenced Corpus is also updated
148    pub fn add_doc(&mut self, key: K, doc: &TokenFrequency) {
149        // key_rcを作成
150        let key_rc = KeyRc::new(key);
151        if self.documents.contains_key(&key_rc) {
152            self.del_doc(&key_rc);
153        }
154        let token_sum = doc.token_sum();
155        // ドキュメントのトークンをコーパスに追加
156        self.add_corpus(doc);
157        // 新語彙を差分追加 (O(|doc_vocab|))
158        for tok in doc.token_set(){
159            self.token_dim_rev_index
160                .entry_mut(tok.into_boxed_str())
161                .or_insert_with(Vec::new)
162                .push(Rc::clone(&key_rc)); // 逆Indexに追加
163        }
164
165        let (mut tf_vec, denormalize_num) = E::tf_vec(doc, self.token_dim_rev_index.as_index_set());
166        tf_vec.shrink_to_fit();
167        let mut doc = TFVector {
168            tf_vec,
169            token_sum,
170            denormalize_num,
171        };
172        doc.shrink_to_fit();
173        self.documents.insert(key_rc, doc);
174    }
175
176    pub fn del_doc(&mut self, key: &K)
177    where
178        K: PartialEq,
179    {
180        let rc_key = KeyRc::new(key.clone());
181        if let Some(doc) = self.documents.get(&rc_key) {
182            let tokens = doc.tf_vec.raw_iter()
183                .filter_map(|(idx, _)| {
184                    let doc_keys = self.token_dim_rev_index.get_with_index_mut(idx);
185                    if let Some(doc_keys) = doc_keys {
186                        // 逆Indexから削除
187                        let rc_key = KeyRc::new(key.clone());
188                        doc_keys.retain(|k| *k != rc_key);
189                    }
190                    let token = self.token_dim_rev_index.get_key_with_index(idx).cloned();
191                    token
192                }).collect::<Vec<Box<str>>>();
193            // ドキュメントを削除
194            self.documents.swap_remove(&rc_key);
195            // コーパスからも削除
196            self.corpus_ref.sub_set(&tokens);
197        }
198    }
199
200    /// Get TFVector by document ID
201    pub fn get_tf(&self, key: &K) -> Option<&TFVector<N>>
202    where
203        K: Eq + Hash,
204    {
205        let rc_key = KeyRc::new(key.clone());
206        self.documents.get(&rc_key)
207    }
208
209    /// Get TokenFrequency by document ID
210    /// If quantized, there may be some error
211    /// Words not included in the corpus are ignored
212    pub fn get_tf_into_token_freq(&self, key: &K) -> Option<TokenFrequency>
213    {
214        if let Some(tf_vec) = self.get_tf(key) {
215            let mut token_freq = TokenFrequency::new();
216            tf_vec.tf_vec.raw_iter().for_each(|(idx, val)| {
217                if let Some(token) = self.token_dim_rev_index.get_key_with_index(idx) {
218                    let val_f64 = (*val).into();
219                    let token_num = tf_vec.token_sum.denormalize(tf_vec.denormalize_num) as f64 * val_f64;
220                    token_freq.set_token_count(token, token_num as u64);
221                } // out of range is ignored
222            });
223            Some(token_freq)
224        } else {
225            None
226        }
227    }
228
229    /// Check if a document with the given ID exists
230    pub fn contains_doc(&self, key: &K) -> bool
231    where
232        K: PartialEq,
233    {
234        let rc_key = KeyRc::new(key.clone());
235        self.documents.contains_key(&rc_key)
236    }
237
238    /// Check if the token exists in the token dimension sample
239    pub fn contains_token(&self, token: &str) -> bool {
240        self.token_dim_rev_index.contains_key(&Box::<str>::from(token))
241    }
242
243    /// Check if all tokens in the given TokenFrequency exist in the token dimension sample
244    pub fn contains_tokens_from_freq(&self, freq: &TokenFrequency) -> bool {
245        freq.token_set_ref_str().iter().all(|tok| self.contains_token(tok))
246    }
247
248    pub fn doc_num(&self) -> usize {
249        self.documents.len()
250    }
251
252    /// add document to corpus
253    /// update the referenced corpus
254    fn add_corpus(&self, doc: &TokenFrequency) {
255        // add document to corpus
256        self.corpus_ref.add_set(&doc.token_set_ref_str());
257    }
258}