Skip to main content

scirs2_text/
abstractive_summary.rs

1//! Abstractive summarization module
2//!
3//! Provides multi-document fusion, sentence compression, and enhanced
4//! centroid-based summarization with position bias and query-focused mode.
5//! Includes ROUGE-like evaluation metrics (ROUGE-N and ROUGE-L).
6//!
7//! # Structures
8//!
9//! - [`FusionSummarizer`] – multi-document fusion via semantic sentence clustering
10//! - [`CompressionSummarizer`] – sentence compression by dropping low-importance tokens
11//! - [`EnhancedCentroidSummarizer`] – centroid summarization with position bias and
12//!   optional query focus
13//!
14//! # Evaluation
15//!
16//! - [`rouge_n`] – ROUGE-N recall (n-gram overlap)
17//! - [`rouge_l`] – ROUGE-L (LCS-based recall)
18
19use crate::error::{Result, TextError};
20use crate::tokenize::{SentenceTokenizer, Tokenizer, WordTokenizer};
21use crate::vectorize::{TfidfVectorizer, Vectorizer};
22use scirs2_core::ndarray::{Array1, Array2};
23use std::collections::{HashMap, HashSet};
24
25// ---------------------------------------------------------------------------
26// Shared helpers
27// ---------------------------------------------------------------------------
28
29/// Tokenise a sentence into lowercase word tokens, stripping punctuation.
30fn word_tokens(sentence: &str) -> Vec<String> {
31    sentence
32        .split(|c: char| !c.is_alphanumeric())
33        .filter(|t| !t.is_empty())
34        .map(|t| t.to_lowercase())
35        .collect()
36}
37
38/// Cosine similarity between two ndarray row vectors from the same matrix.
39fn cosine_sim_rows(matrix: &Array2<f64>, i: usize, j: usize) -> f64 {
40    let cols = matrix.ncols();
41    let mut dot = 0.0_f64;
42    let mut n1 = 0.0_f64;
43    let mut n2 = 0.0_f64;
44    for c in 0..cols {
45        let a = matrix[[i, c]];
46        let b = matrix[[j, c]];
47        dot += a * b;
48        n1 += a * a;
49        n2 += b * b;
50    }
51    let denom = n1.sqrt() * n2.sqrt();
52    if denom == 0.0 {
53        0.0
54    } else {
55        dot / denom
56    }
57}
58
59/// Cosine similarity between an owned row vector and a centroid vector.
60fn cosine_sim_vec(row: &Array1<f64>, centroid: &Array1<f64>) -> f64 {
61    let dot = row.dot(centroid);
62    let n1 = row.dot(row).sqrt();
63    let n2 = centroid.dot(centroid).sqrt();
64    if n1 == 0.0 || n2 == 0.0 {
65        0.0
66    } else {
67        dot / (n1 * n2)
68    }
69}
70
71/// Build a TF-IDF matrix from a slice of sentence strings.
72/// Returns the (matrix, vectorizer) pair so callers can reuse the vocabulary.
73fn build_tfidf_matrix(sentences: &[String]) -> Result<(Array2<f64>, TfidfVectorizer)> {
74    if sentences.is_empty() {
75        return Err(TextError::InvalidInput(
76            "Cannot build TF-IDF matrix from empty sentence list".to_string(),
77        ));
78    }
79    let refs: Vec<&str> = sentences.iter().map(|s| s.as_str()).collect();
80    let mut vectorizer = TfidfVectorizer::default();
81    let matrix = vectorizer.fit_transform(&refs)?;
82    Ok((matrix, vectorizer))
83}
84
85// ---------------------------------------------------------------------------
86// ScoredSentence (local, richer than text_summarization::ScoredSentence)
87// ---------------------------------------------------------------------------
88
89/// A sentence annotated with its origin document index and a relevance score.
90#[derive(Debug, Clone)]
91pub struct ScoredSentence {
92    /// The sentence text.
93    pub text: String,
94    /// Zero-based position within the original document (or global sentence list).
95    pub index: usize,
96    /// Zero-based index of the source document (for multi-document fusion).
97    pub doc_index: usize,
98    /// Relevance score (higher is more important).
99    pub score: f64,
100}
101
102// ---------------------------------------------------------------------------
103// FusionSummarizer
104// ---------------------------------------------------------------------------
105
106/// Multi-document fusion summarizer.
107///
108/// The pipeline is:
109/// 1. Extract all sentences from all documents together with TF-IDF scores.
110/// 2. Cluster sentences semantically (k-means-style cosine clustering).
111/// 3. Pick the best representative from each cluster and concatenate up to
112///    `max_words` of the resulting fusion summary.
113///
114/// # Example
115///
116/// ```rust
117/// use scirs2_text::abstractive_summary::FusionSummarizer;
118///
119/// let docs = vec![
120///     "Rust is a systems programming language. It focuses on safety.",
121///     "Safety is the primary goal of Rust. Memory safety without a GC.",
122/// ];
123/// let summarizer = FusionSummarizer::new(3);
124/// let summary = summarizer.summarize(&docs, 50).unwrap();
125/// assert!(!summary.is_empty());
126/// ```
127pub struct FusionSummarizer {
128    /// Desired number of clusters (= roughly number of output sentences).
129    n_clusters: usize,
130    /// Maximum PageRank-like iterations for cluster convergence.
131    max_iter: usize,
132    /// Minimum cosine similarity to assign a sentence to an existing cluster.
133    cluster_threshold: f64,
134}
135
136impl FusionSummarizer {
137    /// Create a new `FusionSummarizer` with `n_clusters` output sentences.
138    pub fn new(n_clusters: usize) -> Self {
139        Self {
140            n_clusters: n_clusters.max(1),
141            max_iter: 30,
142            cluster_threshold: 0.1,
143        }
144    }
145
146    /// Override maximum clustering iterations.
147    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
148        self.max_iter = max_iter;
149        self
150    }
151
152    /// Override the minimum cosine similarity threshold used during clustering.
153    pub fn with_cluster_threshold(mut self, threshold: f64) -> Self {
154        self.cluster_threshold = threshold.clamp(0.0, 1.0);
155        self
156    }
157
158    /// Extract sentences from multiple documents, scoring each by TF-IDF.
159    ///
160    /// Returns a flat list of [`ScoredSentence`] items drawn from all documents.
161    pub fn extract_sentences(&self, documents: &[&str]) -> Vec<ScoredSentence> {
162        if documents.is_empty() {
163            return Vec::new();
164        }
165
166        let sentence_tokenizer = SentenceTokenizer::new();
167        let mut all_sentences: Vec<ScoredSentence> = Vec::new();
168        let mut global_index = 0usize;
169
170        // Collect all raw sentences first.
171        let mut raw_per_doc: Vec<Vec<String>> = Vec::new();
172        for doc in documents {
173            let sents = sentence_tokenizer
174                .tokenize(doc)
175                .unwrap_or_else(|_| vec![doc.to_string()]);
176            raw_per_doc.push(sents);
177        }
178
179        // Flatten for TF-IDF fitting across all documents.
180        let flat: Vec<String> = raw_per_doc.iter().flatten().cloned().collect();
181        if flat.is_empty() {
182            return Vec::new();
183        }
184
185        // Build TF-IDF on the full corpus.
186        let flat_refs: Vec<&str> = flat.iter().map(|s| s.as_str()).collect();
187        let mut vectorizer = TfidfVectorizer::default();
188        let tfidf = match vectorizer.fit_transform(&flat_refs) {
189            Ok(m) => m,
190            Err(_) => return Vec::new(),
191        };
192
193        let cols = tfidf.ncols();
194        let n = flat.len();
195
196        for (flat_idx, sentence) in flat.iter().enumerate() {
197            // Score = mean TF-IDF of the row.
198            let score = if cols > 0 {
199                let row_sum: f64 = (0..cols).map(|c| tfidf[[flat_idx, c]]).sum();
200                row_sum / cols as f64
201            } else {
202                0.0
203            };
204
205            // Determine which document this sentence came from.
206            let mut doc_index = 0usize;
207            let mut cumulative = 0usize;
208            for (di, sents) in raw_per_doc.iter().enumerate() {
209                if flat_idx < cumulative + sents.len() {
210                    doc_index = di;
211                    break;
212                }
213                cumulative += sents.len();
214            }
215
216            all_sentences.push(ScoredSentence {
217                text: sentence.clone(),
218                index: global_index,
219                doc_index,
220                score,
221            });
222            global_index += 1;
223        }
224
225        // Normalise scores to [0,1].
226        let max_score = all_sentences
227            .iter()
228            .map(|s| s.score)
229            .fold(0.0_f64, f64::max);
230        if max_score > 0.0 {
231            for s in &mut all_sentences {
232                s.score /= max_score;
233            }
234        }
235
236        all_sentences
237    }
238
239    /// Cluster sentences semantically using cosine similarity.
240    ///
241    /// Implements a greedy cosine-based k-means variant:
242    /// 1. Initialise cluster centroids with the highest-scoring sentences.
243    /// 2. Assign each sentence to its nearest centroid.
244    /// 3. Recompute centroids and repeat up to `max_iter` times.
245    ///
246    /// Returns a `Vec` of clusters, each being a `Vec<ScoredSentence>`.
247    pub fn cluster_sentences(
248        &self,
249        sentences: &[ScoredSentence],
250        n_clusters: usize,
251    ) -> Vec<Vec<ScoredSentence>> {
252        let k = n_clusters.min(sentences.len()).max(1);
253        if sentences.is_empty() {
254            return Vec::new();
255        }
256        if sentences.len() <= k {
257            // Each sentence is its own cluster.
258            return sentences.iter().map(|s| vec![s.clone()]).collect();
259        }
260
261        // Build TF-IDF matrix from sentence texts.
262        let texts: Vec<String> = sentences.iter().map(|s| s.text.clone()).collect();
263        let refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
264        let mut vectorizer = TfidfVectorizer::default();
265        let matrix = match vectorizer.fit_transform(&refs) {
266            Ok(m) => m,
267            Err(_) => {
268                // Fallback: put everything in one cluster.
269                return vec![sentences.to_vec()];
270            }
271        };
272
273        let n = sentences.len();
274        let cols = matrix.ncols();
275
276        // Choose initial centroids: the k highest-scoring sentences.
277        let mut sorted_indices: Vec<usize> = (0..n).collect();
278        sorted_indices.sort_by(|&a, &b| {
279            sentences[b]
280                .score
281                .partial_cmp(&sentences[a].score)
282                .unwrap_or(std::cmp::Ordering::Equal)
283        });
284        let centroid_indices: Vec<usize> = sorted_indices.into_iter().take(k).collect();
285
286        // Initialise centroid vectors (k x cols).
287        let mut centroids: Vec<Array1<f64>> = centroid_indices
288            .iter()
289            .map(|&ci| matrix.row(ci).to_owned())
290            .collect();
291
292        let mut assignments: Vec<usize> = vec![0; n];
293
294        for _iter in 0..self.max_iter {
295            let mut changed = false;
296
297            // Assignment step.
298            for i in 0..n {
299                let row = matrix.row(i).to_owned();
300                let best_cluster = centroids
301                    .iter()
302                    .enumerate()
303                    .map(|(ci, centroid)| (ci, cosine_sim_vec(&row, centroid)))
304                    .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
305                    .map(|(ci, _)| ci)
306                    .unwrap_or(0);
307
308                if assignments[i] != best_cluster {
309                    assignments[i] = best_cluster;
310                    changed = true;
311                }
312            }
313
314            if !changed {
315                break;
316            }
317
318            // Update step: recompute centroids as the mean of assigned sentences.
319            for ci in 0..k {
320                let members: Vec<usize> = (0..n).filter(|&i| assignments[i] == ci).collect();
321                if members.is_empty() {
322                    // Keep previous centroid.
323                    continue;
324                }
325                let mut new_centroid = Array1::zeros(cols);
326                for &mi in &members {
327                    new_centroid = new_centroid + matrix.row(mi).to_owned();
328                }
329                let count = members.len() as f64;
330                new_centroid.mapv_inplace(|v| v / count);
331                centroids[ci] = new_centroid;
332            }
333        }
334
335        // Build output clusters.
336        let mut clusters: Vec<Vec<ScoredSentence>> = vec![Vec::new(); k];
337        for (i, &ci) in assignments.iter().enumerate() {
338            clusters[ci].push(sentences[i].clone());
339        }
340
341        // Remove empty clusters that may arise from degenerate inputs.
342        clusters.retain(|c| !c.is_empty());
343        clusters
344    }
345
346    /// Generate a fused summary from clusters, limited to `max_words`.
347    ///
348    /// Picks the highest-scoring sentence from each cluster, then joins them
349    /// in order of cluster appearance (preserving reading flow).
350    pub fn generate_summary(&self, clusters: &[Vec<ScoredSentence>], max_words: usize) -> String {
351        if clusters.is_empty() {
352            return String::new();
353        }
354
355        // Pick best representative per cluster.
356        let mut representatives: Vec<&ScoredSentence> = clusters
357            .iter()
358            .filter_map(|cluster| {
359                cluster.iter().max_by(|a, b| {
360                    a.score
361                        .partial_cmp(&b.score)
362                        .unwrap_or(std::cmp::Ordering::Equal)
363                })
364            })
365            .collect();
366
367        // Sort representatives by their original index to preserve flow.
368        representatives.sort_by_key(|s| s.index);
369
370        // Concatenate up to max_words.
371        let mut words_used = 0usize;
372        let mut collected_words: Vec<&str> = Vec::new();
373
374        'outer: for rep in representatives {
375            let sentence_words: Vec<&str> = rep.text.split_whitespace().collect();
376            for word in &sentence_words {
377                if words_used >= max_words {
378                    break 'outer;
379                }
380                collected_words.push(word);
381                words_used += 1;
382            }
383        }
384
385        collected_words.join(" ")
386    }
387
388    /// Convenience method: extract + cluster + generate in one call.
389    pub fn summarize(&self, documents: &[&str], max_words: usize) -> Result<String> {
390        if documents.is_empty() {
391            return Ok(String::new());
392        }
393        let sentences = self.extract_sentences(documents);
394        if sentences.is_empty() {
395            return Ok(String::new());
396        }
397        let clusters = self.cluster_sentences(&sentences, self.n_clusters);
398        Ok(self.generate_summary(&clusters, max_words))
399    }
400}
401
402// ---------------------------------------------------------------------------
403// CompressionSummarizer
404// ---------------------------------------------------------------------------
405
406/// Sentence compression by dropping low-importance tokens.
407///
408/// The importance of each token is computed via a TF-IDF-inspired heuristic
409/// using term frequency within the sentence and inverse document frequency
410/// estimated from a small built-in stop-word list.
411///
412/// # Example
413///
414/// ```rust
415/// use scirs2_text::abstractive_summary::CompressionSummarizer;
416///
417/// let cs = CompressionSummarizer::new();
418/// let compressed = cs.compress_sentence("The very quick brown fox jumped lazily", 0.6);
419/// assert!(!compressed.is_empty());
420/// ```
421pub struct CompressionSummarizer {
422    /// Stop words that always receive very low importance scores.
423    stop_words: HashSet<String>,
424}
425
426impl CompressionSummarizer {
427    /// Create a `CompressionSummarizer` with the built-in English stop-word list.
428    pub fn new() -> Self {
429        let raw = [
430            "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has",
431            "had", "do", "does", "did", "will", "would", "could", "should", "may", "might",
432            "shall", "can", "and", "but", "or", "nor", "for", "yet", "so", "in", "on", "at", "to",
433            "from", "by", "with", "of", "about", "as", "into", "through", "during", "before",
434            "after", "above", "below", "between", "each", "all", "both", "very", "just", "too",
435            "also", "then", "than", "that", "this", "these", "those", "i", "me", "my", "we", "our",
436            "you", "your", "he", "she", "it", "its", "they", "them", "their", "what", "which",
437            "who", "whom", "not", "no",
438        ];
439        Self {
440            stop_words: raw.iter().map(|w| w.to_string()).collect(),
441        }
442    }
443
444    /// Create a `CompressionSummarizer` with a custom stop-word list.
445    pub fn with_stop_words(stop_words: HashSet<String>) -> Self {
446        Self { stop_words }
447    }
448
449    /// Compute the importance score of a single `token` given its sentence context.
450    ///
451    /// Score components:
452    /// - Term frequency within the sentence.
453    /// - Stop-word penalty (×0.1 if the token is a stop word).
454    /// - Length bonus: longer tokens receive a slight boost.
455    /// - Capitalisation bonus: capitalised mid-sentence tokens receive a boost
456    ///   (heuristic for proper nouns).
457    pub fn importance_score(&self, token: &str, sentence_tokens: &[String]) -> f64 {
458        if sentence_tokens.is_empty() {
459            return 0.0;
460        }
461        let token_lower = token.to_lowercase();
462
463        // Term frequency.
464        let tf = sentence_tokens
465            .iter()
466            .filter(|t| t.to_lowercase() == token_lower)
467            .count() as f64
468            / sentence_tokens.len() as f64;
469
470        // Stop-word penalty.
471        let stop_penalty = if self.stop_words.contains(&token_lower) {
472            0.1
473        } else {
474            1.0
475        };
476
477        // Length bonus (normalised to ~[0.5, 1.5]).
478        let len_bonus = (1.0 + (token.len() as f64 / 10.0).min(1.0)) * 0.5;
479
480        // Capitalisation bonus for mid-sentence proper-noun heuristic.
481        let cap_bonus = if token
482            .chars()
483            .next()
484            .map(|c| c.is_uppercase())
485            .unwrap_or(false)
486        {
487            0.3
488        } else {
489            0.0
490        };
491
492        (tf * stop_penalty + len_bonus + cap_bonus).max(0.0)
493    }
494
495    /// Compress `sentence` by retaining only the fraction `ratio` of tokens
496    /// with the highest importance scores.
497    ///
498    /// `ratio` is clamped to (0.0, 1.0]. A ratio of 1.0 keeps all tokens.
499    /// Tokens are retained in their original order.
500    ///
501    /// Returns an empty string if the sentence has no words.
502    pub fn compress_sentence(&self, sentence: &str, ratio: f64) -> String {
503        let ratio = ratio.clamp(0.01, 1.0);
504
505        // Preserve original whitespace-split tokens for output.
506        let original_tokens: Vec<&str> = sentence.split_whitespace().collect();
507        if original_tokens.is_empty() {
508            return String::new();
509        }
510
511        // Normalised tokens for scoring.
512        let norm_tokens: Vec<String> = original_tokens
513            .iter()
514            .map(|t| {
515                t.trim_matches(|c: char| !c.is_alphanumeric())
516                    .to_lowercase()
517            })
518            .collect();
519
520        let n = original_tokens.len();
521        let keep_count = ((n as f64 * ratio).ceil() as usize).clamp(1, n);
522
523        // Score each token.
524        let mut scored: Vec<(usize, f64)> = norm_tokens
525            .iter()
526            .enumerate()
527            .map(|(i, t)| (i, self.importance_score(t, &norm_tokens)))
528            .collect();
529
530        // Sort descending by score.
531        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
532
533        // Collect the top-k indices (then re-sort by position).
534        let mut keep_indices: Vec<usize> =
535            scored.iter().take(keep_count).map(|&(i, _)| i).collect();
536        keep_indices.sort_unstable();
537
538        // Reassemble.
539        keep_indices
540            .iter()
541            .map(|&i| original_tokens[i])
542            .collect::<Vec<_>>()
543            .join(" ")
544    }
545}
546
547impl Default for CompressionSummarizer {
548    fn default() -> Self {
549        Self::new()
550    }
551}
552
553// ---------------------------------------------------------------------------
554// EnhancedCentroidSummarizer
555// ---------------------------------------------------------------------------
556
557/// Centroid-based summarizer with position bias and optional query focus.
558///
559/// Extends the basic centroid approach with:
560/// - **Position bias**: earlier sentences receive a configurable bonus.
561/// - **Query-focused mode**: sentences are re-ranked by their cosine similarity
562///   to a query vector in addition to the document centroid.
563///
564/// # Example
565///
566/// ```rust
567/// use scirs2_text::abstractive_summary::EnhancedCentroidSummarizer;
568///
569/// let summarizer = EnhancedCentroidSummarizer::new(2)
570///     .with_position_bias(0.3)
571///     .with_redundancy_threshold(0.85);
572///
573/// let text = "Natural language processing is a field of AI. \
574///             It allows computers to understand human language. \
575///             NLP is used in chatbots, translation, and search engines.";
576/// let summary = summarizer.summarize(text).unwrap();
577/// assert!(!summary.is_empty());
578/// ```
579pub struct EnhancedCentroidSummarizer {
580    num_sentences: usize,
581    topic_threshold: f64,
582    redundancy_threshold: f64,
583    /// Weight of position score relative to centroid score (0 = no position bias).
584    position_bias: f64,
585}
586
587impl EnhancedCentroidSummarizer {
588    /// Create a new `EnhancedCentroidSummarizer` extracting up to `num_sentences` sentences.
589    pub fn new(num_sentences: usize) -> Self {
590        Self {
591            num_sentences: num_sentences.max(1),
592            topic_threshold: 0.1,
593            redundancy_threshold: 0.95,
594            position_bias: 0.2,
595        }
596    }
597
598    /// Set the position bias weight (0.0 = off, 1.0 = strong bias towards early sentences).
599    pub fn with_position_bias(mut self, bias: f64) -> Self {
600        self.position_bias = bias.clamp(0.0, 1.0);
601        self
602    }
603
604    /// Set the TF-IDF topic threshold (terms below this weight are zeroed in the centroid).
605    pub fn with_topic_threshold(mut self, threshold: f64) -> Self {
606        self.topic_threshold = threshold.clamp(0.0, 1.0);
607        self
608    }
609
610    /// Set the redundancy threshold: sentences more similar than this value are excluded.
611    pub fn with_redundancy_threshold(mut self, threshold: f64) -> Self {
612        self.redundancy_threshold = threshold.clamp(0.0, 1.0);
613        self
614    }
615
616    /// Standard summarization (not query-focused).
617    pub fn summarize(&self, text: &str) -> Result<String> {
618        self.summarize_internal(text, None)
619    }
620
621    /// Query-focused summarization.
622    ///
623    /// Sentences are ranked by a linear combination of their similarity to the
624    /// document centroid and their similarity to the query vector.
625    ///
626    /// # Arguments
627    ///
628    /// * `document` – the source document text.
629    /// * `query` – a short query string describing the desired focus.
630    /// * `max_sentences` – maximum number of sentences to return.
631    pub fn summarize_query_focused(
632        &self,
633        document: &str,
634        query: &str,
635        max_sentences: usize,
636    ) -> Result<String> {
637        let override_self = EnhancedCentroidSummarizer {
638            num_sentences: max_sentences.max(1),
639            ..*self
640        };
641        override_self.summarize_internal(document, Some(query))
642    }
643
644    fn summarize_internal(&self, text: &str, query: Option<&str>) -> Result<String> {
645        let sentence_tokenizer = SentenceTokenizer::new();
646        let sentences: Vec<String> = sentence_tokenizer.tokenize(text)?;
647
648        if sentences.is_empty() {
649            return Ok(String::new());
650        }
651        if sentences.len() <= self.num_sentences {
652            return Ok(text.to_string());
653        }
654
655        // Build TF-IDF vectors for all sentences.
656        let sentence_refs: Vec<&str> = sentences.iter().map(|s| s.as_str()).collect();
657        let mut vectorizer = TfidfVectorizer::default();
658        let tfidf = vectorizer.fit_transform(&sentence_refs)?;
659
660        // Compute document centroid.
661        let centroid = self.compute_centroid(&tfidf);
662
663        // Optionally compute query vector by transforming the query sentence.
664        let query_vec: Option<Array1<f64>> = if let Some(q) = query {
665            vectorizer.transform_batch(&[q]).ok().map(|m| {
666                // The query may produce a 1-row matrix; take row 0.
667                m.row(0).to_owned()
668            })
669        } else {
670            None
671        };
672
673        // Score each sentence.
674        let n = sentences.len();
675        let mut scored: Vec<(usize, f64)> = (0..n)
676            .map(|i| {
677                let row = tfidf.row(i).to_owned();
678                let centroid_sim = cosine_sim_vec(&row, &centroid);
679                let query_sim = query_vec
680                    .as_ref()
681                    .map(|qv| cosine_sim_vec(&row, qv))
682                    .unwrap_or(0.0);
683                // Combine centroid similarity with query similarity (50/50 if query provided).
684                let content_score = if query_vec.is_some() {
685                    0.5 * centroid_sim + 0.5 * query_sim
686                } else {
687                    centroid_sim
688                };
689                // Position bonus: exponential decay from sentence 0.
690                let pos_bonus = (-0.5 * i as f64 / n as f64).exp() * self.position_bias;
691                (i, content_score + pos_bonus)
692            })
693            .collect();
694
695        // Sort descending by score.
696        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
697
698        // Greedy selection avoiding redundancy.
699        let mut selected: Vec<usize> = Vec::new();
700        for (idx, _score) in &scored {
701            if selected.len() >= self.num_sentences {
702                break;
703            }
704            let redundant = selected
705                .iter()
706                .any(|&si| cosine_sim_rows(&tfidf, *idx, si) > self.redundancy_threshold);
707            if !redundant {
708                selected.push(*idx);
709            }
710        }
711
712        // Restore original order.
713        selected.sort_unstable();
714
715        let summary = selected
716            .iter()
717            .map(|&i| sentences[i].as_str())
718            .collect::<Vec<_>>()
719            .join(" ");
720
721        Ok(summary)
722    }
723
724    /// Compute the document centroid as the mean TF-IDF vector, zeroing terms
725    /// below `topic_threshold`.
726    fn compute_centroid(&self, tfidf: &Array2<f64>) -> Array1<f64> {
727        let mean = tfidf
728            .mean_axis(scirs2_core::ndarray::Axis(0))
729            .unwrap_or_else(|| Array1::zeros(tfidf.ncols()));
730
731        mean.mapv(|v| if v > self.topic_threshold { v } else { 0.0 })
732    }
733}
734
735// ---------------------------------------------------------------------------
736// ROUGE evaluation metrics
737// ---------------------------------------------------------------------------
738
739/// Compute ROUGE-N recall between a `hypothesis` and a `reference` string.
740///
741/// ROUGE-N is the fraction of reference n-grams that appear in the hypothesis.
742///
743/// # Arguments
744///
745/// * `hypothesis` – generated summary text.
746/// * `reference` – gold-standard reference text.
747/// * `n` – n-gram order (1 = unigrams, 2 = bigrams, …).
748///
749/// # Returns
750///
751/// A recall value in [0.0, 1.0]. Returns `0.0` when the reference contains no
752/// n-grams (e.g. `n` larger than the reference length).
753///
754/// # Example
755///
756/// ```rust
757/// use scirs2_text::abstractive_summary::rouge_n;
758///
759/// let recall = rouge_n("the cat sat on the mat", "the cat sat", 1);
760/// assert!(recall > 0.5);
761/// ```
762pub fn rouge_n(hypothesis: &str, reference: &str, n: usize) -> f64 {
763    if n == 0 {
764        return 0.0;
765    }
766    let hyp_tokens = word_tokens(hypothesis);
767    let ref_tokens = word_tokens(reference);
768
769    if ref_tokens.len() < n {
770        return 0.0;
771    }
772
773    // Build n-gram counts for reference.
774    let ref_ngrams = count_ngrams(&ref_tokens, n);
775    if ref_ngrams.is_empty() {
776        return 0.0;
777    }
778    let ref_total: usize = ref_ngrams.values().sum();
779
780    // Build n-gram counts for hypothesis.
781    let hyp_ngrams = count_ngrams(&hyp_tokens, n);
782
783    // Clipped overlap count.
784    let overlap: usize = ref_ngrams
785        .iter()
786        .map(|(gram, &ref_count)| {
787            let hyp_count = hyp_ngrams.get(gram).copied().unwrap_or(0);
788            hyp_count.min(ref_count)
789        })
790        .sum();
791
792    overlap as f64 / ref_total as f64
793}
794
795/// Build a frequency map of n-grams from `tokens`.
796fn count_ngrams(tokens: &[String], n: usize) -> HashMap<Vec<String>, usize> {
797    let mut map: HashMap<Vec<String>, usize> = HashMap::new();
798    if tokens.len() < n {
799        return map;
800    }
801    for i in 0..=(tokens.len() - n) {
802        let gram: Vec<String> = tokens[i..i + n].to_vec();
803        *map.entry(gram).or_insert(0) += 1;
804    }
805    map
806}
807
808/// Compute ROUGE-L recall based on the Longest Common Subsequence (LCS).
809///
810/// ROUGE-L is defined as `LCS(hypothesis, reference) / |reference|` in terms of
811/// token counts.
812///
813/// # Arguments
814///
815/// * `hypothesis` – generated summary text.
816/// * `reference` – gold-standard reference text.
817///
818/// # Returns
819///
820/// A recall value in [0.0, 1.0]. Returns `0.0` for empty inputs.
821///
822/// # Example
823///
824/// ```rust
825/// use scirs2_text::abstractive_summary::rouge_l;
826///
827/// let score = rouge_l("the cat sat", "the cat sat on the mat");
828/// assert!(score > 0.4);
829/// ```
830pub fn rouge_l(hypothesis: &str, reference: &str) -> f64 {
831    let hyp_tokens = word_tokens(hypothesis);
832    let ref_tokens = word_tokens(reference);
833
834    if ref_tokens.is_empty() {
835        return 0.0;
836    }
837
838    let lcs_len = lcs_length(&hyp_tokens, &ref_tokens);
839    lcs_len as f64 / ref_tokens.len() as f64
840}
841
842/// Compute the length of the Longest Common Subsequence between two token sequences.
843///
844/// Uses the classic O(m×n) dynamic programming algorithm.
845fn lcs_length(a: &[String], b: &[String]) -> usize {
846    let m = a.len();
847    let n = b.len();
848    if m == 0 || n == 0 {
849        return 0;
850    }
851
852    // Rolling two-row DP to save memory.
853    let mut prev = vec![0usize; n + 1];
854    let mut curr = vec![0usize; n + 1];
855
856    for i in 1..=m {
857        for j in 1..=n {
858            curr[j] = if a[i - 1] == b[j - 1] {
859                prev[j - 1] + 1
860            } else {
861                prev[j].max(curr[j - 1])
862            };
863        }
864        std::mem::swap(&mut prev, &mut curr);
865        curr.iter_mut().for_each(|v| *v = 0);
866    }
867
868    prev[n]
869}
870
871// ---------------------------------------------------------------------------
872// Tests
873// ---------------------------------------------------------------------------
874
875#[cfg(test)]
876mod tests {
877    use super::*;
878
879    const MULTI_DOC_A: &str =
880        "Rust is a systems programming language. It focuses on safety and performance.";
881    const MULTI_DOC_B: &str =
882        "Memory safety without a garbage collector is a key goal of Rust. The language also \
883         emphasises zero-cost abstractions.";
884    const LONG_TEXT: &str = "Natural language processing is a field of artificial intelligence. \
885        It allows computers to understand and generate human language. \
886        Applications include machine translation, chatbots, and sentiment analysis. \
887        Deep learning has greatly advanced NLP in recent years. \
888        Transformer models such as BERT and GPT are state-of-the-art.";
889
890    // -- FusionSummarizer --
891
892    #[test]
893    fn test_fusion_extract_sentences_nonempty() {
894        let docs = vec![MULTI_DOC_A, MULTI_DOC_B];
895        let fs = FusionSummarizer::new(2);
896        let sents = fs.extract_sentences(&docs);
897        assert!(!sents.is_empty());
898        // Every score should be in [0, 1].
899        for s in &sents {
900            assert!(
901                (0.0..=1.001).contains(&s.score),
902                "score out of range: {}",
903                s.score
904            );
905        }
906    }
907
908    #[test]
909    fn test_fusion_extract_empty_docs() {
910        let fs = FusionSummarizer::new(2);
911        let sents = fs.extract_sentences(&[]);
912        assert!(sents.is_empty());
913    }
914
915    #[test]
916    fn test_fusion_cluster_basic() {
917        let docs = vec![MULTI_DOC_A, MULTI_DOC_B];
918        let fs = FusionSummarizer::new(2);
919        let sents = fs.extract_sentences(&docs);
920        let clusters = fs.cluster_sentences(&sents, 2);
921        assert!(!clusters.is_empty());
922        // All sentences accounted for.
923        let total: usize = clusters.iter().map(|c| c.len()).sum();
924        assert_eq!(total, sents.len());
925    }
926
927    #[test]
928    fn test_fusion_generate_summary_respects_max_words() {
929        let docs = vec![MULTI_DOC_A, MULTI_DOC_B];
930        let fs = FusionSummarizer::new(2);
931        let sents = fs.extract_sentences(&docs);
932        let clusters = fs.cluster_sentences(&sents, 2);
933        let summary = fs.generate_summary(&clusters, 10);
934        let words: usize = summary.split_whitespace().count();
935        assert!(words <= 10, "Expected ≤10 words, got {}", words);
936    }
937
938    #[test]
939    fn test_fusion_summarize_end_to_end() {
940        let docs = vec![MULTI_DOC_A, MULTI_DOC_B];
941        let fs = FusionSummarizer::new(2);
942        let summary = fs.summarize(&docs, 60).expect("summarize should succeed");
943        assert!(!summary.is_empty());
944    }
945
946    #[test]
947    fn test_fusion_single_document() {
948        let docs = vec![LONG_TEXT];
949        let fs = FusionSummarizer::new(2);
950        let summary = fs.summarize(&docs, 80).expect("should succeed");
951        assert!(!summary.is_empty());
952    }
953
954    // -- CompressionSummarizer --
955
956    #[test]
957    fn test_compression_basic() {
958        let cs = CompressionSummarizer::new();
959        let sentence = "The very quick brown fox jumped lazily over the fence";
960        let compressed = cs.compress_sentence(sentence, 0.5);
961        let orig_words: usize = sentence.split_whitespace().count();
962        let comp_words: usize = compressed.split_whitespace().count();
963        assert!(comp_words <= orig_words);
964        assert!(!compressed.is_empty());
965    }
966
967    #[test]
968    fn test_compression_ratio_one_keeps_all() {
969        let cs = CompressionSummarizer::new();
970        let sentence = "Hello world this is a test sentence";
971        let compressed = cs.compress_sentence(sentence, 1.0);
972        let orig_words = sentence.split_whitespace().count();
973        let comp_words = compressed.split_whitespace().count();
974        assert_eq!(comp_words, orig_words);
975    }
976
977    #[test]
978    fn test_compression_empty_sentence() {
979        let cs = CompressionSummarizer::new();
980        let result = cs.compress_sentence("", 0.5);
981        assert!(result.is_empty());
982    }
983
984    #[test]
985    fn test_compression_importance_stop_word_lower() {
986        let cs = CompressionSummarizer::new();
987        let tokens: Vec<String> = vec!["the".to_string(), "quick".to_string(), "fox".to_string()];
988        let stop_score = cs.importance_score("the", &tokens);
989        let content_score = cs.importance_score("fox", &tokens);
990        assert!(
991            content_score > stop_score,
992            "Content word should score higher than stop word"
993        );
994    }
995
996    // -- EnhancedCentroidSummarizer --
997
998    #[test]
999    fn test_enhanced_centroid_basic() {
1000        let s = EnhancedCentroidSummarizer::new(2);
1001        let summary = s.summarize(LONG_TEXT).expect("should succeed");
1002        assert!(!summary.is_empty());
1003        assert!(summary.len() < LONG_TEXT.len());
1004    }
1005
1006    #[test]
1007    fn test_enhanced_centroid_short_text() {
1008        let s = EnhancedCentroidSummarizer::new(5);
1009        let text = "One sentence only.";
1010        let summary = s.summarize(text).expect("should succeed");
1011        assert_eq!(summary, text);
1012    }
1013
1014    #[test]
1015    fn test_enhanced_centroid_empty() {
1016        let s = EnhancedCentroidSummarizer::new(2);
1017        let summary = s.summarize("").expect("should succeed");
1018        assert!(summary.is_empty());
1019    }
1020
1021    #[test]
1022    fn test_enhanced_centroid_query_focused() {
1023        let s = EnhancedCentroidSummarizer::new(2);
1024        let summary = s
1025            .summarize_query_focused(LONG_TEXT, "transformer models BERT GPT", 2)
1026            .expect("should succeed");
1027        assert!(!summary.is_empty());
1028    }
1029
1030    #[test]
1031    fn test_enhanced_centroid_query_focused_max_sentences() {
1032        let s = EnhancedCentroidSummarizer::new(2);
1033        let summary = s
1034            .summarize_query_focused(LONG_TEXT, "deep learning", 1)
1035            .expect("should succeed");
1036        // Should return at most 1 sentence.
1037        let sent_tok = SentenceTokenizer::new();
1038        let sents = sent_tok.tokenize(&summary).expect("ok");
1039        assert!(sents.len() <= 1);
1040    }
1041
1042    // -- ROUGE-N --
1043
1044    #[test]
1045    fn test_rouge1_perfect_match() {
1046        let recall = rouge_n("the cat sat", "the cat sat", 1);
1047        assert!((recall - 1.0).abs() < 1e-9, "Expected 1.0, got {recall}");
1048    }
1049
1050    #[test]
1051    fn test_rouge1_partial_overlap() {
1052        let recall = rouge_n("cat sat", "the cat sat on the mat", 1);
1053        // 2 out of 6 reference unigrams matched: 2/6 ≈ 0.333
1054        assert!((recall - 2.0 / 6.0).abs() < 1e-9, "Got {recall}");
1055    }
1056
1057    #[test]
1058    fn test_rouge2_basic() {
1059        let recall = rouge_n("the cat sat on the mat", "the cat sat on the mat", 2);
1060        assert!((recall - 1.0).abs() < 1e-9);
1061    }
1062
1063    #[test]
1064    fn test_rouge_n_zero_n() {
1065        assert_eq!(rouge_n("anything", "reference", 0), 0.0);
1066    }
1067
1068    #[test]
1069    fn test_rouge_n_empty_reference() {
1070        assert_eq!(rouge_n("hypothesis", "", 1), 0.0);
1071    }
1072
1073    #[test]
1074    fn test_rouge_n_empty_hypothesis() {
1075        // No n-grams match → recall = 0.
1076        assert_eq!(rouge_n("", "the cat sat", 1), 0.0);
1077    }
1078
1079    // -- ROUGE-L --
1080
1081    #[test]
1082    fn test_rouge_l_perfect_match() {
1083        let score = rouge_l("the cat sat", "the cat sat");
1084        assert!((score - 1.0).abs() < 1e-9);
1085    }
1086
1087    #[test]
1088    fn test_rouge_l_partial() {
1089        // LCS("cat sat", "the cat sat on the mat") = "cat sat" length 2 → 2/6
1090        let score = rouge_l("cat sat", "the cat sat on the mat");
1091        assert!((score - 2.0 / 6.0).abs() < 1e-9, "Got {score}");
1092    }
1093
1094    #[test]
1095    fn test_rouge_l_empty_reference() {
1096        assert_eq!(rouge_l("hypothesis", ""), 0.0);
1097    }
1098
1099    #[test]
1100    fn test_rouge_l_empty_hypothesis() {
1101        assert_eq!(rouge_l("", "reference text"), 0.0);
1102    }
1103
1104    #[test]
1105    fn test_lcs_symmetric() {
1106        let a = vec!["a".to_string(), "b".to_string(), "c".to_string()];
1107        let b = vec!["b".to_string(), "c".to_string(), "d".to_string()];
1108        let lcs_ab = lcs_length(&a, &b);
1109        let lcs_ba = lcs_length(&b, &a);
1110        assert_eq!(lcs_ab, lcs_ba);
1111    }
1112}