Skip to main content

vtcode_commons/
tokens.rs

1//! Token counting via tiktoken BPE tokenizer.
2//!
3//! All token estimation goes through [`tiktoken`]'s `cl100k_base` encoding
4//! (GPT-4, GPT-3.5-turbo). BPE tokenizers are similar enough across providers
5//! that this gives reasonable accuracy for Anthropic, Gemini, and others.
6//!
7//! Provider-reported exact token counts (from API responses) should always be
8//! preferred when available. This module is for pre-call budget estimation and
9//! offline token sizing where no provider response exists yet.
10
11use std::sync::OnceLock;
12use tiktoken::CoreBpe;
13
14/// Return the process-global `cl100k_base` BPE instance.
15///
16/// Loaded once on first call; all subsequent calls return the same reference.
17fn bpe() -> &'static CoreBpe {
18    static BPE: OnceLock<&CoreBpe> = OnceLock::new();
19    BPE.get_or_init(|| tiktoken::get_encoding("cl100k_base").expect("failed to load cl100k_base"))
20}
21
22/// Count the number of tokens in `text` using tiktoken BPE.
23///
24/// Returns 0 for empty strings.
25pub fn estimate_tokens(text: &str) -> usize {
26    if text.is_empty() {
27        return 0;
28    }
29    bpe().count(text)
30}
31
32/// Truncate `text` to at most `max_tokens` tokens.
33///
34/// Decodes the truncated token sequence back to text so the result is always
35/// valid UTF-8 with no mid-token corruption. Falls back to byte-level
36/// truncation if BPE decode fails (should not happen in practice).
37pub fn truncate_to_tokens(text: &str, max_tokens: usize) -> String {
38    if max_tokens == 0 || text.is_empty() {
39        return String::new();
40    }
41    let tokens = bpe().encode_with_special_tokens(text);
42    if tokens.len() <= max_tokens {
43        return text.to_string();
44    }
45    bpe()
46        .decode_to_string(&tokens[..max_tokens])
47        .unwrap_or_else(|_| {
48            // Byte-level fallback (should never fire for valid UTF-8).
49            let end = (max_tokens * 4).min(text.len());
50            let mut end = end;
51            while end > 0 && !text.is_char_boundary(end) {
52                end -= 1;
53            }
54            let mut result = text[..end].to_string();
55            result.push_str("...");
56            result
57        })
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63
64    #[test]
65    fn empty_string_returns_zero() {
66        assert_eq!(estimate_tokens(""), 0);
67        assert_eq!(truncate_to_tokens("", 10), "");
68    }
69
70    #[test]
71    fn count_is_reasonable() {
72        let count = estimate_tokens("Hello, how are you today?");
73        assert!(count >= 4 && count <= 12, "count={count}");
74    }
75
76    #[test]
77    fn truncate_respects_limit() {
78        let text = "the quick brown fox jumps over the lazy dog";
79        let truncated = truncate_to_tokens(text, 5);
80        let count = estimate_tokens(&truncated);
81        assert!(count <= 5 + 1, "count={count} should be <= 6");
82    }
83
84    #[test]
85    fn truncate_zero_returns_empty() {
86        assert_eq!(truncate_to_tokens("hello", 0), "");
87    }
88
89    #[test]
90    fn code_and_prose_tokenize() {
91        let code = "fn main() { println!(\"hello\"); }";
92        let prose = "the main function prints hello to console";
93        assert!(estimate_tokens(code) > 0);
94        assert!(estimate_tokens(prose) > 0);
95    }
96
97    #[test]
98    fn json_tokenizes() {
99        let json = r#"{"name":"test","value":123,"nested":{"key":"value"}}"#;
100        let count = estimate_tokens(json);
101        assert!(count >= 10 && count <= 40, "json count={count}");
102    }
103}