yake_rust/
lib.rs

1#![cfg_attr(not(doctest), doc = include_str!("../README.md"))]
2#![allow(clippy::len_zero)]
3#![allow(clippy::type_complexity)]
4#![deny(rustdoc::broken_intra_doc_links)]
5#![deny(rustdoc::private_intra_doc_links)]
6#![deny(unused_imports)]
7#![warn(missing_docs)]
8#![allow(clippy::needless_doctest_main)]
9
10use std::collections::VecDeque;
11
12use hashbrown::hash_set::Entry;
13use hashbrown::{HashMap, HashSet};
14use indexmap::IndexMap;
15use itertools::Itertools;
16use plural_helper::PluralHelper;
17use preprocessor::{split_into_sentences, split_into_words};
18use stats::{median, OnlineStats};
19
20use crate::context::Contexts;
21pub use crate::result_item::*;
22pub use crate::stopwords::StopWords;
23use crate::tag::*;
24
25mod context;
26mod counter;
27mod plural_helper;
28mod preprocessor;
29mod result_item;
30mod stopwords;
31mod tag;
32
33#[cfg(test)]
34mod tests;
35
36/// String from the original text
37type RawString = String;
38
39/// Lowercased string
40type LTerm = String;
41
42/// Lowercased string without punctuation symbols in single form
43type UTerm = String;
44
45type Sentences = Vec<Sentence>;
46type Candidates<'s> = IndexMap<&'s [LTerm], Candidate<'s>>;
47type Features<'s> = HashMap<&'s UTerm, TermScore>;
48type Words<'s> = HashMap<&'s UTerm, Vec<Occurrence<'s>>>;
49
50#[derive(PartialEq, Eq, Hash, Debug)]
51struct Occurrence<'sentence> {
52    /// Index (0..) of sentence where the term occur
53    pub sentence_idx: usize,
54    /// The word itself
55    pub word: &'sentence RawString,
56    pub tag: Tag,
57}
58
59#[derive(Debug, Default)]
60struct TermScore {
61    /// Term frequency. The total number of occurrences in the text.
62    tf: f64,
63    /// Importance score. The less, the better
64    score: f64,
65}
66
67#[derive(Debug, Default)]
68struct TermStats {
69    /// Term frequency. The total number of occurrences in the text.
70    tf: f64,
71    /// The number of times this candidate term is marked as an acronym (=all uppercase).
72    tf_a: f64,
73    /// The number of occurrences of this candidate term starting with an uppercase letter.
74    tf_n: f64,
75    /// Term casing heuristic.
76    casing: f64,
77    /// Term position heuristic
78    position: f64,
79    /// Normalized term frequency heuristic
80    frequency: f64,
81    /// Term relatedness to context
82    relatedness: f64,
83    /// Term's different sentences heuristic
84    sentences: f64,
85    /// Importance score. The less, the better
86    score: f64,
87}
88
89impl From<TermStats> for TermScore {
90    fn from(stats: TermStats) -> Self {
91        Self { tf: stats.tf, score: stats.score }
92    }
93}
94
95#[derive(Debug, Clone)]
96struct Sentence {
97    pub words: Vec<RawString>,
98    pub lc_terms: Vec<LTerm>,
99    pub uq_terms: Vec<UTerm>,
100    pub tags: Vec<Tag>,
101}
102
103/// N-gram, a sequence of N terms.
104#[derive(Debug, Default, Clone)]
105struct Candidate<'s> {
106    pub occurrences: usize,
107    pub raw: &'s [RawString],
108    pub lc_terms: &'s [LTerm],
109    pub uq_terms: &'s [UTerm],
110    pub score: f64,
111}
112
113/// Fine-tunes keyword extraction.
114#[derive(Debug, Clone, PartialEq)]
115pub struct Config {
116    /// How many words a key phrase may contain.
117    ///
118    /// _n-gram_ is a contiguous sequence of _n_ words occurring in the text.
119    pub ngrams: usize,
120    /// List of punctuation symbols.
121    ///
122    /// They are known as _exclude chars_ in the [original implementation](https://github.com/LIAAD/yake).
123    pub punctuation: std::collections::HashSet<char>,
124
125    /// The number of tokens both preceding and following a term to calculate _term dispersion_ metric.
126    pub window_size: usize,
127    /// When `true`, calculate _term casing_ metric by counting capitalized terms _without_
128    /// intermediate uppercase letters. Thus, `Paypal` is counted while `PayPal` is not.
129    ///
130    /// The [original implementation](https://github.com/LIAAD/yake) sticks with `true`.
131    pub strict_capital: bool,
132
133    /// When `true`, key phrases are allowed to have only alphanumeric characters and hyphen.
134    pub only_alphanumeric_and_hyphen: bool,
135    /// Key phrases can't be too short, less than `minimum_chars` in total.
136    pub minimum_chars: usize,
137
138    /// When `true`, similar key phrases are deduplicated.
139    ///
140    /// Key phrases are considered similar if their Levenshtein distance is greater than
141    /// [`deduplication_threshold`](Config::deduplication_threshold).
142    pub remove_duplicates: bool,
143    /// A threshold in range 0..1. Equal strings have the distance equal to 1.
144    ///
145    /// Effective only when [`remove_duplicates`](Config::remove_duplicates) is `true`.
146    pub deduplication_threshold: f64,
147}
148
149impl Default for Config {
150    fn default() -> Self {
151        Self {
152            punctuation: r##"!"#$%&'()*+,-./:,<=>?@[\]^_`{|}~"##.chars().collect(),
153            deduplication_threshold: 0.9,
154            ngrams: 3,
155            remove_duplicates: true,
156            window_size: 1,
157            strict_capital: true,
158            only_alphanumeric_and_hyphen: false,
159            minimum_chars: 3,
160        }
161    }
162}
163
164/// Extract the top N most important key phrases from the text.
165///
166/// If you need all the keywords, pass [`usize::MAX`].
167pub fn get_n_best(n: usize, text: &str, stop_words: &StopWords, config: &Config) -> Vec<ResultItem> {
168    Yake::new(stop_words.clone(), config.clone()).get_n_best(text, n)
169}
170
171#[derive(Debug, Clone)]
172struct Yake {
173    config: Config,
174    stop_words: StopWords,
175}
176
177impl Yake {
178    pub fn new(stop_words: StopWords, config: Config) -> Yake {
179        Self { config, stop_words }
180    }
181
182    fn get_n_best(&self, text: &str, n: usize) -> Vec<ResultItem> {
183        let sentences = self.preprocess_text(text);
184
185        let (context, vocabulary) = self.build_context_and_vocabulary(&sentences);
186        let features = self.extract_features(&context, vocabulary, &sentences);
187
188        let mut ngrams: Candidates = self.ngram_selection(self.config.ngrams, &sentences);
189        self.candidate_weighting(features, &context, &mut ngrams);
190
191        let mut results: Vec<ResultItem> = ngrams.into_values().map(Into::into).collect();
192        results.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
193
194        if self.config.remove_duplicates {
195            remove_duplicates(self.config.deduplication_threshold, results, n)
196        } else {
197            results.truncate(n);
198            results
199        }
200    }
201
202    fn get_unique_term(&self, word: &str) -> UTerm {
203        word.to_single().to_lowercase()
204    }
205
206    fn is_stopword(&self, lc_term: &LTerm) -> bool {
207        self.stop_words.contains(lc_term)
208            || self.stop_words.contains(lc_term.to_single())
209            // having less than 3 non-punctuation symbols is typical for stop words
210            || lc_term.to_single().chars().filter(|ch| !self.config.punctuation.contains(ch)).count() < 3
211    }
212
213    fn preprocess_text(&self, text: &str) -> Sentences {
214        split_into_sentences(text)
215            .into_iter()
216            .map(|sentence| {
217                let words = split_into_words(&sentence);
218                let lc_terms = words.iter().map(|w| w.to_lowercase()).collect::<Vec<LTerm>>();
219                let uq_terms = lc_terms.iter().map(|w| self.get_unique_term(w)).collect();
220                let tags = words.iter().enumerate().map(|(w_idx, w)| Tag::from(w, w_idx == 0, &self.config)).collect();
221                Sentence { words, lc_terms, uq_terms, tags }
222            })
223            .collect()
224    }
225
226    fn build_context_and_vocabulary<'s>(&self, sentences: &'s [Sentence]) -> (Contexts<'s>, Words<'s>) {
227        let mut ctx = Contexts::default();
228        let mut words = Words::new();
229
230        for (idx, sentence) in sentences.iter().enumerate() {
231            let mut window: VecDeque<(&UTerm, Tag)> = VecDeque::with_capacity(self.config.window_size + 1);
232
233            for ((word, term), &tag) in sentence.words.iter().zip(&sentence.uq_terms).zip(&sentence.tags) {
234                if tag == Tag::Punctuation {
235                    window.clear();
236                    continue;
237                }
238
239                let occurrence = Occurrence { sentence_idx: idx, word, tag };
240                words.entry(term).or_default().push(occurrence);
241
242                // Do not store in contexts in any way if the word (not the unique term) is tagged "d" or "u"
243                if tag != Tag::Digit && tag != Tag::Unparsable {
244                    for &(left_uterm, left_tag) in window.iter() {
245                        if left_tag == Tag::Digit || left_tag == Tag::Unparsable {
246                            continue;
247                        }
248
249                        ctx.track(left_uterm, term);
250                    }
251                }
252
253                if window.len() == self.config.window_size {
254                    window.pop_front();
255                }
256                window.push_back((term, tag));
257            }
258        }
259
260        (ctx, words)
261    }
262
263    /// Computes local statistic features that extract informative content within the text
264    /// to calculate the importance of single terms.
265    fn extract_features<'s>(&self, ctx: &Contexts, words: Words<'s>, sentences: &'s Sentences) -> Features<'s> {
266        let candidate_words: HashMap<_, _> = sentences
267            .iter()
268            .flat_map(|sentence| sentence.lc_terms.iter().zip(&sentence.uq_terms).zip(&sentence.tags))
269            .filter(|&(_, &tag)| tag != Tag::Punctuation)
270            .map(|p| p.0)
271            .collect();
272
273        let non_stop_words: HashMap<&UTerm, usize> = candidate_words
274            .iter()
275            .filter(|&(lc, _)| !self.is_stopword(lc))
276            .map(|(_, &uq)| {
277                let occurrences = words.get(uq).unwrap().len();
278                (uq, occurrences)
279            })
280            .collect();
281
282        let (nsw_tf_std, nsw_tf_mean) = {
283            let tfs: OnlineStats = non_stop_words.values().map(|&freq| freq as f64).collect();
284            (tfs.stddev(), tfs.mean())
285        };
286
287        let max_tf = words.values().map(Vec::len).max().unwrap_or(0) as f64;
288
289        let mut features = Features::new();
290
291        for (_, u_term) in candidate_words {
292            let occurrences = words.get(u_term).unwrap();
293            let mut stats = TermStats { tf: occurrences.len() as f64, ..Default::default() };
294
295            {
296                // We consider the occurrence of acronyms through a heuristic, where all the letters of the word are capitals.
297                stats.tf_a = occurrences.iter().map(|occ| occ.tag).filter(|&tag| tag == Tag::Acronym).count() as f64;
298                // We give extra attention to any term beginning with a capital letter (excluding the beginning of sentences).
299                stats.tf_n = occurrences.iter().map(|occ| occ.tag).filter(|&tag| tag == Tag::Uppercase).count() as f64;
300
301                // The casing aspect of a term is an important feature when considering the extraction
302                // of keywords. The underlying rationale is that uppercase terms tend to be more
303                // relevant than lowercase ones.
304                stats.casing = stats.tf_a.max(stats.tf_n);
305
306                // The more often the candidate term occurs with a capital letter, the more important
307                // it is considered. This means that a candidate term that occurs with a capital letter
308                // ten times within ten occurrences will be given a higher value (4.34) than a candidate
309                // term that occurs with a capital letter five times within five occurrences (3.10).
310                stats.casing /= 1.0 + stats.tf.ln();
311            }
312
313            {
314                // Another indicator of the importance of a candidate term is its position.
315                // The rationale is that relevant keywords tend to appear at the very beginning
316                // of a document, whereas words occurring in the middle or at the end of a document
317                // tend to be less important.
318                //
319                // This is particularly evident for both news articles and scientific texts,
320                // which tend to concentrate a high degree of important
321                // keywords at the top of the text (e.g., in the introduction or abstract).
322                //
323                // Like Florescu and Caragea, who posit that models that take into account the positions
324                // of terms perform better than those that only use the first position or no position
325                // at all, we also consider a term’s position to be an important feature. However,
326                // unlike their model, we do not consider the positions of the terms,
327                // but of the sentences in which the terms occur.
328                //
329                // Our assumption is that terms that occur in the early
330                // sentences of a text should be more highly valued than terms that appear later. Thus,
331                // instead of considering a uniform distribution of terms, our model assigns higher
332                // scores to terms found in early sentences.
333                let sentence_ids = occurrences.iter().map(|o| o.sentence_idx).dedup();
334                // When the candidate term only appears in the first sentence, the median function
335                // can return a value of 0. To guarantee position > 0, a constant 3 > e=2.71 is added.
336                stats.position = 3.0 + median(sentence_ids).unwrap();
337                // A double log is applied in order to smooth the difference between terms that occur
338                // with a large median difference.
339                stats.position = stats.position.ln().ln();
340            }
341
342            {
343                // The higher the frequency of a candidate term, the greater its importance.
344                stats.frequency = stats.tf;
345                // To prevent a bias towards high frequency in long documents, we balance it.
346                // The mean and the standard deviation is calculated from non-stopwords terms,
347                // as stopwords usually have high frequencies.
348                stats.frequency /= nsw_tf_mean + nsw_tf_std;
349            }
350
351            {
352                let (dl, dr) = ctx.dispersion_of(u_term);
353                stats.relatedness = 1.0 + (dr + dl) * (stats.tf / max_tf);
354            }
355
356            {
357                // Candidates which appear in many different sentences have a higher probability
358                // of being important.
359                let distinct = occurrences.iter().map(|o| o.sentence_idx).dedup().count();
360                stats.sentences = distinct as f64 / sentences.len() as f64;
361            }
362
363            stats.score = (stats.relatedness * stats.position)
364                / (stats.casing + (stats.frequency / stats.relatedness) + (stats.sentences / stats.relatedness));
365
366            features.insert(u_term, stats.into());
367        }
368
369        features
370    }
371
372    fn candidate_weighting<'s>(&self, features: Features<'s>, ctx: &Contexts<'s>, candidates: &mut Candidates<'s>) {
373        for (&lc_terms, candidate) in candidates.iter_mut() {
374            let uq_terms = candidate.uq_terms;
375            let mut prod_ = 1.0;
376            let mut sum_ = 0.0;
377
378            for (j, (lc, uq)) in lc_terms.iter().zip(uq_terms).enumerate() {
379                if self.is_stopword(lc) {
380                    let prob_prev = match uq_terms.get(j - 1) {
381                        None => 0.0,
382                        Some(prev_uq) => {
383                            // #previous term occurring before this one / #previous term
384                            ctx.cases_term_is_followed(prev_uq, uq) as f64 / features.get(&prev_uq).unwrap().tf
385                        }
386                    };
387
388                    let prob_succ = match uq_terms.get(j + 1) {
389                        None => 0.0,
390                        Some(next_uq) => {
391                            // #next term occurring after this one / #next term
392                            ctx.cases_term_is_followed(uq, next_uq) as f64 / features.get(&next_uq).unwrap().tf
393                        }
394                    };
395
396                    let prob = prob_prev * prob_succ;
397                    prod_ *= 1.0 + (1.0 - prob);
398                    sum_ -= 1.0 - prob;
399                } else if let Some(stats) = features.get(uq) {
400                    prod_ *= stats.score;
401                    sum_ += stats.score;
402                }
403            }
404
405            if sum_ == -1.0 {
406                sum_ = 0.999999999;
407            }
408
409            let tf = candidate.occurrences as f64;
410            candidate.score = prod_ / (tf * (1.0 + sum_));
411        }
412    }
413
414    fn is_candidate(&self, lc_terms: &[LTerm], tags: &[Tag]) -> bool {
415        let is_bad =
416            // has a bad tag
417            tags.iter().any(|tag| matches!(tag, Tag::Digit | Tag::Punctuation | Tag::Unparsable))
418            // the last word is a stopword
419            || self.is_stopword(lc_terms.last().unwrap())
420            // not enough symbols in total
421            || lc_terms.iter().map(|w| w.chars().count()).sum::<usize>() < self.config.minimum_chars
422            // has non-alphanumeric characters
423            || self.config.only_alphanumeric_and_hyphen && !lc_terms.iter().all(word_is_alphanumeric_and_hyphen);
424
425        !is_bad
426    }
427
428    fn ngram_selection<'s>(&self, n: usize, sentences: &'s Sentences) -> Candidates<'s> {
429        let mut candidates = Candidates::new();
430        let mut ignored = HashSet::new();
431
432        for sentence in sentences.iter() {
433            let length = sentence.words.len();
434
435            for j in 0..length {
436                if self.is_stopword(&sentence.lc_terms[j]) {
437                    continue; // all further n-grams starting with this word can't be candidates
438                }
439
440                for k in (j + 1..length + 1).take(n) {
441                    let lc_terms = &sentence.lc_terms[j..k];
442
443                    if let Entry::Vacant(e) = ignored.entry(lc_terms) {
444                        if !self.is_candidate(lc_terms, &sentence.tags[j..k]) {
445                            e.insert();
446                        } else {
447                            candidates
448                                .entry(lc_terms)
449                                .or_insert_with(|| Candidate {
450                                    lc_terms,
451                                    uq_terms: &sentence.uq_terms[j..k],
452                                    raw: &sentence.words[j..k],
453                                    ..Default::default()
454                                })
455                                .occurrences += 1;
456                        }
457                    };
458                }
459            }
460        }
461
462        candidates
463    }
464}
465
466fn word_is_alphanumeric_and_hyphen(word: impl AsRef<str>) -> bool {
467    word.as_ref().chars().all(|ch| ch.is_alphanumeric() || ch == '-')
468}