tf_idf_vectorizer/vectorizer/
token.rs

1use core::str;
2use std::{collections::{HashMap, HashSet}, fmt::Debug};
3use ahash::RandomState;
4use serde::{Deserialize, Serialize};
5
6
7/// TokenFrequency struct
8/// Manages the frequency of token occurrences.
9/// Counts the number of times each token appears.
10///
11/// # Examples
12/// ```
13/// use crate::tf_idf_vectorizer::vectorizer::token::TokenFrequency;
14/// let mut token_freq = TokenFrequency::new();
15/// token_freq.add_token("token1");
16/// token_freq.add_token("token2");
17/// token_freq.add_token("token1");
18///
19/// assert_eq!(token_freq.token_count("token1"), 2);
20/// ```
21#[derive(Serialize, Deserialize, Debug, Clone)]
22pub struct TokenFrequency {
23    token_count: HashMap<String, u64, RandomState>,
24    total_token_count: u64,
25}
26
27/// Implementation for adding and removing tokens
28impl TokenFrequency {
29    /// Create a new TokenFrequency
30    pub fn new() -> Self {
31        TokenFrequency {
32            token_count: HashMap::with_hasher(RandomState::new()),
33            total_token_count: 0,
34        }
35    }
36
37    /// Add a token
38    ///
39    /// # Arguments
40    /// * `token` - Token to add
41    #[inline]
42    pub fn add_token(&mut self, token: &str) -> &mut Self {
43        let count = self.token_count.entry(token.to_string()).or_insert(0);
44        *count += 1;
45        self.total_token_count += 1;
46        self
47    }
48
49    /// Add multiple tokens
50    ///
51    /// # Arguments
52    /// * `tokens` - Slice of tokens to add
53    #[inline]
54    pub fn add_tokens<T>(&mut self, tokens: &[T]) -> &mut Self 
55    where T: AsRef<str> 
56    {
57        for token in tokens {
58            let token_str = token.as_ref();
59            self.add_token(token_str);
60        }
61        self
62    }
63
64    /// Subtract a token
65    ///
66    /// # Arguments
67    /// * `token` - Token to subtract
68    #[inline]
69    pub fn sub_token(&mut self, token: &str) -> &mut Self {
70        if let Some(count) = self.token_count.get_mut(token) {
71            if *count > 1 {
72                *count -= 1;
73                self.total_token_count -= 1;
74            } else if *count == 1 {
75                self.token_count.remove(token);
76                self.total_token_count -= 1;
77            }
78        }
79        self
80    }
81
82    /// Subtract multiple tokens
83    ///
84    /// # Arguments
85    /// * `tokens` - Slice of tokens to subtract
86    #[inline]
87    pub fn sub_tokens<T>(&mut self, tokens: &[T]) -> &mut Self 
88    where T: AsRef<str>
89    {
90        for token in tokens {
91            let token_str = token.as_ref();
92            self.sub_token(token_str);
93        }
94        self
95    }
96
97    /// Set the occurrence count for a token
98    ///
99    /// # Arguments
100    /// * `token` - Token
101    /// * `count` - Occurrence count
102    pub fn set_token_count(&mut self, token: &str, count: u64) -> &mut Self {
103        if count == 0 {
104            self.token_count.remove(token);
105        } else {
106            let current_count = self.token_count.entry(token.to_string()).or_insert(0);
107            self.total_token_count += count - *current_count;
108            *current_count = count;
109        }
110        self
111    }
112
113    /// Merge with another TokenFrequency
114    /// # Arguments
115    /// * `other` - Another TokenFrequency to merge with
116    pub fn add_tokens_from_freq(&mut self, other: &TokenFrequency) -> &mut Self {
117        for (token, &count) in &other.token_count {
118            let entry = self.token_count.entry(token.clone()).or_insert(0);
119            *entry += count;
120            self.total_token_count += count;
121        }
122        self
123    }
124
125    /// Scale the token counts by a scalar
126    /// # Arguments
127    /// * `scalar` - Scalar to scale by
128    pub fn scale(&mut self, scalar: f64) -> &mut Self {
129        let mut total_count = 0;
130        self.token_count.iter_mut().for_each(|(_, count)| {
131            *count = ((*count as f64) * scalar).round() as u64;
132            total_count += *count;
133        });
134        self.total_token_count = total_count;
135        self
136    }
137}
138
139/// Implementation for retrieving information from TokenFrequency
140impl TokenFrequency {
141    /// Get a vector of all tokens and their counts
142    ///
143    /// # Returns
144    /// * `Vec<(String, u64)>` - Vector of tokens and their counts
145    #[inline]
146    pub fn token_count_vector(&self) -> Vec<(String, u64)> {
147        self.token_count.iter().map(|(token, &count)| {
148            (token.clone(), count)
149        }).collect()
150    }
151
152    /// Get a vector of all tokens and their counts (as &str)
153    ///
154    /// # Returns
155    /// * `Vec<(&str, u64)>` - Vector of tokens and their counts
156    #[inline]
157    pub fn token_count_vector_ref_str(&self) -> Vec<(&str, u64)> {
158        self.token_count.iter().map(|(token, &count)| {
159            (token.as_str(), count)
160        }).collect()
161    }
162
163    /// Get a hashmap of all tokens and their counts (as &str)
164    ///
165    /// # Returns
166    /// * `HashMap<&str, u64>` - HashMap of tokens and their counts
167    #[inline]
168    pub fn token_count_hashmap_ref_str(&self) -> HashMap<&str, u64, RandomState> {
169        self.token_count.iter().map(|(token, &count)| {
170            (token.as_str(), count)
171        }).collect()
172    }
173
174    /// Get the total count of all tokens
175    ///
176    /// # Returns
177    /// * `u64` - Total token count
178    #[inline]
179    pub fn token_sum(&self) -> u64 {
180        self.total_token_count
181    }
182
183    /// Get the occurrence count for a specific token
184    ///
185    /// # Arguments
186    /// * `token` - Token
187    ///
188    /// # Returns
189    /// * `u64` - Occurrence count for the token
190    #[inline]
191    pub fn token_count(&self, token: &str) -> u64 {
192        *self.token_count.get(token).unwrap_or(&0)
193    }
194
195    /// Get the most frequent tokens
196    /// If multiple tokens have the same count, all are returned
197    ///
198    /// # Returns
199    /// * `Vec<(String, u64)>` - Vector of most frequent tokens and their counts
200    #[inline]
201    pub fn most_frequent_tokens_vector(&self) -> Vec<(String, u64)> {
202        if let Some(&max_count) = self.token_count.values().max() {
203            self.token_count.iter()
204                .filter(|&(_, &count)| count == max_count)
205                .map(|(token, &count)| (token.clone(), count))
206                .collect()
207        } else {
208            Vec::new()
209        }
210    }
211
212    /// Get the count of the most frequent token
213    ///
214    /// # Returns
215    /// * `u64` - Count of the most frequent token
216    #[inline]
217    pub fn most_frequent_token_count(&self) -> u64 {
218        if let Some(&max_count) = self.token_count.values().max() {
219            max_count
220        } else {
221            0
222        }
223    }
224
225    /// Check if a token exists
226    ///
227    /// # Arguments
228    /// * `token` - Token
229    ///
230    /// # Returns
231    /// * `bool` - true if the token exists, false otherwise
232    #[inline]
233    pub fn contains_token(&self, token: &str) -> bool {
234        self.token_count.contains_key(token)
235    }
236
237    /// Get the set of tokens
238    ///
239    /// # Returns
240    /// * `Vec<String>` - Set of tokens
241    #[inline]
242    pub fn token_set(&self) -> Vec<String> {
243        self.token_count.keys().cloned().collect()
244    }
245
246    /// Get the set of tokens (as &str)
247    ///
248    /// # Returns
249    /// * `Vec<&str>` - Set of tokens
250    #[inline]
251    pub fn token_set_ref_str(&self) -> Vec<&str> {
252        self.token_count.keys().map(|s| s.as_str()).collect()
253    }
254
255    /// Get the set of tokens as a HashSet
256    ///
257    /// # Returns
258    /// * `HashSet<String>` - Set of tokens
259    #[inline]
260    pub fn token_hashset(&self) -> HashSet<String, RandomState> {
261        self.token_count.keys().cloned().collect()
262    }
263
264    /// Get the set of tokens as a HashSet (as &str)
265    ///
266    /// # Returns
267    /// * `HashSet<&str>` - Set of tokens
268    #[inline]
269    pub fn token_hashset_ref_str(&self) -> HashSet<&str, RandomState> {
270        self.token_count.keys().map(|s| s.as_str()).collect()
271    }
272
273    /// Get the number of unique tokens
274    ///
275    /// # Returns
276    /// * `usize` - Number of unique tokens
277    #[inline]
278    pub fn token_num(&self) -> usize {
279        self.token_count.len()
280    }
281
282    /// Remove stop tokens
283    ///
284    /// # Arguments
285    /// * `stop_tokens` - Slice of stop tokens to remove
286    ///
287    /// # Returns
288    /// * `u64` - Total count of removed tokens
289    #[inline]
290    pub fn remove_stop_tokens(&mut self, stop_tokens: &[&str]) -> u64{
291        let mut removed_total_count: u64 = 0;
292        for &stop_token in stop_tokens {
293            if let Some(count) = self.token_count.remove(stop_token) {
294                removed_total_count += count as u64;
295            }
296        }
297        self.total_token_count -= removed_total_count;
298        removed_total_count
299    }
300
301    /// Remove tokens by a condition
302    ///
303    /// # Arguments
304    /// * `condition` - Closure to determine which tokens to remove
305    ///
306    /// # Returns
307    /// * `u64` - Total count of removed tokens
308    #[inline]
309    pub fn remove_tokens_by<F>(&mut self, condition: F) -> u64
310    where
311        F: Fn(&str, &u64) -> bool,
312    {
313        let mut removed_total_count: u64 = 0;
314        self.token_count.retain(|token, count| {
315            if condition(token, count) {
316                removed_total_count += *count as u64;
317                false
318            } else {
319                true
320            }
321        });
322        self.total_token_count -= removed_total_count as u64;
323
324        removed_total_count
325    }
326
327    /// Get a vector of tokens sorted by frequency (descending)
328    ///
329    /// # Returns
330    /// * `Vec<(String, u64)>` - Vector of tokens sorted by frequency
331    #[inline]
332    pub fn sorted_frequency_vector(&self) -> Vec<(String, u64)> {
333        let mut token_list: Vec<(String, u64)> = self.token_count
334            .iter()
335            .map(|(token, &count)| (token.clone(), count))
336            .collect();
337
338        token_list.sort_by(|a, b| b.1.cmp(&a.1));
339        token_list
340    }
341
342    /// Get a vector of tokens sorted by dictionary order (ascending)
343    ///
344    /// # Returns
345    /// * `Vec<(String, u64)>` - Vector of tokens sorted by dictionary order
346    #[inline]
347    pub fn sorted_dict_order_vector(&self) -> Vec<(String, u64)> {
348        let mut token_list: Vec<(String, u64)> = self.token_count
349            .iter()
350            .map(|(token, &count)| (token.clone(), count))
351            .collect();
352
353        token_list.sort_by(|a, b| a.0.cmp(&b.0));
354        token_list
355    }
356
357    /// Calculate the diversity of tokens
358    /// 1.0 indicates complete diversity, 0.0 indicates no diversity
359    ///
360    /// # Returns
361    /// * `f64` - Diversity of tokens
362    #[inline]
363    pub fn unique_token_ratio(&self) -> f64 {
364        if self.total_token_count == 0 {
365            return 0.0;
366        }
367        self.token_count.len() as f64 / self.total_token_count as f64
368    }
369
370    /// Get the probability distribution P(token) (owned String version)
371    /// Returns an empty vector if total is 0
372    #[inline]
373    pub fn probability_vector(&self) -> Vec<(String, f64)> {
374        if self.total_token_count == 0 {
375            return Vec::new();
376        }
377        let total = self.total_token_count as f64;
378        self.token_count
379            .iter()
380            .map(|(token, &count)| (token.clone(), (count as f64) / total))
381            .collect()
382    }
383
384    /// Get the probability distribution P(token) (as &str)
385    /// Returns an empty vector if total is 0
386    #[inline]
387    pub fn probability_vector_ref_str(&self) -> Vec<(&str, f64)> {
388        if self.total_token_count == 0 {
389            return Vec::new();
390        }
391        let total = self.total_token_count as f64;
392        self.token_count
393            .iter()
394            .map(|(token, &count)| (token.as_str(), (count as f64) / total))
395            .collect()
396    }
397
398    /// Get the probability P(token) for a specific token
399    /// Returns 0.0 if total is 0
400    #[inline]
401    pub fn probability(&self, token: &str) -> f64 {
402        if self.total_token_count == 0 {
403            return 0.0;
404        }
405        (self.token_count(token) as f64) / (self.total_token_count as f64)
406    }
407
408    /// Reset all counts
409    #[inline]
410    pub fn clear(&mut self) {
411        self.token_count.clear();
412        self.total_token_count = 0;
413    }
414
415    /// Shrink internal storage to fit current size
416    #[inline]
417    pub fn shrink_to_fit(&mut self) {
418        self.token_count.shrink_to_fit();
419    }
420}