ultra_nlp/_hashmap/
segment_bidirectional_longest.rs

1use crate::{
2    BehaviorForUnmatched,
3    Match,
4};
5use crate::hashmap::{
6    segment_backward_longest,
7    segment_forward_longest,
8    Dictionary,
9};
10
11// 待generator稳定, 改为generator, 以便返回Iterator.
12pub fn segment_bidirectional_longest<T: AsRef<str>>(
13    text: T,
14    dict: &Dictionary,
15    behavior_for_unmatched: BehaviorForUnmatched,
16) -> Vec<Match> {
17    let forward_results = segment_forward_longest(
18        &text,
19        dict,
20        behavior_for_unmatched,
21    );
22    let backward_results = segment_backward_longest(
23        &text,
24        dict,
25        behavior_for_unmatched,
26    );
27
28    if forward_results.len() < backward_results.len() {
29        forward_results
30    } else if forward_results.len() > backward_results.len() {
31        backward_results
32    } else {
33        let forward_single_chars_count = count_single_chars(
34            &forward_results,
35            &text,
36        );
37        let backward_single_chars_count = count_single_chars(
38            &backward_results,
39            &text,
40        );
41
42        if forward_single_chars_count < backward_single_chars_count {
43            forward_results
44        } else {
45            backward_results
46        }
47    }
48}
49
50fn count_single_chars<T: AsRef<str>>(matches: &Vec<Match>, text: T) -> usize {
51    matches
52        .into_iter()
53        .map(|mat| {
54            if mat.range().extract(text.as_ref())
55                .map(|text| text.chars().count() == 1)
56                .unwrap_or(false) {
57                1
58            } else {
59                0
60            }
61        })
62        .fold(0, |acc, cur| acc + cur)
63}
64
65#[cfg(test)]
66mod tests {
67    use crate::BehaviorForUnmatched;
68    use crate::hashmap::{
69        segment_bidirectional_longest,
70        Dictionary,
71    };
72
73    #[test]
74    fn test_should_returns_forward_longest_results() {
75        let text = " 当下雨天地面积水, hello world ";
76        let patterns = vec![
77            "当",
78            "当下",
79            "下雨",
80            "下雨天",
81            "雨天",
82            "地面",
83            "积水",
84            "你好世界",
85        ];
86        let dict = Dictionary::new(patterns.clone()).unwrap();
87
88        // 正向结果: [当下, 雨天, 地面, 积水]
89        // 逆向结果: [当, 下雨天, 地面, 积水]
90        // 结果数量相同, 单字数量正向结果少于逆向结果, 返回单字数量更少的正向结果.
91        let result = segment_bidirectional_longest(
92            text,
93            &dict,
94            BehaviorForUnmatched::Ignore,
95        );
96
97        assert_eq!(
98            result
99                .into_iter()
100                .map(|x| x.range().extract(text).unwrap())
101                .collect::<Vec<_>>(),
102            vec!["当下", "雨天", "地面", "积水"]
103        )
104    }
105
106    #[test]
107    fn test_should_returns_backward_longest_results() {
108        let text = " 商品和服务, hello world ";
109        let patterns = vec!["商品", "和服", "服务", "你好世界"];
110        let dict = Dictionary::new(patterns.clone()).unwrap();
111
112        // 正向结果: [商品, 和服]
113        // 逆向结果: [商品, 服务]
114        // 结果数量相同, 单字数量也相同, 返回逆向结果.
115        let result = segment_bidirectional_longest(
116            text,
117            &dict,
118            BehaviorForUnmatched::Ignore,
119        );
120
121        assert_eq!(
122            result
123                .into_iter()
124                .map(|x| x.range().extract(text).unwrap())
125                .collect::<Vec<_>>(),
126            vec!["商品", "服务",]
127        )
128    }
129
130    #[test]
131    fn test_ignore_unmatched() {
132        let text = " 商品和服务, hello world ";
133        let patterns = vec!["商品", "和服", "服务", "你好世界"];
134        let dict = Dictionary::new(patterns.clone()).unwrap();
135
136        let result = segment_bidirectional_longest(
137            text,
138            &dict,
139            BehaviorForUnmatched::Ignore,
140        );
141
142        assert_eq!(
143            result
144                .into_iter()
145                .map(|x| x.range().extract(text).unwrap())
146                .collect::<Vec<_>>(),
147            vec!["商品", "服务",]
148        )
149    }
150
151    #[test]
152    fn test_keep_unmatched_as_chars() {
153        let text = " 商品和服务, hello world ";
154        let patterns = vec!["商品", "和服", "服务", "你好世界"];
155        let dict = Dictionary::new(patterns.clone()).unwrap();
156
157        let result = segment_bidirectional_longest(
158            text,
159            &dict,
160            BehaviorForUnmatched::KeepAsChars,
161        );
162
163        assert_eq!(
164            result
165                .into_iter()
166                .map(|x| x.range().extract(text).unwrap())
167                .collect::<Vec<_>>(),
168            vec![
169                " ",
170                "商品",
171                "和",
172                "服务",
173                ",",
174                " ",
175                "h",
176                "e",
177                "l",
178                "l",
179                "o",
180                " ",
181                "w",
182                "o",
183                "r",
184                "l",
185                "d",
186                " ",
187            ]
188        )
189    }
190
191    #[test]
192    fn test_keep_unmatched_as_words() {
193        let text = " 当下雨天地面积水, hello world ";
194        let patterns = vec![
195            "当",
196            "当下",
197            "下雨",
198            "下雨天",
199            "雨天",
200            "地面",
201            "积水",
202            "你好世界",
203        ];
204        let dict = Dictionary::new(patterns.clone()).unwrap();
205
206        let result = segment_bidirectional_longest(
207            text,
208            &dict,
209            BehaviorForUnmatched::KeepAsWords,
210        );
211
212        assert_eq!(
213            result
214                .into_iter()
215                .map(|x| x.range().extract(text).unwrap())
216                .collect::<Vec<_>>(),
217            vec![" ", "当下", "雨天", "地面", "积水", ", hello world "]
218        )
219    }
220
221    #[test]
222    fn test_value() {
223        let text = " 当下雨天地面积水, hello world ";
224        let patterns: Vec<&str> = vec![
225            "当",
226            "当下",
227            "下雨",
228            "下雨天",
229            "雨天",
230            "地面",
231            "积水",
232            "你好世界",
233        ];
234        let dict = Dictionary::new(patterns.clone()).unwrap();
235
236        // 正向结果: [当下, 雨天, 地面, 积水]
237        // 逆向结果: [当, 下雨天, 地面, 积水]
238        // 结果数量相同, 单字数量正向结果少于逆向结果, 返回单字数量更少的正向结果.
239        let result = segment_bidirectional_longest(
240            text,
241            &dict,
242            BehaviorForUnmatched::Ignore,
243        );
244
245        assert_eq!(
246            result
247                .into_iter()
248                .map(|x| x.index_of_patterns().unwrap())
249                .collect::<Vec<_>>(),
250            vec![1, 4, 5, 6]
251        )
252    }
253
254    #[test]
255    fn test_chars_on_edge() {
256        let text = "你好世界";
257        let patterns = vec!["你好", "世界"];
258        let dict = Dictionary::new(patterns).unwrap();
259
260        let result = segment_bidirectional_longest(
261            text,
262            &dict,
263            BehaviorForUnmatched::Ignore
264        );
265
266        assert_eq!(
267            result
268                .into_iter()
269                .map(|x| x.range().extract(text).unwrap())
270                .collect::<Vec<_>>(),
271            vec!["你好", "世界"]
272        );
273    }
274}