tf_idf_vectorizer/vectorizer/
token.rs1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3use ahash::RandomState;
4use serde::{Deserialize, Serialize};
5
6use crate::Corpus;
7
8
9#[derive(Serialize, Deserialize, Debug, Clone)]
24pub struct TokenFrequency {
25 token_count: HashMap<String, u64, RandomState>,
26 total_token_count: u64,
27}
28
29impl TokenFrequency {
31 pub fn new() -> Self {
33 TokenFrequency {
34 token_count: HashMap::with_hasher(RandomState::new()),
35 total_token_count: 0,
36 }
37 }
38
39 #[inline]
44 pub fn add_token(&mut self, token: &str) -> &mut Self {
45 let count = self.token_count.entry(token.to_string()).or_insert(0);
46 *count += 1;
47 self.total_token_count += 1;
48 self
49 }
50
51 #[inline]
56 pub fn add_tokens<T>(&mut self, tokens: &[T]) -> &mut Self
57 where T: AsRef<str>
58 {
59 for token in tokens {
60 let token_str = token.as_ref();
61 self.add_token(token_str);
62 }
63 self
64 }
65
66 #[inline]
71 pub fn sub_token(&mut self, token: &str) -> &mut Self {
72 if let Some(count) = self.token_count.get_mut(token) {
73 if *count > 1 {
74 *count -= 1;
75 self.total_token_count -= 1;
76 } else if *count == 1 {
77 self.token_count.remove(token);
78 self.total_token_count -= 1;
79 }
80 }
81 self
82 }
83
84 #[inline]
89 pub fn sub_tokens<T>(&mut self, tokens: &[T]) -> &mut Self
90 where T: AsRef<str>
91 {
92 for token in tokens {
93 let token_str = token.as_ref();
94 self.sub_token(token_str);
95 }
96 self
97 }
98
99 pub fn set_token_count(&mut self, token: &str, count: u64) -> &mut Self {
105 if count == 0 {
106 self.token_count.remove(token);
107 } else {
108 let current_count = self.token_count.entry(token.to_string()).or_insert(0);
109 self.total_token_count += count - *current_count;
110 *current_count = count;
111 }
112 self
113 }
114
115 pub fn add_tokens_from_freq(&mut self, other: &TokenFrequency) -> &mut Self {
119 for (token, &count) in &other.token_count {
120 let entry = self.token_count.entry(token.clone()).or_insert(0);
121 *entry += count;
122 self.total_token_count += count;
123 }
124 self
125 }
126
127 pub fn scale(&mut self, scalar: f64) -> &mut Self {
131 let mut total_count = 0;
132 self.token_count.iter_mut().for_each(|(_, count)| {
133 *count = ((*count as f64) * scalar).round() as u64;
134 total_count += *count;
135 });
136 self.total_token_count = total_count;
137 self
138 }
139}
140
141impl<T> From<&[T]> for TokenFrequency
142where
143 T: AsRef<str>,
144{
145 fn from(tokens: &[T]) -> Self {
146 let mut tf = TokenFrequency::new();
147 tf.add_tokens(tokens);
148 tf
149 }
150}
151
152impl From<Corpus> for TokenFrequency {
153 fn from(corpus: Corpus) -> Self {
154 let mut tf = TokenFrequency::new();
155 for entry in corpus.token_counts.iter() {
156 let token = entry.key();
157 let count = *entry.value();
158 tf.set_token_count(token, count);
159 }
160 tf
161 }
162}
163
164impl TokenFrequency {
166 #[inline]
171 pub fn token_count_vector(&self) -> Vec<(String, u64)> {
172 self.token_count.iter().map(|(token, &count)| {
173 (token.clone(), count)
174 }).collect()
175 }
176
177 #[inline]
182 pub fn token_count_vector_ref_str(&self) -> Vec<(&str, u64)> {
183 self.token_count.iter().map(|(token, &count)| {
184 (token.as_str(), count)
185 }).collect()
186 }
187
188 #[inline]
193 pub fn token_count_hashmap_ref_str(&self) -> HashMap<&str, u64, RandomState> {
194 self.token_count.iter().map(|(token, &count)| {
195 (token.as_str(), count)
196 }).collect()
197 }
198
199 #[inline]
204 pub fn token_sum(&self) -> u64 {
205 self.total_token_count
206 }
207
208 #[inline]
216 pub fn token_count(&self, token: &str) -> u64 {
217 *self.token_count.get(token).unwrap_or(&0)
218 }
219
220 #[inline]
226 pub fn most_frequent_tokens_vector(&self) -> Vec<(String, u64)> {
227 if let Some(&max_count) = self.token_count.values().max() {
228 self.token_count.iter()
229 .filter(|&(_, &count)| count == max_count)
230 .map(|(token, &count)| (token.clone(), count))
231 .collect()
232 } else {
233 Vec::new()
234 }
235 }
236
237 #[inline]
242 pub fn most_frequent_token_count(&self) -> u64 {
243 if let Some(&max_count) = self.token_count.values().max() {
244 max_count
245 } else {
246 0
247 }
248 }
249
250 #[inline]
258 pub fn contains_token(&self, token: &str) -> bool {
259 self.token_count.contains_key(token)
260 }
261
262 #[inline]
267 pub fn token_set(&self) -> Vec<String> {
268 self.token_count.keys().cloned().collect()
269 }
270
271 #[inline]
276 pub fn token_set_ref_str(&self) -> Vec<&str> {
277 self.token_count.keys().map(|s| s.as_str()).collect()
278 }
279
280 #[inline]
285 pub fn token_hashset(&self) -> HashSet<String, RandomState> {
286 self.token_count.keys().cloned().collect()
287 }
288
289 #[inline]
294 pub fn token_hashset_ref_str(&self) -> HashSet<&str, RandomState> {
295 self.token_count.keys().map(|s| s.as_str()).collect()
296 }
297
298 #[inline]
303 pub fn token_num(&self) -> usize {
304 self.token_count.len()
305 }
306
307 #[inline]
315 pub fn remove_stop_tokens(&mut self, stop_tokens: &[&str]) -> u64{
316 let mut removed_total_count: u64 = 0;
317 for &stop_token in stop_tokens {
318 if let Some(count) = self.token_count.remove(stop_token) {
319 removed_total_count += count as u64;
320 }
321 }
322 self.total_token_count -= removed_total_count;
323 removed_total_count
324 }
325
326 #[inline]
334 pub fn remove_tokens_by<F>(&mut self, condition: F) -> u64
335 where
336 F: Fn(&str, &u64) -> bool,
337 {
338 let mut removed_total_count: u64 = 0;
339 self.token_count.retain(|token, count| {
340 if condition(token, count) {
341 removed_total_count += *count as u64;
342 false
343 } else {
344 true
345 }
346 });
347 self.total_token_count -= removed_total_count as u64;
348
349 removed_total_count
350 }
351
352 #[inline]
357 pub fn sorted_frequency_vector(&self) -> Vec<(String, u64)> {
358 let mut token_list: Vec<(String, u64)> = self.token_count
359 .iter()
360 .map(|(token, &count)| (token.clone(), count))
361 .collect();
362
363 token_list.sort_by(|a, b| b.1.cmp(&a.1));
364 token_list
365 }
366
367 #[inline]
372 pub fn sorted_dict_order_vector(&self) -> Vec<(String, u64)> {
373 let mut token_list: Vec<(String, u64)> = self.token_count
374 .iter()
375 .map(|(token, &count)| (token.clone(), count))
376 .collect();
377
378 token_list.sort_by(|a, b| a.0.cmp(&b.0));
379 token_list
380 }
381
382 #[inline]
388 pub fn unique_token_ratio(&self) -> f64 {
389 if self.total_token_count == 0 {
390 return 0.0;
391 }
392 self.token_count.len() as f64 / self.total_token_count as f64
393 }
394
395 #[inline]
398 pub fn probability_vector(&self) -> Vec<(String, f64)> {
399 if self.total_token_count == 0 {
400 return Vec::new();
401 }
402 let total = self.total_token_count as f64;
403 self.token_count
404 .iter()
405 .map(|(token, &count)| (token.clone(), (count as f64) / total))
406 .collect()
407 }
408
409 #[inline]
412 pub fn probability_vector_ref_str(&self) -> Vec<(&str, f64)> {
413 if self.total_token_count == 0 {
414 return Vec::new();
415 }
416 let total = self.total_token_count as f64;
417 self.token_count
418 .iter()
419 .map(|(token, &count)| (token.as_str(), (count as f64) / total))
420 .collect()
421 }
422
423 #[inline]
426 pub fn probability(&self, token: &str) -> f64 {
427 if self.total_token_count == 0 {
428 return 0.0;
429 }
430 (self.token_count(token) as f64) / (self.total_token_count as f64)
431 }
432
433 #[inline]
435 pub fn clear(&mut self) {
436 self.token_count.clear();
437 self.total_token_count = 0;
438 }
439
440 #[inline]
442 pub fn shrink_to_fit(&mut self) {
443 self.token_count.shrink_to_fit();
444 }
445}