tf_idf_vectorizer/vectorizer/
token.rs

1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3use ahash::RandomState;
4use serde::{Deserialize, Serialize};
5
6
7/// TokenFrequency struct
8/// Manages the frequency of token occurrences.
9/// Counts the number of times each token appears.
10///
11/// # Examples
12/// ```
13/// use crate::tf_idf_vectorizer::vectorizer::token::TokenFrequency;
14/// let mut token_freq = TokenFrequency::new();
15/// token_freq.add_token("token1");
16/// token_freq.add_token("token2");
17/// token_freq.add_token("token1");
18///
19/// assert_eq!(token_freq.token_count("token1"), 2);
20/// ```
21#[derive(Serialize, Deserialize, Debug, Clone)]
22pub struct TokenFrequency {
23    token_count: HashMap<String, u64, RandomState>,
24    total_token_count: u64,
25}
26
27/// Implementation for adding and removing tokens
28impl TokenFrequency {
29    /// Create a new TokenFrequency
30    pub fn new() -> Self {
31        TokenFrequency {
32            token_count: HashMap::with_hasher(RandomState::new()),
33            total_token_count: 0,
34        }
35    }
36
37    /// Add a token
38    ///
39    /// # Arguments
40    /// * `token` - Token to add
41    #[inline]
42    pub fn add_token(&mut self, token: &str) -> &mut Self {
43        let count = self.token_count.entry(token.to_string()).or_insert(0);
44        *count += 1;
45        self.total_token_count += 1;
46        self
47    }
48
49    /// Add multiple tokens
50    ///
51    /// # Arguments
52    /// * `tokens` - Slice of tokens to add
53    #[inline]
54    pub fn add_tokens<T>(&mut self, tokens: &[T]) -> &mut Self 
55    where T: AsRef<str> 
56    {
57        for token in tokens {
58            let token_str = token.as_ref();
59            self.add_token(token_str);
60        }
61        self
62    }
63
64    /// Subtract a token
65    ///
66    /// # Arguments
67    /// * `token` - Token to subtract
68    #[inline]
69    pub fn sub_token(&mut self, token: &str) -> &mut Self {
70        if let Some(count) = self.token_count.get_mut(token) {
71            if *count > 1 {
72                *count -= 1;
73                self.total_token_count -= 1;
74            } else if *count == 1 {
75                self.token_count.remove(token);
76                self.total_token_count -= 1;
77            }
78        }
79        self
80    }
81
82    /// Subtract multiple tokens
83    ///
84    /// # Arguments
85    /// * `tokens` - Slice of tokens to subtract
86    #[inline]
87    pub fn sub_tokens<T>(&mut self, tokens: &[T]) -> &mut Self 
88    where T: AsRef<str>
89    {
90        for token in tokens {
91            let token_str = token.as_ref();
92            self.sub_token(token_str);
93        }
94        self
95    }
96
97    /// Set the occurrence count for a token
98    ///
99    /// # Arguments
100    /// * `token` - Token
101    /// * `count` - Occurrence count
102    pub fn set_token_count(&mut self, token: &str, count: u64) -> &mut Self {
103        if count == 0 {
104            self.token_count.remove(token);
105        } else {
106            let current_count = self.token_count.entry(token.to_string()).or_insert(0);
107            self.total_token_count += count - *current_count;
108            *current_count = count;
109        }
110        self
111    }
112}
113
114/// Implementation for retrieving information from TokenFrequency
115impl TokenFrequency {
116    /// Get a vector of all tokens and their counts
117    ///
118    /// # Returns
119    /// * `Vec<(String, u64)>` - Vector of tokens and their counts
120    #[inline]
121    pub fn token_count_vector(&self) -> Vec<(String, u64)> {
122        self.token_count.iter().map(|(token, &count)| {
123            (token.clone(), count)
124        }).collect()
125    }
126
127    /// Get a vector of all tokens and their counts (as &str)
128    ///
129    /// # Returns
130    /// * `Vec<(&str, u64)>` - Vector of tokens and their counts
131    #[inline]
132    pub fn token_count_vector_ref_str(&self) -> Vec<(&str, u64)> {
133        self.token_count.iter().map(|(token, &count)| {
134            (token.as_str(), count)
135        }).collect()
136    }
137
138    /// Get a hashmap of all tokens and their counts (as &str)
139    ///
140    /// # Returns
141    /// * `HashMap<&str, u64>` - HashMap of tokens and their counts
142    #[inline]
143    pub fn token_count_hashmap_ref_str(&self) -> HashMap<&str, u64, RandomState> {
144        self.token_count.iter().map(|(token, &count)| {
145            (token.as_str(), count)
146        }).collect()
147    }
148
149    /// Get the total count of all tokens
150    ///
151    /// # Returns
152    /// * `u64` - Total token count
153    #[inline]
154    pub fn token_sum(&self) -> u64 {
155        self.total_token_count
156    }
157
158    /// Get the occurrence count for a specific token
159    ///
160    /// # Arguments
161    /// * `token` - Token
162    ///
163    /// # Returns
164    /// * `u64` - Occurrence count for the token
165    #[inline]
166    pub fn token_count(&self, token: &str) -> u64 {
167        *self.token_count.get(token).unwrap_or(&0)
168    }
169
170    /// Get the most frequent tokens
171    /// If multiple tokens have the same count, all are returned
172    ///
173    /// # Returns
174    /// * `Vec<(String, u64)>` - Vector of most frequent tokens and their counts
175    #[inline]
176    pub fn most_frequent_tokens_vector(&self) -> Vec<(String, u64)> {
177        if let Some(&max_count) = self.token_count.values().max() {
178            self.token_count.iter()
179                .filter(|&(_, &count)| count == max_count)
180                .map(|(token, &count)| (token.clone(), count))
181                .collect()
182        } else {
183            Vec::new()
184        }
185    }
186
187    /// Get the count of the most frequent token
188    ///
189    /// # Returns
190    /// * `u64` - Count of the most frequent token
191    #[inline]
192    pub fn most_frequent_token_count(&self) -> u64 {
193        if let Some(&max_count) = self.token_count.values().max() {
194            max_count
195        } else {
196            0
197        }
198    }
199
200    /// Check if a token exists
201    ///
202    /// # Arguments
203    /// * `token` - Token
204    ///
205    /// # Returns
206    /// * `bool` - true if the token exists, false otherwise
207    #[inline]
208    pub fn contains_token(&self, token: &str) -> bool {
209        self.token_count.contains_key(token)
210    }
211
212    /// Get the set of tokens
213    ///
214    /// # Returns
215    /// * `Vec<String>` - Set of tokens
216    #[inline]
217    pub fn token_set(&self) -> Vec<String> {
218        self.token_count.keys().cloned().collect()
219    }
220
221    /// Get the set of tokens (as &str)
222    ///
223    /// # Returns
224    /// * `Vec<&str>` - Set of tokens
225    #[inline]
226    pub fn token_set_ref_str(&self) -> Vec<&str> {
227        self.token_count.keys().map(|s| s.as_str()).collect()
228    }
229
230    /// Get the set of tokens as a HashSet
231    ///
232    /// # Returns
233    /// * `HashSet<String>` - Set of tokens
234    #[inline]
235    pub fn token_hashset(&self) -> HashSet<String, RandomState> {
236        self.token_count.keys().cloned().collect()
237    }
238
239    /// Get the set of tokens as a HashSet (as &str)
240    ///
241    /// # Returns
242    /// * `HashSet<&str>` - Set of tokens
243    #[inline]
244    pub fn token_hashset_ref_str(&self) -> HashSet<&str, RandomState> {
245        self.token_count.keys().map(|s| s.as_str()).collect()
246    }
247
248    /// Get the number of unique tokens
249    ///
250    /// # Returns
251    /// * `usize` - Number of unique tokens
252    #[inline]
253    pub fn token_num(&self) -> usize {
254        self.token_count.len()
255    }
256
257    /// Remove stop tokens
258    ///
259    /// # Arguments
260    /// * `stop_tokens` - Slice of stop tokens to remove
261    ///
262    /// # Returns
263    /// * `u64` - Total count of removed tokens
264    #[inline]
265    pub fn remove_stop_tokens(&mut self, stop_tokens: &[&str]) -> u64{
266        let mut removed_total_count: u64 = 0;
267        for &stop_token in stop_tokens {
268            if let Some(count) = self.token_count.remove(stop_token) {
269                removed_total_count += count as u64;
270            }
271        }
272        self.total_token_count -= removed_total_count;
273        removed_total_count
274    }
275
276    /// Remove tokens by a condition
277    ///
278    /// # Arguments
279    /// * `condition` - Closure to determine which tokens to remove
280    ///
281    /// # Returns
282    /// * `u64` - Total count of removed tokens
283    #[inline]
284    pub fn remove_tokens_by<F>(&mut self, condition: F) -> u64
285    where
286        F: Fn(&str, &u64) -> bool,
287    {
288        let mut removed_total_count: u64 = 0;
289        self.token_count.retain(|token, count| {
290            if condition(token, count) {
291                removed_total_count += *count as u64;
292                false
293            } else {
294                true
295            }
296        });
297        self.total_token_count -= removed_total_count as u64;
298
299        removed_total_count
300    }
301
302    /// Get a vector of tokens sorted by frequency (descending)
303    ///
304    /// # Returns
305    /// * `Vec<(String, u64)>` - Vector of tokens sorted by frequency
306    #[inline]
307    pub fn sorted_frequency_vector(&self) -> Vec<(String, u64)> {
308        let mut token_list: Vec<(String, u64)> = self.token_count
309            .iter()
310            .map(|(token, &count)| (token.clone(), count))
311            .collect();
312
313        token_list.sort_by(|a, b| b.1.cmp(&a.1));
314        token_list
315    }
316
317    /// Get a vector of tokens sorted by dictionary order (ascending)
318    ///
319    /// # Returns
320    /// * `Vec<(String, u64)>` - Vector of tokens sorted by dictionary order
321    #[inline]
322    pub fn sorted_dict_order_vector(&self) -> Vec<(String, u64)> {
323        let mut token_list: Vec<(String, u64)> = self.token_count
324            .iter()
325            .map(|(token, &count)| (token.clone(), count))
326            .collect();
327
328        token_list.sort_by(|a, b| a.0.cmp(&b.0));
329        token_list
330    }
331
332    /// Calculate the diversity of tokens
333    /// 1.0 indicates complete diversity, 0.0 indicates no diversity
334    ///
335    /// # Returns
336    /// * `f64` - Diversity of tokens
337    #[inline]
338    pub fn unique_token_ratio(&self) -> f64 {
339        if self.total_token_count == 0 {
340            return 0.0;
341        }
342        self.token_count.len() as f64 / self.total_token_count as f64
343    }
344
345    /// Get the probability distribution P(token) (owned String version)
346    /// Returns an empty vector if total is 0
347    #[inline]
348    pub fn probability_vector(&self) -> Vec<(String, f64)> {
349        if self.total_token_count == 0 {
350            return Vec::new();
351        }
352        let total = self.total_token_count as f64;
353        self.token_count
354            .iter()
355            .map(|(token, &count)| (token.clone(), (count as f64) / total))
356            .collect()
357    }
358
359    /// Get the probability distribution P(token) (as &str)
360    /// Returns an empty vector if total is 0
361    #[inline]
362    pub fn probability_vector_ref_str(&self) -> Vec<(&str, f64)> {
363        if self.total_token_count == 0 {
364            return Vec::new();
365        }
366        let total = self.total_token_count as f64;
367        self.token_count
368            .iter()
369            .map(|(token, &count)| (token.as_str(), (count as f64) / total))
370            .collect()
371    }
372
373    /// Get the probability P(token) for a specific token
374    /// Returns 0.0 if total is 0
375    #[inline]
376    pub fn probability(&self, token: &str) -> f64 {
377        if self.total_token_count == 0 {
378            return 0.0;
379        }
380        (self.token_count(token) as f64) / (self.total_token_count as f64)
381    }
382
383    /// Reset all counts
384    #[inline]
385    pub fn clear(&mut self) {
386        self.token_count.clear();
387        self.total_token_count = 0;
388    }
389
390    /// Shrink internal storage to fit current size
391    #[inline]
392    pub fn shrink_to_fit(&mut self) {
393        self.token_count.shrink_to_fit();
394    }
395}