tf_idf_vectorizer/vectorizer/
mod.rs

1pub mod corpus;
2pub mod tfidf;
3pub mod token;
4pub mod serde;
5pub mod evaluate;
6
7use std::{rc::Rc, sync::Arc};
8use std::hash::Hash;
9
10use num_traits::Num;
11use ::serde::{Deserialize, Serialize};
12
13use crate::utils::datastruct::map::IndexMap;
14use crate::{utils::{datastruct::{vector::{ZeroSpVec, ZeroSpVecTrait}}, normalizer::DeNormalizer}, vectorizer::{corpus::Corpus, tfidf::{DefaultTFIDFEngine, TFIDFEngine}, token::TokenFrequency}};
15
16pub type KeyRc<K> = Rc<K>;
17
18#[derive(Debug, Clone)]
19pub struct TFIDFVectorizer<N = f32, K = String, E = DefaultTFIDFEngine>
20where
21    N: Num + Copy + Into<f64> + Send + Sync,
22    E: TFIDFEngine<N, K> + Send + Sync,
23    K: Clone + Send + Sync + Eq + std::hash::Hash,
24{
25    /// Document's TF Vector
26    pub documents: IndexMap<KeyRc<K>, TFVector<N>>,
27    /// TF Vector's token dimension sample and reverse index
28    pub token_dim_rev_index: IndexMap<Box<str>, Vec<KeyRc<K>>>,
29    /// Corpus reference
30    pub corpus_ref: Arc<Corpus>,
31    /// IDF Vector
32    pub idf_cache: IDFVector<N>,
33    _marker: std::marker::PhantomData<E>,
34}
35
36/// 子要素含めて合計64bytesになってる
37#[derive(Debug, Serialize, Deserialize, Clone)]
38#[repr(align(64))]
39pub struct TFVector<N>
40where
41    N: Num + Copy,
42{
43    /// TF Vector
44    /// use sparse vector
45    pub tf_vec: ZeroSpVec<N>,
46    /// sum of tokens of this document
47    pub token_sum: u64,
48    /// denormalize number for this document
49    /// for reverse calculation to get token counts from tf values
50    pub denormalize_num: f64,
51}
52
53// サイズアサーション
54// const evaluation より
55// エラーなってたら64bytes になってないってこと
56
57#[allow(dead_code)]
58const TF_VECTOR_SIZE: usize = core::mem::size_of::<TFVector<f32>>();
59static_assertions::const_assert!(TF_VECTOR_SIZE == 64);
60#[allow(dead_code)]
61const ZSV_SIZE: usize = core::mem::size_of::<ZeroSpVec<f32>>();
62static_assertions::const_assert!(ZSV_SIZE == 48);
63
64impl<N> TFVector<N>
65where
66    N: Num + Copy,
67{
68    pub fn shrink_to_fit(&mut self) {
69        self.tf_vec.shrink_to_fit();
70    }
71}
72
73#[derive(Debug, Serialize, Deserialize, Clone)]
74pub struct IDFVector<N>
75where
76    N: Num,
77{
78    /// IDF Vector it is not sparse because it is mostly filled
79    pub idf_vec: Vec<N>,
80    /// denormalize number for idf
81    pub denormalize_num: f64,
82    /// latest entropy
83    pub latest_entropy: u64,
84    /// document count
85    pub doc_num: u64,
86}
87
88impl <N> IDFVector<N>
89where
90    N: Num,
91{
92    pub fn new() -> Self {
93        Self {
94            idf_vec: Vec::new(),
95            denormalize_num: 1.0,
96            latest_entropy: 0,
97            doc_num: 0,
98        }
99    }
100}
101
102impl <N, K, E> TFIDFVectorizer<N, K, E>
103where
104    N: Num + Copy + Into<f64> + Send + Sync,
105    E: TFIDFEngine<N, K> + Send + Sync,
106    K: Clone + Send + Sync + Eq + Hash,
107{
108    /// Create a new TFIDFVectorizer instance
109    pub fn new(corpus_ref: Arc<Corpus>) -> Self {
110        let mut instance = Self {
111            documents: IndexMap::new(),
112            token_dim_rev_index: IndexMap::new(),
113            corpus_ref,
114            idf_cache: IDFVector::new(),
115            _marker: std::marker::PhantomData,
116        };
117        instance.re_calc_idf();
118        instance
119    }
120
121    /// set corpus reference
122    /// and recalculate idf
123    pub fn set_corpus_ref(&mut self, corpus_ref: Arc<Corpus>) {
124        self.corpus_ref = corpus_ref;
125        self.re_calc_idf();
126    }
127
128    /// Corpusに変更があればIDFを再計算する
129    pub fn update_idf(&mut self) {
130        if self.corpus_ref.get_gen_num() != self.idf_cache.latest_entropy {
131            self.re_calc_idf();
132        }
133        // 更新がなければ何もしない
134    }
135
136    /// CorpusからIDFを再計算する
137    fn re_calc_idf(&mut self) {
138        self.idf_cache.latest_entropy = self.corpus_ref.get_gen_num();
139        self.idf_cache.doc_num = self.corpus_ref.get_doc_num();
140        (self.idf_cache.idf_vec, self.idf_cache.denormalize_num) = E::idf_vec(&self.corpus_ref, self.token_dim_rev_index.keys());
141    }
142}
143
144impl <N, K, E> TFIDFVectorizer<N, K, E>
145where
146    N: Num + Copy + Into<f64> + Send + Sync,
147    E: TFIDFEngine<N, K> + Send + Sync,
148    K: PartialEq + Clone + Send + Sync + Eq + Hash
149{
150    /// Add a document
151    /// The immediately referenced Corpus is also updated
152    pub fn add_doc(&mut self, key: K, doc: &TokenFrequency) {
153        // key_rcを作成
154        let key_rc = KeyRc::new(key);
155        if self.documents.contains_key(&key_rc) {
156            self.del_doc(&key_rc);
157        }
158        let token_sum = doc.token_sum();
159        // ドキュメントのトークンをコーパスに追加
160        self.add_corpus(doc);
161        // 新語彙を差分追加 (O(|doc_vocab|))
162        for tok in doc.token_set_ref_str() {
163            self.token_dim_rev_index
164                .entry_mut(&Box::from(tok))
165                .or_insert_with(Vec::new)
166                .push(Rc::clone(&key_rc)); // 逆Indexに追加
167        }
168
169        let (tf_vec, denormalize_num) = E::tf_vec(doc, self.token_dim_rev_index.keys());
170        let mut doc = TFVector {
171            tf_vec,
172            token_sum,
173            denormalize_num,
174        };
175        doc.shrink_to_fit();
176        self.documents.insert(&key_rc, doc);
177    }
178
179    pub fn del_doc(&mut self, key: &K)
180    where
181        K: PartialEq,
182    {
183        let rc_key = KeyRc::new(key.clone());
184        if let Some(doc) = self.documents.get(&rc_key) {
185            let tokens = doc.tf_vec.raw_iter()
186                .filter_map(|(idx, _)| {
187                    let doc_keys = self.token_dim_rev_index.get_with_index_mut(idx);
188                    if let Some(doc_keys) = doc_keys {
189                        // 逆Indexから削除
190                        let rc_key = KeyRc::new(key.clone());
191                        doc_keys.retain(|k| *k == rc_key);
192                    }
193                    let token = self.token_dim_rev_index.get_key_with_index(idx).cloned();
194                    token
195                }).collect::<Vec<Box<str>>>();
196            // ドキュメントを削除
197            self.documents.swap_remove(&rc_key);
198            // コーパスからも削除
199            self.corpus_ref.sub_set(&tokens);
200        }
201    }
202
203    /// Get TFVector by document ID
204    pub fn get_tf(&self, key: &K) -> Option<&TFVector<N>>
205    where
206        K: Eq + Hash,
207    {
208        let rc_key = KeyRc::new(key.clone());
209        self.documents.get(&rc_key)
210    }
211
212    /// Get TokenFrequency by document ID
213    /// If quantized, there may be some error
214    /// Words not included in the corpus are ignored
215    pub fn get_tf_into_token_freq(&self, key: &K) -> Option<TokenFrequency>
216    {
217        if let Some(tf_vec) = self.get_tf(key) {
218            let mut token_freq = TokenFrequency::new();
219            tf_vec.tf_vec.raw_iter().for_each(|(idx, val)| {
220                if let Some(token) = self.token_dim_rev_index.get_key_with_index(idx) {
221                    let val_f64: f64 = (*val).into();
222                    let token_num: f64 = tf_vec.token_sum.denormalize(tf_vec.denormalize_num) * val_f64;
223                    token_freq.set_token_count(token, token_num as u64);
224                } // out of range is ignored
225            });
226            Some(token_freq)
227        } else {
228            None
229        }
230    }
231
232    /// Check if a document with the given ID exists
233    pub fn contains_doc(&self, key: &K) -> bool
234    where
235        K: PartialEq,
236    {
237        let rc_key = KeyRc::new(key.clone());
238        self.documents.contains_key(&rc_key)
239    }
240
241    /// Check if the token exists in the token dimension sample
242    pub fn contains_token(&self, token: &str) -> bool {
243        self.token_dim_rev_index.contains_key(&Box::<str>::from(token))
244    }
245
246    /// Check if all tokens in the given TokenFrequency exist in the token dimension sample
247    pub fn contains_tokens_from_freq(&self, freq: &TokenFrequency) -> bool {
248        freq.token_set_ref_str().iter().all(|tok| self.contains_token(tok))
249    }
250
251    pub fn doc_num(&self) -> usize {
252        self.documents.len()
253    }
254
255    /// add document to corpus
256    /// update the referenced corpus
257    fn add_corpus(&mut self, doc: &TokenFrequency) {
258        // add document to corpus
259        self.corpus_ref.add_set(&doc.token_set_ref_str());
260    }
261}