tf_idf_vectorizer/vectorizer/
term.rs1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3use ahash::RandomState;
4use serde::{Deserialize, Serialize};
5
6use crate::Corpus;
7
8
9#[derive(Serialize, Deserialize, Debug, Clone)]
24pub struct TermFrequency {
25 term_count: HashMap<String, u64, RandomState>,
26 total_term_count: u64,
27}
28
29impl TermFrequency {
31 pub fn new() -> Self {
33 TermFrequency {
34 term_count: HashMap::with_hasher(RandomState::new()),
35 total_term_count: 0,
36 }
37 }
38
39 #[inline]
44 pub fn add_term(&mut self, term: &str) -> &mut Self {
45 let count = self.term_count.entry(term.to_string()).or_insert(0);
46 *count += 1;
47 self.total_term_count += 1;
48 self
49 }
50
51 #[inline]
56 pub fn add_terms<T>(&mut self, terms: &[T]) -> &mut Self
57 where T: AsRef<str>
58 {
59 for term in terms {
60 let term_str = term.as_ref();
61 self.add_term(term_str);
62 }
63 self
64 }
65
66 #[inline]
71 pub fn sub_term(&mut self, term: &str) -> &mut Self {
72 if let Some(count) = self.term_count.get_mut(term) {
73 if *count > 1 {
74 *count -= 1;
75 self.total_term_count -= 1;
76 } else if *count == 1 {
77 self.term_count.remove(term);
78 self.total_term_count -= 1;
79 }
80 }
81 self
82 }
83
84 #[inline]
89 pub fn sub_terms<T>(&mut self, terms: &[T]) -> &mut Self
90 where T: AsRef<str>
91 {
92 for term in terms {
93 let term_str = term.as_ref();
94 self.sub_term(term_str);
95 }
96 self
97 }
98
99 pub fn set_term_count(&mut self, term: &str, count: u64) -> &mut Self {
105 if count == 0 {
106 self.term_count.remove(term);
107 } else {
108 let current_count = self.term_count.entry(term.to_string()).or_insert(0);
109 self.total_term_count += count - *current_count;
110 *current_count = count;
111 }
112 self
113 }
114
115 pub fn add_terms_from_freq(&mut self, other: &TermFrequency) -> &mut Self {
119 for (term, &count) in &other.term_count {
120 let entry = self.term_count.entry(term.clone()).or_insert(0);
121 *entry += count;
122 self.total_term_count += count;
123 }
124 self
125 }
126
127 pub fn scale(&mut self, scalar: f64) -> &mut Self {
131 let mut total_count = 0;
132 self.term_count.iter_mut().for_each(|(_, count)| {
133 *count = ((*count as f64) * scalar).round() as u64;
134 total_count += *count;
135 });
136 self.total_term_count = total_count;
137 self
138 }
139}
140
141impl<T> From<&[T]> for TermFrequency
142where
143 T: AsRef<str>,
144{
145 fn from(terms: &[T]) -> Self {
146 let mut tf = TermFrequency::new();
147 tf.add_terms(terms);
148 tf
149 }
150}
151
152impl From<Corpus> for TermFrequency {
153 fn from(corpus: Corpus) -> Self {
154 let mut tf = TermFrequency::new();
155 for entry in corpus.term_counts.iter() {
156 let term = entry.key();
157 let count = *entry.value();
158 tf.set_term_count(term, count);
159 }
160 tf
161 }
162}
163
164impl TermFrequency {
166 #[inline]
171 pub fn iter(&self) -> impl Iterator<Item=(&str, u64)> {
172 self.term_count.iter().map(|(term, &count)| {
173 (term.as_str(), count)
174 })
175 }
176
177 #[inline]
182 pub fn term_count_vector(&self) -> Vec<(String, u64)> {
183 self.term_count.iter().map(|(term, &count)| {
184 (term.clone(), count)
185 }).collect()
186 }
187
188 #[inline]
193 pub fn term_count_vector_ref_str(&self) -> Vec<(&str, u64)> {
194 self.term_count.iter().map(|(term, &count)| {
195 (term.as_str(), count)
196 }).collect()
197 }
198
199 #[inline]
204 pub fn term_count_hashmap_ref_str(&self) -> HashMap<&str, u64, RandomState> {
205 self.term_count.iter().map(|(term, &count)| {
206 (term.as_str(), count)
207 }).collect()
208 }
209
210 #[inline]
215 pub fn term_sum(&self) -> u64 {
216 self.total_term_count
217 }
218
219 #[inline]
227 pub fn term_count(&self, term: &str) -> u64 {
228 *self.term_count.get(term).unwrap_or(&0)
229 }
230
231 #[inline]
237 pub fn most_frequent_terms_vector(&self) -> Vec<(String, u64)> {
238 if let Some(&max_count) = self.term_count.values().max() {
239 self.term_count.iter()
240 .filter(|&(_, &count)| count == max_count)
241 .map(|(term, &count)| (term.clone(), count))
242 .collect()
243 } else {
244 Vec::new()
245 }
246 }
247
248 #[inline]
253 pub fn most_frequent_term_count(&self) -> u64 {
254 if let Some(&max_count) = self.term_count.values().max() {
255 max_count
256 } else {
257 0
258 }
259 }
260
261 #[inline]
269 pub fn contains_term(&self, term: &str) -> bool {
270 self.term_count.contains_key(term)
271 }
272
273 #[inline]
278 pub fn term_set_iter(&self) -> impl Iterator<Item=&str> {
279 self.term_count.keys().map(|s| s.as_str())
280 }
281
282 #[inline]
287 pub fn term_set(&self) -> Vec<String> {
288 self.term_count.keys().cloned().collect()
289 }
290
291 #[inline]
296 pub fn term_set_ref_str(&self) -> Vec<&str> {
297 self.term_count.keys().map(|s| s.as_str()).collect()
298 }
299
300 #[inline]
305 pub fn term_hashset(&self) -> HashSet<String, RandomState> {
306 self.term_count.keys().cloned().collect()
307 }
308
309 #[inline]
314 pub fn term_hashset_ref_str(&self) -> HashSet<&str, RandomState> {
315 self.term_count.keys().map(|s| s.as_str()).collect()
316 }
317
318 #[inline]
323 pub fn term_num(&self) -> usize {
324 self.term_count.len()
325 }
326
327 #[inline]
335 pub fn remove_stop_terms(&mut self, stop_terms: &[&str]) -> u64{
336 let mut removed_total_count: u64 = 0;
337 for &stop_term in stop_terms {
338 if let Some(count) = self.term_count.remove(stop_term) {
339 removed_total_count += count as u64;
340 }
341 }
342 self.total_term_count -= removed_total_count;
343 removed_total_count
344 }
345
346 #[inline]
354 pub fn remove_terms_by<F>(&mut self, condition: F) -> u64
355 where
356 F: Fn(&str, &u64) -> bool,
357 {
358 let mut removed_total_count: u64 = 0;
359 self.term_count.retain(|term, count| {
360 if condition(term, count) {
361 removed_total_count += *count as u64;
362 false
363 } else {
364 true
365 }
366 });
367 self.total_term_count -= removed_total_count as u64;
368
369 removed_total_count
370 }
371
372 #[inline]
377 pub fn sorted_frequency_vector(&self) -> Vec<(String, u64)> {
378 let mut term_list: Vec<(String, u64)> = self.term_count
379 .iter()
380 .map(|(term, &count)| (term.clone(), count))
381 .collect();
382
383 term_list.sort_by(|a, b| b.1.cmp(&a.1));
384 term_list
385 }
386
387 #[inline]
392 pub fn sorted_dict_order_vector(&self) -> Vec<(String, u64)> {
393 let mut term_list: Vec<(String, u64)> = self.term_count
394 .iter()
395 .map(|(term, &count)| (term.clone(), count))
396 .collect();
397
398 term_list.sort_by(|a, b| a.0.cmp(&b.0));
399 term_list
400 }
401
402 #[inline]
408 pub fn unique_term_ratio(&self) -> f64 {
409 if self.total_term_count == 0 {
410 return 0.0;
411 }
412 self.term_count.len() as f64 / self.total_term_count as f64
413 }
414
415 #[inline]
418 pub fn probability_vector(&self) -> Vec<(String, f64)> {
419 if self.total_term_count == 0 {
420 return Vec::new();
421 }
422 let total = self.total_term_count as f64;
423 self.term_count
424 .iter()
425 .map(|(term, &count)| (term.clone(), (count as f64) / total))
426 .collect()
427 }
428
429 #[inline]
432 pub fn probability_vector_ref_str(&self) -> Vec<(&str, f64)> {
433 if self.total_term_count == 0 {
434 return Vec::new();
435 }
436 let total = self.total_term_count as f64;
437 self.term_count
438 .iter()
439 .map(|(term, &count)| (term.as_str(), (count as f64) / total))
440 .collect()
441 }
442
443 #[inline]
446 pub fn probability(&self, term: &str) -> f64 {
447 if self.total_term_count == 0 {
448 return 0.0;
449 }
450 (self.term_count(term) as f64) / (self.total_term_count as f64)
451 }
452
453 #[inline]
455 pub fn clear(&mut self) {
456 self.term_count.clear();
457 self.total_term_count = 0;
458 }
459
460 #[inline]
462 pub fn shrink_to_fit(&mut self) {
463 self.term_count.shrink_to_fit();
464 }
465}