tf_idf_vectorizer/vectorizer/
token.rs

1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3use ahash::RandomState;
4use serde::{Deserialize, Serialize};
5
6
7///  TokenFrequency 構造体
8/// tokenの出現頻度を管理するための構造体です
9/// tokenの出現回数をカウントします。
10/// 
11/// # Examples
12/// ```
13/// use vectorizer::token::TokenFrequency;
14/// let mut token_freq = TokenFrequency::new();
15/// token_freq.add_token("token1");
16/// token_freq.add_token("token2");
17/// token_freq.add_token("token1");
18/// 
19/// ```
20#[derive(Serialize, Deserialize, Debug, Clone)]
21pub struct TokenFrequency {
22    token_count: HashMap<String, u64, RandomState>,
23    total_token_count: u64,
24}
25
26/// Tokenの追加、削除の実装
27impl TokenFrequency {
28    /// 新しいTokenFrequencyを作成するメソッド
29    pub fn new() -> Self {
30        TokenFrequency {
31            token_count: HashMap::with_hasher(RandomState::new()),
32            total_token_count: 0,
33        }
34    }
35
36    /// tokenを追加する
37    /// 
38    /// # Arguments
39    /// * `token` - 追加するトークン
40    #[inline]
41    pub fn add_token(&mut self, token: &str) -> &mut Self {
42        let count = self.token_count.entry(token.to_string()).or_insert(0);
43        *count += 1;
44        self.total_token_count += 1;
45        self
46    }
47
48    /// 複数のtokenを追加する
49    /// 
50    /// # Arguments
51    /// * `tokens` - 追加するトークンのスライス
52    #[inline]
53    pub fn add_tokens<T>(&mut self, tokens: &[T]) -> &mut Self 
54    where T: AsRef<str> 
55    {
56        for token in tokens {
57            let token_str = token.as_ref();
58            self.add_token(token_str);
59        }
60        self
61    }
62
63    /// tokenを引く
64    /// 
65    /// # Arguments
66    /// * `token` - 引くトークン
67    #[inline]
68    pub fn sub_token(&mut self, token: &str) -> &mut Self {
69        if let Some(count) = self.token_count.get_mut(token) {
70            if *count > 1 {
71                *count -= 1;
72                self.total_token_count -= 1;
73            } else if *count == 1 {
74                self.token_count.remove(token);
75                self.total_token_count -= 1;
76            }
77        }
78        self
79    }
80
81    /// 複数のtokenを引く
82    /// 
83    /// # Arguments
84    /// * `tokens` - 引くトークンのスライス
85    #[inline]
86    pub fn sub_tokens<T>(&mut self, tokens: &[T]) -> &mut Self 
87    where T: AsRef<str>
88    {
89        for token in tokens {
90            let token_str = token.as_ref();
91            self.sub_token(token_str);
92        }
93        self
94    }
95
96    /// tokenの出現回数を指定する
97    /// 
98    /// # Arguments
99    /// * `token` - トークン
100    /// * `count` - 出現回数
101    pub fn set_token_count(&mut self, token: &str, count: u64) -> &mut Self {
102        if count == 0 {
103            self.token_count.remove(token);
104        } else {
105            let current_count = self.token_count.entry(token.to_string()).or_insert(0);
106            self.total_token_count += count - *current_count;
107            *current_count = count;
108        }
109        self
110    }
111}
112
113/// TokenFrequencyの情報を取得するための実装
114impl TokenFrequency {
115    /// すべてのtokenの出現回数を取得します
116    /// 
117    /// # Returns
118    /// * `Vec<(String, u32)>` - トークンとその出現回数のベクタ
119    #[inline]
120    pub fn token_count_vector(&self) -> Vec<(String, u64)> {
121        self.token_count.iter().map(|(token, &count)| {
122            (token.clone(), count)
123        }).collect()
124    }
125
126    /// すべてのtokenの出現回数を取得します
127    /// 文字列はこれの参照を返します
128    /// 
129    /// # Returns
130    /// * `Vec<(&str, u32)>` - トークンとその出現回数のベクタ
131    #[inline]
132    pub fn token_count_vector_ref_str(&self) -> Vec<(&str, u64)> {
133        self.token_count.iter().map(|(token, &count)| {
134            (token.as_str(), count)
135        }).collect()
136    }
137
138    /// すべてのtokenの出現回数を取得します
139    /// 文字列はこれの参照を返します
140    /// 
141    /// # Returns
142    /// * `HashMap<&str, u32>` - トークンとその出現回数のハッシュマップ
143    #[inline]
144    pub fn token_count_hashmap_ref_str(&self) -> HashMap<&str, u64, RandomState> {
145        self.token_count.iter().map(|(token, &count)| {
146            (token.as_str(), count)
147        }).collect()
148    }
149
150    /// 全tokenのカウントの合計を取得します
151    /// 
152    /// # Returns
153    /// * `u64` - tokenのカウントの合計
154    #[inline]
155    pub fn token_sum(&self) -> u64 {
156        self.total_token_count
157    }
158
159    /// あるtokenの出現回数を取得します
160    /// 
161    /// # Arguments
162    /// * `token` - トークン
163    /// 
164    /// # Returns
165    /// * `u32` - トークンの出現回数
166    #[inline]
167    pub fn token_count(&self, token: &str) -> u64 {
168        *self.token_count.get(token).unwrap_or(&0)
169    }
170
171    /// もっとも多く出現したtokenを取得します
172    /// 同じ出現回数のtokenが複数ある場合は、すべてのtokenを取得します
173    /// 
174    /// # Returns
175    /// * `Vec<(String, u32)>` - トークンとその出現回数のベクタ
176    #[inline]
177    pub fn most_frequent_tokens_vector(&self) -> Vec<(String, u64)> {
178        if let Some(&max_count) = self.token_count.values().max() {
179            self.token_count.iter()
180                .filter(|&(_, &count)| count == max_count)
181                .map(|(token, &count)| (token.clone(), count))
182                .collect()
183        } else {
184            Vec::new()
185        }
186    }
187
188    /// もっとも多く出現したtokenの出現回数取得します
189    /// 
190    /// # Returns
191    /// * `u32` - 再頻出tokenの出現回数
192    #[inline]
193    pub fn most_frequent_token_count(&self) -> u64 {
194        if let Some(&max_count) = self.token_count.values().max() {
195            max_count
196        } else {
197            0
198        }
199    }
200
201    /// tokenが存在するかどうかを確認します
202    /// 
203    /// # Arguments
204    /// * `token` - トークン
205    /// 
206    /// # Returns
207    /// * `bool` - tokenが存在する場合はtrue、存在しない場合はfalse
208    #[inline]
209    pub fn contains_token(&self, token: &str) -> bool {
210        self.token_count.contains_key(token)
211    }
212
213    /// tokenのsetを取得します
214    /// 
215    /// # Returns
216    /// * `Vec<String>` - tokenのset
217    #[inline]
218    pub fn token_set(&self) -> Vec<String> {
219        self.token_count.keys().cloned().collect()
220    }
221
222    /// tokenのsetを取得します
223    /// 文字列はこれの参照を返します
224    /// 
225    /// # Returns
226    /// * `Vec<&str>` - tokenのset
227    #[inline]
228    pub fn token_set_ref_str(&self) -> Vec<&str> {
229        self.token_count.keys().map(|s| s.as_str()).collect()
230    }
231
232    /// tokenのsetを取得します
233    /// 
234    /// # Returns
235    /// * `HashSet<String>` - tokenのset
236    #[inline]
237    pub fn token_hashset(&self) -> HashSet<String, RandomState> {
238        self.token_count.keys().cloned().collect()
239    }
240
241    /// tokenのsetを取得します
242    /// 文字列はこれの参照を返します
243    /// 
244    /// # Returns
245    /// * `HashSet<&str>` - tokenのset
246    #[inline]
247    pub fn token_hashset_ref_str(&self) -> HashSet<&str, RandomState> {
248        self.token_count.keys().map(|s| s.as_str()).collect()
249    }
250
251    /// 出現した単語数を取得します
252    /// 
253    /// # Returns
254    /// * `usize` - 出現した単語数
255    #[inline]
256    pub fn token_num(&self) -> usize {
257        self.token_count.len()
258    }
259
260    /// stop_tokenを削除します
261    /// 
262    /// # Arguments
263    /// * `stop_tokens` - 削除するトークンのスライス
264    /// 
265    /// # Returns
266    /// * `u64` - 削除されたtokenの合計数
267    #[inline]
268    pub fn remove_stop_tokens(&mut self, stop_tokens: &[&str]) -> u64{
269        let mut removed_total_count: u64 = 0;
270        for &stop_token in stop_tokens {
271            if let Some(count) = self.token_count.remove(stop_token) {
272                removed_total_count += count as u64;
273            }
274        }
275        self.total_token_count -= removed_total_count;
276        removed_total_count
277    }
278
279    /// 条件に基づいてtokenを削除します
280    /// 
281    /// # Arguments
282    /// * `condition` - 条件を満たすtokenを削除するクロージャ
283    /// 
284    /// # Returns
285    /// * `u64` - 削除されたtokenの合計数
286    #[inline]
287    pub fn remove_tokens_by<F>(&mut self, condition: F) -> u64
288    where
289        F: Fn(&str, &u64) -> bool,
290    {
291        let mut removed_total_count: u64 = 0;
292        self.token_count.retain(|token, count| {
293            if condition(token, count) {
294                removed_total_count += *count as u64;
295                false
296            } else {
297                true
298            }
299        });
300        self.total_token_count -= removed_total_count as u64;
301
302        removed_total_count
303    }
304
305    /// 頻度でソートされたトークンのベクタを取得(降順)
306    /// 
307    /// # Returns
308    /// * `Vec<(String, u32)>` - 頻度でソートされたトークンのベクタ
309    #[inline]
310    pub fn sorted_frequency_vector(&self) -> Vec<(String, u64)> {
311        let mut token_list: Vec<(String, u64)> = self.token_count
312            .iter()
313            .map(|(token, &count)| (token.clone(), count))
314            .collect();
315
316        token_list.sort_by(|a, b| b.1.cmp(&a.1));
317        token_list
318    }
319
320    /// 辞書順でソートされたトークンのベクタを取得(昇順)
321    /// 
322    /// # Returns
323    /// * `Vec<(String, u32)>` - 辞書順でソートされたトークンのベクタ
324    #[inline]
325    pub fn sorted_dict_order_vector(&self) -> Vec<(String, u64)> {
326        let mut token_list: Vec<(String, u64)> = self.token_count
327            .iter()
328            .map(|(token, &count)| (token.clone(), count))
329            .collect();
330
331        token_list.sort_by(|a, b| a.0.cmp(&b.0));
332        token_list
333    }
334
335    /// tokenの多様性を計算します
336    /// 1.0は完全な多様性を示し、0.0は完全な非多様性を示します
337    /// 
338    /// # Returns
339    /// * `f64` - tokenの多様性
340    #[inline]
341    pub fn unique_token_ratio(&self) -> f64 {
342        if self.total_token_count == 0 {
343            return 0.0;
344        }
345        self.token_count.len() as f64 / self.total_token_count as f64
346    }
347
348    /// カウントを全リセットします
349    #[inline]
350    pub fn clear(&mut self) {
351        self.token_count.clear();
352        self.total_token_count = 0;
353    }
354
355    /// shrink internal storage to fit current size
356    #[inline]
357    pub fn shrink_to_fit(&mut self) {
358        self.token_count.shrink_to_fit();
359    }
360}