tf_idf_vectorizer/vectorizer/
token.rs

1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3use ahash::RandomState;
4use serde::{Deserialize, Serialize};
5
6use crate::Corpus;
7
8
9/// TokenFrequency struct
10/// Manages the frequency of token occurrences.
11/// Counts the number of times each token appears.
12///
13/// # Examples
14/// ```
15/// use crate::tf_idf_vectorizer::vectorizer::token::TokenFrequency;
16/// let mut token_freq = TokenFrequency::new();
17/// token_freq.add_token("token1");
18/// token_freq.add_token("token2");
19/// token_freq.add_token("token1");
20///
21/// assert_eq!(token_freq.token_count("token1"), 2);
22/// ```
23#[derive(Serialize, Deserialize, Debug, Clone)]
24pub struct TokenFrequency {
25    token_count: HashMap<String, u64, RandomState>,
26    total_token_count: u64,
27}
28
29/// Implementation for adding and removing tokens
30impl TokenFrequency {
31    /// Create a new TokenFrequency
32    pub fn new() -> Self {
33        TokenFrequency {
34            token_count: HashMap::with_hasher(RandomState::new()),
35            total_token_count: 0,
36        }
37    }
38
39    /// Add a token
40    ///
41    /// # Arguments
42    /// * `token` - Token to add
43    #[inline]
44    pub fn add_token(&mut self, token: &str) -> &mut Self {
45        let count = self.token_count.entry(token.to_string()).or_insert(0);
46        *count += 1;
47        self.total_token_count += 1;
48        self
49    }
50
51    /// Add multiple tokens
52    ///
53    /// # Arguments
54    /// * `tokens` - Slice of tokens to add
55    #[inline]
56    pub fn add_tokens<T>(&mut self, tokens: &[T]) -> &mut Self 
57    where T: AsRef<str> 
58    {
59        for token in tokens {
60            let token_str = token.as_ref();
61            self.add_token(token_str);
62        }
63        self
64    }
65
66    /// Subtract a token
67    ///
68    /// # Arguments
69    /// * `token` - Token to subtract
70    #[inline]
71    pub fn sub_token(&mut self, token: &str) -> &mut Self {
72        if let Some(count) = self.token_count.get_mut(token) {
73            if *count > 1 {
74                *count -= 1;
75                self.total_token_count -= 1;
76            } else if *count == 1 {
77                self.token_count.remove(token);
78                self.total_token_count -= 1;
79            }
80        }
81        self
82    }
83
84    /// Subtract multiple tokens
85    ///
86    /// # Arguments
87    /// * `tokens` - Slice of tokens to subtract
88    #[inline]
89    pub fn sub_tokens<T>(&mut self, tokens: &[T]) -> &mut Self 
90    where T: AsRef<str>
91    {
92        for token in tokens {
93            let token_str = token.as_ref();
94            self.sub_token(token_str);
95        }
96        self
97    }
98
99    /// Set the occurrence count for a token
100    ///
101    /// # Arguments
102    /// * `token` - Token
103    /// * `count` - Occurrence count
104    pub fn set_token_count(&mut self, token: &str, count: u64) -> &mut Self {
105        if count == 0 {
106            self.token_count.remove(token);
107        } else {
108            let current_count = self.token_count.entry(token.to_string()).or_insert(0);
109            self.total_token_count += count - *current_count;
110            *current_count = count;
111        }
112        self
113    }
114
115    /// Merge with another TokenFrequency
116    /// # Arguments
117    /// * `other` - Another TokenFrequency to merge with
118    pub fn add_tokens_from_freq(&mut self, other: &TokenFrequency) -> &mut Self {
119        for (token, &count) in &other.token_count {
120            let entry = self.token_count.entry(token.clone()).or_insert(0);
121            *entry += count;
122            self.total_token_count += count;
123        }
124        self
125    }
126
127    /// Scale the token counts by a scalar
128    /// # Arguments
129    /// * `scalar` - Scalar to scale by
130    pub fn scale(&mut self, scalar: f64) -> &mut Self {
131        let mut total_count = 0;
132        self.token_count.iter_mut().for_each(|(_, count)| {
133            *count = ((*count as f64) * scalar).round() as u64;
134            total_count += *count;
135        });
136        self.total_token_count = total_count;
137        self
138    }
139}
140
141impl<T> From<&[T]> for TokenFrequency
142where
143    T: AsRef<str>,
144{
145    fn from(tokens: &[T]) -> Self {
146        let mut tf = TokenFrequency::new();
147        tf.add_tokens(tokens);
148        tf
149    }
150}
151
152impl From<Corpus> for TokenFrequency {
153    fn from(corpus: Corpus) -> Self {
154        let mut tf = TokenFrequency::new();
155        for entry in corpus.token_counts.iter() {
156            let token = entry.key();
157            let count = *entry.value();
158            tf.set_token_count(token, count);
159        }
160        tf
161    }
162}
163
164/// Implementation for retrieving information from TokenFrequency
165impl TokenFrequency {
166    /// Get iterator over all tokens and their counts
167    /// 
168    /// # Returns
169    /// * `impl Iterator<Item=(&str, u64)>` - Iterator over tokens and their counts
170    #[inline]
171    pub fn iter(&self) -> impl Iterator<Item=(&str, u64)> {
172        self.token_count.iter().map(|(token, &count)| {
173            (token.as_str(), count)
174        })
175    }
176
177    /// Get a vector of all tokens and their counts
178    ///
179    /// # Returns
180    /// * `Vec<(String, u64)>` - Vector of tokens and their counts
181    #[inline]
182    pub fn token_count_vector(&self) -> Vec<(String, u64)> {
183        self.token_count.iter().map(|(token, &count)| {
184            (token.clone(), count)
185        }).collect()
186    }
187
188    /// Get a vector of all tokens and their counts (as &str)
189    ///
190    /// # Returns
191    /// * `Vec<(&str, u64)>` - Vector of tokens and their counts
192    #[inline]
193    pub fn token_count_vector_ref_str(&self) -> Vec<(&str, u64)> {
194        self.token_count.iter().map(|(token, &count)| {
195            (token.as_str(), count)
196        }).collect()
197    }
198
199    /// Get a hashmap of all tokens and their counts (as &str)
200    ///
201    /// # Returns
202    /// * `HashMap<&str, u64>` - HashMap of tokens and their counts
203    #[inline]
204    pub fn token_count_hashmap_ref_str(&self) -> HashMap<&str, u64, RandomState> {
205        self.token_count.iter().map(|(token, &count)| {
206            (token.as_str(), count)
207        }).collect()
208    }
209
210    /// Get the total count of all tokens
211    ///
212    /// # Returns
213    /// * `u64` - Total token count
214    #[inline]
215    pub fn token_sum(&self) -> u64 {
216        self.total_token_count
217    }
218
219    /// Get the occurrence count for a specific token
220    ///
221    /// # Arguments
222    /// * `token` - Token
223    ///
224    /// # Returns
225    /// * `u64` - Occurrence count for the token
226    #[inline]
227    pub fn token_count(&self, token: &str) -> u64 {
228        *self.token_count.get(token).unwrap_or(&0)
229    }
230
231    /// Get the most frequent tokens
232    /// If multiple tokens have the same count, all are returned
233    ///
234    /// # Returns
235    /// * `Vec<(String, u64)>` - Vector of most frequent tokens and their counts
236    #[inline]
237    pub fn most_frequent_tokens_vector(&self) -> Vec<(String, u64)> {
238        if let Some(&max_count) = self.token_count.values().max() {
239            self.token_count.iter()
240                .filter(|&(_, &count)| count == max_count)
241                .map(|(token, &count)| (token.clone(), count))
242                .collect()
243        } else {
244            Vec::new()
245        }
246    }
247
248    /// Get the count of the most frequent token
249    ///
250    /// # Returns
251    /// * `u64` - Count of the most frequent token
252    #[inline]
253    pub fn most_frequent_token_count(&self) -> u64 {
254        if let Some(&max_count) = self.token_count.values().max() {
255            max_count
256        } else {
257            0
258        }
259    }
260
261    /// Check if a token exists
262    ///
263    /// # Arguments
264    /// * `token` - Token
265    ///
266    /// # Returns
267    /// * `bool` - true if the token exists, false otherwise
268    #[inline]
269    pub fn contains_token(&self, token: &str) -> bool {
270        self.token_count.contains_key(token)
271    }
272
273    /// token_set_iter
274    /// 
275    /// # Returns
276    /// * `impl Iterator<Item=&str>` - Iterator over the set of tokens
277    #[inline]
278    pub fn token_set_iter(&self) -> impl Iterator<Item=&str> {
279        self.token_count.keys().map(|s| s.as_str())
280    }
281
282    /// Get the set of tokens
283    ///
284    /// # Returns
285    /// * `Vec<String>` - Set of tokens
286    #[inline]
287    pub fn token_set(&self) -> Vec<String> {
288        self.token_count.keys().cloned().collect()
289    }
290
291    /// Get the set of tokens (as &str)
292    ///
293    /// # Returns
294    /// * `Vec<&str>` - Set of tokens
295    #[inline]
296    pub fn token_set_ref_str(&self) -> Vec<&str> {
297        self.token_count.keys().map(|s| s.as_str()).collect()
298    }
299
300    /// Get the set of tokens as a HashSet
301    ///
302    /// # Returns
303    /// * `HashSet<String>` - Set of tokens
304    #[inline]
305    pub fn token_hashset(&self) -> HashSet<String, RandomState> {
306        self.token_count.keys().cloned().collect()
307    }
308
309    /// Get the set of tokens as a HashSet (as &str)
310    ///
311    /// # Returns
312    /// * `HashSet<&str>` - Set of tokens
313    #[inline]
314    pub fn token_hashset_ref_str(&self) -> HashSet<&str, RandomState> {
315        self.token_count.keys().map(|s| s.as_str()).collect()
316    }
317
318    /// Get the number of unique tokens
319    ///
320    /// # Returns
321    /// * `usize` - Number of unique tokens
322    #[inline]
323    pub fn token_num(&self) -> usize {
324        self.token_count.len()
325    }
326
327    /// Remove stop tokens
328    ///
329    /// # Arguments
330    /// * `stop_tokens` - Slice of stop tokens to remove
331    ///
332    /// # Returns
333    /// * `u64` - Total count of removed tokens
334    #[inline]
335    pub fn remove_stop_tokens(&mut self, stop_tokens: &[&str]) -> u64{
336        let mut removed_total_count: u64 = 0;
337        for &stop_token in stop_tokens {
338            if let Some(count) = self.token_count.remove(stop_token) {
339                removed_total_count += count as u64;
340            }
341        }
342        self.total_token_count -= removed_total_count;
343        removed_total_count
344    }
345
346    /// Remove tokens by a condition
347    ///
348    /// # Arguments
349    /// * `condition` - Closure to determine which tokens to remove
350    ///
351    /// # Returns
352    /// * `u64` - Total count of removed tokens
353    #[inline]
354    pub fn remove_tokens_by<F>(&mut self, condition: F) -> u64
355    where
356        F: Fn(&str, &u64) -> bool,
357    {
358        let mut removed_total_count: u64 = 0;
359        self.token_count.retain(|token, count| {
360            if condition(token, count) {
361                removed_total_count += *count as u64;
362                false
363            } else {
364                true
365            }
366        });
367        self.total_token_count -= removed_total_count as u64;
368
369        removed_total_count
370    }
371
372    /// Get a vector of tokens sorted by frequency (descending)
373    ///
374    /// # Returns
375    /// * `Vec<(String, u64)>` - Vector of tokens sorted by frequency
376    #[inline]
377    pub fn sorted_frequency_vector(&self) -> Vec<(String, u64)> {
378        let mut token_list: Vec<(String, u64)> = self.token_count
379            .iter()
380            .map(|(token, &count)| (token.clone(), count))
381            .collect();
382
383        token_list.sort_by(|a, b| b.1.cmp(&a.1));
384        token_list
385    }
386
387    /// Get a vector of tokens sorted by dictionary order (ascending)
388    ///
389    /// # Returns
390    /// * `Vec<(String, u64)>` - Vector of tokens sorted by dictionary order
391    #[inline]
392    pub fn sorted_dict_order_vector(&self) -> Vec<(String, u64)> {
393        let mut token_list: Vec<(String, u64)> = self.token_count
394            .iter()
395            .map(|(token, &count)| (token.clone(), count))
396            .collect();
397
398        token_list.sort_by(|a, b| a.0.cmp(&b.0));
399        token_list
400    }
401
402    /// Calculate the diversity of tokens
403    /// 1.0 indicates complete diversity, 0.0 indicates no diversity
404    ///
405    /// # Returns
406    /// * `f64` - Diversity of tokens
407    #[inline]
408    pub fn unique_token_ratio(&self) -> f64 {
409        if self.total_token_count == 0 {
410            return 0.0;
411        }
412        self.token_count.len() as f64 / self.total_token_count as f64
413    }
414
415    /// Get the probability distribution P(token) (owned String version)
416    /// Returns an empty vector if total is 0
417    #[inline]
418    pub fn probability_vector(&self) -> Vec<(String, f64)> {
419        if self.total_token_count == 0 {
420            return Vec::new();
421        }
422        let total = self.total_token_count as f64;
423        self.token_count
424            .iter()
425            .map(|(token, &count)| (token.clone(), (count as f64) / total))
426            .collect()
427    }
428
429    /// Get the probability distribution P(token) (as &str)
430    /// Returns an empty vector if total is 0
431    #[inline]
432    pub fn probability_vector_ref_str(&self) -> Vec<(&str, f64)> {
433        if self.total_token_count == 0 {
434            return Vec::new();
435        }
436        let total = self.total_token_count as f64;
437        self.token_count
438            .iter()
439            .map(|(token, &count)| (token.as_str(), (count as f64) / total))
440            .collect()
441    }
442
443    /// Get the probability P(token) for a specific token
444    /// Returns 0.0 if total is 0
445    #[inline]
446    pub fn probability(&self, token: &str) -> f64 {
447        if self.total_token_count == 0 {
448            return 0.0;
449        }
450        (self.token_count(token) as f64) / (self.total_token_count as f64)
451    }
452
453    /// Reset all counts
454    #[inline]
455    pub fn clear(&mut self) {
456        self.token_count.clear();
457        self.total_token_count = 0;
458    }
459
460    /// Shrink internal storage to fit current size
461    #[inline]
462    pub fn shrink_to_fit(&mut self) {
463        self.token_count.shrink_to_fit();
464    }
465}