tf_idf_vectorizer/vectorizer/
token.rs

1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3use ahash::RandomState;
4use serde::{Deserialize, Serialize};
5
6use crate::Corpus;
7
8
9/// TokenFrequency struct
10/// Manages the frequency of token occurrences.
11/// Counts the number of times each token appears.
12///
13/// # Examples
14/// ```
15/// use crate::tf_idf_vectorizer::vectorizer::token::TokenFrequency;
16/// let mut token_freq = TokenFrequency::new();
17/// token_freq.add_token("token1");
18/// token_freq.add_token("token2");
19/// token_freq.add_token("token1");
20///
21/// assert_eq!(token_freq.token_count("token1"), 2);
22/// ```
23#[derive(Serialize, Deserialize, Debug, Clone)]
24pub struct TokenFrequency {
25    token_count: HashMap<String, u64, RandomState>,
26    total_token_count: u64,
27}
28
29/// Implementation for adding and removing tokens
30impl TokenFrequency {
31    /// Create a new TokenFrequency
32    pub fn new() -> Self {
33        TokenFrequency {
34            token_count: HashMap::with_hasher(RandomState::new()),
35            total_token_count: 0,
36        }
37    }
38
39    /// Add a token
40    ///
41    /// # Arguments
42    /// * `token` - Token to add
43    #[inline]
44    pub fn add_token(&mut self, token: &str) -> &mut Self {
45        let count = self.token_count.entry(token.to_string()).or_insert(0);
46        *count += 1;
47        self.total_token_count += 1;
48        self
49    }
50
51    /// Add multiple tokens
52    ///
53    /// # Arguments
54    /// * `tokens` - Slice of tokens to add
55    #[inline]
56    pub fn add_tokens<T>(&mut self, tokens: &[T]) -> &mut Self 
57    where T: AsRef<str> 
58    {
59        for token in tokens {
60            let token_str = token.as_ref();
61            self.add_token(token_str);
62        }
63        self
64    }
65
66    /// Subtract a token
67    ///
68    /// # Arguments
69    /// * `token` - Token to subtract
70    #[inline]
71    pub fn sub_token(&mut self, token: &str) -> &mut Self {
72        if let Some(count) = self.token_count.get_mut(token) {
73            if *count > 1 {
74                *count -= 1;
75                self.total_token_count -= 1;
76            } else if *count == 1 {
77                self.token_count.remove(token);
78                self.total_token_count -= 1;
79            }
80        }
81        self
82    }
83
84    /// Subtract multiple tokens
85    ///
86    /// # Arguments
87    /// * `tokens` - Slice of tokens to subtract
88    #[inline]
89    pub fn sub_tokens<T>(&mut self, tokens: &[T]) -> &mut Self 
90    where T: AsRef<str>
91    {
92        for token in tokens {
93            let token_str = token.as_ref();
94            self.sub_token(token_str);
95        }
96        self
97    }
98
99    /// Set the occurrence count for a token
100    ///
101    /// # Arguments
102    /// * `token` - Token
103    /// * `count` - Occurrence count
104    pub fn set_token_count(&mut self, token: &str, count: u64) -> &mut Self {
105        if count == 0 {
106            self.token_count.remove(token);
107        } else {
108            let current_count = self.token_count.entry(token.to_string()).or_insert(0);
109            self.total_token_count += count - *current_count;
110            *current_count = count;
111        }
112        self
113    }
114
115    /// Merge with another TokenFrequency
116    /// # Arguments
117    /// * `other` - Another TokenFrequency to merge with
118    pub fn add_tokens_from_freq(&mut self, other: &TokenFrequency) -> &mut Self {
119        for (token, &count) in &other.token_count {
120            let entry = self.token_count.entry(token.clone()).or_insert(0);
121            *entry += count;
122            self.total_token_count += count;
123        }
124        self
125    }
126
127    /// Scale the token counts by a scalar
128    /// # Arguments
129    /// * `scalar` - Scalar to scale by
130    pub fn scale(&mut self, scalar: f64) -> &mut Self {
131        let mut total_count = 0;
132        self.token_count.iter_mut().for_each(|(_, count)| {
133            *count = ((*count as f64) * scalar).round() as u64;
134            total_count += *count;
135        });
136        self.total_token_count = total_count;
137        self
138    }
139}
140
141impl<T> From<&[T]> for TokenFrequency
142where
143    T: AsRef<str>,
144{
145    fn from(tokens: &[T]) -> Self {
146        let mut tf = TokenFrequency::new();
147        tf.add_tokens(tokens);
148        tf
149    }
150}
151
152impl From<Corpus> for TokenFrequency {
153    fn from(corpus: Corpus) -> Self {
154        let mut tf = TokenFrequency::new();
155        for entry in corpus.token_counts.iter() {
156            let token = entry.key();
157            let count = *entry.value();
158            tf.set_token_count(token, count);
159        }
160        tf
161    }
162}
163
164/// Implementation for retrieving information from TokenFrequency
165impl TokenFrequency {
166    /// Get a vector of all tokens and their counts
167    ///
168    /// # Returns
169    /// * `Vec<(String, u64)>` - Vector of tokens and their counts
170    #[inline]
171    pub fn token_count_vector(&self) -> Vec<(String, u64)> {
172        self.token_count.iter().map(|(token, &count)| {
173            (token.clone(), count)
174        }).collect()
175    }
176
177    /// Get a vector of all tokens and their counts (as &str)
178    ///
179    /// # Returns
180    /// * `Vec<(&str, u64)>` - Vector of tokens and their counts
181    #[inline]
182    pub fn token_count_vector_ref_str(&self) -> Vec<(&str, u64)> {
183        self.token_count.iter().map(|(token, &count)| {
184            (token.as_str(), count)
185        }).collect()
186    }
187
188    /// Get a hashmap of all tokens and their counts (as &str)
189    ///
190    /// # Returns
191    /// * `HashMap<&str, u64>` - HashMap of tokens and their counts
192    #[inline]
193    pub fn token_count_hashmap_ref_str(&self) -> HashMap<&str, u64, RandomState> {
194        self.token_count.iter().map(|(token, &count)| {
195            (token.as_str(), count)
196        }).collect()
197    }
198
199    /// Get the total count of all tokens
200    ///
201    /// # Returns
202    /// * `u64` - Total token count
203    #[inline]
204    pub fn token_sum(&self) -> u64 {
205        self.total_token_count
206    }
207
208    /// Get the occurrence count for a specific token
209    ///
210    /// # Arguments
211    /// * `token` - Token
212    ///
213    /// # Returns
214    /// * `u64` - Occurrence count for the token
215    #[inline]
216    pub fn token_count(&self, token: &str) -> u64 {
217        *self.token_count.get(token).unwrap_or(&0)
218    }
219
220    /// Get the most frequent tokens
221    /// If multiple tokens have the same count, all are returned
222    ///
223    /// # Returns
224    /// * `Vec<(String, u64)>` - Vector of most frequent tokens and their counts
225    #[inline]
226    pub fn most_frequent_tokens_vector(&self) -> Vec<(String, u64)> {
227        if let Some(&max_count) = self.token_count.values().max() {
228            self.token_count.iter()
229                .filter(|&(_, &count)| count == max_count)
230                .map(|(token, &count)| (token.clone(), count))
231                .collect()
232        } else {
233            Vec::new()
234        }
235    }
236
237    /// Get the count of the most frequent token
238    ///
239    /// # Returns
240    /// * `u64` - Count of the most frequent token
241    #[inline]
242    pub fn most_frequent_token_count(&self) -> u64 {
243        if let Some(&max_count) = self.token_count.values().max() {
244            max_count
245        } else {
246            0
247        }
248    }
249
250    /// Check if a token exists
251    ///
252    /// # Arguments
253    /// * `token` - Token
254    ///
255    /// # Returns
256    /// * `bool` - true if the token exists, false otherwise
257    #[inline]
258    pub fn contains_token(&self, token: &str) -> bool {
259        self.token_count.contains_key(token)
260    }
261
262    /// Get the set of tokens
263    ///
264    /// # Returns
265    /// * `Vec<String>` - Set of tokens
266    #[inline]
267    pub fn token_set(&self) -> Vec<String> {
268        self.token_count.keys().cloned().collect()
269    }
270
271    /// Get the set of tokens (as &str)
272    ///
273    /// # Returns
274    /// * `Vec<&str>` - Set of tokens
275    #[inline]
276    pub fn token_set_ref_str(&self) -> Vec<&str> {
277        self.token_count.keys().map(|s| s.as_str()).collect()
278    }
279
280    /// Get the set of tokens as a HashSet
281    ///
282    /// # Returns
283    /// * `HashSet<String>` - Set of tokens
284    #[inline]
285    pub fn token_hashset(&self) -> HashSet<String, RandomState> {
286        self.token_count.keys().cloned().collect()
287    }
288
289    /// Get the set of tokens as a HashSet (as &str)
290    ///
291    /// # Returns
292    /// * `HashSet<&str>` - Set of tokens
293    #[inline]
294    pub fn token_hashset_ref_str(&self) -> HashSet<&str, RandomState> {
295        self.token_count.keys().map(|s| s.as_str()).collect()
296    }
297
298    /// Get the number of unique tokens
299    ///
300    /// # Returns
301    /// * `usize` - Number of unique tokens
302    #[inline]
303    pub fn token_num(&self) -> usize {
304        self.token_count.len()
305    }
306
307    /// Remove stop tokens
308    ///
309    /// # Arguments
310    /// * `stop_tokens` - Slice of stop tokens to remove
311    ///
312    /// # Returns
313    /// * `u64` - Total count of removed tokens
314    #[inline]
315    pub fn remove_stop_tokens(&mut self, stop_tokens: &[&str]) -> u64{
316        let mut removed_total_count: u64 = 0;
317        for &stop_token in stop_tokens {
318            if let Some(count) = self.token_count.remove(stop_token) {
319                removed_total_count += count as u64;
320            }
321        }
322        self.total_token_count -= removed_total_count;
323        removed_total_count
324    }
325
326    /// Remove tokens by a condition
327    ///
328    /// # Arguments
329    /// * `condition` - Closure to determine which tokens to remove
330    ///
331    /// # Returns
332    /// * `u64` - Total count of removed tokens
333    #[inline]
334    pub fn remove_tokens_by<F>(&mut self, condition: F) -> u64
335    where
336        F: Fn(&str, &u64) -> bool,
337    {
338        let mut removed_total_count: u64 = 0;
339        self.token_count.retain(|token, count| {
340            if condition(token, count) {
341                removed_total_count += *count as u64;
342                false
343            } else {
344                true
345            }
346        });
347        self.total_token_count -= removed_total_count as u64;
348
349        removed_total_count
350    }
351
352    /// Get a vector of tokens sorted by frequency (descending)
353    ///
354    /// # Returns
355    /// * `Vec<(String, u64)>` - Vector of tokens sorted by frequency
356    #[inline]
357    pub fn sorted_frequency_vector(&self) -> Vec<(String, u64)> {
358        let mut token_list: Vec<(String, u64)> = self.token_count
359            .iter()
360            .map(|(token, &count)| (token.clone(), count))
361            .collect();
362
363        token_list.sort_by(|a, b| b.1.cmp(&a.1));
364        token_list
365    }
366
367    /// Get a vector of tokens sorted by dictionary order (ascending)
368    ///
369    /// # Returns
370    /// * `Vec<(String, u64)>` - Vector of tokens sorted by dictionary order
371    #[inline]
372    pub fn sorted_dict_order_vector(&self) -> Vec<(String, u64)> {
373        let mut token_list: Vec<(String, u64)> = self.token_count
374            .iter()
375            .map(|(token, &count)| (token.clone(), count))
376            .collect();
377
378        token_list.sort_by(|a, b| a.0.cmp(&b.0));
379        token_list
380    }
381
382    /// Calculate the diversity of tokens
383    /// 1.0 indicates complete diversity, 0.0 indicates no diversity
384    ///
385    /// # Returns
386    /// * `f64` - Diversity of tokens
387    #[inline]
388    pub fn unique_token_ratio(&self) -> f64 {
389        if self.total_token_count == 0 {
390            return 0.0;
391        }
392        self.token_count.len() as f64 / self.total_token_count as f64
393    }
394
395    /// Get the probability distribution P(token) (owned String version)
396    /// Returns an empty vector if total is 0
397    #[inline]
398    pub fn probability_vector(&self) -> Vec<(String, f64)> {
399        if self.total_token_count == 0 {
400            return Vec::new();
401        }
402        let total = self.total_token_count as f64;
403        self.token_count
404            .iter()
405            .map(|(token, &count)| (token.clone(), (count as f64) / total))
406            .collect()
407    }
408
409    /// Get the probability distribution P(token) (as &str)
410    /// Returns an empty vector if total is 0
411    #[inline]
412    pub fn probability_vector_ref_str(&self) -> Vec<(&str, f64)> {
413        if self.total_token_count == 0 {
414            return Vec::new();
415        }
416        let total = self.total_token_count as f64;
417        self.token_count
418            .iter()
419            .map(|(token, &count)| (token.as_str(), (count as f64) / total))
420            .collect()
421    }
422
423    /// Get the probability P(token) for a specific token
424    /// Returns 0.0 if total is 0
425    #[inline]
426    pub fn probability(&self, token: &str) -> f64 {
427        if self.total_token_count == 0 {
428            return 0.0;
429        }
430        (self.token_count(token) as f64) / (self.total_token_count as f64)
431    }
432
433    /// Reset all counts
434    #[inline]
435    pub fn clear(&mut self) {
436        self.token_count.clear();
437        self.total_token_count = 0;
438    }
439
440    /// Shrink internal storage to fit current size
441    #[inline]
442    pub fn shrink_to_fit(&mut self) {
443        self.token_count.shrink_to_fit();
444    }
445}