1use rusqlite::Connection;
5use rusqlite::ffi::{
6 FTS5_TOKEN_COLOCATED, FTS5_TOKENIZE_AUX, FTS5_TOKENIZE_DOCUMENT, FTS5_TOKENIZE_PREFIX,
7 FTS5_TOKENIZE_QUERY, Fts5Tokenizer, SQLITE_ERROR, SQLITE_OK, SQLITE_PREPARE_PERSISTENT,
8 fts5_api, fts5_tokenizer_v2, sqlite3_bind_pointer, sqlite3_finalize, sqlite3_prepare_v3,
9 sqlite3_step, sqlite3_stmt,
10};
11use std::convert::{TryFrom, TryInto};
12use std::ffi::{CStr, c_char, c_int, c_void};
13use std::fmt::Formatter;
14use std::ops::Range;
15use std::panic::AssertUnwindSafe;
16
17pub mod error;
18
19const FTS5_API_VERSION: c_int = 3;
21const FTS5_TOKENIZER_VERSION: c_int = 2;
23
24#[derive(Clone, Debug, Eq, PartialEq)]
26pub enum TokenizeReason {
27 Document,
29 Query {
31 prefix: bool,
33 },
34 Aux,
36}
37
38#[derive(Debug)]
39pub enum IntoTokenizeReasonError {
40 UnrecognizedValue(c_int),
41}
42
43impl std::fmt::Display for IntoTokenizeReasonError {
44 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
45 match self {
46 Self::UnrecognizedValue(flag) => {
47 write!(f, "Unrecognized flags passed to xTokenize: {flag}")
48 }
49 }
50 }
51}
52
53impl std::error::Error for IntoTokenizeReasonError {}
54
55impl TryFrom<c_int> for TokenizeReason {
56 type Error = IntoTokenizeReasonError;
57
58 fn try_from(value: c_int) -> Result<Self, Self::Error> {
59 const FTS5_TOKENIZE_QUERY_PREFIX: c_int = FTS5_TOKENIZE_QUERY | FTS5_TOKENIZE_PREFIX;
61 match value {
62 FTS5_TOKENIZE_DOCUMENT => Ok(Self::Document),
63 FTS5_TOKENIZE_QUERY => Ok(Self::Query { prefix: false }),
64 FTS5_TOKENIZE_QUERY_PREFIX => Ok(Self::Query { prefix: true }),
65 FTS5_TOKENIZE_AUX => Ok(Self::Aux),
66 _ => Err(IntoTokenizeReasonError::UnrecognizedValue(value)),
67 }
68 }
69}
70
71pub trait Tokenizer: Sized + Send + 'static {
73 type Global: Send + 'static;
75 fn name() -> &'static CStr;
77 fn new(global: &Self::Global, args: Vec<String>) -> Result<Self, rusqlite::Error>;
83 fn tokenize<TKF>(
93 &mut self,
94 reason: TokenizeReason,
95 text: &[u8],
96 push_token: TKF,
97 ) -> Result<(), rusqlite::Error>
98 where
99 TKF: FnMut(&[u8], Range<usize>, bool) -> Result<(), rusqlite::Error>;
100}
101
102unsafe extern "C" fn x_create<T: Tokenizer>(
103 global: *mut c_void,
104 args: *mut *const c_char,
105 nargs: c_int,
106 out_tokenizer: *mut *mut Fts5Tokenizer,
107) -> c_int {
108 let global = unsafe { &*global.cast::<T::Global>() };
109 let args = (0..nargs as usize)
110 .map(|i| unsafe { *args.add(i) })
111 .map(|s| unsafe { CStr::from_ptr(s).to_string_lossy().into_owned() })
112 .collect::<Vec<String>>();
113 let res = std::panic::catch_unwind(AssertUnwindSafe(move || T::new(global, args)));
114 match res {
115 Ok(Ok(v)) => {
116 let bp = Box::into_raw(Box::new(v));
117 unsafe {
118 *out_tokenizer = bp.cast::<Fts5Tokenizer>();
119 }
120 SQLITE_OK
121 }
122 Ok(Err(rusqlite::Error::SqliteFailure(e, _))) => e.extended_code,
123 Ok(Err(_)) => SQLITE_ERROR,
124 Err(msg) => {
125 log::error!(
126 "<{} as Tokenizer>::new panic: {}",
127 std::any::type_name::<T>(),
128 panic_err_to_str(&msg)
129 );
130 SQLITE_ERROR
131 }
132 }
133}
134
135unsafe extern "C" fn x_delete<T: Tokenizer>(v: *mut Fts5Tokenizer) {
136 let tokenizer = unsafe { Box::from_raw(v.cast::<T>()) };
137 match std::panic::catch_unwind(AssertUnwindSafe(move || drop(tokenizer))) {
138 Ok(()) => {}
139 Err(e) => {
140 log::error!(
141 "{}::drop panic: {}",
142 std::any::type_name::<T>(),
143 panic_err_to_str(&e)
144 );
145 }
146 }
147}
148
149unsafe extern "C" fn x_destroy<T: Tokenizer>(v: *mut c_void) {
150 let tokenizer = unsafe { Box::from_raw(v.cast::<T::Global>()) };
151 match std::panic::catch_unwind(AssertUnwindSafe(move || drop(tokenizer))) {
152 Ok(()) => {}
153 Err(e) => {
154 log::error!(
155 "{}::drop panic: {}",
156 std::any::type_name::<T::Global>(),
157 panic_err_to_str(&e)
158 );
159 }
160 }
161}
162
163unsafe extern "C" fn x_tokenize<T: Tokenizer>(
165 this: *mut Fts5Tokenizer,
166 ctx: *mut c_void,
167 flag: c_int,
168 data: *const c_char,
169 data_len: c_int,
170 _locale: *const c_char,
171 _locale_len: c_int,
172 push_token: Option<
173 unsafe extern "C" fn(*mut c_void, c_int, *const c_char, c_int, c_int, c_int) -> c_int,
174 >,
175) -> c_int {
176 let this = unsafe { &mut *this.cast::<T>() };
177 let reason = match TokenizeReason::try_from(flag) {
178 Ok(reason) => reason,
179 Err(error) => {
180 log::error!("{error}");
181 return SQLITE_ERROR;
182 }
183 };
184
185 let data = unsafe { std::slice::from_raw_parts(data.cast::<u8>(), data_len as usize) };
186
187 let push_token = push_token.expect("No provide push token function");
188 let push_token = |token: &[u8],
189 Range { start, end }: Range<usize>,
190 colocated: bool|
191 -> Result<(), rusqlite::Error> {
192 let token_len: c_int = token.len().try_into().expect("Token is too long");
193 assert!(
194 start <= data.len() && end <= data.len(),
195 "Token range is invalid. Range is [{start}..{end}], data length is {}",
196 data.len(),
197 );
198 let flags = if colocated { FTS5_TOKEN_COLOCATED } else { 0 };
199
200 let res = unsafe {
201 (push_token)(
202 ctx,
203 flags,
204 token.as_ptr().cast::<c_char>(),
205 token_len,
206 start as c_int,
207 end as c_int,
208 )
209 };
210 if res == SQLITE_OK {
211 Ok(())
212 } else {
213 Err(rusqlite::Error::SqliteFailure(
214 rusqlite::ffi::Error::new(res),
215 None,
216 ))
217 }
218 };
219
220 match std::panic::catch_unwind(AssertUnwindSafe(|| this.tokenize(reason, data, push_token))) {
221 Ok(Ok(())) => SQLITE_OK,
222 Ok(Err(rusqlite::Error::SqliteFailure(e, _))) => e.extended_code,
223 Ok(Err(_)) => SQLITE_ERROR,
224 Err(msg) => {
225 log::error!(
226 "<{} as Tokenizer>::tokenize panic: {}",
227 std::any::type_name::<T>(),
228 panic_err_to_str(&msg)
229 );
230 SQLITE_ERROR
231 }
232 }
233}
234
235fn panic_err_to_str(msg: &Box<dyn std::any::Any + Send>) -> &str {
236 if let Some(msg) = msg.downcast_ref::<String>() {
237 msg.as_str()
238 } else if let Some(msg) = msg.downcast_ref::<&'static str>() {
239 msg
240 } else {
241 "<non-string panic reason>"
242 }
243}
244
245#[derive(Debug)]
246pub enum RegisterTokenizerError {
247 SelectFts5Failed,
248 Fts5ApiNul,
249 Fts5ApiVersionTooLow,
250 Fts5xCreateTokenizerV2Nul,
251 Fts5xCreateTokenizerFailed(i32),
252}
253
254impl std::fmt::Display for RegisterTokenizerError {
255 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256 match self {
257 RegisterTokenizerError::SelectFts5Failed => {
258 write!(f, "SELECT fts5(?1) failed.")
259 }
260 RegisterTokenizerError::Fts5ApiNul => {
261 write!(f, "Could not get fts5 api.")
262 }
263 RegisterTokenizerError::Fts5ApiVersionTooLow => {
264 write!(f, "The version of fts5 api is too low.")
265 }
266 RegisterTokenizerError::Fts5xCreateTokenizerV2Nul => {
267 write!(f, "Fts5 api xCreateTokenizer_v2 ptr is null.")
268 }
269 RegisterTokenizerError::Fts5xCreateTokenizerFailed(rc) => {
270 write!(
271 f,
272 "Fts5 xCreateTokenizer failed, the error flag when sqlite returned is {rc}."
273 )
274 }
275 }
276 }
277}
278
279impl std::error::Error for RegisterTokenizerError {}
280
281unsafe fn get_fts5_api(db: &Connection) -> Result<*mut fts5_api, RegisterTokenizerError> {
283 let dbp = unsafe { db.handle() };
286 let mut api: *mut fts5_api = std::ptr::null_mut();
287 let mut stmt: *mut sqlite3_stmt = std::ptr::null_mut();
288 const FTS5_QUERY_STATEMENT: &CStr = c"SELECT fts5(?1)";
289 const FTS5_QUERY_STATEMENT_LEN: c_int = FTS5_QUERY_STATEMENT.count_bytes() as c_int;
290 unsafe {
291 if sqlite3_prepare_v3(
292 dbp,
293 FTS5_QUERY_STATEMENT.as_ptr(),
294 FTS5_QUERY_STATEMENT_LEN,
295 SQLITE_PREPARE_PERSISTENT,
296 &mut stmt,
297 std::ptr::null_mut(),
298 ) != SQLITE_OK
299 {
300 return Err(RegisterTokenizerError::SelectFts5Failed);
301 }
302 sqlite3_bind_pointer(
303 stmt,
304 1,
305 (&mut api) as *mut _ as *mut c_void,
306 c"fts5_api_ptr".as_ptr(),
307 None,
308 );
309 sqlite3_step(stmt);
310 sqlite3_finalize(stmt);
311 }
312 if api.is_null() {
313 return Err(RegisterTokenizerError::Fts5ApiNul);
314 }
315 Ok(api)
316}
317
318pub fn register_tokenizer<T: Tokenizer>(
320 db: &Connection,
321 global_data: T::Global,
322) -> Result<(), RegisterTokenizerError> {
323 unsafe {
324 let api: *mut fts5_api = get_fts5_api(db)?;
325 let global_data = Box::into_raw(Box::new(global_data));
326 if (*api).iVersion < FTS5_API_VERSION {
327 return Err(RegisterTokenizerError::Fts5ApiVersionTooLow);
328 }
329 let rc = ((*api)
331 .xCreateTokenizer_v2
332 .as_ref()
333 .ok_or(RegisterTokenizerError::Fts5xCreateTokenizerV2Nul)?)(
334 api,
335 T::name().as_ptr(),
336 global_data.cast::<c_void>(),
337 &mut fts5_tokenizer_v2 {
338 iVersion: FTS5_TOKENIZER_VERSION,
339 xCreate: Some(x_create::<T>),
340 xDelete: Some(x_delete::<T>),
341 xTokenize: Some(x_tokenize::<T>),
342 },
343 Some(x_destroy::<T>),
344 );
345 if rc != SQLITE_OK {
346 return Err(RegisterTokenizerError::Fts5xCreateTokenizerFailed(rc));
347 }
348 Ok(())
349 }
350}