unobtanium_segmenter/augmentation/
classify.rs

1// SPDX-FileCopyrightText: 2026 Slatian
2//
3// SPDX-License-Identifier: LGPL-3.0-only
4
5use unicode_properties::GeneralCategoryGroup;
6use unicode_properties::UnicodeGeneralCategory;
7
8use crate::SegmentedToken;
9use crate::SegmentedTokenKind;
10use crate::augmentation::Augmenter;
11
12/// An augmenter that rewrites the [SegmentedToken::kind] field to match reality.
13///
14/// It does so by reading the token text (preferring the normalized text)
15/// and applying heuristics based on the unicode [GeneralCategoryGroup] of the
16/// characters it contains.
17///
18/// The following heuristics are applied in the given order:
19///
20/// 1. If it contains **Letters** or **Numbers** -> [SegmentedTokenKind::AlphaNumeric]
21/// 2. If it contains **Symbols** or **Other** -> [SegmentedTokenKind::Symbol]
22/// 3. If it contains **Punctuation** or **Separators** -> [SegmentedTokenKind::Separator]
23///
24/// Exceptions from usual unicode classification: `\n` and `\0` are seperators.
25///
26/// The **Mark** category is ignored. If none of the heuristics apply the token kind is reset to `None`.
27#[derive(Debug, Clone, Default)]
28pub struct AugmentationClassify {}
29
30impl AugmentationClassify {
31	/// Create a new classify augmenter with default settings.
32	pub fn new() -> Self {
33		Default::default()
34	}
35}
36
37impl Augmenter for AugmentationClassify {
38	fn augment<'a>(&self, mut token: SegmentedToken<'a>) -> SegmentedToken<'a> {
39		let mut has_seperators = false;
40		let mut has_symbols = false;
41		for c in token.get_text_prefer_normalized().chars() {
42			match c.general_category_group() {
43				GeneralCategoryGroup::Letter | GeneralCategoryGroup::Number => {
44					token.kind = Some(SegmentedTokenKind::AlphaNumeric);
45					return token;
46				}
47				GeneralCategoryGroup::Punctuation | GeneralCategoryGroup::Separator => {
48					has_seperators = true
49				}
50				GeneralCategoryGroup::Symbol | GeneralCategoryGroup::Other => match c {
51					'\n' | '\0' => has_seperators = true,
52					_ => has_symbols = true,
53				},
54				GeneralCategoryGroup::Mark => { /* ignore */ }
55			}
56		}
57		if has_symbols {
58			token.kind = Some(SegmentedTokenKind::Symbol);
59			return token;
60		}
61		if has_seperators {
62			token.kind = Some(SegmentedTokenKind::Separator);
63			return token;
64		}
65		token.kind = None;
66		return token;
67	}
68}
69
70#[cfg(test)]
71mod test {
72
73	use super::*;
74
75	use crate::chain::ChainAugmenter;
76	use crate::chain::ChainSegmenter;
77	use crate::chain::StartSegmentationChain;
78	use crate::segmentation::UnicodeWordSplitter;
79
80	fn a() -> Option<SegmentedTokenKind> {
81		Some(SegmentedTokenKind::AlphaNumeric)
82	}
83
84	fn s() -> Option<SegmentedTokenKind> {
85		Some(SegmentedTokenKind::Separator)
86	}
87
88	fn y() -> Option<SegmentedTokenKind> {
89		Some(SegmentedTokenKind::Symbol)
90	}
91
92	#[test]
93	fn test_unicode_word_split() {
94		let test_text = "The quick (\"brown\") fox🦊 can't jump 32.3 feet, right?\nThe quick (\"brown\")  fox. The value of π in german is '3,141592…'.";
95
96		let word_splitter = UnicodeWordSplitter::new();
97		let classifier = AugmentationClassify::new();
98
99		let result: Vec<(&str, Option<SegmentedTokenKind>)> = test_text
100			.start_segmentation_chain()
101			.chain_segmenter(&word_splitter)
102			.chain_augmenter(&classifier)
103			.map(|t| (t.text, t.kind))
104			.collect();
105
106		let expected_tokens = vec![
107			("The", a()),
108			(" ", s()),
109			("quick", a()),
110			(" ", s()),
111			("(", s()),
112			("\"", s()),
113			("brown", a()),
114			("\"", s()),
115			(")", s()),
116			(" ", s()),
117			("fox", a()),
118			("🦊", y()),
119			(" ", s()),
120			("can't", a()),
121			(" ", s()),
122			("jump", a()),
123			(" ", s()),
124			("32.3", a()),
125			(" ", s()),
126			("feet", a()),
127			(",", s()),
128			(" ", s()),
129			("right", a()),
130			("?", s()),
131			("\n", s()),
132			("The", a()),
133			(" ", s()),
134			("quick", a()),
135			(" ", s()),
136			("(", s()),
137			("\"", s()),
138			("brown", a()),
139			("\"", s()),
140			(")", s()),
141			("  ", s()),
142			("fox", a()),
143			(".", s()),
144			(" ", s()),
145			("The", a()),
146			(" ", s()),
147			("value", a()),
148			(" ", s()),
149			("of", a()),
150			(" ", s()),
151			("Ï€", a()),
152			(" ", s()),
153			("in", a()),
154			(" ", s()),
155			("german", a()),
156			(" ", s()),
157			("is", a()),
158			(" ", s()),
159			("'", s()),
160			("3,141592", a()),
161			("…", s()),
162			("'", s()),
163			(".", s()),
164		];
165
166		assert_eq!(result, expected_tokens);
167	}
168}