ultra_nlp/_cedarwood/
segment_bidirectional_longest.rs

1use crate::{
2    BehaviorForUnmatched,
3    Match,
4};
5use crate::cedarwood::{
6    segment_backward_longest,
7    segment_forward_longest,
8    BackwardDictionary,
9    ForwardDictionary,
10};
11
12// 待generator稳定, 改为generator, 以便返回Iterator.
13pub fn segment_bidirectional_longest<T: AsRef<str>>(
14    text: T,
15    forward_dict: &ForwardDictionary,
16    backward_dict: &BackwardDictionary,
17    behavior_for_unmatched: BehaviorForUnmatched,
18) -> Vec<Match> {
19    let forward_results = segment_forward_longest(
20        &text,
21        forward_dict,
22        behavior_for_unmatched,
23    );
24    let backward_results = segment_backward_longest(
25        &text,
26        backward_dict,
27        behavior_for_unmatched,
28    );
29
30    if forward_results.len() < backward_results.len() {
31        forward_results
32    } else if forward_results.len() > backward_results.len() {
33        backward_results
34    } else {
35        let forward_single_chars_count = count_single_chars(
36            &forward_results,
37            &text
38        );
39        let backward_single_chars_count = count_single_chars(
40            &backward_results,
41            &text
42        );
43
44        if forward_single_chars_count < backward_single_chars_count {
45            forward_results
46        } else {
47            backward_results
48        }
49    }
50}
51
52fn count_single_chars<T: AsRef<str>>(matches: &Vec<Match>, text: T) -> usize {
53    matches
54        .into_iter()
55        .map(|mat| {
56            if mat.range().extract(text.as_ref())
57                .map(|text| text.chars().count() == 1)
58                .unwrap_or(false) {
59                1
60            } else {
61                0
62            }
63        })
64        .fold(0, |acc, cur| acc + cur)
65}
66
67#[cfg(test)]
68mod tests {
69    use crate::BehaviorForUnmatched;
70    use crate::cedarwood::{
71        segment_bidirectional_longest,
72        BackwardDictionary,
73        ForwardDictionary,
74    };
75
76    #[test]
77    fn test_should_returns_forward_longest_results() {
78        let text = " 当下雨天地面积水, hello world ";
79        let patterns = vec![
80            "当",
81            "当下",
82            "下雨",
83            "下雨天",
84            "雨天",
85            "地面",
86            "积水",
87            "你好世界",
88        ];
89        let forward_dict = ForwardDictionary::new(
90            patterns.clone()
91        ).unwrap();
92        let backward_dict = BackwardDictionary::new(
93            patterns.clone()
94        ).unwrap();
95
96        // 正向结果: [当下, 雨天, 地面, 积水]
97        // 逆向结果: [当, 下雨天, 地面, 积水]
98        // 结果数量相同, 单字数量正向结果少于逆向结果, 返回单字数量更少的正向结果.
99        let result = segment_bidirectional_longest(
100            text,
101            &forward_dict,
102            &backward_dict,
103            BehaviorForUnmatched::Ignore,
104        );
105
106        assert_eq!(
107            result
108                .into_iter()
109                .map(|x| x.range().extract(text).unwrap())
110                .collect::<Vec<_>>(),
111            vec!["当下", "雨天", "地面", "积水"]
112        )
113    }
114
115    #[test]
116    fn test_should_returns_backward_longest_results() {
117        let text = " 商品和服务, hello world ";
118        let patterns = vec!["商品", "和服", "服务", "你好世界"];
119        let forward_dict = ForwardDictionary::new(
120            patterns.clone()
121        ).unwrap();
122        let backward_dict = BackwardDictionary::new(
123            patterns.clone()
124        ).unwrap();
125
126        // 正向结果: [商品, 和服]
127        // 逆向结果: [商品, 服务]
128        // 结果数量相同, 单字数量也相同, 返回逆向结果.
129        let result = segment_bidirectional_longest(
130            text,
131            &forward_dict,
132            &backward_dict,
133            BehaviorForUnmatched::Ignore,
134        );
135
136        assert_eq!(
137            result
138                .into_iter()
139                .map(|x| x.range().extract(text).unwrap())
140                .collect::<Vec<_>>(),
141            vec!["商品", "服务",]
142        )
143    }
144
145    #[test]
146    fn test_ignore_unmatched() {
147        let text = " 商品和服务, hello world ";
148        let patterns = vec!["商品", "和服", "服务", "你好世界"];
149        let forward_dict = ForwardDictionary::new(
150            patterns.clone()
151        ).unwrap();
152        let backward_dict = BackwardDictionary::new(
153            patterns.clone()
154        ).unwrap();
155
156        let result = segment_bidirectional_longest(
157            text,
158            &forward_dict,
159            &backward_dict,
160            BehaviorForUnmatched::Ignore,
161        );
162
163        assert_eq!(
164            result
165                .into_iter()
166                .map(|x| x.range().extract(text).unwrap())
167                .collect::<Vec<_>>(),
168            vec!["商品", "服务",]
169        )
170    }
171
172    #[test]
173    fn test_keep_unmatched_as_chars() {
174        let text = " 商品和服务, hello world ";
175        let patterns = vec!["商品", "和服", "服务", "你好世界"];
176        let forward_dict = ForwardDictionary::new(
177            patterns.clone()
178        ).unwrap();
179        let backward_dict = BackwardDictionary::new(
180            patterns.clone()
181        ).unwrap();
182
183        let result = segment_bidirectional_longest(
184            text,
185            &forward_dict,
186            &backward_dict,
187            BehaviorForUnmatched::KeepAsChars,
188        );
189
190        assert_eq!(
191            result
192                .into_iter()
193                .map(|x| x.range().extract(text).unwrap())
194                .collect::<Vec<_>>(),
195            vec![
196                " ",
197                "商品",
198                "和",
199                "服务",
200                ",",
201                " ",
202                "h",
203                "e",
204                "l",
205                "l",
206                "o",
207                " ",
208                "w",
209                "o",
210                "r",
211                "l",
212                "d",
213                " ",
214            ]
215        )
216    }
217
218    #[test]
219    fn test_keep_unmatched_as_words() {
220        let text = " 当下雨天地面积水, hello world ";
221        let patterns = vec![
222            "当",
223            "当下",
224            "下雨",
225            "下雨天",
226            "雨天",
227            "地面",
228            "积水",
229            "你好世界",
230        ];
231        let forward_dict = ForwardDictionary::new(
232            patterns.clone()
233        ).unwrap();
234        let backward_dict = BackwardDictionary::new(
235            patterns.clone()
236        ).unwrap();
237
238        let result = segment_bidirectional_longest(
239            text,
240            &forward_dict,
241            &backward_dict,
242            BehaviorForUnmatched::KeepAsWords,
243        );
244
245        assert_eq!(
246            result
247                .into_iter()
248                .map(|x| x.range().extract(text).unwrap())
249                .collect::<Vec<_>>(),
250            vec![" ", "当下", "雨天", "地面", "积水", ", hello world "]
251        )
252    }
253
254    #[test]
255    fn test_value() {
256        let text = " 当下雨天地面积水, hello world ";
257        let patterns: Vec<&str> = vec![
258            "当",
259            "当下",
260            "下雨",
261            "下雨天",
262            "雨天",
263            "地面",
264            "积水",
265            "你好世界",
266        ];
267        let forward_dict = ForwardDictionary::new(
268            patterns.clone()
269        ).unwrap();
270        let backward_dict = BackwardDictionary::new(
271            patterns.clone()
272        ).unwrap();
273
274        // 正向结果: [当下, 雨天, 地面, 积水]
275        // 逆向结果: [当, 下雨天, 地面, 积水]
276        // 结果数量相同, 单字数量正向结果少于逆向结果, 返回单字数量更少的正向结果.
277        let result = segment_bidirectional_longest(
278            text,
279            &forward_dict,
280            &backward_dict,
281            BehaviorForUnmatched::Ignore,
282        );
283
284        assert_eq!(
285            result
286                .into_iter()
287                .map(|x| x.index_of_patterns().unwrap())
288                .collect::<Vec<_>>(),
289            vec![1, 4, 5, 6]
290        )
291    }
292
293    #[test]
294    fn test_chars_on_edge() {
295        let text = "你好世界";
296        let patterns = vec!["你好", "世界"];
297        let forward_dict = ForwardDictionary::new(
298            patterns.clone()
299        ).unwrap();
300        let backward_dict = BackwardDictionary::new(
301            patterns.clone()
302        ).unwrap();
303
304        let result = segment_bidirectional_longest(
305            text,
306            &forward_dict,
307            &backward_dict,
308            BehaviorForUnmatched::Ignore
309        );
310
311        assert_eq!(
312            result
313                .into_iter()
314                .map(|x| x.range().extract(text).unwrap())
315                .collect::<Vec<_>>(),
316            vec!["你好", "世界"]
317        );
318    }
319}