tf_idf_vectorizer/vectorizer/
token.rs1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3
4use indexmap::IndexMap;
5use num::Num;
6use serde::{Deserialize, Serialize};
7
8use crate::utils::normalizer::IntoNormalizer;
9
10#[derive(Serialize, Deserialize, Debug, Clone)]
26pub struct TokenFrequency {
27 #[serde(with = "indexmap::map::serde_seq")]
28 token_count: IndexMap<String, u32>,
29 total_token_count: u64,
30}
31
32impl TokenFrequency {
34 pub fn new() -> Self {
36 TokenFrequency {
37 token_count: IndexMap::new(),
38 total_token_count: 0,
39 }
40 }
41
42 #[inline]
47 pub fn add_token(&mut self, token: &str) -> &mut Self {
48 let count = self.token_count.entry(token.to_string()).or_insert(0);
49 *count += 1;
50 self.total_token_count += 1;
51 self
52 }
53
54 #[inline]
59 pub fn add_tokens<T>(&mut self, tokens: &[T]) -> &mut Self
60 where T: AsRef<str>
61 {
62 for token in tokens {
63 let token_str = token.as_ref();
64 self.add_token(token_str);
65 }
66 self
67 }
68
69 #[deprecated(note = "countに0を指定した場合、token_numはそれを1つのユニークなtokenとしてカウントします。
108 このメソッドは、token_numのカウントを不正にする可能性があるため、非推奨です")]
109 pub fn set_token_count(&mut self, token: &str, count: u32) -> &mut Self {
110 if let Some(existing_count) = self.token_count.get_mut(token) {
111 if count >= *existing_count {
112 self.total_token_count += count as u64 - *existing_count as u64;
113 } else {
114 self.total_token_count -= *existing_count as u64 - count as u64;
115 }
116 *existing_count = count;
117 } else {
118 self.token_count.insert(token.to_string(), count);
119 self.total_token_count += count as u64;
120 }
121 self
122 }
123}
124
125impl TokenFrequency
127{
128 #[inline]
137 pub fn tf_calc(max_count: u32, count: u32) -> f64 {
138 if count == 0 {
139 return 0.0;
140 }
141 (count as f64 + 1.0).ln() / (max_count as f64 + 1.0).ln()
142 }
143
144 #[inline]
149 pub fn tf_vector<N>(&self) -> Vec<(String, N)>
150 where f64: IntoNormalizer<N>, N: Num {
151 let max_count = self.most_frequent_token_count();
152 self.token_count
153 .iter()
154 .map(|(token, &count)| {
155 (token.clone(), Self::tf_calc(max_count, count).into_normalized())
156 })
157 .collect()
158 }
159
160 #[inline]
166 pub fn tf_vector_ref_str<N>(&self) -> Vec<(&str, N)>
167 where f64: IntoNormalizer<N>, N: Num {
168 let max_count = self.most_frequent_token_count();
169 self.token_count
170 .iter()
171 .map(|(token, &count)| {
172 (token.as_str(), Self::tf_calc(max_count, count).into_normalized())
173 })
174 .collect()
175 }
176
177 #[inline]
182 pub fn tf_hashmap<N>(&self) -> HashMap<String, N>
183 where f64: IntoNormalizer<N>, N: Num {
184 let max_count = self.most_frequent_token_count();
185 self.token_count
186 .iter()
187 .map(|(token, &count)| {
188 (token.clone(), Self::tf_calc(max_count, count).into_normalized())
189 })
190 .collect()
191 }
192
193 #[inline]
199 pub fn tf_hashmap_ref_str<N>(&self) -> HashMap<&str, N>
200 where f64: IntoNormalizer<N>, N: Num {
201 let max_count = self.most_frequent_token_count();
202 self.token_count
203 .iter()
204 .map(|(token, &count)| {
205 (token.as_str(), Self::tf_calc(max_count, count).into_normalized())
206 })
207 .collect()
208 }
209
210 #[inline]
218 pub fn tf_token<N>(&self, token: &str) -> N
219 where f64: IntoNormalizer<N>, N: Num{
220 let max_count = self.most_frequent_token_count();
221 let count = self.token_count.get(token).copied().unwrap_or(0);
222 Self::tf_calc(max_count, count).into_normalized()
223 }
224}
225
226impl TokenFrequency {
228 #[inline]
231 fn idf_max(&self, total_doc_count: u64) -> f64 {
232 (1.0 + total_doc_count as f64 / (2.0)).ln()
233 }
234
235 #[inline]
245 pub fn idf_calc(total_doc_count: u64, max_idf: f64, doc_count: u32) -> f64 {
246 (1.0 + total_doc_count as f64 / (1.0 + doc_count as f64)).ln() / max_idf
247 }
248
249 #[inline]
257 pub fn idf_vector<N>(&self, total_doc_count: u64) -> Vec<(String, N)>
258 where f64: IntoNormalizer<N>, N: Num {
259 self.token_count
260 .iter()
261 .map(|(token, &doc_count)| {
262 let idf = Self::idf_calc(total_doc_count, self.idf_max(total_doc_count), doc_count);
263 (token.clone(), idf.into_normalized())
264 })
265 .collect()
266 }
267
268 #[inline]
277 pub fn idf_vector_ref_str<N>(&self, total_doc_count: u64) -> Vec<(&str, N)>
278 where f64: IntoNormalizer<N>, N: Num {
279 self.token_count.iter().map(|(token, &doc_count)| {
280 let idf = Self::idf_calc(total_doc_count, self.idf_max(total_doc_count), doc_count);
281 (token.as_str(), idf.into_normalized())
282 }).collect()
283 }
284
285
286 #[inline]
294 pub fn idf_hashmap<N>(&self, total_doc_count: u64) -> HashMap<String, N>
295 where f64: IntoNormalizer<N>, N: Num {
296 self.token_count
297 .iter()
298 .map(|(token, &doc_count)| {
299 let idf = Self::idf_calc(total_doc_count, self.idf_max(total_doc_count), doc_count);
300 (token.clone(), idf.into_normalized())
301 })
302 .collect()
303 }
304
305 #[inline]
314 pub fn idf_hashmap_ref_str<N>(&self, total_doc_count: u64) -> HashMap<&str, N>
315 where f64: IntoNormalizer<N>, N: Num {
316 self.token_count.iter().map(|(token, &doc_count)| {
317 let idf = Self::idf_calc(total_doc_count, self.idf_max(total_doc_count), doc_count);
318 (token.as_str(), idf.into_normalized())
319 }).collect()
320 }
321}
322
323impl TokenFrequency {
325 #[inline]
330 pub fn token_count_vector(&self) -> Vec<(String, u32)> {
331 self.token_count.iter().map(|(token, &count)| {
332 (token.clone(), count)
333 }).collect()
334 }
335
336 #[inline]
342 pub fn token_count_vector_ref_str(&self) -> Vec<(&str, u32)> {
343 self.token_count.iter().map(|(token, &count)| {
344 (token.as_str(), count)
345 }).collect()
346 }
347
348 #[inline]
354 pub fn token_count_hashmap_ref_str(&self) -> HashMap<&str, u32> {
355 self.token_count.iter().map(|(token, &count)| {
356 (token.as_str(), count)
357 }).collect()
358 }
359
360 #[inline]
365 pub fn token_total_count(&self) -> u64 {
366 self.total_token_count
367 }
368
369 #[inline]
377 pub fn token_count(&self, token: &str) -> u32 {
378 *self.token_count.get(token).unwrap_or(&0)
379 }
380
381 #[inline]
387 pub fn most_frequent_tokens_vector(&self) -> Vec<(String, u32)> {
388 if let Some(&max_count) = self.token_count.values().max() {
389 self.token_count.iter()
390 .filter(|&(_, &count)| count == max_count)
391 .map(|(token, &count)| (token.clone(), count))
392 .collect()
393 } else {
394 Vec::new()
395 }
396 }
397
398 #[inline]
403 pub fn most_frequent_token_count(&self) -> u32 {
404 if let Some(&max_count) = self.token_count.values().max() {
405 max_count
406 } else {
407 0
408 }
409 }
410
411 #[inline]
419 pub fn contains_token(&self, token: &str) -> bool {
420 self.token_count.contains_key(token)
421 }
422
423 #[inline]
428 pub fn token_set(&self) -> Vec<String> {
429 self.token_count.keys().cloned().collect()
430 }
431
432 #[inline]
438 pub fn token_set_ref_str(&self) -> Vec<&str> {
439 self.token_count.keys().map(|s| s.as_str()).collect()
440 }
441
442 #[inline]
447 pub fn token_hashset(&self) -> HashSet<String> {
448 self.token_count.keys().cloned().collect()
449 }
450
451 #[inline]
457 pub fn token_hashset_ref_str(&self) -> HashSet<&str> {
458 self.token_count.keys().map(|s| s.as_str()).collect()
459 }
460
461 #[inline]
466 pub fn token_num(&self) -> usize {
467 self.token_count.len()
468 }
469
470 #[inline]
497 pub fn remove_tokens_by_condition<F>(&mut self, condition: F) -> u64
498 where
499 F: Fn(&str, &u32) -> bool,
500 {
501 let mut removed_total_count: u64 = 0;
502 self.token_count.retain(|token, count| {
503 if condition(token, count) {
504 removed_total_count += *count as u64;
505 false
506 } else {
507 true
508 }
509 });
510 self.total_token_count -= removed_total_count as u64;
511
512 removed_total_count
513 }
514
515 #[inline]
520 pub fn sorted_frequency_vector(&self) -> Vec<(String, u32)> {
521 let mut token_list: Vec<(String, u32)> = self.token_count
522 .iter()
523 .map(|(token, &count)| (token.clone(), count))
524 .collect();
525
526 token_list.sort_by(|a, b| b.1.cmp(&a.1));
527 token_list
528 }
529
530 #[inline]
535 pub fn sorted_dict_order_vector(&self) -> Vec<(String, u32)> {
536 let mut token_list: Vec<(String, u32)> = self.token_count
537 .iter()
538 .map(|(token, &count)| (token.clone(), count))
539 .collect();
540
541 token_list.sort_by(|a, b| a.0.cmp(&b.0));
542 token_list
543 }
544
545 #[inline]
551 pub fn unique_token_ratio(&self) -> f64 {
552 if self.total_token_count == 0 {
553 return 0.0;
554 }
555 self.token_count.len() as f64 / self.total_token_count as f64
556 }
557
558 #[inline]
560 pub fn clear(&mut self) {
561 self.token_count.clear();
562 self.total_token_count = 0;
563 }
564}
565
566#[cfg(test)]
567mod tests {
568}