tf_idf_vectorizer/vectorizer/
token.rs

1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3use ahash::RandomState;
4use serde::{Deserialize, Serialize};
5
6
7/// TokenFrequency struct
8/// Manages the frequency of token occurrences.
9/// Counts the number of times each token appears.
10///
11/// # Examples
12/// ```
13/// use crate::tf_idf_vectorizer::vectorizer::token::TokenFrequency;
14/// let mut token_freq = TokenFrequency::new();
15/// token_freq.add_token("token1");
16/// token_freq.add_token("token2");
17/// token_freq.add_token("token1");
18///
19/// assert_eq!(token_freq.token_count("token1"), 2);
20/// ```
21#[derive(Serialize, Deserialize, Debug, Clone)]
22pub struct TokenFrequency {
23    token_count: HashMap<String, u64, RandomState>,
24    total_token_count: u64,
25}
26
27/// Implementation for adding and removing tokens
28impl TokenFrequency {
29    /// Create a new TokenFrequency
30    pub fn new() -> Self {
31        TokenFrequency {
32            token_count: HashMap::with_hasher(RandomState::new()),
33            total_token_count: 0,
34        }
35    }
36
37    /// Add a token
38    ///
39    /// # Arguments
40    /// * `token` - Token to add
41    #[inline]
42    pub fn add_token(&mut self, token: &str) -> &mut Self {
43        let count = self.token_count.entry(token.to_string()).or_insert(0);
44        *count += 1;
45        self.total_token_count += 1;
46        self
47    }
48
49    /// Add multiple tokens
50    ///
51    /// # Arguments
52    /// * `tokens` - Slice of tokens to add
53    #[inline]
54    pub fn add_tokens<T>(&mut self, tokens: &[T]) -> &mut Self 
55    where T: AsRef<str> 
56    {
57        for token in tokens {
58            let token_str = token.as_ref();
59            self.add_token(token_str);
60        }
61        self
62    }
63
64    /// Subtract a token
65    ///
66    /// # Arguments
67    /// * `token` - Token to subtract
68    #[inline]
69    pub fn sub_token(&mut self, token: &str) -> &mut Self {
70        if let Some(count) = self.token_count.get_mut(token) {
71            if *count > 1 {
72                *count -= 1;
73                self.total_token_count -= 1;
74            } else if *count == 1 {
75                self.token_count.remove(token);
76                self.total_token_count -= 1;
77            }
78        }
79        self
80    }
81
82    /// Subtract multiple tokens
83    ///
84    /// # Arguments
85    /// * `tokens` - Slice of tokens to subtract
86    #[inline]
87    pub fn sub_tokens<T>(&mut self, tokens: &[T]) -> &mut Self 
88    where T: AsRef<str>
89    {
90        for token in tokens {
91            let token_str = token.as_ref();
92            self.sub_token(token_str);
93        }
94        self
95    }
96
97    /// Set the occurrence count for a token
98    ///
99    /// # Arguments
100    /// * `token` - Token
101    /// * `count` - Occurrence count
102    pub fn set_token_count(&mut self, token: &str, count: u64) -> &mut Self {
103        if count == 0 {
104            self.token_count.remove(token);
105        } else {
106            let current_count = self.token_count.entry(token.to_string()).or_insert(0);
107            self.total_token_count += count - *current_count;
108            *current_count = count;
109        }
110        self
111    }
112
113    /// Merge with another TokenFrequency
114    /// # Arguments
115    /// * `other` - Another TokenFrequency to merge with
116    pub fn add_tokens_from_freq(&mut self, other: &TokenFrequency) -> &mut Self {
117        for (token, &count) in &other.token_count {
118            let entry = self.token_count.entry(token.clone()).or_insert(0);
119            *entry += count;
120            self.total_token_count += count;
121        }
122        self
123    }
124
125    /// Scale the token counts by a scalar
126    /// # Arguments
127    /// * `scalar` - Scalar to scale by
128    pub fn scale(&mut self, scalar: f64) -> &mut Self {
129        let mut total_count = 0;
130        self.token_count.iter_mut().for_each(|(_, count)| {
131            *count = ((*count as f64) * scalar).round() as u64;
132            total_count += *count;
133        });
134        self.total_token_count = total_count;
135        self
136    }
137}
138
139impl<T> From<&[T]> for TokenFrequency
140where
141    T: AsRef<str>,
142{
143    fn from(tokens: &[T]) -> Self {
144        let mut tf = TokenFrequency::new();
145        tf.add_tokens(tokens);
146        tf
147    }
148}
149
150/// Implementation for retrieving information from TokenFrequency
151impl TokenFrequency {
152    /// Get a vector of all tokens and their counts
153    ///
154    /// # Returns
155    /// * `Vec<(String, u64)>` - Vector of tokens and their counts
156    #[inline]
157    pub fn token_count_vector(&self) -> Vec<(String, u64)> {
158        self.token_count.iter().map(|(token, &count)| {
159            (token.clone(), count)
160        }).collect()
161    }
162
163    /// Get a vector of all tokens and their counts (as &str)
164    ///
165    /// # Returns
166    /// * `Vec<(&str, u64)>` - Vector of tokens and their counts
167    #[inline]
168    pub fn token_count_vector_ref_str(&self) -> Vec<(&str, u64)> {
169        self.token_count.iter().map(|(token, &count)| {
170            (token.as_str(), count)
171        }).collect()
172    }
173
174    /// Get a hashmap of all tokens and their counts (as &str)
175    ///
176    /// # Returns
177    /// * `HashMap<&str, u64>` - HashMap of tokens and their counts
178    #[inline]
179    pub fn token_count_hashmap_ref_str(&self) -> HashMap<&str, u64, RandomState> {
180        self.token_count.iter().map(|(token, &count)| {
181            (token.as_str(), count)
182        }).collect()
183    }
184
185    /// Get the total count of all tokens
186    ///
187    /// # Returns
188    /// * `u64` - Total token count
189    #[inline]
190    pub fn token_sum(&self) -> u64 {
191        self.total_token_count
192    }
193
194    /// Get the occurrence count for a specific token
195    ///
196    /// # Arguments
197    /// * `token` - Token
198    ///
199    /// # Returns
200    /// * `u64` - Occurrence count for the token
201    #[inline]
202    pub fn token_count(&self, token: &str) -> u64 {
203        *self.token_count.get(token).unwrap_or(&0)
204    }
205
206    /// Get the most frequent tokens
207    /// If multiple tokens have the same count, all are returned
208    ///
209    /// # Returns
210    /// * `Vec<(String, u64)>` - Vector of most frequent tokens and their counts
211    #[inline]
212    pub fn most_frequent_tokens_vector(&self) -> Vec<(String, u64)> {
213        if let Some(&max_count) = self.token_count.values().max() {
214            self.token_count.iter()
215                .filter(|&(_, &count)| count == max_count)
216                .map(|(token, &count)| (token.clone(), count))
217                .collect()
218        } else {
219            Vec::new()
220        }
221    }
222
223    /// Get the count of the most frequent token
224    ///
225    /// # Returns
226    /// * `u64` - Count of the most frequent token
227    #[inline]
228    pub fn most_frequent_token_count(&self) -> u64 {
229        if let Some(&max_count) = self.token_count.values().max() {
230            max_count
231        } else {
232            0
233        }
234    }
235
236    /// Check if a token exists
237    ///
238    /// # Arguments
239    /// * `token` - Token
240    ///
241    /// # Returns
242    /// * `bool` - true if the token exists, false otherwise
243    #[inline]
244    pub fn contains_token(&self, token: &str) -> bool {
245        self.token_count.contains_key(token)
246    }
247
248    /// Get the set of tokens
249    ///
250    /// # Returns
251    /// * `Vec<String>` - Set of tokens
252    #[inline]
253    pub fn token_set(&self) -> Vec<String> {
254        self.token_count.keys().cloned().collect()
255    }
256
257    /// Get the set of tokens (as &str)
258    ///
259    /// # Returns
260    /// * `Vec<&str>` - Set of tokens
261    #[inline]
262    pub fn token_set_ref_str(&self) -> Vec<&str> {
263        self.token_count.keys().map(|s| s.as_str()).collect()
264    }
265
266    /// Get the set of tokens as a HashSet
267    ///
268    /// # Returns
269    /// * `HashSet<String>` - Set of tokens
270    #[inline]
271    pub fn token_hashset(&self) -> HashSet<String, RandomState> {
272        self.token_count.keys().cloned().collect()
273    }
274
275    /// Get the set of tokens as a HashSet (as &str)
276    ///
277    /// # Returns
278    /// * `HashSet<&str>` - Set of tokens
279    #[inline]
280    pub fn token_hashset_ref_str(&self) -> HashSet<&str, RandomState> {
281        self.token_count.keys().map(|s| s.as_str()).collect()
282    }
283
284    /// Get the number of unique tokens
285    ///
286    /// # Returns
287    /// * `usize` - Number of unique tokens
288    #[inline]
289    pub fn token_num(&self) -> usize {
290        self.token_count.len()
291    }
292
293    /// Remove stop tokens
294    ///
295    /// # Arguments
296    /// * `stop_tokens` - Slice of stop tokens to remove
297    ///
298    /// # Returns
299    /// * `u64` - Total count of removed tokens
300    #[inline]
301    pub fn remove_stop_tokens(&mut self, stop_tokens: &[&str]) -> u64{
302        let mut removed_total_count: u64 = 0;
303        for &stop_token in stop_tokens {
304            if let Some(count) = self.token_count.remove(stop_token) {
305                removed_total_count += count as u64;
306            }
307        }
308        self.total_token_count -= removed_total_count;
309        removed_total_count
310    }
311
312    /// Remove tokens by a condition
313    ///
314    /// # Arguments
315    /// * `condition` - Closure to determine which tokens to remove
316    ///
317    /// # Returns
318    /// * `u64` - Total count of removed tokens
319    #[inline]
320    pub fn remove_tokens_by<F>(&mut self, condition: F) -> u64
321    where
322        F: Fn(&str, &u64) -> bool,
323    {
324        let mut removed_total_count: u64 = 0;
325        self.token_count.retain(|token, count| {
326            if condition(token, count) {
327                removed_total_count += *count as u64;
328                false
329            } else {
330                true
331            }
332        });
333        self.total_token_count -= removed_total_count as u64;
334
335        removed_total_count
336    }
337
338    /// Get a vector of tokens sorted by frequency (descending)
339    ///
340    /// # Returns
341    /// * `Vec<(String, u64)>` - Vector of tokens sorted by frequency
342    #[inline]
343    pub fn sorted_frequency_vector(&self) -> Vec<(String, u64)> {
344        let mut token_list: Vec<(String, u64)> = self.token_count
345            .iter()
346            .map(|(token, &count)| (token.clone(), count))
347            .collect();
348
349        token_list.sort_by(|a, b| b.1.cmp(&a.1));
350        token_list
351    }
352
353    /// Get a vector of tokens sorted by dictionary order (ascending)
354    ///
355    /// # Returns
356    /// * `Vec<(String, u64)>` - Vector of tokens sorted by dictionary order
357    #[inline]
358    pub fn sorted_dict_order_vector(&self) -> Vec<(String, u64)> {
359        let mut token_list: Vec<(String, u64)> = self.token_count
360            .iter()
361            .map(|(token, &count)| (token.clone(), count))
362            .collect();
363
364        token_list.sort_by(|a, b| a.0.cmp(&b.0));
365        token_list
366    }
367
368    /// Calculate the diversity of tokens
369    /// 1.0 indicates complete diversity, 0.0 indicates no diversity
370    ///
371    /// # Returns
372    /// * `f64` - Diversity of tokens
373    #[inline]
374    pub fn unique_token_ratio(&self) -> f64 {
375        if self.total_token_count == 0 {
376            return 0.0;
377        }
378        self.token_count.len() as f64 / self.total_token_count as f64
379    }
380
381    /// Get the probability distribution P(token) (owned String version)
382    /// Returns an empty vector if total is 0
383    #[inline]
384    pub fn probability_vector(&self) -> Vec<(String, f64)> {
385        if self.total_token_count == 0 {
386            return Vec::new();
387        }
388        let total = self.total_token_count as f64;
389        self.token_count
390            .iter()
391            .map(|(token, &count)| (token.clone(), (count as f64) / total))
392            .collect()
393    }
394
395    /// Get the probability distribution P(token) (as &str)
396    /// Returns an empty vector if total is 0
397    #[inline]
398    pub fn probability_vector_ref_str(&self) -> Vec<(&str, f64)> {
399        if self.total_token_count == 0 {
400            return Vec::new();
401        }
402        let total = self.total_token_count as f64;
403        self.token_count
404            .iter()
405            .map(|(token, &count)| (token.as_str(), (count as f64) / total))
406            .collect()
407    }
408
409    /// Get the probability P(token) for a specific token
410    /// Returns 0.0 if total is 0
411    #[inline]
412    pub fn probability(&self, token: &str) -> f64 {
413        if self.total_token_count == 0 {
414            return 0.0;
415        }
416        (self.token_count(token) as f64) / (self.total_token_count as f64)
417    }
418
419    /// Reset all counts
420    #[inline]
421    pub fn clear(&mut self) {
422        self.token_count.clear();
423        self.total_token_count = 0;
424    }
425
426    /// Shrink internal storage to fit current size
427    #[inline]
428    pub fn shrink_to_fit(&mut self) {
429        self.token_count.shrink_to_fit();
430    }
431}