tf_idf_vectorizer/vectorizer/
term.rs1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3use ahash::RandomState;
4use serde::{Deserialize, Serialize};
5
6use crate::Corpus;
7
8
9#[derive(Serialize, Deserialize, Debug, Clone)]
33pub struct TermFrequency {
34 term_count: HashMap<String, u64, RandomState>,
35 total_term_count: u64,
36}
37
38impl TermFrequency {
40 pub fn new() -> Self {
42 TermFrequency {
43 term_count: HashMap::with_hasher(RandomState::new()),
44 total_term_count: 0,
45 }
46 }
47
48 #[inline]
53 pub fn add_term(&mut self, term: &str) -> &mut Self {
54 let count = self.term_count.entry(term.to_string()).or_insert(0);
55 *count += 1;
56 self.total_term_count += 1;
57 self
58 }
59
60 #[inline]
65 pub fn add_terms<T>(&mut self, terms: &[T]) -> &mut Self
66 where T: AsRef<str>
67 {
68 for term in terms {
69 let term_str = term.as_ref();
70 self.add_term(term_str);
71 }
72 self
73 }
74
75 #[inline]
80 pub fn sub_term(&mut self, term: &str) -> &mut Self {
81 if let Some(count) = self.term_count.get_mut(term) {
82 if *count > 1 {
83 *count -= 1;
84 self.total_term_count -= 1;
85 } else if *count == 1 {
86 self.term_count.remove(term);
87 self.total_term_count -= 1;
88 }
89 }
90 self
91 }
92
93 #[inline]
98 pub fn sub_terms<T>(&mut self, terms: &[T]) -> &mut Self
99 where T: AsRef<str>
100 {
101 for term in terms {
102 let term_str = term.as_ref();
103 self.sub_term(term_str);
104 }
105 self
106 }
107
108 pub fn set_term_count(&mut self, term: &str, count: u64) -> &mut Self {
114 if count == 0 {
115 self.term_count.remove(term);
116 } else {
117 let current_count = self.term_count.entry(term.to_string()).or_insert(0);
118 self.total_term_count += count - *current_count;
119 *current_count = count;
120 }
121 self
122 }
123
124 pub fn add_terms_from_freq(&mut self, other: &TermFrequency) -> &mut Self {
128 for (term, &count) in &other.term_count {
129 let entry = self.term_count.entry(term.clone()).or_insert(0);
130 *entry += count;
131 self.total_term_count += count;
132 }
133 self
134 }
135
136 pub fn scale(&mut self, scalar: f64) -> &mut Self {
140 let mut total_count = 0;
141 self.term_count.iter_mut().for_each(|(_, count)| {
142 *count = ((*count as f64) * scalar).round() as u64;
143 total_count += *count;
144 });
145 self.total_term_count = total_count;
146 self
147 }
148}
149
150impl<T> From<&[T]> for TermFrequency
151where
152 T: AsRef<str>,
153{
154 fn from(terms: &[T]) -> Self {
155 let mut tf = TermFrequency::new();
156 tf.add_terms(terms);
157 tf
158 }
159}
160
161impl From<Corpus> for TermFrequency {
162 fn from(corpus: Corpus) -> Self {
163 let mut tf = TermFrequency::new();
164 for entry in corpus.term_counts.iter() {
165 let term = entry.key();
166 let count = *entry.value();
167 tf.set_term_count(term, count);
168 }
169 tf
170 }
171}
172
173impl TermFrequency {
175 #[inline]
180 pub fn iter(&self) -> impl Iterator<Item=(&str, u64)> {
181 self.term_count.iter().map(|(term, &count)| {
182 (term.as_str(), count)
183 })
184 }
185
186 #[inline]
191 pub fn term_count_vector(&self) -> Vec<(String, u64)> {
192 self.term_count.iter().map(|(term, &count)| {
193 (term.clone(), count)
194 }).collect()
195 }
196
197 #[inline]
202 pub fn term_count_vector_ref_str(&self) -> Vec<(&str, u64)> {
203 self.term_count.iter().map(|(term, &count)| {
204 (term.as_str(), count)
205 }).collect()
206 }
207
208 #[inline]
213 pub fn term_count_hashmap_ref_str(&self) -> HashMap<&str, u64, RandomState> {
214 self.term_count.iter().map(|(term, &count)| {
215 (term.as_str(), count)
216 }).collect()
217 }
218
219 #[inline]
224 pub fn term_sum(&self) -> u64 {
225 self.total_term_count
226 }
227
228 #[inline]
236 pub fn term_count(&self, term: &str) -> u64 {
237 *self.term_count.get(term).unwrap_or(&0)
238 }
239
240 #[inline]
246 pub fn most_frequent_terms_vector(&self) -> Vec<(String, u64)> {
247 if let Some(&max_count) = self.term_count.values().max() {
248 self.term_count.iter()
249 .filter(|&(_, &count)| count == max_count)
250 .map(|(term, &count)| (term.clone(), count))
251 .collect()
252 } else {
253 Vec::new()
254 }
255 }
256
257 #[inline]
262 pub fn most_frequent_term_count(&self) -> u64 {
263 if let Some(&max_count) = self.term_count.values().max() {
264 max_count
265 } else {
266 0
267 }
268 }
269
270 #[inline]
278 pub fn contains_term(&self, term: &str) -> bool {
279 self.term_count.contains_key(term)
280 }
281
282 #[inline]
287 pub fn term_set_iter(&self) -> impl Iterator<Item=&str> {
288 self.term_count.keys().map(|s| s.as_str())
289 }
290
291 #[inline]
296 pub fn term_set(&self) -> Vec<String> {
297 self.term_count.keys().cloned().collect()
298 }
299
300 #[inline]
305 pub fn term_set_ref_str(&self) -> Vec<&str> {
306 self.term_count.keys().map(|s| s.as_str()).collect()
307 }
308
309 #[inline]
314 pub fn term_hashset(&self) -> HashSet<String, RandomState> {
315 self.term_count.keys().cloned().collect()
316 }
317
318 #[inline]
323 pub fn term_hashset_ref_str(&self) -> HashSet<&str, RandomState> {
324 self.term_count.keys().map(|s| s.as_str()).collect()
325 }
326
327 #[inline]
332 pub fn term_num(&self) -> usize {
333 self.term_count.len()
334 }
335
336 #[inline]
344 pub fn remove_stop_terms(&mut self, stop_terms: &[&str]) -> u64{
345 let mut removed_total_count: u64 = 0;
346 for &stop_term in stop_terms {
347 if let Some(count) = self.term_count.remove(stop_term) {
348 removed_total_count += count as u64;
349 }
350 }
351 self.total_term_count -= removed_total_count;
352 removed_total_count
353 }
354
355 #[inline]
363 pub fn remove_terms_by<F>(&mut self, condition: F) -> u64
364 where
365 F: Fn(&str, &u64) -> bool,
366 {
367 let mut removed_total_count: u64 = 0;
368 self.term_count.retain(|term, count| {
369 if condition(term, count) {
370 removed_total_count += *count as u64;
371 false
372 } else {
373 true
374 }
375 });
376 self.total_term_count -= removed_total_count as u64;
377
378 removed_total_count
379 }
380
381 #[inline]
386 pub fn sorted_frequency_vector(&self) -> Vec<(String, u64)> {
387 let mut term_list: Vec<(String, u64)> = self.term_count
388 .iter()
389 .map(|(term, &count)| (term.clone(), count))
390 .collect();
391
392 term_list.sort_by(|a, b| b.1.cmp(&a.1));
393 term_list
394 }
395
396 #[inline]
401 pub fn sorted_dict_order_vector(&self) -> Vec<(String, u64)> {
402 let mut term_list: Vec<(String, u64)> = self.term_count
403 .iter()
404 .map(|(term, &count)| (term.clone(), count))
405 .collect();
406
407 term_list.sort_by(|a, b| a.0.cmp(&b.0));
408 term_list
409 }
410
411 #[inline]
417 pub fn unique_term_ratio(&self) -> f64 {
418 if self.total_term_count == 0 {
419 return 0.0;
420 }
421 self.term_count.len() as f64 / self.total_term_count as f64
422 }
423
424 #[inline]
427 pub fn probability_vector(&self) -> Vec<(String, f64)> {
428 if self.total_term_count == 0 {
429 return Vec::new();
430 }
431 let total = self.total_term_count as f64;
432 self.term_count
433 .iter()
434 .map(|(term, &count)| (term.clone(), (count as f64) / total))
435 .collect()
436 }
437
438 #[inline]
441 pub fn probability_vector_ref_str(&self) -> Vec<(&str, f64)> {
442 if self.total_term_count == 0 {
443 return Vec::new();
444 }
445 let total = self.total_term_count as f64;
446 self.term_count
447 .iter()
448 .map(|(term, &count)| (term.as_str(), (count as f64) / total))
449 .collect()
450 }
451
452 #[inline]
455 pub fn probability(&self, term: &str) -> f64 {
456 if self.total_term_count == 0 {
457 return 0.0;
458 }
459 (self.term_count(term) as f64) / (self.total_term_count as f64)
460 }
461
462 #[inline]
464 pub fn clear(&mut self) {
465 self.term_count.clear();
466 self.total_term_count = 0;
467 }
468
469 #[inline]
471 pub fn shrink_to_fit(&mut self) {
472 self.term_count.shrink_to_fit();
473 }
474}