1use crate::error::{Result, TextError};
12use crate::tokenize::{Tokenizer, WordTokenizer};
13use crate::vectorize::{TfidfVectorizer, Vectorizer};
14use std::collections::{HashMap, HashSet};
15
16#[derive(Debug, Clone)]
22pub struct Keyword {
23 pub text: String,
25 pub score: f64,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum KeywordMethod {
32 TfIdf,
34 TextRank,
36 Rake,
38}
39
40pub fn extract_keywords(text: &str, method: KeywordMethod, top_k: usize) -> Result<Vec<Keyword>> {
53 match method {
54 KeywordMethod::TfIdf => {
55 let extractor = TfIdfKeywordExtractor::new();
56 extractor.extract(text, top_k)
57 }
58 KeywordMethod::TextRank => {
59 let extractor = TextRankKeywordExtractor::new();
60 extractor.extract(text, top_k)
61 }
62 KeywordMethod::Rake => {
63 let extractor = RakeKeywordExtractor::new();
64 extractor.extract(text, top_k)
65 }
66 }
67}
68
69pub struct TfIdfKeywordExtractor {
76 tokenizer: Box<dyn Tokenizer + Send + Sync>,
77 min_token_len: usize,
79}
80
81impl TfIdfKeywordExtractor {
82 pub fn new() -> Self {
84 Self {
85 tokenizer: Box::new(WordTokenizer::default()),
86 min_token_len: 2,
87 }
88 }
89
90 pub fn with_min_token_len(mut self, len: usize) -> Self {
92 self.min_token_len = len;
93 self
94 }
95
96 pub fn extract(&self, text: &str, top_k: usize) -> Result<Vec<Keyword>> {
98 if text.trim().is_empty() {
99 return Ok(Vec::new());
100 }
101
102 let sentences = split_sentences(text);
104 if sentences.is_empty() {
105 return Ok(Vec::new());
106 }
107
108 let sentence_refs: Vec<&str> = sentences.iter().map(|s| s.as_str()).collect();
109
110 let mut vectorizer = TfidfVectorizer::default();
111 vectorizer.fit(&sentence_refs)?;
112 let tfidf_matrix = vectorizer.transform_batch(&sentence_refs)?;
113
114 let vocab = build_vocabulary(&sentences, &*self.tokenizer)?;
116
117 let n_terms = tfidf_matrix.ncols();
119 let n_docs = tfidf_matrix.nrows();
120 if n_terms == 0 || n_docs == 0 {
121 return Ok(Vec::new());
122 }
123
124 let mut avg_scores: Vec<f64> = Vec::with_capacity(n_terms);
125 for col_idx in 0..n_terms {
126 let col_sum: f64 = (0..n_docs).map(|row| tfidf_matrix[[row, col_idx]]).sum();
127 avg_scores.push(col_sum / n_docs as f64);
128 }
129
130 let mut keyword_scores: Vec<Keyword> = Vec::new();
132 for (idx, &score) in avg_scores.iter().enumerate() {
133 if score <= 0.0 {
134 continue;
135 }
136 if let Some(term) = vocab.get(&idx) {
137 if term.len() >= self.min_token_len {
138 keyword_scores.push(Keyword {
139 text: term.clone(),
140 score,
141 });
142 }
143 }
144 }
145
146 keyword_scores.sort_by(|a, b| {
147 b.score
148 .partial_cmp(&a.score)
149 .unwrap_or(std::cmp::Ordering::Equal)
150 });
151 keyword_scores.truncate(top_k);
152 Ok(keyword_scores)
153 }
154}
155
156impl Default for TfIdfKeywordExtractor {
157 fn default() -> Self {
158 Self::new()
159 }
160}
161
162pub struct TextRankKeywordExtractor {
172 tokenizer: Box<dyn Tokenizer + Send + Sync>,
173 window_size: usize,
175 damping: f64,
177 max_iterations: usize,
179 convergence_threshold: f64,
181 min_token_len: usize,
183}
184
185impl TextRankKeywordExtractor {
186 pub fn new() -> Self {
188 Self {
189 tokenizer: Box::new(WordTokenizer::default()),
190 window_size: 4,
191 damping: 0.85,
192 max_iterations: 100,
193 convergence_threshold: 1e-5,
194 min_token_len: 2,
195 }
196 }
197
198 pub fn with_window_size(mut self, size: usize) -> Result<Self> {
200 if size < 2 {
201 return Err(TextError::InvalidInput(
202 "Window size must be at least 2".to_string(),
203 ));
204 }
205 self.window_size = size;
206 Ok(self)
207 }
208
209 pub fn with_damping(mut self, d: f64) -> Result<Self> {
211 if !(0.0..=1.0).contains(&d) {
212 return Err(TextError::InvalidInput(
213 "Damping factor must be between 0 and 1".to_string(),
214 ));
215 }
216 self.damping = d;
217 Ok(self)
218 }
219
220 pub fn extract(&self, text: &str, top_k: usize) -> Result<Vec<Keyword>> {
222 if text.trim().is_empty() {
223 return Ok(Vec::new());
224 }
225
226 let tokens = self.tokenizer.tokenize(text)?;
227 let filtered: Vec<String> = tokens
228 .into_iter()
229 .filter(|t| t.len() >= self.min_token_len && !is_stopword(t))
230 .collect();
231
232 if filtered.is_empty() {
233 return Ok(Vec::new());
234 }
235
236 let mut graph: HashMap<String, HashMap<String, f64>> = HashMap::new();
238 for window in filtered.windows(self.window_size) {
239 for i in 0..window.len() {
240 for j in (i + 1)..window.len() {
241 let a = &window[i];
242 let b = &window[j];
243 *graph
244 .entry(a.clone())
245 .or_default()
246 .entry(b.clone())
247 .or_insert(0.0) += 1.0;
248 *graph
249 .entry(b.clone())
250 .or_default()
251 .entry(a.clone())
252 .or_insert(0.0) += 1.0;
253 }
254 }
255 }
256
257 let nodes: Vec<String> = graph.keys().cloned().collect();
259 let n = nodes.len();
260 if n == 0 {
261 return Ok(Vec::new());
262 }
263
264 let node_idx: HashMap<&str, usize> = nodes
265 .iter()
266 .enumerate()
267 .map(|(i, w)| (w.as_str(), i))
268 .collect();
269
270 let mut scores = vec![1.0 / n as f64; n];
271
272 let out_sums: Vec<f64> = nodes
274 .iter()
275 .map(|node| {
276 graph
277 .get(node)
278 .map(|neighbors| neighbors.values().sum::<f64>())
279 .unwrap_or(0.0)
280 })
281 .collect();
282
283 for _ in 0..self.max_iterations {
284 let mut new_scores = vec![(1.0 - self.damping) / n as f64; n];
285
286 for (j, node_j) in nodes.iter().enumerate() {
287 if out_sums[j] <= 0.0 {
288 continue;
289 }
290 if let Some(neighbors) = graph.get(node_j) {
291 for (neighbor, weight) in neighbors {
292 if let Some(&i) = node_idx.get(neighbor.as_str()) {
293 new_scores[i] += self.damping * (weight / out_sums[j]) * scores[j];
294 }
295 }
296 }
297 }
298
299 let diff: f64 = scores
300 .iter()
301 .zip(new_scores.iter())
302 .map(|(a, b)| (a - b).abs())
303 .sum();
304
305 scores = new_scores;
306 if diff < self.convergence_threshold {
307 break;
308 }
309 }
310
311 let mut word_scores: HashMap<String, f64> = HashMap::new();
313 for (i, node) in nodes.iter().enumerate() {
314 word_scores.insert(node.clone(), scores[i]);
315 }
316
317 let all_tokens = self.tokenizer.tokenize(text)?;
319 let keywords = merge_adjacent_keywords(&all_tokens, &word_scores, top_k);
320
321 Ok(keywords)
322 }
323}
324
325impl Default for TextRankKeywordExtractor {
326 fn default() -> Self {
327 Self::new()
328 }
329}
330
331fn merge_adjacent_keywords(
334 tokens: &[String],
335 word_scores: &HashMap<String, f64>,
336 top_k: usize,
337) -> Vec<Keyword> {
338 let mut phrases: Vec<(Vec<String>, f64)> = Vec::new();
339 let mut current_phrase: Vec<String> = Vec::new();
340 let mut current_score: f64 = 0.0;
341
342 for token in tokens {
343 if let Some(&score) = word_scores.get(token) {
344 current_phrase.push(token.clone());
345 current_score += score;
346 } else {
347 if !current_phrase.is_empty() {
348 phrases.push((current_phrase.clone(), current_score));
349 current_phrase.clear();
350 current_score = 0.0;
351 }
352 }
353 }
354 if !current_phrase.is_empty() {
355 phrases.push((current_phrase, current_score));
356 }
357
358 let mut seen: HashSet<String> = HashSet::new();
360 let mut keywords: Vec<Keyword> = Vec::new();
361 for (words, score) in phrases {
362 let phrase_text = words.join(" ");
363 if seen.contains(&phrase_text) {
364 continue;
365 }
366 seen.insert(phrase_text.clone());
367 keywords.push(Keyword {
368 text: phrase_text,
369 score,
370 });
371 }
372
373 keywords.sort_by(|a, b| {
374 b.score
375 .partial_cmp(&a.score)
376 .unwrap_or(std::cmp::Ordering::Equal)
377 });
378 keywords.truncate(top_k);
379 keywords
380}
381
382pub struct RakeKeywordExtractor {
394 min_phrase_len: usize,
396 max_phrase_len: usize,
398 min_word_len: usize,
400}
401
402impl RakeKeywordExtractor {
403 pub fn new() -> Self {
405 Self {
406 min_phrase_len: 1,
407 max_phrase_len: 4,
408 min_word_len: 2,
409 }
410 }
411
412 pub fn with_min_phrase_len(mut self, len: usize) -> Self {
414 self.min_phrase_len = len;
415 self
416 }
417
418 pub fn with_max_phrase_len(mut self, len: usize) -> Self {
420 self.max_phrase_len = len;
421 self
422 }
423
424 pub fn extract(&self, text: &str, top_k: usize) -> Result<Vec<Keyword>> {
426 if text.trim().is_empty() {
427 return Ok(Vec::new());
428 }
429
430 let candidates = self.generate_candidates(text);
432 if candidates.is_empty() {
433 return Ok(Vec::new());
434 }
435
436 let mut word_freq: HashMap<String, f64> = HashMap::new();
438 let mut word_degree: HashMap<String, f64> = HashMap::new();
439
440 for phrase in &candidates {
441 let words: Vec<&str> = phrase
442 .split_whitespace()
443 .filter(|w| w.len() >= self.min_word_len)
444 .collect();
445 let degree = words.len() as f64;
446 for word in &words {
447 let w = word.to_lowercase();
448 *word_freq.entry(w.clone()).or_insert(0.0) += 1.0;
449 *word_degree.entry(w).or_insert(0.0) += degree;
450 }
451 }
452
453 let mut word_scores: HashMap<String, f64> = HashMap::new();
455 for (word, freq) in &word_freq {
456 let degree = word_degree.get(word).copied().unwrap_or(0.0);
457 if *freq > 0.0 {
458 word_scores.insert(word.clone(), degree / freq);
459 }
460 }
461
462 let mut phrase_scores: Vec<Keyword> = Vec::new();
464 let mut seen: HashSet<String> = HashSet::new();
465
466 for phrase in &candidates {
467 let normalized = phrase.to_lowercase();
468 if seen.contains(&normalized) {
469 continue;
470 }
471 seen.insert(normalized.clone());
472
473 let words: Vec<&str> = normalized
474 .split_whitespace()
475 .filter(|w| w.len() >= self.min_word_len)
476 .collect();
477 if words.is_empty() {
478 continue;
479 }
480
481 let score: f64 = words
482 .iter()
483 .map(|w| word_scores.get(*w).copied().unwrap_or(0.0))
484 .sum();
485
486 phrase_scores.push(Keyword {
487 text: normalized,
488 score,
489 });
490 }
491
492 phrase_scores.sort_by(|a, b| {
493 b.score
494 .partial_cmp(&a.score)
495 .unwrap_or(std::cmp::Ordering::Equal)
496 });
497 phrase_scores.truncate(top_k);
498 Ok(phrase_scores)
499 }
500
501 fn generate_candidates(&self, text: &str) -> Vec<String> {
504 let lower = text.to_lowercase();
505 let mut candidates: Vec<String> = Vec::new();
507 let mut current_phrase: Vec<String> = Vec::new();
508
509 for word in lower.split(|c: char| !c.is_alphanumeric() && c != '\'') {
510 let trimmed = word.trim();
511 if trimmed.is_empty() {
512 if !current_phrase.is_empty() {
513 self.add_candidate(&mut candidates, ¤t_phrase);
514 current_phrase.clear();
515 }
516 continue;
517 }
518
519 if is_stopword(trimmed) {
520 if !current_phrase.is_empty() {
521 self.add_candidate(&mut candidates, ¤t_phrase);
522 current_phrase.clear();
523 }
524 } else {
525 current_phrase.push(trimmed.to_string());
526 }
527 }
528
529 if !current_phrase.is_empty() {
530 self.add_candidate(&mut candidates, ¤t_phrase);
531 }
532
533 candidates
534 }
535
536 fn add_candidate(&self, candidates: &mut Vec<String>, phrase_words: &[String]) {
537 if phrase_words.len() < self.min_phrase_len || phrase_words.len() > self.max_phrase_len {
538 return;
539 }
540 let phrase = phrase_words.join(" ");
541 if phrase
542 .split_whitespace()
543 .any(|w| w.len() >= self.min_word_len)
544 {
545 candidates.push(phrase);
546 }
547 }
548}
549
550impl Default for RakeKeywordExtractor {
551 fn default() -> Self {
552 Self::new()
553 }
554}
555
556fn split_sentences(text: &str) -> Vec<String> {
562 let mut sentences = Vec::new();
563 let mut current = String::new();
564
565 for ch in text.chars() {
566 current.push(ch);
567 if ch == '.' || ch == '!' || ch == '?' {
568 let trimmed = current.trim().to_string();
569 if !trimmed.is_empty() {
570 sentences.push(trimmed);
571 }
572 current.clear();
573 }
574 }
575 let trimmed = current.trim().to_string();
576 if !trimmed.is_empty() {
577 sentences.push(trimmed);
578 }
579 sentences
580}
581
582fn build_vocabulary(
585 sentences: &[String],
586 tokenizer: &dyn Tokenizer,
587) -> Result<HashMap<usize, String>> {
588 let mut term_to_idx: HashMap<String, usize> = HashMap::new();
589 let mut next_idx: usize = 0;
590
591 for sentence in sentences {
592 let tokens = tokenizer.tokenize(sentence)?;
593 for token in tokens {
594 if let std::collections::hash_map::Entry::Vacant(e) = term_to_idx.entry(token) {
595 e.insert(next_idx);
596 next_idx += 1;
597 }
598 }
599 }
600
601 let idx_to_term: HashMap<usize, String> =
602 term_to_idx.into_iter().map(|(t, i)| (i, t)).collect();
603 Ok(idx_to_term)
604}
605
606fn is_stopword(word: &str) -> bool {
608 const STOPWORDS: &[&str] = &[
609 "a", "an", "the", "and", "or", "but", "if", "in", "on", "at", "to", "for", "of", "with",
610 "by", "from", "as", "is", "was", "are", "were", "been", "be", "have", "has", "had", "do",
611 "does", "did", "will", "would", "shall", "should", "may", "might", "must", "can", "could",
612 "not", "no", "nor", "so", "than", "that", "this", "these", "those", "it", "its", "i", "me",
613 "my", "we", "us", "our", "you", "your", "he", "him", "his", "she", "her", "they", "them",
614 "their", "what", "which", "who", "whom", "when", "where", "why", "how", "all", "each",
615 "every", "both", "few", "more", "most", "other", "some", "such", "only", "own", "same",
616 "also", "just", "about", "above", "after", "again", "against", "any", "because", "before",
617 "below", "between", "during", "further", "here", "into", "once", "out", "over", "then",
618 "there", "through", "under", "until", "up", "very", "while",
619 ];
620 STOPWORDS.contains(&word.to_lowercase().as_str())
621}
622
623#[cfg(test)]
628mod tests {
629 use super::*;
630
631 #[test]
634 fn test_tfidf_extracts_keywords() {
635 let text = "Machine learning is a powerful tool. \
636 Machine learning algorithms process data efficiently. \
637 Deep learning extends machine learning with neural networks.";
638 let keywords = extract_keywords(text, KeywordMethod::TfIdf, 5)
639 .expect("TF-IDF extraction should succeed");
640 assert!(!keywords.is_empty());
641 assert!(keywords.len() <= 5);
642 for pair in keywords.windows(2) {
644 assert!(pair[0].score >= pair[1].score);
645 }
646 }
647
648 #[test]
649 fn test_tfidf_empty_text() {
650 let result =
651 extract_keywords("", KeywordMethod::TfIdf, 5).expect("Empty text should not error");
652 assert!(result.is_empty());
653 }
654
655 #[test]
656 fn test_tfidf_single_sentence() {
657 let result = extract_keywords(
658 "Rust programming language is fast and safe.",
659 KeywordMethod::TfIdf,
660 3,
661 )
662 .expect("Single sentence should succeed");
663 assert!(!result.is_empty());
665 }
666
667 #[test]
668 fn test_tfidf_respects_top_k() {
669 let text = "Alpha beta gamma delta epsilon zeta eta theta iota kappa. \
670 Alpha beta gamma delta epsilon zeta eta theta iota kappa.";
671 let result =
672 extract_keywords(text, KeywordMethod::TfIdf, 3).expect("Extraction should succeed");
673 assert!(result.len() <= 3);
674 }
675
676 #[test]
677 fn test_tfidf_min_token_len() {
678 let extractor = TfIdfKeywordExtractor::new().with_min_token_len(5);
679 let text = "AI and ML are big. Artificial intelligence is growing.";
680 let result = extractor
681 .extract(text, 10)
682 .expect("Extraction should succeed");
683 for kw in &result {
685 for word in kw.text.split_whitespace() {
686 assert!(word.len() >= 5, "Word '{}' is too short", word);
687 }
688 }
689 }
690
691 #[test]
694 fn test_textrank_extracts_keywords() {
695 let text = "Natural language processing enables computers to understand human language. \
696 Text mining and information retrieval are subfields of natural language processing. \
697 Sentiment analysis determines the emotional tone of text.";
698 let keywords = extract_keywords(text, KeywordMethod::TextRank, 5)
699 .expect("TextRank extraction should succeed");
700 assert!(!keywords.is_empty());
701 assert!(keywords.len() <= 5);
702 }
703
704 #[test]
705 fn test_textrank_empty_text() {
706 let result =
707 extract_keywords("", KeywordMethod::TextRank, 5).expect("Empty text should not error");
708 assert!(result.is_empty());
709 }
710
711 #[test]
712 fn test_textrank_scores_descending() {
713 let text = "Graph algorithms are fundamental in computer science. \
714 PageRank is a famous graph algorithm. \
715 Many applications use graph-based methods.";
716 let keywords = extract_keywords(text, KeywordMethod::TextRank, 10)
717 .expect("TextRank extraction should succeed");
718 for pair in keywords.windows(2) {
719 assert!(pair[0].score >= pair[1].score);
720 }
721 }
722
723 #[test]
724 fn test_textrank_window_size() {
725 let extractor = TextRankKeywordExtractor::new()
726 .with_window_size(2)
727 .expect("Window size 2 should be valid");
728 let text = "Alpha beta gamma delta epsilon. Alpha beta gamma delta.";
729 let result = extractor
730 .extract(text, 5)
731 .expect("Extraction should succeed");
732 assert!(!result.is_empty());
733 }
734
735 #[test]
736 fn test_textrank_invalid_window() {
737 let result = TextRankKeywordExtractor::new().with_window_size(0);
738 assert!(result.is_err());
739 }
740
741 #[test]
744 fn test_rake_extracts_keywords() {
745 let text =
746 "Compatibility of systems of linear constraints over the set of natural numbers. \
747 Criteria of compatibility of a system of linear Diophantine equations.";
748 let keywords =
749 extract_keywords(text, KeywordMethod::Rake, 5).expect("RAKE extraction should succeed");
750 assert!(!keywords.is_empty());
751 assert!(keywords.len() <= 5);
752 }
753
754 #[test]
755 fn test_rake_empty_text() {
756 let result =
757 extract_keywords("", KeywordMethod::Rake, 5).expect("Empty text should not error");
758 assert!(result.is_empty());
759 }
760
761 #[test]
762 fn test_rake_phrase_scoring() {
763 let text = "Machine learning algorithms are important. \
764 Deep learning algorithms are even more powerful. \
765 Algorithms drive modern artificial intelligence.";
766 let keywords = extract_keywords(text, KeywordMethod::Rake, 10)
767 .expect("RAKE extraction should succeed");
768 assert!(!keywords.is_empty());
771 for pair in keywords.windows(2) {
772 assert!(pair[0].score >= pair[1].score);
773 }
774 }
775
776 #[test]
777 fn test_rake_stopword_splitting() {
778 let text = "The quick brown fox and the lazy dog.";
779 let extractor = RakeKeywordExtractor::new();
780 let candidates = extractor.generate_candidates(text);
781 for candidate in &candidates {
783 for word in candidate.split_whitespace() {
784 assert!(!is_stopword(word), "'{}' is a stopword", word);
785 }
786 }
787 }
788
789 #[test]
790 fn test_rake_max_phrase_len() {
791 let extractor = RakeKeywordExtractor::new().with_max_phrase_len(2);
792 let text = "Advanced machine learning algorithms improve natural language processing.";
793 let result = extractor
794 .extract(text, 10)
795 .expect("Extraction should succeed");
796 for kw in &result {
797 let word_count = kw.text.split_whitespace().count();
798 assert!(word_count <= 2, "Phrase '{}' exceeds max length", kw.text);
799 }
800 }
801
802 #[test]
805 fn test_all_methods_non_empty_for_real_text() {
806 let text = "Rust is a systems programming language focused on safety and performance. \
807 The Rust compiler prevents data races and memory errors at compile time. \
808 Many developers choose Rust for building reliable software.";
809 for method in &[
810 KeywordMethod::TfIdf,
811 KeywordMethod::TextRank,
812 KeywordMethod::Rake,
813 ] {
814 let keywords = extract_keywords(text, *method, 5).expect("Extraction should succeed");
815 assert!(
816 !keywords.is_empty(),
817 "Method {:?} returned empty for real text",
818 method
819 );
820 }
821 }
822
823 #[test]
824 fn test_all_methods_handle_whitespace_only() {
825 for method in &[
826 KeywordMethod::TfIdf,
827 KeywordMethod::TextRank,
828 KeywordMethod::Rake,
829 ] {
830 let result =
831 extract_keywords(" \t\n ", *method, 5).expect("Whitespace should not error");
832 assert!(result.is_empty());
833 }
834 }
835}