tf_idf_vectorizer/vectorizer/
term.rs

1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3use ahash::RandomState;
4use serde::{Deserialize, Serialize};
5
6use crate::Corpus;
7
8
9/// term Frequency Structure
10///
11/// Manages per-document term statistics used for TF calculation.
12///
13/// Tracks:
14/// - term occurrence counts
15/// - Total term count in the document
16///
17/// ### Use Cases
18/// - TF calculation
19/// - term-level statistics
20/// 
21/// # Examples
22/// ```
23/// use crate::tf_idf_vectorizer::vectorizer::term::TermFrequency;
24/// let mut term_freq = TermFrequency::new();
25/// term_freq.add_term("term1");
26/// term_freq.add_term("term2");
27/// term_freq.add_term("term1");
28///
29/// assert_eq!(term_freq.term_count("term1"), 2);
30/// ```
31/// 
32#[derive(Serialize, Deserialize, Debug, Clone)]
33pub struct TermFrequency {
34    term_count: HashMap<String, u64, RandomState>,
35    total_term_count: u64,
36}
37
38/// Implementation for adding and removing terms
39impl TermFrequency {
40    /// Create a new TermFrequency
41    pub fn new() -> Self {
42        TermFrequency {
43            term_count: HashMap::with_hasher(RandomState::new()),
44            total_term_count: 0,
45        }
46    }
47
48    /// Add a term
49    ///
50    /// # Arguments
51    /// * `term` - term to add
52    #[inline]
53    pub fn add_term(&mut self, term: &str) -> &mut Self {
54        let count = self.term_count.entry(term.to_string()).or_insert(0);
55        *count += 1;
56        self.total_term_count += 1;
57        self
58    }
59
60    /// Add multiple terms
61    ///
62    /// # Arguments
63    /// * `terms` - Slice of terms to add
64    #[inline]
65    pub fn add_terms<T>(&mut self, terms: &[T]) -> &mut Self 
66    where T: AsRef<str> 
67    {
68        for term in terms {
69            let term_str = term.as_ref();
70            self.add_term(term_str);
71        }
72        self
73    }
74
75    /// Subtract a term
76    ///
77    /// # Arguments
78    /// * `term` - term to subtract
79    #[inline]
80    pub fn sub_term(&mut self, term: &str) -> &mut Self {
81        if let Some(count) = self.term_count.get_mut(term) {
82            if *count > 1 {
83                *count -= 1;
84                self.total_term_count -= 1;
85            } else if *count == 1 {
86                self.term_count.remove(term);
87                self.total_term_count -= 1;
88            }
89        }
90        self
91    }
92
93    /// Subtract multiple terms
94    ///
95    /// # Arguments
96    /// * `terms` - Slice of terms to subtract
97    #[inline]
98    pub fn sub_terms<T>(&mut self, terms: &[T]) -> &mut Self 
99    where T: AsRef<str>
100    {
101        for term in terms {
102            let term_str = term.as_ref();
103            self.sub_term(term_str);
104        }
105        self
106    }
107
108    /// Set the occurrence count for a term
109    ///
110    /// # Arguments
111    /// * `term` - term
112    /// * `count` - Occurrence count
113    pub fn set_term_count(&mut self, term: &str, count: u64) -> &mut Self {
114        if count == 0 {
115            self.term_count.remove(term);
116        } else {
117            let current_count = self.term_count.entry(term.to_string()).or_insert(0);
118            self.total_term_count += count - *current_count;
119            *current_count = count;
120        }
121        self
122    }
123
124    /// Merge with another TermFrequency
125    /// # Arguments
126    /// * `other` - Another TermFrequency to merge with
127    pub fn add_terms_from_freq(&mut self, other: &TermFrequency) -> &mut Self {
128        for (term, &count) in &other.term_count {
129            let entry = self.term_count.entry(term.clone()).or_insert(0);
130            *entry += count;
131            self.total_term_count += count;
132        }
133        self
134    }
135
136    /// Scale the term counts by a scalar
137    /// # Arguments
138    /// * `scalar` - Scalar to scale by
139    pub fn scale(&mut self, scalar: f64) -> &mut Self {
140        let mut total_count = 0;
141        self.term_count.iter_mut().for_each(|(_, count)| {
142            *count = ((*count as f64) * scalar).round() as u64;
143            total_count += *count;
144        });
145        self.total_term_count = total_count;
146        self
147    }
148}
149
150impl<T> From<&[T]> for TermFrequency
151where
152    T: AsRef<str>,
153{
154    fn from(terms: &[T]) -> Self {
155        let mut tf = TermFrequency::new();
156        tf.add_terms(terms);
157        tf
158    }
159}
160
161impl From<Corpus> for TermFrequency {
162    fn from(corpus: Corpus) -> Self {
163        let mut tf = TermFrequency::new();
164        for entry in corpus.term_counts.iter() {
165            let term = entry.key();
166            let count = *entry.value();
167            tf.set_term_count(term, count);
168        }
169        tf
170    }
171}
172
173/// Implementation for retrieving information from TermFrequency
174impl TermFrequency {
175    /// Get iterator over all terms and their counts
176    /// 
177    /// # Returns
178    /// * `impl Iterator<Item=(&str, u64)>` - Iterator over terms and their counts
179    #[inline]
180    pub fn iter(&self) -> impl Iterator<Item=(&str, u64)> {
181        self.term_count.iter().map(|(term, &count)| {
182            (term.as_str(), count)
183        })
184    }
185
186    /// Get a vector of all terms and their counts
187    ///
188    /// # Returns
189    /// * `Vec<(String, u64)>` - Vector of terms and their counts
190    #[inline]
191    pub fn term_count_vector(&self) -> Vec<(String, u64)> {
192        self.term_count.iter().map(|(term, &count)| {
193            (term.clone(), count)
194        }).collect()
195    }
196
197    /// Get a vector of all terms and their counts (as &str)
198    ///
199    /// # Returns
200    /// * `Vec<(&str, u64)>` - Vector of terms and their counts
201    #[inline]
202    pub fn term_count_vector_ref_str(&self) -> Vec<(&str, u64)> {
203        self.term_count.iter().map(|(term, &count)| {
204            (term.as_str(), count)
205        }).collect()
206    }
207
208    /// Get a hashmap of all terms and their counts (as &str)
209    ///
210    /// # Returns
211    /// * `HashMap<&str, u64>` - HashMap of terms and their counts
212    #[inline]
213    pub fn term_count_hashmap_ref_str(&self) -> HashMap<&str, u64, RandomState> {
214        self.term_count.iter().map(|(term, &count)| {
215            (term.as_str(), count)
216        }).collect()
217    }
218
219    /// Get the total count of all terms
220    ///
221    /// # Returns
222    /// * `u64` - Total term count
223    #[inline]
224    pub fn term_sum(&self) -> u64 {
225        self.total_term_count
226    }
227
228    /// Get the occurrence count for a specific term
229    ///
230    /// # Arguments
231    /// * `term` - term
232    ///
233    /// # Returns
234    /// * `u64` - Occurrence count for the term
235    #[inline]
236    pub fn term_count(&self, term: &str) -> u64 {
237        *self.term_count.get(term).unwrap_or(&0)
238    }
239
240    /// Get the most frequent terms
241    /// If multiple terms have the same count, all are returned
242    ///
243    /// # Returns
244    /// * `Vec<(String, u64)>` - Vector of most frequent terms and their counts
245    #[inline]
246    pub fn most_frequent_terms_vector(&self) -> Vec<(String, u64)> {
247        if let Some(&max_count) = self.term_count.values().max() {
248            self.term_count.iter()
249                .filter(|&(_, &count)| count == max_count)
250                .map(|(term, &count)| (term.clone(), count))
251                .collect()
252        } else {
253            Vec::new()
254        }
255    }
256
257    /// Get the count of the most frequent term
258    ///
259    /// # Returns
260    /// * `u64` - Count of the most frequent term
261    #[inline]
262    pub fn most_frequent_term_count(&self) -> u64 {
263        if let Some(&max_count) = self.term_count.values().max() {
264            max_count
265        } else {
266            0
267        }
268    }
269
270    /// Check if a term exists
271    ///
272    /// # Arguments
273    /// * `term` - term
274    ///
275    /// # Returns
276    /// * `bool` - true if the term exists, false otherwise
277    #[inline]
278    pub fn contains_term(&self, term: &str) -> bool {
279        self.term_count.contains_key(term)
280    }
281
282    /// term_set_iter
283    /// 
284    /// # Returns
285    /// * `impl Iterator<Item=&str>` - Iterator over the set of terms
286    #[inline]
287    pub fn term_set_iter(&self) -> impl Iterator<Item=&str> {
288        self.term_count.keys().map(|s| s.as_str())
289    }
290
291    /// Get the set of terms
292    ///
293    /// # Returns
294    /// * `Vec<String>` - Set of terms
295    #[inline]
296    pub fn term_set(&self) -> Vec<String> {
297        self.term_count.keys().cloned().collect()
298    }
299
300    /// Get the set of terms (as &str)
301    ///
302    /// # Returns
303    /// * `Vec<&str>` - Set of terms
304    #[inline]
305    pub fn term_set_ref_str(&self) -> Vec<&str> {
306        self.term_count.keys().map(|s| s.as_str()).collect()
307    }
308
309    /// Get the set of terms as a HashSet
310    ///
311    /// # Returns
312    /// * `HashSet<String>` - Set of terms
313    #[inline]
314    pub fn term_hashset(&self) -> HashSet<String, RandomState> {
315        self.term_count.keys().cloned().collect()
316    }
317
318    /// Get the set of terms as a HashSet (as &str)
319    ///
320    /// # Returns
321    /// * `HashSet<&str>` - Set of terms
322    #[inline]
323    pub fn term_hashset_ref_str(&self) -> HashSet<&str, RandomState> {
324        self.term_count.keys().map(|s| s.as_str()).collect()
325    }
326
327    /// Get the number of unique terms
328    ///
329    /// # Returns
330    /// * `usize` - Number of unique terms
331    #[inline]
332    pub fn term_num(&self) -> usize {
333        self.term_count.len()
334    }
335
336    /// Remove stop terms
337    ///
338    /// # Arguments
339    /// * `stop_terms` - Slice of stop terms to remove
340    ///
341    /// # Returns
342    /// * `u64` - Total count of removed terms
343    #[inline]
344    pub fn remove_stop_terms(&mut self, stop_terms: &[&str]) -> u64{
345        let mut removed_total_count: u64 = 0;
346        for &stop_term in stop_terms {
347            if let Some(count) = self.term_count.remove(stop_term) {
348                removed_total_count += count as u64;
349            }
350        }
351        self.total_term_count -= removed_total_count;
352        removed_total_count
353    }
354
355    /// Remove terms by a condition
356    ///
357    /// # Arguments
358    /// * `condition` - Closure to determine which terms to remove
359    ///
360    /// # Returns
361    /// * `u64` - Total count of removed terms
362    #[inline]
363    pub fn remove_terms_by<F>(&mut self, condition: F) -> u64
364    where
365        F: Fn(&str, &u64) -> bool,
366    {
367        let mut removed_total_count: u64 = 0;
368        self.term_count.retain(|term, count| {
369            if condition(term, count) {
370                removed_total_count += *count as u64;
371                false
372            } else {
373                true
374            }
375        });
376        self.total_term_count -= removed_total_count as u64;
377
378        removed_total_count
379    }
380
381    /// Get a vector of terms sorted by frequency (descending)
382    ///
383    /// # Returns
384    /// * `Vec<(String, u64)>` - Vector of terms sorted by frequency
385    #[inline]
386    pub fn sorted_frequency_vector(&self) -> Vec<(String, u64)> {
387        let mut term_list: Vec<(String, u64)> = self.term_count
388            .iter()
389            .map(|(term, &count)| (term.clone(), count))
390            .collect();
391
392        term_list.sort_by(|a, b| b.1.cmp(&a.1));
393        term_list
394    }
395
396    /// Get a vector of terms sorted by dictionary order (ascending)
397    ///
398    /// # Returns
399    /// * `Vec<(String, u64)>` - Vector of terms sorted by dictionary order
400    #[inline]
401    pub fn sorted_dict_order_vector(&self) -> Vec<(String, u64)> {
402        let mut term_list: Vec<(String, u64)> = self.term_count
403            .iter()
404            .map(|(term, &count)| (term.clone(), count))
405            .collect();
406
407        term_list.sort_by(|a, b| a.0.cmp(&b.0));
408        term_list
409    }
410
411    /// Calculate the diversity of terms
412    /// 1.0 indicates complete diversity, 0.0 indicates no diversity
413    ///
414    /// # Returns
415    /// * `f64` - Diversity of terms
416    #[inline]
417    pub fn unique_term_ratio(&self) -> f64 {
418        if self.total_term_count == 0 {
419            return 0.0;
420        }
421        self.term_count.len() as f64 / self.total_term_count as f64
422    }
423
424    /// Get the probability distribution P(term) (owned String version)
425    /// Returns an empty vector if total is 0
426    #[inline]
427    pub fn probability_vector(&self) -> Vec<(String, f64)> {
428        if self.total_term_count == 0 {
429            return Vec::new();
430        }
431        let total = self.total_term_count as f64;
432        self.term_count
433            .iter()
434            .map(|(term, &count)| (term.clone(), (count as f64) / total))
435            .collect()
436    }
437
438    /// Get the probability distribution P(term) (as &str)
439    /// Returns an empty vector if total is 0
440    #[inline]
441    pub fn probability_vector_ref_str(&self) -> Vec<(&str, f64)> {
442        if self.total_term_count == 0 {
443            return Vec::new();
444        }
445        let total = self.total_term_count as f64;
446        self.term_count
447            .iter()
448            .map(|(term, &count)| (term.as_str(), (count as f64) / total))
449            .collect()
450    }
451
452    /// Get the probability P(term) for a specific term
453    /// Returns 0.0 if total is 0
454    #[inline]
455    pub fn probability(&self, term: &str) -> f64 {
456        if self.total_term_count == 0 {
457            return 0.0;
458        }
459        (self.term_count(term) as f64) / (self.total_term_count as f64)
460    }
461
462    /// Reset all counts
463    #[inline]
464    pub fn clear(&mut self) {
465        self.term_count.clear();
466        self.total_term_count = 0;
467    }
468
469    /// Shrink internal storage to fit current size
470    #[inline]
471    pub fn shrink_to_fit(&mut self) {
472        self.term_count.shrink_to_fit();
473    }
474}