scirs2_text/
summarization.rs

1//! Text summarization module
2//!
3//! This module provides various algorithms for automatic text summarization.
4
5use crate::error::{Result, TextError};
6use crate::tokenize::Tokenizer;
7use crate::vectorize::{TfidfVectorizer, Vectorizer};
8use scirs2_core::ndarray::{Array1, Array2};
9use std::collections::HashSet;
10
11/// TextRank algorithm for extractive summarization
12pub struct TextRank {
13    /// Number of sentences to extract
14    num_sentences: usize,
15    /// Damping factor (usually 0.85)
16    damping_factor: f64,
17    /// Maximum iterations
18    max_iterations: usize,
19    /// Convergence threshold
20    threshold: f64,
21    /// Tokenizer for sentence splitting
22    sentencetokenizer: Box<dyn Tokenizer + Send + Sync>,
23}
24
25impl TextRank {
26    /// Create a new TextRank summarizer
27    pub fn new(_numsentences: usize) -> Self {
28        Self {
29            num_sentences: _numsentences,
30            damping_factor: 0.85,
31            max_iterations: 100,
32            threshold: 0.0001,
33            sentencetokenizer: Box::new(crate::tokenize::SentenceTokenizer::new()),
34        }
35    }
36
37    /// Set the damping factor
38    pub fn with_damping_factor(mut self, dampingfactor: f64) -> Result<Self> {
39        if !(0.0..=1.0).contains(&dampingfactor) {
40            return Err(TextError::InvalidInput(
41                "Damping _factor must be between 0 and 1".to_string(),
42            ));
43        }
44        self.damping_factor = dampingfactor;
45        Ok(self)
46    }
47
48    /// Extract summary from text
49    pub fn summarize(&self, text: &str) -> Result<String> {
50        let sentences: Vec<String> = self.sentencetokenizer.tokenize(text)?;
51
52        if sentences.is_empty() {
53            return Ok(String::new());
54        }
55
56        if sentences.len() <= self.num_sentences {
57            return Ok(text.to_string());
58        }
59
60        // Build similarity matrix
61        let similarity_matrix = self.build_similarity_matrix(&sentences)?;
62
63        // Apply PageRank algorithm
64        let scores = self.page_rank(&similarity_matrix)?;
65
66        // Select top sentences
67        let selected_indices = self.select_top_sentences(&scores);
68
69        // Reconstruct summary maintaining original order
70        let summary = self.reconstruct_summary(&sentences, &selected_indices);
71
72        Ok(summary)
73    }
74
75    /// Build similarity matrix between sentences
76    fn build_similarity_matrix(&self, sentences: &[String]) -> Result<Array2<f64>> {
77        let n = sentences.len();
78        let mut matrix = Array2::zeros((n, n));
79
80        // Use TF-IDF for sentence representation
81        let sentence_refs: Vec<&str> = sentences.iter().map(|s| s.as_ref()).collect();
82        let mut vectorizer = TfidfVectorizer::default();
83        vectorizer.fit(&sentence_refs)?;
84        let vectors = vectorizer.transform_batch(&sentence_refs)?;
85
86        // Calculate cosine similarity between all pairs
87        for i in 0..n {
88            for j in 0..n {
89                if i == j {
90                    matrix[[i, j]] = 0.0; // No self-loops
91                } else {
92                    let similarity = self
93                        .cosine_similarity(vectors.row(i).to_owned(), vectors.row(j).to_owned());
94                    matrix[[i, j]] = similarity;
95                }
96            }
97        }
98
99        Ok(matrix)
100    }
101
102    /// Calculate cosine similarity between two vectors
103    fn cosine_similarity(&self, vec1: Array1<f64>, vec2: Array1<f64>) -> f64 {
104        let dot_product = vec1.dot(&vec2);
105        let norm1 = vec1.dot(&vec1).sqrt();
106        let norm2 = vec2.dot(&vec2).sqrt();
107
108        if norm1 == 0.0 || norm2 == 0.0 {
109            0.0
110        } else {
111            dot_product / (norm1 * norm2)
112        }
113    }
114
115    /// Apply PageRank algorithm
116    fn page_rank(&self, matrix: &Array2<f64>) -> Result<Array1<f64>> {
117        let n = matrix.nrows();
118        let mut scores = Array1::from_elem(n, 1.0 / n as f64);
119
120        // Normalize rows of similarity matrix
121        let mut normalized_matrix = matrix.clone();
122        for i in 0..n {
123            let row_sum: f64 = matrix.row(i).sum();
124            if row_sum > 0.0 {
125                normalized_matrix.row_mut(i).mapv_inplace(|x| x / row_sum);
126            }
127        }
128
129        // Iterate until convergence
130        for _ in 0..self.max_iterations {
131            let new_scores = Array1::from_elem(n, (1.0 - self.damping_factor) / n as f64)
132                + self.damping_factor * normalized_matrix.t().dot(&scores);
133
134            // Check convergence
135            let diff = (&new_scores - &scores).mapv(f64::abs).sum();
136            scores = new_scores;
137
138            if diff < self.threshold {
139                break;
140            }
141        }
142
143        Ok(scores)
144    }
145
146    /// Select top scoring sentences
147    fn select_top_sentences(&self, scores: &Array1<f64>) -> Vec<usize> {
148        let mut indexed_scores: Vec<(usize, f64)> = scores
149            .iter()
150            .enumerate()
151            .map(|(i, &score)| (i, score))
152            .collect();
153
154        indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
155
156        indexed_scores
157            .iter()
158            .take(self.num_sentences)
159            .map(|&(idx_, _)| idx_)
160            .collect()
161    }
162
163    /// Reconstruct summary maintaining original order
164    fn reconstruct_summary(&self, sentences: &[String], indices: &[usize]) -> String {
165        let mut sorted_indices = indices.to_vec();
166        sorted_indices.sort_unstable();
167
168        sorted_indices
169            .iter()
170            .map(|&idx| sentences[idx].clone())
171            .collect::<Vec<_>>()
172            .join(" ")
173    }
174}
175
176/// Centroid-based summarization
177pub struct CentroidSummarizer {
178    /// Number of sentences to extract
179    num_sentences: usize,
180    /// Topic threshold
181    topic_threshold: f64,
182    /// Redundancy threshold
183    redundancy_threshold: f64,
184    /// Sentence tokenizer
185    sentencetokenizer: Box<dyn Tokenizer + Send + Sync>,
186}
187
188impl CentroidSummarizer {
189    /// Create a new centroid summarizer
190    pub fn new(_numsentences: usize) -> Self {
191        Self {
192            num_sentences: _numsentences,
193            topic_threshold: 0.1,
194            redundancy_threshold: 0.95,
195            sentencetokenizer: Box::new(crate::tokenize::SentenceTokenizer::new()),
196        }
197    }
198
199    /// Summarize text using centroid method
200    pub fn summarize(&self, text: &str) -> Result<String> {
201        let sentences: Vec<String> = self.sentencetokenizer.tokenize(text)?;
202
203        if sentences.is_empty() {
204            return Ok(String::new());
205        }
206
207        if sentences.len() <= self.num_sentences {
208            return Ok(text.to_string());
209        }
210
211        // Create TF-IDF vectors
212        let sentence_refs: Vec<&str> = sentences.iter().map(|s| s.as_ref()).collect();
213        let mut vectorizer = TfidfVectorizer::default();
214        vectorizer.fit(&sentence_refs)?;
215        let vectors = vectorizer.transform_batch(&sentence_refs)?;
216
217        // Calculate centroid
218        let centroid = self.calculate_centroid(&vectors);
219
220        // Select sentences closest to centroid
221        let selected_indices = self.select_sentences(&vectors, &centroid);
222
223        // Reconstruct summary
224        let summary = self.reconstruct_summary(&sentences, &selected_indices);
225
226        Ok(summary)
227    }
228
229    /// Calculate document centroid
230    fn calculate_centroid(&self, vectors: &Array2<f64>) -> Array1<f64> {
231        let _n_docs = vectors.nrows();
232        let mut centroid = vectors.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap();
233
234        // Apply topic threshold
235        centroid.mapv_inplace(|x| if x > self.topic_threshold { x } else { 0.0 });
236
237        centroid
238    }
239
240    /// Select sentences based on centroid similarity
241    fn select_sentences(&self, vectors: &Array2<f64>, centroid: &Array1<f64>) -> Vec<usize> {
242        let mut selected = Vec::new();
243        let mut used_sentences = HashSet::new();
244
245        // Calculate similarities to centroid
246        let mut similarities: Vec<(usize, f64)> = Vec::new();
247        for i in 0..vectors.nrows() {
248            let similarity = self.cosine_similarity(vectors.row(i).to_owned(), centroid.clone());
249            similarities.push((i, similarity));
250        }
251
252        // Sort by similarity
253        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
254
255        // Select sentences avoiding redundancy
256        for (idx_, _similarity) in similarities {
257            if selected.len() >= self.num_sentences {
258                break;
259            }
260
261            // Check redundancy with already selected sentences
262            let mut is_redundant = false;
263            for &selected_idx in &selected {
264                let sim = self.cosine_similarity(
265                    vectors.row(idx_).to_owned(),
266                    vectors.row(selected_idx).to_owned(),
267                );
268                if sim > self.redundancy_threshold {
269                    is_redundant = true;
270                    break;
271                }
272            }
273
274            if !is_redundant {
275                selected.push(idx_);
276                used_sentences.insert(idx_);
277            }
278        }
279
280        selected
281    }
282
283    /// Calculate cosine similarity
284    fn cosine_similarity(&self, vec1: Array1<f64>, vec2: Array1<f64>) -> f64 {
285        let dot_product = vec1.dot(&vec2);
286        let norm1 = vec1.dot(&vec1).sqrt();
287        let norm2 = vec2.dot(&vec2).sqrt();
288
289        if norm1 == 0.0 || norm2 == 0.0 {
290            0.0
291        } else {
292            dot_product / (norm1 * norm2)
293        }
294    }
295
296    /// Reconstruct summary maintaining original order
297    fn reconstruct_summary(&self, sentences: &[String], indices: &[usize]) -> String {
298        let mut sorted_indices = indices.to_vec();
299        sorted_indices.sort_unstable();
300
301        sorted_indices
302            .iter()
303            .map(|&idx| sentences[idx].clone())
304            .collect::<Vec<_>>()
305            .join(" ")
306    }
307}
308
309/// Keyword extraction using TF-IDF
310pub struct KeywordExtractor {
311    /// Number of keywords to extract
312    _numkeywords: usize,
313    /// Minimum document frequency
314    #[allow(dead_code)]
315    min_df: f64,
316    /// Maximum document frequency
317    #[allow(dead_code)]
318    max_df: f64,
319    /// N-gram range
320    ngram_range: (usize, usize),
321}
322
323impl KeywordExtractor {
324    /// Create a new keyword extractor
325    pub fn new(_numkeywords: usize) -> Self {
326        Self {
327            _numkeywords,
328            min_df: 0.01, // Unused but kept for API compatibility
329            max_df: 0.95, // Unused but kept for API compatibility
330            ngram_range: (1, 3),
331        }
332    }
333
334    /// Configure n-gram range
335    pub fn with_ngram_range(mut self, min_n: usize, maxn: usize) -> Result<Self> {
336        if min_n > maxn || min_n == 0 {
337            return Err(TextError::InvalidInput("Invalid _n-gram range".to_string()));
338        }
339        self.ngram_range = (min_n, maxn);
340        Ok(self)
341    }
342
343    /// Extract keywords from text
344    pub fn extract_keywords(&self, text: &str) -> Result<Vec<(String, f64)>> {
345        // Split into sentences for better TF-IDF
346        let sentence_tokenizer = crate::tokenize::SentenceTokenizer::new();
347        let sentences = sentence_tokenizer.tokenize(text)?;
348
349        if sentences.is_empty() {
350            return Ok(Vec::new());
351        }
352
353        let sentence_refs: Vec<&str> = sentences.iter().map(|s| s.as_ref()).collect();
354
355        // Create enhanced TF-IDF vectorizer with n-grams
356        // Create vectorizer with ngram range configuration
357        let mut vectorizer = crate::enhanced_vectorize::EnhancedTfidfVectorizer::new()
358            .set_ngram_range((self.ngram_range.0, self.ngram_range.1))?;
359
360        vectorizer.fit(&sentence_refs)?;
361        let tfidf_matrix = vectorizer.transform_batch(&sentence_refs)?;
362
363        // Calculate average TF-IDF scores across documents
364        let avg_tfidf = tfidf_matrix
365            .mean_axis(scirs2_core::ndarray::Axis(0))
366            .unwrap();
367
368        // Get terms from the tokenizer directly
369        let all_words: Vec<String> = text.split_whitespace().map(|w| w.to_string()).collect();
370
371        // Create keyword-score pairs (use top scoring features)
372        let mut keyword_scores: Vec<(String, f64)> = avg_tfidf
373            .iter()
374            .enumerate()
375            .take(self._numkeywords * 2) // Get more than needed to filter
376            .map(|(i, &score)| {
377                let term = if i < all_words.len() {
378                    all_words[i].clone()
379                } else {
380                    format!("term_{i}")
381                };
382                (term, score)
383            })
384            .collect();
385
386        // Sort by score
387        keyword_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
388
389        // Return top keywords
390        Ok(keyword_scores.into_iter().take(self._numkeywords).collect())
391    }
392
393    /// Extract keywords with position information
394    pub fn extract_keywords_with_positions(
395        &self,
396        text: &str,
397    ) -> Result<Vec<(String, f64, Vec<usize>)>> {
398        let keywords = self.extract_keywords(text)?;
399        let mut results = Vec::new();
400
401        for (keyword, score) in keywords {
402            let positions = self.find_keyword_positions(text, &keyword);
403            results.push((keyword, score, positions));
404        }
405
406        Ok(results)
407    }
408
409    /// Find positions of a keyword in text
410    fn find_keyword_positions(&self, text: &str, keyword: &str) -> Vec<usize> {
411        let mut positions = Vec::new();
412        let text_lower = text.to_lowercase();
413        let keyword_lower = keyword.to_lowercase();
414
415        let mut start = 0;
416        while let Some(pos) = text_lower[start..].find(&keyword_lower) {
417            positions.push(start + pos);
418            start += pos + keyword.len();
419        }
420
421        positions
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428
429    #[test]
430    fn testtextrank_summarizer() {
431        let summarizer = TextRank::new(2);
432        let text = "Machine learning is a subset of artificial intelligence. \
433                    It enables computers to learn from data. \
434                    Deep learning is a subset of machine learning. \
435                    Neural networks are used in deep learning. \
436                    These technologies are transforming many industries.";
437
438        let summary = summarizer.summarize(text).unwrap();
439        assert!(!summary.is_empty());
440        assert!(summary.len() < text.len());
441    }
442
443    #[test]
444    fn test_centroid_summarizer() {
445        let summarizer = CentroidSummarizer::new(2);
446        let text = "Natural language processing is important. \
447                    It helps computers understand human language. \
448                    Many applications use NLP technology. \
449                    Chatbots and translation are examples. \
450                    NLP continues to evolve rapidly.";
451
452        let summary = summarizer.summarize(text).unwrap();
453        assert!(!summary.is_empty());
454    }
455
456    #[test]
457    fn test_keyword_extraction() {
458        let extractor = KeywordExtractor::new(5);
459        let text = "Machine learning algorithms are essential for artificial intelligence. \
460                    Deep learning models use neural networks. \
461                    These models can process complex data patterns.";
462
463        let keywords = extractor.extract_keywords(text).unwrap();
464        assert!(!keywords.is_empty());
465        assert!(keywords.len() <= 5);
466
467        // Check that scores are in descending order
468        for i in 1..keywords.len() {
469            assert!(keywords[i - 1].1 >= keywords[i].1);
470        }
471    }
472
473    #[test]
474    fn test_keyword_positions() {
475        let extractor = KeywordExtractor::new(3);
476        let text = "Machine learning is great. Machine learning transforms industries.";
477
478        let keywords_with_pos = extractor.extract_keywords_with_positions(text).unwrap();
479
480        // Should find positions for repeated keywords
481        for (keyword, _score, positions) in keywords_with_pos {
482            if keyword.to_lowercase().contains("machine learning") {
483                assert!(positions.len() >= 2);
484            }
485        }
486    }
487
488    #[test]
489    fn test_emptytext() {
490        let textrank = TextRank::new(3);
491        let centroid = CentroidSummarizer::new(3);
492        let keywords = KeywordExtractor::new(5);
493
494        assert_eq!(textrank.summarize("").unwrap(), "");
495        assert_eq!(centroid.summarize("").unwrap(), "");
496        assert_eq!(keywords.extract_keywords("").unwrap().len(), 0);
497    }
498
499    #[test]
500    fn test_shorttext() {
501        let summarizer = TextRank::new(5);
502        let shorttext = "This is a short text.";
503
504        let summary = summarizer.summarize(shorttext).unwrap();
505        assert_eq!(summary, shorttext);
506    }
507}