tf_idf_vectorizer/vectorizer/
token.rs1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3use ahash::RandomState;
4use serde::{Deserialize, Serialize};
5
6
7#[derive(Serialize, Deserialize, Debug, Clone)]
22pub struct TokenFrequency {
23 token_count: HashMap<String, u64, RandomState>,
24 total_token_count: u64,
25}
26
27impl TokenFrequency {
29 pub fn new() -> Self {
31 TokenFrequency {
32 token_count: HashMap::with_hasher(RandomState::new()),
33 total_token_count: 0,
34 }
35 }
36
37 #[inline]
42 pub fn add_token(&mut self, token: &str) -> &mut Self {
43 let count = self.token_count.entry(token.to_string()).or_insert(0);
44 *count += 1;
45 self.total_token_count += 1;
46 self
47 }
48
49 #[inline]
54 pub fn add_tokens<T>(&mut self, tokens: &[T]) -> &mut Self
55 where T: AsRef<str>
56 {
57 for token in tokens {
58 let token_str = token.as_ref();
59 self.add_token(token_str);
60 }
61 self
62 }
63
64 #[inline]
69 pub fn sub_token(&mut self, token: &str) -> &mut Self {
70 if let Some(count) = self.token_count.get_mut(token) {
71 if *count > 1 {
72 *count -= 1;
73 self.total_token_count -= 1;
74 } else if *count == 1 {
75 self.token_count.remove(token);
76 self.total_token_count -= 1;
77 }
78 }
79 self
80 }
81
82 #[inline]
87 pub fn sub_tokens<T>(&mut self, tokens: &[T]) -> &mut Self
88 where T: AsRef<str>
89 {
90 for token in tokens {
91 let token_str = token.as_ref();
92 self.sub_token(token_str);
93 }
94 self
95 }
96
97 pub fn set_token_count(&mut self, token: &str, count: u64) -> &mut Self {
103 if count == 0 {
104 self.token_count.remove(token);
105 } else {
106 let current_count = self.token_count.entry(token.to_string()).or_insert(0);
107 self.total_token_count += count - *current_count;
108 *current_count = count;
109 }
110 self
111 }
112
113 pub fn add_tokens_from_freq(&mut self, other: &TokenFrequency) -> &mut Self {
117 for (token, &count) in &other.token_count {
118 let entry = self.token_count.entry(token.clone()).or_insert(0);
119 *entry += count;
120 self.total_token_count += count;
121 }
122 self
123 }
124
125 pub fn scale(&mut self, scalar: f64) -> &mut Self {
129 let mut total_count = 0;
130 self.token_count.iter_mut().for_each(|(_, count)| {
131 *count = ((*count as f64) * scalar).round() as u64;
132 total_count += *count;
133 });
134 self.total_token_count = total_count;
135 self
136 }
137}
138
139impl TokenFrequency {
141 #[inline]
146 pub fn token_count_vector(&self) -> Vec<(String, u64)> {
147 self.token_count.iter().map(|(token, &count)| {
148 (token.clone(), count)
149 }).collect()
150 }
151
152 #[inline]
157 pub fn token_count_vector_ref_str(&self) -> Vec<(&str, u64)> {
158 self.token_count.iter().map(|(token, &count)| {
159 (token.as_str(), count)
160 }).collect()
161 }
162
163 #[inline]
168 pub fn token_count_hashmap_ref_str(&self) -> HashMap<&str, u64, RandomState> {
169 self.token_count.iter().map(|(token, &count)| {
170 (token.as_str(), count)
171 }).collect()
172 }
173
174 #[inline]
179 pub fn token_sum(&self) -> u64 {
180 self.total_token_count
181 }
182
183 #[inline]
191 pub fn token_count(&self, token: &str) -> u64 {
192 *self.token_count.get(token).unwrap_or(&0)
193 }
194
195 #[inline]
201 pub fn most_frequent_tokens_vector(&self) -> Vec<(String, u64)> {
202 if let Some(&max_count) = self.token_count.values().max() {
203 self.token_count.iter()
204 .filter(|&(_, &count)| count == max_count)
205 .map(|(token, &count)| (token.clone(), count))
206 .collect()
207 } else {
208 Vec::new()
209 }
210 }
211
212 #[inline]
217 pub fn most_frequent_token_count(&self) -> u64 {
218 if let Some(&max_count) = self.token_count.values().max() {
219 max_count
220 } else {
221 0
222 }
223 }
224
225 #[inline]
233 pub fn contains_token(&self, token: &str) -> bool {
234 self.token_count.contains_key(token)
235 }
236
237 #[inline]
242 pub fn token_set(&self) -> Vec<String> {
243 self.token_count.keys().cloned().collect()
244 }
245
246 #[inline]
251 pub fn token_set_ref_str(&self) -> Vec<&str> {
252 self.token_count.keys().map(|s| s.as_str()).collect()
253 }
254
255 #[inline]
260 pub fn token_hashset(&self) -> HashSet<String, RandomState> {
261 self.token_count.keys().cloned().collect()
262 }
263
264 #[inline]
269 pub fn token_hashset_ref_str(&self) -> HashSet<&str, RandomState> {
270 self.token_count.keys().map(|s| s.as_str()).collect()
271 }
272
273 #[inline]
278 pub fn token_num(&self) -> usize {
279 self.token_count.len()
280 }
281
282 #[inline]
290 pub fn remove_stop_tokens(&mut self, stop_tokens: &[&str]) -> u64{
291 let mut removed_total_count: u64 = 0;
292 for &stop_token in stop_tokens {
293 if let Some(count) = self.token_count.remove(stop_token) {
294 removed_total_count += count as u64;
295 }
296 }
297 self.total_token_count -= removed_total_count;
298 removed_total_count
299 }
300
301 #[inline]
309 pub fn remove_tokens_by<F>(&mut self, condition: F) -> u64
310 where
311 F: Fn(&str, &u64) -> bool,
312 {
313 let mut removed_total_count: u64 = 0;
314 self.token_count.retain(|token, count| {
315 if condition(token, count) {
316 removed_total_count += *count as u64;
317 false
318 } else {
319 true
320 }
321 });
322 self.total_token_count -= removed_total_count as u64;
323
324 removed_total_count
325 }
326
327 #[inline]
332 pub fn sorted_frequency_vector(&self) -> Vec<(String, u64)> {
333 let mut token_list: Vec<(String, u64)> = self.token_count
334 .iter()
335 .map(|(token, &count)| (token.clone(), count))
336 .collect();
337
338 token_list.sort_by(|a, b| b.1.cmp(&a.1));
339 token_list
340 }
341
342 #[inline]
347 pub fn sorted_dict_order_vector(&self) -> Vec<(String, u64)> {
348 let mut token_list: Vec<(String, u64)> = self.token_count
349 .iter()
350 .map(|(token, &count)| (token.clone(), count))
351 .collect();
352
353 token_list.sort_by(|a, b| a.0.cmp(&b.0));
354 token_list
355 }
356
357 #[inline]
363 pub fn unique_token_ratio(&self) -> f64 {
364 if self.total_token_count == 0 {
365 return 0.0;
366 }
367 self.token_count.len() as f64 / self.total_token_count as f64
368 }
369
370 #[inline]
373 pub fn probability_vector(&self) -> Vec<(String, f64)> {
374 if self.total_token_count == 0 {
375 return Vec::new();
376 }
377 let total = self.total_token_count as f64;
378 self.token_count
379 .iter()
380 .map(|(token, &count)| (token.clone(), (count as f64) / total))
381 .collect()
382 }
383
384 #[inline]
387 pub fn probability_vector_ref_str(&self) -> Vec<(&str, f64)> {
388 if self.total_token_count == 0 {
389 return Vec::new();
390 }
391 let total = self.total_token_count as f64;
392 self.token_count
393 .iter()
394 .map(|(token, &count)| (token.as_str(), (count as f64) / total))
395 .collect()
396 }
397
398 #[inline]
401 pub fn probability(&self, token: &str) -> f64 {
402 if self.total_token_count == 0 {
403 return 0.0;
404 }
405 (self.token_count(token) as f64) / (self.total_token_count as f64)
406 }
407
408 #[inline]
410 pub fn clear(&mut self) {
411 self.token_count.clear();
412 self.total_token_count = 0;
413 }
414
415 #[inline]
417 pub fn shrink_to_fit(&mut self) {
418 self.token_count.shrink_to_fit();
419 }
420}