sqlite_simple_tokenizer/
lib.rs

1#[cfg(feature = "build_extension")]
2mod create_extension;
3mod error;
4mod load_extension;
5mod pinyin;
6mod tokenizer;
7mod utils;
8
9include!(concat!(env!("OUT_DIR"), "/stopword_data.rs"));
10
11pub use error::Error;
12use load_extension::create_scalar_functions;
13use load_extension::load_fts5_extension;
14use log::LevelFilter;
15use rusqlite::Connection;
16use utils::init_logging;
17
18pub fn load(connection: &Connection) -> Result<(), Error> {
19    load_with_loglevel(connection, LevelFilter::Info)
20}
21
22pub fn load_with_loglevel(connection: &Connection, log_level: LevelFilter) -> Result<(), Error> {
23    // 设置 log
24    init_logging(log_level);
25    // 加载拓展函数
26    create_scalar_functions(connection)?;
27    // 加载 fts5 拓展
28    load_fts5_extension(connection)
29}
30
31#[cfg(test)]
32mod tests {
33    use crate::load;
34    use rusqlite::Connection;
35
36    #[test]
37    fn test_simple_query() {
38        let conn = Connection::open_in_memory().unwrap();
39        load(&conn).unwrap();
40        let mut stmt = conn.prepare("SELECT simple_query('国')").unwrap();
41        let result = stmt
42            .query_map([], |row| Ok(row.get::<_, String>(0).unwrap()))
43            .unwrap();
44        let mut vec = Vec::new();
45        for row in result {
46            let row = row.unwrap();
47            vec.push(row)
48        }
49        assert_eq!(["(g+u+o* OR gu+o* OR guo*)"], vec.as_slice());
50    }
51
52    #[test]
53    fn test_load() {
54        let conn = Connection::open_in_memory().unwrap();
55        load(&conn).unwrap();
56        // 创建一个测试表
57        conn.execute(
58            "CREATE VIRTUAL TABLE t1 USING fts5(text, tokenize = 'simple');",
59            [],
60        )
61        .unwrap();
62        // 插入数据
63        conn.execute(
64            r#"INSERT INTO t1(text) VALUES ('中华人民共和国国歌'),('静夜思'),('国家'),('举头望明月'),('like'),('liking'),('liked'),('I''m making a sqlite tokenizer'),('I''m learning English');"#,
65            [],
66        )
67            .unwrap();
68        let mut stmt = conn
69            .prepare("SELECT * FROM t1 WHERE text MATCH simple_query('国');")
70            .unwrap();
71        let result = stmt
72            .query_map([], |row| Ok(row.get::<_, String>(0).unwrap()))
73            .unwrap();
74        let mut vec = Vec::new();
75        for row in result {
76            let row = row.unwrap();
77            vec.push(row)
78        }
79        assert_eq!(["中华人民共和国国歌", "国家"], vec.as_slice());
80    }
81}