1use crate::{Result, TextError};
2use regex::Regex;
4use scirs2_core::random::Random;
5use scirs2_core::RngExt;
6use std::collections::HashMap;
7
8lazy_static::lazy_static! {
10 static ref WHITESPACE_RE: Regex = Regex::new(r"\s+")
12 .expect("WHITESPACE_RE: compile-time constant regex should be valid");
13
14 static ref URL_RE: Regex = Regex::new(r"https?://[^\s]+")
16 .expect("URL_RE: compile-time constant regex should be valid");
17
18 static ref EMAIL_RE: Regex = Regex::new(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b")
20 .expect("EMAIL_RE: compile-time constant regex should be valid");
21
22 static ref HTML_RE: Regex = Regex::new(r"<[^>]+>")
24 .expect("HTML_RE: compile-time constant regex should be valid");
25
26 static ref MENTION_RE: Regex = Regex::new(r"@\w+")
28 .expect("MENTION_RE: compile-time constant regex should be valid");
29
30 static ref HASHTAG_RE: Regex = Regex::new(r"#\w+")
32 .expect("HASHTAG_RE: compile-time constant regex should be valid");
33}
34
35#[derive(Debug, Clone)]
40pub struct TextNormalizer {
41 lowercase: bool,
42 remove_accents: bool,
43 remove_punctuation: bool,
44 remove_digits: bool,
45 remove_extra_spaces: bool,
46 normalize_unicode: bool,
47}
48
49impl Default for TextNormalizer {
50 fn default() -> Self {
51 Self {
52 lowercase: true,
53 remove_accents: false,
54 remove_punctuation: false,
55 remove_digits: false,
56 remove_extra_spaces: true,
57 normalize_unicode: true,
58 }
59 }
60}
61
62impl TextNormalizer {
63 pub fn new() -> Self {
64 Self::default()
65 }
66
67 pub fn lowercase(mut self, value: bool) -> Self {
68 self.lowercase = value;
69 self
70 }
71
72 pub fn remove_accents(mut self, value: bool) -> Self {
73 self.remove_accents = value;
74 self
75 }
76
77 pub fn remove_punctuation(mut self, value: bool) -> Self {
78 self.remove_punctuation = value;
79 self
80 }
81
82 pub fn remove_digits(mut self, value: bool) -> Self {
83 self.remove_digits = value;
84 self
85 }
86
87 pub fn remove_extra_spaces(mut self, value: bool) -> Self {
88 self.remove_extra_spaces = value;
89 self
90 }
91
92 pub fn normalize_unicode(mut self, value: bool) -> Self {
93 self.normalize_unicode = value;
94 self
95 }
96
97 pub fn normalize(&self, text: &str) -> String {
98 let mut result = text.to_string();
99
100 if self.normalize_unicode {
101 result = self.normalize_unicode_text(&result);
102 }
103
104 if self.lowercase {
105 result = result.to_lowercase();
106 }
107
108 if self.remove_accents {
109 result = self.remove_accents_text(&result);
110 }
111
112 if self.remove_punctuation {
113 result = self.remove_punctuation_text(&result);
114 }
115
116 if self.remove_digits {
117 result = self.remove_digits_text(&result);
118 }
119
120 if self.remove_extra_spaces {
121 result = self.remove_extra_spaces_text(&result);
122 }
123
124 result.trim().to_string()
125 }
126
127 fn normalize_unicode_text(&self, text: &str) -> String {
128 let mut result = String::new();
130 for c in text.chars() {
131 match c {
132 '\u{2018}' | '\u{2019}' => result.push('\''), '\u{201C}' | '\u{201D}' => result.push('"'), '\u{2013}' | '\u{2014}' => result.push('-'), '\u{2026}' => result.push_str("..."), _ => result.push(c),
137 }
138 }
139 result
140 }
141
142 fn remove_accents_text(&self, text: &str) -> String {
143 let accent_map: HashMap<char, char> = [
145 ('à', 'a'),
146 ('á', 'a'),
147 ('â', 'a'),
148 ('ã', 'a'),
149 ('ä', 'a'),
150 ('å', 'a'),
151 ('è', 'e'),
152 ('é', 'e'),
153 ('ê', 'e'),
154 ('ë', 'e'),
155 ('ì', 'i'),
156 ('í', 'i'),
157 ('î', 'i'),
158 ('ï', 'i'),
159 ('ò', 'o'),
160 ('ó', 'o'),
161 ('ô', 'o'),
162 ('õ', 'o'),
163 ('ö', 'o'),
164 ('ù', 'u'),
165 ('ú', 'u'),
166 ('û', 'u'),
167 ('ü', 'u'),
168 ('ý', 'y'),
169 ('ÿ', 'y'),
170 ('ñ', 'n'),
171 ('ç', 'c'),
172 ('À', 'A'),
173 ('Á', 'A'),
174 ('Â', 'A'),
175 ('Ã', 'A'),
176 ('Ä', 'A'),
177 ('Å', 'A'),
178 ('È', 'E'),
179 ('É', 'E'),
180 ('Ê', 'E'),
181 ('Ë', 'E'),
182 ('Ì', 'I'),
183 ('Í', 'I'),
184 ('Î', 'I'),
185 ('Ï', 'I'),
186 ('Ò', 'O'),
187 ('Ó', 'O'),
188 ('Ô', 'O'),
189 ('Õ', 'O'),
190 ('Ö', 'O'),
191 ('Ù', 'U'),
192 ('Ú', 'U'),
193 ('Û', 'U'),
194 ('Ü', 'U'),
195 ('Ý', 'Y'),
196 ('Ÿ', 'Y'),
197 ('Ñ', 'N'),
198 ('Ç', 'C'),
199 ]
200 .iter()
201 .cloned()
202 .collect();
203
204 text.chars()
205 .map(|c| accent_map.get(&c).copied().unwrap_or(c))
206 .collect()
207 }
208
209 fn remove_punctuation_text(&self, text: &str) -> String {
210 text.chars().filter(|c| !c.is_ascii_punctuation()).collect()
211 }
212
213 fn remove_digits_text(&self, text: &str) -> String {
214 text.chars().filter(|c| !c.is_ascii_digit()).collect()
215 }
216
217 fn remove_extra_spaces_text(&self, text: &str) -> String {
218 WHITESPACE_RE.replace_all(text, " ").to_string()
219 }
220}
221
222#[derive(Debug, Clone)]
227pub struct TextCleaner {
228 remove_urls: bool,
229 remove_emails: bool,
230 remove_html: bool,
231 remove_mentions: bool,
232 remove_hashtags: bool,
233 remove_special_chars: bool,
234}
235
236impl Default for TextCleaner {
237 fn default() -> Self {
238 Self {
239 remove_urls: true,
240 remove_emails: true,
241 remove_html: true,
242 remove_mentions: false,
243 remove_hashtags: false,
244 remove_special_chars: false,
245 }
246 }
247}
248
249impl TextCleaner {
250 pub fn new() -> Self {
251 Self::default()
252 }
253
254 pub fn remove_urls(mut self, value: bool) -> Self {
255 self.remove_urls = value;
256 self
257 }
258
259 pub fn remove_emails(mut self, value: bool) -> Self {
260 self.remove_emails = value;
261 self
262 }
263
264 pub fn remove_html(mut self, value: bool) -> Self {
265 self.remove_html = value;
266 self
267 }
268
269 pub fn remove_mentions(mut self, value: bool) -> Self {
270 self.remove_mentions = value;
271 self
272 }
273
274 pub fn remove_hashtags(mut self, value: bool) -> Self {
275 self.remove_hashtags = value;
276 self
277 }
278
279 pub fn remove_special_chars(mut self, value: bool) -> Self {
280 self.remove_special_chars = value;
281 self
282 }
283
284 pub fn clean(&self, text: &str) -> String {
285 let mut result = text.to_string();
286
287 if self.remove_urls {
288 result = self.remove_urls_from_text(&result);
289 }
290
291 if self.remove_emails {
292 result = self.remove_emails_from_text(&result);
293 }
294
295 if self.remove_html {
296 result = self.remove_html_from_text(&result);
297 }
298
299 if self.remove_mentions {
300 result = self.remove_mentions_from_text(&result);
301 }
302
303 if self.remove_hashtags {
304 result = self.remove_hashtags_from_text(&result);
305 }
306
307 if self.remove_special_chars {
308 result = self.remove_special_chars_from_text(&result);
309 }
310
311 WHITESPACE_RE.replace_all(&result, " ").trim().to_string()
313 }
314
315 fn remove_urls_from_text(&self, text: &str) -> String {
316 URL_RE.replace_all(text, "").to_string()
317 }
318
319 fn remove_emails_from_text(&self, text: &str) -> String {
320 EMAIL_RE.replace_all(text, "").to_string()
321 }
322
323 fn remove_html_from_text(&self, text: &str) -> String {
324 HTML_RE.replace_all(text, "").to_string()
325 }
326
327 fn remove_mentions_from_text(&self, text: &str) -> String {
328 MENTION_RE.replace_all(text, "").to_string()
329 }
330
331 fn remove_hashtags_from_text(&self, text: &str) -> String {
332 HASHTAG_RE.replace_all(text, "").to_string()
333 }
334
335 fn remove_special_chars_from_text(&self, text: &str) -> String {
336 text.chars()
337 .filter(|c| c.is_alphanumeric() || c.is_whitespace())
338 .collect()
339 }
340}
341
342#[derive(Debug, Clone, Default)]
347pub struct TextAugmenter {
348 }
350
351impl TextAugmenter {
352 pub fn new() -> Self {
353 Self::default()
354 }
355
356 pub fn synonym_replacement(&self, text: &str, replacement_prob: f32) -> String {
357 let mut rng = Random::seed(42);
359 let synonyms: HashMap<&str, Vec<&str>> = [
361 ("good", vec!["great", "excellent", "wonderful"]),
362 ("bad", vec!["terrible", "awful", "horrible"]),
363 ("big", vec!["large", "huge", "enormous"]),
364 ("small", vec!["tiny", "little", "miniature"]),
365 ]
366 .iter()
367 .cloned()
368 .collect();
369
370 let words: Vec<&str> = text.split_whitespace().collect();
371 let mut result_words = Vec::new();
372
373 for word in words {
374 if rng.random::<f32>() < replacement_prob {
375 if let Some(syns) = synonyms.get(word.to_lowercase().as_str()) {
376 let idx = rng.gen_range(0..syns.len());
377 result_words.push(syns[idx].to_string());
378 } else {
379 result_words.push(word.to_string());
380 }
381 } else {
382 result_words.push(word.to_string());
383 }
384 }
385
386 result_words.join(" ")
387 }
388
389 pub fn random_insertion(&self, text: &str, insertion_prob: f32) -> String {
390 let mut rng = Random::seed(42);
392 let words: Vec<&str> = text.split_whitespace().collect();
393 let mut result_words = Vec::new();
394
395 let insert_words = ["the", "a", "an", "very", "really", "quite"];
397
398 for word in words {
399 result_words.push(word.to_string());
400
401 if rng.random::<f32>() < insertion_prob {
402 let idx = rng.gen_range(0..insert_words.len());
403 result_words.push(insert_words[idx].to_string());
404 }
405 }
406
407 result_words.join(" ")
408 }
409
410 pub fn random_deletion(&self, text: &str, deletion_prob: f32) -> String {
411 let mut rng = Random::seed(42);
413 let words: Vec<&str> = text.split_whitespace().collect();
414 let mut result_words = Vec::new();
415
416 for word in words {
417 if rng.random::<f32>() >= deletion_prob {
418 result_words.push(word.to_string());
419 }
420 }
421
422 if result_words.is_empty() {
423 text.to_string()
424 } else {
425 result_words.join(" ")
426 }
427 }
428
429 pub fn random_swap(&self, text: &str, swap_prob: f32) -> String {
430 let mut rng = Random::seed(42);
432 let mut words: Vec<String> = text.split_whitespace().map(|s| s.to_string()).collect();
433
434 if words.len() < 2 {
435 return text.to_string();
436 }
437
438 for i in 0..words.len() {
439 if rng.random::<f32>() < swap_prob {
440 let j = rng.gen_range(0..words.len());
441 words.swap(i, j);
442 }
443 }
444
445 words.join(" ")
446 }
447
448 pub fn back_translation_simulation(&self, text: &str) -> String {
449 let mut result = self.synonym_replacement(text, 0.1);
451 result = self.random_swap(&result, 0.05);
452 result
453 }
454
455 pub fn augment(&self, text: &str) -> String {
457 let mut result = self.synonym_replacement(text, 0.1);
459 result = self.random_insertion(&result, 0.05);
460 result = self.random_deletion(&result, 0.05);
461 result = self.random_swap(&result, 0.05);
462 result
463 }
464}
465
466#[derive(Debug, Clone, Copy)]
471pub enum PaddingStrategy {
472 Left,
473 Right,
474 Center,
475}
476
477#[derive(Debug, Clone, Copy)]
478pub enum TruncationStrategy {
479 Left,
480 Right,
481 Center,
482}
483
484pub fn pad_sequence(
485 tokens: &[u32],
486 max_length: usize,
487 pad_token_id: u32,
488 strategy: PaddingStrategy,
489) -> Vec<u32> {
490 if tokens.len() >= max_length {
491 return tokens.to_vec();
492 }
493
494 let padding_needed = max_length - tokens.len();
495 let mut result = Vec::with_capacity(max_length);
496
497 match strategy {
498 PaddingStrategy::Left => {
499 result.extend(vec![pad_token_id; padding_needed]);
500 result.extend_from_slice(tokens);
501 }
502 PaddingStrategy::Right => {
503 result.extend_from_slice(tokens);
504 result.extend(vec![pad_token_id; padding_needed]);
505 }
506 PaddingStrategy::Center => {
507 let left_padding = padding_needed / 2;
508 let right_padding = padding_needed - left_padding;
509 result.extend(vec![pad_token_id; left_padding]);
510 result.extend_from_slice(tokens);
511 result.extend(vec![pad_token_id; right_padding]);
512 }
513 }
514
515 result
516}
517
518pub fn truncate_sequence(
519 tokens: &[u32],
520 max_length: usize,
521 strategy: TruncationStrategy,
522) -> Vec<u32> {
523 if tokens.len() <= max_length {
524 return tokens.to_vec();
525 }
526
527 match strategy {
528 TruncationStrategy::Left => tokens[tokens.len() - max_length..].to_vec(),
529 TruncationStrategy::Right => tokens[..max_length].to_vec(),
530 TruncationStrategy::Center => {
531 let remove_from_each_side = (tokens.len() - max_length) / 2;
532 let start = remove_from_each_side;
533 let end = tokens.len() - (tokens.len() - max_length - remove_from_each_side);
534 tokens[start..end].to_vec()
535 }
536 }
537}
538
539pub fn pad_and_truncate_sequences(
540 sequences: &[Vec<u32>],
541 max_length: Option<usize>,
542 pad_token_id: u32,
543 padding_strategy: PaddingStrategy,
544 truncation_strategy: TruncationStrategy,
545) -> Vec<Vec<u32>> {
546 let max_len =
547 max_length.unwrap_or_else(|| sequences.iter().map(|seq| seq.len()).max().unwrap_or(0));
548
549 sequences
550 .iter()
551 .map(|seq| {
552 let truncated = truncate_sequence(seq, max_len, truncation_strategy);
553 pad_sequence(&truncated, max_len, pad_token_id, padding_strategy)
554 })
555 .collect()
556}
557
558pub fn one_hot_encode(token_ids: &[u32], vocab_size: usize) -> Vec<Vec<f32>> {
563 token_ids
564 .iter()
565 .map(|&token_id| {
566 let mut encoding = vec![0.0; vocab_size];
567 if (token_id as usize) < vocab_size {
568 encoding[token_id as usize] = 1.0;
569 }
570 encoding
571 })
572 .collect()
573}
574
575pub fn label_encode(labels: &[String]) -> (Vec<u32>, HashMap<String, u32>) {
576 let mut label_to_id = HashMap::new();
577 let mut id_counter = 0u32;
578
579 let encoded: Vec<u32> = labels
580 .iter()
581 .map(|label| {
582 if let Some(&id) = label_to_id.get(label) {
583 id
584 } else {
585 let id = id_counter;
586 label_to_id.insert(label.clone(), id);
587 id_counter += 1;
588 id
589 }
590 })
591 .collect();
592
593 (encoded, label_to_id)
594}
595
596#[derive(Debug)]
602pub struct TextPreprocessingPipeline {
603 normalizer: Option<TextNormalizer>,
604 cleaner: Option<TextCleaner>,
605 augmenter: Option<TextAugmenter>,
606 custom_steps: Vec<Box<dyn PreprocessingStep>>,
607}
608
609pub trait PreprocessingStep: std::fmt::Debug + Send + Sync {
611 fn process(&self, text: &str) -> Result<String>;
612 fn name(&self) -> &str;
613}
614
615pub struct CustomStep<F>
617where
618 F: Fn(&str) -> String + Send + Sync + 'static,
619{
620 function: F,
621 name: String,
622}
623
624impl<F> CustomStep<F>
625where
626 F: Fn(&str) -> String + Send + Sync + 'static,
627{
628 pub fn new(function: F, name: String) -> Self {
629 Self { function, name }
630 }
631}
632
633impl<F> std::fmt::Debug for CustomStep<F>
634where
635 F: Fn(&str) -> String + Send + Sync + 'static,
636{
637 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
638 f.debug_struct("CustomStep")
639 .field("name", &self.name)
640 .finish()
641 }
642}
643
644impl<F> PreprocessingStep for CustomStep<F>
645where
646 F: Fn(&str) -> String + Send + Sync + 'static,
647{
648 fn process(&self, text: &str) -> Result<String> {
649 Ok((self.function)(text))
650 }
651
652 fn name(&self) -> &str {
653 &self.name
654 }
655}
656
657impl Clone for TextPreprocessingPipeline {
658 fn clone(&self) -> Self {
659 Self {
660 normalizer: self.normalizer.clone(),
661 cleaner: self.cleaner.clone(),
662 augmenter: self.augmenter.clone(),
663 custom_steps: Vec::new(), }
665 }
666}
667
668impl TextPreprocessingPipeline {
669 pub fn new() -> Self {
670 Self {
671 normalizer: None,
672 cleaner: None,
673 augmenter: None,
674 custom_steps: Vec::new(),
675 }
676 }
677
678 pub fn with_normalization(mut self, normalizer: TextNormalizer) -> Self {
679 self.normalizer = Some(normalizer);
680 self
681 }
682
683 pub fn with_cleaning(mut self, cleaner: TextCleaner) -> Self {
684 self.cleaner = Some(cleaner);
685 self
686 }
687
688 pub fn with_augmentation(mut self, augmenter: TextAugmenter) -> Self {
689 self.augmenter = Some(augmenter);
690 self
691 }
692
693 pub fn add_custom_step(mut self, step: Box<dyn PreprocessingStep>) -> Self {
694 self.custom_steps.push(step);
695 self
696 }
697
698 pub fn process_text(&self, text: &str) -> Result<String> {
700 let mut result = text.to_string();
701
702 if let Some(normalizer) = &self.normalizer {
704 result = normalizer.normalize(&result);
705 }
706
707 if let Some(cleaner) = &self.cleaner {
709 result = cleaner.clean(&result);
710 }
711
712 for step in &self.custom_steps {
714 result = step.process(&result)?;
715 }
716
717 if let Some(augmenter) = &self.augmenter {
719 result = augmenter.augment(&result);
720 }
721
722 Ok(result)
723 }
724
725 pub fn process_batch(&self, texts: &[String]) -> Result<Vec<String>> {
727 texts.iter().map(|text| self.process_text(text)).collect()
728 }
729
730 pub fn process_batch_parallel(&self, texts: &[String]) -> Result<Vec<String>> {
732 use scirs2_core::parallel_ops::*;
733
734 texts
735 .par_iter()
736 .map(|text| self.process_text(text))
737 .collect()
738 }
739
740 pub fn summary(&self) -> Vec<String> {
742 let mut steps = Vec::new();
743
744 if self.normalizer.is_some() {
745 steps.push("Text Normalization".to_string());
746 }
747
748 if self.cleaner.is_some() {
749 steps.push("Text Cleaning".to_string());
750 }
751
752 for step in &self.custom_steps {
753 steps.push(format!("Custom: {}", step.name()));
754 }
755
756 if self.augmenter.is_some() {
757 steps.push("Text Augmentation".to_string());
758 }
759
760 steps
761 }
762}
763
764impl Default for TextPreprocessingPipeline {
765 fn default() -> Self {
766 Self::new()
767 .with_normalization(TextNormalizer::default())
768 .with_cleaning(TextCleaner::default())
769 }
770}
771
772#[derive(Debug)]
774pub struct RemoveExtraWhitespaceStep;
775
776impl PreprocessingStep for RemoveExtraWhitespaceStep {
777 fn process(&self, text: &str) -> Result<String> {
778 Ok(WHITESPACE_RE.replace_all(text, " ").trim().to_string())
779 }
780
781 fn name(&self) -> &str {
782 "Remove Extra Whitespace"
783 }
784}
785
786#[derive(Debug)]
787pub struct MinLengthFilterStep {
788 min_length: usize,
789}
790
791impl MinLengthFilterStep {
792 pub fn new(min_length: usize) -> Self {
793 Self { min_length }
794 }
795}
796
797impl PreprocessingStep for MinLengthFilterStep {
798 fn process(&self, text: &str) -> Result<String> {
799 if text.len() < self.min_length {
800 Err(TextError::ValidationError(format!(
801 "Text too short: {} < {}",
802 text.len(),
803 self.min_length
804 )))
805 } else {
806 Ok(text.to_string())
807 }
808 }
809
810 fn name(&self) -> &str {
811 "Minimum Length Filter"
812 }
813}
814
815#[derive(Debug)]
816pub struct MaxLengthTruncateStep {
817 max_length: usize,
818}
819
820impl MaxLengthTruncateStep {
821 pub fn new(max_length: usize) -> Self {
822 Self { max_length }
823 }
824}
825
826impl PreprocessingStep for MaxLengthTruncateStep {
827 fn process(&self, text: &str) -> Result<String> {
828 if text.len() > self.max_length {
829 Ok(text.chars().take(self.max_length).collect())
830 } else {
831 Ok(text.to_string())
832 }
833 }
834
835 fn name(&self) -> &str {
836 "Maximum Length Truncate"
837 }
838}
839
840pub struct PreprocessingUtils;
842
843impl PreprocessingUtils {
844 pub fn classification_pipeline() -> TextPreprocessingPipeline {
846 TextPreprocessingPipeline::new()
847 .with_normalization(
848 TextNormalizer::new()
849 .lowercase(true)
850 .remove_extra_spaces(true)
851 .normalize_unicode(true),
852 )
853 .with_cleaning(
854 TextCleaner::new()
855 .remove_urls(true)
856 .remove_emails(true)
857 .remove_special_chars(true),
858 )
859 .add_custom_step(Box::new(RemoveExtraWhitespaceStep))
860 }
861
862 pub fn language_modeling_pipeline() -> TextPreprocessingPipeline {
864 TextPreprocessingPipeline::new()
865 .with_normalization(
866 TextNormalizer::new()
867 .normalize_unicode(true)
868 .remove_extra_spaces(true),
869 )
870 .add_custom_step(Box::new(RemoveExtraWhitespaceStep))
871 .add_custom_step(Box::new(MinLengthFilterStep::new(10)))
872 }
873
874 pub fn translation_pipeline() -> TextPreprocessingPipeline {
876 TextPreprocessingPipeline::new()
877 .with_normalization(
878 TextNormalizer::new()
879 .normalize_unicode(true)
880 .remove_extra_spaces(true),
881 )
882 .add_custom_step(Box::new(RemoveExtraWhitespaceStep))
883 .add_custom_step(Box::new(MaxLengthTruncateStep::new(512)))
884 }
885
886 pub fn filter_texts(
888 texts: &[String],
889 min_length: Option<usize>,
890 max_length: Option<usize>,
891 allowed_chars: Option<&str>,
892 ) -> Vec<String> {
893 texts
894 .iter()
895 .filter(|text| {
896 if let Some(min) = min_length {
898 if text.len() < min {
899 return false;
900 }
901 }
902 if let Some(max) = max_length {
903 if text.len() > max {
904 return false;
905 }
906 }
907
908 if let Some(allowed) = allowed_chars {
910 let allowed_set: std::collections::HashSet<char> = allowed.chars().collect();
911 for ch in text.chars() {
912 if !allowed_set.contains(&ch) && !ch.is_whitespace() {
913 return false;
914 }
915 }
916 }
917
918 true
919 })
920 .cloned()
921 .collect()
922 }
923
924 pub fn compute_batch_stats(texts: &[String]) -> PreprocessingStats {
926 let total_texts = texts.len();
927 let total_chars: usize = texts.iter().map(|t| t.len()).sum();
928 let total_words: usize = texts.iter().map(|t| t.split_whitespace().count()).sum();
929
930 let avg_chars = if total_texts > 0 {
931 total_chars as f32 / total_texts as f32
932 } else {
933 0.0
934 };
935 let avg_words = if total_texts > 0 {
936 total_words as f32 / total_texts as f32
937 } else {
938 0.0
939 };
940
941 let min_chars = texts.iter().map(|t| t.len()).min().unwrap_or(0);
942 let max_chars = texts.iter().map(|t| t.len()).max().unwrap_or(0);
943
944 PreprocessingStats {
945 total_texts,
946 total_chars,
947 total_words,
948 avg_chars_per_text: avg_chars,
949 avg_words_per_text: avg_words,
950 min_text_length: min_chars,
951 max_text_length: max_chars,
952 }
953 }
954}
955
956#[derive(Debug, Clone)]
957pub struct PreprocessingStats {
958 pub total_texts: usize,
959 pub total_chars: usize,
960 pub total_words: usize,
961 pub avg_chars_per_text: f32,
962 pub avg_words_per_text: f32,
963 pub min_text_length: usize,
964 pub max_text_length: usize,
965}
966
967#[deprecated(
972 note = "Use TextPreprocessingPipeline::classification_pipeline().process_text() instead"
973)]
974pub fn normalize_text(text: &str) -> String {
975 TextNormalizer::default().normalize(text)
976}
977
978#[deprecated(note = "Use proper sentence segmentation libraries instead")]
979pub fn split_sentences(text: &str) -> Vec<String> {
980 let mut sentences = Vec::new();
981 let mut current = String::new();
982
983 for ch in text.chars() {
984 current.push(ch);
985 if ch == '.' || ch == '!' || ch == '?' {
986 let sentence = current.trim().to_string();
987 if !sentence.is_empty() {
988 sentences.push(sentence);
989 }
990 current.clear();
991 }
992 }
993
994 let remaining = current.trim().to_string();
996 if !remaining.is_empty() {
997 sentences.push(remaining);
998 }
999
1000 sentences
1001}
1002
1003pub fn count_words(text: &str) -> usize {
1004 text.split_whitespace().count()
1005}
1006
1007#[deprecated(
1008 note = "Use TextPreprocessingPipeline::classification_pipeline().process_text() instead"
1009)]
1010pub fn clean_text(text: &str) -> String {
1011 TextCleaner::default().clean(text)
1012}
1013
1014pub struct BatchProcessor {
1020 chunk_size: usize,
1021 parallel: bool,
1022 cache_enabled: bool,
1023 cache: Option<std::collections::HashMap<String, String>>,
1024}
1025
1026impl BatchProcessor {
1027 pub fn new() -> Self {
1028 Self {
1029 chunk_size: 1000,
1030 parallel: true,
1031 cache_enabled: false,
1032 cache: None,
1033 }
1034 }
1035
1036 pub fn with_chunk_size(mut self, size: usize) -> Self {
1037 self.chunk_size = size;
1038 self
1039 }
1040
1041 pub fn with_parallel(mut self, parallel: bool) -> Self {
1042 self.parallel = parallel;
1043 self
1044 }
1045
1046 pub fn with_cache(mut self, enable: bool) -> Self {
1047 self.cache_enabled = enable;
1048 if enable {
1049 self.cache = Some(std::collections::HashMap::new());
1050 } else {
1051 self.cache = None;
1052 }
1053 self
1054 }
1055
1056 pub fn process_with_function<F, T>(
1058 &mut self,
1059 texts: &[String],
1060 mut processor: F,
1061 ) -> Result<Vec<T>>
1062 where
1063 F: FnMut(&str) -> Result<T> + Send + Sync,
1064 T: Send + Sync,
1065 {
1066 if self.parallel && texts.len() > self.chunk_size {
1067 self.process_parallel_chunked(texts, processor)
1068 } else {
1069 texts.iter().map(|text| processor(text)).collect()
1070 }
1071 }
1072
1073 pub fn process_with_cache<F>(
1075 &mut self,
1076 texts: &[String],
1077 mut processor: F,
1078 ) -> Result<Vec<String>>
1079 where
1080 F: FnMut(&str) -> Result<String> + Send + Sync,
1081 {
1082 let mut results = Vec::with_capacity(texts.len());
1083
1084 for text in texts {
1085 if self.cache_enabled {
1086 if let Some(cache) = &self.cache {
1087 if let Some(cached_result) = cache.get(text) {
1088 results.push(cached_result.clone());
1089 continue;
1090 }
1091 }
1092 }
1093
1094 let result = processor(text)?;
1095
1096 if self.cache_enabled {
1097 if let Some(cache) = &mut self.cache {
1098 cache.insert(text.clone(), result.clone());
1099 }
1100 }
1101
1102 results.push(result);
1103 }
1104
1105 Ok(results)
1106 }
1107
1108 fn process_parallel_chunked<F, T>(&self, texts: &[String], processor: F) -> Result<Vec<T>>
1109 where
1110 F: FnMut(&str) -> Result<T> + Send + Sync,
1111 T: Send + Sync,
1112 {
1113 use scirs2_core::parallel_ops::*;
1114 use std::sync::Mutex;
1115
1116 let processor = Mutex::new(processor);
1117
1118 texts
1119 .par_chunks(self.chunk_size)
1120 .map(|chunk| {
1121 chunk
1122 .iter()
1123 .map(|text| {
1124 let mut proc = processor.lock().expect("lock should not be poisoned");
1125 proc(text)
1126 })
1127 .collect::<Result<Vec<T>>>()
1128 })
1129 .collect::<Result<Vec<Vec<T>>>>()
1130 .map(|chunks| chunks.into_iter().flatten().collect())
1131 }
1132
1133 pub fn clear_cache(&mut self) {
1135 if let Some(cache) = &mut self.cache {
1136 cache.clear();
1137 }
1138 }
1139
1140 pub fn cache_stats(&self) -> Option<(usize, usize)> {
1142 self.cache
1143 .as_ref()
1144 .map(|cache| (cache.len(), cache.capacity()))
1145 }
1146}
1147
1148impl Default for BatchProcessor {
1149 fn default() -> Self {
1150 Self::new()
1151 }
1152}
1153
1154pub struct OptimizedBatchOps;
1156
1157impl OptimizedBatchOps {
1158 pub fn batch_tokenize(
1160 texts: &[String],
1161 tokenizer: &dyn crate::tokenization::Tokenizer,
1162 parallel: bool,
1163 ) -> Result<Vec<Vec<u32>>> {
1164 if parallel && texts.len() > 100 {
1165 use scirs2_core::parallel_ops::*;
1166 texts
1167 .par_iter()
1168 .map(|text| tokenizer.encode(text))
1169 .collect()
1170 } else {
1171 texts.iter().map(|text| tokenizer.encode(text)).collect()
1172 }
1173 }
1174
1175 pub fn batch_clean(texts: &[String], cleaner: &TextCleaner) -> Vec<String> {
1177 let mut processor = BatchProcessor::new()
1178 .with_parallel(true)
1179 .with_chunk_size(500)
1180 .with_cache(texts.len() > 1000);
1181
1182 processor
1183 .process_with_cache(texts, |text| Ok(cleaner.clean(text)))
1184 .unwrap_or_else(|_| texts.iter().map(|t| cleaner.clean(t)).collect())
1185 }
1186
1187 pub fn batch_normalize(texts: &[String], normalizer: &TextNormalizer) -> Vec<String> {
1189 use scirs2_core::parallel_ops::*;
1190
1191 if texts.len() > 100 {
1192 texts
1193 .par_iter()
1194 .map(|text| normalizer.normalize(text))
1195 .collect()
1196 } else {
1197 texts
1198 .iter()
1199 .map(|text| normalizer.normalize(text))
1200 .collect()
1201 }
1202 }
1203
1204 pub fn batch_statistics(texts: &[String]) -> BatchTextStats {
1206 use scirs2_core::parallel_ops::*;
1207
1208 let chunk_size = 1000;
1209
1210 if texts.len() > chunk_size {
1211 let partial_stats: Vec<BatchTextStats> = texts
1213 .par_chunks(chunk_size)
1214 .map(Self::compute_chunk_stats)
1215 .collect();
1216
1217 Self::merge_stats(partial_stats)
1219 } else {
1220 Self::compute_chunk_stats(texts)
1221 }
1222 }
1223
1224 fn compute_chunk_stats(texts: &[String]) -> BatchTextStats {
1225 let mut total_chars = 0;
1226 let mut total_words = 0;
1227 let mut min_length = usize::MAX;
1228 let mut max_length = 0;
1229 let mut char_distribution = std::collections::HashMap::new();
1230
1231 for text in texts {
1232 let char_count = text.chars().count();
1233 let word_count = text.split_whitespace().count();
1234
1235 total_chars += char_count;
1236 total_words += word_count;
1237 min_length = min_length.min(char_count);
1238 max_length = max_length.max(char_count);
1239
1240 if texts.len() < 10000 {
1242 for ch in text.chars() {
1243 *char_distribution.entry(ch).or_insert(0) += 1;
1244 }
1245 }
1246 }
1247
1248 if min_length == usize::MAX {
1249 min_length = 0;
1250 }
1251
1252 BatchTextStats {
1253 text_count: texts.len(),
1254 total_chars,
1255 total_words,
1256 min_length,
1257 max_length,
1258 avg_length: if texts.is_empty() {
1259 0.0
1260 } else {
1261 total_chars as f64 / texts.len() as f64
1262 },
1263 avg_words: if texts.is_empty() {
1264 0.0
1265 } else {
1266 total_words as f64 / texts.len() as f64
1267 },
1268 char_distribution,
1269 }
1270 }
1271
1272 fn merge_stats(stats: Vec<BatchTextStats>) -> BatchTextStats {
1273 let mut merged = BatchTextStats::default();
1274
1275 for stat in stats {
1276 merged.text_count += stat.text_count;
1277 merged.total_chars += stat.total_chars;
1278 merged.total_words += stat.total_words;
1279 merged.min_length = merged.min_length.min(stat.min_length);
1280 merged.max_length = merged.max_length.max(stat.max_length);
1281
1282 for (ch, count) in stat.char_distribution {
1284 *merged.char_distribution.entry(ch).or_insert(0) += count;
1285 }
1286 }
1287
1288 if merged.text_count > 0 {
1290 merged.avg_length = merged.total_chars as f64 / merged.text_count as f64;
1291 merged.avg_words = merged.total_words as f64 / merged.text_count as f64;
1292 }
1293
1294 merged
1295 }
1296
1297 pub fn batch_filter<F>(texts: &[String], predicate: F) -> Vec<String>
1299 where
1300 F: Fn(&str) -> bool + Send + Sync,
1301 {
1302 use scirs2_core::parallel_ops::*;
1303
1304 if texts.len() > 1000 {
1305 texts
1306 .par_iter()
1307 .filter(|text| predicate(text))
1308 .cloned()
1309 .collect()
1310 } else {
1311 texts
1312 .iter()
1313 .filter(|text| predicate(text))
1314 .cloned()
1315 .collect()
1316 }
1317 }
1318
1319 pub fn process_large_file<F>(
1321 file_path: &std::path::Path,
1322 processor: F,
1323 output_path: &std::path::Path,
1324 ) -> Result<()>
1325 where
1326 F: Fn(&str) -> String + Send + Sync,
1327 {
1328 use std::fs::File;
1329 use std::io::{BufRead, BufReader, BufWriter, Write};
1330
1331 let input_file = File::open(file_path)?;
1332 let reader = BufReader::new(input_file);
1333
1334 let output_file = File::create(output_path)?;
1335 let mut writer = BufWriter::new(output_file);
1336
1337 const BATCH_SIZE: usize = 1000;
1338 let mut batch = Vec::with_capacity(BATCH_SIZE);
1339
1340 for line in reader.lines() {
1341 let line = line?;
1342 batch.push(line);
1343
1344 if batch.len() >= BATCH_SIZE {
1345 let processed: Vec<String> = if batch.len() > 100 {
1347 use scirs2_core::parallel_ops::*;
1348 batch.par_iter().map(|text| processor(text)).collect()
1349 } else {
1350 batch.iter().map(|text| processor(text)).collect()
1351 };
1352
1353 for result in processed {
1355 writeln!(writer, "{result}")?;
1356 }
1357
1358 batch.clear();
1359 }
1360 }
1361
1362 if !batch.is_empty() {
1364 let processed: Vec<String> = batch.iter().map(|text| processor(text)).collect();
1365 for result in processed {
1366 writeln!(writer, "{result}")?;
1367 }
1368 }
1369
1370 writer.flush()?;
1371 Ok(())
1372 }
1373}
1374
1375#[derive(Debug, Clone, Default)]
1376pub struct BatchTextStats {
1377 pub text_count: usize,
1378 pub total_chars: usize,
1379 pub total_words: usize,
1380 pub min_length: usize,
1381 pub max_length: usize,
1382 pub avg_length: f64,
1383 pub avg_words: f64,
1384 pub char_distribution: std::collections::HashMap<char, usize>,
1385}
1386
1387type StreamingProcessorFn<T> = Box<dyn FnMut(&[T]) -> Result<Vec<T>>>;
1389
1390pub struct StreamingBatchProcessor<T> {
1392 batch_size: usize,
1393 buffer: Vec<T>,
1394 processor: StreamingProcessorFn<T>,
1395}
1396
1397impl<T> StreamingBatchProcessor<T> {
1398 pub fn new<F>(batch_size: usize, processor: F) -> Self
1399 where
1400 F: FnMut(&[T]) -> Result<Vec<T>> + 'static,
1401 {
1402 Self {
1403 batch_size,
1404 buffer: Vec::with_capacity(batch_size),
1405 processor: Box::new(processor),
1406 }
1407 }
1408
1409 pub fn add_item(&mut self, item: T) -> Result<Option<Vec<T>>> {
1410 self.buffer.push(item);
1411
1412 if self.buffer.len() >= self.batch_size {
1413 let result = (self.processor)(&self.buffer)?;
1414 self.buffer.clear();
1415 Ok(Some(result))
1416 } else {
1417 Ok(None)
1418 }
1419 }
1420
1421 pub fn finish(mut self) -> Result<Vec<T>> {
1422 if !self.buffer.is_empty() {
1423 (self.processor)(&self.buffer)
1424 } else {
1425 Ok(Vec::new())
1426 }
1427 }
1428}