unobtanium_segmenter/augmentation/
classify.rs

1use unicode_properties::GeneralCategoryGroup;
2use unicode_properties::UnicodeGeneralCategory;
3
4use crate::augmentation::Augmenter;
5use crate::SegmentedToken;
6use crate::SegmentedTokenKind;
7
8/// An augmenter that rewrites the [SegmentedToken::kind] field to match reality.
9///
10/// It does so by reading the token text (preferring the normalized text)
11/// and applying heuristics based on the unicode [GeneralCategoryGroup] of the
12/// characters it contains.
13///
14/// The following heuristics are applied in the given order:
15///
16/// 1. If it contains **Letters** or **Numbers** -> [SegmentedTokenKind::AlphaNumeric]
17/// 2. If it contains **Symbols** or **Other** -> [SegmentedTokenKind::Symbol]
18/// 3. If it contains **Punctuation** or **Separators** -> [SegmentedTokenKind::Separator]
19///
20/// Exceptions from usual unicode classification: `\n` and `\0` are seperators.
21///
22/// The **Mark** category is ignored. If none of the heuristics apply the token kind is reset to `None`.
23#[derive(Debug, Clone, Default)]
24pub struct AugmentationClassify {}
25
26impl AugmentationClassify {
27	/// Create a new classify augmenter with default settings.
28	pub fn new() -> Self {
29		Default::default()
30	}
31}
32
33impl Augmenter for AugmentationClassify {
34	fn augment<'a>(&self, mut token: SegmentedToken<'a>) -> SegmentedToken<'a> {
35		let mut has_seperators = false;
36		let mut has_symbols = false;
37		for c in token.get_text_prefer_normalized().chars() {
38			match c.general_category_group() {
39				GeneralCategoryGroup::Letter | GeneralCategoryGroup::Number => {
40					token.kind = Some(SegmentedTokenKind::AlphaNumeric);
41					return token;
42				}
43				GeneralCategoryGroup::Punctuation | GeneralCategoryGroup::Separator => {
44					has_seperators = true
45				}
46				GeneralCategoryGroup::Symbol | GeneralCategoryGroup::Other => match c {
47					'\n' | '\0' => has_seperators = true,
48					_ => has_symbols = true,
49				},
50				GeneralCategoryGroup::Mark => { /* ignore */ }
51			}
52		}
53		if has_symbols {
54			token.kind = Some(SegmentedTokenKind::Symbol);
55			return token;
56		}
57		if has_seperators {
58			token.kind = Some(SegmentedTokenKind::Separator);
59			return token;
60		}
61		token.kind = None;
62		return token;
63	}
64}
65
66#[cfg(test)]
67mod test {
68
69	use super::*;
70
71	use crate::chain::ChainAugmenter;
72	use crate::chain::ChainSegmenter;
73	use crate::chain::StartSegmentationChain;
74	use crate::segmentation::UnicodeWordSplitter;
75
76	fn a() -> Option<SegmentedTokenKind> {
77		Some(SegmentedTokenKind::AlphaNumeric)
78	}
79
80	fn s() -> Option<SegmentedTokenKind> {
81		Some(SegmentedTokenKind::Separator)
82	}
83
84	fn y() -> Option<SegmentedTokenKind> {
85		Some(SegmentedTokenKind::Symbol)
86	}
87
88	#[test]
89	fn test_unicode_word_split() {
90		let test_text = "The quick (\"brown\") fox🦊 can't jump 32.3 feet, right?\nThe quick (\"brown\")  fox. The value of π in german is '3,141592…'.";
91
92		let word_splitter = UnicodeWordSplitter::new();
93		let classifier = AugmentationClassify::new();
94
95		let result: Vec<(&str, Option<SegmentedTokenKind>)> = test_text
96			.start_segmentation_chain()
97			.chain_segmenter(&word_splitter)
98			.chain_augmenter(&classifier)
99			.map(|t| (t.text, t.kind))
100			.collect();
101
102		let expected_tokens = vec![
103			("The", a()),
104			(" ", s()),
105			("quick", a()),
106			(" ", s()),
107			("(", s()),
108			("\"", s()),
109			("brown", a()),
110			("\"", s()),
111			(")", s()),
112			(" ", s()),
113			("fox", a()),
114			("🦊", y()),
115			(" ", s()),
116			("can't", a()),
117			(" ", s()),
118			("jump", a()),
119			(" ", s()),
120			("32.3", a()),
121			(" ", s()),
122			("feet", a()),
123			(",", s()),
124			(" ", s()),
125			("right", a()),
126			("?", s()),
127			("\n", s()),
128			("The", a()),
129			(" ", s()),
130			("quick", a()),
131			(" ", s()),
132			("(", s()),
133			("\"", s()),
134			("brown", a()),
135			("\"", s()),
136			(")", s()),
137			("  ", s()),
138			("fox", a()),
139			(".", s()),
140			(" ", s()),
141			("The", a()),
142			(" ", s()),
143			("value", a()),
144			(" ", s()),
145			("of", a()),
146			(" ", s()),
147			("Ï€", a()),
148			(" ", s()),
149			("in", a()),
150			(" ", s()),
151			("german", a()),
152			(" ", s()),
153			("is", a()),
154			(" ", s()),
155			("'", s()),
156			("3,141592", a()),
157			("…", s()),
158			("'", s()),
159			(".", s()),
160		];
161
162		assert_eq!(result, expected_tokens);
163	}
164}