1use crate::error::{Result, TextError};
15use crate::tokenize::{SentenceTokenizer, Tokenizer, WordTokenizer};
16use crate::vectorize::{TfidfVectorizer, Vectorizer};
17use scirs2_core::ndarray::{Array1, Array2, Axis};
18use std::collections::{HashMap, HashSet};
19
20#[derive(Debug, Clone, Copy, PartialEq)]
26pub enum SummarizationMethod {
27 TextRank,
29 Position,
31 TfIdf,
33 Ensemble {
35 textrank_weight: f64,
37 position_weight: f64,
39 tfidf_weight: f64,
41 },
42}
43
44#[derive(Debug, Clone)]
46pub struct ScoredSentence {
47 pub text: String,
49 pub index: usize,
51 pub score: f64,
53}
54
55pub fn summarize(text: &str, ratio: f64, method: SummarizationMethod) -> Result<String> {
71 if text.trim().is_empty() {
72 return Ok(String::new());
73 }
74
75 let clamped_ratio = ratio.clamp(0.0, 1.0);
76
77 let sentence_tokenizer = SentenceTokenizer::new();
78 let sentences: Vec<String> = sentence_tokenizer.tokenize(text)?;
79
80 if sentences.is_empty() {
81 return Ok(String::new());
82 }
83
84 let n_select = (sentences.len() as f64 * clamped_ratio).ceil().max(1.0) as usize;
85
86 if n_select >= sentences.len() {
87 return Ok(text.to_string());
88 }
89
90 let scored = match method {
91 SummarizationMethod::TextRank => score_textrank(&sentences)?,
92 SummarizationMethod::Position => score_position(&sentences),
93 SummarizationMethod::TfIdf => score_tfidf(&sentences)?,
94 SummarizationMethod::Ensemble {
95 textrank_weight,
96 position_weight,
97 tfidf_weight,
98 } => score_ensemble(&sentences, textrank_weight, position_weight, tfidf_weight)?,
99 };
100
101 let mut top: Vec<ScoredSentence> = scored;
103 top.sort_by(|a, b| {
104 b.score
105 .partial_cmp(&a.score)
106 .unwrap_or(std::cmp::Ordering::Equal)
107 });
108 top.truncate(n_select);
109
110 top.sort_by_key(|s| s.index);
112
113 let summary = top
114 .iter()
115 .map(|s| s.text.clone())
116 .collect::<Vec<_>>()
117 .join(" ");
118
119 Ok(summary)
120}
121
122pub fn score_textrank(sentences: &[String]) -> Result<Vec<ScoredSentence>> {
129 let n = sentences.len();
130 if n == 0 {
131 return Ok(Vec::new());
132 }
133 if n == 1 {
134 return Ok(vec![ScoredSentence {
135 text: sentences[0].clone(),
136 index: 0,
137 score: 1.0,
138 }]);
139 }
140
141 let similarity_matrix = build_similarity_matrix(sentences)?;
142 let scores = pagerank(&similarity_matrix, 0.85, 100, 1e-5)?;
143
144 Ok(sentences
145 .iter()
146 .enumerate()
147 .map(|(i, s)| ScoredSentence {
148 text: s.clone(),
149 index: i,
150 score: scores[i],
151 })
152 .collect())
153}
154
155fn build_similarity_matrix(sentences: &[String]) -> Result<Array2<f64>> {
157 let refs: Vec<&str> = sentences.iter().map(|s| s.as_str()).collect();
158 let mut vectorizer = TfidfVectorizer::default();
159 vectorizer.fit(&refs)?;
160 let tfidf = vectorizer.transform_batch(&refs)?;
161
162 let n = sentences.len();
163 let mut matrix = Array2::zeros((n, n));
164
165 for i in 0..n {
166 for j in (i + 1)..n {
167 let sim = cosine_sim_rows(&tfidf, i, j);
168 matrix[[i, j]] = sim;
169 matrix[[j, i]] = sim;
170 }
171 }
172
173 Ok(matrix)
174}
175
176fn cosine_sim_rows(matrix: &Array2<f64>, i: usize, j: usize) -> f64 {
178 let cols = matrix.ncols();
179 let mut dot = 0.0_f64;
180 let mut n1 = 0.0_f64;
181 let mut n2 = 0.0_f64;
182
183 for c in 0..cols {
184 let a = matrix[[i, c]];
185 let b = matrix[[j, c]];
186 dot += a * b;
187 n1 += a * a;
188 n2 += b * b;
189 }
190
191 let denom = n1.sqrt() * n2.sqrt();
192 if denom == 0.0 {
193 0.0
194 } else {
195 dot / denom
196 }
197}
198
199fn pagerank(
201 matrix: &Array2<f64>,
202 damping: f64,
203 max_iter: usize,
204 threshold: f64,
205) -> Result<Vec<f64>> {
206 let n = matrix.nrows();
207 let mut scores = vec![1.0 / n as f64; n];
208
209 let mut norm_matrix = matrix.clone();
211 for i in 0..n {
212 let row_sum: f64 = (0..n).map(|j| matrix[[i, j]]).sum();
213 if row_sum > 0.0 {
214 for j in 0..n {
215 norm_matrix[[i, j]] = matrix[[i, j]] / row_sum;
216 }
217 }
218 }
219
220 for _ in 0..max_iter {
221 let mut new_scores = vec![(1.0 - damping) / n as f64; n];
222
223 for i in 0..n {
224 for j in 0..n {
225 new_scores[i] += damping * norm_matrix[[j, i]] * scores[j];
226 }
227 }
228
229 let diff: f64 = scores
230 .iter()
231 .zip(new_scores.iter())
232 .map(|(a, b)| (a - b).abs())
233 .sum();
234
235 scores = new_scores;
236 if diff < threshold {
237 break;
238 }
239 }
240
241 Ok(scores)
242}
243
244pub fn score_position(sentences: &[String]) -> Vec<ScoredSentence> {
251 let n = sentences.len();
252 if n == 0 {
253 return Vec::new();
254 }
255
256 sentences
257 .iter()
258 .enumerate()
259 .map(|(i, s)| {
260 let position_score = if n == 1 {
262 1.0
263 } else {
264 let lead_score = 1.0 - (i as f64 / n as f64);
265 let conclusion_bonus = if i == n - 1 { 0.2 } else { 0.0 };
266 let first_bonus = if i == 0 { 0.15 } else { 0.0 };
268 lead_score + conclusion_bonus + first_bonus
269 };
270
271 let word_count = s.split_whitespace().count() as f64;
273 let length_factor = (word_count.ln() + 1.0).min(3.0) / 3.0;
274
275 ScoredSentence {
276 text: s.clone(),
277 index: i,
278 score: position_score * length_factor,
279 }
280 })
281 .collect()
282}
283
284pub fn score_tfidf(sentences: &[String]) -> Result<Vec<ScoredSentence>> {
292 let n = sentences.len();
293 if n == 0 {
294 return Ok(Vec::new());
295 }
296 if n == 1 {
297 return Ok(vec![ScoredSentence {
298 text: sentences[0].clone(),
299 index: 0,
300 score: 1.0,
301 }]);
302 }
303
304 let refs: Vec<&str> = sentences.iter().map(|s| s.as_str()).collect();
305 let mut vectorizer = TfidfVectorizer::default();
306 let tfidf = vectorizer.fit_transform(&refs)?;
307
308 let cols = tfidf.ncols();
309 if cols == 0 {
310 return Ok(sentences
311 .iter()
312 .enumerate()
313 .map(|(i, s)| ScoredSentence {
314 text: s.clone(),
315 index: i,
316 score: 0.0,
317 })
318 .collect());
319 }
320
321 Ok(sentences
322 .iter()
323 .enumerate()
324 .map(|(i, s)| {
325 let row_sum: f64 = (0..cols).map(|c| tfidf[[i, c]]).sum();
326 let avg = row_sum / cols as f64;
327 ScoredSentence {
328 text: s.clone(),
329 index: i,
330 score: avg,
331 }
332 })
333 .collect())
334}
335
336fn score_ensemble(
342 sentences: &[String],
343 textrank_weight: f64,
344 position_weight: f64,
345 tfidf_weight: f64,
346) -> Result<Vec<ScoredSentence>> {
347 let n = sentences.len();
348 if n == 0 {
349 return Ok(Vec::new());
350 }
351
352 let tr_scores = score_textrank(sentences)?;
353 let pos_scores = score_position(sentences);
354 let tfidf_scores = score_tfidf(sentences)?;
355
356 let tr_normalised = normalise_scores(&tr_scores);
358 let pos_normalised = normalise_scores(&pos_scores);
359 let tfidf_normalised = normalise_scores(&tfidf_scores);
360
361 let total_weight = textrank_weight + position_weight + tfidf_weight;
362 let tw = if total_weight > 0.0 {
363 textrank_weight / total_weight
364 } else {
365 1.0 / 3.0
366 };
367 let pw = if total_weight > 0.0 {
368 position_weight / total_weight
369 } else {
370 1.0 / 3.0
371 };
372 let iw = if total_weight > 0.0 {
373 tfidf_weight / total_weight
374 } else {
375 1.0 / 3.0
376 };
377
378 Ok((0..n)
379 .map(|i| ScoredSentence {
380 text: sentences[i].clone(),
381 index: i,
382 score: tw * tr_normalised[i] + pw * pos_normalised[i] + iw * tfidf_normalised[i],
383 })
384 .collect())
385}
386
387fn normalise_scores(scored: &[ScoredSentence]) -> Vec<f64> {
389 if scored.is_empty() {
390 return Vec::new();
391 }
392
393 let min = scored.iter().map(|s| s.score).fold(f64::INFINITY, f64::min);
394 let max = scored
395 .iter()
396 .map(|s| s.score)
397 .fold(f64::NEG_INFINITY, f64::max);
398
399 let range = max - min;
400 if range < 1e-12 {
401 return vec![0.5; scored.len()];
402 }
403
404 scored.iter().map(|s| (s.score - min) / range).collect()
405}
406
407#[cfg(test)]
412mod tests {
413 use super::*;
414
415 const SAMPLE_TEXT: &str = "Machine learning is a subset of artificial intelligence. \
416 It enables computers to learn from data without explicit programming. \
417 Deep learning is a subset of machine learning that uses neural networks. \
418 Neural networks are modeled loosely after the human brain. \
419 These technologies are transforming many industries today.";
420
421 #[test]
424 fn test_textrank_produces_shorter_summary() {
425 let summary =
426 summarize(SAMPLE_TEXT, 0.4, SummarizationMethod::TextRank).expect("Should succeed");
427 assert!(!summary.is_empty());
428 assert!(summary.len() < SAMPLE_TEXT.len());
429 }
430
431 #[test]
432 fn test_textrank_empty_text() {
433 let summary = summarize("", 0.5, SummarizationMethod::TextRank).expect("ok");
434 assert!(summary.is_empty());
435 }
436
437 #[test]
438 fn test_textrank_ratio_one_returns_full() {
439 let summary = summarize(SAMPLE_TEXT, 1.0, SummarizationMethod::TextRank).expect("ok");
440 assert_eq!(summary, SAMPLE_TEXT);
441 }
442
443 #[test]
444 fn test_textrank_ratio_zero_returns_one_sentence() {
445 let summary = summarize(SAMPLE_TEXT, 0.0, SummarizationMethod::TextRank).expect("ok");
446 assert!(!summary.is_empty());
448 let sentence_tokenizer = SentenceTokenizer::new();
450 let sentences = sentence_tokenizer.tokenize(&summary).expect("ok");
451 assert_eq!(sentences.len(), 1);
452 }
453
454 #[test]
455 fn test_textrank_single_sentence() {
456 let summary =
457 summarize("Just one sentence.", 0.5, SummarizationMethod::TextRank).expect("ok");
458 assert_eq!(summary, "Just one sentence.");
459 }
460
461 #[test]
462 fn test_textrank_scores_non_negative() {
463 let sentence_tokenizer = SentenceTokenizer::new();
464 let sentences = sentence_tokenizer.tokenize(SAMPLE_TEXT).expect("ok");
465 let scored = score_textrank(&sentences).expect("ok");
466 for s in &scored {
467 assert!(s.score >= 0.0, "Score should be non-negative");
468 }
469 }
470
471 #[test]
474 fn test_position_first_sentence_highest() {
475 let sentence_tokenizer = SentenceTokenizer::new();
476 let sentences = sentence_tokenizer.tokenize(SAMPLE_TEXT).expect("ok");
477 let scored = score_position(&sentences);
478 let first = &scored[0];
480 for s in &scored[1..] {
481 assert!(
482 first.score >= s.score,
483 "First sentence should have the highest position score"
484 );
485 }
486 }
487
488 #[test]
489 fn test_position_produces_summary() {
490 let summary = summarize(SAMPLE_TEXT, 0.4, SummarizationMethod::Position).expect("ok");
491 assert!(!summary.is_empty());
492 assert!(summary.len() < SAMPLE_TEXT.len());
493 }
494
495 #[test]
496 fn test_position_empty() {
497 let summary = summarize("", 0.5, SummarizationMethod::Position).expect("ok");
498 assert!(summary.is_empty());
499 }
500
501 #[test]
502 fn test_position_scores_non_negative() {
503 let sentence_tokenizer = SentenceTokenizer::new();
504 let sentences = sentence_tokenizer.tokenize(SAMPLE_TEXT).expect("ok");
505 let scored = score_position(&sentences);
506 for s in &scored {
507 assert!(s.score >= 0.0);
508 }
509 }
510
511 #[test]
512 fn test_position_last_sentence_has_conclusion_bonus() {
513 let sentence_tokenizer = SentenceTokenizer::new();
514 let sentences = sentence_tokenizer.tokenize(SAMPLE_TEXT).expect("ok");
515 let scored = score_position(&sentences);
516 let n = scored.len();
517 if n >= 2 {
518 let last_score = scored[n - 1].score;
521 let lead_alone = 1.0 - ((n - 1) as f64 / n as f64);
523 assert!(
525 last_score > lead_alone * 0.3,
526 "Last sentence should benefit from conclusion bonus"
527 );
528 }
529 }
530
531 #[test]
534 fn test_tfidf_produces_summary() {
535 let summary = summarize(SAMPLE_TEXT, 0.4, SummarizationMethod::TfIdf).expect("ok");
536 assert!(!summary.is_empty());
537 assert!(summary.len() < SAMPLE_TEXT.len());
538 }
539
540 #[test]
541 fn test_tfidf_empty() {
542 let summary = summarize("", 0.5, SummarizationMethod::TfIdf).expect("ok");
543 assert!(summary.is_empty());
544 }
545
546 #[test]
547 fn test_tfidf_single_sentence() {
548 let summary = summarize("Only one.", 0.5, SummarizationMethod::TfIdf).expect("ok");
549 assert_eq!(summary, "Only one.");
550 }
551
552 #[test]
553 fn test_tfidf_scores_non_negative() {
554 let sentence_tokenizer = SentenceTokenizer::new();
555 let sentences = sentence_tokenizer.tokenize(SAMPLE_TEXT).expect("ok");
556 let scored = score_tfidf(&sentences).expect("ok");
557 for s in &scored {
558 assert!(s.score >= 0.0);
559 }
560 }
561
562 #[test]
563 fn test_tfidf_ratio_half() {
564 let summary = summarize(SAMPLE_TEXT, 0.5, SummarizationMethod::TfIdf).expect("ok");
565 let sentence_tokenizer = SentenceTokenizer::new();
566 let original = sentence_tokenizer.tokenize(SAMPLE_TEXT).expect("ok");
567 let summarised = sentence_tokenizer.tokenize(&summary).expect("ok");
568 let expected = (original.len() as f64 * 0.5).ceil() as usize;
570 assert_eq!(summarised.len(), expected);
571 }
572
573 #[test]
576 fn test_ensemble_produces_summary() {
577 let method = SummarizationMethod::Ensemble {
578 textrank_weight: 1.0,
579 position_weight: 0.5,
580 tfidf_weight: 0.5,
581 };
582 let summary = summarize(SAMPLE_TEXT, 0.4, method).expect("ok");
583 assert!(!summary.is_empty());
584 assert!(summary.len() < SAMPLE_TEXT.len());
585 }
586
587 #[test]
588 fn test_ensemble_equal_weights() {
589 let method = SummarizationMethod::Ensemble {
590 textrank_weight: 1.0,
591 position_weight: 1.0,
592 tfidf_weight: 1.0,
593 };
594 let summary = summarize(SAMPLE_TEXT, 0.4, method).expect("ok");
595 assert!(!summary.is_empty());
596 }
597
598 #[test]
599 fn test_ensemble_zero_weights_defaults() {
600 let method = SummarizationMethod::Ensemble {
601 textrank_weight: 0.0,
602 position_weight: 0.0,
603 tfidf_weight: 0.0,
604 };
605 let summary = summarize(SAMPLE_TEXT, 0.4, method).expect("ok");
606 assert!(!summary.is_empty());
607 }
608
609 #[test]
610 fn test_ensemble_empty() {
611 let method = SummarizationMethod::Ensemble {
612 textrank_weight: 1.0,
613 position_weight: 1.0,
614 tfidf_weight: 1.0,
615 };
616 let summary = summarize("", 0.5, method).expect("ok");
617 assert!(summary.is_empty());
618 }
619
620 #[test]
621 fn test_ensemble_scores_bounded() {
622 let sentence_tokenizer = SentenceTokenizer::new();
623 let sentences = sentence_tokenizer.tokenize(SAMPLE_TEXT).expect("ok");
624 let scored = score_ensemble(&sentences, 1.0, 1.0, 1.0).expect("ok");
625 for s in &scored {
626 assert!(
627 s.score >= 0.0 && s.score <= 1.0,
628 "Ensemble scores should be in [0,1]"
629 );
630 }
631 }
632
633 #[test]
636 fn test_summary_preserves_order() {
637 let summary = summarize(SAMPLE_TEXT, 0.6, SummarizationMethod::TextRank).expect("ok");
638 let sentence_tokenizer = SentenceTokenizer::new();
639 let summary_sentences = sentence_tokenizer.tokenize(&summary).expect("ok");
640 let original_sentences = sentence_tokenizer.tokenize(SAMPLE_TEXT).expect("ok");
641
642 let mut last_idx: Option<usize> = None;
644 for ss in &summary_sentences {
645 let idx = original_sentences
646 .iter()
647 .position(|os| os.trim() == ss.trim());
648 if let (Some(li), Some(ci)) = (last_idx, idx) {
649 assert!(ci > li, "Summary sentences should be in original order");
650 }
651 last_idx = idx;
652 }
653 }
654}