Skip to main content

sqlite_simple_tokenizer/
simple_tokenizer.rs

1use crate::pinyin::{get_pinyin, has_pinyin, split_pinyin};
2use rusqlite::Error;
3use rusqlite_ext::{TokenizeReason, Tokenizer};
4use sqlite_chinese_stopword::STOPWORD;
5use sqlite_english_stemmer::{EN_STEMMER, make_lowercase};
6use std::ffi::CStr;
7use std::ops::Range;
8use unicode_segmentation::UnicodeSegmentation;
9
10/// 适用于拼音和中文的分词器
11pub struct SimpleTokenizer {
12    /// 是否支持拼音,默认支持拼音
13    enable_pinyin: bool,
14    /// 是否启用停词表, 默认启用
15    enable_stopword: bool,
16}
17
18impl Default for SimpleTokenizer {
19    fn default() -> Self {
20        Self {
21            enable_pinyin: true,
22            enable_stopword: true,
23        }
24    }
25}
26
27impl SimpleTokenizer {
28    /// 关闭拼音分词
29    pub fn disable_pinyin(&mut self) {
30        self.enable_pinyin = false;
31    }
32    /// 不启用停词表
33    pub fn disable_stopword(&mut self) {
34        self.enable_stopword = false;
35    }
36    /// 将查询文档转换成 SQLite 的 match 语句
37    pub fn tokenize_query(text: &str) -> Option<String> {
38        let mut match_sql = "".to_owned();
39        for (_, word) in text.unicode_word_indices() {
40            // 判断是否是单字
41            if need_pinyin(word) {
42                if let Some(ch) = word.chars().next()
43                    && let Some(pinyin_vec) = get_pinyin(&ch)
44                {
45                    for pinyin in pinyin_vec {
46                        let sql = Self::split_pinyin_to_sql(&pinyin);
47                        Self::append_match_sql(sql, &mut match_sql);
48                    }
49                }
50            } else {
51                let sql = Self::split_pinyin_to_sql(word);
52                Self::append_match_sql(sql, &mut match_sql);
53            }
54        }
55        Some(match_sql)
56    }
57
58    fn append_match_sql(sql: String, buf: &mut String) {
59        if buf.is_empty() {
60            buf.push('(');
61        } else {
62            buf.push_str(" AND (");
63        }
64        buf.push_str(&sql);
65        buf.push(')');
66    }
67
68    fn split_pinyin_to_sql(word: &str) -> String {
69        let pinyin_set = split_pinyin(word);
70        pinyin_set
71            .into_iter()
72            .fold(String::new(), |mut acc, pinyin| {
73                if acc.is_empty() {
74                    acc.push_str(&pinyin);
75                    acc.push('*');
76                } else {
77                    acc.push_str(" OR ");
78                    acc.push_str(&pinyin);
79                    acc.push('*');
80                };
81                acc
82            })
83    }
84}
85
86impl Tokenizer for SimpleTokenizer {
87    type Global = ();
88
89    fn name() -> &'static CStr {
90        c"simple"
91    }
92
93    fn new(_global: &Self::Global, args: Vec<String>) -> Result<Self, Error> {
94        let mut tokenizer = Self::default();
95        for arg in args {
96            match arg.as_str() {
97                "disable_pinyin" => {
98                    tokenizer.disable_pinyin();
99                }
100                "disable_stopword" => {
101                    tokenizer.disable_stopword();
102                }
103                _ => {}
104            }
105        }
106        Ok(tokenizer)
107    }
108
109    fn tokenize<TKF>(
110        &mut self,
111        _reason: TokenizeReason,
112        text: &[u8],
113        mut push_token: TKF,
114    ) -> Result<(), Error>
115    where
116        TKF: FnMut(&[u8], Range<usize>, bool) -> Result<(), Error>,
117    {
118        let text = String::from_utf8_lossy(text);
119        // 使用 unicode_word_indices 进行分词,所有中文字符应该是单独一个字符成 word
120        let mut word_buf = String::new();
121        for (index, word) in text.unicode_word_indices() {
122            let range = index..index + word.len();
123            // 开启 pinyin 并且这个是中文字符
124            if self.enable_pinyin && need_pinyin(word) {
125                if self.enable_stopword && STOPWORD.contains(word) {
126                    // 不处理停词
127                    continue;
128                }
129                if let Some(ch) = word.chars().next()
130                    && let Some(pinyin_vec) = get_pinyin(&ch)
131                {
132                    for pinyin in pinyin_vec {
133                        (push_token)(pinyin.as_bytes(), range.clone(), false)?;
134                    }
135                }
136            } else {
137                // 不需要使用 pinyin 模块进行处理
138                // 对单词做归一化处理,并且将单词转换成小写
139                let need_stem = make_lowercase(word, &mut word_buf);
140                if self.enable_stopword && STOPWORD.contains(word_buf.as_str()) {
141                    // 不处理停词
142                    continue;
143                }
144                if need_stem {
145                    let stemmed = EN_STEMMER.stem(word_buf.as_str()).into_owned();
146                    (push_token)(stemmed.as_bytes(), range, false)?;
147                } else {
148                    (push_token)(word_buf.as_bytes(), range, false)?;
149                }
150            }
151        }
152        Ok(())
153    }
154}
155
156/// 判断这个单词是否需要使用 pinyin 模块进行处理
157fn need_pinyin(word: &str) -> bool {
158    if word.is_empty() || word.chars().count() > 1 {
159        // 空串,或者字符个数大于 1 的单词,不需要 pinyin 处理
160        return false;
161    }
162    if let Some(ch) = word.chars().next() {
163        return has_pinyin(&ch);
164    }
165    false
166}
167
168#[cfg(test)]
169mod tests {
170    use unicode_segmentation::UnicodeSegmentation;
171
172    #[test]
173    fn test_tokenize_by_unicode_word_indices() {
174        let text = "The quick (\"brown\") fox can't jump 32.3 feet, right? 我将点燃星海!天上的stars全部都是 eye,不要凝视";
175        let uwi1 = text.unicode_word_indices().collect::<Vec<(usize, &str)>>();
176        let b: &[_] = &[
177            (0, "The"),
178            (4, "quick"),
179            (12, "brown"),
180            (20, "fox"),
181            (24, "can't"),
182            (30, "jump"),
183            (35, "32.3"),
184            (40, "feet"),
185            (46, "right"),
186            (53, "我"),
187            (56, "将"),
188            (59, "点"),
189            (62, "燃"),
190            (65, "星"),
191            (68, "海"),
192            (74, "天"),
193            (77, "上"),
194            (80, "的"),
195            (83, "stars"),
196            (88, "全"),
197            (91, "部"),
198            (94, "都"),
199            (97, "是"),
200            (101, "eye"),
201            (107, "不"),
202            (110, "要"),
203            (113, "凝"),
204            (116, "视"),
205        ];
206        assert_eq!(&uwi1[..], b);
207    }
208}