1#![cfg_attr(not(doctest), doc = include_str!("../README.md"))]
2#![allow(clippy::len_zero)]
3#![allow(clippy::type_complexity)]
4#![deny(rustdoc::broken_intra_doc_links)]
5#![deny(rustdoc::private_intra_doc_links)]
6#![deny(unused_imports)]
7#![warn(missing_docs)]
8#![allow(clippy::needless_doctest_main)]
9
10use std::collections::VecDeque;
11
12use hashbrown::hash_set::Entry;
13use hashbrown::{HashMap, HashSet};
14use indexmap::IndexMap;
15use itertools::Itertools;
16use plural_helper::PluralHelper;
17use preprocessor::{split_into_sentences, split_into_words};
18use stats::{median, OnlineStats};
19
20use crate::context::Contexts;
21pub use crate::result_item::*;
22pub use crate::stopwords::StopWords;
23use crate::tag::*;
24
25mod context;
26mod counter;
27mod plural_helper;
28mod preprocessor;
29mod result_item;
30mod stopwords;
31mod tag;
32
33#[cfg(test)]
34mod tests;
35
36type RawString = String;
38
39type LTerm = String;
41
42type UTerm = String;
44
45type Sentences = Vec<Sentence>;
46type Candidates<'s> = IndexMap<&'s [LTerm], Candidate<'s>>;
47type Features<'s> = HashMap<&'s UTerm, TermScore>;
48type Words<'s> = HashMap<&'s UTerm, Vec<Occurrence<'s>>>;
49
50#[derive(PartialEq, Eq, Hash, Debug)]
51struct Occurrence<'sentence> {
52 pub sentence_idx: usize,
54 pub word: &'sentence RawString,
56 pub tag: Tag,
57}
58
59#[derive(Debug, Default)]
60struct TermScore {
61 tf: f64,
63 score: f64,
65}
66
67#[derive(Debug, Default)]
68struct TermStats {
69 tf: f64,
71 tf_a: f64,
73 tf_n: f64,
75 casing: f64,
77 position: f64,
79 frequency: f64,
81 relatedness: f64,
83 sentences: f64,
85 score: f64,
87}
88
89impl From<TermStats> for TermScore {
90 fn from(stats: TermStats) -> Self {
91 Self { tf: stats.tf, score: stats.score }
92 }
93}
94
95#[derive(Debug, Clone)]
96struct Sentence {
97 pub words: Vec<RawString>,
98 pub lc_terms: Vec<LTerm>,
99 pub uq_terms: Vec<UTerm>,
100 pub tags: Vec<Tag>,
101}
102
103#[derive(Debug, Default, Clone)]
105struct Candidate<'s> {
106 pub occurrences: usize,
107 pub raw: &'s [RawString],
108 pub lc_terms: &'s [LTerm],
109 pub uq_terms: &'s [UTerm],
110 pub score: f64,
111}
112
113#[derive(Debug, Clone, PartialEq)]
115pub struct Config {
116 pub ngrams: usize,
120 pub punctuation: std::collections::HashSet<char>,
124
125 pub window_size: usize,
127 pub strict_capital: bool,
132
133 pub only_alphanumeric_and_hyphen: bool,
135 pub minimum_chars: usize,
137
138 pub remove_duplicates: bool,
143 pub deduplication_threshold: f64,
147}
148
149impl Default for Config {
150 fn default() -> Self {
151 Self {
152 punctuation: r##"!"#$%&'()*+,-./:,<=>?@[\]^_`{|}~"##.chars().collect(),
153 deduplication_threshold: 0.9,
154 ngrams: 3,
155 remove_duplicates: true,
156 window_size: 1,
157 strict_capital: true,
158 only_alphanumeric_and_hyphen: false,
159 minimum_chars: 3,
160 }
161 }
162}
163
164pub fn get_n_best(n: usize, text: &str, stop_words: &StopWords, config: &Config) -> Vec<ResultItem> {
168 Yake::new(stop_words.clone(), config.clone()).get_n_best(text, n)
169}
170
171#[derive(Debug, Clone)]
172struct Yake {
173 config: Config,
174 stop_words: StopWords,
175}
176
177impl Yake {
178 pub fn new(stop_words: StopWords, config: Config) -> Yake {
179 Self { config, stop_words }
180 }
181
182 fn get_n_best(&self, text: &str, n: usize) -> Vec<ResultItem> {
183 let sentences = self.preprocess_text(text);
184
185 let (context, vocabulary) = self.build_context_and_vocabulary(&sentences);
186 let features = self.extract_features(&context, vocabulary, &sentences);
187
188 let mut ngrams: Candidates = self.ngram_selection(self.config.ngrams, &sentences);
189 self.candidate_weighting(features, &context, &mut ngrams);
190
191 let mut results: Vec<ResultItem> = ngrams.into_values().map(Into::into).collect();
192 results.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
193
194 if self.config.remove_duplicates {
195 remove_duplicates(self.config.deduplication_threshold, results, n)
196 } else {
197 results.truncate(n);
198 results
199 }
200 }
201
202 fn get_unique_term(&self, word: &str) -> UTerm {
203 word.to_single().to_lowercase()
204 }
205
206 fn is_stopword(&self, lc_term: <erm) -> bool {
207 self.stop_words.contains(lc_term)
208 || self.stop_words.contains(lc_term.to_single())
209 || lc_term.to_single().chars().filter(|ch| !self.config.punctuation.contains(ch)).count() < 3
211 }
212
213 fn preprocess_text(&self, text: &str) -> Sentences {
214 split_into_sentences(text)
215 .into_iter()
216 .map(|sentence| {
217 let words = split_into_words(&sentence);
218 let lc_terms = words.iter().map(|w| w.to_lowercase()).collect::<Vec<LTerm>>();
219 let uq_terms = lc_terms.iter().map(|w| self.get_unique_term(w)).collect();
220 let tags = words.iter().enumerate().map(|(w_idx, w)| Tag::from(w, w_idx == 0, &self.config)).collect();
221 Sentence { words, lc_terms, uq_terms, tags }
222 })
223 .collect()
224 }
225
226 fn build_context_and_vocabulary<'s>(&self, sentences: &'s [Sentence]) -> (Contexts<'s>, Words<'s>) {
227 let mut ctx = Contexts::default();
228 let mut words = Words::new();
229
230 for (idx, sentence) in sentences.iter().enumerate() {
231 let mut window: VecDeque<(&UTerm, Tag)> = VecDeque::with_capacity(self.config.window_size + 1);
232
233 for ((word, term), &tag) in sentence.words.iter().zip(&sentence.uq_terms).zip(&sentence.tags) {
234 if tag == Tag::Punctuation {
235 window.clear();
236 continue;
237 }
238
239 let occurrence = Occurrence { sentence_idx: idx, word, tag };
240 words.entry(term).or_default().push(occurrence);
241
242 if tag != Tag::Digit && tag != Tag::Unparsable {
244 for &(left_uterm, left_tag) in window.iter() {
245 if left_tag == Tag::Digit || left_tag == Tag::Unparsable {
246 continue;
247 }
248
249 ctx.track(left_uterm, term);
250 }
251 }
252
253 if window.len() == self.config.window_size {
254 window.pop_front();
255 }
256 window.push_back((term, tag));
257 }
258 }
259
260 (ctx, words)
261 }
262
263 fn extract_features<'s>(&self, ctx: &Contexts, words: Words<'s>, sentences: &'s Sentences) -> Features<'s> {
266 let candidate_words: HashMap<_, _> = sentences
267 .iter()
268 .flat_map(|sentence| sentence.lc_terms.iter().zip(&sentence.uq_terms).zip(&sentence.tags))
269 .filter(|&(_, &tag)| tag != Tag::Punctuation)
270 .map(|p| p.0)
271 .collect();
272
273 let non_stop_words: HashMap<&UTerm, usize> = candidate_words
274 .iter()
275 .filter(|&(lc, _)| !self.is_stopword(lc))
276 .map(|(_, &uq)| {
277 let occurrences = words.get(uq).unwrap().len();
278 (uq, occurrences)
279 })
280 .collect();
281
282 let (nsw_tf_std, nsw_tf_mean) = {
283 let tfs: OnlineStats = non_stop_words.values().map(|&freq| freq as f64).collect();
284 (tfs.stddev(), tfs.mean())
285 };
286
287 let max_tf = words.values().map(Vec::len).max().unwrap_or(0) as f64;
288
289 let mut features = Features::new();
290
291 for (_, u_term) in candidate_words {
292 let occurrences = words.get(u_term).unwrap();
293 let mut stats = TermStats { tf: occurrences.len() as f64, ..Default::default() };
294
295 {
296 stats.tf_a = occurrences.iter().map(|occ| occ.tag).filter(|&tag| tag == Tag::Acronym).count() as f64;
298 stats.tf_n = occurrences.iter().map(|occ| occ.tag).filter(|&tag| tag == Tag::Uppercase).count() as f64;
300
301 stats.casing = stats.tf_a.max(stats.tf_n);
305
306 stats.casing /= 1.0 + stats.tf.ln();
311 }
312
313 {
314 let sentence_ids = occurrences.iter().map(|o| o.sentence_idx).dedup();
334 stats.position = 3.0 + median(sentence_ids).unwrap();
337 stats.position = stats.position.ln().ln();
340 }
341
342 {
343 stats.frequency = stats.tf;
345 stats.frequency /= nsw_tf_mean + nsw_tf_std;
349 }
350
351 {
352 let (dl, dr) = ctx.dispersion_of(u_term);
353 stats.relatedness = 1.0 + (dr + dl) * (stats.tf / max_tf);
354 }
355
356 {
357 let distinct = occurrences.iter().map(|o| o.sentence_idx).dedup().count();
360 stats.sentences = distinct as f64 / sentences.len() as f64;
361 }
362
363 stats.score = (stats.relatedness * stats.position)
364 / (stats.casing + (stats.frequency / stats.relatedness) + (stats.sentences / stats.relatedness));
365
366 features.insert(u_term, stats.into());
367 }
368
369 features
370 }
371
372 fn candidate_weighting<'s>(&self, features: Features<'s>, ctx: &Contexts<'s>, candidates: &mut Candidates<'s>) {
373 for (&lc_terms, candidate) in candidates.iter_mut() {
374 let uq_terms = candidate.uq_terms;
375 let mut prod_ = 1.0;
376 let mut sum_ = 0.0;
377
378 for (j, (lc, uq)) in lc_terms.iter().zip(uq_terms).enumerate() {
379 if self.is_stopword(lc) {
380 let prob_prev = match uq_terms.get(j - 1) {
381 None => 0.0,
382 Some(prev_uq) => {
383 ctx.cases_term_is_followed(prev_uq, uq) as f64 / features.get(&prev_uq).unwrap().tf
385 }
386 };
387
388 let prob_succ = match uq_terms.get(j + 1) {
389 None => 0.0,
390 Some(next_uq) => {
391 ctx.cases_term_is_followed(uq, next_uq) as f64 / features.get(&next_uq).unwrap().tf
393 }
394 };
395
396 let prob = prob_prev * prob_succ;
397 prod_ *= 1.0 + (1.0 - prob);
398 sum_ -= 1.0 - prob;
399 } else if let Some(stats) = features.get(uq) {
400 prod_ *= stats.score;
401 sum_ += stats.score;
402 }
403 }
404
405 if sum_ == -1.0 {
406 sum_ = 0.999999999;
407 }
408
409 let tf = candidate.occurrences as f64;
410 candidate.score = prod_ / (tf * (1.0 + sum_));
411 }
412 }
413
414 fn is_candidate(&self, lc_terms: &[LTerm], tags: &[Tag]) -> bool {
415 let is_bad =
416 tags.iter().any(|tag| matches!(tag, Tag::Digit | Tag::Punctuation | Tag::Unparsable))
418 || self.is_stopword(lc_terms.last().unwrap())
420 || lc_terms.iter().map(|w| w.chars().count()).sum::<usize>() < self.config.minimum_chars
422 || self.config.only_alphanumeric_and_hyphen && !lc_terms.iter().all(word_is_alphanumeric_and_hyphen);
424
425 !is_bad
426 }
427
428 fn ngram_selection<'s>(&self, n: usize, sentences: &'s Sentences) -> Candidates<'s> {
429 let mut candidates = Candidates::new();
430 let mut ignored = HashSet::new();
431
432 for sentence in sentences.iter() {
433 let length = sentence.words.len();
434
435 for j in 0..length {
436 if self.is_stopword(&sentence.lc_terms[j]) {
437 continue; }
439
440 for k in (j + 1..length + 1).take(n) {
441 let lc_terms = &sentence.lc_terms[j..k];
442
443 if let Entry::Vacant(e) = ignored.entry(lc_terms) {
444 if !self.is_candidate(lc_terms, &sentence.tags[j..k]) {
445 e.insert();
446 } else {
447 candidates
448 .entry(lc_terms)
449 .or_insert_with(|| Candidate {
450 lc_terms,
451 uq_terms: &sentence.uq_terms[j..k],
452 raw: &sentence.words[j..k],
453 ..Default::default()
454 })
455 .occurrences += 1;
456 }
457 };
458 }
459 }
460 }
461
462 candidates
463 }
464}
465
466fn word_is_alphanumeric_and_hyphen(word: impl AsRef<str>) -> bool {
467 word.as_ref().chars().all(|ch| ch.is_alphanumeric() || ch == '-')
468}