Skip to main content

sqlite_graphrag/
tokenizer.rs

1//! Token-count utilities for embedding input sizing.
2//!
3//! Provides fast approximate token counting used to decide whether a body
4//! fits in a single chunk or requires the multi-chunk splitter.
5
6use crate::constants::PASSAGE_PREFIX;
7use crate::errors::AppError;
8use fastembed::{EmbeddingModel, TextEmbedding};
9use huggingface_hub::api::sync::ApiBuilder;
10use std::path::{Path, PathBuf};
11use std::sync::OnceLock;
12use tokenizers::Tokenizer;
13
14struct TokenizerRuntime {
15    tokenizer: Tokenizer,
16    model_max_length: usize,
17}
18
19static TOKENIZER_RUNTIME: OnceLock<TokenizerRuntime> = OnceLock::new();
20
21/// Returns the process-wide [`Tokenizer`] singleton, initializing it on first call.
22///
23/// # Errors
24/// Returns `Err` when the tokenizer files cannot be loaded from `models_dir`.
25pub fn get_tokenizer(models_dir: &Path) -> Result<&'static Tokenizer, AppError> {
26    Ok(&get_runtime(models_dir)?.tokenizer)
27}
28
29/// Returns the model's `model_max_length` from `tokenizer_config.json`.
30///
31/// # Errors
32/// Returns `Err` when the tokenizer files cannot be loaded or the field is missing.
33pub fn get_model_max_length(models_dir: &Path) -> Result<usize, AppError> {
34    Ok(get_runtime(models_dir)?.model_max_length)
35}
36
37/// Counts the tokens produced by encoding `text` with the passage prefix.
38///
39/// Prepends `PASSAGE_PREFIX` before tokenizing so the count reflects the actual
40/// number of tokens consumed by the embedding model.
41///
42/// # Errors
43/// Returns `Err` when the tokenizer fails to encode the input.
44pub fn count_passage_tokens(tokenizer: &Tokenizer, text: &str) -> Result<usize, AppError> {
45    let prefixed = format!("{PASSAGE_PREFIX}{text}");
46    count_tokens(tokenizer, &prefixed)
47}
48
49/// Returns the byte-offset pairs `(start, end)` for each token in `text`.
50///
51/// The passage prefix is prepended before tokenizing; offsets in the returned
52/// vector are adjusted back to be relative to the original `text` slice.
53///
54/// # Errors
55/// Returns `Err` when the tokenizer fails to encode the input.
56pub fn passage_token_offsets(
57    tokenizer: &Tokenizer,
58    text: &str,
59) -> Result<Vec<(usize, usize)>, AppError> {
60    let prefixed = format!("{PASSAGE_PREFIX}{text}");
61    let prefix_len = PASSAGE_PREFIX.len();
62    let encoding = tokenizer
63        .encode(prefixed, true)
64        .map_err(|e| AppError::Embedding(e.to_string()))?;
65
66    let mut offsets = Vec::new();
67    for &(start, end) in encoding.get_offsets() {
68        if end <= start || end <= prefix_len {
69            continue;
70        }
71
72        let adjusted_start = start.saturating_sub(prefix_len).min(text.len());
73        let adjusted_end = end.saturating_sub(prefix_len).min(text.len());
74
75        if adjusted_end > adjusted_start
76            && text.is_char_boundary(adjusted_start)
77            && text.is_char_boundary(adjusted_end)
78        {
79            offsets.push((adjusted_start, adjusted_end));
80        }
81    }
82
83    if offsets.is_empty() && !text.is_empty() {
84        offsets.push((0, text.len()));
85    }
86
87    Ok(offsets)
88}
89
90fn count_tokens(tokenizer: &Tokenizer, text: &str) -> Result<usize, AppError> {
91    let encoding = tokenizer
92        .encode(text, true)
93        .map_err(|e| AppError::Embedding(e.to_string()))?;
94    Ok(encoding.len())
95}
96
97fn get_runtime(models_dir: &Path) -> Result<&'static TokenizerRuntime, AppError> {
98    if let Some(runtime) = TOKENIZER_RUNTIME.get() {
99        return Ok(runtime);
100    }
101
102    let runtime = load_runtime(models_dir)?;
103    let _ = TOKENIZER_RUNTIME.set(runtime);
104    Ok(TOKENIZER_RUNTIME
105        .get()
106        .expect("tokenizer runtime just initialized"))
107}
108
109fn load_runtime(models_dir: &Path) -> Result<TokenizerRuntime, AppError> {
110    let model_info = TextEmbedding::get_model_info(&EmbeddingModel::MultilingualE5Small)
111        .map_err(|e| AppError::Embedding(e.to_string()))?;
112
113    let cache_dir = std::env::var("HF_HOME")
114        .map(PathBuf::from)
115        .unwrap_or_else(|_| models_dir.to_path_buf());
116    let endpoint =
117        std::env::var("HF_ENDPOINT").unwrap_or_else(|_| "https://huggingface.co".to_string());
118
119    let api = ApiBuilder::new()
120        .with_cache_dir(cache_dir)
121        .with_endpoint(endpoint)
122        .with_progress(false)
123        .build()
124        .map_err(|e| AppError::Embedding(e.to_string()))?;
125    let repo = api.model(model_info.model_code.clone());
126
127    let tokenizer_bytes =
128        std::fs::read(repo.get("tokenizer.json").map_err(map_hf_err)?).map_err(AppError::Io)?;
129    let tokenizer_config_bytes =
130        std::fs::read(repo.get("tokenizer_config.json").map_err(map_hf_err)?)
131            .map_err(AppError::Io)?;
132
133    let tokenizer =
134        Tokenizer::from_bytes(tokenizer_bytes).map_err(|e| AppError::Embedding(e.to_string()))?;
135    let tokenizer_config: serde_json::Value =
136        serde_json::from_slice(&tokenizer_config_bytes).map_err(AppError::Json)?;
137    let model_max_length = tokenizer_config["model_max_length"]
138        .as_u64()
139        .map(|n| n as usize)
140        .or_else(|| {
141            tokenizer_config["model_max_length"]
142                .as_f64()
143                .map(|n| n as usize)
144        })
145        .ok_or_else(|| {
146            AppError::Embedding("tokenizer_config.json missing model_max_length field".into())
147        })?;
148
149    Ok(TokenizerRuntime {
150        tokenizer,
151        model_max_length,
152    })
153}
154
155fn map_hf_err(err: huggingface_hub::api::sync::ApiError) -> AppError {
156    AppError::Embedding(err.to_string())
157}