1use moka::sync::Cache;
37use std::collections::HashMap;
38use std::sync::Arc;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub enum TokenizerModel {
47 Cl100kBase,
49 P50kBase,
51 Claude,
53 Llama,
55 Generic,
57}
58
59impl TokenizerModel {
60 pub fn bytes_per_token(&self) -> f32 {
62 match self {
63 Self::Cl100kBase => 3.8,
64 Self::P50kBase => 4.0,
65 Self::Claude => 4.2,
66 Self::Llama => 4.0,
67 Self::Generic => 4.0,
68 }
69 }
70
71 pub fn name(&self) -> &'static str {
73 match self {
74 Self::Cl100kBase => "cl100k_base",
75 Self::P50kBase => "p50k_base",
76 Self::Claude => "claude",
77 Self::Llama => "llama",
78 Self::Generic => "generic",
79 }
80 }
81}
82
83#[derive(Debug, Clone)]
85pub struct ExactTokenConfig {
86 pub model: TokenizerModel,
88
89 pub cache_size: usize,
91
92 pub cache_ttl_secs: u64,
94
95 pub fallback_on_error: bool,
97
98 pub max_cache_text_len: usize,
100}
101
102impl Default for ExactTokenConfig {
103 fn default() -> Self {
104 Self {
105 model: TokenizerModel::Cl100kBase,
106 cache_size: 10_000,
107 cache_ttl_secs: 3600,
108 fallback_on_error: true,
109 max_cache_text_len: 10_000,
110 }
111 }
112}
113
114impl ExactTokenConfig {
115 pub fn gpt4() -> Self {
117 Self {
118 model: TokenizerModel::Cl100kBase,
119 ..Default::default()
120 }
121 }
122
123 pub fn claude() -> Self {
125 Self {
126 model: TokenizerModel::Claude,
127 ..Default::default()
128 }
129 }
130}
131
132pub trait TokenCounter: Send + Sync {
138 fn count(&self, text: &str) -> usize;
140
141 fn count_for_model(&self, text: &str, model: TokenizerModel) -> usize {
143 let _ = model; self.count(text)
145 }
146
147 fn tokenize(&self, text: &str) -> Vec<u32>;
149
150 fn decode(&self, tokens: &[u32]) -> String;
152
153 fn model(&self) -> TokenizerModel;
155
156 fn is_exact(&self) -> bool;
158}
159
160pub struct ExactTokenCounter {
169 config: ExactTokenConfig,
170
171 cache: Cache<u64, usize>,
173
174 vocab: Arc<BpeVocab>,
176
177 stats: Arc<TokenCacheStats>,
179}
180
181struct BpeVocab {
183 token_to_id: HashMap<String, u32>,
185
186 id_to_token: HashMap<u32, String>,
188
189 #[allow(dead_code)]
191 merges: HashMap<(String, String), String>,
192
193 special_tokens: HashMap<String, u32>,
195}
196
197impl BpeVocab {
198 fn cl100k_base() -> Self {
200 let mut token_to_id = HashMap::new();
201 let mut id_to_token = HashMap::new();
202
203 for b in 32u8..127 {
205 let token = String::from(b as char);
206 let id = b as u32;
207 token_to_id.insert(token.clone(), id);
208 id_to_token.insert(id, token);
209 }
210
211 let common_tokens = [
213 "the", "ing", "tion", "ed", "er", "es", "en", "al", "re", "on", "an", "or", "ar", "is",
214 "it", "at", "as", "le", "ve", " the", " a", " to", " of", " and", " in", " is", " for",
215 " ", "\n", "\t", "```", "...", "->", "=>", "==", "!=",
216 ];
217
218 let mut id = 200u32;
219 for token in common_tokens {
220 token_to_id.insert(token.to_string(), id);
221 id_to_token.insert(id, token.to_string());
222 id += 1;
223 }
224
225 let mut special_tokens = HashMap::new();
227 special_tokens.insert("<|endoftext|>".to_string(), 100257);
228 special_tokens.insert("<|fim_prefix|>".to_string(), 100258);
229 special_tokens.insert("<|fim_middle|>".to_string(), 100259);
230 special_tokens.insert("<|fim_suffix|>".to_string(), 100260);
231
232 Self {
233 token_to_id,
234 id_to_token,
235 merges: HashMap::new(),
236 special_tokens,
237 }
238 }
239
240 fn tokenize(&self, text: &str) -> Vec<u32> {
242 let mut tokens = Vec::new();
243 let mut remaining = text;
244
245 while !remaining.is_empty() {
246 let mut matched = false;
248
249 for (special, id) in &self.special_tokens {
251 if remaining.starts_with(special) {
252 tokens.push(*id);
253 remaining = &remaining[special.len()..];
254 matched = true;
255 break;
256 }
257 }
258
259 if matched {
260 continue;
261 }
262
263 for len in (1..=remaining.len().min(10)).rev() {
265 if let Some(substr) = remaining.get(..len) {
266 if let Some(&id) = self.token_to_id.get(substr) {
267 tokens.push(id);
268 remaining = &remaining[len..];
269 matched = true;
270 break;
271 }
272 }
273 }
274
275 if !matched {
276 if let Some(c) = remaining.chars().next() {
278 let byte_id = (c as u32).min(255);
279 tokens.push(byte_id);
280 remaining = &remaining[c.len_utf8()..];
281 }
282 }
283 }
284
285 tokens
286 }
287
288 fn decode(&self, tokens: &[u32]) -> String {
290 let mut result = String::new();
291
292 for &id in tokens {
293 if let Some(token) = self.id_to_token.get(&id) {
294 result.push_str(token);
295 } else {
296 if id < 256 {
298 if let Some(c) = char::from_u32(id) {
299 result.push(c);
300 }
301 }
302 }
303 }
304
305 result
306 }
307}
308
309#[derive(Debug, Default)]
311pub struct TokenCacheStats {
312 pub hits: std::sync::atomic::AtomicUsize,
314 pub misses: std::sync::atomic::AtomicUsize,
316 pub tokenizations: std::sync::atomic::AtomicUsize,
318 pub total_tokens: std::sync::atomic::AtomicUsize,
320}
321
322impl TokenCacheStats {
323 pub fn hit_rate(&self) -> f64 {
325 let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
326 let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
327 let total = hits + misses;
328 if total == 0 {
329 0.0
330 } else {
331 hits as f64 / total as f64
332 }
333 }
334}
335
336impl ExactTokenCounter {
337 pub fn new(config: ExactTokenConfig) -> Self {
339 let cache = Cache::builder()
340 .max_capacity(config.cache_size as u64)
341 .time_to_live(std::time::Duration::from_secs(config.cache_ttl_secs))
342 .build();
343
344 Self {
345 config,
346 cache,
347 vocab: Arc::new(BpeVocab::cl100k_base()),
348 stats: Arc::new(TokenCacheStats::default()),
349 }
350 }
351
352 pub fn default_counter() -> Self {
354 Self::new(ExactTokenConfig::default())
355 }
356
357 pub fn stats(&self) -> &Arc<TokenCacheStats> {
359 &self.stats
360 }
361
362 fn text_hash(text: &str) -> u64 {
364 use std::collections::hash_map::DefaultHasher;
365 use std::hash::{Hash, Hasher};
366
367 let mut hasher = DefaultHasher::new();
368 text.hash(&mut hasher);
369 hasher.finish()
370 }
371
372 fn count_cached(&self, text: &str) -> usize {
374 if text.len() > self.config.max_cache_text_len {
376 self.stats
377 .misses
378 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
379 return self.tokenize(text).len();
380 }
381
382 let hash = Self::text_hash(text);
383
384 if let Some(count) = self.cache.get(&hash) {
385 self.stats
386 .hits
387 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
388 return count;
389 }
390
391 self.stats
392 .misses
393 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
394
395 let tokens = self.tokenize(text);
396 let count = tokens.len();
397
398 self.cache.insert(hash, count);
399 self.stats
400 .total_tokens
401 .fetch_add(count, std::sync::atomic::Ordering::Relaxed);
402
403 count
404 }
405
406 #[allow(dead_code)]
408 fn estimate_tokens(&self, text: &str) -> usize {
409 let bytes = text.len();
410 ((bytes as f32) / self.config.model.bytes_per_token()).ceil() as usize
411 }
412}
413
414impl TokenCounter for ExactTokenCounter {
415 fn count(&self, text: &str) -> usize {
416 self.stats
417 .tokenizations
418 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
419 self.count_cached(text)
420 }
421
422 fn count_for_model(&self, text: &str, model: TokenizerModel) -> usize {
423 if model == self.config.model {
424 self.count(text)
425 } else {
426 let bytes = text.len();
428 ((bytes as f32) / model.bytes_per_token()).ceil() as usize
429 }
430 }
431
432 fn tokenize(&self, text: &str) -> Vec<u32> {
433 self.vocab.tokenize(text)
434 }
435
436 fn decode(&self, tokens: &[u32]) -> String {
437 self.vocab.decode(tokens)
438 }
439
440 fn model(&self) -> TokenizerModel {
441 self.config.model
442 }
443
444 fn is_exact(&self) -> bool {
445 true
446 }
447}
448
449pub struct HeuristicTokenCounter {
455 bytes_per_token: f32,
457
458 model: TokenizerModel,
460}
461
462impl HeuristicTokenCounter {
463 pub fn new() -> Self {
465 Self {
466 bytes_per_token: 4.0,
467 model: TokenizerModel::Generic,
468 }
469 }
470
471 pub fn for_model(model: TokenizerModel) -> Self {
473 Self {
474 bytes_per_token: model.bytes_per_token(),
475 model,
476 }
477 }
478}
479
480impl Default for HeuristicTokenCounter {
481 fn default() -> Self {
482 Self::new()
483 }
484}
485
486impl TokenCounter for HeuristicTokenCounter {
487 fn count(&self, text: &str) -> usize {
488 let bytes = text.len();
489 ((bytes as f32) / self.bytes_per_token).ceil() as usize
490 }
491
492 fn tokenize(&self, text: &str) -> Vec<u32> {
493 text.split_whitespace()
495 .enumerate()
496 .map(|(i, _)| i as u32)
497 .collect()
498 }
499
500 fn decode(&self, _tokens: &[u32]) -> String {
501 "[decode not supported for heuristic counter]".to_string()
503 }
504
505 fn model(&self) -> TokenizerModel {
506 self.model
507 }
508
509 fn is_exact(&self) -> bool {
510 false
511 }
512}
513
514pub struct ExactBudgetEnforcer<C: TokenCounter> {
520 counter: Arc<C>,
522
523 budget: usize,
525
526 used: std::sync::atomic::AtomicUsize,
528}
529
530impl<C: TokenCounter> ExactBudgetEnforcer<C> {
531 pub fn new(counter: Arc<C>, budget: usize) -> Self {
533 Self {
534 counter,
535 budget,
536 used: std::sync::atomic::AtomicUsize::new(0),
537 }
538 }
539
540 pub fn remaining(&self) -> usize {
542 self.budget
543 .saturating_sub(self.used.load(std::sync::atomic::Ordering::Relaxed))
544 }
545
546 pub fn fits(&self, text: &str) -> bool {
548 let tokens = self.counter.count(text);
549 tokens <= self.remaining()
550 }
551
552 pub fn try_consume(&self, text: &str) -> Option<usize> {
555 let tokens = self.counter.count(text);
556 let remaining = self.remaining();
557
558 if tokens <= remaining {
559 self.used
560 .fetch_add(tokens, std::sync::atomic::Ordering::Relaxed);
561 Some(tokens)
562 } else {
563 None
564 }
565 }
566
567 pub fn force_consume(&self, tokens: usize) {
569 self.used
570 .fetch_add(tokens, std::sync::atomic::Ordering::Relaxed);
571 }
572
573 pub fn truncate_to_fit(&self, text: &str) -> (String, usize) {
575 let remaining = self.remaining();
576 if remaining == 0 {
577 return (String::new(), 0);
578 }
579
580 let mut low = 0;
582 let mut high = text.len();
583 let mut best_len = 0;
584 let mut best_tokens = 0;
585
586 while low < high {
587 let mid = (low + high + 1) / 2;
588
589 let truncated = if mid >= text.len() {
591 text.to_string()
592 } else {
593 let mut end = mid;
594 while !text.is_char_boundary(end) && end > 0 {
595 end -= 1;
596 }
597 text[..end].to_string()
598 };
599
600 let tokens = self.counter.count(&truncated);
601
602 if tokens <= remaining {
603 best_len = truncated.len();
604 best_tokens = tokens;
605 low = mid;
606 } else {
607 high = mid - 1;
608 }
609 }
610
611 if best_len == 0 {
612 (String::new(), 0)
613 } else {
614 (text[..best_len].to_string(), best_tokens)
615 }
616 }
617
618 pub fn summary(&self) -> BudgetSummary {
620 let used = self.used.load(std::sync::atomic::Ordering::Relaxed);
621 BudgetSummary {
622 budget: self.budget,
623 used,
624 remaining: self.budget.saturating_sub(used),
625 utilization: (used as f64) / (self.budget as f64),
626 }
627 }
628}
629
630#[derive(Debug, Clone)]
632pub struct BudgetSummary {
633 pub budget: usize,
635 pub used: usize,
637 pub remaining: usize,
639 pub utilization: f64,
641}
642
643pub fn count_tokens_exact(text: &str) -> usize {
649 let counter = ExactTokenCounter::default_counter();
650 counter.count(text)
651}
652
653pub fn count_tokens_heuristic(text: &str) -> usize {
655 let counter = HeuristicTokenCounter::new();
656 counter.count(text)
657}
658
659pub fn create_budget_enforcer(budget: usize) -> ExactBudgetEnforcer<ExactTokenCounter> {
661 let counter = Arc::new(ExactTokenCounter::default_counter());
662 ExactBudgetEnforcer::new(counter, budget)
663}
664
665#[cfg(test)]
670mod tests {
671 use super::*;
672
673 #[test]
674 fn test_exact_token_count() {
675 let counter = ExactTokenCounter::default_counter();
676
677 let count = counter.count("Hello, world!");
678 assert!(count > 0);
679 assert!(count < 20); }
681
682 #[test]
683 fn test_tokenize_and_decode() {
684 let counter = ExactTokenCounter::default_counter();
685
686 let text = "Hello world";
687 let tokens = counter.tokenize(text);
688
689 assert!(!tokens.is_empty());
690
691 let decoded = counter.decode(&tokens);
693 assert!(!decoded.is_empty());
694 }
695
696 #[test]
697 fn test_cache_hits() {
698 let counter = ExactTokenCounter::default_counter();
699
700 let _ = counter.count("test text for caching");
702
703 let _ = counter.count("test text for caching");
705
706 let stats = counter.stats();
707 let hits = stats.hits.load(std::sync::atomic::Ordering::Relaxed);
708 let misses = stats.misses.load(std::sync::atomic::Ordering::Relaxed);
709
710 assert!(hits >= 1);
711 assert!(misses >= 1);
712 }
713
714 #[test]
715 fn test_heuristic_counter() {
716 let counter = HeuristicTokenCounter::new();
717
718 let count = counter.count("Hello world");
720 assert!(count >= 2 && count <= 5);
721 }
722
723 #[test]
724 fn test_budget_enforcer() {
725 let counter = Arc::new(ExactTokenCounter::default_counter());
726 let enforcer = ExactBudgetEnforcer::new(counter, 100);
727
728 assert_eq!(enforcer.remaining(), 100);
729
730 let consumed = enforcer.try_consume("Hello world").unwrap();
732 assert!(consumed > 0);
733 assert!(enforcer.remaining() < 100);
734 }
735
736 #[test]
737 fn test_budget_truncation() {
738 let counter = Arc::new(ExactTokenCounter::default_counter());
739 let enforcer = ExactBudgetEnforcer::new(counter, 5);
740
741 let long_text =
742 "This is a very long text that definitely exceeds five tokens and should be truncated";
743
744 let (truncated, tokens) = enforcer.truncate_to_fit(long_text);
745
746 assert!(truncated.len() < long_text.len());
747 assert!(tokens <= 5);
748 }
749
750 #[test]
751 fn test_budget_summary() {
752 let counter = Arc::new(HeuristicTokenCounter::new());
753 let enforcer = ExactBudgetEnforcer::new(counter, 100);
754
755 enforcer.force_consume(25);
756
757 let summary = enforcer.summary();
758 assert_eq!(summary.budget, 100);
759 assert_eq!(summary.used, 25);
760 assert_eq!(summary.remaining, 75);
761 assert!((summary.utilization - 0.25).abs() < 0.01);
762 }
763
764 #[test]
765 fn test_model_specific_counting() {
766 let counter = ExactTokenCounter::default_counter();
767
768 let text = "Hello, world!";
769
770 let gpt4_count = counter.count_for_model(text, TokenizerModel::Cl100kBase);
772 let claude_count = counter.count_for_model(text, TokenizerModel::Claude);
773
774 assert!(gpt4_count > 0);
776 assert!(claude_count > 0);
777 }
778}