Skip to main content

scirs2_text/
summarize_advanced.rs

1//! Advanced extractive text summarization algorithms.
2//!
3//! Provides `TextRankSummarizer` (PageRank on sentence similarity graphs),
4//! `ExtractiveSummarizer` (lead-k and frequency-based strategies) and three
5//! sentence-similarity metrics: cosine TF-IDF, BM25, and Jaccard.
6
7use std::collections::HashMap;
8
9use crate::error::{Result, TextError};
10
11// ---------------------------------------------------------------------------
12// Sentence similarity
13// ---------------------------------------------------------------------------
14
15/// Algorithm used to compute inter-sentence similarity.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum SentenceSimilarity {
18    /// Cosine similarity over TF-IDF term vectors.
19    CosineTFIDF,
20    /// BM25 relevance scoring.
21    BM25,
22    /// Token-overlap Jaccard coefficient.
23    Jaccard,
24}
25
26// ---------------------------------------------------------------------------
27// Tokenisation helpers
28// ---------------------------------------------------------------------------
29
30/// Very simple sentence tokeniser: splits on `.`, `?`, `!` followed by space.
31fn split_sentences(text: &str) -> Vec<String> {
32    // Split on sentence-ending punctuation followed by whitespace or end
33    let mut sentences = Vec::new();
34    let mut current = String::new();
35    let chars: Vec<char> = text.chars().collect();
36    let n = chars.len();
37    let mut i = 0;
38    while i < n {
39        current.push(chars[i]);
40        if matches!(chars[i], '.' | '?' | '!') {
41            // Check if followed by space/end and the next char is uppercase
42            if i + 1 >= n || chars[i + 1].is_whitespace() {
43                let trimmed = current.trim().to_string();
44                if !trimmed.is_empty() {
45                    sentences.push(trimmed);
46                }
47                current = String::new();
48            }
49        }
50        i += 1;
51    }
52    let trimmed = current.trim().to_string();
53    if !trimmed.is_empty() {
54        sentences.push(trimmed);
55    }
56    sentences
57}
58
59/// Tokenise a sentence into lowercase words (alpha-only).
60fn tokenize_words(sentence: &str) -> Vec<String> {
61    sentence
62        .split(|c: char| !c.is_alphabetic())
63        .filter(|s| !s.is_empty())
64        .map(|s| s.to_lowercase())
65        .collect()
66}
67
68// ---------------------------------------------------------------------------
69// Similarity functions
70// ---------------------------------------------------------------------------
71
72/// Jaccard similarity between two sentences.
73fn jaccard(a: &str, b: &str) -> f64 {
74    let ta: std::collections::HashSet<String> = tokenize_words(a).into_iter().collect();
75    let tb: std::collections::HashSet<String> = tokenize_words(b).into_iter().collect();
76    let inter = ta.intersection(&tb).count();
77    let union = ta.union(&tb).count();
78    if union == 0 {
79        0.0
80    } else {
81        inter as f64 / union as f64
82    }
83}
84
85/// Build a TF map for one sentence.
86fn tf_map(sentence: &str) -> HashMap<String, f64> {
87    let words = tokenize_words(sentence);
88    let n = words.len() as f64;
89    if n == 0.0 {
90        return HashMap::new();
91    }
92    let mut counts: HashMap<String, usize> = HashMap::new();
93    for w in words {
94        *counts.entry(w).or_insert(0) += 1;
95    }
96    counts.into_iter().map(|(k, c)| (k, c as f64 / n)).collect()
97}
98
99/// Cosine similarity with IDF weighting across a sentence corpus.
100fn cosine_tfidf(a: &str, b: &str, idf: &HashMap<String, f64>) -> f64 {
101    let ta = tf_map(a);
102    let tb = tf_map(b);
103    let dot: f64 = ta
104        .iter()
105        .filter_map(|(w, &tfa)| {
106            tb.get(w).map(|&tfb| {
107                let idf_w = idf.get(w).copied().unwrap_or(1.0);
108                tfa * idf_w * tfb * idf_w
109            })
110        })
111        .sum();
112    let norm_a: f64 = ta
113        .values()
114        .map(|&v| {
115            let idf_w = idf.get(&String::new()).copied().unwrap_or(1.0);
116            (v * idf_w).powi(2)
117        })
118        .sum::<f64>()
119        .sqrt();
120    let norm_b: f64 = tb
121        .values()
122        .map(|&v| {
123            let idf_w = idf.get(&String::new()).copied().unwrap_or(1.0);
124            (v * idf_w).powi(2)
125        })
126        .sum::<f64>()
127        .sqrt();
128    if norm_a == 0.0 || norm_b == 0.0 {
129        return 0.0;
130    }
131    dot / (norm_a * norm_b)
132}
133
134/// BM25 score of `query_sentence` given `doc_sentence` as the document.
135fn bm25_similarity(
136    query: &str,
137    doc: &str,
138    avgdl: f64,
139    idf: &HashMap<String, f64>,
140    k1: f64,
141    b: f64,
142) -> f64 {
143    let query_words = tokenize_words(query);
144    let doc_words = tokenize_words(doc);
145    let dl = doc_words.len() as f64;
146    let mut freq_map: HashMap<&str, usize> = HashMap::new();
147    for w in &doc_words {
148        *freq_map.entry(w.as_str()).or_insert(0) += 1;
149    }
150    query_words
151        .iter()
152        .map(|w| {
153            let idf_w = idf.get(w).copied().unwrap_or(1.0);
154            let f = *freq_map.get(w.as_str()).unwrap_or(&0) as f64;
155            idf_w * (f * (k1 + 1.0)) / (f + k1 * (1.0 - b + b * dl / avgdl))
156        })
157        .sum()
158}
159
160/// Build IDF map from a corpus of sentences.
161fn build_idf(sentences: &[String]) -> HashMap<String, f64> {
162    let n = sentences.len() as f64;
163    let mut df: HashMap<String, usize> = HashMap::new();
164    for sent in sentences {
165        let words: std::collections::HashSet<String> = tokenize_words(sent).into_iter().collect();
166        for w in words {
167            *df.entry(w).or_insert(0) += 1;
168        }
169    }
170    df.into_iter()
171        .map(|(w, c)| (w, ((n + 1.0) / (c as f64 + 1.0)).ln() + 1.0))
172        .collect()
173}
174
175// ---------------------------------------------------------------------------
176// TextRankSummarizer
177// ---------------------------------------------------------------------------
178
179/// Extractive summariser based on the TextRank algorithm.
180///
181/// Sentences are represented as graph nodes; edges are weighted by
182/// sentence similarity.  PageRank is run to score each sentence.
183#[derive(Debug, Clone)]
184pub struct TextRankSummarizer {
185    /// PageRank damping factor.
186    pub damping: f64,
187    /// Number of PageRank iterations.
188    pub n_iterations: usize,
189    /// Similarity metric used to weight graph edges.
190    pub similarity: SentenceSimilarity,
191}
192
193impl Default for TextRankSummarizer {
194    fn default() -> Self {
195        TextRankSummarizer {
196            damping: 0.85,
197            n_iterations: 50,
198            similarity: SentenceSimilarity::CosineTFIDF,
199        }
200    }
201}
202
203impl TextRankSummarizer {
204    /// Create a summariser with custom parameters.
205    pub fn new(damping: f64, n_iterations: usize, similarity: SentenceSimilarity) -> Self {
206        TextRankSummarizer {
207            damping,
208            n_iterations,
209            similarity,
210        }
211    }
212
213    /// Summarise `text` by extracting the top `n_sentences` ranked sentences.
214    ///
215    /// Sentences are returned in their **original document order**.
216    pub fn summarize(&self, text: &str, n_sentences: usize) -> Result<String> {
217        let sentences = split_sentences(text);
218        if sentences.is_empty() {
219            return Ok(String::new());
220        }
221        let k = n_sentences.min(sentences.len());
222        if k == sentences.len() {
223            return Ok(text.to_string());
224        }
225
226        let idf = build_idf(&sentences);
227        let n = sentences.len();
228        let avgdl = sentences
229            .iter()
230            .map(|s| tokenize_words(s).len())
231            .sum::<usize>() as f64
232            / n as f64;
233
234        // Build adjacency matrix
235        let mut adj = vec![vec![0.0f64; n]; n];
236        for i in 0..n {
237            for j in 0..n {
238                if i == j {
239                    continue;
240                }
241                adj[i][j] = match self.similarity {
242                    SentenceSimilarity::Jaccard => jaccard(&sentences[i], &sentences[j]),
243                    SentenceSimilarity::CosineTFIDF => {
244                        cosine_tfidf(&sentences[i], &sentences[j], &idf)
245                    }
246                    SentenceSimilarity::BM25 => {
247                        bm25_similarity(&sentences[i], &sentences[j], avgdl, &idf, 1.5, 0.75)
248                    }
249                };
250            }
251        }
252
253        // Row-normalise adjacency
254        for row in adj.iter_mut() {
255            let total: f64 = row.iter().sum();
256            if total > 0.0 {
257                for v in row.iter_mut() {
258                    *v /= total;
259                }
260            }
261        }
262
263        // PageRank
264        let mut scores = vec![1.0 / n as f64; n];
265        for _ in 0..self.n_iterations {
266            let mut new_scores = vec![0.0f64; n];
267            for j in 0..n {
268                for i in 0..n {
269                    new_scores[j] += adj[i][j] * scores[i];
270                }
271                new_scores[j] = (1.0 - self.damping) / n as f64 + self.damping * new_scores[j];
272            }
273            scores = new_scores;
274        }
275
276        // Select top-k sentence indices
277        let mut ranked: Vec<(usize, f64)> = scores.iter().cloned().enumerate().collect();
278        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
279        let mut top_indices: Vec<usize> = ranked.iter().take(k).map(|&(i, _)| i).collect();
280        // Restore original order
281        top_indices.sort();
282
283        let summary = top_indices
284            .iter()
285            .map(|&i| sentences[i].as_str())
286            .collect::<Vec<_>>()
287            .join(" ");
288        Ok(summary)
289    }
290}
291
292// ---------------------------------------------------------------------------
293// ExtractiveSummarizer
294// ---------------------------------------------------------------------------
295
296/// Simple extractive summarisation strategies.
297pub struct ExtractiveSummarizer;
298
299impl ExtractiveSummarizer {
300    /// Return the first `k` sentences of `text`.
301    pub fn lead_k(text: &str, k: usize) -> Result<String> {
302        if k == 0 {
303            return Err(TextError::InvalidInput("k must be at least 1".to_string()));
304        }
305        let sentences = split_sentences(text);
306        let selected: Vec<&str> = sentences.iter().take(k).map(String::as_str).collect();
307        Ok(selected.join(" "))
308    }
309
310    /// Score sentences by aggregate word-frequency and return the top `k`.
311    ///
312    /// The frequency of each word is computed across the whole document;
313    /// stop-words (high-frequency function words) are down-weighted.
314    pub fn frequency_based(text: &str, k: usize) -> Result<String> {
315        if k == 0 {
316            return Err(TextError::InvalidInput("k must be at least 1".to_string()));
317        }
318        let sentences = split_sentences(text);
319        if sentences.is_empty() {
320            return Ok(String::new());
321        }
322
323        // Word frequency across the full document
324        let mut freq: HashMap<String, usize> = HashMap::new();
325        for sent in &sentences {
326            for w in tokenize_words(sent) {
327                *freq.entry(w).or_insert(0) += 1;
328            }
329        }
330        // Normalise
331        let max_freq = *freq.values().max().unwrap_or(&1) as f64;
332        let norm_freq: HashMap<String, f64> = freq
333            .into_iter()
334            .map(|(k, v)| (k, v as f64 / max_freq))
335            .collect();
336
337        // Score each sentence
338        let mut scored: Vec<(usize, f64)> = sentences
339            .iter()
340            .enumerate()
341            .map(|(i, sent)| {
342                let words = tokenize_words(sent);
343                let score: f64 = words
344                    .iter()
345                    .map(|w| norm_freq.get(w).copied().unwrap_or(0.0))
346                    .sum();
347                (i, score)
348            })
349            .collect();
350
351        // Take top-k in original order
352        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
353        let mut top_indices: Vec<usize> = scored.iter().take(k).map(|&(i, _)| i).collect();
354        top_indices.sort();
355
356        let result = top_indices
357            .iter()
358            .map(|&i| sentences[i].as_str())
359            .collect::<Vec<_>>()
360            .join(" ");
361        Ok(result)
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    const TEXT: &str = "The quick brown fox jumps over the lazy dog. \
370         A fox is a cunning animal. \
371         Dogs are loyal companions. \
372         Foxes live in dens and are mostly nocturnal. \
373         The dog slept all afternoon.";
374
375    #[test]
376    fn test_split_sentences() {
377        let sents = split_sentences(TEXT);
378        assert_eq!(sents.len(), 5);
379    }
380
381    #[test]
382    fn test_textrank_summarize_count() {
383        let summarizer = TextRankSummarizer::default();
384        let summary = summarizer.summarize(TEXT, 2).expect("summarize failed");
385        // A 2-sentence summary should contain at most 2 sentence-ending marks
386        let count = summary.matches('.').count();
387        assert!(count <= 2, "too many sentences: {}", count);
388    }
389
390    #[test]
391    fn test_textrank_empty_text() {
392        let summarizer = TextRankSummarizer::default();
393        let summary = summarizer.summarize("", 3).expect("summarize empty");
394        assert!(summary.is_empty());
395    }
396
397    #[test]
398    fn test_textrank_more_than_available() {
399        let summarizer = TextRankSummarizer::default();
400        // Requesting 100 sentences from a 5-sentence text → full text
401        let summary = summarizer.summarize(TEXT, 100).expect("summarize");
402        assert!(!summary.is_empty());
403    }
404
405    #[test]
406    fn test_textrank_bm25() {
407        let summarizer = TextRankSummarizer::new(0.85, 20, SentenceSimilarity::BM25);
408        let summary = summarizer.summarize(TEXT, 2).expect("summarize bm25");
409        assert!(!summary.is_empty());
410    }
411
412    #[test]
413    fn test_textrank_jaccard() {
414        let summarizer = TextRankSummarizer::new(0.85, 20, SentenceSimilarity::Jaccard);
415        let summary = summarizer.summarize(TEXT, 2).expect("summarize jaccard");
416        assert!(!summary.is_empty());
417    }
418
419    #[test]
420    fn test_lead_k() {
421        let summary = ExtractiveSummarizer::lead_k(TEXT, 2).expect("lead_k");
422        // Should start with the first sentence's first word
423        assert!(summary.starts_with("The quick"));
424    }
425
426    #[test]
427    fn test_lead_k_zero_error() {
428        let result = ExtractiveSummarizer::lead_k(TEXT, 0);
429        assert!(result.is_err());
430    }
431
432    #[test]
433    fn test_frequency_based() {
434        let summary = ExtractiveSummarizer::frequency_based(TEXT, 2).expect("freq_based");
435        assert!(!summary.is_empty());
436    }
437
438    #[test]
439    fn test_jaccard_similarity() {
440        let a = "the cat sat on the mat";
441        let b = "the cat sat on the mat";
442        let sim = jaccard(a, b);
443        assert!((sim - 1.0).abs() < 1e-6);
444    }
445
446    #[test]
447    fn test_jaccard_no_overlap() {
448        let a = "hello world";
449        let b = "foo bar baz";
450        let sim = jaccard(a, b);
451        assert!(sim < 0.01);
452    }
453
454    #[test]
455    fn test_build_idf() {
456        let sents = vec!["the cat sat".to_string(), "the dog ran".to_string()];
457        let idf = build_idf(&sents);
458        // "the" appears in all sentences → low IDF
459        // "cat" appears in 1 sentence → higher IDF
460        let idf_the = idf.get("the").copied().unwrap_or(0.0);
461        let idf_cat = idf.get("cat").copied().unwrap_or(0.0);
462        assert!(idf_cat >= idf_the);
463    }
464}