1use std::collections::HashMap;
37use std::sync::Arc;
38use moka::sync::Cache;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub enum TokenizerModel {
47 Cl100kBase,
49 P50kBase,
51 Claude,
53 Llama,
55 Generic,
57}
58
59impl TokenizerModel {
60 pub fn bytes_per_token(&self) -> f32 {
62 match self {
63 Self::Cl100kBase => 3.8,
64 Self::P50kBase => 4.0,
65 Self::Claude => 4.2,
66 Self::Llama => 4.0,
67 Self::Generic => 4.0,
68 }
69 }
70
71 pub fn name(&self) -> &'static str {
73 match self {
74 Self::Cl100kBase => "cl100k_base",
75 Self::P50kBase => "p50k_base",
76 Self::Claude => "claude",
77 Self::Llama => "llama",
78 Self::Generic => "generic",
79 }
80 }
81}
82
83#[derive(Debug, Clone)]
85pub struct ExactTokenConfig {
86 pub model: TokenizerModel,
88
89 pub cache_size: usize,
91
92 pub cache_ttl_secs: u64,
94
95 pub fallback_on_error: bool,
97
98 pub max_cache_text_len: usize,
100}
101
102impl Default for ExactTokenConfig {
103 fn default() -> Self {
104 Self {
105 model: TokenizerModel::Cl100kBase,
106 cache_size: 10_000,
107 cache_ttl_secs: 3600,
108 fallback_on_error: true,
109 max_cache_text_len: 10_000,
110 }
111 }
112}
113
114impl ExactTokenConfig {
115 pub fn gpt4() -> Self {
117 Self {
118 model: TokenizerModel::Cl100kBase,
119 ..Default::default()
120 }
121 }
122
123 pub fn claude() -> Self {
125 Self {
126 model: TokenizerModel::Claude,
127 ..Default::default()
128 }
129 }
130}
131
132pub trait TokenCounter: Send + Sync {
138 fn count(&self, text: &str) -> usize;
140
141 fn count_for_model(&self, text: &str, model: TokenizerModel) -> usize {
143 let _ = model; self.count(text)
145 }
146
147 fn tokenize(&self, text: &str) -> Vec<u32>;
149
150 fn decode(&self, tokens: &[u32]) -> String;
152
153 fn model(&self) -> TokenizerModel;
155
156 fn is_exact(&self) -> bool;
158}
159
160pub struct ExactTokenCounter {
169 config: ExactTokenConfig,
170
171 cache: Cache<u64, usize>,
173
174 vocab: Arc<BpeVocab>,
176
177 stats: Arc<TokenCacheStats>,
179}
180
181struct BpeVocab {
183 token_to_id: HashMap<String, u32>,
185
186 id_to_token: HashMap<u32, String>,
188
189 #[allow(dead_code)]
191 merges: HashMap<(String, String), String>,
192
193 special_tokens: HashMap<String, u32>,
195}
196
197impl BpeVocab {
198 fn cl100k_base() -> Self {
200 let mut token_to_id = HashMap::new();
201 let mut id_to_token = HashMap::new();
202
203 for b in 32u8..127 {
205 let token = String::from(b as char);
206 let id = b as u32;
207 token_to_id.insert(token.clone(), id);
208 id_to_token.insert(id, token);
209 }
210
211 let common_tokens = [
213 "the", "ing", "tion", "ed", "er", "es", "en", "al", "re",
214 "on", "an", "or", "ar", "is", "it", "at", "as", "le", "ve",
215 " the", " a", " to", " of", " and", " in", " is", " for",
216 " ", "\n", "\t", "```", "...", "->", "=>", "==", "!=",
217 ];
218
219 let mut id = 200u32;
220 for token in common_tokens {
221 token_to_id.insert(token.to_string(), id);
222 id_to_token.insert(id, token.to_string());
223 id += 1;
224 }
225
226 let mut special_tokens = HashMap::new();
228 special_tokens.insert("<|endoftext|>".to_string(), 100257);
229 special_tokens.insert("<|fim_prefix|>".to_string(), 100258);
230 special_tokens.insert("<|fim_middle|>".to_string(), 100259);
231 special_tokens.insert("<|fim_suffix|>".to_string(), 100260);
232
233 Self {
234 token_to_id,
235 id_to_token,
236 merges: HashMap::new(),
237 special_tokens,
238 }
239 }
240
241 fn tokenize(&self, text: &str) -> Vec<u32> {
243 let mut tokens = Vec::new();
244 let mut remaining = text;
245
246 while !remaining.is_empty() {
247 let mut matched = false;
249
250 for (special, id) in &self.special_tokens {
252 if remaining.starts_with(special) {
253 tokens.push(*id);
254 remaining = &remaining[special.len()..];
255 matched = true;
256 break;
257 }
258 }
259
260 if matched {
261 continue;
262 }
263
264 for len in (1..=remaining.len().min(10)).rev() {
266 if let Some(substr) = remaining.get(..len) {
267 if let Some(&id) = self.token_to_id.get(substr) {
268 tokens.push(id);
269 remaining = &remaining[len..];
270 matched = true;
271 break;
272 }
273 }
274 }
275
276 if !matched {
277 if let Some(c) = remaining.chars().next() {
279 let byte_id = (c as u32).min(255);
280 tokens.push(byte_id);
281 remaining = &remaining[c.len_utf8()..];
282 }
283 }
284 }
285
286 tokens
287 }
288
289 fn decode(&self, tokens: &[u32]) -> String {
291 let mut result = String::new();
292
293 for &id in tokens {
294 if let Some(token) = self.id_to_token.get(&id) {
295 result.push_str(token);
296 } else {
297 if id < 256 {
299 if let Some(c) = char::from_u32(id) {
300 result.push(c);
301 }
302 }
303 }
304 }
305
306 result
307 }
308}
309
310#[derive(Debug, Default)]
312pub struct TokenCacheStats {
313 pub hits: std::sync::atomic::AtomicUsize,
315 pub misses: std::sync::atomic::AtomicUsize,
317 pub tokenizations: std::sync::atomic::AtomicUsize,
319 pub total_tokens: std::sync::atomic::AtomicUsize,
321}
322
323impl TokenCacheStats {
324 pub fn hit_rate(&self) -> f64 {
326 let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
327 let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
328 let total = hits + misses;
329 if total == 0 {
330 0.0
331 } else {
332 hits as f64 / total as f64
333 }
334 }
335}
336
337impl ExactTokenCounter {
338 pub fn new(config: ExactTokenConfig) -> Self {
340 let cache = Cache::builder()
341 .max_capacity(config.cache_size as u64)
342 .time_to_live(std::time::Duration::from_secs(config.cache_ttl_secs))
343 .build();
344
345 Self {
346 config,
347 cache,
348 vocab: Arc::new(BpeVocab::cl100k_base()),
349 stats: Arc::new(TokenCacheStats::default()),
350 }
351 }
352
353 pub fn default_counter() -> Self {
355 Self::new(ExactTokenConfig::default())
356 }
357
358 pub fn stats(&self) -> &Arc<TokenCacheStats> {
360 &self.stats
361 }
362
363 fn text_hash(text: &str) -> u64 {
365 use std::hash::{Hash, Hasher};
366 use std::collections::hash_map::DefaultHasher;
367
368 let mut hasher = DefaultHasher::new();
369 text.hash(&mut hasher);
370 hasher.finish()
371 }
372
373 fn count_cached(&self, text: &str) -> usize {
375 if text.len() > self.config.max_cache_text_len {
377 self.stats.misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
378 return self.tokenize(text).len();
379 }
380
381 let hash = Self::text_hash(text);
382
383 if let Some(count) = self.cache.get(&hash) {
384 self.stats.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
385 return count;
386 }
387
388 self.stats.misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
389
390 let tokens = self.tokenize(text);
391 let count = tokens.len();
392
393 self.cache.insert(hash, count);
394 self.stats.total_tokens.fetch_add(count, std::sync::atomic::Ordering::Relaxed);
395
396 count
397 }
398
399 #[allow(dead_code)]
401 fn estimate_tokens(&self, text: &str) -> usize {
402 let bytes = text.len();
403 ((bytes as f32) / self.config.model.bytes_per_token()).ceil() as usize
404 }
405}
406
407impl TokenCounter for ExactTokenCounter {
408 fn count(&self, text: &str) -> usize {
409 self.stats.tokenizations.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
410 self.count_cached(text)
411 }
412
413 fn count_for_model(&self, text: &str, model: TokenizerModel) -> usize {
414 if model == self.config.model {
415 self.count(text)
416 } else {
417 let bytes = text.len();
419 ((bytes as f32) / model.bytes_per_token()).ceil() as usize
420 }
421 }
422
423 fn tokenize(&self, text: &str) -> Vec<u32> {
424 self.vocab.tokenize(text)
425 }
426
427 fn decode(&self, tokens: &[u32]) -> String {
428 self.vocab.decode(tokens)
429 }
430
431 fn model(&self) -> TokenizerModel {
432 self.config.model
433 }
434
435 fn is_exact(&self) -> bool {
436 true
437 }
438}
439
440pub struct HeuristicTokenCounter {
446 bytes_per_token: f32,
448
449 model: TokenizerModel,
451}
452
453impl HeuristicTokenCounter {
454 pub fn new() -> Self {
456 Self {
457 bytes_per_token: 4.0,
458 model: TokenizerModel::Generic,
459 }
460 }
461
462 pub fn for_model(model: TokenizerModel) -> Self {
464 Self {
465 bytes_per_token: model.bytes_per_token(),
466 model,
467 }
468 }
469}
470
471impl Default for HeuristicTokenCounter {
472 fn default() -> Self {
473 Self::new()
474 }
475}
476
477impl TokenCounter for HeuristicTokenCounter {
478 fn count(&self, text: &str) -> usize {
479 let bytes = text.len();
480 ((bytes as f32) / self.bytes_per_token).ceil() as usize
481 }
482
483 fn tokenize(&self, text: &str) -> Vec<u32> {
484 text.split_whitespace()
486 .enumerate()
487 .map(|(i, _)| i as u32)
488 .collect()
489 }
490
491 fn decode(&self, _tokens: &[u32]) -> String {
492 "[decode not supported for heuristic counter]".to_string()
494 }
495
496 fn model(&self) -> TokenizerModel {
497 self.model
498 }
499
500 fn is_exact(&self) -> bool {
501 false
502 }
503}
504
505pub struct ExactBudgetEnforcer<C: TokenCounter> {
511 counter: Arc<C>,
513
514 budget: usize,
516
517 used: std::sync::atomic::AtomicUsize,
519}
520
521impl<C: TokenCounter> ExactBudgetEnforcer<C> {
522 pub fn new(counter: Arc<C>, budget: usize) -> Self {
524 Self {
525 counter,
526 budget,
527 used: std::sync::atomic::AtomicUsize::new(0),
528 }
529 }
530
531 pub fn remaining(&self) -> usize {
533 self.budget.saturating_sub(self.used.load(std::sync::atomic::Ordering::Relaxed))
534 }
535
536 pub fn fits(&self, text: &str) -> bool {
538 let tokens = self.counter.count(text);
539 tokens <= self.remaining()
540 }
541
542 pub fn try_consume(&self, text: &str) -> Option<usize> {
545 let tokens = self.counter.count(text);
546 let remaining = self.remaining();
547
548 if tokens <= remaining {
549 self.used.fetch_add(tokens, std::sync::atomic::Ordering::Relaxed);
550 Some(tokens)
551 } else {
552 None
553 }
554 }
555
556 pub fn force_consume(&self, tokens: usize) {
558 self.used.fetch_add(tokens, std::sync::atomic::Ordering::Relaxed);
559 }
560
561 pub fn truncate_to_fit(&self, text: &str) -> (String, usize) {
563 let remaining = self.remaining();
564 if remaining == 0 {
565 return (String::new(), 0);
566 }
567
568 let mut low = 0;
570 let mut high = text.len();
571 let mut best_len = 0;
572 let mut best_tokens = 0;
573
574 while low < high {
575 let mid = (low + high + 1) / 2;
576
577 let truncated = if mid >= text.len() {
579 text.to_string()
580 } else {
581 let mut end = mid;
582 while !text.is_char_boundary(end) && end > 0 {
583 end -= 1;
584 }
585 text[..end].to_string()
586 };
587
588 let tokens = self.counter.count(&truncated);
589
590 if tokens <= remaining {
591 best_len = truncated.len();
592 best_tokens = tokens;
593 low = mid;
594 } else {
595 high = mid - 1;
596 }
597 }
598
599 if best_len == 0 {
600 (String::new(), 0)
601 } else {
602 (text[..best_len].to_string(), best_tokens)
603 }
604 }
605
606 pub fn summary(&self) -> BudgetSummary {
608 let used = self.used.load(std::sync::atomic::Ordering::Relaxed);
609 BudgetSummary {
610 budget: self.budget,
611 used,
612 remaining: self.budget.saturating_sub(used),
613 utilization: (used as f64) / (self.budget as f64),
614 }
615 }
616}
617
618#[derive(Debug, Clone)]
620pub struct BudgetSummary {
621 pub budget: usize,
623 pub used: usize,
625 pub remaining: usize,
627 pub utilization: f64,
629}
630
631pub fn count_tokens_exact(text: &str) -> usize {
637 let counter = ExactTokenCounter::default_counter();
638 counter.count(text)
639}
640
641pub fn count_tokens_heuristic(text: &str) -> usize {
643 let counter = HeuristicTokenCounter::new();
644 counter.count(text)
645}
646
647pub fn create_budget_enforcer(budget: usize) -> ExactBudgetEnforcer<ExactTokenCounter> {
649 let counter = Arc::new(ExactTokenCounter::default_counter());
650 ExactBudgetEnforcer::new(counter, budget)
651}
652
653#[cfg(test)]
658mod tests {
659 use super::*;
660
661 #[test]
662 fn test_exact_token_count() {
663 let counter = ExactTokenCounter::default_counter();
664
665 let count = counter.count("Hello, world!");
666 assert!(count > 0);
667 assert!(count < 20); }
669
670 #[test]
671 fn test_tokenize_and_decode() {
672 let counter = ExactTokenCounter::default_counter();
673
674 let text = "Hello world";
675 let tokens = counter.tokenize(text);
676
677 assert!(!tokens.is_empty());
678
679 let decoded = counter.decode(&tokens);
681 assert!(!decoded.is_empty());
682 }
683
684 #[test]
685 fn test_cache_hits() {
686 let counter = ExactTokenCounter::default_counter();
687
688 let _ = counter.count("test text for caching");
690
691 let _ = counter.count("test text for caching");
693
694 let stats = counter.stats();
695 let hits = stats.hits.load(std::sync::atomic::Ordering::Relaxed);
696 let misses = stats.misses.load(std::sync::atomic::Ordering::Relaxed);
697
698 assert!(hits >= 1);
699 assert!(misses >= 1);
700 }
701
702 #[test]
703 fn test_heuristic_counter() {
704 let counter = HeuristicTokenCounter::new();
705
706 let count = counter.count("Hello world");
708 assert!(count >= 2 && count <= 5);
709 }
710
711 #[test]
712 fn test_budget_enforcer() {
713 let counter = Arc::new(ExactTokenCounter::default_counter());
714 let enforcer = ExactBudgetEnforcer::new(counter, 100);
715
716 assert_eq!(enforcer.remaining(), 100);
717
718 let consumed = enforcer.try_consume("Hello world").unwrap();
720 assert!(consumed > 0);
721 assert!(enforcer.remaining() < 100);
722 }
723
724 #[test]
725 fn test_budget_truncation() {
726 let counter = Arc::new(ExactTokenCounter::default_counter());
727 let enforcer = ExactBudgetEnforcer::new(counter, 5);
728
729 let long_text = "This is a very long text that definitely exceeds five tokens and should be truncated";
730
731 let (truncated, tokens) = enforcer.truncate_to_fit(long_text);
732
733 assert!(truncated.len() < long_text.len());
734 assert!(tokens <= 5);
735 }
736
737 #[test]
738 fn test_budget_summary() {
739 let counter = Arc::new(HeuristicTokenCounter::new());
740 let enforcer = ExactBudgetEnforcer::new(counter, 100);
741
742 enforcer.force_consume(25);
743
744 let summary = enforcer.summary();
745 assert_eq!(summary.budget, 100);
746 assert_eq!(summary.used, 25);
747 assert_eq!(summary.remaining, 75);
748 assert!((summary.utilization - 0.25).abs() < 0.01);
749 }
750
751 #[test]
752 fn test_model_specific_counting() {
753 let counter = ExactTokenCounter::default_counter();
754
755 let text = "Hello, world!";
756
757 let gpt4_count = counter.count_for_model(text, TokenizerModel::Cl100kBase);
759 let claude_count = counter.count_for_model(text, TokenizerModel::Claude);
760
761 assert!(gpt4_count > 0);
763 assert!(claude_count > 0);
764 }
765}