sqlite_simple_tokenizer/
lib.rs

1#[cfg(feature = "build_extension")]
2mod create_extension;
3mod load_extension;
4mod pinyin;
5mod tokenizer;
6mod utils;
7
8include!(concat!(env!("OUT_DIR"), "/stopword_data.rs"));
9
10use load_extension::create_scalar_functions;
11use load_extension::load_fts5_extension;
12use log::LevelFilter;
13use rusqlite::Connection;
14use std::error::Error;
15use utils::init_logging;
16
17pub fn load(connection: &Connection) -> Result<(), Box<dyn Error>> {
18    load_with_loglevel(connection, LevelFilter::Info)
19}
20
21pub fn load_with_loglevel(
22    connection: &Connection,
23    log_level: LevelFilter,
24) -> Result<(), Box<dyn Error>> {
25    // 设置 log
26    init_logging(log_level);
27    // 加载拓展函数
28    create_scalar_functions(connection)?;
29    // 加载 fts5 拓展
30    load_fts5_extension(connection)
31}
32
33#[cfg(test)]
34mod tests {
35    use crate::load;
36    use rusqlite::Connection;
37
38    #[test]
39    fn test_simple_query() {
40        let conn = Connection::open_in_memory().unwrap();
41        load(&conn).unwrap();
42        let mut stmt = conn.prepare("SELECT simple_query('国')").unwrap();
43        let result = stmt
44            .query_map([], |row| Ok(row.get::<_, String>(0).unwrap()))
45            .unwrap();
46        let mut vec = Vec::new();
47        for row in result {
48            let row = row.unwrap();
49            vec.push(row)
50        }
51        assert_eq!(["(g+u+o* OR gu+o* OR guo*)"], vec.as_slice());
52    }
53
54    #[test]
55    fn test_load() {
56        let conn = Connection::open_in_memory().unwrap();
57        load(&conn).unwrap();
58        // 创建一个测试表
59        conn.execute(
60            "CREATE VIRTUAL TABLE t1 USING fts5(text, tokenize = 'simple');",
61            [],
62        )
63        .unwrap();
64        // 插入数据
65        conn.execute(
66            r#"INSERT INTO t1(text) VALUES ('中华人民共和国国歌'),('静夜思'),('国家'),('举头望明月'),('like'),('liking'),('liked'),('I''m making a sqlite tokenizer'),('I''m learning English');"#,
67            [],
68        )
69            .unwrap();
70        let mut stmt = conn
71            .prepare("SELECT * FROM t1 WHERE text MATCH simple_query('国');")
72            .unwrap();
73        let result = stmt
74            .query_map([], |row| Ok(row.get::<_, String>(0).unwrap()))
75            .unwrap();
76        let mut vec = Vec::new();
77        for row in result {
78            let row = row.unwrap();
79            vec.push(row)
80        }
81        assert_eq!(["中华人民共和国国歌", "国家"], vec.as_slice());
82    }
83}