tf_idf_vectorizer/vectorizer/
token.rs

1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3
4use indexmap::IndexMap;
5use num::Num;
6use serde::{Deserialize, Serialize};
7
8use crate::utils::normalizer::IntoNormalizer;
9
10///  TokenFrequency 構造体
11/// tokenの出現頻度を管理するための構造体です
12/// tokenの出現回数をカウントし、TF-IDFの計算を行います
13/// 
14/// # Examples
15/// ```
16/// use vectorizer::token::TokenFrequency;
17/// let mut token_freq = TokenFrequency::new();
18/// token_freq.add_token("token1");
19/// token_freq.add_token("token2");
20/// token_freq.add_token("token1");
21/// 
22/// let tf = token_freq.tf_vector::<f64>();
23/// println!("{:?}", tf);
24/// ```
25#[derive(Serialize, Deserialize, Debug, Clone)]
26pub struct TokenFrequency {
27    #[serde(with = "indexmap::map::serde_seq")]
28    token_count: IndexMap<String, u32>,
29    total_token_count: u64,
30}
31
32/// Tokenの追加、削除の実装
33impl TokenFrequency {
34    /// 新しいTokenFrequencyを作成するメソッド
35    pub fn new() -> Self {
36        TokenFrequency {
37            token_count: IndexMap::new(),
38            total_token_count: 0,
39        }
40    }
41
42    /// tokenを追加する
43    /// 
44    /// # Arguments
45    /// * `token` - 追加するトークン
46    #[inline]
47    pub fn add_token(&mut self, token: &str) -> &mut Self {
48        let count = self.token_count.entry(token.to_string()).or_insert(0);
49        *count += 1;
50        self.total_token_count += 1;
51        self
52    }
53
54    /// 複数のtokenを追加する
55    /// 
56    /// # Arguments
57    /// * `tokens` - 追加するトークンのスライス
58    #[inline]
59    pub fn add_tokens<T>(&mut self, tokens: &[T]) -> &mut Self 
60    where T: AsRef<str> 
61    {
62        for token in tokens {
63            let token_str = token.as_ref();
64            self.add_token(token_str);
65        }
66        self
67    }
68
69    // /// tokenを引く
70    // /// 
71    // /// # Arguments
72    // /// * `token` - 引くトークン
73    // #[inline]
74    // pub fn sub_token(&mut self, token: &str) -> &mut Self {
75    //     if let Some(count) = self.token_count.get_mut(token) {
76    //         if *count > 1 {
77    //             *count -= 1;
78    //             self.total_token_count -= 1;
79    //         } else if *count == 1 {
80    //             self.token_count.remove(token);
81    //             self.total_token_count -= 1;
82    //         }
83    //     }
84    //     self
85    // }
86
87    // /// 複数のtokenを引く
88    // /// 
89    // /// # Arguments
90    // /// * `tokens` - 引くトークンのスライス
91    // #[inline]
92    // pub fn sub_tokens<T>(&mut self, tokens: &[T]) -> &mut Self 
93    // where T: AsRef<str>
94    // {
95    //     for token in tokens {
96    //         let token_str = token.as_ref();
97    //         self.sub_token(token_str);
98    //     }
99    //     self
100    // }
101
102    /// tokenの出現回数を指定する
103    /// 
104    /// # Arguments
105    /// * `token` - トークン
106    /// * `count` - 出現回数
107    #[deprecated(note = "countに0を指定した場合、token_numはそれを1つのユニークなtokenとしてカウントします。
108    このメソッドは、token_numのカウントを不正にする可能性があるため、非推奨です")]
109    pub fn set_token_count(&mut self, token: &str, count: u32) -> &mut Self {
110        if let Some(existing_count) = self.token_count.get_mut(token) {
111            if count >= *existing_count {
112                self.total_token_count += count as u64 - *existing_count as u64;
113            } else {
114                self.total_token_count -= *existing_count as u64 - count as u64;
115            }
116            *existing_count = count;
117        } else {
118            self.token_count.insert(token.to_string(), count);
119            self.total_token_count += count as u64;
120        }
121        self
122    }
123}
124
125/// TF-calculationの実装
126impl TokenFrequency
127{
128    /// TFの計算メソッド
129    /// 
130    /// # Arguments
131    /// * `max_count` - 最大カウント
132    /// * `count` - カウント
133    /// 
134    /// # Returns
135    /// * `f64` - TFの値 (0.0~1.0)
136    #[inline]
137    pub fn tf_calc(max_count: u32, count: u32) -> f64 {
138        if count == 0 {
139            return 0.0;
140        }
141        (count as f64 + 1.0).ln() / (max_count as f64 + 1.0).ln()
142    }
143
144    /// 全tokenのTFを取得します
145    /// 
146    /// # Returns
147    /// * `Vec<(String, N)>` - トークンとそのTFのベクタ
148    #[inline]
149    pub fn tf_vector<N>(&self) -> Vec<(String, N)> 
150    where f64: IntoNormalizer<N>, N: Num {
151        let max_count = self.most_frequent_token_count();
152        self.token_count
153            .iter()
154            .map(|(token, &count)| {
155                (token.clone(), Self::tf_calc(max_count, count).into_normalized())
156            })
157            .collect()
158    }
159
160    /// 全tokenのTFを取得します
161    /// 文字列はこれの参照を返します
162    /// 
163    /// # Returns
164    /// * `Vec<(&str, N)>` - トークンとそのTFのベクタ
165    #[inline]
166    pub fn tf_vector_ref_str<N>(&self) -> Vec<(&str, N)>
167    where f64: IntoNormalizer<N>, N: Num {
168        let max_count = self.most_frequent_token_count();
169        self.token_count
170            .iter()
171            .map(|(token, &count)| {
172                (token.as_str(), Self::tf_calc(max_count, count).into_normalized())
173            })
174            .collect()
175    }
176
177    /// 全tokenのTFを取得します
178    /// 
179    /// # Returns
180    /// * `HashMap<String, N>` - トークンとそのTFのハッシュマップ
181    #[inline]
182    pub fn tf_hashmap<N>(&self) -> HashMap<String, N> 
183    where f64: IntoNormalizer<N>, N: Num {
184        let max_count = self.most_frequent_token_count();
185        self.token_count
186            .iter()
187            .map(|(token, &count)| {
188                (token.clone(), Self::tf_calc(max_count, count).into_normalized())
189            })
190            .collect()
191    }
192
193    /// 全tokenのTFを取得します
194    /// 文字列はこれの参照を返します
195    /// 
196    /// # Returns
197    /// * `HashMap<&str, N>` - トークンとそのTFのハッシュマップ
198    #[inline]
199    pub fn tf_hashmap_ref_str<N>(&self) -> HashMap<&str, N> 
200    where f64: IntoNormalizer<N>, N: Num {
201        let max_count = self.most_frequent_token_count();
202        self.token_count
203            .iter()
204            .map(|(token, &count)| {
205                (token.as_str(), Self::tf_calc(max_count, count).into_normalized())
206            })
207            .collect()
208    }
209
210    /// 特定のtokenのTFを取得します
211    /// 
212    /// # Arguments
213    /// * `token` - トークン
214    /// 
215    /// # Returns
216    /// * `N` - トークンのTF
217    #[inline]
218    pub fn tf_token<N>(&self, token: &str) -> N 
219    where f64: IntoNormalizer<N>, N: Num{
220        let max_count = self.most_frequent_token_count();
221        let count = self.token_count.get(token).copied().unwrap_or(0);
222        Self::tf_calc(max_count, count).into_normalized()
223    }
224}
225
226/// IDF-calculationの実装
227impl TokenFrequency {
228    /// 最大IDFの計算
229    /// 正規化は行われません
230    #[inline]
231    fn idf_max(&self, total_doc_count: u64) -> f64 {
232        (1.0 + total_doc_count as f64 / (2.0)).ln()
233    }
234
235    /// IDFの計算
236    /// 
237    /// # Arguments
238    /// * `total_doc_count` - 全ドキュメント数
239    /// * `max_idf` - 最大IDF
240    /// * `doc_count` - ドキュメント内のトークン数
241    /// 
242    /// # Returns
243    /// * `f64` - IDFの値 (0.0~1.0)
244    #[inline]
245    pub fn idf_calc(total_doc_count: u64, max_idf: f64, doc_count: u32) -> f64 {
246        (1.0 + total_doc_count as f64 / (1.0 + doc_count as f64)).ln() / max_idf
247    }
248
249    /// 全tokenのIDFを取得します
250    /// 
251    /// # Arguments
252    /// * `total_doc_count` - 全ドキュメント数
253    /// 
254    /// # Returns
255    /// * `Vec<(String, N)>` - トークンとそのIDFのベクタ
256    #[inline]
257    pub fn idf_vector<N>(&self, total_doc_count: u64) -> Vec<(String, N)> 
258    where f64: IntoNormalizer<N>, N: Num {
259        self.token_count
260            .iter()
261            .map(|(token, &doc_count)| {
262                let idf = Self::idf_calc(total_doc_count, self.idf_max(total_doc_count), doc_count);
263                (token.clone(), idf.into_normalized())
264            })
265            .collect()
266    }
267
268    /// 全tokenのIDFを取得します
269    /// 文字列はこれの参照を返します
270    ///
271    /// # Arguments
272    /// * `total_doc_count` - 全ドキュメント数
273    /// 
274    /// # Returns
275    /// * `Vec<(&str, N)>` - トークンとそのIDFのベクタ
276    #[inline]
277    pub fn idf_vector_ref_str<N>(&self, total_doc_count: u64) -> Vec<(&str, N)> 
278    where f64: IntoNormalizer<N>, N: Num {
279        self.token_count.iter().map(|(token, &doc_count)| {
280            let idf = Self::idf_calc(total_doc_count, self.idf_max(total_doc_count), doc_count);
281            (token.as_str(), idf.into_normalized())
282        }).collect()
283    }
284
285
286    /// 全tokenのIDFを取得します
287    ///     
288    /// # Arguments
289    /// * `total_doc_count` - 全ドキュメント数
290    /// 
291    /// # Returns
292    /// * `HashMap<String, N>` - トークンとそのIDFのハッシュマップ
293    #[inline]
294    pub fn idf_hashmap<N>(&self, total_doc_count: u64) -> HashMap<String, N> 
295    where f64: IntoNormalizer<N>, N: Num {
296        self.token_count
297            .iter()
298            .map(|(token, &doc_count)| {
299                let idf = Self::idf_calc(total_doc_count, self.idf_max(total_doc_count), doc_count);
300                (token.clone(), idf.into_normalized())
301            })
302            .collect()
303    }
304
305    /// 全tokenのIDFを取得します
306    /// 文字列はこれの参照を返します
307    /// 
308    /// # Arguments
309    /// * `total_doc_count` - 全ドキュメント数
310    /// 
311    /// # Returns
312    /// * `HashMap<&str, N>` - トークンとそのIDFのハッシュマップ
313    #[inline]
314    pub fn idf_hashmap_ref_str<N>(&self, total_doc_count: u64) -> HashMap<&str, N> 
315    where f64: IntoNormalizer<N>, N: Num {
316        self.token_count.iter().map(|(token, &doc_count)| {
317            let idf = Self::idf_calc(total_doc_count, self.idf_max(total_doc_count), doc_count);
318            (token.as_str(), idf.into_normalized())
319        }).collect()
320    }
321}
322
323/// TokenFrequencyの情報を取得するための実装
324impl TokenFrequency {
325    /// すべてのtokenの出現回数を取得します
326    /// 
327    /// # Returns
328    /// * `Vec<(String, u32)>` - トークンとその出現回数のベクタ
329    #[inline]
330    pub fn token_count_vector(&self) -> Vec<(String, u32)> {
331        self.token_count.iter().map(|(token, &count)| {
332            (token.clone(), count)
333        }).collect()
334    }
335
336    /// すべてのtokenの出現回数を取得します
337    /// 文字列はこれの参照を返します
338    /// 
339    /// # Returns
340    /// * `Vec<(&str, u32)>` - トークンとその出現回数のベクタ
341    #[inline]
342    pub fn token_count_vector_ref_str(&self) -> Vec<(&str, u32)> {
343        self.token_count.iter().map(|(token, &count)| {
344            (token.as_str(), count)
345        }).collect()
346    }
347
348    /// すべてのtokenの出現回数を取得します
349    /// 文字列はこれの参照を返します
350    /// 
351    /// # Returns
352    /// * `HashMap<&str, u32>` - トークンとその出現回数のハッシュマップ
353    #[inline]
354    pub fn token_count_hashmap_ref_str(&self) -> HashMap<&str, u32> {
355        self.token_count.iter().map(|(token, &count)| {
356            (token.as_str(), count)
357        }).collect()
358    }
359
360    /// 全tokenのカウントの合計を取得します
361    /// 
362    /// # Returns
363    /// * `u64` - tokenのカウントの合計
364    #[inline]
365    pub fn token_total_count(&self) -> u64 {
366        self.total_token_count
367    }
368
369    /// あるtokenの出現回数を取得します
370    /// 
371    /// # Arguments
372    /// * `token` - トークン
373    /// 
374    /// # Returns
375    /// * `u32` - トークンの出現回数
376    #[inline]
377    pub fn token_count(&self, token: &str) -> u32 {
378        *self.token_count.get(token).unwrap_or(&0)
379    }
380
381    /// もっとも多く出現したtokenを取得します
382    /// 同じ出現回数のtokenが複数ある場合は、すべてのtokenを取得します
383    /// 
384    /// # Returns
385    /// * `Vec<(String, u32)>` - トークンとその出現回数のベクタ
386    #[inline]
387    pub fn most_frequent_tokens_vector(&self) -> Vec<(String, u32)> {
388        if let Some(&max_count) = self.token_count.values().max() {
389            self.token_count.iter()
390                .filter(|&(_, &count)| count == max_count)
391                .map(|(token, &count)| (token.clone(), count))
392                .collect()
393        } else {
394            Vec::new()
395        }
396    }
397
398    /// もっとも多く出現したtokenを取得します
399    /// 
400    /// # Returns
401    /// * `u32` - 再頻出tokenの出現回数
402    #[inline]
403    pub fn most_frequent_token_count(&self) -> u32 {
404        if let Some(&max_count) = self.token_count.values().max() {
405            max_count
406        } else {
407            0
408        }
409    }
410
411    /// tokenが存在するかどうかを確認します
412    /// 
413    /// # Arguments
414    /// * `token` - トークン
415    /// 
416    /// # Returns
417    /// * `bool` - tokenが存在する場合はtrue、存在しない場合はfalse
418    #[inline]
419    pub fn contains_token(&self, token: &str) -> bool {
420        self.token_count.contains_key(token)
421    }
422
423    /// tokenのsetを取得します
424    /// 
425    /// # Returns
426    /// * `Vec<String>` - tokenのset
427    #[inline]
428    pub fn token_set(&self) -> Vec<String> {
429        self.token_count.keys().cloned().collect()
430    }
431
432    /// tokenのsetを取得します
433    /// 文字列はこれの参照を返します
434    /// 
435    /// # Returns
436    /// * `Vec<&str>` - tokenのset
437    #[inline]
438    pub fn token_set_ref_str(&self) -> Vec<&str> {
439        self.token_count.keys().map(|s| s.as_str()).collect()
440    }
441
442    /// tokenのsetを取得します
443    /// 
444    /// # Returns
445    /// * `HashSet<String>` - tokenのset
446    #[inline]
447    pub fn token_hashset(&self) -> HashSet<String> {
448        self.token_count.keys().cloned().collect()
449    }
450
451    /// tokenのsetを取得します
452    /// 文字列はこれの参照を返します
453    /// 
454    /// # Returns
455    /// * `HashSet<&str>` - tokenのset
456    #[inline]
457    pub fn token_hashset_ref_str(&self) -> HashSet<&str> {
458        self.token_count.keys().map(|s| s.as_str()).collect()
459    }
460
461    /// 出現した単語数を取得します
462    /// 
463    /// # Returns
464    /// * `usize` - 出現した単語数
465    #[inline]
466    pub fn token_num(&self) -> usize {
467        self.token_count.len()
468    }
469
470    // /// stop_tokenを削除します
471    // /// 
472    // /// # Arguments
473    // /// * `stop_tokens` - 削除するトークンのスライス
474    // /// 
475    // /// # Returns
476    // /// * `u64` - 削除されたtokenの合計数
477    // #[inline]
478    // pub fn remove_stop_tokens(&mut self, stop_tokens: &[&str]) -> u64{
479    //     let mut removed_total_count: u64 = 0;
480    //     for &stop_token in stop_tokens {
481    //         if let Some(count) = self.token_count.remove(stop_token) {
482    //             removed_total_count += count as u64;
483    //         }
484    //     }
485    //     self.total_token_count -= removed_total_count;
486    //     removed_total_count
487    // }
488
489    /// 条件に基づいてtokenを削除します
490    /// 
491    /// # Arguments
492    /// * `condition` - 条件を満たすtokenを削除するクロージャ
493    /// 
494    /// # Returns
495    /// * `u64` - 削除されたtokenの合計数
496    #[inline]
497    pub fn remove_tokens_by_condition<F>(&mut self, condition: F) -> u64
498    where
499        F: Fn(&str, &u32) -> bool,
500    {
501        let mut removed_total_count: u64 = 0;
502        self.token_count.retain(|token, count| {
503            if condition(token, count) {
504                removed_total_count += *count as u64;
505                false
506            } else {
507                true
508            }
509        });
510        self.total_token_count -= removed_total_count as u64;
511
512        removed_total_count
513    }
514
515    /// 頻度でソートされたトークンのベクタを取得(降順)
516    /// 
517    /// # Returns
518    /// * `Vec<(String, u32)>` - 頻度でソートされたトークンのベクタ
519    #[inline]
520    pub fn sorted_frequency_vector(&self) -> Vec<(String, u32)> {
521        let mut token_list: Vec<(String, u32)> = self.token_count
522            .iter()
523            .map(|(token, &count)| (token.clone(), count))
524            .collect();
525
526        token_list.sort_by(|a, b| b.1.cmp(&a.1));
527        token_list
528    }
529
530    /// 辞書順でソートされたトークンのベクタを取得(昇順)
531    /// 
532    /// # Returns
533    /// * `Vec<(String, u32)>` - 辞書順でソートされたトークンのベクタ
534    #[inline]
535    pub fn sorted_dict_order_vector(&self) -> Vec<(String, u32)> {
536        let mut token_list: Vec<(String, u32)> = self.token_count
537            .iter()
538            .map(|(token, &count)| (token.clone(), count))
539            .collect();
540
541        token_list.sort_by(|a, b| a.0.cmp(&b.0));
542        token_list
543    }
544
545    /// tokenの多様性を計算します
546    /// 1.0は完全な多様性を示し、0.0は完全な非多様性を示します
547    /// 
548    /// # Returns
549    /// * `f64` - tokenの多様性
550    #[inline]
551    pub fn unique_token_ratio(&self) -> f64 {
552        if self.total_token_count == 0 {
553            return 0.0;
554        }
555        self.token_count.len() as f64 / self.total_token_count as f64
556    }
557
558    /// カウントを全リセットします
559    #[inline]
560    pub fn clear(&mut self) {
561        self.token_count.clear();
562        self.total_token_count = 0;
563    }
564}
565
566#[cfg(test)]
567mod tests {
568}