tf_idf_vectorizer/vectorizer/
token.rs1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3
4use indexmap::IndexMap;
5use num::Num;
6use serde::{Deserialize, Serialize};
7
8use crate::utils::normalizer::IntoNormalizer;
9
10#[derive(Serialize, Deserialize, Debug, Clone)]
12pub struct TokenFrequency {
13 #[serde(with = "indexmap::map::serde_seq")]
14 token_count: IndexMap<String, u32>,
15 total_token_count: u64,
16}
17
18impl TokenFrequency {
20 pub fn new() -> Self {
21 TokenFrequency {
22 token_count: IndexMap::new(),
23 total_token_count: 0,
24 }
25 }
26
27 #[inline]
32 pub fn add_token(&mut self, token: &str) -> &mut Self {
33 let count = self.token_count.entry(token.to_string()).or_insert(0);
34 *count += 1;
35 self.total_token_count += 1;
36 self
37 }
38
39 #[inline]
44 pub fn add_tokens<T>(&mut self, tokens: &[T]) -> &mut Self
45 where T: AsRef<str>
46 {
47 for token in tokens {
48 let token_str = token.as_ref();
49 self.add_token(token_str);
50 }
51 self
52 }
53
54 #[deprecated(note = "countに0を指定した場合、token_numはそれを1つのユニークなtokenとしてカウントします。
93 このメソッドは、token_numのカウントを不正にする可能性があるため、非推奨です")]
94 pub fn set_token_count(&mut self, token: &str, count: u32) -> &mut Self {
95 if let Some(existing_count) = self.token_count.get_mut(token) {
96 if count >= *existing_count {
97 self.total_token_count += count as u64 - *existing_count as u64;
98 } else {
99 self.total_token_count -= *existing_count as u64 - count as u64;
100 }
101 *existing_count = count;
102 } else {
103 self.token_count.insert(token.to_string(), count);
104 self.total_token_count += count as u64;
105 }
106 self
107 }
108}
109
110impl TokenFrequency
112{
113 #[inline]
122 pub fn tf_calc(max_count: u32, count: u32) -> f64 {
123 if count == 0 {
124 return 0.0;
125 }
126 (count as f64 + 1.0).ln() / (max_count as f64 + 1.0).ln()
127 }
128
129 #[inline]
134 pub fn tf_vector<N>(&self) -> Vec<(String, N)>
135 where f64: IntoNormalizer<N>, N: Num {
136 let max_count = self.most_frequent_token_count();
137 self.token_count
138 .iter()
139 .map(|(token, &count)| {
140 (token.clone(), Self::tf_calc(max_count, count).into_normalized())
141 })
142 .collect()
143 }
144
145 #[inline]
151 pub fn tf_vector_ref_str<N>(&self) -> Vec<(&str, N)>
152 where f64: IntoNormalizer<N>, N: Num {
153 let max_count = self.most_frequent_token_count();
154 self.token_count
155 .iter()
156 .map(|(token, &count)| {
157 (token.as_str(), Self::tf_calc(max_count, count).into_normalized())
158 })
159 .collect()
160 }
161
162 #[inline]
167 pub fn tf_hashmap<N>(&self) -> HashMap<String, N>
168 where f64: IntoNormalizer<N>, N: Num {
169 let max_count = self.most_frequent_token_count();
170 self.token_count
171 .iter()
172 .map(|(token, &count)| {
173 (token.clone(), Self::tf_calc(max_count, count).into_normalized())
174 })
175 .collect()
176 }
177
178 #[inline]
184 pub fn tf_hashmap_ref_str<N>(&self) -> HashMap<&str, N>
185 where f64: IntoNormalizer<N>, N: Num {
186 let max_count = self.most_frequent_token_count();
187 self.token_count
188 .iter()
189 .map(|(token, &count)| {
190 (token.as_str(), Self::tf_calc(max_count, count).into_normalized())
191 })
192 .collect()
193 }
194
195 #[inline]
203 pub fn tf_token<N>(&self, token: &str) -> N
204 where f64: IntoNormalizer<N>, N: Num{
205 let max_count = self.most_frequent_token_count();
206 let count = self.token_count.get(token).copied().unwrap_or(0);
207 Self::tf_calc(max_count, count).into_normalized()
208 }
209}
210
211impl TokenFrequency {
213 #[inline]
216 fn idf_max(&self, total_doc_count: u64) -> f64 {
217 (1.0 + total_doc_count as f64 / (2.0)).ln()
218 }
219
220 #[inline]
230 pub fn idf_calc(total_doc_count: u64, max_idf: f64, doc_count: u32) -> f64 {
231 (1.0 + total_doc_count as f64 / (1.0 + doc_count as f64)).ln() / max_idf
232 }
233
234 #[inline]
242 pub fn idf_vector<N>(&self, total_doc_count: u64) -> Vec<(String, N)>
243 where f64: IntoNormalizer<N>, N: Num {
244 self.token_count
245 .iter()
246 .map(|(token, &doc_count)| {
247 let idf = Self::idf_calc(total_doc_count, self.idf_max(total_doc_count), doc_count);
248 (token.clone(), idf.into_normalized())
249 })
250 .collect()
251 }
252
253 #[inline]
262 pub fn idf_vector_ref_str<N>(&self, total_doc_count: u64) -> Vec<(&str, N)>
263 where f64: IntoNormalizer<N>, N: Num {
264 self.token_count.iter().map(|(token, &doc_count)| {
265 let idf = Self::idf_calc(total_doc_count, self.idf_max(total_doc_count), doc_count);
266 (token.as_str(), idf.into_normalized())
267 }).collect()
268 }
269
270
271 #[inline]
279 pub fn idf_hashmap<N>(&self, total_doc_count: u64) -> HashMap<String, N>
280 where f64: IntoNormalizer<N>, N: Num {
281 self.token_count
282 .iter()
283 .map(|(token, &doc_count)| {
284 let idf = Self::idf_calc(total_doc_count, self.idf_max(total_doc_count), doc_count);
285 (token.clone(), idf.into_normalized())
286 })
287 .collect()
288 }
289
290 #[inline]
299 pub fn idf_hashmap_ref_str<N>(&self, total_doc_count: u64) -> HashMap<&str, N>
300 where f64: IntoNormalizer<N>, N: Num {
301 self.token_count.iter().map(|(token, &doc_count)| {
302 let idf = Self::idf_calc(total_doc_count, self.idf_max(total_doc_count), doc_count);
303 (token.as_str(), idf.into_normalized())
304 }).collect()
305 }
306}
307
308impl TokenFrequency {
310 #[inline]
315 pub fn token_count_vector(&self) -> Vec<(String, u32)> {
316 self.token_count.iter().map(|(token, &count)| {
317 (token.clone(), count)
318 }).collect()
319 }
320
321 #[inline]
327 pub fn token_count_vector_ref_str(&self) -> Vec<(&str, u32)> {
328 self.token_count.iter().map(|(token, &count)| {
329 (token.as_str(), count)
330 }).collect()
331 }
332
333 #[inline]
339 pub fn token_count_hashmap_ref_str(&self) -> HashMap<&str, u32> {
340 self.token_count.iter().map(|(token, &count)| {
341 (token.as_str(), count)
342 }).collect()
343 }
344
345 #[inline]
350 pub fn token_total_count(&self) -> u64 {
351 self.total_token_count
352 }
353
354 #[inline]
362 pub fn token_count(&self, token: &str) -> u32 {
363 *self.token_count.get(token).unwrap_or(&0)
364 }
365
366 #[inline]
372 pub fn most_frequent_tokens_vector(&self) -> Vec<(String, u32)> {
373 if let Some(&max_count) = self.token_count.values().max() {
374 self.token_count.iter()
375 .filter(|&(_, &count)| count == max_count)
376 .map(|(token, &count)| (token.clone(), count))
377 .collect()
378 } else {
379 Vec::new()
380 }
381 }
382
383 #[inline]
388 pub fn most_frequent_token_count(&self) -> u32 {
389 if let Some(&max_count) = self.token_count.values().max() {
390 max_count
391 } else {
392 0
393 }
394 }
395
396 #[inline]
404 pub fn contains_token(&self, token: &str) -> bool {
405 self.token_count.contains_key(token)
406 }
407
408 #[inline]
413 pub fn token_set(&self) -> Vec<String> {
414 self.token_count.keys().cloned().collect()
415 }
416
417 #[inline]
423 pub fn token_set_ref_str(&self) -> Vec<&str> {
424 self.token_count.keys().map(|s| s.as_str()).collect()
425 }
426
427 #[inline]
432 pub fn token_hashset(&self) -> HashSet<String> {
433 self.token_count.keys().cloned().collect()
434 }
435
436 #[inline]
442 pub fn token_hashset_ref_str(&self) -> HashSet<&str> {
443 self.token_count.keys().map(|s| s.as_str()).collect()
444 }
445
446 #[inline]
451 pub fn token_num(&self) -> usize {
452 self.token_count.len()
453 }
454
455 #[inline]
482 pub fn remove_tokens_by_condition<F>(&mut self, condition: F) -> u64
483 where
484 F: Fn(&str, &u32) -> bool,
485 {
486 let mut removed_total_count: u64 = 0;
487 self.token_count.retain(|token, count| {
488 if condition(token, count) {
489 removed_total_count += *count as u64;
490 false
491 } else {
492 true
493 }
494 });
495 self.total_token_count -= removed_total_count as u64;
496
497 removed_total_count
498 }
499
500 #[inline]
505 pub fn sorted_frequency_vector(&self) -> Vec<(String, u32)> {
506 let mut token_list: Vec<(String, u32)> = self.token_count
507 .iter()
508 .map(|(token, &count)| (token.clone(), count))
509 .collect();
510
511 token_list.sort_by(|a, b| b.1.cmp(&a.1));
512 token_list
513 }
514
515 #[inline]
520 pub fn sorted_dict_order_vector(&self) -> Vec<(String, u32)> {
521 let mut token_list: Vec<(String, u32)> = self.token_count
522 .iter()
523 .map(|(token, &count)| (token.clone(), count))
524 .collect();
525
526 token_list.sort_by(|a, b| a.0.cmp(&b.0));
527 token_list
528 }
529
530 #[inline]
536 pub fn unique_token_ratio(&self) -> f64 {
537 if self.total_token_count == 0 {
538 return 0.0;
539 }
540 self.token_count.len() as f64 / self.total_token_count as f64
541 }
542
543 #[inline]
545 pub fn clear(&mut self) {
546 self.token_count.clear();
547 self.total_token_count = 0;
548 }
549}
550
551#[cfg(test)]
552mod tests {
553}