sqlite_graphrag/
tokenizer.rs1use crate::constants::PASSAGE_PREFIX;
7use crate::errors::AppError;
8use fastembed::{EmbeddingModel, TextEmbedding};
9use huggingface_hub::api::sync::ApiBuilder;
10use std::path::{Path, PathBuf};
11use std::sync::OnceLock;
12use tokenizers::Tokenizer;
13
14struct TokenizerRuntime {
15 tokenizer: Tokenizer,
16 model_max_length: usize,
17}
18
19static TOKENIZER_RUNTIME: OnceLock<TokenizerRuntime> = OnceLock::new();
20
21pub fn get_tokenizer(models_dir: &Path) -> Result<&'static Tokenizer, AppError> {
22 Ok(&get_runtime(models_dir)?.tokenizer)
23}
24
25pub fn get_model_max_length(models_dir: &Path) -> Result<usize, AppError> {
26 Ok(get_runtime(models_dir)?.model_max_length)
27}
28
29pub fn count_passage_tokens(tokenizer: &Tokenizer, text: &str) -> Result<usize, AppError> {
30 let prefixed = format!("{PASSAGE_PREFIX}{text}");
31 count_tokens(tokenizer, &prefixed)
32}
33
34pub fn passage_token_offsets(
35 tokenizer: &Tokenizer,
36 text: &str,
37) -> Result<Vec<(usize, usize)>, AppError> {
38 let prefixed = format!("{PASSAGE_PREFIX}{text}");
39 let prefix_len = PASSAGE_PREFIX.len();
40 let encoding = tokenizer
41 .encode(prefixed, true)
42 .map_err(|e| AppError::Embedding(e.to_string()))?;
43
44 let mut offsets = Vec::new();
45 for &(start, end) in encoding.get_offsets() {
46 if end <= start || end <= prefix_len {
47 continue;
48 }
49
50 let adjusted_start = start.saturating_sub(prefix_len).min(text.len());
51 let adjusted_end = end.saturating_sub(prefix_len).min(text.len());
52
53 if adjusted_end > adjusted_start
54 && text.is_char_boundary(adjusted_start)
55 && text.is_char_boundary(adjusted_end)
56 {
57 offsets.push((adjusted_start, adjusted_end));
58 }
59 }
60
61 if offsets.is_empty() && !text.is_empty() {
62 offsets.push((0, text.len()));
63 }
64
65 Ok(offsets)
66}
67
68fn count_tokens(tokenizer: &Tokenizer, text: &str) -> Result<usize, AppError> {
69 let encoding = tokenizer
70 .encode(text, true)
71 .map_err(|e| AppError::Embedding(e.to_string()))?;
72 Ok(encoding.len())
73}
74
75fn get_runtime(models_dir: &Path) -> Result<&'static TokenizerRuntime, AppError> {
76 if let Some(runtime) = TOKENIZER_RUNTIME.get() {
77 return Ok(runtime);
78 }
79
80 let runtime = load_runtime(models_dir)?;
81 let _ = TOKENIZER_RUNTIME.set(runtime);
82 Ok(TOKENIZER_RUNTIME
83 .get()
84 .expect("tokenizer runtime just initialized"))
85}
86
87fn load_runtime(models_dir: &Path) -> Result<TokenizerRuntime, AppError> {
88 let model_info = TextEmbedding::get_model_info(&EmbeddingModel::MultilingualE5Small)
89 .map_err(|e| AppError::Embedding(e.to_string()))?;
90
91 let cache_dir = std::env::var("HF_HOME")
92 .map(PathBuf::from)
93 .unwrap_or_else(|_| models_dir.to_path_buf());
94 let endpoint =
95 std::env::var("HF_ENDPOINT").unwrap_or_else(|_| "https://huggingface.co".to_string());
96
97 let api = ApiBuilder::new()
98 .with_cache_dir(cache_dir)
99 .with_endpoint(endpoint)
100 .with_progress(false)
101 .build()
102 .map_err(|e| AppError::Embedding(e.to_string()))?;
103 let repo = api.model(model_info.model_code.clone());
104
105 let tokenizer_bytes =
106 std::fs::read(repo.get("tokenizer.json").map_err(map_hf_err)?).map_err(AppError::Io)?;
107 let tokenizer_config_bytes =
108 std::fs::read(repo.get("tokenizer_config.json").map_err(map_hf_err)?)
109 .map_err(AppError::Io)?;
110
111 let tokenizer =
112 Tokenizer::from_bytes(tokenizer_bytes).map_err(|e| AppError::Embedding(e.to_string()))?;
113 let tokenizer_config: serde_json::Value =
114 serde_json::from_slice(&tokenizer_config_bytes).map_err(AppError::Json)?;
115 let model_max_length = tokenizer_config["model_max_length"]
116 .as_u64()
117 .map(|n| n as usize)
118 .or_else(|| {
119 tokenizer_config["model_max_length"]
120 .as_f64()
121 .map(|n| n as usize)
122 })
123 .ok_or_else(|| AppError::Embedding("tokenizer_config.json sem model_max_length".into()))?;
124
125 Ok(TokenizerRuntime {
126 tokenizer,
127 model_max_length,
128 })
129}
130
131fn map_hf_err(err: huggingface_hub::api::sync::ApiError) -> AppError {
132 AppError::Embedding(err.to_string())
133}