tf_idf_vectorizer/vectorizer/
token.rs1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3use ahash::RandomState;
4use serde::{Deserialize, Serialize};
5
6
7#[derive(Serialize, Deserialize, Debug, Clone)]
22pub struct TokenFrequency {
23 token_count: HashMap<String, u64, RandomState>,
24 total_token_count: u64,
25}
26
27impl TokenFrequency {
29 pub fn new() -> Self {
31 TokenFrequency {
32 token_count: HashMap::with_hasher(RandomState::new()),
33 total_token_count: 0,
34 }
35 }
36
37 #[inline]
42 pub fn add_token(&mut self, token: &str) -> &mut Self {
43 let count = self.token_count.entry(token.to_string()).or_insert(0);
44 *count += 1;
45 self.total_token_count += 1;
46 self
47 }
48
49 #[inline]
54 pub fn add_tokens<T>(&mut self, tokens: &[T]) -> &mut Self
55 where T: AsRef<str>
56 {
57 for token in tokens {
58 let token_str = token.as_ref();
59 self.add_token(token_str);
60 }
61 self
62 }
63
64 #[inline]
69 pub fn sub_token(&mut self, token: &str) -> &mut Self {
70 if let Some(count) = self.token_count.get_mut(token) {
71 if *count > 1 {
72 *count -= 1;
73 self.total_token_count -= 1;
74 } else if *count == 1 {
75 self.token_count.remove(token);
76 self.total_token_count -= 1;
77 }
78 }
79 self
80 }
81
82 #[inline]
87 pub fn sub_tokens<T>(&mut self, tokens: &[T]) -> &mut Self
88 where T: AsRef<str>
89 {
90 for token in tokens {
91 let token_str = token.as_ref();
92 self.sub_token(token_str);
93 }
94 self
95 }
96
97 pub fn set_token_count(&mut self, token: &str, count: u64) -> &mut Self {
103 if count == 0 {
104 self.token_count.remove(token);
105 } else {
106 let current_count = self.token_count.entry(token.to_string()).or_insert(0);
107 self.total_token_count += count - *current_count;
108 *current_count = count;
109 }
110 self
111 }
112
113 pub fn add_tokens_from_freq(&mut self, other: &TokenFrequency) -> &mut Self {
117 for (token, &count) in &other.token_count {
118 let entry = self.token_count.entry(token.clone()).or_insert(0);
119 *entry += count;
120 self.total_token_count += count;
121 }
122 self
123 }
124
125 pub fn scale(&mut self, scalar: f64) -> &mut Self {
129 let mut total_count = 0;
130 self.token_count.iter_mut().for_each(|(_, count)| {
131 *count = ((*count as f64) * scalar).round() as u64;
132 total_count += *count;
133 });
134 self.total_token_count = total_count;
135 self
136 }
137}
138
139impl<T> From<&[T]> for TokenFrequency
140where
141 T: AsRef<str>,
142{
143 fn from(tokens: &[T]) -> Self {
144 let mut tf = TokenFrequency::new();
145 tf.add_tokens(tokens);
146 tf
147 }
148}
149
150impl TokenFrequency {
152 #[inline]
157 pub fn token_count_vector(&self) -> Vec<(String, u64)> {
158 self.token_count.iter().map(|(token, &count)| {
159 (token.clone(), count)
160 }).collect()
161 }
162
163 #[inline]
168 pub fn token_count_vector_ref_str(&self) -> Vec<(&str, u64)> {
169 self.token_count.iter().map(|(token, &count)| {
170 (token.as_str(), count)
171 }).collect()
172 }
173
174 #[inline]
179 pub fn token_count_hashmap_ref_str(&self) -> HashMap<&str, u64, RandomState> {
180 self.token_count.iter().map(|(token, &count)| {
181 (token.as_str(), count)
182 }).collect()
183 }
184
185 #[inline]
190 pub fn token_sum(&self) -> u64 {
191 self.total_token_count
192 }
193
194 #[inline]
202 pub fn token_count(&self, token: &str) -> u64 {
203 *self.token_count.get(token).unwrap_or(&0)
204 }
205
206 #[inline]
212 pub fn most_frequent_tokens_vector(&self) -> Vec<(String, u64)> {
213 if let Some(&max_count) = self.token_count.values().max() {
214 self.token_count.iter()
215 .filter(|&(_, &count)| count == max_count)
216 .map(|(token, &count)| (token.clone(), count))
217 .collect()
218 } else {
219 Vec::new()
220 }
221 }
222
223 #[inline]
228 pub fn most_frequent_token_count(&self) -> u64 {
229 if let Some(&max_count) = self.token_count.values().max() {
230 max_count
231 } else {
232 0
233 }
234 }
235
236 #[inline]
244 pub fn contains_token(&self, token: &str) -> bool {
245 self.token_count.contains_key(token)
246 }
247
248 #[inline]
253 pub fn token_set(&self) -> Vec<String> {
254 self.token_count.keys().cloned().collect()
255 }
256
257 #[inline]
262 pub fn token_set_ref_str(&self) -> Vec<&str> {
263 self.token_count.keys().map(|s| s.as_str()).collect()
264 }
265
266 #[inline]
271 pub fn token_hashset(&self) -> HashSet<String, RandomState> {
272 self.token_count.keys().cloned().collect()
273 }
274
275 #[inline]
280 pub fn token_hashset_ref_str(&self) -> HashSet<&str, RandomState> {
281 self.token_count.keys().map(|s| s.as_str()).collect()
282 }
283
284 #[inline]
289 pub fn token_num(&self) -> usize {
290 self.token_count.len()
291 }
292
293 #[inline]
301 pub fn remove_stop_tokens(&mut self, stop_tokens: &[&str]) -> u64{
302 let mut removed_total_count: u64 = 0;
303 for &stop_token in stop_tokens {
304 if let Some(count) = self.token_count.remove(stop_token) {
305 removed_total_count += count as u64;
306 }
307 }
308 self.total_token_count -= removed_total_count;
309 removed_total_count
310 }
311
312 #[inline]
320 pub fn remove_tokens_by<F>(&mut self, condition: F) -> u64
321 where
322 F: Fn(&str, &u64) -> bool,
323 {
324 let mut removed_total_count: u64 = 0;
325 self.token_count.retain(|token, count| {
326 if condition(token, count) {
327 removed_total_count += *count as u64;
328 false
329 } else {
330 true
331 }
332 });
333 self.total_token_count -= removed_total_count as u64;
334
335 removed_total_count
336 }
337
338 #[inline]
343 pub fn sorted_frequency_vector(&self) -> Vec<(String, u64)> {
344 let mut token_list: Vec<(String, u64)> = self.token_count
345 .iter()
346 .map(|(token, &count)| (token.clone(), count))
347 .collect();
348
349 token_list.sort_by(|a, b| b.1.cmp(&a.1));
350 token_list
351 }
352
353 #[inline]
358 pub fn sorted_dict_order_vector(&self) -> Vec<(String, u64)> {
359 let mut token_list: Vec<(String, u64)> = self.token_count
360 .iter()
361 .map(|(token, &count)| (token.clone(), count))
362 .collect();
363
364 token_list.sort_by(|a, b| a.0.cmp(&b.0));
365 token_list
366 }
367
368 #[inline]
374 pub fn unique_token_ratio(&self) -> f64 {
375 if self.total_token_count == 0 {
376 return 0.0;
377 }
378 self.token_count.len() as f64 / self.total_token_count as f64
379 }
380
381 #[inline]
384 pub fn probability_vector(&self) -> Vec<(String, f64)> {
385 if self.total_token_count == 0 {
386 return Vec::new();
387 }
388 let total = self.total_token_count as f64;
389 self.token_count
390 .iter()
391 .map(|(token, &count)| (token.clone(), (count as f64) / total))
392 .collect()
393 }
394
395 #[inline]
398 pub fn probability_vector_ref_str(&self) -> Vec<(&str, f64)> {
399 if self.total_token_count == 0 {
400 return Vec::new();
401 }
402 let total = self.total_token_count as f64;
403 self.token_count
404 .iter()
405 .map(|(token, &count)| (token.as_str(), (count as f64) / total))
406 .collect()
407 }
408
409 #[inline]
412 pub fn probability(&self, token: &str) -> f64 {
413 if self.total_token_count == 0 {
414 return 0.0;
415 }
416 (self.token_count(token) as f64) / (self.total_token_count as f64)
417 }
418
419 #[inline]
421 pub fn clear(&mut self) {
422 self.token_count.clear();
423 self.total_token_count = 0;
424 }
425
426 #[inline]
428 pub fn shrink_to_fit(&mut self) {
429 self.token_count.shrink_to_fit();
430 }
431}