Skip to main content

sqlite_graphrag/
tokenizer.rs

1//! Token-count utilities for embedding input sizing.
2//!
3//! Provides fast approximate token counting used to decide whether a body
4//! fits in a single chunk or requires the multi-chunk splitter.
5
6use crate::constants::PASSAGE_PREFIX;
7use crate::errors::AppError;
8use fastembed::{EmbeddingModel, TextEmbedding};
9use huggingface_hub::api::sync::ApiBuilder;
10use std::path::{Path, PathBuf};
11use std::sync::OnceLock;
12use tokenizers::Tokenizer;
13
14struct TokenizerRuntime {
15    tokenizer: Tokenizer,
16    model_max_length: usize,
17}
18
19static TOKENIZER_RUNTIME: OnceLock<TokenizerRuntime> = OnceLock::new();
20
21pub fn get_tokenizer(models_dir: &Path) -> Result<&'static Tokenizer, AppError> {
22    Ok(&get_runtime(models_dir)?.tokenizer)
23}
24
25pub fn get_model_max_length(models_dir: &Path) -> Result<usize, AppError> {
26    Ok(get_runtime(models_dir)?.model_max_length)
27}
28
29pub fn count_passage_tokens(tokenizer: &Tokenizer, text: &str) -> Result<usize, AppError> {
30    let prefixed = format!("{PASSAGE_PREFIX}{text}");
31    count_tokens(tokenizer, &prefixed)
32}
33
34pub fn passage_token_offsets(
35    tokenizer: &Tokenizer,
36    text: &str,
37) -> Result<Vec<(usize, usize)>, AppError> {
38    let prefixed = format!("{PASSAGE_PREFIX}{text}");
39    let prefix_len = PASSAGE_PREFIX.len();
40    let encoding = tokenizer
41        .encode(prefixed, true)
42        .map_err(|e| AppError::Embedding(e.to_string()))?;
43
44    let mut offsets = Vec::new();
45    for &(start, end) in encoding.get_offsets() {
46        if end <= start || end <= prefix_len {
47            continue;
48        }
49
50        let adjusted_start = start.saturating_sub(prefix_len).min(text.len());
51        let adjusted_end = end.saturating_sub(prefix_len).min(text.len());
52
53        if adjusted_end > adjusted_start
54            && text.is_char_boundary(adjusted_start)
55            && text.is_char_boundary(adjusted_end)
56        {
57            offsets.push((adjusted_start, adjusted_end));
58        }
59    }
60
61    if offsets.is_empty() && !text.is_empty() {
62        offsets.push((0, text.len()));
63    }
64
65    Ok(offsets)
66}
67
68fn count_tokens(tokenizer: &Tokenizer, text: &str) -> Result<usize, AppError> {
69    let encoding = tokenizer
70        .encode(text, true)
71        .map_err(|e| AppError::Embedding(e.to_string()))?;
72    Ok(encoding.len())
73}
74
75fn get_runtime(models_dir: &Path) -> Result<&'static TokenizerRuntime, AppError> {
76    if let Some(runtime) = TOKENIZER_RUNTIME.get() {
77        return Ok(runtime);
78    }
79
80    let runtime = load_runtime(models_dir)?;
81    let _ = TOKENIZER_RUNTIME.set(runtime);
82    Ok(TOKENIZER_RUNTIME
83        .get()
84        .expect("tokenizer runtime just initialized"))
85}
86
87fn load_runtime(models_dir: &Path) -> Result<TokenizerRuntime, AppError> {
88    let model_info = TextEmbedding::get_model_info(&EmbeddingModel::MultilingualE5Small)
89        .map_err(|e| AppError::Embedding(e.to_string()))?;
90
91    let cache_dir = std::env::var("HF_HOME")
92        .map(PathBuf::from)
93        .unwrap_or_else(|_| models_dir.to_path_buf());
94    let endpoint =
95        std::env::var("HF_ENDPOINT").unwrap_or_else(|_| "https://huggingface.co".to_string());
96
97    let api = ApiBuilder::new()
98        .with_cache_dir(cache_dir)
99        .with_endpoint(endpoint)
100        .with_progress(false)
101        .build()
102        .map_err(|e| AppError::Embedding(e.to_string()))?;
103    let repo = api.model(model_info.model_code.clone());
104
105    let tokenizer_bytes =
106        std::fs::read(repo.get("tokenizer.json").map_err(map_hf_err)?).map_err(AppError::Io)?;
107    let tokenizer_config_bytes =
108        std::fs::read(repo.get("tokenizer_config.json").map_err(map_hf_err)?)
109            .map_err(AppError::Io)?;
110
111    let tokenizer =
112        Tokenizer::from_bytes(tokenizer_bytes).map_err(|e| AppError::Embedding(e.to_string()))?;
113    let tokenizer_config: serde_json::Value =
114        serde_json::from_slice(&tokenizer_config_bytes).map_err(AppError::Json)?;
115    let model_max_length = tokenizer_config["model_max_length"]
116        .as_u64()
117        .map(|n| n as usize)
118        .or_else(|| {
119            tokenizer_config["model_max_length"]
120                .as_f64()
121                .map(|n| n as usize)
122        })
123        .ok_or_else(|| AppError::Embedding("tokenizer_config.json sem model_max_length".into()))?;
124
125    Ok(TokenizerRuntime {
126        tokenizer,
127        model_max_length,
128    })
129}
130
131fn map_hf_err(err: huggingface_hub::api::sync::ApiError) -> AppError {
132    AppError::Embedding(err.to_string())
133}