Skip to main content

rusqlite_ext/
lib.rs

1// 代码来自 https://gist.github.com/ColonelThirtyTwo/3dd1fe04e4cff0502fa70d12f3a6e72e/revisions
2// 针对 Rust 和 rusqlite 的新版本做了一些调整
3
4use rusqlite::Connection;
5use rusqlite::ffi::{
6    FTS5_TOKEN_COLOCATED, FTS5_TOKENIZE_AUX, FTS5_TOKENIZE_DOCUMENT, FTS5_TOKENIZE_PREFIX,
7    FTS5_TOKENIZE_QUERY, Fts5Tokenizer, SQLITE_ERROR, SQLITE_OK, SQLITE_PREPARE_PERSISTENT,
8    fts5_api, fts5_tokenizer_v2, sqlite3_bind_pointer, sqlite3_finalize, sqlite3_prepare_v3,
9    sqlite3_step, sqlite3_stmt,
10};
11use std::convert::{TryFrom, TryInto};
12use std::ffi::{CStr, c_char, c_int, c_void};
13use std::fmt::Formatter;
14use std::ops::Range;
15use std::panic::AssertUnwindSafe;
16
17pub mod error;
18
19/// fts5_api 的版本,要求最低版本不能低于 3
20const FTS5_API_VERSION: c_int = 3;
21/// 设置 fts5_tokenizer 的版本,设置为 2,使用 v2 接口
22const FTS5_TOKENIZER_VERSION: c_int = 2;
23
24/// FTS5 请求对所提供的文本进行标记化的原因
25#[derive(Clone, Debug, Eq, PartialEq)]
26pub enum TokenizeReason {
27    /// 往 FTS 表中插入或者删除文档
28    Document,
29    ///  对 FTS 索引执行 MATCH 查询
30    Query {
31        /// 查询的字符串后是否带上 “*",
32        prefix: bool,
33    },
34    /// 手动调用 `fts5_api.xTokenize`.
35    Aux,
36}
37
38#[derive(Debug)]
39pub enum IntoTokenizeReasonError {
40    UnrecognizedValue(c_int),
41}
42
43impl std::fmt::Display for IntoTokenizeReasonError {
44    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
45        match self {
46            Self::UnrecognizedValue(flag) => {
47                write!(f, "Unrecognized flags passed to xTokenize: {flag}")
48            }
49        }
50    }
51}
52
53impl std::error::Error for IntoTokenizeReasonError {}
54
55impl TryFrom<c_int> for TokenizeReason {
56    type Error = IntoTokenizeReasonError;
57
58    fn try_from(value: c_int) -> Result<Self, Self::Error> {
59        /// 这个值是针对 FTS 索引执行 MATCH 查询时,在查询字符串后带上 * 的特殊值
60        const FTS5_TOKENIZE_QUERY_PREFIX: c_int = FTS5_TOKENIZE_QUERY | FTS5_TOKENIZE_PREFIX;
61        match value {
62            FTS5_TOKENIZE_DOCUMENT => Ok(Self::Document),
63            FTS5_TOKENIZE_QUERY => Ok(Self::Query { prefix: false }),
64            FTS5_TOKENIZE_QUERY_PREFIX => Ok(Self::Query { prefix: true }),
65            FTS5_TOKENIZE_AUX => Ok(Self::Aux),
66            _ => Err(IntoTokenizeReasonError::UnrecognizedValue(value)),
67        }
68    }
69}
70
71/// Tokenizer
72pub trait Tokenizer: Sized + Send + 'static {
73    /// 一个全局数据的类型
74    type Global: Send + 'static;
75    /// 提供一个 tokenizer 名称
76    fn name() -> &'static CStr;
77    /// 创建 Tokenizer 方法
78    ///
79    /// 在创建 Tokenizer 实例后,通过指定的全局数据访问这个实例
80    ///
81    /// 在 xCreate 中被调用,xCreate 的 azArg 参数转换成 Vec<String>,并以此提供给 new方法使用
82    fn new(global: &Self::Global, args: Vec<String>) -> Result<Self, rusqlite::Error>;
83    /// 分词的具体实现
84    ///
85    /// 应该检查 `text` 对象,并且对每个 `token` 调用 `push_token` 这个回调方法
86    ///
87    /// `push_token` 的参数有
88    /// * &[u8] - token
89    /// * Range<usize> - token 在文本中位置
90    /// * bool - 对应 `FTS5_TOKEN_COLOCATED`
91    ///
92    fn tokenize<TKF>(
93        &mut self,
94        reason: TokenizeReason,
95        text: &[u8],
96        push_token: TKF,
97    ) -> Result<(), rusqlite::Error>
98    where
99        TKF: FnMut(&[u8], Range<usize>, bool) -> Result<(), rusqlite::Error>;
100}
101
102unsafe extern "C" fn x_create<T: Tokenizer>(
103    global: *mut c_void,
104    args: *mut *const c_char,
105    nargs: c_int,
106    out_tokenizer: *mut *mut Fts5Tokenizer,
107) -> c_int {
108    let global = unsafe { &*global.cast::<T::Global>() };
109    let args = (0..nargs as usize)
110        .map(|i| unsafe { *args.add(i) })
111        .map(|s| unsafe { CStr::from_ptr(s).to_string_lossy().into_owned() })
112        .collect::<Vec<String>>();
113    let res = std::panic::catch_unwind(AssertUnwindSafe(move || T::new(global, args)));
114    match res {
115        Ok(Ok(v)) => {
116            let bp = Box::into_raw(Box::new(v));
117            unsafe {
118                *out_tokenizer = bp.cast::<Fts5Tokenizer>();
119            }
120            SQLITE_OK
121        }
122        Ok(Err(rusqlite::Error::SqliteFailure(e, _))) => e.extended_code,
123        Ok(Err(_)) => SQLITE_ERROR,
124        Err(msg) => {
125            log::error!(
126                "<{} as Tokenizer>::new panic: {}",
127                std::any::type_name::<T>(),
128                panic_err_to_str(&msg)
129            );
130            SQLITE_ERROR
131        }
132    }
133}
134
135unsafe extern "C" fn x_delete<T: Tokenizer>(v: *mut Fts5Tokenizer) {
136    let tokenizer = unsafe { Box::from_raw(v.cast::<T>()) };
137    match std::panic::catch_unwind(AssertUnwindSafe(move || drop(tokenizer))) {
138        Ok(()) => {}
139        Err(e) => {
140            log::error!(
141                "{}::drop panic: {}",
142                std::any::type_name::<T>(),
143                panic_err_to_str(&e)
144            );
145        }
146    }
147}
148
149unsafe extern "C" fn x_destroy<T: Tokenizer>(v: *mut c_void) {
150    let tokenizer = unsafe { Box::from_raw(v.cast::<T::Global>()) };
151    match std::panic::catch_unwind(AssertUnwindSafe(move || drop(tokenizer))) {
152        Ok(()) => {}
153        Err(e) => {
154            log::error!(
155                "{}::drop panic: {}",
156                std::any::type_name::<T::Global>(),
157                panic_err_to_str(&e)
158            );
159        }
160    }
161}
162
163/// 忽略 locale 配置
164unsafe extern "C" fn x_tokenize<T: Tokenizer>(
165    this: *mut Fts5Tokenizer,
166    ctx: *mut c_void,
167    flag: c_int,
168    data: *const c_char,
169    data_len: c_int,
170    _locale: *const c_char,
171    _locale_len: c_int,
172    push_token: Option<
173        unsafe extern "C" fn(*mut c_void, c_int, *const c_char, c_int, c_int, c_int) -> c_int,
174    >,
175) -> c_int {
176    let this = unsafe { &mut *this.cast::<T>() };
177    let reason = match TokenizeReason::try_from(flag) {
178        Ok(reason) => reason,
179        Err(error) => {
180            log::error!("{error}");
181            return SQLITE_ERROR;
182        }
183    };
184
185    let data = unsafe { std::slice::from_raw_parts(data.cast::<u8>(), data_len as usize) };
186
187    let push_token = push_token.expect("No provide push token function");
188    let push_token = |token: &[u8],
189                      Range { start, end }: Range<usize>,
190                      colocated: bool|
191     -> Result<(), rusqlite::Error> {
192        let token_len: c_int = token.len().try_into().expect("Token is too long");
193        assert!(
194            start <= data.len() && end <= data.len(),
195            "Token range is invalid. Range is [{start}..{end}], data length is {}",
196            data.len(),
197        );
198        let flags = if colocated { FTS5_TOKEN_COLOCATED } else { 0 };
199
200        let res = unsafe {
201            (push_token)(
202                ctx,
203                flags,
204                token.as_ptr().cast::<c_char>(),
205                token_len,
206                start as c_int,
207                end as c_int,
208            )
209        };
210        if res == SQLITE_OK {
211            Ok(())
212        } else {
213            Err(rusqlite::Error::SqliteFailure(
214                rusqlite::ffi::Error::new(res),
215                None,
216            ))
217        }
218    };
219
220    match std::panic::catch_unwind(AssertUnwindSafe(|| this.tokenize(reason, data, push_token))) {
221        Ok(Ok(())) => SQLITE_OK,
222        Ok(Err(rusqlite::Error::SqliteFailure(e, _))) => e.extended_code,
223        Ok(Err(_)) => SQLITE_ERROR,
224        Err(msg) => {
225            log::error!(
226                "<{} as Tokenizer>::tokenize panic: {}",
227                std::any::type_name::<T>(),
228                panic_err_to_str(&msg)
229            );
230            SQLITE_ERROR
231        }
232    }
233}
234
235fn panic_err_to_str(msg: &Box<dyn std::any::Any + Send>) -> &str {
236    if let Some(msg) = msg.downcast_ref::<String>() {
237        msg.as_str()
238    } else if let Some(msg) = msg.downcast_ref::<&'static str>() {
239        msg
240    } else {
241        "<non-string panic reason>"
242    }
243}
244
245#[derive(Debug)]
246pub enum RegisterTokenizerError {
247    SelectFts5Failed,
248    Fts5ApiNul,
249    Fts5ApiVersionTooLow,
250    Fts5xCreateTokenizerV2Nul,
251    Fts5xCreateTokenizerFailed(i32),
252}
253
254impl std::fmt::Display for RegisterTokenizerError {
255    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256        match self {
257            RegisterTokenizerError::SelectFts5Failed => {
258                write!(f, "SELECT fts5(?1) failed.")
259            }
260            RegisterTokenizerError::Fts5ApiNul => {
261                write!(f, "Could not get fts5 api.")
262            }
263            RegisterTokenizerError::Fts5ApiVersionTooLow => {
264                write!(f, "The version of fts5 api is too low.")
265            }
266            RegisterTokenizerError::Fts5xCreateTokenizerV2Nul => {
267                write!(f, "Fts5 api xCreateTokenizer_v2 ptr is null.")
268            }
269            RegisterTokenizerError::Fts5xCreateTokenizerFailed(rc) => {
270                write!(
271                    f,
272                    "Fts5 xCreateTokenizer failed, the error flag when sqlite returned is {rc}."
273                )
274            }
275        }
276    }
277}
278
279impl std::error::Error for RegisterTokenizerError {}
280
281/// 内部获取 fts5_api 指针的方法
282unsafe fn get_fts5_api(db: &Connection) -> Result<*mut fts5_api, RegisterTokenizerError> {
283    // 获取 fts5_api 结构体的指针,并且使用 sqlite3_bind_pointer 绑定指针
284    // 详情 https://sqlite.org/fts5.html#extending_fts5
285    let dbp = unsafe { db.handle() };
286    let mut api: *mut fts5_api = std::ptr::null_mut();
287    let mut stmt: *mut sqlite3_stmt = std::ptr::null_mut();
288    const FTS5_QUERY_STATEMENT: &CStr = c"SELECT fts5(?1)";
289    const FTS5_QUERY_STATEMENT_LEN: c_int = FTS5_QUERY_STATEMENT.count_bytes() as c_int;
290    unsafe {
291        if sqlite3_prepare_v3(
292            dbp,
293            FTS5_QUERY_STATEMENT.as_ptr(),
294            FTS5_QUERY_STATEMENT_LEN,
295            SQLITE_PREPARE_PERSISTENT,
296            &mut stmt,
297            std::ptr::null_mut(),
298        ) != SQLITE_OK
299        {
300            return Err(RegisterTokenizerError::SelectFts5Failed);
301        }
302        sqlite3_bind_pointer(
303            stmt,
304            1,
305            (&mut api) as *mut _ as *mut c_void,
306            c"fts5_api_ptr".as_ptr(),
307            None,
308        );
309        sqlite3_step(stmt);
310        sqlite3_finalize(stmt);
311    }
312    if api.is_null() {
313        return Err(RegisterTokenizerError::Fts5ApiNul);
314    }
315    Ok(api)
316}
317
318/// 注册 Tokenizer
319pub fn register_tokenizer<T: Tokenizer>(
320    db: &Connection,
321    global_data: T::Global,
322) -> Result<(), RegisterTokenizerError> {
323    unsafe {
324        let api: *mut fts5_api = get_fts5_api(db)?;
325        let global_data = Box::into_raw(Box::new(global_data));
326        if (*api).iVersion < FTS5_API_VERSION {
327            return Err(RegisterTokenizerError::Fts5ApiVersionTooLow);
328        }
329        // 注册tokenizer
330        let rc = ((*api)
331            .xCreateTokenizer_v2
332            .as_ref()
333            .ok_or(RegisterTokenizerError::Fts5xCreateTokenizerV2Nul)?)(
334            api,
335            T::name().as_ptr(),
336            global_data.cast::<c_void>(),
337            &mut fts5_tokenizer_v2 {
338                iVersion: FTS5_TOKENIZER_VERSION,
339                xCreate: Some(x_create::<T>),
340                xDelete: Some(x_delete::<T>),
341                xTokenize: Some(x_tokenize::<T>),
342            },
343            Some(x_destroy::<T>),
344        );
345        if rc != SQLITE_OK {
346            return Err(RegisterTokenizerError::Fts5xCreateTokenizerFailed(rc));
347        }
348        Ok(())
349    }
350}