1use crate::pinyin::{get_pinyin, has_pinyin, split_pinyin};
2use rusqlite::Error;
3use rusqlite_ext::{TokenizeReason, Tokenizer};
4use sqlite_chinese_stopword::STOPWORD;
5use sqlite_english_stemmer::{EN_STEMMER, make_lowercase};
6use std::ffi::CStr;
7use std::ops::Range;
8use unicode_segmentation::UnicodeSegmentation;
9
10pub struct SimpleTokenizer {
12 enable_pinyin: bool,
14 enable_stopword: bool,
16}
17
18impl Default for SimpleTokenizer {
19 fn default() -> Self {
20 Self {
21 enable_pinyin: true,
22 enable_stopword: true,
23 }
24 }
25}
26
27impl SimpleTokenizer {
28 pub fn disable_pinyin(&mut self) {
30 self.enable_pinyin = false;
31 }
32 pub fn disable_stopword(&mut self) {
34 self.enable_stopword = false;
35 }
36 pub fn tokenize_query(text: &str) -> Option<String> {
38 let mut match_sql = "".to_owned();
39 for (_, word) in text.unicode_word_indices() {
40 if need_pinyin(word) {
42 if let Some(ch) = word.chars().next()
43 && let Some(pinyin_vec) = get_pinyin(&ch)
44 {
45 for pinyin in pinyin_vec {
46 let sql = Self::split_pinyin_to_sql(&pinyin);
47 Self::append_match_sql(sql, &mut match_sql);
48 }
49 }
50 } else {
51 let sql = Self::split_pinyin_to_sql(word);
52 Self::append_match_sql(sql, &mut match_sql);
53 }
54 }
55 Some(match_sql)
56 }
57
58 fn append_match_sql(sql: String, buf: &mut String) {
59 if buf.is_empty() {
60 buf.push('(');
61 } else {
62 buf.push_str(" AND (");
63 }
64 buf.push_str(&sql);
65 buf.push(')');
66 }
67
68 fn split_pinyin_to_sql(word: &str) -> String {
69 let pinyin_set = split_pinyin(word);
70 pinyin_set
71 .into_iter()
72 .fold(String::new(), |mut acc, pinyin| {
73 if acc.is_empty() {
74 acc.push_str(&pinyin);
75 acc.push('*');
76 } else {
77 acc.push_str(" OR ");
78 acc.push_str(&pinyin);
79 acc.push('*');
80 };
81 acc
82 })
83 }
84}
85
86impl Tokenizer for SimpleTokenizer {
87 type Global = ();
88
89 fn name() -> &'static CStr {
90 c"simple"
91 }
92
93 fn new(_global: &Self::Global, args: Vec<String>) -> Result<Self, Error> {
94 let mut tokenizer = Self::default();
95 for arg in args {
96 match arg.as_str() {
97 "disable_pinyin" => {
98 tokenizer.disable_pinyin();
99 }
100 "disable_stopword" => {
101 tokenizer.disable_stopword();
102 }
103 _ => {}
104 }
105 }
106 Ok(tokenizer)
107 }
108
109 fn tokenize<TKF>(
110 &mut self,
111 _reason: TokenizeReason,
112 text: &[u8],
113 mut push_token: TKF,
114 ) -> Result<(), Error>
115 where
116 TKF: FnMut(&[u8], Range<usize>, bool) -> Result<(), Error>,
117 {
118 let text = String::from_utf8_lossy(text);
119 let mut word_buf = String::new();
121 for (index, word) in text.unicode_word_indices() {
122 let range = index..index + word.len();
123 if self.enable_pinyin && need_pinyin(word) {
125 if self.enable_stopword && STOPWORD.contains(word) {
126 continue;
128 }
129 if let Some(ch) = word.chars().next()
130 && let Some(pinyin_vec) = get_pinyin(&ch)
131 {
132 for pinyin in pinyin_vec {
133 (push_token)(pinyin.as_bytes(), range.clone(), false)?;
134 }
135 }
136 } else {
137 let need_stem = make_lowercase(word, &mut word_buf);
140 if self.enable_stopword && STOPWORD.contains(word_buf.as_str()) {
141 continue;
143 }
144 if need_stem {
145 let stemmed = EN_STEMMER.stem(word_buf.as_str()).into_owned();
146 (push_token)(stemmed.as_bytes(), range, false)?;
147 } else {
148 (push_token)(word_buf.as_bytes(), range, false)?;
149 }
150 }
151 }
152 Ok(())
153 }
154}
155
156fn need_pinyin(word: &str) -> bool {
158 if word.is_empty() || word.chars().count() > 1 {
159 return false;
161 }
162 if let Some(ch) = word.chars().next() {
163 return has_pinyin(&ch);
164 }
165 false
166}
167
168#[cfg(test)]
169mod tests {
170 use unicode_segmentation::UnicodeSegmentation;
171
172 #[test]
173 fn test_tokenize_by_unicode_word_indices() {
174 let text = "The quick (\"brown\") fox can't jump 32.3 feet, right? 我将点燃星海!天上的stars全部都是 eye,不要凝视";
175 let uwi1 = text.unicode_word_indices().collect::<Vec<(usize, &str)>>();
176 let b: &[_] = &[
177 (0, "The"),
178 (4, "quick"),
179 (12, "brown"),
180 (20, "fox"),
181 (24, "can't"),
182 (30, "jump"),
183 (35, "32.3"),
184 (40, "feet"),
185 (46, "right"),
186 (53, "我"),
187 (56, "将"),
188 (59, "点"),
189 (62, "燃"),
190 (65, "星"),
191 (68, "海"),
192 (74, "天"),
193 (77, "上"),
194 (80, "的"),
195 (83, "stars"),
196 (88, "全"),
197 (91, "部"),
198 (94, "都"),
199 (97, "是"),
200 (101, "eye"),
201 (107, "不"),
202 (110, "要"),
203 (113, "凝"),
204 (116, "视"),
205 ];
206 assert_eq!(&uwi1[..], b);
207 }
208}