tf_idf_vectorizer/vectorizer/
token.rs1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3use ahash::RandomState;
4use serde::{Deserialize, Serialize};
5
6use crate::Corpus;
7
8
9#[derive(Serialize, Deserialize, Debug, Clone)]
24pub struct TokenFrequency {
25 token_count: HashMap<String, u64, RandomState>,
26 total_token_count: u64,
27}
28
29impl TokenFrequency {
31 pub fn new() -> Self {
33 TokenFrequency {
34 token_count: HashMap::with_hasher(RandomState::new()),
35 total_token_count: 0,
36 }
37 }
38
39 #[inline]
44 pub fn add_token(&mut self, token: &str) -> &mut Self {
45 let count = self.token_count.entry(token.to_string()).or_insert(0);
46 *count += 1;
47 self.total_token_count += 1;
48 self
49 }
50
51 #[inline]
56 pub fn add_tokens<T>(&mut self, tokens: &[T]) -> &mut Self
57 where T: AsRef<str>
58 {
59 for token in tokens {
60 let token_str = token.as_ref();
61 self.add_token(token_str);
62 }
63 self
64 }
65
66 #[inline]
71 pub fn sub_token(&mut self, token: &str) -> &mut Self {
72 if let Some(count) = self.token_count.get_mut(token) {
73 if *count > 1 {
74 *count -= 1;
75 self.total_token_count -= 1;
76 } else if *count == 1 {
77 self.token_count.remove(token);
78 self.total_token_count -= 1;
79 }
80 }
81 self
82 }
83
84 #[inline]
89 pub fn sub_tokens<T>(&mut self, tokens: &[T]) -> &mut Self
90 where T: AsRef<str>
91 {
92 for token in tokens {
93 let token_str = token.as_ref();
94 self.sub_token(token_str);
95 }
96 self
97 }
98
99 pub fn set_token_count(&mut self, token: &str, count: u64) -> &mut Self {
105 if count == 0 {
106 self.token_count.remove(token);
107 } else {
108 let current_count = self.token_count.entry(token.to_string()).or_insert(0);
109 self.total_token_count += count - *current_count;
110 *current_count = count;
111 }
112 self
113 }
114
115 pub fn add_tokens_from_freq(&mut self, other: &TokenFrequency) -> &mut Self {
119 for (token, &count) in &other.token_count {
120 let entry = self.token_count.entry(token.clone()).or_insert(0);
121 *entry += count;
122 self.total_token_count += count;
123 }
124 self
125 }
126
127 pub fn scale(&mut self, scalar: f64) -> &mut Self {
131 let mut total_count = 0;
132 self.token_count.iter_mut().for_each(|(_, count)| {
133 *count = ((*count as f64) * scalar).round() as u64;
134 total_count += *count;
135 });
136 self.total_token_count = total_count;
137 self
138 }
139}
140
141impl<T> From<&[T]> for TokenFrequency
142where
143 T: AsRef<str>,
144{
145 fn from(tokens: &[T]) -> Self {
146 let mut tf = TokenFrequency::new();
147 tf.add_tokens(tokens);
148 tf
149 }
150}
151
152impl From<Corpus> for TokenFrequency {
153 fn from(corpus: Corpus) -> Self {
154 let mut tf = TokenFrequency::new();
155 for entry in corpus.token_counts.iter() {
156 let token = entry.key();
157 let count = *entry.value();
158 tf.set_token_count(token, count);
159 }
160 tf
161 }
162}
163
164impl TokenFrequency {
166 #[inline]
171 pub fn iter(&self) -> impl Iterator<Item=(&str, u64)> {
172 self.token_count.iter().map(|(token, &count)| {
173 (token.as_str(), count)
174 })
175 }
176
177 #[inline]
182 pub fn token_count_vector(&self) -> Vec<(String, u64)> {
183 self.token_count.iter().map(|(token, &count)| {
184 (token.clone(), count)
185 }).collect()
186 }
187
188 #[inline]
193 pub fn token_count_vector_ref_str(&self) -> Vec<(&str, u64)> {
194 self.token_count.iter().map(|(token, &count)| {
195 (token.as_str(), count)
196 }).collect()
197 }
198
199 #[inline]
204 pub fn token_count_hashmap_ref_str(&self) -> HashMap<&str, u64, RandomState> {
205 self.token_count.iter().map(|(token, &count)| {
206 (token.as_str(), count)
207 }).collect()
208 }
209
210 #[inline]
215 pub fn token_sum(&self) -> u64 {
216 self.total_token_count
217 }
218
219 #[inline]
227 pub fn token_count(&self, token: &str) -> u64 {
228 *self.token_count.get(token).unwrap_or(&0)
229 }
230
231 #[inline]
237 pub fn most_frequent_tokens_vector(&self) -> Vec<(String, u64)> {
238 if let Some(&max_count) = self.token_count.values().max() {
239 self.token_count.iter()
240 .filter(|&(_, &count)| count == max_count)
241 .map(|(token, &count)| (token.clone(), count))
242 .collect()
243 } else {
244 Vec::new()
245 }
246 }
247
248 #[inline]
253 pub fn most_frequent_token_count(&self) -> u64 {
254 if let Some(&max_count) = self.token_count.values().max() {
255 max_count
256 } else {
257 0
258 }
259 }
260
261 #[inline]
269 pub fn contains_token(&self, token: &str) -> bool {
270 self.token_count.contains_key(token)
271 }
272
273 #[inline]
278 pub fn token_set_iter(&self) -> impl Iterator<Item=&str> {
279 self.token_count.keys().map(|s| s.as_str())
280 }
281
282 #[inline]
287 pub fn token_set(&self) -> Vec<String> {
288 self.token_count.keys().cloned().collect()
289 }
290
291 #[inline]
296 pub fn token_set_ref_str(&self) -> Vec<&str> {
297 self.token_count.keys().map(|s| s.as_str()).collect()
298 }
299
300 #[inline]
305 pub fn token_hashset(&self) -> HashSet<String, RandomState> {
306 self.token_count.keys().cloned().collect()
307 }
308
309 #[inline]
314 pub fn token_hashset_ref_str(&self) -> HashSet<&str, RandomState> {
315 self.token_count.keys().map(|s| s.as_str()).collect()
316 }
317
318 #[inline]
323 pub fn token_num(&self) -> usize {
324 self.token_count.len()
325 }
326
327 #[inline]
335 pub fn remove_stop_tokens(&mut self, stop_tokens: &[&str]) -> u64{
336 let mut removed_total_count: u64 = 0;
337 for &stop_token in stop_tokens {
338 if let Some(count) = self.token_count.remove(stop_token) {
339 removed_total_count += count as u64;
340 }
341 }
342 self.total_token_count -= removed_total_count;
343 removed_total_count
344 }
345
346 #[inline]
354 pub fn remove_tokens_by<F>(&mut self, condition: F) -> u64
355 where
356 F: Fn(&str, &u64) -> bool,
357 {
358 let mut removed_total_count: u64 = 0;
359 self.token_count.retain(|token, count| {
360 if condition(token, count) {
361 removed_total_count += *count as u64;
362 false
363 } else {
364 true
365 }
366 });
367 self.total_token_count -= removed_total_count as u64;
368
369 removed_total_count
370 }
371
372 #[inline]
377 pub fn sorted_frequency_vector(&self) -> Vec<(String, u64)> {
378 let mut token_list: Vec<(String, u64)> = self.token_count
379 .iter()
380 .map(|(token, &count)| (token.clone(), count))
381 .collect();
382
383 token_list.sort_by(|a, b| b.1.cmp(&a.1));
384 token_list
385 }
386
387 #[inline]
392 pub fn sorted_dict_order_vector(&self) -> Vec<(String, u64)> {
393 let mut token_list: Vec<(String, u64)> = self.token_count
394 .iter()
395 .map(|(token, &count)| (token.clone(), count))
396 .collect();
397
398 token_list.sort_by(|a, b| a.0.cmp(&b.0));
399 token_list
400 }
401
402 #[inline]
408 pub fn unique_token_ratio(&self) -> f64 {
409 if self.total_token_count == 0 {
410 return 0.0;
411 }
412 self.token_count.len() as f64 / self.total_token_count as f64
413 }
414
415 #[inline]
418 pub fn probability_vector(&self) -> Vec<(String, f64)> {
419 if self.total_token_count == 0 {
420 return Vec::new();
421 }
422 let total = self.total_token_count as f64;
423 self.token_count
424 .iter()
425 .map(|(token, &count)| (token.clone(), (count as f64) / total))
426 .collect()
427 }
428
429 #[inline]
432 pub fn probability_vector_ref_str(&self) -> Vec<(&str, f64)> {
433 if self.total_token_count == 0 {
434 return Vec::new();
435 }
436 let total = self.total_token_count as f64;
437 self.token_count
438 .iter()
439 .map(|(token, &count)| (token.as_str(), (count as f64) / total))
440 .collect()
441 }
442
443 #[inline]
446 pub fn probability(&self, token: &str) -> f64 {
447 if self.total_token_count == 0 {
448 return 0.0;
449 }
450 (self.token_count(token) as f64) / (self.total_token_count as f64)
451 }
452
453 #[inline]
455 pub fn clear(&mut self) {
456 self.token_count.clear();
457 self.total_token_count = 0;
458 }
459
460 #[inline]
462 pub fn shrink_to_fit(&mut self) {
463 self.token_count.shrink_to_fit();
464 }
465}