unobtanium_segmenter/segmentation/
lingua.rs1use std::vec::IntoIter;
6
7use lingua::LanguageDetector;
8use lingua::LanguageDetectorBuilder;
9
10use crate::SegmentedToken;
11use crate::UseOrSubdivide;
12use crate::language_detection::DetectionIterator;
13use crate::language_detection::detect_top_n_languages;
14use crate::language_detection::lingua_language_to_whatlang_language;
15use crate::segmentation::Segmenter;
16
17pub struct LinguaLanguageBlockSplitter {
27 language_detector: LanguageDetector,
28}
29
30impl Default for LinguaLanguageBlockSplitter {
31 fn default() -> Self {
32 Self {
33 language_detector: LanguageDetectorBuilder::from_all_languages()
34 .with_preloaded_language_models()
35 .build(),
36 }
37 }
38}
39
40impl LinguaLanguageBlockSplitter {
41 pub fn new() -> Self {
43 Default::default()
44 }
45
46 pub fn new_with_builder(mut builder: LanguageDetectorBuilder) -> Self {
48 Self {
49 language_detector: builder.build(),
50 }
51 }
52}
53
54impl Segmenter for LinguaLanguageBlockSplitter {
55 type SubdivisionIter<'a> = IntoIter<SegmentedToken<'a>>;
56
57 fn subdivide<'a>(
58 &self,
59 token: SegmentedToken<'a>,
60 ) -> UseOrSubdivide<SegmentedToken<'a>, IntoIter<SegmentedToken<'a>>> {
61 let languages = detect_top_n_languages(&self.language_detector, 3, token.text);
62
63 if languages.is_empty() {
64 return UseOrSubdivide::Use(token);
66 }
67
68 let detection_iterator = DetectionIterator::detect(
69 &LanguageDetectorBuilder::from_languages(&languages).build(),
71 token.text,
72 );
73
74 let result_list = self
75 .language_detector
76 .detect_multiple_languages_of(token.text);
77 let mut collection: Vec<SegmentedToken<'_>> = Vec::with_capacity(result_list.len() * 2 + 1);
78
79 let mut last_offset = 0;
80
81 for (start_index, end_index, language) in detection_iterator {
82 if last_offset != start_index {
84 collection.push(SegmentedToken::new_derived_from(
85 &token.text[last_offset..start_index],
86 &token,
87 ));
88 }
89 let mut new_token =
90 SegmentedToken::new_derived_from(&token.text[start_index..end_index], &token);
91 new_token.detected_language = language.and_then(lingua_language_to_whatlang_language);
92 new_token.is_detected_language_relible = true;
93
94 collection.push(new_token);
95
96 last_offset = end_index;
97 }
98
99 if last_offset != token.text.len() {
100 collection.push(SegmentedToken::new_derived_from(
101 &token.text[last_offset..],
102 &token,
103 ));
104 }
105
106 UseOrSubdivide::Subdivide(collection.into_iter())
107 }
108}
109
110#[cfg(test)]
111mod test {
112
113 use std::time::Instant;
114
115 use super::*;
116
117 use crate::chain::ChainSegmenter;
118 use crate::chain::StartSegmentationChain;
119
120 use whatlang::Lang;
121
122 #[test]
123 fn test_lingua_multilanguage_detection() {
124 let test_text = "Parlez-vous français? \
125 Ich spreche Französisch nur ein bisschen. \
126 A little bit is better than nothing. ";
127
128 let lingua_segmenter = LinguaLanguageBlockSplitter::new();
129
130 for _ in 0..100 {
131 let result: Vec<(&str, Option<Lang>)> = test_text
132 .start_segmentation_chain()
133 .chain_segmenter(&lingua_segmenter)
134 .map(|t| (t.text, t.detected_language))
135 .collect();
136
137 let expected_tokens = vec![
138 ("Parlez-vous français? ", Some(Lang::Fra)),
139 (
140 "Ich spreche Französisch nur ein bisschen. ",
141 Some(Lang::Deu),
142 ),
143 ("A little bit is better than nothing. ", Some(Lang::Eng)),
144 ];
145
146 assert_eq!(result, expected_tokens);
147 }
148 }
149
150 #[test]
151 fn test_lingua_performance() {
152 let test_text = "Parlez-vous français? \
153 Ich spreche Französisch nur ein bisschen. \
154 A little bit is better than nothing.";
155
156 let start_instant = Instant::now();
157
158 let _result: Vec<(&str, Option<Lang>)> = test_text
159 .start_segmentation_chain()
160 .chain_segmenter(&LinguaLanguageBlockSplitter::new())
161 .map(|t| (t.text, t.detected_language))
162 .collect();
163
164 let time_first_iteration = start_instant.elapsed();
165 let start_instant = Instant::now();
166
167 for _ in 0..100 {
168 let _result: Vec<(&str, Option<Lang>)> = test_text
169 .start_segmentation_chain()
170 .chain_segmenter(&LinguaLanguageBlockSplitter::new())
171 .map(|t| (t.text, t.detected_language))
172 .collect();
173 }
174
175 let time_multiple_iterations = start_instant.elapsed();
176
177 assert!(
178 time_first_iteration > (time_multiple_iterations / 100),
179 "Subsequent iterations should be faster than the initial one, even when not keeping the struct around. {time_first_iteration:?} > {:?}",
180 time_multiple_iterations / 100
181 );
182 }
183}