1use crate::error::{Result, TextError};
6use crate::tokenize::Tokenizer;
7use crate::vectorize::{TfidfVectorizer, Vectorizer};
8use scirs2_core::ndarray::{Array1, Array2};
9use std::collections::HashSet;
10
11pub struct TextRank {
13 num_sentences: usize,
15 damping_factor: f64,
17 max_iterations: usize,
19 threshold: f64,
21 sentencetokenizer: Box<dyn Tokenizer + Send + Sync>,
23}
24
25impl TextRank {
26 pub fn new(_numsentences: usize) -> Self {
28 Self {
29 num_sentences: _numsentences,
30 damping_factor: 0.85,
31 max_iterations: 100,
32 threshold: 0.0001,
33 sentencetokenizer: Box::new(crate::tokenize::SentenceTokenizer::new()),
34 }
35 }
36
37 pub fn with_damping_factor(mut self, dampingfactor: f64) -> Result<Self> {
39 if !(0.0..=1.0).contains(&dampingfactor) {
40 return Err(TextError::InvalidInput(
41 "Damping _factor must be between 0 and 1".to_string(),
42 ));
43 }
44 self.damping_factor = dampingfactor;
45 Ok(self)
46 }
47
48 pub fn summarize(&self, text: &str) -> Result<String> {
50 let sentences: Vec<String> = self.sentencetokenizer.tokenize(text)?;
51
52 if sentences.is_empty() {
53 return Ok(String::new());
54 }
55
56 if sentences.len() <= self.num_sentences {
57 return Ok(text.to_string());
58 }
59
60 let similarity_matrix = self.build_similarity_matrix(&sentences)?;
62
63 let scores = self.page_rank(&similarity_matrix)?;
65
66 let selected_indices = self.select_top_sentences(&scores);
68
69 let summary = self.reconstruct_summary(&sentences, &selected_indices);
71
72 Ok(summary)
73 }
74
75 fn build_similarity_matrix(&self, sentences: &[String]) -> Result<Array2<f64>> {
77 let n = sentences.len();
78 let mut matrix = Array2::zeros((n, n));
79
80 let sentence_refs: Vec<&str> = sentences.iter().map(|s| s.as_ref()).collect();
82 let mut vectorizer = TfidfVectorizer::default();
83 vectorizer.fit(&sentence_refs)?;
84 let vectors = vectorizer.transform_batch(&sentence_refs)?;
85
86 for i in 0..n {
88 for j in 0..n {
89 if i == j {
90 matrix[[i, j]] = 0.0; } else {
92 let similarity = self
93 .cosine_similarity(vectors.row(i).to_owned(), vectors.row(j).to_owned());
94 matrix[[i, j]] = similarity;
95 }
96 }
97 }
98
99 Ok(matrix)
100 }
101
102 fn cosine_similarity(&self, vec1: Array1<f64>, vec2: Array1<f64>) -> f64 {
104 let dot_product = vec1.dot(&vec2);
105 let norm1 = vec1.dot(&vec1).sqrt();
106 let norm2 = vec2.dot(&vec2).sqrt();
107
108 if norm1 == 0.0 || norm2 == 0.0 {
109 0.0
110 } else {
111 dot_product / (norm1 * norm2)
112 }
113 }
114
115 fn page_rank(&self, matrix: &Array2<f64>) -> Result<Array1<f64>> {
117 let n = matrix.nrows();
118 let mut scores = Array1::from_elem(n, 1.0 / n as f64);
119
120 let mut normalized_matrix = matrix.clone();
122 for i in 0..n {
123 let row_sum: f64 = matrix.row(i).sum();
124 if row_sum > 0.0 {
125 normalized_matrix.row_mut(i).mapv_inplace(|x| x / row_sum);
126 }
127 }
128
129 for _ in 0..self.max_iterations {
131 let new_scores = Array1::from_elem(n, (1.0 - self.damping_factor) / n as f64)
132 + self.damping_factor * normalized_matrix.t().dot(&scores);
133
134 let diff = (&new_scores - &scores).mapv(f64::abs).sum();
136 scores = new_scores;
137
138 if diff < self.threshold {
139 break;
140 }
141 }
142
143 Ok(scores)
144 }
145
146 fn select_top_sentences(&self, scores: &Array1<f64>) -> Vec<usize> {
148 let mut indexed_scores: Vec<(usize, f64)> = scores
149 .iter()
150 .enumerate()
151 .map(|(i, &score)| (i, score))
152 .collect();
153
154 indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Operation failed"));
155
156 indexed_scores
157 .iter()
158 .take(self.num_sentences)
159 .map(|&(idx_, _)| idx_)
160 .collect()
161 }
162
163 fn reconstruct_summary(&self, sentences: &[String], indices: &[usize]) -> String {
165 let mut sorted_indices = indices.to_vec();
166 sorted_indices.sort_unstable();
167
168 sorted_indices
169 .iter()
170 .map(|&idx| sentences[idx].clone())
171 .collect::<Vec<_>>()
172 .join(" ")
173 }
174}
175
176pub struct CentroidSummarizer {
178 num_sentences: usize,
180 topic_threshold: f64,
182 redundancy_threshold: f64,
184 sentencetokenizer: Box<dyn Tokenizer + Send + Sync>,
186}
187
188impl CentroidSummarizer {
189 pub fn new(_numsentences: usize) -> Self {
191 Self {
192 num_sentences: _numsentences,
193 topic_threshold: 0.1,
194 redundancy_threshold: 0.95,
195 sentencetokenizer: Box::new(crate::tokenize::SentenceTokenizer::new()),
196 }
197 }
198
199 pub fn summarize(&self, text: &str) -> Result<String> {
201 let sentences: Vec<String> = self.sentencetokenizer.tokenize(text)?;
202
203 if sentences.is_empty() {
204 return Ok(String::new());
205 }
206
207 if sentences.len() <= self.num_sentences {
208 return Ok(text.to_string());
209 }
210
211 let sentence_refs: Vec<&str> = sentences.iter().map(|s| s.as_ref()).collect();
213 let mut vectorizer = TfidfVectorizer::default();
214 vectorizer.fit(&sentence_refs)?;
215 let vectors = vectorizer.transform_batch(&sentence_refs)?;
216
217 let centroid = self.calculate_centroid(&vectors);
219
220 let selected_indices = self.select_sentences(&vectors, ¢roid);
222
223 let summary = self.reconstruct_summary(&sentences, &selected_indices);
225
226 Ok(summary)
227 }
228
229 fn calculate_centroid(&self, vectors: &Array2<f64>) -> Array1<f64> {
231 let _n_docs = vectors.nrows();
232 let mut centroid = vectors
233 .mean_axis(scirs2_core::ndarray::Axis(0))
234 .expect("Operation failed");
235
236 centroid.mapv_inplace(|x| if x > self.topic_threshold { x } else { 0.0 });
238
239 centroid
240 }
241
242 fn select_sentences(&self, vectors: &Array2<f64>, centroid: &Array1<f64>) -> Vec<usize> {
244 let mut selected = Vec::new();
245 let mut used_sentences = HashSet::new();
246
247 let mut similarities: Vec<(usize, f64)> = Vec::new();
249 for i in 0..vectors.nrows() {
250 let similarity = self.cosine_similarity(vectors.row(i).to_owned(), centroid.clone());
251 similarities.push((i, similarity));
252 }
253
254 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Operation failed"));
256
257 for (idx_, _similarity) in similarities {
259 if selected.len() >= self.num_sentences {
260 break;
261 }
262
263 let mut is_redundant = false;
265 for &selected_idx in &selected {
266 let sim = self.cosine_similarity(
267 vectors.row(idx_).to_owned(),
268 vectors.row(selected_idx).to_owned(),
269 );
270 if sim > self.redundancy_threshold {
271 is_redundant = true;
272 break;
273 }
274 }
275
276 if !is_redundant {
277 selected.push(idx_);
278 used_sentences.insert(idx_);
279 }
280 }
281
282 selected
283 }
284
285 fn cosine_similarity(&self, vec1: Array1<f64>, vec2: Array1<f64>) -> f64 {
287 let dot_product = vec1.dot(&vec2);
288 let norm1 = vec1.dot(&vec1).sqrt();
289 let norm2 = vec2.dot(&vec2).sqrt();
290
291 if norm1 == 0.0 || norm2 == 0.0 {
292 0.0
293 } else {
294 dot_product / (norm1 * norm2)
295 }
296 }
297
298 fn reconstruct_summary(&self, sentences: &[String], indices: &[usize]) -> String {
300 let mut sorted_indices = indices.to_vec();
301 sorted_indices.sort_unstable();
302
303 sorted_indices
304 .iter()
305 .map(|&idx| sentences[idx].clone())
306 .collect::<Vec<_>>()
307 .join(" ")
308 }
309}
310
311pub struct KeywordExtractor {
313 _numkeywords: usize,
315 #[allow(dead_code)]
317 min_df: f64,
318 #[allow(dead_code)]
320 max_df: f64,
321 ngram_range: (usize, usize),
323}
324
325impl KeywordExtractor {
326 pub fn new(_numkeywords: usize) -> Self {
328 Self {
329 _numkeywords,
330 min_df: 0.01, max_df: 0.95, ngram_range: (1, 3),
333 }
334 }
335
336 pub fn with_ngram_range(mut self, min_n: usize, maxn: usize) -> Result<Self> {
338 if min_n > maxn || min_n == 0 {
339 return Err(TextError::InvalidInput("Invalid _n-gram range".to_string()));
340 }
341 self.ngram_range = (min_n, maxn);
342 Ok(self)
343 }
344
345 pub fn extract_keywords(&self, text: &str) -> Result<Vec<(String, f64)>> {
347 let sentence_tokenizer = crate::tokenize::SentenceTokenizer::new();
349 let sentences = sentence_tokenizer.tokenize(text)?;
350
351 if sentences.is_empty() {
352 return Ok(Vec::new());
353 }
354
355 let sentence_refs: Vec<&str> = sentences.iter().map(|s| s.as_ref()).collect();
356
357 let mut vectorizer = crate::enhanced_vectorize::EnhancedTfidfVectorizer::new()
360 .set_ngram_range((self.ngram_range.0, self.ngram_range.1))?;
361
362 vectorizer.fit(&sentence_refs)?;
363 let tfidf_matrix = vectorizer.transform_batch(&sentence_refs)?;
364
365 let avg_tfidf = tfidf_matrix
367 .mean_axis(scirs2_core::ndarray::Axis(0))
368 .expect("Operation failed");
369
370 let all_words: Vec<String> = text.split_whitespace().map(|w| w.to_string()).collect();
372
373 let mut keyword_scores: Vec<(String, f64)> = avg_tfidf
375 .iter()
376 .enumerate()
377 .take(self._numkeywords * 2) .map(|(i, &score)| {
379 let term = if i < all_words.len() {
380 all_words[i].clone()
381 } else {
382 format!("term_{i}")
383 };
384 (term, score)
385 })
386 .collect();
387
388 keyword_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Operation failed"));
390
391 Ok(keyword_scores.into_iter().take(self._numkeywords).collect())
393 }
394
395 pub fn extract_keywords_with_positions(
397 &self,
398 text: &str,
399 ) -> Result<Vec<(String, f64, Vec<usize>)>> {
400 let keywords = self.extract_keywords(text)?;
401 let mut results = Vec::new();
402
403 for (keyword, score) in keywords {
404 let positions = self.find_keyword_positions(text, &keyword);
405 results.push((keyword, score, positions));
406 }
407
408 Ok(results)
409 }
410
411 fn find_keyword_positions(&self, text: &str, keyword: &str) -> Vec<usize> {
413 let mut positions = Vec::new();
414 let text_lower = text.to_lowercase();
415 let keyword_lower = keyword.to_lowercase();
416
417 let mut start = 0;
418 while let Some(pos) = text_lower[start..].find(&keyword_lower) {
419 positions.push(start + pos);
420 start += pos + keyword.len();
421 }
422
423 positions
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 #[test]
432 fn testtextrank_summarizer() {
433 let summarizer = TextRank::new(2);
434 let text = "Machine learning is a subset of artificial intelligence. \
435 It enables computers to learn from data. \
436 Deep learning is a subset of machine learning. \
437 Neural networks are used in deep learning. \
438 These technologies are transforming many industries.";
439
440 let summary = summarizer.summarize(text).expect("Operation failed");
441 assert!(!summary.is_empty());
442 assert!(summary.len() < text.len());
443 }
444
445 #[test]
446 fn test_centroid_summarizer() {
447 let summarizer = CentroidSummarizer::new(2);
448 let text = "Natural language processing is important. \
449 It helps computers understand human language. \
450 Many applications use NLP technology. \
451 Chatbots and translation are examples. \
452 NLP continues to evolve rapidly.";
453
454 let summary = summarizer.summarize(text).expect("Operation failed");
455 assert!(!summary.is_empty());
456 }
457
458 #[test]
459 fn test_keyword_extraction() {
460 let extractor = KeywordExtractor::new(5);
461 let text = "Machine learning algorithms are essential for artificial intelligence. \
462 Deep learning models use neural networks. \
463 These models can process complex data patterns.";
464
465 let keywords = extractor.extract_keywords(text).expect("Operation failed");
466 assert!(!keywords.is_empty());
467 assert!(keywords.len() <= 5);
468
469 for i in 1..keywords.len() {
471 assert!(keywords[i - 1].1 >= keywords[i].1);
472 }
473 }
474
475 #[test]
476 fn test_keyword_positions() {
477 let extractor = KeywordExtractor::new(3);
478 let text = "Machine learning is great. Machine learning transforms industries.";
479
480 let keywords_with_pos = extractor
481 .extract_keywords_with_positions(text)
482 .expect("Operation failed");
483
484 for (keyword, _score, positions) in keywords_with_pos {
486 if keyword.to_lowercase().contains("machine learning") {
487 assert!(positions.len() >= 2);
488 }
489 }
490 }
491
492 #[test]
493 fn test_emptytext() {
494 let textrank = TextRank::new(3);
495 let centroid = CentroidSummarizer::new(3);
496 let keywords = KeywordExtractor::new(5);
497
498 assert_eq!(textrank.summarize("").expect("Operation failed"), "");
499 assert_eq!(centroid.summarize("").expect("Operation failed"), "");
500 assert_eq!(
501 keywords
502 .extract_keywords("")
503 .expect("Operation failed")
504 .len(),
505 0
506 );
507 }
508
509 #[test]
510 fn test_shorttext() {
511 let summarizer = TextRank::new(5);
512 let shorttext = "This is a short text.";
513
514 let summary = summarizer.summarize(shorttext).expect("Operation failed");
515 assert_eq!(summary, shorttext);
516 }
517}