tf_idf_vectorizer/vectorizer/
token.rs1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3use ahash::RandomState;
4use serde::{Deserialize, Serialize};
5
6
7#[derive(Serialize, Deserialize, Debug, Clone)]
22pub struct TokenFrequency {
23 token_count: HashMap<String, u64, RandomState>,
24 total_token_count: u64,
25}
26
27impl TokenFrequency {
29 pub fn new() -> Self {
31 TokenFrequency {
32 token_count: HashMap::with_hasher(RandomState::new()),
33 total_token_count: 0,
34 }
35 }
36
37 #[inline]
42 pub fn add_token(&mut self, token: &str) -> &mut Self {
43 let count = self.token_count.entry(token.to_string()).or_insert(0);
44 *count += 1;
45 self.total_token_count += 1;
46 self
47 }
48
49 #[inline]
54 pub fn add_tokens<T>(&mut self, tokens: &[T]) -> &mut Self
55 where T: AsRef<str>
56 {
57 for token in tokens {
58 let token_str = token.as_ref();
59 self.add_token(token_str);
60 }
61 self
62 }
63
64 #[inline]
69 pub fn sub_token(&mut self, token: &str) -> &mut Self {
70 if let Some(count) = self.token_count.get_mut(token) {
71 if *count > 1 {
72 *count -= 1;
73 self.total_token_count -= 1;
74 } else if *count == 1 {
75 self.token_count.remove(token);
76 self.total_token_count -= 1;
77 }
78 }
79 self
80 }
81
82 #[inline]
87 pub fn sub_tokens<T>(&mut self, tokens: &[T]) -> &mut Self
88 where T: AsRef<str>
89 {
90 for token in tokens {
91 let token_str = token.as_ref();
92 self.sub_token(token_str);
93 }
94 self
95 }
96
97 pub fn set_token_count(&mut self, token: &str, count: u64) -> &mut Self {
103 if count == 0 {
104 self.token_count.remove(token);
105 } else {
106 let current_count = self.token_count.entry(token.to_string()).or_insert(0);
107 self.total_token_count += count - *current_count;
108 *current_count = count;
109 }
110 self
111 }
112}
113
114impl TokenFrequency {
116 #[inline]
121 pub fn token_count_vector(&self) -> Vec<(String, u64)> {
122 self.token_count.iter().map(|(token, &count)| {
123 (token.clone(), count)
124 }).collect()
125 }
126
127 #[inline]
132 pub fn token_count_vector_ref_str(&self) -> Vec<(&str, u64)> {
133 self.token_count.iter().map(|(token, &count)| {
134 (token.as_str(), count)
135 }).collect()
136 }
137
138 #[inline]
143 pub fn token_count_hashmap_ref_str(&self) -> HashMap<&str, u64, RandomState> {
144 self.token_count.iter().map(|(token, &count)| {
145 (token.as_str(), count)
146 }).collect()
147 }
148
149 #[inline]
154 pub fn token_sum(&self) -> u64 {
155 self.total_token_count
156 }
157
158 #[inline]
166 pub fn token_count(&self, token: &str) -> u64 {
167 *self.token_count.get(token).unwrap_or(&0)
168 }
169
170 #[inline]
176 pub fn most_frequent_tokens_vector(&self) -> Vec<(String, u64)> {
177 if let Some(&max_count) = self.token_count.values().max() {
178 self.token_count.iter()
179 .filter(|&(_, &count)| count == max_count)
180 .map(|(token, &count)| (token.clone(), count))
181 .collect()
182 } else {
183 Vec::new()
184 }
185 }
186
187 #[inline]
192 pub fn most_frequent_token_count(&self) -> u64 {
193 if let Some(&max_count) = self.token_count.values().max() {
194 max_count
195 } else {
196 0
197 }
198 }
199
200 #[inline]
208 pub fn contains_token(&self, token: &str) -> bool {
209 self.token_count.contains_key(token)
210 }
211
212 #[inline]
217 pub fn token_set(&self) -> Vec<String> {
218 self.token_count.keys().cloned().collect()
219 }
220
221 #[inline]
226 pub fn token_set_ref_str(&self) -> Vec<&str> {
227 self.token_count.keys().map(|s| s.as_str()).collect()
228 }
229
230 #[inline]
235 pub fn token_hashset(&self) -> HashSet<String, RandomState> {
236 self.token_count.keys().cloned().collect()
237 }
238
239 #[inline]
244 pub fn token_hashset_ref_str(&self) -> HashSet<&str, RandomState> {
245 self.token_count.keys().map(|s| s.as_str()).collect()
246 }
247
248 #[inline]
253 pub fn token_num(&self) -> usize {
254 self.token_count.len()
255 }
256
257 #[inline]
265 pub fn remove_stop_tokens(&mut self, stop_tokens: &[&str]) -> u64{
266 let mut removed_total_count: u64 = 0;
267 for &stop_token in stop_tokens {
268 if let Some(count) = self.token_count.remove(stop_token) {
269 removed_total_count += count as u64;
270 }
271 }
272 self.total_token_count -= removed_total_count;
273 removed_total_count
274 }
275
276 #[inline]
284 pub fn remove_tokens_by<F>(&mut self, condition: F) -> u64
285 where
286 F: Fn(&str, &u64) -> bool,
287 {
288 let mut removed_total_count: u64 = 0;
289 self.token_count.retain(|token, count| {
290 if condition(token, count) {
291 removed_total_count += *count as u64;
292 false
293 } else {
294 true
295 }
296 });
297 self.total_token_count -= removed_total_count as u64;
298
299 removed_total_count
300 }
301
302 #[inline]
307 pub fn sorted_frequency_vector(&self) -> Vec<(String, u64)> {
308 let mut token_list: Vec<(String, u64)> = self.token_count
309 .iter()
310 .map(|(token, &count)| (token.clone(), count))
311 .collect();
312
313 token_list.sort_by(|a, b| b.1.cmp(&a.1));
314 token_list
315 }
316
317 #[inline]
322 pub fn sorted_dict_order_vector(&self) -> Vec<(String, u64)> {
323 let mut token_list: Vec<(String, u64)> = self.token_count
324 .iter()
325 .map(|(token, &count)| (token.clone(), count))
326 .collect();
327
328 token_list.sort_by(|a, b| a.0.cmp(&b.0));
329 token_list
330 }
331
332 #[inline]
338 pub fn unique_token_ratio(&self) -> f64 {
339 if self.total_token_count == 0 {
340 return 0.0;
341 }
342 self.token_count.len() as f64 / self.total_token_count as f64
343 }
344
345 #[inline]
348 pub fn probability_vector(&self) -> Vec<(String, f64)> {
349 if self.total_token_count == 0 {
350 return Vec::new();
351 }
352 let total = self.total_token_count as f64;
353 self.token_count
354 .iter()
355 .map(|(token, &count)| (token.clone(), (count as f64) / total))
356 .collect()
357 }
358
359 #[inline]
362 pub fn probability_vector_ref_str(&self) -> Vec<(&str, f64)> {
363 if self.total_token_count == 0 {
364 return Vec::new();
365 }
366 let total = self.total_token_count as f64;
367 self.token_count
368 .iter()
369 .map(|(token, &count)| (token.as_str(), (count as f64) / total))
370 .collect()
371 }
372
373 #[inline]
376 pub fn probability(&self, token: &str) -> f64 {
377 if self.total_token_count == 0 {
378 return 0.0;
379 }
380 (self.token_count(token) as f64) / (self.total_token_count as f64)
381 }
382
383 #[inline]
385 pub fn clear(&mut self) {
386 self.token_count.clear();
387 self.total_token_count = 0;
388 }
389
390 #[inline]
392 pub fn shrink_to_fit(&mut self) {
393 self.token_count.shrink_to_fit();
394 }
395}