sqlite_graphrag/
tokenizer.rs1use crate::constants::PASSAGE_PREFIX;
2use crate::errors::AppError;
3use fastembed::{EmbeddingModel, TextEmbedding};
4use huggingface_hub::api::sync::ApiBuilder;
5use std::path::{Path, PathBuf};
6use std::sync::OnceLock;
7use tokenizers::Tokenizer;
8
9struct TokenizerRuntime {
10 tokenizer: Tokenizer,
11 model_max_length: usize,
12}
13
14static TOKENIZER_RUNTIME: OnceLock<TokenizerRuntime> = OnceLock::new();
15
16pub fn get_tokenizer(models_dir: &Path) -> Result<&'static Tokenizer, AppError> {
17 Ok(&get_runtime(models_dir)?.tokenizer)
18}
19
20pub fn get_model_max_length(models_dir: &Path) -> Result<usize, AppError> {
21 Ok(get_runtime(models_dir)?.model_max_length)
22}
23
24pub fn count_passage_tokens(tokenizer: &Tokenizer, text: &str) -> Result<usize, AppError> {
25 let prefixed = format!("{PASSAGE_PREFIX}{text}");
26 count_tokens(tokenizer, &prefixed)
27}
28
29pub fn passage_token_offsets(
30 tokenizer: &Tokenizer,
31 text: &str,
32) -> Result<Vec<(usize, usize)>, AppError> {
33 let prefixed = format!("{PASSAGE_PREFIX}{text}");
34 let prefix_len = PASSAGE_PREFIX.len();
35 let encoding = tokenizer
36 .encode(prefixed, true)
37 .map_err(|e| AppError::Embedding(e.to_string()))?;
38
39 let mut offsets = Vec::new();
40 for &(start, end) in encoding.get_offsets() {
41 if end <= start || end <= prefix_len {
42 continue;
43 }
44
45 let adjusted_start = start.saturating_sub(prefix_len).min(text.len());
46 let adjusted_end = end.saturating_sub(prefix_len).min(text.len());
47
48 if adjusted_end > adjusted_start
49 && text.is_char_boundary(adjusted_start)
50 && text.is_char_boundary(adjusted_end)
51 {
52 offsets.push((adjusted_start, adjusted_end));
53 }
54 }
55
56 if offsets.is_empty() && !text.is_empty() {
57 offsets.push((0, text.len()));
58 }
59
60 Ok(offsets)
61}
62
63fn count_tokens(tokenizer: &Tokenizer, text: &str) -> Result<usize, AppError> {
64 let encoding = tokenizer
65 .encode(text, true)
66 .map_err(|e| AppError::Embedding(e.to_string()))?;
67 Ok(encoding.len())
68}
69
70fn get_runtime(models_dir: &Path) -> Result<&'static TokenizerRuntime, AppError> {
71 if let Some(runtime) = TOKENIZER_RUNTIME.get() {
72 return Ok(runtime);
73 }
74
75 let runtime = load_runtime(models_dir)?;
76 let _ = TOKENIZER_RUNTIME.set(runtime);
77 Ok(TOKENIZER_RUNTIME
78 .get()
79 .expect("tokenizer runtime just initialized"))
80}
81
82fn load_runtime(models_dir: &Path) -> Result<TokenizerRuntime, AppError> {
83 let model_info = TextEmbedding::get_model_info(&EmbeddingModel::MultilingualE5Small)
84 .map_err(|e| AppError::Embedding(e.to_string()))?;
85
86 let cache_dir = std::env::var("HF_HOME")
87 .map(PathBuf::from)
88 .unwrap_or_else(|_| models_dir.to_path_buf());
89 let endpoint =
90 std::env::var("HF_ENDPOINT").unwrap_or_else(|_| "https://huggingface.co".to_string());
91
92 let api = ApiBuilder::new()
93 .with_cache_dir(cache_dir)
94 .with_endpoint(endpoint)
95 .with_progress(false)
96 .build()
97 .map_err(|e| AppError::Embedding(e.to_string()))?;
98 let repo = api.model(model_info.model_code.clone());
99
100 let tokenizer_bytes =
101 std::fs::read(repo.get("tokenizer.json").map_err(map_hf_err)?).map_err(AppError::Io)?;
102 let tokenizer_config_bytes =
103 std::fs::read(repo.get("tokenizer_config.json").map_err(map_hf_err)?)
104 .map_err(AppError::Io)?;
105
106 let tokenizer =
107 Tokenizer::from_bytes(tokenizer_bytes).map_err(|e| AppError::Embedding(e.to_string()))?;
108 let tokenizer_config: serde_json::Value =
109 serde_json::from_slice(&tokenizer_config_bytes).map_err(AppError::Json)?;
110 let model_max_length = tokenizer_config["model_max_length"]
111 .as_u64()
112 .map(|n| n as usize)
113 .or_else(|| {
114 tokenizer_config["model_max_length"]
115 .as_f64()
116 .map(|n| n as usize)
117 })
118 .ok_or_else(|| AppError::Embedding("tokenizer_config.json sem model_max_length".into()))?;
119
120 Ok(TokenizerRuntime {
121 tokenizer,
122 model_max_length,
123 })
124}
125
126fn map_hf_err(err: huggingface_hub::api::sync::ApiError) -> AppError {
127 AppError::Embedding(err.to_string())
128}