1use std::collections::HashMap;
34use std::sync::Arc;
35use moka::sync::Cache;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
43pub enum TokenizerModel {
44 Cl100kBase,
46 P50kBase,
48 Claude,
50 Llama,
52 Generic,
54}
55
56impl TokenizerModel {
57 pub fn bytes_per_token(&self) -> f32 {
59 match self {
60 Self::Cl100kBase => 3.8,
61 Self::P50kBase => 4.0,
62 Self::Claude => 4.2,
63 Self::Llama => 4.0,
64 Self::Generic => 4.0,
65 }
66 }
67
68 pub fn name(&self) -> &'static str {
70 match self {
71 Self::Cl100kBase => "cl100k_base",
72 Self::P50kBase => "p50k_base",
73 Self::Claude => "claude",
74 Self::Llama => "llama",
75 Self::Generic => "generic",
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct ExactTokenConfig {
83 pub model: TokenizerModel,
85
86 pub cache_size: usize,
88
89 pub cache_ttl_secs: u64,
91
92 pub fallback_on_error: bool,
94
95 pub max_cache_text_len: usize,
97}
98
99impl Default for ExactTokenConfig {
100 fn default() -> Self {
101 Self {
102 model: TokenizerModel::Cl100kBase,
103 cache_size: 10_000,
104 cache_ttl_secs: 3600,
105 fallback_on_error: true,
106 max_cache_text_len: 10_000,
107 }
108 }
109}
110
111impl ExactTokenConfig {
112 pub fn gpt4() -> Self {
114 Self {
115 model: TokenizerModel::Cl100kBase,
116 ..Default::default()
117 }
118 }
119
120 pub fn claude() -> Self {
122 Self {
123 model: TokenizerModel::Claude,
124 ..Default::default()
125 }
126 }
127}
128
129pub trait TokenCounter: Send + Sync {
135 fn count(&self, text: &str) -> usize;
137
138 fn count_for_model(&self, text: &str, model: TokenizerModel) -> usize {
140 let _ = model; self.count(text)
142 }
143
144 fn tokenize(&self, text: &str) -> Vec<u32>;
146
147 fn decode(&self, tokens: &[u32]) -> String;
149
150 fn model(&self) -> TokenizerModel;
152
153 fn is_exact(&self) -> bool;
155}
156
157pub struct ExactTokenCounter {
166 config: ExactTokenConfig,
167
168 cache: Cache<u64, usize>,
170
171 vocab: Arc<BpeVocab>,
173
174 stats: Arc<TokenCacheStats>,
176}
177
178struct BpeVocab {
180 token_to_id: HashMap<String, u32>,
182
183 id_to_token: HashMap<u32, String>,
185
186 #[allow(dead_code)]
188 merges: HashMap<(String, String), String>,
189
190 special_tokens: HashMap<String, u32>,
192}
193
194impl BpeVocab {
195 fn cl100k_base() -> Self {
197 let mut token_to_id = HashMap::new();
198 let mut id_to_token = HashMap::new();
199
200 for b in 32u8..127 {
202 let token = String::from(b as char);
203 let id = b as u32;
204 token_to_id.insert(token.clone(), id);
205 id_to_token.insert(id, token);
206 }
207
208 let common_tokens = [
210 "the", "ing", "tion", "ed", "er", "es", "en", "al", "re",
211 "on", "an", "or", "ar", "is", "it", "at", "as", "le", "ve",
212 " the", " a", " to", " of", " and", " in", " is", " for",
213 " ", "\n", "\t", "```", "...", "->", "=>", "==", "!=",
214 ];
215
216 let mut id = 200u32;
217 for token in common_tokens {
218 token_to_id.insert(token.to_string(), id);
219 id_to_token.insert(id, token.to_string());
220 id += 1;
221 }
222
223 let mut special_tokens = HashMap::new();
225 special_tokens.insert("<|endoftext|>".to_string(), 100257);
226 special_tokens.insert("<|fim_prefix|>".to_string(), 100258);
227 special_tokens.insert("<|fim_middle|>".to_string(), 100259);
228 special_tokens.insert("<|fim_suffix|>".to_string(), 100260);
229
230 Self {
231 token_to_id,
232 id_to_token,
233 merges: HashMap::new(),
234 special_tokens,
235 }
236 }
237
238 fn tokenize(&self, text: &str) -> Vec<u32> {
240 let mut tokens = Vec::new();
241 let mut remaining = text;
242
243 while !remaining.is_empty() {
244 let mut matched = false;
246
247 for (special, id) in &self.special_tokens {
249 if remaining.starts_with(special) {
250 tokens.push(*id);
251 remaining = &remaining[special.len()..];
252 matched = true;
253 break;
254 }
255 }
256
257 if matched {
258 continue;
259 }
260
261 for len in (1..=remaining.len().min(10)).rev() {
263 if let Some(substr) = remaining.get(..len) {
264 if let Some(&id) = self.token_to_id.get(substr) {
265 tokens.push(id);
266 remaining = &remaining[len..];
267 matched = true;
268 break;
269 }
270 }
271 }
272
273 if !matched {
274 if let Some(c) = remaining.chars().next() {
276 let byte_id = (c as u32).min(255);
277 tokens.push(byte_id);
278 remaining = &remaining[c.len_utf8()..];
279 }
280 }
281 }
282
283 tokens
284 }
285
286 fn decode(&self, tokens: &[u32]) -> String {
288 let mut result = String::new();
289
290 for &id in tokens {
291 if let Some(token) = self.id_to_token.get(&id) {
292 result.push_str(token);
293 } else {
294 if id < 256 {
296 if let Some(c) = char::from_u32(id) {
297 result.push(c);
298 }
299 }
300 }
301 }
302
303 result
304 }
305}
306
307#[derive(Debug, Default)]
309pub struct TokenCacheStats {
310 pub hits: std::sync::atomic::AtomicUsize,
312 pub misses: std::sync::atomic::AtomicUsize,
314 pub tokenizations: std::sync::atomic::AtomicUsize,
316 pub total_tokens: std::sync::atomic::AtomicUsize,
318}
319
320impl TokenCacheStats {
321 pub fn hit_rate(&self) -> f64 {
323 let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
324 let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
325 let total = hits + misses;
326 if total == 0 {
327 0.0
328 } else {
329 hits as f64 / total as f64
330 }
331 }
332}
333
334impl ExactTokenCounter {
335 pub fn new(config: ExactTokenConfig) -> Self {
337 let cache = Cache::builder()
338 .max_capacity(config.cache_size as u64)
339 .time_to_live(std::time::Duration::from_secs(config.cache_ttl_secs))
340 .build();
341
342 Self {
343 config,
344 cache,
345 vocab: Arc::new(BpeVocab::cl100k_base()),
346 stats: Arc::new(TokenCacheStats::default()),
347 }
348 }
349
350 pub fn default_counter() -> Self {
352 Self::new(ExactTokenConfig::default())
353 }
354
355 pub fn stats(&self) -> &Arc<TokenCacheStats> {
357 &self.stats
358 }
359
360 fn text_hash(text: &str) -> u64 {
362 use std::hash::{Hash, Hasher};
363 use std::collections::hash_map::DefaultHasher;
364
365 let mut hasher = DefaultHasher::new();
366 text.hash(&mut hasher);
367 hasher.finish()
368 }
369
370 fn count_cached(&self, text: &str) -> usize {
372 if text.len() > self.config.max_cache_text_len {
374 self.stats.misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
375 return self.tokenize(text).len();
376 }
377
378 let hash = Self::text_hash(text);
379
380 if let Some(count) = self.cache.get(&hash) {
381 self.stats.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
382 return count;
383 }
384
385 self.stats.misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
386
387 let tokens = self.tokenize(text);
388 let count = tokens.len();
389
390 self.cache.insert(hash, count);
391 self.stats.total_tokens.fetch_add(count, std::sync::atomic::Ordering::Relaxed);
392
393 count
394 }
395
396 #[allow(dead_code)]
398 fn estimate_tokens(&self, text: &str) -> usize {
399 let bytes = text.len();
400 ((bytes as f32) / self.config.model.bytes_per_token()).ceil() as usize
401 }
402}
403
404impl TokenCounter for ExactTokenCounter {
405 fn count(&self, text: &str) -> usize {
406 self.stats.tokenizations.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
407 self.count_cached(text)
408 }
409
410 fn count_for_model(&self, text: &str, model: TokenizerModel) -> usize {
411 if model == self.config.model {
412 self.count(text)
413 } else {
414 let bytes = text.len();
416 ((bytes as f32) / model.bytes_per_token()).ceil() as usize
417 }
418 }
419
420 fn tokenize(&self, text: &str) -> Vec<u32> {
421 self.vocab.tokenize(text)
422 }
423
424 fn decode(&self, tokens: &[u32]) -> String {
425 self.vocab.decode(tokens)
426 }
427
428 fn model(&self) -> TokenizerModel {
429 self.config.model
430 }
431
432 fn is_exact(&self) -> bool {
433 true
434 }
435}
436
437pub struct HeuristicTokenCounter {
443 bytes_per_token: f32,
445
446 model: TokenizerModel,
448}
449
450impl HeuristicTokenCounter {
451 pub fn new() -> Self {
453 Self {
454 bytes_per_token: 4.0,
455 model: TokenizerModel::Generic,
456 }
457 }
458
459 pub fn for_model(model: TokenizerModel) -> Self {
461 Self {
462 bytes_per_token: model.bytes_per_token(),
463 model,
464 }
465 }
466}
467
468impl Default for HeuristicTokenCounter {
469 fn default() -> Self {
470 Self::new()
471 }
472}
473
474impl TokenCounter for HeuristicTokenCounter {
475 fn count(&self, text: &str) -> usize {
476 let bytes = text.len();
477 ((bytes as f32) / self.bytes_per_token).ceil() as usize
478 }
479
480 fn tokenize(&self, text: &str) -> Vec<u32> {
481 text.split_whitespace()
483 .enumerate()
484 .map(|(i, _)| i as u32)
485 .collect()
486 }
487
488 fn decode(&self, _tokens: &[u32]) -> String {
489 "[decode not supported for heuristic counter]".to_string()
491 }
492
493 fn model(&self) -> TokenizerModel {
494 self.model
495 }
496
497 fn is_exact(&self) -> bool {
498 false
499 }
500}
501
502pub struct ExactBudgetEnforcer<C: TokenCounter> {
508 counter: Arc<C>,
510
511 budget: usize,
513
514 used: std::sync::atomic::AtomicUsize,
516}
517
518impl<C: TokenCounter> ExactBudgetEnforcer<C> {
519 pub fn new(counter: Arc<C>, budget: usize) -> Self {
521 Self {
522 counter,
523 budget,
524 used: std::sync::atomic::AtomicUsize::new(0),
525 }
526 }
527
528 pub fn remaining(&self) -> usize {
530 self.budget.saturating_sub(self.used.load(std::sync::atomic::Ordering::Relaxed))
531 }
532
533 pub fn fits(&self, text: &str) -> bool {
535 let tokens = self.counter.count(text);
536 tokens <= self.remaining()
537 }
538
539 pub fn try_consume(&self, text: &str) -> Option<usize> {
542 let tokens = self.counter.count(text);
543 let remaining = self.remaining();
544
545 if tokens <= remaining {
546 self.used.fetch_add(tokens, std::sync::atomic::Ordering::Relaxed);
547 Some(tokens)
548 } else {
549 None
550 }
551 }
552
553 pub fn force_consume(&self, tokens: usize) {
555 self.used.fetch_add(tokens, std::sync::atomic::Ordering::Relaxed);
556 }
557
558 pub fn truncate_to_fit(&self, text: &str) -> (String, usize) {
560 let remaining = self.remaining();
561 if remaining == 0 {
562 return (String::new(), 0);
563 }
564
565 let mut low = 0;
567 let mut high = text.len();
568 let mut best_len = 0;
569 let mut best_tokens = 0;
570
571 while low < high {
572 let mid = (low + high + 1) / 2;
573
574 let truncated = if mid >= text.len() {
576 text.to_string()
577 } else {
578 let mut end = mid;
579 while !text.is_char_boundary(end) && end > 0 {
580 end -= 1;
581 }
582 text[..end].to_string()
583 };
584
585 let tokens = self.counter.count(&truncated);
586
587 if tokens <= remaining {
588 best_len = truncated.len();
589 best_tokens = tokens;
590 low = mid;
591 } else {
592 high = mid - 1;
593 }
594 }
595
596 if best_len == 0 {
597 (String::new(), 0)
598 } else {
599 (text[..best_len].to_string(), best_tokens)
600 }
601 }
602
603 pub fn summary(&self) -> BudgetSummary {
605 let used = self.used.load(std::sync::atomic::Ordering::Relaxed);
606 BudgetSummary {
607 budget: self.budget,
608 used,
609 remaining: self.budget.saturating_sub(used),
610 utilization: (used as f64) / (self.budget as f64),
611 }
612 }
613}
614
615#[derive(Debug, Clone)]
617pub struct BudgetSummary {
618 pub budget: usize,
620 pub used: usize,
622 pub remaining: usize,
624 pub utilization: f64,
626}
627
628pub fn count_tokens_exact(text: &str) -> usize {
634 let counter = ExactTokenCounter::default_counter();
635 counter.count(text)
636}
637
638pub fn count_tokens_heuristic(text: &str) -> usize {
640 let counter = HeuristicTokenCounter::new();
641 counter.count(text)
642}
643
644pub fn create_budget_enforcer(budget: usize) -> ExactBudgetEnforcer<ExactTokenCounter> {
646 let counter = Arc::new(ExactTokenCounter::default_counter());
647 ExactBudgetEnforcer::new(counter, budget)
648}
649
650#[cfg(test)]
655mod tests {
656 use super::*;
657
658 #[test]
659 fn test_exact_token_count() {
660 let counter = ExactTokenCounter::default_counter();
661
662 let count = counter.count("Hello, world!");
663 assert!(count > 0);
664 assert!(count < 20); }
666
667 #[test]
668 fn test_tokenize_and_decode() {
669 let counter = ExactTokenCounter::default_counter();
670
671 let text = "Hello world";
672 let tokens = counter.tokenize(text);
673
674 assert!(!tokens.is_empty());
675
676 let decoded = counter.decode(&tokens);
678 assert!(!decoded.is_empty());
679 }
680
681 #[test]
682 fn test_cache_hits() {
683 let counter = ExactTokenCounter::default_counter();
684
685 let _ = counter.count("test text for caching");
687
688 let _ = counter.count("test text for caching");
690
691 let stats = counter.stats();
692 let hits = stats.hits.load(std::sync::atomic::Ordering::Relaxed);
693 let misses = stats.misses.load(std::sync::atomic::Ordering::Relaxed);
694
695 assert!(hits >= 1);
696 assert!(misses >= 1);
697 }
698
699 #[test]
700 fn test_heuristic_counter() {
701 let counter = HeuristicTokenCounter::new();
702
703 let count = counter.count("Hello world");
705 assert!(count >= 2 && count <= 5);
706 }
707
708 #[test]
709 fn test_budget_enforcer() {
710 let counter = Arc::new(ExactTokenCounter::default_counter());
711 let enforcer = ExactBudgetEnforcer::new(counter, 100);
712
713 assert_eq!(enforcer.remaining(), 100);
714
715 let consumed = enforcer.try_consume("Hello world").unwrap();
717 assert!(consumed > 0);
718 assert!(enforcer.remaining() < 100);
719 }
720
721 #[test]
722 fn test_budget_truncation() {
723 let counter = Arc::new(ExactTokenCounter::default_counter());
724 let enforcer = ExactBudgetEnforcer::new(counter, 5);
725
726 let long_text = "This is a very long text that definitely exceeds five tokens and should be truncated";
727
728 let (truncated, tokens) = enforcer.truncate_to_fit(long_text);
729
730 assert!(truncated.len() < long_text.len());
731 assert!(tokens <= 5);
732 }
733
734 #[test]
735 fn test_budget_summary() {
736 let counter = Arc::new(HeuristicTokenCounter::new());
737 let enforcer = ExactBudgetEnforcer::new(counter, 100);
738
739 enforcer.force_consume(25);
740
741 let summary = enforcer.summary();
742 assert_eq!(summary.budget, 100);
743 assert_eq!(summary.used, 25);
744 assert_eq!(summary.remaining, 75);
745 assert!((summary.utilization - 0.25).abs() < 0.01);
746 }
747
748 #[test]
749 fn test_model_specific_counting() {
750 let counter = ExactTokenCounter::default_counter();
751
752 let text = "Hello, world!";
753
754 let gpt4_count = counter.count_for_model(text, TokenizerModel::Cl100kBase);
756 let claude_count = counter.count_for_model(text, TokenizerModel::Claude);
757
758 assert!(gpt4_count > 0);
760 assert!(claude_count > 0);
761 }
762}