Skip to main content

vtcode_commons/
tokens.rs

1//! Token counting via tiktoken BPE tokenizer.
2//!
3//! All token estimation goes through [`tiktoken`]'s `cl100k_base` encoding
4//! (GPT-4, GPT-3.5-turbo). BPE tokenizers are similar enough across providers
5//! that this gives reasonable accuracy for Anthropic, Gemini, and others.
6//!
7//! Provider-reported exact token counts (from API responses) should always be
8//! preferred when available. This module is for pre-call budget estimation and
9//! offline token sizing where no provider response exists yet.
10
11use std::sync::OnceLock;
12use tiktoken::CoreBpe;
13
14/// Return the process-global `cl100k_base` BPE instance, if it could be loaded.
15///
16/// Loaded once on first call; all subsequent calls return the same reference.
17/// Returns `None` only if the builtin encoding fails to load, in which case
18/// callers fall back to a character-based heuristic rather than panicking.
19fn bpe() -> Option<&'static CoreBpe> {
20    static BPE: OnceLock<Option<&'static CoreBpe>> = OnceLock::new();
21    *BPE.get_or_init(|| tiktoken::get_encoding("cl100k_base"))
22}
23
24/// Approximate token count from character length (~4 chars per token).
25fn heuristic_token_count(text: &str) -> usize {
26    text.len().div_ceil(4)
27}
28
29/// Count the number of tokens in `text` using tiktoken BPE.
30///
31/// Returns 0 for empty strings. Falls back to a character-based heuristic if
32/// the BPE tokenizer is unavailable.
33pub fn estimate_tokens(text: &str) -> usize {
34    if text.is_empty() {
35        return 0;
36    }
37    match bpe() {
38        Some(bpe) => bpe.count(text),
39        None => heuristic_token_count(text),
40    }
41}
42
43/// Truncate `text` to at most `max_tokens` tokens.
44///
45/// Decodes the truncated token sequence back to text so the result is always
46/// valid UTF-8 with no mid-token corruption. Falls back to byte-level
47/// truncation if BPE decode fails (should not happen in practice).
48pub fn truncate_to_tokens(text: &str, max_tokens: usize) -> String {
49    if max_tokens == 0 || text.is_empty() {
50        return String::new();
51    }
52    // Byte-level fallback used when BPE is unavailable or decode fails.
53    let byte_truncate = || {
54        let end = (max_tokens * 4).min(text.len());
55        let mut end = end;
56        while end > 0 && !text.is_char_boundary(end) {
57            end -= 1;
58        }
59        let mut result = text[..end].to_string();
60        result.push_str("...");
61        result
62    };
63    let Some(bpe) = bpe() else {
64        return byte_truncate();
65    };
66    let tokens = bpe.encode_with_special_tokens(text);
67    if tokens.len() <= max_tokens {
68        return text.to_string();
69    }
70    bpe.decode_to_string(&tokens[..max_tokens])
71        .unwrap_or_else(|_| byte_truncate())
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77
78    #[test]
79    fn empty_string_returns_zero() {
80        assert_eq!(estimate_tokens(""), 0);
81        assert_eq!(truncate_to_tokens("", 10), "");
82    }
83
84    #[test]
85    fn count_is_reasonable() {
86        let count = estimate_tokens("Hello, how are you today?");
87        assert!((4..=12).contains(&count), "count={count}");
88    }
89
90    #[test]
91    fn truncate_respects_limit() {
92        let text = "the quick brown fox jumps over the lazy dog";
93        let truncated = truncate_to_tokens(text, 5);
94        let count = estimate_tokens(&truncated);
95        assert!(count <= 5 + 1, "count={count} should be <= 6");
96    }
97
98    #[test]
99    fn truncate_zero_returns_empty() {
100        assert_eq!(truncate_to_tokens("hello", 0), "");
101    }
102
103    #[test]
104    fn code_and_prose_tokenize() {
105        let code = "fn main() { println!(\"hello\"); }";
106        let prose = "the main function prints hello to console";
107        assert!(estimate_tokens(code) > 0);
108        assert!(estimate_tokens(prose) > 0);
109    }
110
111    #[test]
112    fn json_tokenizes() {
113        let json = r#"{"name":"test","value":123,"nested":{"key":"value"}}"#;
114        let count = estimate_tokens(json);
115        assert!((10..=40).contains(&count), "json count={count}");
116    }
117}