1use crate::embeddings::Word2Vec;
33use crate::error::{Result, TextError};
34use crate::tokenize::{Tokenizer, WordTokenizer};
35use scirs2_core::random::{thread_rng, CoreRandom};
36use std::collections::{HashMap, HashSet};
37
38#[derive(Debug, Clone, Copy, PartialEq)]
40pub enum ParaphraseStrategy {
41 Synonym,
43 Restructure,
45 BackTranslation,
47 Hybrid,
49}
50
51#[derive(Debug, Clone)]
53pub struct ParaphraseConfig {
54 pub num_variations: usize,
56 pub strategy: ParaphraseStrategy,
58 pub preserve_entities: bool,
60 pub min_similarity: f32,
62 pub max_replacement_ratio: f32,
64 pub aggressive: bool,
66}
67
68impl Default for ParaphraseConfig {
69 fn default() -> Self {
70 Self {
71 num_variations: 3,
72 strategy: ParaphraseStrategy::Hybrid,
73 preserve_entities: true,
74 min_similarity: 0.6,
75 max_replacement_ratio: 0.4,
76 aggressive: false,
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct ParaphraseResult {
84 pub text: String,
86 pub similarity: f32,
88 pub strategy_used: ParaphraseStrategy,
90 pub replacements: Vec<(String, String)>,
92}
93
94pub struct Paraphraser {
96 config: ParaphraseConfig,
97 tokenizer: Box<dyn Tokenizer>,
98 word2vec: Option<Word2Vec>,
99 synonym_map: HashMap<String, Vec<String>>,
100}
101
102impl Paraphraser {
103 pub fn new(config: ParaphraseConfig) -> Self {
105 Self {
106 config,
107 tokenizer: Box::new(WordTokenizer::default()),
108 word2vec: None,
109 synonym_map: Self::build_default_synonym_map(),
110 }
111 }
112
113 pub fn with_word2vec(mut self, model: Word2Vec) -> Self {
115 self.word2vec = Some(model);
116 self
117 }
118
119 pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer>) -> Self {
121 self.tokenizer = tokenizer;
122 self
123 }
124
125 pub fn paraphrase(&self, text: &str) -> Result<Vec<ParaphraseResult>> {
127 if text.trim().is_empty() {
128 return Err(TextError::InvalidInput("Input text is empty".into()));
129 }
130
131 let mut results = Vec::new();
132 let mut seen = HashSet::new();
133 seen.insert(text.to_lowercase());
134
135 let mut attempt = 0;
136 let max_attempts = self.config.num_variations * 3;
137
138 while results.len() < self.config.num_variations && attempt < max_attempts {
139 attempt += 1;
140
141 let strategy = if self.config.strategy == ParaphraseStrategy::Hybrid {
142 self.select_random_strategy()
144 } else {
145 self.config.strategy
146 };
147
148 let paraphrase_result = match strategy {
149 ParaphraseStrategy::Synonym => self.paraphrase_synonym(text)?,
150 ParaphraseStrategy::Restructure => self.paraphrase_restructure(text)?,
151 ParaphraseStrategy::BackTranslation => self.paraphrase_backtranslation(text)?,
152 ParaphraseStrategy::Hybrid => unreachable!(),
153 };
154
155 let paraphrase_lower = paraphrase_result.text.to_lowercase();
157 if !seen.contains(¶phrase_lower) && paraphrase_result.text != text {
158 seen.insert(paraphrase_lower);
159 results.push(paraphrase_result);
160 }
161 }
162
163 if results.is_empty() {
164 return Err(TextError::ProcessingError(
165 "Could not generate any valid paraphrases".to_string(),
166 ));
167 }
168
169 Ok(results)
170 }
171
172 fn paraphrase_synonym(&self, text: &str) -> Result<ParaphraseResult> {
174 let tokens = self.tokenizer.tokenize(text)?;
175 let mut rng = thread_rng();
176 let mut new_tokens = tokens.clone();
177 let mut replacements = Vec::new();
178
179 let max_replacements =
181 ((tokens.len() as f32 * self.config.max_replacement_ratio).ceil() as usize).max(1);
182
183 let mut replaced_count = 0;
184 let mut candidates: Vec<usize> = (0..tokens.len()).collect();
185
186 for i in (1..candidates.len()).rev() {
188 let j = (rng.random::<f32>() * (i + 1) as f32) as usize;
189 candidates.swap(i, j);
190 }
191
192 for &idx in candidates.iter() {
194 if replaced_count >= max_replacements {
195 break;
196 }
197
198 let word = &tokens[idx];
199
200 if word.len() <= 2 || !word.chars().any(|c| c.is_alphabetic()) {
202 continue;
203 }
204
205 if let Some(synonym) = self.find_synonym(word)? {
207 new_tokens[idx] = synonym.clone();
208 replacements.push((word.clone(), synonym));
209 replaced_count += 1;
210 }
211 }
212
213 let paraphrased_text = new_tokens.join(" ");
214 let similarity = self.calculate_similarity(text, ¶phrased_text);
215
216 Ok(ParaphraseResult {
217 text: paraphrased_text,
218 similarity,
219 strategy_used: ParaphraseStrategy::Synonym,
220 replacements,
221 })
222 }
223
224 fn paraphrase_restructure(&self, text: &str) -> Result<ParaphraseResult> {
226 let restructured = self.apply_restructuring_patterns(text)?;
227 let similarity = self.calculate_similarity(text, &restructured);
228
229 Ok(ParaphraseResult {
230 text: restructured,
231 similarity,
232 strategy_used: ParaphraseStrategy::Restructure,
233 replacements: vec![],
234 })
235 }
236
237 fn paraphrase_backtranslation(&self, text: &str) -> Result<ParaphraseResult> {
239 let transformed = self.apply_backtranslation_patterns(text)?;
240 let similarity = self.calculate_similarity(text, &transformed);
241
242 Ok(ParaphraseResult {
243 text: transformed,
244 similarity,
245 strategy_used: ParaphraseStrategy::BackTranslation,
246 replacements: vec![],
247 })
248 }
249
250 fn find_synonym(&self, word: &str) -> Result<Option<String>> {
252 let word_lower = word.to_lowercase();
253
254 if let Some(ref model) = self.word2vec {
256 if let Ok(similar_words) = model.most_similar(&word_lower, 5) {
257 if !similar_words.is_empty() {
258 let mut rng = thread_rng();
259 let idx = (rng.random::<f32>() * similar_words.len() as f32) as usize;
260 let selected = &similar_words[idx.min(similar_words.len() - 1)].0;
261 return Ok(Some(self.match_case(word, selected)));
262 }
263 }
264 }
265
266 if let Some(synonyms) = self.synonym_map.get(&word_lower) {
268 if !synonyms.is_empty() {
269 let mut rng = thread_rng();
270 let idx = (rng.random::<f32>() * synonyms.len() as f32) as usize;
271 let selected = &synonyms[idx.min(synonyms.len() - 1)];
272 return Ok(Some(self.match_case(word, selected)));
273 }
274 }
275
276 Ok(None)
277 }
278
279 fn match_case(&self, original: &str, replacement: &str) -> String {
281 if original.chars().all(|c| c.is_uppercase()) {
282 replacement.to_uppercase()
283 } else if original.chars().next().is_some_and(|c| c.is_uppercase()) {
284 let mut chars = replacement.chars();
285 match chars.next() {
286 None => String::new(),
287 Some(first) => first.to_uppercase().chain(chars).collect(),
288 }
289 } else {
290 replacement.to_lowercase()
291 }
292 }
293
294 fn apply_restructuring_patterns(&self, text: &str) -> Result<String> {
296 let mut rng = thread_rng();
297 let pattern_idx = (rng.random::<f32>() * 4.0) as usize;
298
299 let result = match pattern_idx {
300 0 => self.pattern_passive_to_active(text),
301 1 => self.pattern_clause_reorder(text),
302 2 => self.pattern_conjunction_variation(text),
303 _ => self.pattern_adverb_movement(text),
304 };
305
306 Ok(result)
307 }
308
309 fn pattern_passive_to_active(&self, text: &str) -> String {
311 if text.contains(" is ") && text.contains(" by ") {
314 let parts: Vec<&str> = text.split(" by ").collect();
315 if parts.len() == 2 {
316 let first_parts: Vec<&str> = parts[0].split(" is ").collect();
317 if first_parts.len() == 2 {
318 return format!(
319 "{} {} {}",
320 parts[1].trim(),
321 first_parts[1].trim(),
322 first_parts[0].trim()
323 );
324 }
325 }
326 }
327 text.to_string()
328 }
329
330 fn pattern_clause_reorder(&self, text: &str) -> String {
332 for conj in &[" and ", " but ", " or ", ", "] {
334 if text.contains(conj) {
335 let parts: Vec<&str> = text.splitn(2, conj).collect();
336 if parts.len() == 2 {
337 return format!("{}{}{}", parts[1].trim(), conj, parts[0].trim());
338 }
339 }
340 }
341 text.to_string()
342 }
343
344 fn pattern_conjunction_variation(&self, text: &str) -> String {
346 let replacements = [
347 (" and ", " as well as "),
348 (" but ", " however "),
349 (" because ", " since "),
350 (" so ", " therefore "),
351 ];
352
353 let mut result = text.to_string();
354 let mut rng = thread_rng();
355 let idx = (rng.random::<f32>() * replacements.len() as f32) as usize;
356 let (original, replacement) = replacements[idx.min(replacements.len() - 1)];
357
358 if result.contains(original) {
359 result = result.replacen(original, replacement, 1);
360 }
361
362 result
363 }
364
365 fn pattern_adverb_movement(&self, text: &str) -> String {
367 let tokens: Vec<&str> = text.split_whitespace().collect();
369 if tokens.len() < 3 {
370 return text.to_string();
371 }
372
373 for (i, token) in tokens.iter().enumerate() {
375 if token.ends_with("ly") && i > 0 {
376 let mut new_tokens = tokens.clone();
378 new_tokens.remove(i);
379 new_tokens.insert(0, token);
380 return new_tokens.join(" ");
381 }
382 }
383
384 text.to_string()
385 }
386
387 fn apply_backtranslation_patterns(&self, text: &str) -> Result<String> {
389 let patterns = [
391 self.pattern_article_variation(text),
392 self.pattern_preposition_variation(text),
393 self.pattern_tense_variation(text),
394 self.pattern_number_variation(text),
395 ];
396
397 let mut rng = thread_rng();
398 let idx = (rng.random::<f32>() * patterns.len() as f32) as usize;
399 Ok(patterns[idx.min(patterns.len() - 1)].clone())
400 }
401
402 fn pattern_article_variation(&self, text: &str) -> String {
404 let mut result = text.to_string();
405 let replacements = [(" a ", " the "), (" the ", " a "), (" an ", " the ")];
406
407 let mut rng = thread_rng();
408 let idx = (rng.random::<f32>() * replacements.len() as f32) as usize;
409 let (original, replacement) = replacements[idx.min(replacements.len() - 1)];
410
411 if result.contains(original) {
412 result = result.replacen(original, replacement, 1);
413 }
414
415 result
416 }
417
418 fn pattern_preposition_variation(&self, text: &str) -> String {
420 let replacements = [
421 (" on ", " upon "),
422 (" in ", " within "),
423 (" at ", " in "),
424 (" to ", " towards "),
425 ];
426
427 let mut result = text.to_string();
428 let mut rng = thread_rng();
429 let idx = (rng.random::<f32>() * replacements.len() as f32) as usize;
430 let (original, replacement) = replacements[idx.min(replacements.len() - 1)];
431
432 if result.contains(original) {
433 result = result.replacen(original, replacement, 1);
434 }
435
436 result
437 }
438
439 fn pattern_tense_variation(&self, text: &str) -> String {
441 let replacements = [
442 (" is ", " was "),
443 (" are ", " were "),
444 (" has ", " had "),
445 (" will ", " would "),
446 ];
447
448 let mut result = text.to_string();
449 let mut rng = thread_rng();
450 let idx = (rng.random::<f32>() * replacements.len() as f32) as usize;
451 let (original, replacement) = replacements[idx.min(replacements.len() - 1)];
452
453 if result.contains(original) {
454 result = result.replacen(original, replacement, 1);
455 }
456
457 result
458 }
459
460 fn pattern_number_variation(&self, text: &str) -> String {
462 let tokens: Vec<String> = text.split_whitespace().map(|s| s.to_string()).collect();
464 let mut new_tokens = tokens.clone();
465
466 for (i, token) in tokens.iter().enumerate() {
467 if token.ends_with('s') && token.len() > 2 && !token.ends_with("ss") {
468 new_tokens[i] = token[..token.len() - 1].to_string();
470 break;
471 } else if !token.ends_with('s') && token.chars().all(|c| c.is_alphabetic()) {
472 new_tokens[i] = format!("{}s", token);
474 break;
475 }
476 }
477
478 new_tokens.join(" ")
479 }
480
481 fn calculate_similarity(&self, text1: &str, text2: &str) -> f32 {
483 let tokens1: HashSet<String> = text1
485 .to_lowercase()
486 .split_whitespace()
487 .map(|s| s.to_string())
488 .collect();
489
490 let tokens2: HashSet<String> = text2
491 .to_lowercase()
492 .split_whitespace()
493 .map(|s| s.to_string())
494 .collect();
495
496 let intersection = tokens1.intersection(&tokens2).count();
497 let union = tokens1.union(&tokens2).count();
498
499 if union == 0 {
500 return 0.0;
501 }
502
503 intersection as f32 / union as f32
504 }
505
506 fn select_random_strategy(&self) -> ParaphraseStrategy {
508 let mut rng = thread_rng();
509 let val = rng.random::<f32>();
510
511 if val < 0.33 {
512 ParaphraseStrategy::Synonym
513 } else if val < 0.67 {
514 ParaphraseStrategy::Restructure
515 } else {
516 ParaphraseStrategy::BackTranslation
517 }
518 }
519
520 fn build_default_synonym_map() -> HashMap<String, Vec<String>> {
522 let mut map = HashMap::new();
523
524 map.insert(
526 "good".to_string(),
527 vec![
528 "excellent".to_string(),
529 "great".to_string(),
530 "fine".to_string(),
531 ],
532 );
533 map.insert(
534 "bad".to_string(),
535 vec![
536 "poor".to_string(),
537 "awful".to_string(),
538 "terrible".to_string(),
539 ],
540 );
541 map.insert(
542 "big".to_string(),
543 vec![
544 "large".to_string(),
545 "huge".to_string(),
546 "enormous".to_string(),
547 ],
548 );
549 map.insert(
550 "small".to_string(),
551 vec![
552 "tiny".to_string(),
553 "little".to_string(),
554 "minute".to_string(),
555 ],
556 );
557 map.insert(
558 "fast".to_string(),
559 vec![
560 "quick".to_string(),
561 "rapid".to_string(),
562 "swift".to_string(),
563 ],
564 );
565 map.insert(
566 "slow".to_string(),
567 vec![
568 "gradual".to_string(),
569 "leisurely".to_string(),
570 "sluggish".to_string(),
571 ],
572 );
573 map.insert(
574 "important".to_string(),
575 vec![
576 "significant".to_string(),
577 "crucial".to_string(),
578 "vital".to_string(),
579 ],
580 );
581 map.insert(
582 "easy".to_string(),
583 vec![
584 "simple".to_string(),
585 "effortless".to_string(),
586 "straightforward".to_string(),
587 ],
588 );
589 map.insert(
590 "difficult".to_string(),
591 vec![
592 "hard".to_string(),
593 "challenging".to_string(),
594 "complex".to_string(),
595 ],
596 );
597 map.insert(
598 "beautiful".to_string(),
599 vec![
600 "lovely".to_string(),
601 "attractive".to_string(),
602 "gorgeous".to_string(),
603 ],
604 );
605 map.insert(
606 "happy".to_string(),
607 vec![
608 "joyful".to_string(),
609 "cheerful".to_string(),
610 "delighted".to_string(),
611 ],
612 );
613 map.insert(
614 "sad".to_string(),
615 vec![
616 "unhappy".to_string(),
617 "sorrowful".to_string(),
618 "melancholy".to_string(),
619 ],
620 );
621 map.insert(
622 "smart".to_string(),
623 vec![
624 "intelligent".to_string(),
625 "clever".to_string(),
626 "bright".to_string(),
627 ],
628 );
629 map.insert(
630 "stupid".to_string(),
631 vec![
632 "foolish".to_string(),
633 "silly".to_string(),
634 "ignorant".to_string(),
635 ],
636 );
637 map.insert(
638 "old".to_string(),
639 vec![
640 "ancient".to_string(),
641 "aged".to_string(),
642 "elderly".to_string(),
643 ],
644 );
645 map.insert(
646 "new".to_string(),
647 vec![
648 "recent".to_string(),
649 "modern".to_string(),
650 "fresh".to_string(),
651 ],
652 );
653 map.insert(
654 "strong".to_string(),
655 vec![
656 "powerful".to_string(),
657 "robust".to_string(),
658 "sturdy".to_string(),
659 ],
660 );
661 map.insert(
662 "weak".to_string(),
663 vec![
664 "feeble".to_string(),
665 "frail".to_string(),
666 "fragile".to_string(),
667 ],
668 );
669 map.insert(
670 "clean".to_string(),
671 vec![
672 "spotless".to_string(),
673 "pristine".to_string(),
674 "immaculate".to_string(),
675 ],
676 );
677 map.insert(
678 "dirty".to_string(),
679 vec![
680 "filthy".to_string(),
681 "grimy".to_string(),
682 "soiled".to_string(),
683 ],
684 );
685 map.insert(
686 "quick".to_string(),
687 vec!["fast".to_string(), "rapid".to_string(), "swift".to_string()],
688 );
689 map.insert(
690 "lazy".to_string(),
691 vec![
692 "idle".to_string(),
693 "sluggish".to_string(),
694 "lethargic".to_string(),
695 ],
696 );
697 map.insert(
698 "jumps".to_string(),
699 vec![
700 "leaps".to_string(),
701 "hops".to_string(),
702 "bounds".to_string(),
703 ],
704 );
705 map.insert(
706 "brown".to_string(),
707 vec![
708 "tan".to_string(),
709 "chestnut".to_string(),
710 "tawny".to_string(),
711 ],
712 );
713 map.insert(
714 "over".to_string(),
715 vec![
716 "above".to_string(),
717 "across".to_string(),
718 "past".to_string(),
719 ],
720 );
721
722 map
723 }
724}
725
726#[cfg(test)]
727mod tests {
728 use super::*;
729
730 #[test]
731 fn test_paraphrase_basic() {
732 let config = ParaphraseConfig::default();
733 let paraphraser = Paraphraser::new(config);
734
735 let text = "The quick brown fox jumps over the lazy dog";
736 let result = paraphraser.paraphrase(text);
737 assert!(result.is_ok());
738
739 let paraphrases = result.expect("Test failure: paraphrasing should succeed");
740 assert!(!paraphrases.is_empty());
741 assert!(paraphrases[0].text != text);
742 }
743
744 #[test]
745 fn test_synonym_replacement() {
746 let config = ParaphraseConfig {
747 num_variations: 1,
748 strategy: ParaphraseStrategy::Synonym,
749 ..Default::default()
750 };
751 let paraphraser = Paraphraser::new(config);
752
753 let text = "This is a good example";
754 let result = paraphraser.paraphrase(text);
755 assert!(result.is_ok());
756
757 let paraphrases = result.expect("Test failure: paraphrasing should succeed");
758 assert!(!paraphrases.is_empty());
759 }
760
761 #[test]
762 fn test_case_matching() {
763 let config = ParaphraseConfig::default();
764 let paraphraser = Paraphraser::new(config);
765
766 assert_eq!(paraphraser.match_case("Good", "excellent"), "Excellent");
767 assert_eq!(paraphraser.match_case("GOOD", "excellent"), "EXCELLENT");
768 assert_eq!(paraphraser.match_case("good", "excellent"), "excellent");
769 }
770
771 #[test]
772 fn test_similarity_calculation() {
773 let config = ParaphraseConfig::default();
774 let paraphraser = Paraphraser::new(config);
775
776 let text1 = "the quick brown fox";
777 let text2 = "the quick brown fox";
778 let similarity = paraphraser.calculate_similarity(text1, text2);
779 assert!((similarity - 1.0).abs() < 0.001);
780
781 let text3 = "the slow white cat";
782 let similarity2 = paraphraser.calculate_similarity(text1, text3);
783 assert!(similarity2 < 1.0 && similarity2 > 0.0);
784 }
785
786 #[test]
787 fn test_empty_input() {
788 let config = ParaphraseConfig::default();
789 let paraphraser = Paraphraser::new(config);
790
791 let result = paraphraser.paraphrase("");
792 assert!(result.is_err());
793 }
794}