1use crate::error::{Result, TextError};
6use crate::tokenize::Tokenizer;
7use scirs2_core::parallel_ops;
8use std::collections::HashMap;
9
10#[allow(dead_code)]
21pub fn count_tokens(text: &str, tokenizer: &dyn Tokenizer) -> Result<HashMap<String, usize>> {
22 let tokens = tokenizer.tokenize(text)?;
23 let mut counts = HashMap::new();
24
25 for token in tokens {
26 *counts.entry(token).or_insert(0) += 1;
27 }
28
29 Ok(counts)
30}
31
32#[allow(dead_code)]
43pub fn count_tokens_batch(
44 texts: &[&str],
45 tokenizer: &dyn Tokenizer,
46) -> Result<HashMap<String, usize>> {
47 let mut total_counts = HashMap::new();
49
50 for &text in texts {
51 let counts = count_tokens(text, tokenizer)?;
52 for (token, count) in counts {
53 *total_counts.entry(token).or_insert(0) += count;
54 }
55 }
56
57 Ok(total_counts)
58}
59
60#[allow(dead_code)]
71pub fn count_tokens_batch_parallel<T>(
72 texts: &[&str],
73 tokenizer: &T,
74) -> Result<HashMap<String, usize>>
75where
76 T: Tokenizer + Send + Sync,
77{
78 let texts_owned: Vec<String> = texts.iter().map(|&s| s.to_string()).collect();
81 let tokenizer_boxed = tokenizer.clone_box();
82
83 let token_counts = parallel_ops::parallel_map_result(&texts_owned, move |text| {
84 count_tokens(text, &*tokenizer_boxed).map_err(|e| {
85 scirs2_core::CoreError::ComputationError(scirs2_core::error::ErrorContext::new(
87 format!("Text processing error: {e}"),
88 ))
89 })
90 })?;
91
92 let mut total_counts = HashMap::new();
94 for counts in token_counts {
95 for (token, count) in counts {
96 *total_counts.entry(token).or_insert(0) += count;
97 }
98 }
99
100 Ok(total_counts)
101}
102
103#[allow(dead_code)]
115pub fn filter_tokens<F>(text: &str, tokenizer: &dyn Tokenizer, predicate: F) -> Result<String>
116where
117 F: Fn(&str) -> bool,
118{
119 let tokens = tokenizer.tokenize(text)?;
120 let filtered_tokens: Vec<String> = tokens
121 .iter()
122 .filter(|token| predicate(token))
123 .cloned()
124 .collect();
125
126 Ok(filtered_tokens.join(" "))
127}
128
129#[allow(dead_code)]
141pub fn extract_ngrams(text: &str, tokenizer: &dyn Tokenizer, n: usize) -> Result<Vec<String>> {
142 if n == 0 {
143 return Err(TextError::InvalidInput(
144 "n-gram size must be greater than 0".to_string(),
145 ));
146 }
147
148 let tokens = tokenizer.tokenize(text)?;
149
150 if tokens.is_empty() || tokens.len() < n {
151 return Ok(Vec::new());
152 }
153
154 let ngrams: Vec<String> = (0..=(tokens.len() - n))
155 .map(|i| tokens[i..(i + n)].to_vec().join(" "))
156 .collect();
157
158 Ok(ngrams)
159}
160
161#[allow(dead_code)]
174pub fn extract_collocations(
175 text: &str,
176 tokenizer: &dyn Tokenizer,
177 window_size: usize,
178 min_count: usize,
179) -> Result<HashMap<(String, String), usize>> {
180 let tokens = tokenizer.tokenize(text)?;
181 let mut collocations = HashMap::new();
182
183 if tokens.len() < 2 {
184 return Ok(collocations);
185 }
186
187 for i in 0..tokens.len() {
189 let end = std::cmp::min(i + window_size + 1, tokens.len());
190
191 for j in (i + 1)..end {
192 let pair = (tokens[i].clone(), tokens[j].clone());
193 *collocations.entry(pair).or_insert(0) += 1;
194 }
195 }
196
197 collocations.retain(|_, &mut _count| _count >= min_count);
199
200 Ok(collocations)
201}
202
203#[allow(dead_code)]
215pub fn train_test_split(
216 texts: &[String],
217 test_size: f64,
218 random_seed: Option<u64>,
219) -> Result<(Vec<String>, Vec<String>)> {
220 use scirs2_core::random::seq::SliceRandom;
221 use scirs2_core::random::SeedableRng;
222
223 if !(0.0..=1.0).contains(&test_size) {
224 return Err(TextError::InvalidInput(
225 "test_size must be between 0.0 and 1.0".to_string(),
226 ));
227 }
228
229 if texts.is_empty() {
230 return Ok((Vec::new(), Vec::new()));
231 }
232
233 let mut rng = match random_seed {
235 Some(_seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(_seed),
236 None => {
237 let mut temp_rng = scirs2_core::random::rng();
238 scirs2_core::random::rngs::StdRng::from_rng(&mut temp_rng)
239 }
240 };
241
242 let mut texts_copy = texts.to_vec();
244 texts_copy.shuffle(&mut rng);
245
246 let test_count = (texts.len() as f64 * test_size).round() as usize;
248 let train_count = texts.len() - test_count;
249
250 let traintexts = texts_copy.iter().take(train_count).cloned().collect();
251 let testtexts = texts_copy.iter().skip(train_count).cloned().collect();
252
253 Ok((traintexts, testtexts))
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use crate::tokenize::WordTokenizer;
260
261 #[test]
262 fn test_count_tokens() {
263 let tokenizer = WordTokenizer::default();
264 let text = "this is a test this is only a test";
265 let counts = count_tokens(text, &tokenizer).expect("Operation failed");
266
267 assert_eq!(counts.get("this").expect("Operation failed"), &2);
268 assert_eq!(counts.get("is").expect("Operation failed"), &2);
269 assert_eq!(counts.get("a").expect("Operation failed"), &2);
270 assert_eq!(counts.get("test").expect("Operation failed"), &2);
271 assert_eq!(counts.get("only").expect("Operation failed"), &1);
272 }
273
274 #[test]
275 fn test_filter_tokens() {
276 let tokenizer = WordTokenizer::default();
277 let text = "this is a test this is only a test";
278
279 let predicate = |token: &str| !["this", "is", "a"].contains(&token);
281 let filtered = filter_tokens(text, &tokenizer, predicate).expect("Operation failed");
282
283 assert_eq!(filtered, "test only test");
284 }
285
286 #[test]
287 fn test_extract_ngrams() {
288 let tokenizer = WordTokenizer::default();
289 let text = "this is a simple test";
290
291 let bigrams = extract_ngrams(text, &tokenizer, 2).expect("Operation failed");
293 assert_eq!(bigrams, vec!["this is", "is a", "a simple", "simple test"]);
294
295 let trigrams = extract_ngrams(text, &tokenizer, 3).expect("Operation failed");
297 assert_eq!(trigrams, vec!["this is a", "is a simple", "a simple test"]);
298 }
299
300 #[test]
301 fn test_extract_collocations() {
302 let tokenizer = WordTokenizer::default();
303 let text = "machine learning is a subset of artificial intelligence that provides systems with the ability to learn";
304
305 let collocations = extract_collocations(text, &tokenizer, 2, 1).expect("Operation failed");
306
307 assert!(collocations.contains_key(&("machine".to_string(), "learning".to_string())));
309 assert!(collocations.contains_key(&("artificial".to_string(), "intelligence".to_string())));
310 }
311
312 #[test]
313 fn test_train_test_split() {
314 let texts = vec![
315 "text 1".to_string(),
316 "text 2".to_string(),
317 "text 3".to_string(),
318 "text 4".to_string(),
319 "text 5".to_string(),
320 ];
321
322 let (train, test) = train_test_split(&texts, 0.4, Some(42)).expect("Operation failed");
324
325 assert_eq!(train.len(), 3);
326 assert_eq!(test.len(), 2);
327
328 for text in &texts {
330 assert_eq!(
331 train.iter().filter(|&t| t == text).count()
332 + test.iter().filter(|&t| t == text).count(),
333 1
334 );
335 }
336 }
337}