tf_idf_vectorizer/vectorizer/
term.rs

1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3use ahash::RandomState;
4use serde::{Deserialize, Serialize};
5
6use crate::Corpus;
7
8
9/// TermFrequency struct
10/// Manages the frequency of term occurrences.
11/// Counts the number of times each term appears.
12///
13/// # Examples
14/// ```
15/// use crate::tf_idf_vectorizer::vectorizer::term::TermFrequency;
16/// let mut term_freq = TermFrequency::new();
17/// term_freq.add_term("term1");
18/// term_freq.add_term("term2");
19/// term_freq.add_term("term1");
20///
21/// assert_eq!(term_freq.term_count("term1"), 2);
22/// ```
23#[derive(Serialize, Deserialize, Debug, Clone)]
24pub struct TermFrequency {
25    term_count: HashMap<String, u64, RandomState>,
26    total_term_count: u64,
27}
28
29/// Implementation for adding and removing terms
30impl TermFrequency {
31    /// Create a new TermFrequency
32    pub fn new() -> Self {
33        TermFrequency {
34            term_count: HashMap::with_hasher(RandomState::new()),
35            total_term_count: 0,
36        }
37    }
38
39    /// Add a term
40    ///
41    /// # Arguments
42    /// * `term` - term to add
43    #[inline]
44    pub fn add_term(&mut self, term: &str) -> &mut Self {
45        let count = self.term_count.entry(term.to_string()).or_insert(0);
46        *count += 1;
47        self.total_term_count += 1;
48        self
49    }
50
51    /// Add multiple terms
52    ///
53    /// # Arguments
54    /// * `terms` - Slice of terms to add
55    #[inline]
56    pub fn add_terms<T>(&mut self, terms: &[T]) -> &mut Self 
57    where T: AsRef<str> 
58    {
59        for term in terms {
60            let term_str = term.as_ref();
61            self.add_term(term_str);
62        }
63        self
64    }
65
66    /// Subtract a term
67    ///
68    /// # Arguments
69    /// * `term` - term to subtract
70    #[inline]
71    pub fn sub_term(&mut self, term: &str) -> &mut Self {
72        if let Some(count) = self.term_count.get_mut(term) {
73            if *count > 1 {
74                *count -= 1;
75                self.total_term_count -= 1;
76            } else if *count == 1 {
77                self.term_count.remove(term);
78                self.total_term_count -= 1;
79            }
80        }
81        self
82    }
83
84    /// Subtract multiple terms
85    ///
86    /// # Arguments
87    /// * `terms` - Slice of terms to subtract
88    #[inline]
89    pub fn sub_terms<T>(&mut self, terms: &[T]) -> &mut Self 
90    where T: AsRef<str>
91    {
92        for term in terms {
93            let term_str = term.as_ref();
94            self.sub_term(term_str);
95        }
96        self
97    }
98
99    /// Set the occurrence count for a term
100    ///
101    /// # Arguments
102    /// * `term` - term
103    /// * `count` - Occurrence count
104    pub fn set_term_count(&mut self, term: &str, count: u64) -> &mut Self {
105        if count == 0 {
106            self.term_count.remove(term);
107        } else {
108            let current_count = self.term_count.entry(term.to_string()).or_insert(0);
109            self.total_term_count += count - *current_count;
110            *current_count = count;
111        }
112        self
113    }
114
115    /// Merge with another TermFrequency
116    /// # Arguments
117    /// * `other` - Another TermFrequency to merge with
118    pub fn add_terms_from_freq(&mut self, other: &TermFrequency) -> &mut Self {
119        for (term, &count) in &other.term_count {
120            let entry = self.term_count.entry(term.clone()).or_insert(0);
121            *entry += count;
122            self.total_term_count += count;
123        }
124        self
125    }
126
127    /// Scale the term counts by a scalar
128    /// # Arguments
129    /// * `scalar` - Scalar to scale by
130    pub fn scale(&mut self, scalar: f64) -> &mut Self {
131        let mut total_count = 0;
132        self.term_count.iter_mut().for_each(|(_, count)| {
133            *count = ((*count as f64) * scalar).round() as u64;
134            total_count += *count;
135        });
136        self.total_term_count = total_count;
137        self
138    }
139}
140
141impl<T> From<&[T]> for TermFrequency
142where
143    T: AsRef<str>,
144{
145    fn from(terms: &[T]) -> Self {
146        let mut tf = TermFrequency::new();
147        tf.add_terms(terms);
148        tf
149    }
150}
151
152impl From<Corpus> for TermFrequency {
153    fn from(corpus: Corpus) -> Self {
154        let mut tf = TermFrequency::new();
155        for entry in corpus.term_counts.iter() {
156            let term = entry.key();
157            let count = *entry.value();
158            tf.set_term_count(term, count);
159        }
160        tf
161    }
162}
163
164/// Implementation for retrieving information from TermFrequency
165impl TermFrequency {
166    /// Get iterator over all terms and their counts
167    /// 
168    /// # Returns
169    /// * `impl Iterator<Item=(&str, u64)>` - Iterator over terms and their counts
170    #[inline]
171    pub fn iter(&self) -> impl Iterator<Item=(&str, u64)> {
172        self.term_count.iter().map(|(term, &count)| {
173            (term.as_str(), count)
174        })
175    }
176
177    /// Get a vector of all terms and their counts
178    ///
179    /// # Returns
180    /// * `Vec<(String, u64)>` - Vector of terms and their counts
181    #[inline]
182    pub fn term_count_vector(&self) -> Vec<(String, u64)> {
183        self.term_count.iter().map(|(term, &count)| {
184            (term.clone(), count)
185        }).collect()
186    }
187
188    /// Get a vector of all terms and their counts (as &str)
189    ///
190    /// # Returns
191    /// * `Vec<(&str, u64)>` - Vector of terms and their counts
192    #[inline]
193    pub fn term_count_vector_ref_str(&self) -> Vec<(&str, u64)> {
194        self.term_count.iter().map(|(term, &count)| {
195            (term.as_str(), count)
196        }).collect()
197    }
198
199    /// Get a hashmap of all terms and their counts (as &str)
200    ///
201    /// # Returns
202    /// * `HashMap<&str, u64>` - HashMap of terms and their counts
203    #[inline]
204    pub fn term_count_hashmap_ref_str(&self) -> HashMap<&str, u64, RandomState> {
205        self.term_count.iter().map(|(term, &count)| {
206            (term.as_str(), count)
207        }).collect()
208    }
209
210    /// Get the total count of all terms
211    ///
212    /// # Returns
213    /// * `u64` - Total term count
214    #[inline]
215    pub fn term_sum(&self) -> u64 {
216        self.total_term_count
217    }
218
219    /// Get the occurrence count for a specific term
220    ///
221    /// # Arguments
222    /// * `term` - term
223    ///
224    /// # Returns
225    /// * `u64` - Occurrence count for the term
226    #[inline]
227    pub fn term_count(&self, term: &str) -> u64 {
228        *self.term_count.get(term).unwrap_or(&0)
229    }
230
231    /// Get the most frequent terms
232    /// If multiple terms have the same count, all are returned
233    ///
234    /// # Returns
235    /// * `Vec<(String, u64)>` - Vector of most frequent terms and their counts
236    #[inline]
237    pub fn most_frequent_terms_vector(&self) -> Vec<(String, u64)> {
238        if let Some(&max_count) = self.term_count.values().max() {
239            self.term_count.iter()
240                .filter(|&(_, &count)| count == max_count)
241                .map(|(term, &count)| (term.clone(), count))
242                .collect()
243        } else {
244            Vec::new()
245        }
246    }
247
248    /// Get the count of the most frequent term
249    ///
250    /// # Returns
251    /// * `u64` - Count of the most frequent term
252    #[inline]
253    pub fn most_frequent_term_count(&self) -> u64 {
254        if let Some(&max_count) = self.term_count.values().max() {
255            max_count
256        } else {
257            0
258        }
259    }
260
261    /// Check if a term exists
262    ///
263    /// # Arguments
264    /// * `term` - term
265    ///
266    /// # Returns
267    /// * `bool` - true if the term exists, false otherwise
268    #[inline]
269    pub fn contains_term(&self, term: &str) -> bool {
270        self.term_count.contains_key(term)
271    }
272
273    /// term_set_iter
274    /// 
275    /// # Returns
276    /// * `impl Iterator<Item=&str>` - Iterator over the set of terms
277    #[inline]
278    pub fn term_set_iter(&self) -> impl Iterator<Item=&str> {
279        self.term_count.keys().map(|s| s.as_str())
280    }
281
282    /// Get the set of terms
283    ///
284    /// # Returns
285    /// * `Vec<String>` - Set of terms
286    #[inline]
287    pub fn term_set(&self) -> Vec<String> {
288        self.term_count.keys().cloned().collect()
289    }
290
291    /// Get the set of terms (as &str)
292    ///
293    /// # Returns
294    /// * `Vec<&str>` - Set of terms
295    #[inline]
296    pub fn term_set_ref_str(&self) -> Vec<&str> {
297        self.term_count.keys().map(|s| s.as_str()).collect()
298    }
299
300    /// Get the set of terms as a HashSet
301    ///
302    /// # Returns
303    /// * `HashSet<String>` - Set of terms
304    #[inline]
305    pub fn term_hashset(&self) -> HashSet<String, RandomState> {
306        self.term_count.keys().cloned().collect()
307    }
308
309    /// Get the set of terms as a HashSet (as &str)
310    ///
311    /// # Returns
312    /// * `HashSet<&str>` - Set of terms
313    #[inline]
314    pub fn term_hashset_ref_str(&self) -> HashSet<&str, RandomState> {
315        self.term_count.keys().map(|s| s.as_str()).collect()
316    }
317
318    /// Get the number of unique terms
319    ///
320    /// # Returns
321    /// * `usize` - Number of unique terms
322    #[inline]
323    pub fn term_num(&self) -> usize {
324        self.term_count.len()
325    }
326
327    /// Remove stop terms
328    ///
329    /// # Arguments
330    /// * `stop_terms` - Slice of stop terms to remove
331    ///
332    /// # Returns
333    /// * `u64` - Total count of removed terms
334    #[inline]
335    pub fn remove_stop_terms(&mut self, stop_terms: &[&str]) -> u64{
336        let mut removed_total_count: u64 = 0;
337        for &stop_term in stop_terms {
338            if let Some(count) = self.term_count.remove(stop_term) {
339                removed_total_count += count as u64;
340            }
341        }
342        self.total_term_count -= removed_total_count;
343        removed_total_count
344    }
345
346    /// Remove terms by a condition
347    ///
348    /// # Arguments
349    /// * `condition` - Closure to determine which terms to remove
350    ///
351    /// # Returns
352    /// * `u64` - Total count of removed terms
353    #[inline]
354    pub fn remove_terms_by<F>(&mut self, condition: F) -> u64
355    where
356        F: Fn(&str, &u64) -> bool,
357    {
358        let mut removed_total_count: u64 = 0;
359        self.term_count.retain(|term, count| {
360            if condition(term, count) {
361                removed_total_count += *count as u64;
362                false
363            } else {
364                true
365            }
366        });
367        self.total_term_count -= removed_total_count as u64;
368
369        removed_total_count
370    }
371
372    /// Get a vector of terms sorted by frequency (descending)
373    ///
374    /// # Returns
375    /// * `Vec<(String, u64)>` - Vector of terms sorted by frequency
376    #[inline]
377    pub fn sorted_frequency_vector(&self) -> Vec<(String, u64)> {
378        let mut term_list: Vec<(String, u64)> = self.term_count
379            .iter()
380            .map(|(term, &count)| (term.clone(), count))
381            .collect();
382
383        term_list.sort_by(|a, b| b.1.cmp(&a.1));
384        term_list
385    }
386
387    /// Get a vector of terms sorted by dictionary order (ascending)
388    ///
389    /// # Returns
390    /// * `Vec<(String, u64)>` - Vector of terms sorted by dictionary order
391    #[inline]
392    pub fn sorted_dict_order_vector(&self) -> Vec<(String, u64)> {
393        let mut term_list: Vec<(String, u64)> = self.term_count
394            .iter()
395            .map(|(term, &count)| (term.clone(), count))
396            .collect();
397
398        term_list.sort_by(|a, b| a.0.cmp(&b.0));
399        term_list
400    }
401
402    /// Calculate the diversity of terms
403    /// 1.0 indicates complete diversity, 0.0 indicates no diversity
404    ///
405    /// # Returns
406    /// * `f64` - Diversity of terms
407    #[inline]
408    pub fn unique_term_ratio(&self) -> f64 {
409        if self.total_term_count == 0 {
410            return 0.0;
411        }
412        self.term_count.len() as f64 / self.total_term_count as f64
413    }
414
415    /// Get the probability distribution P(term) (owned String version)
416    /// Returns an empty vector if total is 0
417    #[inline]
418    pub fn probability_vector(&self) -> Vec<(String, f64)> {
419        if self.total_term_count == 0 {
420            return Vec::new();
421        }
422        let total = self.total_term_count as f64;
423        self.term_count
424            .iter()
425            .map(|(term, &count)| (term.clone(), (count as f64) / total))
426            .collect()
427    }
428
429    /// Get the probability distribution P(term) (as &str)
430    /// Returns an empty vector if total is 0
431    #[inline]
432    pub fn probability_vector_ref_str(&self) -> Vec<(&str, f64)> {
433        if self.total_term_count == 0 {
434            return Vec::new();
435        }
436        let total = self.total_term_count as f64;
437        self.term_count
438            .iter()
439            .map(|(term, &count)| (term.as_str(), (count as f64) / total))
440            .collect()
441    }
442
443    /// Get the probability P(term) for a specific term
444    /// Returns 0.0 if total is 0
445    #[inline]
446    pub fn probability(&self, term: &str) -> f64 {
447        if self.total_term_count == 0 {
448            return 0.0;
449        }
450        (self.term_count(term) as f64) / (self.total_term_count as f64)
451    }
452
453    /// Reset all counts
454    #[inline]
455    pub fn clear(&mut self) {
456        self.term_count.clear();
457        self.total_term_count = 0;
458    }
459
460    /// Shrink internal storage to fit current size
461    #[inline]
462    pub fn shrink_to_fit(&mut self) {
463        self.term_count.shrink_to_fit();
464    }
465}