Skip to main content

scirs2_text/
text_summarization.rs

1//! Extractive text summarization module
2//!
3//! This module provides multiple sentence-scoring strategies for extractive
4//! summarization:
5//!
6//! - **TextRank**: Graph-based sentence ranking via PageRank on a similarity
7//!   matrix.
8//! - **Position-based**: Lead-bias and coverage heuristics.
9//! - **TF-IDF**: Sentences scored by their average TF-IDF weight.
10//! - **Ensemble**: Weighted combination of multiple scoring methods.
11//!
12//! All methods are accessible through the unified [`summarize`] function.
13
14use crate::error::{Result, TextError};
15use crate::tokenize::{SentenceTokenizer, Tokenizer, WordTokenizer};
16use crate::vectorize::{TfidfVectorizer, Vectorizer};
17use scirs2_core::ndarray::{Array1, Array2, Axis};
18use std::collections::{HashMap, HashSet};
19
20// ---------------------------------------------------------------------------
21// Types
22// ---------------------------------------------------------------------------
23
24/// Summarization method selector.
25#[derive(Debug, Clone, Copy, PartialEq)]
26pub enum SummarizationMethod {
27    /// TextRank sentence scoring.
28    TextRank,
29    /// Position-based scoring (lead bias).
30    Position,
31    /// TF-IDF sentence scoring.
32    TfIdf,
33    /// Ensemble of all three methods with customisable weights.
34    Ensemble {
35        /// Weight for TextRank component.
36        textrank_weight: f64,
37        /// Weight for position component.
38        position_weight: f64,
39        /// Weight for TF-IDF component.
40        tfidf_weight: f64,
41    },
42}
43
44/// A scored sentence, carrying its index in the original text.
45#[derive(Debug, Clone)]
46pub struct ScoredSentence {
47    /// The original sentence text.
48    pub text: String,
49    /// Zero-based index in the original sentence list.
50    pub index: usize,
51    /// Relevance score (higher is more important).
52    pub score: f64,
53}
54
55// ---------------------------------------------------------------------------
56// Unified API
57// ---------------------------------------------------------------------------
58
59/// Produce an extractive summary of `text`.
60///
61/// `ratio` controls how much of the text to keep (0.0..1.0). A ratio of 0.3
62/// means roughly 30% of sentences will be selected.
63///
64/// Returns the summary as a single string with selected sentences in their
65/// original order.
66///
67/// # Errors
68///
69/// Returns an error if tokenization or vectorization fails.
70pub fn summarize(text: &str, ratio: f64, method: SummarizationMethod) -> Result<String> {
71    if text.trim().is_empty() {
72        return Ok(String::new());
73    }
74
75    let clamped_ratio = ratio.clamp(0.0, 1.0);
76
77    let sentence_tokenizer = SentenceTokenizer::new();
78    let sentences: Vec<String> = sentence_tokenizer.tokenize(text)?;
79
80    if sentences.is_empty() {
81        return Ok(String::new());
82    }
83
84    let n_select = (sentences.len() as f64 * clamped_ratio).ceil().max(1.0) as usize;
85
86    if n_select >= sentences.len() {
87        return Ok(text.to_string());
88    }
89
90    let scored = match method {
91        SummarizationMethod::TextRank => score_textrank(&sentences)?,
92        SummarizationMethod::Position => score_position(&sentences),
93        SummarizationMethod::TfIdf => score_tfidf(&sentences)?,
94        SummarizationMethod::Ensemble {
95            textrank_weight,
96            position_weight,
97            tfidf_weight,
98        } => score_ensemble(&sentences, textrank_weight, position_weight, tfidf_weight)?,
99    };
100
101    // Select top-n sentences.
102    let mut top: Vec<ScoredSentence> = scored;
103    top.sort_by(|a, b| {
104        b.score
105            .partial_cmp(&a.score)
106            .unwrap_or(std::cmp::Ordering::Equal)
107    });
108    top.truncate(n_select);
109
110    // Restore original order for readability.
111    top.sort_by_key(|s| s.index);
112
113    let summary = top
114        .iter()
115        .map(|s| s.text.clone())
116        .collect::<Vec<_>>()
117        .join(" ");
118
119    Ok(summary)
120}
121
122// ---------------------------------------------------------------------------
123// TextRank scoring
124// ---------------------------------------------------------------------------
125
126/// Score sentences using TextRank (PageRank over a TF-IDF cosine similarity
127/// graph).
128pub fn score_textrank(sentences: &[String]) -> Result<Vec<ScoredSentence>> {
129    let n = sentences.len();
130    if n == 0 {
131        return Ok(Vec::new());
132    }
133    if n == 1 {
134        return Ok(vec![ScoredSentence {
135            text: sentences[0].clone(),
136            index: 0,
137            score: 1.0,
138        }]);
139    }
140
141    let similarity_matrix = build_similarity_matrix(sentences)?;
142    let scores = pagerank(&similarity_matrix, 0.85, 100, 1e-5)?;
143
144    Ok(sentences
145        .iter()
146        .enumerate()
147        .map(|(i, s)| ScoredSentence {
148            text: s.clone(),
149            index: i,
150            score: scores[i],
151        })
152        .collect())
153}
154
155/// Build a sentence-by-sentence cosine similarity matrix using TF-IDF.
156fn build_similarity_matrix(sentences: &[String]) -> Result<Array2<f64>> {
157    let refs: Vec<&str> = sentences.iter().map(|s| s.as_str()).collect();
158    let mut vectorizer = TfidfVectorizer::default();
159    vectorizer.fit(&refs)?;
160    let tfidf = vectorizer.transform_batch(&refs)?;
161
162    let n = sentences.len();
163    let mut matrix = Array2::zeros((n, n));
164
165    for i in 0..n {
166        for j in (i + 1)..n {
167            let sim = cosine_sim_rows(&tfidf, i, j);
168            matrix[[i, j]] = sim;
169            matrix[[j, i]] = sim;
170        }
171    }
172
173    Ok(matrix)
174}
175
176/// Cosine similarity between row `i` and row `j` of a matrix.
177fn cosine_sim_rows(matrix: &Array2<f64>, i: usize, j: usize) -> f64 {
178    let cols = matrix.ncols();
179    let mut dot = 0.0_f64;
180    let mut n1 = 0.0_f64;
181    let mut n2 = 0.0_f64;
182
183    for c in 0..cols {
184        let a = matrix[[i, c]];
185        let b = matrix[[j, c]];
186        dot += a * b;
187        n1 += a * a;
188        n2 += b * b;
189    }
190
191    let denom = n1.sqrt() * n2.sqrt();
192    if denom == 0.0 {
193        0.0
194    } else {
195        dot / denom
196    }
197}
198
199/// PageRank on a similarity matrix.
200fn pagerank(
201    matrix: &Array2<f64>,
202    damping: f64,
203    max_iter: usize,
204    threshold: f64,
205) -> Result<Vec<f64>> {
206    let n = matrix.nrows();
207    let mut scores = vec![1.0 / n as f64; n];
208
209    // Row-normalise the matrix.
210    let mut norm_matrix = matrix.clone();
211    for i in 0..n {
212        let row_sum: f64 = (0..n).map(|j| matrix[[i, j]]).sum();
213        if row_sum > 0.0 {
214            for j in 0..n {
215                norm_matrix[[i, j]] = matrix[[i, j]] / row_sum;
216            }
217        }
218    }
219
220    for _ in 0..max_iter {
221        let mut new_scores = vec![(1.0 - damping) / n as f64; n];
222
223        for i in 0..n {
224            for j in 0..n {
225                new_scores[i] += damping * norm_matrix[[j, i]] * scores[j];
226            }
227        }
228
229        let diff: f64 = scores
230            .iter()
231            .zip(new_scores.iter())
232            .map(|(a, b)| (a - b).abs())
233            .sum();
234
235        scores = new_scores;
236        if diff < threshold {
237            break;
238        }
239    }
240
241    Ok(scores)
242}
243
244// ---------------------------------------------------------------------------
245// Position-based scoring
246// ---------------------------------------------------------------------------
247
248/// Score sentences by position: early sentences and the last sentence receive
249/// a boost (lead-bias heuristic commonly used in news articles).
250pub fn score_position(sentences: &[String]) -> Vec<ScoredSentence> {
251    let n = sentences.len();
252    if n == 0 {
253        return Vec::new();
254    }
255
256    sentences
257        .iter()
258        .enumerate()
259        .map(|(i, s)| {
260            // Lead sentences score highest, then the conclusion, middle is lowest.
261            let position_score = if n == 1 {
262                1.0
263            } else {
264                let lead_score = 1.0 - (i as f64 / n as f64);
265                let conclusion_bonus = if i == n - 1 { 0.2 } else { 0.0 };
266                // First sentence gets a small extra boost.
267                let first_bonus = if i == 0 { 0.15 } else { 0.0 };
268                lead_score + conclusion_bonus + first_bonus
269            };
270
271            // Longer sentences are somewhat more informative.
272            let word_count = s.split_whitespace().count() as f64;
273            let length_factor = (word_count.ln() + 1.0).min(3.0) / 3.0;
274
275            ScoredSentence {
276                text: s.clone(),
277                index: i,
278                score: position_score * length_factor,
279            }
280        })
281        .collect()
282}
283
284// ---------------------------------------------------------------------------
285// TF-IDF scoring
286// ---------------------------------------------------------------------------
287
288/// Score sentences by their average TF-IDF value.
289///
290/// Sentences with rare, important terms score higher.
291pub fn score_tfidf(sentences: &[String]) -> Result<Vec<ScoredSentence>> {
292    let n = sentences.len();
293    if n == 0 {
294        return Ok(Vec::new());
295    }
296    if n == 1 {
297        return Ok(vec![ScoredSentence {
298            text: sentences[0].clone(),
299            index: 0,
300            score: 1.0,
301        }]);
302    }
303
304    let refs: Vec<&str> = sentences.iter().map(|s| s.as_str()).collect();
305    let mut vectorizer = TfidfVectorizer::default();
306    let tfidf = vectorizer.fit_transform(&refs)?;
307
308    let cols = tfidf.ncols();
309    if cols == 0 {
310        return Ok(sentences
311            .iter()
312            .enumerate()
313            .map(|(i, s)| ScoredSentence {
314                text: s.clone(),
315                index: i,
316                score: 0.0,
317            })
318            .collect());
319    }
320
321    Ok(sentences
322        .iter()
323        .enumerate()
324        .map(|(i, s)| {
325            let row_sum: f64 = (0..cols).map(|c| tfidf[[i, c]]).sum();
326            let avg = row_sum / cols as f64;
327            ScoredSentence {
328                text: s.clone(),
329                index: i,
330                score: avg,
331            }
332        })
333        .collect())
334}
335
336// ---------------------------------------------------------------------------
337// Ensemble scoring
338// ---------------------------------------------------------------------------
339
340/// Combine TextRank, position and TF-IDF scores with the given weights.
341fn score_ensemble(
342    sentences: &[String],
343    textrank_weight: f64,
344    position_weight: f64,
345    tfidf_weight: f64,
346) -> Result<Vec<ScoredSentence>> {
347    let n = sentences.len();
348    if n == 0 {
349        return Ok(Vec::new());
350    }
351
352    let tr_scores = score_textrank(sentences)?;
353    let pos_scores = score_position(sentences);
354    let tfidf_scores = score_tfidf(sentences)?;
355
356    // Normalise each set of scores to [0, 1].
357    let tr_normalised = normalise_scores(&tr_scores);
358    let pos_normalised = normalise_scores(&pos_scores);
359    let tfidf_normalised = normalise_scores(&tfidf_scores);
360
361    let total_weight = textrank_weight + position_weight + tfidf_weight;
362    let tw = if total_weight > 0.0 {
363        textrank_weight / total_weight
364    } else {
365        1.0 / 3.0
366    };
367    let pw = if total_weight > 0.0 {
368        position_weight / total_weight
369    } else {
370        1.0 / 3.0
371    };
372    let iw = if total_weight > 0.0 {
373        tfidf_weight / total_weight
374    } else {
375        1.0 / 3.0
376    };
377
378    Ok((0..n)
379        .map(|i| ScoredSentence {
380            text: sentences[i].clone(),
381            index: i,
382            score: tw * tr_normalised[i] + pw * pos_normalised[i] + iw * tfidf_normalised[i],
383        })
384        .collect())
385}
386
387/// Min-max normalise scores to [0, 1].
388fn normalise_scores(scored: &[ScoredSentence]) -> Vec<f64> {
389    if scored.is_empty() {
390        return Vec::new();
391    }
392
393    let min = scored.iter().map(|s| s.score).fold(f64::INFINITY, f64::min);
394    let max = scored
395        .iter()
396        .map(|s| s.score)
397        .fold(f64::NEG_INFINITY, f64::max);
398
399    let range = max - min;
400    if range < 1e-12 {
401        return vec![0.5; scored.len()];
402    }
403
404    scored.iter().map(|s| (s.score - min) / range).collect()
405}
406
407// ---------------------------------------------------------------------------
408// Tests
409// ---------------------------------------------------------------------------
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    const SAMPLE_TEXT: &str = "Machine learning is a subset of artificial intelligence. \
416        It enables computers to learn from data without explicit programming. \
417        Deep learning is a subset of machine learning that uses neural networks. \
418        Neural networks are modeled loosely after the human brain. \
419        These technologies are transforming many industries today.";
420
421    // ---- TextRank tests ----
422
423    #[test]
424    fn test_textrank_produces_shorter_summary() {
425        let summary =
426            summarize(SAMPLE_TEXT, 0.4, SummarizationMethod::TextRank).expect("Should succeed");
427        assert!(!summary.is_empty());
428        assert!(summary.len() < SAMPLE_TEXT.len());
429    }
430
431    #[test]
432    fn test_textrank_empty_text() {
433        let summary = summarize("", 0.5, SummarizationMethod::TextRank).expect("ok");
434        assert!(summary.is_empty());
435    }
436
437    #[test]
438    fn test_textrank_ratio_one_returns_full() {
439        let summary = summarize(SAMPLE_TEXT, 1.0, SummarizationMethod::TextRank).expect("ok");
440        assert_eq!(summary, SAMPLE_TEXT);
441    }
442
443    #[test]
444    fn test_textrank_ratio_zero_returns_one_sentence() {
445        let summary = summarize(SAMPLE_TEXT, 0.0, SummarizationMethod::TextRank).expect("ok");
446        // ratio clamped to 0.0, ceil(n*0) = 0 but max(1) = 1 sentence.
447        assert!(!summary.is_empty());
448        // Should be just one sentence.
449        let sentence_tokenizer = SentenceTokenizer::new();
450        let sentences = sentence_tokenizer.tokenize(&summary).expect("ok");
451        assert_eq!(sentences.len(), 1);
452    }
453
454    #[test]
455    fn test_textrank_single_sentence() {
456        let summary =
457            summarize("Just one sentence.", 0.5, SummarizationMethod::TextRank).expect("ok");
458        assert_eq!(summary, "Just one sentence.");
459    }
460
461    #[test]
462    fn test_textrank_scores_non_negative() {
463        let sentence_tokenizer = SentenceTokenizer::new();
464        let sentences = sentence_tokenizer.tokenize(SAMPLE_TEXT).expect("ok");
465        let scored = score_textrank(&sentences).expect("ok");
466        for s in &scored {
467            assert!(s.score >= 0.0, "Score should be non-negative");
468        }
469    }
470
471    // ---- Position tests ----
472
473    #[test]
474    fn test_position_first_sentence_highest() {
475        let sentence_tokenizer = SentenceTokenizer::new();
476        let sentences = sentence_tokenizer.tokenize(SAMPLE_TEXT).expect("ok");
477        let scored = score_position(&sentences);
478        // First sentence should have the highest score.
479        let first = &scored[0];
480        for s in &scored[1..] {
481            assert!(
482                first.score >= s.score,
483                "First sentence should have the highest position score"
484            );
485        }
486    }
487
488    #[test]
489    fn test_position_produces_summary() {
490        let summary = summarize(SAMPLE_TEXT, 0.4, SummarizationMethod::Position).expect("ok");
491        assert!(!summary.is_empty());
492        assert!(summary.len() < SAMPLE_TEXT.len());
493    }
494
495    #[test]
496    fn test_position_empty() {
497        let summary = summarize("", 0.5, SummarizationMethod::Position).expect("ok");
498        assert!(summary.is_empty());
499    }
500
501    #[test]
502    fn test_position_scores_non_negative() {
503        let sentence_tokenizer = SentenceTokenizer::new();
504        let sentences = sentence_tokenizer.tokenize(SAMPLE_TEXT).expect("ok");
505        let scored = score_position(&sentences);
506        for s in &scored {
507            assert!(s.score >= 0.0);
508        }
509    }
510
511    #[test]
512    fn test_position_last_sentence_has_conclusion_bonus() {
513        let sentence_tokenizer = SentenceTokenizer::new();
514        let sentences = sentence_tokenizer.tokenize(SAMPLE_TEXT).expect("ok");
515        let scored = score_position(&sentences);
516        let n = scored.len();
517        if n >= 2 {
518            // The last sentence gets a 0.2 conclusion bonus, so its score
519            // should be higher than a pure lead-bias formula without the bonus.
520            let last_score = scored[n - 1].score;
521            // Lead score alone at position n-1 would be 1.0 - (n-1)/n.
522            let lead_alone = 1.0 - ((n - 1) as f64 / n as f64);
523            // With conclusion bonus the score should exceed this baseline.
524            assert!(
525                last_score > lead_alone * 0.3,
526                "Last sentence should benefit from conclusion bonus"
527            );
528        }
529    }
530
531    // ---- TF-IDF tests ----
532
533    #[test]
534    fn test_tfidf_produces_summary() {
535        let summary = summarize(SAMPLE_TEXT, 0.4, SummarizationMethod::TfIdf).expect("ok");
536        assert!(!summary.is_empty());
537        assert!(summary.len() < SAMPLE_TEXT.len());
538    }
539
540    #[test]
541    fn test_tfidf_empty() {
542        let summary = summarize("", 0.5, SummarizationMethod::TfIdf).expect("ok");
543        assert!(summary.is_empty());
544    }
545
546    #[test]
547    fn test_tfidf_single_sentence() {
548        let summary = summarize("Only one.", 0.5, SummarizationMethod::TfIdf).expect("ok");
549        assert_eq!(summary, "Only one.");
550    }
551
552    #[test]
553    fn test_tfidf_scores_non_negative() {
554        let sentence_tokenizer = SentenceTokenizer::new();
555        let sentences = sentence_tokenizer.tokenize(SAMPLE_TEXT).expect("ok");
556        let scored = score_tfidf(&sentences).expect("ok");
557        for s in &scored {
558            assert!(s.score >= 0.0);
559        }
560    }
561
562    #[test]
563    fn test_tfidf_ratio_half() {
564        let summary = summarize(SAMPLE_TEXT, 0.5, SummarizationMethod::TfIdf).expect("ok");
565        let sentence_tokenizer = SentenceTokenizer::new();
566        let original = sentence_tokenizer.tokenize(SAMPLE_TEXT).expect("ok");
567        let summarised = sentence_tokenizer.tokenize(&summary).expect("ok");
568        // Should keep roughly half (ceil).
569        let expected = (original.len() as f64 * 0.5).ceil() as usize;
570        assert_eq!(summarised.len(), expected);
571    }
572
573    // ---- Ensemble tests ----
574
575    #[test]
576    fn test_ensemble_produces_summary() {
577        let method = SummarizationMethod::Ensemble {
578            textrank_weight: 1.0,
579            position_weight: 0.5,
580            tfidf_weight: 0.5,
581        };
582        let summary = summarize(SAMPLE_TEXT, 0.4, method).expect("ok");
583        assert!(!summary.is_empty());
584        assert!(summary.len() < SAMPLE_TEXT.len());
585    }
586
587    #[test]
588    fn test_ensemble_equal_weights() {
589        let method = SummarizationMethod::Ensemble {
590            textrank_weight: 1.0,
591            position_weight: 1.0,
592            tfidf_weight: 1.0,
593        };
594        let summary = summarize(SAMPLE_TEXT, 0.4, method).expect("ok");
595        assert!(!summary.is_empty());
596    }
597
598    #[test]
599    fn test_ensemble_zero_weights_defaults() {
600        let method = SummarizationMethod::Ensemble {
601            textrank_weight: 0.0,
602            position_weight: 0.0,
603            tfidf_weight: 0.0,
604        };
605        let summary = summarize(SAMPLE_TEXT, 0.4, method).expect("ok");
606        assert!(!summary.is_empty());
607    }
608
609    #[test]
610    fn test_ensemble_empty() {
611        let method = SummarizationMethod::Ensemble {
612            textrank_weight: 1.0,
613            position_weight: 1.0,
614            tfidf_weight: 1.0,
615        };
616        let summary = summarize("", 0.5, method).expect("ok");
617        assert!(summary.is_empty());
618    }
619
620    #[test]
621    fn test_ensemble_scores_bounded() {
622        let sentence_tokenizer = SentenceTokenizer::new();
623        let sentences = sentence_tokenizer.tokenize(SAMPLE_TEXT).expect("ok");
624        let scored = score_ensemble(&sentences, 1.0, 1.0, 1.0).expect("ok");
625        for s in &scored {
626            assert!(
627                s.score >= 0.0 && s.score <= 1.0,
628                "Ensemble scores should be in [0,1]"
629            );
630        }
631    }
632
633    // ---- Original-order tests ----
634
635    #[test]
636    fn test_summary_preserves_order() {
637        let summary = summarize(SAMPLE_TEXT, 0.6, SummarizationMethod::TextRank).expect("ok");
638        let sentence_tokenizer = SentenceTokenizer::new();
639        let summary_sentences = sentence_tokenizer.tokenize(&summary).expect("ok");
640        let original_sentences = sentence_tokenizer.tokenize(SAMPLE_TEXT).expect("ok");
641
642        // Each summary sentence should appear in order in the original.
643        let mut last_idx: Option<usize> = None;
644        for ss in &summary_sentences {
645            let idx = original_sentences
646                .iter()
647                .position(|os| os.trim() == ss.trim());
648            if let (Some(li), Some(ci)) = (last_idx, idx) {
649                assert!(ci > li, "Summary sentences should be in original order");
650            }
651            last_idx = idx;
652        }
653    }
654}