1use jieba_rs::{Jieba, TokenizeMode as JiebaTokenizeMode};
2use rusqlite::ffi::{self, fts5_api, fts5_tokenizer_v2};
3use rusqlite::types::ToSqlOutput;
4use rusqlite::{Connection, params};
5use serde::{Deserialize, Serialize};
6use std::collections::{HashMap, HashSet};
7use std::ffi::{CStr, c_char, c_int, c_void};
8use std::ptr;
9use std::slice;
10use std::str;
11use std::sync::{Arc, Mutex, OnceLock, RwLock};
12
13pub const VULCAN_DICT_TABLE: &str = "_vulcan_dict";
16
17const SQLITE_JIEBA_TOKENIZER_NAME: &CStr = c"jieba";
20
21const SQLITE_FTS5_API_PTR_TYPE: &CStr = c"fts5_api_ptr";
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
28#[serde(rename_all = "snake_case")]
29pub enum TokenizerMode {
30 #[default]
33 None,
34 Jieba,
37}
38
39impl TokenizerMode {
40 pub fn as_str(self) -> &'static str {
43 match self {
44 Self::None => "none",
45 Self::Jieba => "jieba",
46 }
47 }
48
49 #[allow(dead_code)]
52 pub fn parse(value: &str) -> Option<Self> {
53 match value.trim().to_ascii_lowercase().as_str() {
54 "" | "none" | "plain" | "default" => Some(Self::None),
55 "jieba" | "zh" | "zh_cn" => Some(Self::Jieba),
56 _ => None,
57 }
58 }
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
64pub struct CustomWordEntry {
65 pub word: String,
68 pub weight: usize,
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
76pub struct TokenizeOutput {
77 pub tokenizer_mode: String,
80 pub normalized_text: String,
83 pub tokens: Vec<String>,
86 pub fts_query: String,
89}
90
91#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
94pub struct DictionaryMutationResult {
95 pub success: bool,
98 pub message: String,
101 pub affected_rows: u64,
104}
105
106#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
109pub struct ListCustomWordsResult {
110 pub success: bool,
113 pub message: String,
116 pub words: Vec<CustomWordEntry>,
119}
120
121#[derive(Debug, Default)]
124struct SharedDictionaryState {
125 custom_words: Vec<CustomWordEntry>,
126}
127
128#[derive(Debug)]
131struct RegisteredTokenizerContext {
132 connection_handle: usize,
133 shared_state: Arc<RwLock<SharedDictionaryState>>,
134}
135
136#[derive(Debug)]
139struct JiebaTokenizerInstance {
140 shared_state: Arc<RwLock<SharedDictionaryState>>,
141}
142
143#[derive(Debug, Clone, PartialEq, Eq)]
146struct TokenSpan {
147 token: String,
148 start_byte: usize,
149 end_byte: usize,
150}
151
152fn shared_dictionary_registry() -> &'static Mutex<HashMap<String, Arc<RwLock<SharedDictionaryState>>>> {
155 static REGISTRY: OnceLock<Mutex<HashMap<String, Arc<RwLock<SharedDictionaryState>>>>> =
156 OnceLock::new();
157 REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
158}
159
160fn registered_connection_handles() -> &'static Mutex<HashSet<usize>> {
163 static REGISTRY: OnceLock<Mutex<HashSet<usize>>> = OnceLock::new();
164 REGISTRY.get_or_init(|| Mutex::new(HashSet::new()))
165}
166
167pub fn ensure_vulcan_dict_table(connection: &Connection) -> rusqlite::Result<()> {
170 connection.execute_batch(&format!(
171 "CREATE TABLE IF NOT EXISTS {table_name} (
172 word TEXT PRIMARY KEY,
173 weight INTEGER NOT NULL DEFAULT 1,
174 enabled INTEGER NOT NULL DEFAULT 1,
175 created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')),
176 updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
177 );",
178 table_name = VULCAN_DICT_TABLE
179 ))
180}
181
182pub fn ensure_jieba_tokenizer_registered(connection: &Connection) -> rusqlite::Result<()> {
185 ensure_vulcan_dict_table(connection)?;
186
187 let db_key = connection_registry_key(connection)?;
188 let shared_state = shared_dictionary_state_for_key(&db_key);
189 refresh_shared_dictionary_state(connection, &shared_state)?;
190
191 let connection_handle = sqlite_connection_handle(connection);
192 {
193 let registered = registered_connection_handles()
194 .lock()
195 .unwrap_or_else(|poisoned| poisoned.into_inner());
196 if registered.contains(&connection_handle) {
197 return Ok(());
198 }
199 }
200
201 let fts_api = fetch_fts5_api(connection)?;
202 let registration_context = Box::new(RegisteredTokenizerContext {
203 connection_handle,
204 shared_state,
205 });
206 let registration_context_ptr = Box::into_raw(registration_context) as *mut c_void;
207 let tokenizer = fts5_tokenizer_v2 {
208 iVersion: 2,
209 xCreate: Some(sqlite_jieba_tokenizer_create),
210 xDelete: Some(sqlite_jieba_tokenizer_delete),
211 xTokenize: Some(sqlite_jieba_tokenizer_tokenize),
212 };
213
214 let create = unsafe {
215 (*fts_api)
216 .xCreateTokenizer_v2
217 .ok_or_else(|| rusqlite::Error::ExecuteReturnedResults)?
218 };
219
220 let rc = unsafe {
221 create(
222 fts_api,
223 SQLITE_JIEBA_TOKENIZER_NAME.as_ptr(),
224 registration_context_ptr,
225 &tokenizer as *const fts5_tokenizer_v2 as *mut fts5_tokenizer_v2,
226 Some(sqlite_jieba_tokenizer_registration_destroy),
227 )
228 };
229 if rc != ffi::SQLITE_OK {
230 unsafe {
231 drop(Box::from_raw(
232 registration_context_ptr as *mut RegisteredTokenizerContext,
233 ));
234 }
235 return Err(rusqlite::Error::SqliteFailure(
236 ffi::Error::new(rc),
237 Some("register jieba tokenizer failed / 注册 jieba tokenizer 失败".to_string()),
238 ));
239 }
240
241 registered_connection_handles()
242 .lock()
243 .unwrap_or_else(|poisoned| poisoned.into_inner())
244 .insert(connection_handle);
245
246 Ok(())
247}
248
249pub fn upsert_custom_word(
252 connection: &Connection,
253 word: &str,
254 weight: usize,
255) -> rusqlite::Result<DictionaryMutationResult> {
256 ensure_jieba_tokenizer_registered(connection)?;
257 let trimmed = word.trim();
258 if trimmed.is_empty() {
259 return Ok(DictionaryMutationResult {
260 success: false,
261 message: "custom word must not be empty / 自定义词不能为空".to_string(),
262 affected_rows: 0,
263 });
264 }
265
266 let affected_rows = connection.execute(
267 &format!(
268 "INSERT INTO {table_name} (word, weight, enabled, created_at, updated_at)
269 VALUES (?1, ?2, 1, strftime('%Y-%m-%dT%H:%M:%fZ', 'now'), strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
270 ON CONFLICT(word) DO UPDATE SET
271 weight = excluded.weight,
272 enabled = 1,
273 updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')",
274 table_name = VULCAN_DICT_TABLE
275 ),
276 params![trimmed, weight as i64],
277 )?;
278 refresh_registered_dictionary(connection)?;
279
280 Ok(DictionaryMutationResult {
281 success: true,
282 message: "custom word upserted / 自定义词已写入".to_string(),
283 affected_rows: affected_rows as u64,
284 })
285}
286
287pub fn remove_custom_word(
290 connection: &Connection,
291 word: &str,
292) -> rusqlite::Result<DictionaryMutationResult> {
293 ensure_jieba_tokenizer_registered(connection)?;
294 let trimmed = word.trim();
295 if trimmed.is_empty() {
296 return Ok(DictionaryMutationResult {
297 success: false,
298 message: "custom word must not be empty / 自定义词不能为空".to_string(),
299 affected_rows: 0,
300 });
301 }
302
303 let affected_rows = connection.execute(
304 &format!("DELETE FROM {table_name} WHERE word = ?1", table_name = VULCAN_DICT_TABLE),
305 params![trimmed],
306 )?;
307 refresh_registered_dictionary(connection)?;
308
309 Ok(DictionaryMutationResult {
310 success: true,
311 message: if affected_rows > 0 {
312 "custom word removed / 自定义词已删除".to_string()
313 } else {
314 "custom word not found / 自定义词不存在".to_string()
315 },
316 affected_rows: affected_rows as u64,
317 })
318}
319
320pub fn load_custom_words(connection: &Connection) -> rusqlite::Result<Vec<CustomWordEntry>> {
323 ensure_vulcan_dict_table(connection)?;
324 let mut statement = connection.prepare(&format!(
325 "SELECT word, weight
326 FROM {table_name}
327 WHERE enabled = 1
328 ORDER BY word ASC",
329 table_name = VULCAN_DICT_TABLE
330 ))?;
331
332 let mut rows = statement.query([])?;
333 let mut entries = Vec::new();
334 while let Some(row) = rows.next()? {
335 entries.push(CustomWordEntry {
336 word: row.get::<_, String>(0)?,
337 weight: row.get::<_, i64>(1)?.max(1) as usize,
338 });
339 }
340
341 Ok(entries)
342}
343
344pub fn list_custom_words(connection: &Connection) -> rusqlite::Result<ListCustomWordsResult> {
347 let words = load_custom_words(connection)?;
348 Ok(ListCustomWordsResult {
349 success: true,
350 message: format!(
351 "listed {} custom words / 已列出 {} 个自定义词",
352 words.len(),
353 words.len()
354 ),
355 words,
356 })
357}
358
359pub fn tokenize_text(
362 connection: Option<&Connection>,
363 mode: TokenizerMode,
364 text: &str,
365 search_mode: bool,
366) -> rusqlite::Result<TokenizeOutput> {
367 let normalized_text = normalize_text(text);
368 let tokens = match mode {
369 TokenizerMode::None => tokenize_plain(&normalized_text),
370 TokenizerMode::Jieba => tokenize_with_jieba(connection, &normalized_text, search_mode)?,
371 };
372 let fts_query = build_fts_query(&tokens, search_mode);
373
374 Ok(TokenizeOutput {
375 tokenizer_mode: mode.as_str().to_string(),
376 normalized_text,
377 tokens,
378 fts_query,
379 })
380}
381
382fn normalize_text(text: &str) -> String {
385 text.split_whitespace().collect::<Vec<_>>().join(" ")
386}
387
388fn tokenize_plain(text: &str) -> Vec<String> {
391 if text.is_empty() {
392 return Vec::new();
393 }
394
395 let split = text
396 .split(|ch: char| ch.is_whitespace() || ch.is_ascii_punctuation())
397 .filter(|part| !part.is_empty())
398 .map(|part| part.to_string())
399 .collect::<Vec<_>>();
400
401 if split.is_empty() {
402 vec![text.to_string()]
403 } else {
404 split
405 }
406}
407
408fn tokenize_with_jieba(
411 connection: Option<&Connection>,
412 text: &str,
413 search_mode: bool,
414) -> rusqlite::Result<Vec<String>> {
415 if text.is_empty() {
416 return Ok(Vec::new());
417 }
418
419 let custom_words = if let Some(connection) = connection {
420 ensure_jieba_tokenizer_registered(connection)?;
421 current_custom_words(connection)?
422 } else {
423 Vec::new()
424 };
425
426 Ok(jieba_token_spans(text, search_mode, &custom_words)
427 .into_iter()
428 .map(|span| span.token)
429 .collect())
430}
431
432fn build_fts_query(tokens: &[String], search_mode: bool) -> String {
435 tokens
436 .iter()
437 .filter(|token| !token.is_empty())
438 .map(|token| format!("\"{}\"", token.replace('"', "\"\"")))
439 .collect::<Vec<_>>()
440 .join(if search_mode { " OR " } else { " " })
441}
442
443fn current_custom_words(connection: &Connection) -> rusqlite::Result<Vec<CustomWordEntry>> {
446 let db_key = connection_registry_key(connection)?;
447 let shared_state = shared_dictionary_state_for_key(&db_key);
448 Ok(shared_state
449 .read()
450 .unwrap_or_else(|poisoned| poisoned.into_inner())
451 .custom_words
452 .clone())
453}
454
455fn refresh_registered_dictionary(connection: &Connection) -> rusqlite::Result<()> {
458 let db_key = connection_registry_key(connection)?;
459 let shared_state = shared_dictionary_state_for_key(&db_key);
460 refresh_shared_dictionary_state(connection, &shared_state)
461}
462
463fn shared_dictionary_state_for_key(
466 db_key: &str,
467) -> Arc<RwLock<SharedDictionaryState>> {
468 let mut registry = shared_dictionary_registry()
469 .lock()
470 .unwrap_or_else(|poisoned| poisoned.into_inner());
471 registry
472 .entry(db_key.to_string())
473 .or_insert_with(|| Arc::new(RwLock::new(SharedDictionaryState::default())))
474 .clone()
475}
476
477fn refresh_shared_dictionary_state(
480 connection: &Connection,
481 shared_state: &Arc<RwLock<SharedDictionaryState>>,
482) -> rusqlite::Result<()> {
483 let custom_words = load_custom_words(connection)?;
484 let mut guard = shared_state
485 .write()
486 .unwrap_or_else(|poisoned| poisoned.into_inner());
487 guard.custom_words = custom_words;
488 Ok(())
489}
490
491fn connection_registry_key(connection: &Connection) -> rusqlite::Result<String> {
494 let handle = sqlite_connection_handle(connection);
495 match connection.path() {
496 Some(path) if !path.trim().is_empty() => Ok(path.to_string()),
497 _ => Ok(format!(":memory:#{handle:x}")),
498 }
499}
500
501fn sqlite_connection_handle(connection: &Connection) -> usize {
504 unsafe { connection.handle() as usize }
506}
507
508fn fetch_fts5_api(connection: &Connection) -> rusqlite::Result<*mut fts5_api> {
511 let p_ret: *mut fts5_api = ptr::null_mut();
512 let ptr_arg = ToSqlOutput::Pointer((&p_ret as *const *mut fts5_api as _, SQLITE_FTS5_API_PTR_TYPE, None));
513 connection.query_row("SELECT fts5(?)", [ptr_arg], |_| Ok(()))?;
514 if p_ret.is_null() {
515 return Err(rusqlite::Error::SqliteFailure(
516 ffi::Error::new(ffi::SQLITE_ERROR),
517 Some("fts5() returned a null API pointer / fts5() 返回了空指针".to_string()),
518 ));
519 }
520 Ok(p_ret)
521}
522
523fn jieba_token_spans(
526 text: &str,
527 search_mode: bool,
528 custom_words: &[CustomWordEntry],
529) -> Vec<TokenSpan> {
530 if text.is_empty() {
531 return Vec::new();
532 }
533
534 let mut jieba = Jieba::new();
535 for entry in custom_words {
536 jieba.add_word(&entry.word, Some(entry.weight), None);
537 }
538
539 let char_to_byte = unicode_char_to_byte_offsets(text);
540 let mode = if search_mode {
541 JiebaTokenizeMode::Search
542 } else {
543 JiebaTokenizeMode::Default
544 };
545
546 jieba
547 .tokenize(text, mode, true)
548 .into_iter()
549 .filter_map(|token| {
550 let trimmed = token.word.trim();
551 if trimmed.is_empty() {
552 return None;
553 }
554 let start_byte = *char_to_byte.get(token.start)?;
555 let end_byte = *char_to_byte.get(token.end)?;
556 Some(TokenSpan {
557 token: trimmed.to_string(),
558 start_byte,
559 end_byte,
560 })
561 })
562 .collect()
563}
564
565fn unicode_char_to_byte_offsets(text: &str) -> Vec<usize> {
568 let mut offsets = text.char_indices().map(|(index, _)| index).collect::<Vec<_>>();
569 offsets.push(text.len());
570 offsets
571}
572
573unsafe extern "C" fn sqlite_jieba_tokenizer_create(
576 user_data: *mut c_void,
577 _args: *mut *const c_char,
578 _arg_count: c_int,
579 out_tokenizer: *mut *mut ffi::Fts5Tokenizer,
580) -> c_int {
581 if user_data.is_null() || out_tokenizer.is_null() {
582 return ffi::SQLITE_MISUSE;
583 }
584
585 let context = unsafe { &*(user_data as *const RegisteredTokenizerContext) };
587 let tokenizer = Box::new(JiebaTokenizerInstance {
588 shared_state: Arc::clone(&context.shared_state),
589 });
590 unsafe {
592 *out_tokenizer = Box::into_raw(tokenizer) as *mut ffi::Fts5Tokenizer;
593 }
594 ffi::SQLITE_OK
595}
596
597unsafe extern "C" fn sqlite_jieba_tokenizer_delete(tokenizer: *mut ffi::Fts5Tokenizer) {
600 if tokenizer.is_null() {
601 return;
602 }
603 unsafe {
605 drop(Box::from_raw(tokenizer as *mut JiebaTokenizerInstance));
606 }
607}
608
609unsafe extern "C" fn sqlite_jieba_tokenizer_registration_destroy(user_data: *mut c_void) {
612 if user_data.is_null() {
613 return;
614 }
615
616 let context = unsafe { Box::from_raw(user_data as *mut RegisteredTokenizerContext) };
618 registered_connection_handles()
619 .lock()
620 .unwrap_or_else(|poisoned| poisoned.into_inner())
621 .remove(&context.connection_handle);
622}
623
624#[allow(non_snake_case)]
627unsafe extern "C" fn sqlite_jieba_tokenizer_tokenize(
628 tokenizer: *mut ffi::Fts5Tokenizer,
629 token_context: *mut c_void,
630 flags: c_int,
631 text_ptr: *const c_char,
632 text_len: c_int,
633 _locale_ptr: *const c_char,
634 _locale_len: c_int,
635 token_callback: Option<
636 unsafe extern "C" fn(
637 pCtx: *mut c_void,
638 tflags: c_int,
639 pToken: *const c_char,
640 nToken: c_int,
641 iStart: c_int,
642 iEnd: c_int,
643 ) -> c_int,
644 >,
645) -> c_int {
646 if tokenizer.is_null() || token_context.is_null() || text_ptr.is_null() || text_len < 0 {
647 return ffi::SQLITE_MISUSE;
648 }
649 let Some(token_callback) = token_callback else {
650 return ffi::SQLITE_MISUSE;
651 };
652
653 let tokenizer = unsafe { &*(tokenizer as *const JiebaTokenizerInstance) };
655 let shared_state = tokenizer
656 .shared_state
657 .read()
658 .unwrap_or_else(|poisoned| poisoned.into_inner());
659 let text_bytes = unsafe { slice::from_raw_parts(text_ptr as *const u8, text_len as usize) };
661 let Ok(text) = str::from_utf8(text_bytes) else {
662 return ffi::SQLITE_ERROR;
663 };
664
665 let search_mode = (flags & ffi::FTS5_TOKENIZE_QUERY) != 0 || (flags & ffi::FTS5_TOKENIZE_AUX) != 0;
666 let spans = jieba_token_spans(text, search_mode, &shared_state.custom_words);
667 for span in spans {
668 let token_bytes = span.token.as_bytes();
669 let rc = unsafe {
670 token_callback(
671 token_context,
672 0,
673 token_bytes.as_ptr() as *const c_char,
674 token_bytes.len() as c_int,
675 span.start_byte as c_int,
676 span.end_byte as c_int,
677 )
678 };
679 if rc != ffi::SQLITE_OK {
680 return rc;
681 }
682 }
683
684 ffi::SQLITE_OK
685}
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690
691 #[test]
694 fn custom_words_affect_jieba_tokenization() -> rusqlite::Result<()> {
695 let connection = Connection::open_in_memory()?;
696 ensure_jieba_tokenizer_registered(&connection)?;
697
698 let before =
699 tokenize_text(Some(&connection), TokenizerMode::Jieba, "市民田-女士急匆匆", false)?;
700 assert!(!before.tokens.iter().any(|token| token == "田-女士"));
701
702 let mutation = upsert_custom_word(&connection, "田-女士", 42)?;
703 assert!(mutation.success);
704
705 let after =
706 tokenize_text(Some(&connection), TokenizerMode::Jieba, "市民田-女士急匆匆", false)?;
707 assert!(after.tokens.iter().any(|token| token == "田-女士"));
708
709 let removed = remove_custom_word(&connection, "田-女士")?;
710 assert!(removed.success);
711 Ok(())
712 }
713
714 #[test]
717 fn sqlite_fts_jieba_tokenizer_is_registered() -> rusqlite::Result<()> {
718 let connection = Connection::open_in_memory()?;
719 ensure_jieba_tokenizer_registered(&connection)?;
720 upsert_custom_word(&connection, "田-女士", 42)?;
721
722 connection.execute_batch(
723 "CREATE VIRTUAL TABLE IF NOT EXISTS mcp_memory_fts USING fts5(
724 content,
725 tokenize='jieba'
726 );",
727 )?;
728 connection.execute(
729 "INSERT INTO mcp_memory_fts (content) VALUES (?1)",
730 params!["市民田-女士急匆匆"],
731 )?;
732
733 connection.execute_batch(
734 "CREATE VIRTUAL TABLE IF NOT EXISTS mcp_memory_vocab USING fts5vocab(
735 mcp_memory_fts,
736 'instance'
737 );",
738 )?;
739
740 let count: i64 = connection.query_row(
741 "SELECT count(*) FROM mcp_memory_vocab WHERE term = ?1",
742 params!["田-女士"],
743 |row| row.get(0),
744 )?;
745 assert_eq!(count, 1);
746
747 Ok(())
748 }
749}