1use std::sync::OnceLock;
12use tiktoken::CoreBpe;
13
14fn bpe() -> Option<&'static CoreBpe> {
20 static BPE: OnceLock<Option<&'static CoreBpe>> = OnceLock::new();
21 *BPE.get_or_init(|| tiktoken::get_encoding("cl100k_base"))
22}
23
24fn heuristic_token_count(text: &str) -> usize {
26 text.len().div_ceil(4)
27}
28
29pub fn estimate_tokens(text: &str) -> usize {
34 if text.is_empty() {
35 return 0;
36 }
37 match bpe() {
38 Some(bpe) => bpe.count(text),
39 None => heuristic_token_count(text),
40 }
41}
42
43pub fn truncate_to_tokens(text: &str, max_tokens: usize) -> String {
49 if max_tokens == 0 || text.is_empty() {
50 return String::new();
51 }
52 let byte_truncate = || {
54 let end = (max_tokens * 4).min(text.len());
55 let mut end = end;
56 while end > 0 && !text.is_char_boundary(end) {
57 end -= 1;
58 }
59 let mut result = text[..end].to_string();
60 result.push_str("...");
61 result
62 };
63 let Some(bpe) = bpe() else {
64 return byte_truncate();
65 };
66 let tokens = bpe.encode_with_special_tokens(text);
67 if tokens.len() <= max_tokens {
68 return text.to_string();
69 }
70 bpe.decode_to_string(&tokens[..max_tokens])
71 .unwrap_or_else(|_| byte_truncate())
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77
78 #[test]
79 fn empty_string_returns_zero() {
80 assert_eq!(estimate_tokens(""), 0);
81 assert_eq!(truncate_to_tokens("", 10), "");
82 }
83
84 #[test]
85 fn count_is_reasonable() {
86 let count = estimate_tokens("Hello, how are you today?");
87 assert!((4..=12).contains(&count), "count={count}");
88 }
89
90 #[test]
91 fn truncate_respects_limit() {
92 let text = "the quick brown fox jumps over the lazy dog";
93 let truncated = truncate_to_tokens(text, 5);
94 let count = estimate_tokens(&truncated);
95 assert!(count <= 5 + 1, "count={count} should be <= 6");
96 }
97
98 #[test]
99 fn truncate_zero_returns_empty() {
100 assert_eq!(truncate_to_tokens("hello", 0), "");
101 }
102
103 #[test]
104 fn code_and_prose_tokenize() {
105 let code = "fn main() { println!(\"hello\"); }";
106 let prose = "the main function prints hello to console";
107 assert!(estimate_tokens(code) > 0);
108 assert!(estimate_tokens(prose) > 0);
109 }
110
111 #[test]
112 fn json_tokenizes() {
113 let json = r#"{"name":"test","value":123,"nested":{"key":"value"}}"#;
114 let count = estimate_tokens(json);
115 assert!((10..=40).contains(&count), "json count={count}");
116 }
117}