tf_idf_vectorizer/vectorizer/
mod.rs

1pub mod corpus;
2pub mod tfidf;
3pub mod term;
4pub mod serde;
5pub mod evaluate;
6
7use std::{rc::Rc, sync::Arc};
8use std::hash::Hash;
9
10use half::f16;
11use num_traits::Num;
12use ::serde::{Deserialize, Serialize};
13
14use crate::utils::datastruct::vector::{TFVector, TFVectorTrait};
15use crate::{DefaultTFIDFEngine, TFIDFEngine, TermFrequency};
16use crate::utils::datastruct::map::IndexMap;
17use crate::Corpus;
18
19pub type KeyRc<K> = Rc<K>;
20
21#[derive(Debug, Clone)]
22pub struct TFIDFVectorizer<N = f16, K = String, E = DefaultTFIDFEngine>
23where
24    N: Num + Copy + Into<f64> + Send + Sync,
25    E: TFIDFEngine<N> + Send + Sync,
26    K: Clone + Send + Sync + Eq + std::hash::Hash,
27{
28    /// Document's TF Vector
29    pub documents: IndexMap<KeyRc<K>, TFVector<N>>,
30    /// TF Vector's term dimension sample and reverse index
31    pub term_dim_rev_index: IndexMap<Box<str>, Vec<KeyRc<K>>>,
32    /// Corpus reference
33    pub corpus_ref: Arc<Corpus>,
34    /// IDF Vector
35    pub idf_cache: IDFVector,
36    _marker: std::marker::PhantomData<E>,
37}
38
39#[derive(Debug, Serialize, Deserialize, Clone)]
40pub struct IDFVector
41{
42    /// IDF Vector it is not sparse because it is mostly filled
43    pub idf_vec: Vec<f32>,
44    /// latest entropy
45    pub latest_entropy: u64,
46    /// document count
47    pub doc_num: u64,
48}
49
50impl IDFVector
51{
52    pub fn new() -> Self {
53        Self {
54            idf_vec: Vec::new(),
55            latest_entropy: 0,
56            doc_num: 0,
57        }
58    }
59}
60
61impl <N, K, E> TFIDFVectorizer<N, K, E>
62where
63    N: Num + Copy + Into<f64> + Send + Sync,
64    E: TFIDFEngine<N> + Send + Sync,
65    K: Clone + Send + Sync + Eq + Hash,
66{
67    /// Create a new TFIDFVectorizer instance
68    pub fn new(corpus_ref: Arc<Corpus>) -> Self {
69        let mut instance = Self {
70            documents: IndexMap::new(),
71            term_dim_rev_index: IndexMap::new(),
72            corpus_ref,
73            idf_cache: IDFVector::new(),
74            _marker: std::marker::PhantomData,
75        };
76        instance.re_calc_idf();
77        instance
78    }
79
80    /// set corpus reference
81    /// and recalculate idf
82    pub fn set_corpus_ref(&mut self, corpus_ref: Arc<Corpus>) {
83        self.corpus_ref = corpus_ref;
84        self.re_calc_idf();
85    }
86
87    /// Corpusに変更があればIDFを再計算する
88    pub fn update_idf(&mut self) {
89        if self.corpus_ref.get_gen_num() != self.idf_cache.latest_entropy {
90            self.re_calc_idf();
91        }
92        // 更新がなければ何もしない
93    }
94
95    /// CorpusからIDFを再計算する
96    fn re_calc_idf(&mut self) {
97        self.idf_cache.latest_entropy = self.corpus_ref.get_gen_num();
98        self.idf_cache.doc_num = self.corpus_ref.get_doc_num();
99        self.idf_cache.idf_vec = E::idf_vec(&self.corpus_ref, self.term_dim_rev_index.keys());
100    }
101}
102
103impl <N, K, E> TFIDFVectorizer<N, K, E>
104where
105    N: Num + Copy + Into<f64> + Send + Sync,
106    E: TFIDFEngine<N> + Send + Sync,
107    K: PartialEq + Clone + Send + Sync + Eq + Hash
108{
109    /// Add a document
110    /// The immediately referenced Corpus is also updated
111    pub fn add_doc(&mut self, key: K, doc: &TermFrequency) {
112        // key_rcを作成
113        let key_rc = KeyRc::new(key);
114        if self.documents.contains_key(&key_rc) {
115            self.del_doc(&key_rc);
116        }
117        // ドキュメントのトークンをコーパスに追加
118        self.add_corpus(doc);
119        // 新語彙を差分追加 (O(|doc_vocab|))
120        for tok in doc.term_set(){
121            self.term_dim_rev_index
122                .entry_mut(tok.into_boxed_str())
123                .or_insert_with(Vec::new)
124                .push(Rc::clone(&key_rc)); // 逆Indexに追加
125        }
126
127        let tf_vec= E::tf_vec(doc, self.term_dim_rev_index.as_index_set());
128        self.documents.insert(key_rc, tf_vec);
129    }
130
131    pub fn del_doc(&mut self, key: &K)
132    where
133        K: PartialEq,
134    {
135        let rc_key = KeyRc::new(key.clone());
136        if let Some(tf_vec) = self.documents.get(&rc_key) {
137            let terms = tf_vec.raw_iter()
138                .filter_map(|(idx, _)| {
139                    let idx = idx as usize;
140                    let doc_keys = self.term_dim_rev_index.get_with_index_mut(idx);
141                    if let Some(doc_keys) = doc_keys {
142                        // 逆Indexから削除
143                        let rc_key = KeyRc::new(key.clone());
144                        doc_keys.retain(|k| *k != rc_key);
145                    }
146                    let term = self.term_dim_rev_index.get_key_with_index(idx).cloned();
147                    term
148                }).collect::<Vec<Box<str>>>();
149            // ドキュメントを削除
150            self.documents.swap_remove(&rc_key);
151            // コーパスからも削除
152            self.corpus_ref.sub_set(&terms);
153        }
154    }
155
156    /// Get TFVector by document ID
157    pub fn get_tf(&self, key: &K) -> Option<&TFVector<N>>
158    where
159        K: Eq + Hash,
160    {
161        let rc_key = KeyRc::new(key.clone());
162        self.documents.get(&rc_key)
163    }
164
165    /// Get TermFrequency by document ID
166    /// If quantized, there may be some error
167    /// Words not included in the corpus are ignored
168    pub fn get_tf_into_term_freq(&self, key: &K) -> Option<TermFrequency>
169    {
170        if let Some(tf_vec) = self.get_tf(key) {
171            let mut term_freq = TermFrequency::new();
172            tf_vec.raw_iter().for_each(|(idx, val)| {
173                let idx = idx as usize;
174                if let Some(term) = self.term_dim_rev_index.get_key_with_index(idx) {
175                    let term_num = E::tf_denorm(val);
176                    term_freq.set_term_count(term, term_num as u64);
177                } // out of range is ignored
178            });
179            Some(term_freq)
180        } else {
181            None
182        }
183    }
184
185    /// Check if a document with the given ID exists
186    pub fn contains_doc(&self, key: &K) -> bool
187    where
188        K: PartialEq,
189    {
190        let rc_key = KeyRc::new(key.clone());
191        self.documents.contains_key(&rc_key)
192    }
193
194    /// Check if the term exists in the term dimension sample
195    pub fn contains_term(&self, term: &str) -> bool {
196        self.term_dim_rev_index.contains_key(&Box::<str>::from(term))
197    }
198
199    /// Check if all terms in the given TermFrequency exist in the term dimension sample
200    pub fn contains_terms_from_freq(&self, freq: &TermFrequency) -> bool {
201        freq.term_set_ref_str().iter().all(|tok| self.contains_term(tok))
202    }
203
204    pub fn doc_num(&self) -> usize {
205        self.documents.len()
206    }
207
208    /// add document to corpus
209    /// update the referenced corpus
210    fn add_corpus(&self, doc: &TermFrequency) {
211        // add document to corpus
212        self.corpus_ref.add_set(&doc.term_set_ref_str());
213    }
214}