1use std::sync::Arc;
30use tiktoken_rs::{get_bpe_from_model, CoreBPE};
31use serde::{Deserialize, Serialize};
32use once_cell::sync::Lazy;
33use crate::Result;
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct TokenizerConfig {
38 pub encoding_model: String,
40 pub enable_caching: bool,
42 pub token_budget: Option<usize>,
44}
45
46impl Default for TokenizerConfig {
47 fn default() -> Self {
48 Self {
49 encoding_model: "gpt-4".to_string(),
50 enable_caching: true,
51 token_budget: Some(128000), }
53 }
54}
55
56static GLOBAL_TOKEN_COUNTER: Lazy<TokenCounter> = Lazy::new(|| {
59 TokenCounter::new(TokenizerConfig::default())
60 .expect("Failed to initialize global token counter")
61});
62
63pub struct TokenCounter {
65 config: TokenizerConfig,
66 bpe: Arc<CoreBPE>,
67}
68
69impl std::fmt::Debug for TokenCounter {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("TokenCounter")
72 .field("config", &self.config)
73 .field("bpe", &"<CoreBPE>")
74 .finish()
75 }
76}
77
78impl TokenCounter {
79 pub fn new(config: TokenizerConfig) -> Result<Self> {
81 let bpe = get_bpe_from_model(&config.encoding_model)
82 .map_err(|e| crate::ScribeError::tokenization(format!("Failed to load tokenizer for model '{}': {}", config.encoding_model, e)))?;
83
84 Ok(Self {
85 config,
86 bpe: Arc::new(bpe),
87 })
88 }
89
90 pub fn default() -> Result<Self> {
92 Self::new(TokenizerConfig::default())
93 }
94
95 pub fn global() -> &'static TokenCounter {
98 &GLOBAL_TOKEN_COUNTER
99 }
100
101 pub fn count_tokens(&self, content: &str) -> Result<usize> {
103 let tokens = self.bpe.encode_with_special_tokens(content);
104 Ok(tokens.len())
105 }
106
107 pub fn count_tokens_batch(&self, contents: &[&str]) -> Result<usize> {
109 let mut total = 0;
110 for content in contents {
111 total += self.count_tokens(content)?;
112 }
113 Ok(total)
114 }
115
116 pub fn estimate_file_tokens(&self, content: &str, file_path: &std::path::Path) -> Result<usize> {
118 let base_tokens = self.count_tokens(content)?;
120
121 let multiplier = self.get_language_multiplier(file_path);
123
124 Ok((base_tokens as f64 * multiplier).ceil() as usize)
125 }
126
127 fn get_language_multiplier(&self, file_path: &std::path::Path) -> f64 {
129 let extension = file_path.extension()
130 .and_then(|ext| ext.to_str())
131 .unwrap_or("");
132
133 match extension {
134 "java" | "csharp" | "cs" => 1.2,
136
137 "py" | "python" => 0.9,
139 "js" | "javascript" | "ts" | "typescript" => 0.95,
140 "rs" | "rust" => 1.0,
141 "go" => 0.95,
142
143 "json" | "yaml" | "yml" | "toml" => 0.8,
145 "xml" | "html" | "htm" => 1.1,
146
147 "md" | "markdown" | "txt" => 0.7,
149
150 _ => 1.0,
152 }
153 }
154
155 pub fn fits_budget(&self, content: &str) -> Result<bool> {
157 if let Some(budget) = self.config.token_budget {
158 let token_count = self.count_tokens(content)?;
159 Ok(token_count <= budget)
160 } else {
161 Ok(true) }
163 }
164
165 pub fn remaining_budget(&self, used_tokens: usize) -> Option<usize> {
167 self.config.token_budget.map(|budget| budget.saturating_sub(used_tokens))
168 }
169
170 pub fn chunk_content(&self, content: &str, chunk_size: usize) -> Result<Vec<String>> {
172 let tokens = self.bpe.encode_with_special_tokens(content);
173 let mut chunks = Vec::new();
174
175 for chunk_tokens in tokens.chunks(chunk_size) {
176 let chunk_text = self.bpe.decode(chunk_tokens.to_vec())
177 .map_err(|e| crate::ScribeError::tokenization(format!("Failed to decode token chunk: {}", e)))?;
178 chunks.push(chunk_text);
179 }
180
181 Ok(chunks)
182 }
183
184 pub fn config(&self) -> &TokenizerConfig {
186 &self.config
187 }
188
189 pub fn set_token_budget(&mut self, budget: Option<usize>) {
191 self.config.token_budget = budget;
192 }
193}
194
195#[derive(Debug, Clone)]
197pub struct TokenBudget {
198 total_budget: usize,
199 used_tokens: usize,
200 reserved_tokens: usize,
201}
202
203impl TokenBudget {
204 pub fn new(total_budget: usize) -> Self {
206 Self {
207 total_budget,
208 used_tokens: 0,
209 reserved_tokens: 0,
210 }
211 }
212
213 pub fn total(&self) -> usize {
215 self.total_budget
216 }
217
218 pub fn used(&self) -> usize {
220 self.used_tokens
221 }
222
223 pub fn reserved(&self) -> usize {
225 self.reserved_tokens
226 }
227
228 pub fn available(&self) -> usize {
230 self.total_budget.saturating_sub(self.used_tokens + self.reserved_tokens)
231 }
232
233 pub fn can_allocate(&self, tokens: usize) -> bool {
235 self.available() >= tokens
236 }
237
238 pub fn allocate(&mut self, tokens: usize) -> bool {
240 if self.can_allocate(tokens) {
241 self.used_tokens += tokens;
242 true
243 } else {
244 false
245 }
246 }
247
248 pub fn reserve(&mut self, tokens: usize) -> bool {
250 if self.available() >= tokens {
251 self.reserved_tokens += tokens;
252 true
253 } else {
254 false
255 }
256 }
257
258 pub fn confirm_reservation(&mut self, tokens: usize) {
260 let to_confirm = tokens.min(self.reserved_tokens);
261 self.reserved_tokens -= to_confirm;
262 self.used_tokens += to_confirm;
263 }
264
265 pub fn release_reservation(&mut self, tokens: usize) {
267 self.reserved_tokens = self.reserved_tokens.saturating_sub(tokens);
268 }
269
270 pub fn utilization(&self) -> f64 {
272 (self.used_tokens as f64 / self.total_budget as f64) * 100.0
273 }
274
275 pub fn reset(&mut self) {
277 self.used_tokens = 0;
278 self.reserved_tokens = 0;
279 }
280}
281
282pub mod utils {
284 use super::*;
285
286 pub fn estimate_tokens_legacy(content: &str) -> usize {
288 (content.chars().count() as f64 / 4.0).ceil() as usize
290 }
291
292 pub fn compare_tokenization_accuracy(content: &str, counter: &TokenCounter) -> Result<TokenizationComparison> {
294 let tiktoken_count = counter.count_tokens(content)?;
295 let legacy_count = estimate_tokens_legacy(content);
296
297 let accuracy_ratio = if legacy_count > 0 {
298 tiktoken_count as f64 / legacy_count as f64
299 } else {
300 1.0
301 };
302
303 Ok(TokenizationComparison {
304 tiktoken_count,
305 legacy_count,
306 accuracy_ratio,
307 improvement: if accuracy_ratio < 1.0 {
308 Some((1.0 - accuracy_ratio) * 100.0)
309 } else {
310 None
311 },
312 })
313 }
314
315 pub fn recommend_token_budget(model: &str, content_type: ContentType) -> usize {
317 let base_budget = match model {
318 "gpt-4" | "gpt-4-turbo" => 128000,
319 "gpt-4-32k" => 32000,
320 "gpt-3.5-turbo" => 16000,
321 "gpt-3.5-turbo-16k" => 16000,
322 _ => 8000, };
324
325 match content_type {
327 ContentType::Code => (base_budget as f64 * 0.8) as usize, ContentType::Documentation => base_budget,
329 ContentType::Mixed => (base_budget as f64 * 0.9) as usize,
330 }
331 }
332}
333
334#[derive(Debug, Clone, Copy)]
336pub enum ContentType {
337 Code,
338 Documentation,
339 Mixed,
340}
341
342#[derive(Debug, Clone)]
344pub struct TokenizationComparison {
345 pub tiktoken_count: usize,
346 pub legacy_count: usize,
347 pub accuracy_ratio: f64,
348 pub improvement: Option<f64>, }
350
351impl TokenizationComparison {
352 pub fn format(&self) -> String {
354 match self.improvement {
355 Some(improvement) => format!(
356 "Tiktoken: {} tokens, Legacy: {} tokens, {:.1}% more accurate",
357 self.tiktoken_count, self.legacy_count, improvement
358 ),
359 None => format!(
360 "Tiktoken: {} tokens, Legacy: {} tokens, {:.2}x ratio",
361 self.tiktoken_count, self.legacy_count, self.accuracy_ratio
362 ),
363 }
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use std::path::Path;
371
372 #[test]
373 fn test_token_counter_creation() {
374 let config = TokenizerConfig::default();
375 let counter = TokenCounter::new(config);
376 assert!(counter.is_ok());
377 }
378
379 #[test]
380 fn test_basic_token_counting() {
381 let counter = TokenCounter::default().unwrap();
382
383 let simple_text = "Hello, world!";
384 let count = counter.count_tokens(simple_text).unwrap();
385 assert!(count > 0);
386 assert!(count < 10); }
388
389 #[test]
390 fn test_code_token_counting() {
391 let counter = TokenCounter::default().unwrap();
392
393 let rust_code = r#"
394fn main() {
395 println!("Hello, world!");
396 let x = 42;
397 if x > 0 {
398 println!("Positive number: {}", x);
399 }
400}
401"#;
402
403 let count = counter.count_tokens(rust_code).unwrap();
404 assert!(count > 20); assert!(count < 100); }
407
408 #[test]
409 fn test_language_multipliers() {
410 let counter = TokenCounter::default().unwrap();
411
412 let content = "function test() { return 42; }";
413
414 let js_tokens = counter.estimate_file_tokens(content, Path::new("test.js")).unwrap();
415 let java_tokens = counter.estimate_file_tokens(content, Path::new("test.java")).unwrap();
416 let py_tokens = counter.estimate_file_tokens(content, Path::new("test.py")).unwrap();
417
418 assert!(java_tokens >= js_tokens);
420 assert!(py_tokens <= js_tokens);
422 }
423
424 #[test]
425 fn test_token_budget() {
426 let mut budget = TokenBudget::new(1000);
427
428 assert_eq!(budget.total(), 1000);
429 assert_eq!(budget.used(), 0);
430 assert_eq!(budget.available(), 1000);
431
432 assert!(budget.allocate(300));
433 assert_eq!(budget.used(), 300);
434 assert_eq!(budget.available(), 700);
435
436 assert!(budget.reserve(200));
437 assert_eq!(budget.reserved(), 200);
438 assert_eq!(budget.available(), 500);
439
440 budget.confirm_reservation(150);
441 assert_eq!(budget.used(), 450);
442 assert_eq!(budget.reserved(), 50);
443 assert_eq!(budget.available(), 500);
444 }
445
446 #[test]
447 fn test_content_chunking() {
448 let counter = TokenCounter::default().unwrap();
449
450 let long_content = "word ".repeat(1000); let chunks = counter.chunk_content(&long_content, 100).unwrap();
452
453 assert!(chunks.len() > 1); for chunk in &chunks {
457 let chunk_tokens = counter.count_tokens(chunk).unwrap();
458 assert!(chunk_tokens <= 120); }
460 }
461
462 #[test]
463 fn test_tokenization_comparison() {
464 let counter = TokenCounter::default().unwrap();
465
466 let code_content = r#"
467use std::collections::HashMap;
468
469fn process_data(input: &str) -> Result<HashMap<String, i32>, Box<dyn std::error::Error>> {
470 let mut result = HashMap::new();
471 for line in input.lines() {
472 let parts: Vec<&str> = line.split(':').collect();
473 if parts.len() == 2 {
474 result.insert(parts[0].to_string(), parts[1].parse()?);
475 }
476 }
477 Ok(result)
478}
479"#;
480
481 let comparison = utils::compare_tokenization_accuracy(code_content, &counter).unwrap();
482
483 assert!(comparison.tiktoken_count > 0);
484 assert!(comparison.legacy_count > 0);
485 assert!(comparison.accuracy_ratio > 0.0);
486
487 let formatted = comparison.format();
488 assert!(formatted.contains("Tiktoken"));
489 assert!(formatted.contains("Legacy"));
490 }
491
492 #[test]
493 fn test_budget_recommendations() {
494 let code_budget = utils::recommend_token_budget("gpt-4", ContentType::Code);
495 let doc_budget = utils::recommend_token_budget("gpt-4", ContentType::Documentation);
496 let mixed_budget = utils::recommend_token_budget("gpt-4", ContentType::Mixed);
497
498 assert!(code_budget < doc_budget); assert!(mixed_budget > code_budget);
500 assert!(mixed_budget < doc_budget);
501 }
502}