tf_idf_vectorizer/vectorizer/
mod.rs

1pub mod corpus;
2pub mod tfidf;
3pub mod token;
4pub mod serde;
5pub mod evaluate;
6
7use std::{rc::Rc, sync::Arc};
8use std::hash::Hash;
9
10use half::f16;
11use num_traits::Num;
12use ::serde::{Deserialize, Serialize};
13
14use crate::utils::datastruct::map::IndexMap;
15use crate::{utils::{datastruct::{vector::{ZeroSpVec, ZeroSpVecTrait}}, normalizer::DeNormalizer}, vectorizer::{corpus::Corpus, tfidf::{DefaultTFIDFEngine, TFIDFEngine}, token::TokenFrequency}};
16
17pub type KeyRc<K> = Rc<K>;
18
19#[derive(Debug, Clone)]
20pub struct TFIDFVectorizer<N = f16, K = String, E = DefaultTFIDFEngine>
21where
22    N: Num + Copy + Into<f64> + Send + Sync,
23    E: TFIDFEngine<N, K> + Send + Sync,
24    K: Clone + Send + Sync + Eq + std::hash::Hash,
25{
26    /// Document's TF Vector
27    pub documents: IndexMap<KeyRc<K>, TFVector<N>>,
28    /// TF Vector's token dimension sample and reverse index
29    pub token_dim_rev_index: IndexMap<Box<str>, Vec<KeyRc<K>>>,
30    /// Corpus reference
31    pub corpus_ref: Arc<Corpus>,
32    /// IDF Vector
33    pub idf_cache: IDFVector,
34    _marker: std::marker::PhantomData<E>,
35}
36
37/// 子要素含めて合計64bytesになってる
38#[derive(Debug, Serialize, Deserialize, Clone)]
39#[repr(align(64))]
40pub struct TFVector<N>
41where
42    N: Num + Copy,
43{
44    /// TF Vector
45    /// use sparse vector
46    pub tf_vec: ZeroSpVec<N>,
47    /// sum of tokens of this document
48    pub token_sum: u64,
49    /// denormalize number for this document
50    /// for reverse calculation to get token counts from tf values
51    pub denormalize_num: f64,
52}
53
54// サイズアサーション
55// const evaluation より
56// エラーなってたら64bytes になってないってこと
57
58#[allow(dead_code)]
59const TF_VECTOR_SIZE: usize = core::mem::size_of::<TFVector<f32>>();
60static_assertions::const_assert!(TF_VECTOR_SIZE == 64);
61#[allow(dead_code)]
62const ZSV_SIZE: usize = core::mem::size_of::<ZeroSpVec<f32>>();
63static_assertions::const_assert!(ZSV_SIZE == 48);
64
65impl<N> TFVector<N>
66where
67    N: Num + Copy,
68{
69    pub fn shrink_to_fit(&mut self) {
70        self.tf_vec.shrink_to_fit();
71    }
72}
73
74#[derive(Debug, Serialize, Deserialize, Clone)]
75pub struct IDFVector
76{
77    /// IDF Vector it is not sparse because it is mostly filled
78    pub idf_vec: Vec<f32>,
79    /// denormalize number for idf
80    pub denormalize_num: f64,
81    /// latest entropy
82    pub latest_entropy: u64,
83    /// document count
84    pub doc_num: u64,
85}
86
87impl IDFVector
88{
89    pub fn new() -> Self {
90        Self {
91            idf_vec: Vec::new(),
92            denormalize_num: 1.0,
93            latest_entropy: 0,
94            doc_num: 0,
95        }
96    }
97}
98
99impl <N, K, E> TFIDFVectorizer<N, K, E>
100where
101    N: Num + Copy + Into<f64> + Send + Sync,
102    E: TFIDFEngine<N, K> + Send + Sync,
103    K: Clone + Send + Sync + Eq + Hash,
104{
105    /// Create a new TFIDFVectorizer instance
106    pub fn new(corpus_ref: Arc<Corpus>) -> Self {
107        let mut instance = Self {
108            documents: IndexMap::new(),
109            token_dim_rev_index: IndexMap::new(),
110            corpus_ref,
111            idf_cache: IDFVector::new(),
112            _marker: std::marker::PhantomData,
113        };
114        instance.re_calc_idf();
115        instance
116    }
117
118    /// set corpus reference
119    /// and recalculate idf
120    pub fn set_corpus_ref(&mut self, corpus_ref: Arc<Corpus>) {
121        self.corpus_ref = corpus_ref;
122        self.re_calc_idf();
123    }
124
125    /// Corpusに変更があればIDFを再計算する
126    pub fn update_idf(&mut self) {
127        if self.corpus_ref.get_gen_num() != self.idf_cache.latest_entropy {
128            self.re_calc_idf();
129        }
130        // 更新がなければ何もしない
131    }
132
133    /// CorpusからIDFを再計算する
134    fn re_calc_idf(&mut self) {
135        self.idf_cache.latest_entropy = self.corpus_ref.get_gen_num();
136        self.idf_cache.doc_num = self.corpus_ref.get_doc_num();
137        (self.idf_cache.idf_vec, self.idf_cache.denormalize_num) = E::idf_vec(&self.corpus_ref, self.token_dim_rev_index.keys());
138    }
139}
140
141impl <N, K, E> TFIDFVectorizer<N, K, E>
142where
143    N: Num + Copy + Into<f64> + Send + Sync,
144    E: TFIDFEngine<N, K> + Send + Sync,
145    K: PartialEq + Clone + Send + Sync + Eq + Hash
146{
147    /// Add a document
148    /// The immediately referenced Corpus is also updated
149    pub fn add_doc(&mut self, key: K, doc: &TokenFrequency) {
150        // key_rcを作成
151        let key_rc = KeyRc::new(key);
152        if self.documents.contains_key(&key_rc) {
153            self.del_doc(&key_rc);
154        }
155        let token_sum = doc.token_sum();
156        // ドキュメントのトークンをコーパスに追加
157        self.add_corpus(doc);
158        // 新語彙を差分追加 (O(|doc_vocab|))
159        for tok in doc.token_set(){
160            self.token_dim_rev_index
161                .entry_mut(tok.into_boxed_str())
162                .or_insert_with(Vec::new)
163                .push(Rc::clone(&key_rc)); // 逆Indexに追加
164        }
165
166        let (mut tf_vec, denormalize_num) = E::tf_vec(doc, self.token_dim_rev_index.as_index_set());
167        tf_vec.shrink_to_fit();
168        let mut doc = TFVector {
169            tf_vec,
170            token_sum,
171            denormalize_num,
172        };
173        doc.shrink_to_fit();
174        self.documents.insert(key_rc, doc);
175    }
176
177    pub fn del_doc(&mut self, key: &K)
178    where
179        K: PartialEq,
180    {
181        let rc_key = KeyRc::new(key.clone());
182        if let Some(doc) = self.documents.get(&rc_key) {
183            let tokens = doc.tf_vec.raw_iter()
184                .filter_map(|(idx, _)| {
185                    let doc_keys = self.token_dim_rev_index.get_with_index_mut(idx);
186                    if let Some(doc_keys) = doc_keys {
187                        // 逆Indexから削除
188                        let rc_key = KeyRc::new(key.clone());
189                        doc_keys.retain(|k| *k != rc_key);
190                    }
191                    let token = self.token_dim_rev_index.get_key_with_index(idx).cloned();
192                    token
193                }).collect::<Vec<Box<str>>>();
194            // ドキュメントを削除
195            self.documents.swap_remove(&rc_key);
196            // コーパスからも削除
197            self.corpus_ref.sub_set(&tokens);
198        }
199    }
200
201    /// Get TFVector by document ID
202    pub fn get_tf(&self, key: &K) -> Option<&TFVector<N>>
203    where
204        K: Eq + Hash,
205    {
206        let rc_key = KeyRc::new(key.clone());
207        self.documents.get(&rc_key)
208    }
209
210    /// Get TokenFrequency by document ID
211    /// If quantized, there may be some error
212    /// Words not included in the corpus are ignored
213    pub fn get_tf_into_token_freq(&self, key: &K) -> Option<TokenFrequency>
214    {
215        if let Some(tf_vec) = self.get_tf(key) {
216            let mut token_freq = TokenFrequency::new();
217            tf_vec.tf_vec.raw_iter().for_each(|(idx, val)| {
218                if let Some(token) = self.token_dim_rev_index.get_key_with_index(idx) {
219                    let val_f64 = (*val).into();
220                    let token_num = tf_vec.token_sum.denormalize(tf_vec.denormalize_num) as f64 * val_f64;
221                    token_freq.set_token_count(token, token_num as u64);
222                } // out of range is ignored
223            });
224            Some(token_freq)
225        } else {
226            None
227        }
228    }
229
230    /// Check if a document with the given ID exists
231    pub fn contains_doc(&self, key: &K) -> bool
232    where
233        K: PartialEq,
234    {
235        let rc_key = KeyRc::new(key.clone());
236        self.documents.contains_key(&rc_key)
237    }
238
239    /// Check if the token exists in the token dimension sample
240    pub fn contains_token(&self, token: &str) -> bool {
241        self.token_dim_rev_index.contains_key(&Box::<str>::from(token))
242    }
243
244    /// Check if all tokens in the given TokenFrequency exist in the token dimension sample
245    pub fn contains_tokens_from_freq(&self, freq: &TokenFrequency) -> bool {
246        freq.token_set_ref_str().iter().all(|tok| self.contains_token(tok))
247    }
248
249    pub fn doc_num(&self) -> usize {
250        self.documents.len()
251    }
252
253    /// add document to corpus
254    /// update the referenced corpus
255    fn add_corpus(&self, doc: &TokenFrequency) {
256        // add document to corpus
257        self.corpus_ref.add_set(&doc.token_set_ref_str());
258    }
259}