tf_idf_vectorizer/vectorizer/
token.rs1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3use ahash::RandomState;
4use serde::{Deserialize, Serialize};
5
6
7#[derive(Serialize, Deserialize, Debug, Clone)]
21pub struct TokenFrequency {
22 token_count: HashMap<String, u64, RandomState>,
23 total_token_count: u64,
24}
25
26impl TokenFrequency {
28 pub fn new() -> Self {
30 TokenFrequency {
31 token_count: HashMap::with_hasher(RandomState::new()),
32 total_token_count: 0,
33 }
34 }
35
36 #[inline]
41 pub fn add_token(&mut self, token: &str) -> &mut Self {
42 let count = self.token_count.entry(token.to_string()).or_insert(0);
43 *count += 1;
44 self.total_token_count += 1;
45 self
46 }
47
48 #[inline]
53 pub fn add_tokens<T>(&mut self, tokens: &[T]) -> &mut Self
54 where T: AsRef<str>
55 {
56 for token in tokens {
57 let token_str = token.as_ref();
58 self.add_token(token_str);
59 }
60 self
61 }
62
63 #[inline]
68 pub fn sub_token(&mut self, token: &str) -> &mut Self {
69 if let Some(count) = self.token_count.get_mut(token) {
70 if *count > 1 {
71 *count -= 1;
72 self.total_token_count -= 1;
73 } else if *count == 1 {
74 self.token_count.remove(token);
75 self.total_token_count -= 1;
76 }
77 }
78 self
79 }
80
81 #[inline]
86 pub fn sub_tokens<T>(&mut self, tokens: &[T]) -> &mut Self
87 where T: AsRef<str>
88 {
89 for token in tokens {
90 let token_str = token.as_ref();
91 self.sub_token(token_str);
92 }
93 self
94 }
95
96 pub fn set_token_count(&mut self, token: &str, count: u64) -> &mut Self {
102 if count == 0 {
103 self.token_count.remove(token);
104 } else {
105 let current_count = self.token_count.entry(token.to_string()).or_insert(0);
106 self.total_token_count += count - *current_count;
107 *current_count = count;
108 }
109 self
110 }
111}
112
113impl TokenFrequency {
115 #[inline]
120 pub fn token_count_vector(&self) -> Vec<(String, u64)> {
121 self.token_count.iter().map(|(token, &count)| {
122 (token.clone(), count)
123 }).collect()
124 }
125
126 #[inline]
132 pub fn token_count_vector_ref_str(&self) -> Vec<(&str, u64)> {
133 self.token_count.iter().map(|(token, &count)| {
134 (token.as_str(), count)
135 }).collect()
136 }
137
138 #[inline]
144 pub fn token_count_hashmap_ref_str(&self) -> HashMap<&str, u64, RandomState> {
145 self.token_count.iter().map(|(token, &count)| {
146 (token.as_str(), count)
147 }).collect()
148 }
149
150 #[inline]
155 pub fn token_sum(&self) -> u64 {
156 self.total_token_count
157 }
158
159 #[inline]
167 pub fn token_count(&self, token: &str) -> u64 {
168 *self.token_count.get(token).unwrap_or(&0)
169 }
170
171 #[inline]
177 pub fn most_frequent_tokens_vector(&self) -> Vec<(String, u64)> {
178 if let Some(&max_count) = self.token_count.values().max() {
179 self.token_count.iter()
180 .filter(|&(_, &count)| count == max_count)
181 .map(|(token, &count)| (token.clone(), count))
182 .collect()
183 } else {
184 Vec::new()
185 }
186 }
187
188 #[inline]
193 pub fn most_frequent_token_count(&self) -> u64 {
194 if let Some(&max_count) = self.token_count.values().max() {
195 max_count
196 } else {
197 0
198 }
199 }
200
201 #[inline]
209 pub fn contains_token(&self, token: &str) -> bool {
210 self.token_count.contains_key(token)
211 }
212
213 #[inline]
218 pub fn token_set(&self) -> Vec<String> {
219 self.token_count.keys().cloned().collect()
220 }
221
222 #[inline]
228 pub fn token_set_ref_str(&self) -> Vec<&str> {
229 self.token_count.keys().map(|s| s.as_str()).collect()
230 }
231
232 #[inline]
237 pub fn token_hashset(&self) -> HashSet<String, RandomState> {
238 self.token_count.keys().cloned().collect()
239 }
240
241 #[inline]
247 pub fn token_hashset_ref_str(&self) -> HashSet<&str, RandomState> {
248 self.token_count.keys().map(|s| s.as_str()).collect()
249 }
250
251 #[inline]
256 pub fn token_num(&self) -> usize {
257 self.token_count.len()
258 }
259
260 #[inline]
268 pub fn remove_stop_tokens(&mut self, stop_tokens: &[&str]) -> u64{
269 let mut removed_total_count: u64 = 0;
270 for &stop_token in stop_tokens {
271 if let Some(count) = self.token_count.remove(stop_token) {
272 removed_total_count += count as u64;
273 }
274 }
275 self.total_token_count -= removed_total_count;
276 removed_total_count
277 }
278
279 #[inline]
287 pub fn remove_tokens_by<F>(&mut self, condition: F) -> u64
288 where
289 F: Fn(&str, &u64) -> bool,
290 {
291 let mut removed_total_count: u64 = 0;
292 self.token_count.retain(|token, count| {
293 if condition(token, count) {
294 removed_total_count += *count as u64;
295 false
296 } else {
297 true
298 }
299 });
300 self.total_token_count -= removed_total_count as u64;
301
302 removed_total_count
303 }
304
305 #[inline]
310 pub fn sorted_frequency_vector(&self) -> Vec<(String, u64)> {
311 let mut token_list: Vec<(String, u64)> = self.token_count
312 .iter()
313 .map(|(token, &count)| (token.clone(), count))
314 .collect();
315
316 token_list.sort_by(|a, b| b.1.cmp(&a.1));
317 token_list
318 }
319
320 #[inline]
325 pub fn sorted_dict_order_vector(&self) -> Vec<(String, u64)> {
326 let mut token_list: Vec<(String, u64)> = self.token_count
327 .iter()
328 .map(|(token, &count)| (token.clone(), count))
329 .collect();
330
331 token_list.sort_by(|a, b| a.0.cmp(&b.0));
332 token_list
333 }
334
335 #[inline]
341 pub fn unique_token_ratio(&self) -> f64 {
342 if self.total_token_count == 0 {
343 return 0.0;
344 }
345 self.token_count.len() as f64 / self.total_token_count as f64
346 }
347
348 #[inline]
350 pub fn clear(&mut self) {
351 self.token_count.clear();
352 self.total_token_count = 0;
353 }
354
355 #[inline]
357 pub fn shrink_to_fit(&mut self) {
358 self.token_count.shrink_to_fit();
359 }
360}