unobtanium_segmenter/normalization/
rust_stemmers.rs

1use rust_stemmers::Algorithm;
2use rust_stemmers::Stemmer;
3use whatlang::Lang;
4
5use crate::augmentation::Augmenter;
6use crate::SegmentedToken;
7
8/// Will run stemming with the language tagged onto the token if an algorithm is available.
9///
10/// This uses the [rust_stemmers] crate under the hood.
11///
12/// This is recommended to be run after an [AugmentationDetectLanguage][crate::augmentation::AugmentationDetectLanguage] has been used, it will not do anything if no language metadata is available!
13///
14/// This will ignore tokens that already have their `normalized_text` set. Apply things like lowercasing after this.
15#[derive(Debug, Clone)]
16pub struct NormalizationRustStemmers {
17	/// Thereshold above which the flag about the lnguage detection flagging itself as reliable is ignored and the detected lnguage used for normalization anyway.
18	/// Setting this can help with shorter texts.
19	///
20	/// 1.0 which translates to never ignore the flag.
21	/// 0.0 would mean to always ignore it.
22	///
23	/// Default is 0.4 as that is usually "good enough" for correct stemming.
24	pub anyway_above_confidence: f64,
25}
26
27impl NormalizationRustStemmers {
28	/// Create a new NormalizationRustStemmers instance with the default settings.
29	pub fn new() -> Self {
30		Default::default()
31	}
32
33	/// Adjust the value of [anyway_above_confidence][Self::anyway_above_confidence] builder style.
34	pub fn set_anyway_above_confidence(mut self, anyway_above_confidence: f64) -> Self {
35		self.anyway_above_confidence = anyway_above_confidence;
36		return self;
37	}
38}
39
40impl Default for NormalizationRustStemmers {
41	fn default() -> Self {
42		Self {
43			anyway_above_confidence: 0.4,
44		}
45	}
46}
47
48impl Augmenter for NormalizationRustStemmers {
49	fn augment<'a>(&self, mut token: SegmentedToken<'a>) -> SegmentedToken<'a> {
50		if (token.is_detected_language_relible
51			|| token.detected_language_confidence > self.anyway_above_confidence)
52			&& token.normalized_text.is_none()
53		{
54			if let Some(language) = token.detected_language {
55				if let Some(algorithm) = get_stemming_algorithm_for_lang(language) {
56					let stemmer = Stemmer::create(algorithm);
57					let stemmed = stemmer.stem(token.get_text_prefer_normalized());
58					if stemmed != token.text {
59						token.normalized_text = Some(stemmed.to_string());
60					}
61				}
62			}
63		}
64		return token;
65	}
66}
67
68/// Map Whatlang languages to Implemented normalization algorithms
69fn get_stemming_algorithm_for_lang(lang: Lang) -> Option<Algorithm> {
70	Some(match lang {
71		Lang::Ara => Algorithm::Arabic,
72		Lang::Dan => Algorithm::Danish,
73		Lang::Nld => Algorithm::Dutch,
74		Lang::Eng => Algorithm::English,
75		Lang::Fin => Algorithm::Finnish,
76		Lang::Fra => Algorithm::French,
77		Lang::Deu => Algorithm::German,
78		Lang::Ell => Algorithm::Greek,
79		Lang::Hun => Algorithm::Hungarian,
80		Lang::Ita => Algorithm::Italian,
81		// Missing: Norwegian, whatlang can't detect it
82		Lang::Por => Algorithm::Portuguese,
83		Lang::Ron => Algorithm::Romanian,
84		Lang::Rus => Algorithm::Russian,
85		Lang::Spa => Algorithm::Spanish,
86		Lang::Swe => Algorithm::Swedish,
87		Lang::Tam => Algorithm::Tamil,
88		Lang::Tur => Algorithm::Turkish,
89		_ => {
90			return None;
91		}
92	})
93}
94
95#[cfg(test)]
96mod test {
97
98	use super::*;
99
100	use crate::chain::*;
101
102	use crate::augmentation::AugmentationDetectLanguage;
103	use crate::segmentation::UnicodeSentenceSplitter;
104	use crate::segmentation::UnicodeWordSplitter;
105
106	#[test]
107	fn test_stemmed_unicode_word_split() {
108		let test_text = "Fischers Fritze fischt frische Fische! The jumping brown fox quickly jumps over the sleeping dog.";
109
110		let sentence_splitter = UnicodeSentenceSplitter::new();
111		let language_detector = AugmentationDetectLanguage::new();
112		let word_splitter = UnicodeWordSplitter::new();
113
114		let result: Vec<String> = test_text
115			.start_segmentation_chain()
116			.chain_segmenter(&sentence_splitter)
117			.chain_augmenter(&language_detector)
118			.inspect(|x| {
119				println!("{x:?}");
120			})
121			.chain_segmenter(&word_splitter)
122			.inspect(|x| {
123				println!("word: {x:?}");
124			})
125			.chain_augmenter(&NormalizationRustStemmers::new().set_anyway_above_confidence(0.1))
126			.map(|t| t.get_text_prefer_normalized_owned())
127			.collect();
128
129		let expected_tokens: Vec<String> = vec![
130			"Fisch", " ", "Fritz", " ", "fischt", " ", "frisch", " ", "Fisch", "!", " ", "The",
131			" ", "jump", " ", "brown", " ", "fox", " ", "quick", " ", "jump", " ", "over", " ",
132			"the", " ", "sleep", " ", "dog", ".",
133		]
134		.iter()
135		.map(|s| s.to_string())
136		.collect();
137
138		assert_eq!(result, expected_tokens);
139	}
140}