1use crate::error::Result;
7use crate::tokenize::Tokenizer;
8use crate::vectorize::Vectorizer;
9use scirs2_core::ndarray::Array2;
10use scirs2_core::parallel_ops::*;
11use std::sync::{Arc, Mutex};
12
13pub struct ParallelTokenizer<T: Tokenizer + Send + Sync> {
15 tokenizer: T,
17 chunk_size: usize,
19}
20
21impl<T: Tokenizer + Send + Sync> ParallelTokenizer<T> {
22 pub fn new(tokenizer: T) -> Self {
24 Self {
25 tokenizer,
26 chunk_size: 1000,
27 }
28 }
29
30 pub fn with_chunk_size(mut self, chunksize: usize) -> Self {
32 self.chunk_size = chunksize;
33 self
34 }
35
36 pub fn tokenize(&self, texts: &[&str]) -> Result<Vec<Vec<String>>> {
38 let results: Result<Vec<_>> = texts
39 .par_chunks(self.chunk_size)
40 .flat_map(|chunk| {
41 let mut chunk_results = Vec::new();
42 for &text in chunk {
43 match self.tokenizer.tokenize(text) {
44 Ok(tokens) => chunk_results.push(tokens),
45 Err(e) => return vec![Err(e)],
46 }
47 }
48 chunk_results.into_iter().map(Ok).collect::<Vec<_>>()
49 })
50 .collect();
51
52 results
53 }
54
55 pub fn tokenize_and_map<F, R>(&self, texts: &[&str], mapper: F) -> Result<Vec<R>>
57 where
58 F: Fn(Vec<String>) -> R + Send + Sync,
59 R: Send,
60 {
61 let results: Result<Vec<_>> = texts
62 .par_chunks(self.chunk_size)
63 .flat_map(|chunk| {
64 let mut chunk_results = Vec::new();
65 for &text in chunk {
66 match self.tokenizer.tokenize(text) {
67 Ok(tokens) => chunk_results.push(Ok(mapper(tokens))),
68 Err(e) => return vec![Err(e)],
69 }
70 }
71 chunk_results
72 })
73 .collect();
74
75 results
76 }
77}
78
79pub struct ParallelVectorizer<T: Vectorizer + Send + Sync> {
81 vectorizer: Arc<T>,
83 chunk_size: usize,
85}
86
87impl<T: Vectorizer + Send + Sync> ParallelVectorizer<T> {
88 pub fn new(vectorizer: T) -> Self {
90 Self {
91 vectorizer: Arc::new(vectorizer),
92 chunk_size: 100,
93 }
94 }
95
96 pub fn with_chunk_size(mut self, chunksize: usize) -> Self {
98 self.chunk_size = chunksize;
99 self
100 }
101
102 pub fn transform(&self, texts: &[&str]) -> Result<Array2<f64>> {
104 let first_features = self.vectorizer.transform_batch(&texts[0..1])?;
106 let n_features = first_features.ncols();
107
108 let n_samples = texts.len();
110 let result = Arc::new(Mutex::new(Array2::zeros((n_samples, n_features))));
111
112 let chunk_size = self.chunk_size;
114 let errors = Arc::new(Mutex::new(Vec::new()));
115
116 texts
117 .par_chunks(chunk_size)
118 .enumerate()
119 .for_each(|(chunk_idx, chunk)| {
120 let start_idx = chunk_idx * chunk_size;
121
122 match self.vectorizer.transform_batch(chunk) {
123 Ok(chunk_vectors) => {
124 let mut result = result.lock().unwrap();
125
126 for (i, row) in chunk_vectors.rows().into_iter().enumerate() {
127 if start_idx + i < n_samples {
128 result.row_mut(start_idx + i).assign(&row);
129 }
130 }
131 }
132 Err(e) => {
133 let mut errors = errors.lock().unwrap();
134 errors.push(e);
135 }
136 }
137 });
138
139 let errors = errors.lock().unwrap();
140 if !errors.is_empty() {
141 return Err(errors[0].clone());
142 }
143
144 let result = Arc::try_unwrap(result)
145 .map_err(|_| {
146 crate::error::TextError::RuntimeError("Failed to unwrap result Arc".to_string())
147 })?
148 .into_inner()
149 .map_err(|_| {
150 crate::error::TextError::RuntimeError("Failed to unwrap result Mutex".to_string())
151 })?;
152
153 Ok(result)
154 }
155}
156
157pub struct ParallelTextProcessor {
159 num_threads: usize,
161}
162
163impl Default for ParallelTextProcessor {
164 fn default() -> Self {
165 Self {
166 num_threads: num_cpus::get(),
167 }
168 }
169}
170
171impl ParallelTextProcessor {
172 pub fn new() -> Self {
174 Self::default()
175 }
176
177 pub fn with_threads(mut self, numthreads: usize) -> Self {
179 self.num_threads = numthreads;
180 self
181 }
182
183 pub fn process<F, R>(&self, texts: &[&str], f: F) -> Vec<R>
185 where
186 F: Fn(&str) -> R + Send + Sync,
187 R: Send,
188 {
189 texts.par_iter().map(|&text| f(text)).collect()
190 }
191
192 pub fn process_and_flatten<F, R>(&self, texts: &[&str], f: F) -> Vec<R>
194 where
195 F: Fn(&str) -> Vec<R> + Send + Sync,
196 R: Send,
197 {
198 texts.par_iter().flat_map(|&text| f(text)).collect()
199 }
200
201 pub fn process_with_progress<F, R>(
203 &self,
204 texts: &[&str],
205 f: F,
206 update_interval: usize,
207 ) -> Result<(Vec<R>, Vec<usize>)>
208 where
209 F: Fn(&str) -> R + Send + Sync,
210 R: Send,
211 {
212 let progress = Arc::new(Mutex::new(Vec::new()));
213 let total = texts.len();
214
215 let results: Vec<R> = texts
216 .par_iter()
217 .enumerate()
218 .map(|(i, &text)| {
219 let result = f(text);
220
221 if i % update_interval == 0 || i == total - 1 {
223 let mut progress = progress.lock().unwrap();
224 progress.push(i + 1);
225 }
226
227 result
228 })
229 .collect();
230
231 let progress = Arc::try_unwrap(progress)
232 .map_err(|_| {
233 crate::error::TextError::RuntimeError("Failed to unwrap progress Arc".to_string())
234 })?
235 .into_inner()
236 .map_err(|_| {
237 crate::error::TextError::RuntimeError("Failed to unwrap progress Mutex".to_string())
238 })?;
239
240 Ok((results, progress))
241 }
242
243 pub fn batch_process<F, R>(&self, texts: &[&str], chunksize: usize, f: F) -> Vec<Vec<R>>
245 where
246 F: Fn(&[&str]) -> Vec<R> + Send + Sync,
247 R: Send,
248 {
249 texts.par_chunks(chunksize).map(f).collect()
250 }
251}
252
253pub struct ParallelCorpusProcessor {
255 batch_size: usize,
257 num_threads: Option<usize>,
259 max_memory: Option<usize>,
261}
262
263impl ParallelCorpusProcessor {
264 pub fn new(_batchsize: usize) -> Self {
266 Self {
267 batch_size: _batchsize,
268 num_threads: None,
269 max_memory: None,
270 }
271 }
272
273 pub fn with_threads(mut self, numthreads: usize) -> Self {
275 self.num_threads = Some(numthreads);
276 self
277 }
278
279 pub fn with_max_memory(mut self, maxmemory: usize) -> Self {
281 self.max_memory = Some(maxmemory);
282 self
283 }
284
285 pub fn process<F, R>(&self, corpus: &[&str], processor: F) -> Result<Vec<R>>
287 where
288 F: Fn(&[&str]) -> Result<Vec<R>> + Send + Sync,
289 R: Send,
290 {
291 let results = Arc::new(Mutex::new(Vec::new()));
293 let errors = Arc::new(Mutex::new(Vec::new()));
294
295 {
297 let indexed_results: Vec<_> = corpus
299 .par_chunks(self.batch_size)
300 .enumerate()
301 .map(|(idx, batch)| match processor(batch) {
302 Ok(batch_results) => Ok((idx, batch_results)),
303 Err(e) => Err(e),
304 })
305 .collect();
306
307 for result in &indexed_results {
309 if let Err(e) = result {
310 let mut errors = errors.lock().unwrap();
311 errors.push(e.clone());
312 return Err(e.clone());
313 }
314 }
315
316 let mut sorted_results: Vec<_> =
318 indexed_results.into_iter().filter_map(|r| r.ok()).collect();
319 sorted_results.sort_by_key(|(idx_, _)| *idx_);
320
321 let mut results_guard = results.lock().unwrap();
322 for (_, batch_results) in sorted_results {
323 results_guard.extend(batch_results);
324 }
325 }
326
327 let errors = errors.lock().unwrap();
329 if !errors.is_empty() {
330 return Err(errors[0].clone());
331 }
332
333 let results = Arc::try_unwrap(results)
335 .map_err(|_| {
336 crate::error::TextError::RuntimeError("Failed to unwrap results Arc".to_string())
337 })?
338 .into_inner()
339 .map_err(|_| {
340 crate::error::TextError::RuntimeError("Failed to unwrap results Mutex".to_string())
341 })?;
342
343 Ok(results)
344 }
345
346 pub fn process_with_progress<F, R>(
348 &self,
349 corpus: &[&str],
350 processor: F,
351 progress_callback: impl Fn(usize, usize) + Send + Sync,
352 ) -> Result<Vec<R>>
353 where
354 F: Fn(&[&str]) -> Result<Vec<R>> + Send + Sync,
355 R: Send,
356 {
357 let errors = Arc::new(Mutex::new(Vec::new()));
358 let processed = Arc::new(std::sync::atomic::AtomicUsize::new(0));
359 let total = corpus.len();
360
361 let batches: Vec<_> = corpus.chunks(self.batch_size).collect();
362
363 let indexed_results: Vec<_> = batches
365 .into_par_iter()
366 .enumerate()
367 .map(|(idx, batch)| {
368 let result = match processor(batch) {
369 Ok(batch_results) => Ok((idx, batch_results)),
370 Err(e) => Err(e),
371 };
372
373 let current = processed.fetch_add(batch.len(), std::sync::atomic::Ordering::SeqCst)
375 + batch.len();
376 progress_callback(current, total);
377
378 result
379 })
380 .collect();
381
382 for result in &indexed_results {
384 if let Err(e) = result {
385 let mut errors = errors.lock().unwrap();
386 errors.push(e.clone());
387 }
388 }
389
390 let errors = errors.lock().unwrap();
392 if !errors.is_empty() {
393 return Err(errors[0].clone());
394 }
395 drop(errors);
396
397 let mut sorted_results: Vec<_> =
399 indexed_results.into_iter().filter_map(|r| r.ok()).collect();
400 sorted_results.sort_by_key(|(idx_, _)| *idx_);
401
402 let mut final_results = Vec::new();
403 for (_, batch_results) in sorted_results {
404 final_results.extend(batch_results);
405 }
406
407 Ok(final_results)
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use crate::tokenize::WhitespaceTokenizer;
415 use crate::vectorize::TfidfVectorizer;
416
417 fn create_testtexts() -> Vec<&'static str> {
418 vec![
419 "This is a test document",
420 "Another test document here",
421 "Document with more words for testing",
422 "Short text",
423 "More documents for parallel processing testing",
424 ]
425 }
426
427 #[test]
428 fn test_parallel_tokenizer() {
429 let tokenizer = ParallelTokenizer::new(WhitespaceTokenizer::new());
430 let texts = create_testtexts();
431
432 let tokens = tokenizer.tokenize(&texts);
433
434 let tokens = tokens.expect("Tokenization should succeed");
435 assert_eq!(tokens.len(), texts.len());
436 assert_eq!(tokens[0], vec!["This", "is", "a", "test", "document"]);
437 }
438
439 #[test]
440 fn test_parallel_tokenizer_with_mapper() {
441 let tokenizer = ParallelTokenizer::new(WhitespaceTokenizer::new());
442 let texts = create_testtexts();
443
444 let token_counts = tokenizer.tokenize_and_map(&texts, |tokens| tokens.len());
445
446 let token_counts = token_counts.expect("Tokenization and mapping should succeed");
447 assert_eq!(token_counts, vec![5, 4, 6, 2, 6]);
448 }
449
450 #[test]
451 fn test_parallel_vectorizer() {
452 let mut vectorizer = TfidfVectorizer::default();
453 let texts = create_testtexts();
454
455 vectorizer.fit(&texts).unwrap();
456 let parallel_vectorizer = ParallelVectorizer::new(vectorizer);
457
458 let vectors = parallel_vectorizer.transform(&texts).unwrap();
459
460 assert_eq!(vectors.nrows(), texts.len());
461 assert!(vectors.ncols() > 0);
462 }
463
464 #[test]
465 fn test_paralleltext_processor() {
466 let processor = ParallelTextProcessor::new();
467 let texts = create_testtexts();
468
469 let word_counts = processor.process(&texts, |text| text.split_whitespace().count());
470
471 assert_eq!(word_counts, vec![5, 4, 6, 2, 6]);
472 }
473
474 #[test]
475 fn test_paralleltext_processor_with_progress() {
476 let processor = ParallelTextProcessor::new();
477 let texts = create_testtexts();
478
479 let (word_counts, progress) = processor
480 .process_with_progress(&texts, |text| text.split_whitespace().count(), 2)
481 .unwrap();
482
483 assert_eq!(word_counts, vec![5, 4, 6, 2, 6]);
484 assert!(!progress.is_empty());
485 }
486
487 #[test]
488 fn test_parallel_corpus_processor() {
489 let processor = ParallelCorpusProcessor::new(2);
490 let texts = create_testtexts();
491
492 let result = processor
493 .process(&texts, |batch| {
494 let counts = batch
495 .iter()
496 .map(|text| text.split_whitespace().count())
497 .collect();
498 Ok(counts)
499 })
500 .unwrap();
501
502 assert_eq!(result, vec![5, 4, 6, 2, 6]);
503 }
504}