sqlite_simple_tokenizer/
lib.rs

1#[cfg(feature = "build_extension")]
2mod create_extension;
3mod load_extension;
4mod pinyin;
5mod tokenizer;
6mod utils;
7
8include!(concat!(env!("OUT_DIR"), "/stopword_data.rs"));
9
10use load_extension::create_scalar_functions;
11use load_extension::load_fts5_extension;
12use log::LevelFilter;
13use rusqlite::Connection;
14use utils::init_logging;
15
16pub fn load(connection: &Connection) -> anyhow::Result<()> {
17    load_with_loglevel(connection, LevelFilter::Info)
18}
19
20pub fn load_with_loglevel(connection: &Connection, log_level: LevelFilter) -> anyhow::Result<()> {
21    // 设置 log
22    init_logging(log_level);
23    // 加载拓展函数
24    create_scalar_functions(connection)?;
25    // 加载 fts5 拓展
26    load_fts5_extension(connection)
27}
28
29#[cfg(test)]
30mod tests {
31    use crate::load;
32    use rusqlite::Connection;
33
34    #[test]
35    fn test_simple_query() {
36        let conn = Connection::open_in_memory().unwrap();
37        load(&conn).unwrap();
38        let mut stmt = conn.prepare("SELECT simple_query('国')").unwrap();
39        let result = stmt
40            .query_map([], |row| Ok(row.get::<_, String>(0).unwrap()))
41            .unwrap();
42        let mut vec = Vec::new();
43        for row in result {
44            let row = row.unwrap();
45            vec.push(row)
46        }
47        assert_eq!(["(g+u+o* OR gu+o* OR guo*)"], vec.as_slice());
48    }
49
50    #[test]
51    fn test_load() {
52        let conn = Connection::open_in_memory().unwrap();
53        load(&conn).unwrap();
54        // 创建一个测试表
55        conn.execute(
56            "CREATE VIRTUAL TABLE t1 USING fts5(text, tokenize = 'simple');",
57            [],
58        )
59        .unwrap();
60        // 插入数据
61        conn.execute(
62            r#"INSERT INTO t1(text) VALUES ('中华人民共和国国歌'),('静夜思'),('国家'),('举头望明月'),('like'),('liking'),('liked'),('I''m making a sqlite tokenizer'),('I''m learning English');"#,
63            [],
64        )
65            .unwrap();
66        let mut stmt = conn
67            .prepare("SELECT * FROM t1 WHERE text MATCH simple_query('国');")
68            .unwrap();
69        let result = stmt
70            .query_map([], |row| Ok(row.get::<_, String>(0).unwrap()))
71            .unwrap();
72        let mut vec = Vec::new();
73        for row in result {
74            let row = row.unwrap();
75            vec.push(row)
76        }
77        assert_eq!(["中华人民共和国国歌", "国家"], vec.as_slice());
78    }
79}