Skip to main content

scirs2_text/
utils.rs

1//! Utility functions for text processing
2//!
3//! This module provides utility functions for text processing operations.
4
5use crate::error::{Result, TextError};
6use crate::tokenize::Tokenizer;
7use scirs2_core::parallel_ops;
8use std::collections::HashMap;
9
10/// Count the frequency of tokens in a text
11///
12/// # Arguments
13///
14/// * `text` - The text to analyze
15/// * `tokenizer` - The tokenizer to use
16///
17/// # Returns
18///
19/// * Result containing a HashMap of token frequencies
20#[allow(dead_code)]
21pub fn count_tokens(text: &str, tokenizer: &dyn Tokenizer) -> Result<HashMap<String, usize>> {
22    let tokens = tokenizer.tokenize(text)?;
23    let mut counts = HashMap::new();
24
25    for token in tokens {
26        *counts.entry(token).or_insert(0) += 1;
27    }
28
29    Ok(counts)
30}
31
32/// Count the frequency of tokens in a batch of texts
33///
34/// # Arguments
35///
36/// * `texts` - The texts to analyze
37/// * `tokenizer` - The tokenizer to use
38///
39/// # Returns
40///
41/// * Result containing a HashMap of token frequencies
42#[allow(dead_code)]
43pub fn count_tokens_batch(
44    texts: &[&str],
45    tokenizer: &dyn Tokenizer,
46) -> Result<HashMap<String, usize>> {
47    // Process texts sequentially (for thread safety)
48    let mut total_counts = HashMap::new();
49
50    for &text in texts {
51        let counts = count_tokens(text, tokenizer)?;
52        for (token, count) in counts {
53            *total_counts.entry(token).or_insert(0) += count;
54        }
55    }
56
57    Ok(total_counts)
58}
59
60/// Count the frequency of tokens in a batch of texts (parallel version)
61///
62/// # Arguments
63///
64/// * `texts` - The texts to analyze
65/// * `tokenizer` - The tokenizer to use (must be Send + Sync)
66///
67/// # Returns
68///
69/// * Result containing a HashMap of token frequencies
70#[allow(dead_code)]
71pub fn count_tokens_batch_parallel<T>(
72    texts: &[&str],
73    tokenizer: &T,
74) -> Result<HashMap<String, usize>>
75where
76    T: Tokenizer + Send + Sync,
77{
78    // Process texts in parallel using scirs2-core::parallel
79    // Clone data to avoid lifetime issues
80    let texts_owned: Vec<String> = texts.iter().map(|&s| s.to_string()).collect();
81    let tokenizer_boxed = tokenizer.clone_box();
82
83    let token_counts = parallel_ops::parallel_map_result(&texts_owned, move |text| {
84        count_tokens(text, &*tokenizer_boxed).map_err(|e| {
85            // Convert TextError to CoreError
86            scirs2_core::CoreError::ComputationError(scirs2_core::error::ErrorContext::new(
87                format!("Text processing error: {e}"),
88            ))
89        })
90    })?;
91
92    // Merge all counts
93    let mut total_counts = HashMap::new();
94    for counts in token_counts {
95        for (token, count) in counts {
96            *total_counts.entry(token).or_insert(0) += count;
97        }
98    }
99
100    Ok(total_counts)
101}
102
103/// Remove tokens from a text based on a predicate function
104///
105/// # Arguments
106///
107/// * `text` - The text to filter
108/// * `tokenizer` - The tokenizer to use
109/// * `predicate` - Function that returns true for tokens to keep
110///
111/// # Returns
112///
113/// * Result containing the filtered text
114#[allow(dead_code)]
115pub fn filter_tokens<F>(text: &str, tokenizer: &dyn Tokenizer, predicate: F) -> Result<String>
116where
117    F: Fn(&str) -> bool,
118{
119    let tokens = tokenizer.tokenize(text)?;
120    let filtered_tokens: Vec<String> = tokens
121        .iter()
122        .filter(|token| predicate(token))
123        .cloned()
124        .collect();
125
126    Ok(filtered_tokens.join(" "))
127}
128
129/// Extract n-grams from a text
130///
131/// # Arguments
132///
133/// * `text` - The text to process
134/// * `tokenizer` - The tokenizer to use
135/// * `n` - The n-gram size
136///
137/// # Returns
138///
139/// * Result containing a vector of n-grams
140#[allow(dead_code)]
141pub fn extract_ngrams(text: &str, tokenizer: &dyn Tokenizer, n: usize) -> Result<Vec<String>> {
142    if n == 0 {
143        return Err(TextError::InvalidInput(
144            "n-gram size must be greater than 0".to_string(),
145        ));
146    }
147
148    let tokens = tokenizer.tokenize(text)?;
149
150    if tokens.is_empty() || tokens.len() < n {
151        return Ok(Vec::new());
152    }
153
154    let ngrams: Vec<String> = (0..=(tokens.len() - n))
155        .map(|i| tokens[i..(i + n)].to_vec().join(" "))
156        .collect();
157
158    Ok(ngrams)
159}
160
161/// Extract collocations (frequently co-occurring words) from a text
162///
163/// # Arguments
164///
165/// * `text` - The text to process
166/// * `tokenizer` - The tokenizer to use
167/// * `window_size` - The window size for considering co-occurrence
168/// * `min_count` - Minimum count for a collocation to be included
169///
170/// # Returns
171///
172/// * Result containing a HashMap of collocations and their frequencies
173#[allow(dead_code)]
174pub fn extract_collocations(
175    text: &str,
176    tokenizer: &dyn Tokenizer,
177    window_size: usize,
178    min_count: usize,
179) -> Result<HashMap<(String, String), usize>> {
180    let tokens = tokenizer.tokenize(text)?;
181    let mut collocations = HashMap::new();
182
183    if tokens.len() < 2 {
184        return Ok(collocations);
185    }
186
187    // Count co-occurrences within the window
188    for i in 0..tokens.len() {
189        let end = std::cmp::min(i + window_size + 1, tokens.len());
190
191        for j in (i + 1)..end {
192            let pair = (tokens[i].clone(), tokens[j].clone());
193            *collocations.entry(pair).or_insert(0) += 1;
194        }
195    }
196
197    // Filter by minimum _count
198    collocations.retain(|_, &mut _count| _count >= min_count);
199
200    Ok(collocations)
201}
202
203/// Split text into training and testing sets
204///
205/// # Arguments
206///
207/// * `texts` - The texts to split
208/// * `test_size` - The proportion of the dataset to use for testing (0.0 to 1.0)
209/// * `random_seed` - Optional random seed for reproducibility
210///
211/// # Returns
212///
213/// * `(Vec<String>, Vec<String>)` - Training and testing sets
214#[allow(dead_code)]
215pub fn train_test_split(
216    texts: &[String],
217    test_size: f64,
218    random_seed: Option<u64>,
219) -> Result<(Vec<String>, Vec<String>)> {
220    use scirs2_core::random::seq::SliceRandom;
221    use scirs2_core::random::SeedableRng;
222
223    if !(0.0..=1.0).contains(&test_size) {
224        return Err(TextError::InvalidInput(
225            "test_size must be between 0.0 and 1.0".to_string(),
226        ));
227    }
228
229    if texts.is_empty() {
230        return Ok((Vec::new(), Vec::new()));
231    }
232
233    // Use the random _seed if provided
234    let mut rng = match random_seed {
235        Some(_seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(_seed),
236        None => {
237            let mut temp_rng = scirs2_core::random::rng();
238            scirs2_core::random::rngs::StdRng::from_rng(&mut temp_rng)
239        }
240    };
241
242    // Shuffle the texts
243    let mut texts_copy = texts.to_vec();
244    texts_copy.shuffle(&mut rng);
245
246    // Split into training and testing sets
247    let test_count = (texts.len() as f64 * test_size).round() as usize;
248    let train_count = texts.len() - test_count;
249
250    let traintexts = texts_copy.iter().take(train_count).cloned().collect();
251    let testtexts = texts_copy.iter().skip(train_count).cloned().collect();
252
253    Ok((traintexts, testtexts))
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use crate::tokenize::WordTokenizer;
260
261    #[test]
262    fn test_count_tokens() {
263        let tokenizer = WordTokenizer::default();
264        let text = "this is a test this is only a test";
265        let counts = count_tokens(text, &tokenizer).expect("Operation failed");
266
267        assert_eq!(counts.get("this").expect("Operation failed"), &2);
268        assert_eq!(counts.get("is").expect("Operation failed"), &2);
269        assert_eq!(counts.get("a").expect("Operation failed"), &2);
270        assert_eq!(counts.get("test").expect("Operation failed"), &2);
271        assert_eq!(counts.get("only").expect("Operation failed"), &1);
272    }
273
274    #[test]
275    fn test_filter_tokens() {
276        let tokenizer = WordTokenizer::default();
277        let text = "this is a test this is only a test";
278
279        // Filter out common words
280        let predicate = |token: &str| !["this", "is", "a"].contains(&token);
281        let filtered = filter_tokens(text, &tokenizer, predicate).expect("Operation failed");
282
283        assert_eq!(filtered, "test only test");
284    }
285
286    #[test]
287    fn test_extract_ngrams() {
288        let tokenizer = WordTokenizer::default();
289        let text = "this is a simple test";
290
291        // Extract bigrams
292        let bigrams = extract_ngrams(text, &tokenizer, 2).expect("Operation failed");
293        assert_eq!(bigrams, vec!["this is", "is a", "a simple", "simple test"]);
294
295        // Extract trigrams
296        let trigrams = extract_ngrams(text, &tokenizer, 3).expect("Operation failed");
297        assert_eq!(trigrams, vec!["this is a", "is a simple", "a simple test"]);
298    }
299
300    #[test]
301    fn test_extract_collocations() {
302        let tokenizer = WordTokenizer::default();
303        let text = "machine learning is a subset of artificial intelligence that provides systems with the ability to learn";
304
305        let collocations = extract_collocations(text, &tokenizer, 2, 1).expect("Operation failed");
306
307        // Check some expected collocations
308        assert!(collocations.contains_key(&("machine".to_string(), "learning".to_string())));
309        assert!(collocations.contains_key(&("artificial".to_string(), "intelligence".to_string())));
310    }
311
312    #[test]
313    fn test_train_test_split() {
314        let texts = vec![
315            "text 1".to_string(),
316            "text 2".to_string(),
317            "text 3".to_string(),
318            "text 4".to_string(),
319            "text 5".to_string(),
320        ];
321
322        // Split with a fixed seed for reproducibility
323        let (train, test) = train_test_split(&texts, 0.4, Some(42)).expect("Operation failed");
324
325        assert_eq!(train.len(), 3);
326        assert_eq!(test.len(), 2);
327
328        // All texts should be present exactly once in either train or test
329        for text in &texts {
330            assert_eq!(
331                train.iter().filter(|&t| t == text).count()
332                    + test.iter().filter(|&t| t == text).count(),
333                1
334            );
335        }
336    }
337}