1use crate::Result;
30use once_cell::sync::Lazy;
31use serde::{Deserialize, Serialize};
32use std::sync::Arc;
33use tiktoken_rs::{get_bpe_from_model, CoreBPE};
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct TokenizerConfig {
38 pub encoding_model: String,
40 pub enable_caching: bool,
42 pub token_budget: Option<usize>,
44}
45
46impl Default for TokenizerConfig {
47 fn default() -> Self {
48 Self {
49 encoding_model: "gpt-4".to_string(),
50 enable_caching: true,
51 token_budget: Some(128000), }
53 }
54}
55
56static GLOBAL_TOKEN_COUNTER: Lazy<TokenCounter> = Lazy::new(|| {
59 TokenCounter::new(TokenizerConfig::default())
60 .expect("Failed to initialize global token counter")
61});
62
63pub struct TokenCounter {
65 config: TokenizerConfig,
66 bpe: Arc<CoreBPE>,
67}
68
69impl std::fmt::Debug for TokenCounter {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("TokenCounter")
72 .field("config", &self.config)
73 .field("bpe", &"<CoreBPE>")
74 .finish()
75 }
76}
77
78impl TokenCounter {
79 pub fn new(config: TokenizerConfig) -> Result<Self> {
81 let bpe = get_bpe_from_model(&config.encoding_model).map_err(|e| {
82 crate::ScribeError::tokenization(format!(
83 "Failed to load tokenizer for model '{}': {}",
84 config.encoding_model, e
85 ))
86 })?;
87
88 Ok(Self {
89 config,
90 bpe: Arc::new(bpe),
91 })
92 }
93
94 pub fn default() -> Result<Self> {
96 Self::new(TokenizerConfig::default())
97 }
98
99 pub fn global() -> &'static TokenCounter {
102 &GLOBAL_TOKEN_COUNTER
103 }
104
105 pub fn count_tokens(&self, content: &str) -> Result<usize> {
107 let tokens = self.bpe.encode_with_special_tokens(content);
108 Ok(tokens.len())
109 }
110
111 pub fn count_tokens_batch(&self, contents: &[&str]) -> Result<usize> {
113 let mut total = 0;
114 for content in contents {
115 total += self.count_tokens(content)?;
116 }
117 Ok(total)
118 }
119
120 pub fn estimate_file_tokens(
122 &self,
123 content: &str,
124 file_path: &std::path::Path,
125 ) -> Result<usize> {
126 let base_tokens = self.count_tokens(content)?;
128
129 let multiplier = self.get_language_multiplier(file_path);
131
132 Ok((base_tokens as f64 * multiplier).ceil() as usize)
133 }
134
135 fn get_language_multiplier(&self, file_path: &std::path::Path) -> f64 {
137 let extension = file_path
138 .extension()
139 .and_then(|ext| ext.to_str())
140 .unwrap_or("");
141
142 match extension {
143 "java" | "csharp" | "cs" => 1.2,
145
146 "py" | "python" => 0.9,
148 "js" | "javascript" | "ts" | "typescript" => 0.95,
149 "rs" | "rust" => 1.0,
150 "go" => 0.95,
151
152 "json" | "yaml" | "yml" | "toml" => 0.8,
154 "xml" | "html" | "htm" => 1.1,
155
156 "md" | "markdown" | "txt" => 0.7,
158
159 _ => 1.0,
161 }
162 }
163
164 pub fn fits_budget(&self, content: &str) -> Result<bool> {
166 if let Some(budget) = self.config.token_budget {
167 let token_count = self.count_tokens(content)?;
168 Ok(token_count <= budget)
169 } else {
170 Ok(true) }
172 }
173
174 pub fn remaining_budget(&self, used_tokens: usize) -> Option<usize> {
176 self.config
177 .token_budget
178 .map(|budget| budget.saturating_sub(used_tokens))
179 }
180
181 pub fn chunk_content(&self, content: &str, chunk_size: usize) -> Result<Vec<String>> {
183 let tokens = self.bpe.encode_with_special_tokens(content);
184 let mut chunks = Vec::new();
185
186 for chunk_tokens in tokens.chunks(chunk_size) {
187 let chunk_text = self.bpe.decode(chunk_tokens.to_vec()).map_err(|e| {
188 crate::ScribeError::tokenization(format!("Failed to decode token chunk: {}", e))
189 })?;
190 chunks.push(chunk_text);
191 }
192
193 Ok(chunks)
194 }
195
196 pub fn config(&self) -> &TokenizerConfig {
198 &self.config
199 }
200
201 pub fn set_token_budget(&mut self, budget: Option<usize>) {
203 self.config.token_budget = budget;
204 }
205}
206
207#[derive(Debug, Clone)]
209pub struct TokenBudget {
210 total_budget: usize,
211 used_tokens: usize,
212 reserved_tokens: usize,
213}
214
215impl TokenBudget {
216 pub fn new(total_budget: usize) -> Self {
218 Self {
219 total_budget,
220 used_tokens: 0,
221 reserved_tokens: 0,
222 }
223 }
224
225 pub fn total(&self) -> usize {
227 self.total_budget
228 }
229
230 pub fn used(&self) -> usize {
232 self.used_tokens
233 }
234
235 pub fn reserved(&self) -> usize {
237 self.reserved_tokens
238 }
239
240 pub fn available(&self) -> usize {
242 self.total_budget
243 .saturating_sub(self.used_tokens + self.reserved_tokens)
244 }
245
246 pub fn can_allocate(&self, tokens: usize) -> bool {
248 self.available() >= tokens
249 }
250
251 pub fn allocate(&mut self, tokens: usize) -> bool {
253 if self.can_allocate(tokens) {
254 self.used_tokens += tokens;
255 true
256 } else {
257 false
258 }
259 }
260
261 pub fn reserve(&mut self, tokens: usize) -> bool {
263 if self.available() >= tokens {
264 self.reserved_tokens += tokens;
265 true
266 } else {
267 false
268 }
269 }
270
271 pub fn confirm_reservation(&mut self, tokens: usize) {
273 let to_confirm = tokens.min(self.reserved_tokens);
274 self.reserved_tokens -= to_confirm;
275 self.used_tokens += to_confirm;
276 }
277
278 pub fn release_reservation(&mut self, tokens: usize) {
280 self.reserved_tokens = self.reserved_tokens.saturating_sub(tokens);
281 }
282
283 pub fn utilization(&self) -> f64 {
285 (self.used_tokens as f64 / self.total_budget as f64) * 100.0
286 }
287
288 pub fn reset(&mut self) {
290 self.used_tokens = 0;
291 self.reserved_tokens = 0;
292 }
293}
294
295pub mod utils {
297 use super::*;
298
299 pub fn estimate_tokens_legacy(content: &str) -> usize {
301 (content.chars().count() as f64 / 4.0).ceil() as usize
303 }
304
305 pub fn compare_tokenization_accuracy(
307 content: &str,
308 counter: &TokenCounter,
309 ) -> Result<TokenizationComparison> {
310 let tiktoken_count = counter.count_tokens(content)?;
311 let legacy_count = estimate_tokens_legacy(content);
312
313 let accuracy_ratio = if legacy_count > 0 {
314 tiktoken_count as f64 / legacy_count as f64
315 } else {
316 1.0
317 };
318
319 Ok(TokenizationComparison {
320 tiktoken_count,
321 legacy_count,
322 accuracy_ratio,
323 improvement: if accuracy_ratio < 1.0 {
324 Some((1.0 - accuracy_ratio) * 100.0)
325 } else {
326 None
327 },
328 })
329 }
330
331 pub fn recommend_token_budget(model: &str, content_type: ContentType) -> usize {
333 let base_budget = match model {
334 "gpt-4" | "gpt-4-turbo" => 128000,
335 "gpt-4-32k" => 32000,
336 "gpt-3.5-turbo" => 16000,
337 "gpt-3.5-turbo-16k" => 16000,
338 _ => 8000, };
340
341 match content_type {
343 ContentType::Code => (base_budget as f64 * 0.8) as usize, ContentType::Documentation => base_budget,
345 ContentType::Mixed => (base_budget as f64 * 0.9) as usize,
346 }
347 }
348}
349
350#[derive(Debug, Clone, Copy)]
352pub enum ContentType {
353 Code,
354 Documentation,
355 Mixed,
356}
357
358#[derive(Debug, Clone)]
360pub struct TokenizationComparison {
361 pub tiktoken_count: usize,
362 pub legacy_count: usize,
363 pub accuracy_ratio: f64,
364 pub improvement: Option<f64>, }
366
367impl TokenizationComparison {
368 pub fn format(&self) -> String {
370 match self.improvement {
371 Some(improvement) => format!(
372 "Tiktoken: {} tokens, Legacy: {} tokens, {:.1}% more accurate",
373 self.tiktoken_count, self.legacy_count, improvement
374 ),
375 None => format!(
376 "Tiktoken: {} tokens, Legacy: {} tokens, {:.2}x ratio",
377 self.tiktoken_count, self.legacy_count, self.accuracy_ratio
378 ),
379 }
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use std::path::Path;
387
388 #[test]
389 fn test_token_counter_creation() {
390 let config = TokenizerConfig::default();
391 let counter = TokenCounter::new(config);
392 assert!(counter.is_ok());
393 }
394
395 #[test]
396 fn test_basic_token_counting() {
397 let counter = TokenCounter::default().unwrap();
398
399 let simple_text = "Hello, world!";
400 let count = counter.count_tokens(simple_text).unwrap();
401 assert!(count > 0);
402 assert!(count < 10); }
404
405 #[test]
406 fn test_code_token_counting() {
407 let counter = TokenCounter::default().unwrap();
408
409 let rust_code = r#"
410fn main() {
411 println!("Hello, world!");
412 let x = 42;
413 if x > 0 {
414 println!("Positive number: {}", x);
415 }
416}
417"#;
418
419 let count = counter.count_tokens(rust_code).unwrap();
420 assert!(count > 20); assert!(count < 100); }
423
424 #[test]
425 fn test_language_multipliers() {
426 let counter = TokenCounter::default().unwrap();
427
428 let content = "function test() { return 42; }";
429
430 let js_tokens = counter
431 .estimate_file_tokens(content, Path::new("test.js"))
432 .unwrap();
433 let java_tokens = counter
434 .estimate_file_tokens(content, Path::new("test.java"))
435 .unwrap();
436 let py_tokens = counter
437 .estimate_file_tokens(content, Path::new("test.py"))
438 .unwrap();
439
440 assert!(java_tokens >= js_tokens);
442 assert!(py_tokens <= js_tokens);
444 }
445
446 #[test]
447 fn test_token_budget() {
448 let mut budget = TokenBudget::new(1000);
449
450 assert_eq!(budget.total(), 1000);
451 assert_eq!(budget.used(), 0);
452 assert_eq!(budget.available(), 1000);
453
454 assert!(budget.allocate(300));
455 assert_eq!(budget.used(), 300);
456 assert_eq!(budget.available(), 700);
457
458 assert!(budget.reserve(200));
459 assert_eq!(budget.reserved(), 200);
460 assert_eq!(budget.available(), 500);
461
462 budget.confirm_reservation(150);
463 assert_eq!(budget.used(), 450);
464 assert_eq!(budget.reserved(), 50);
465 assert_eq!(budget.available(), 500);
466 }
467
468 #[test]
469 fn test_content_chunking() {
470 let counter = TokenCounter::default().unwrap();
471
472 let long_content = "word ".repeat(1000); let chunks = counter.chunk_content(&long_content, 100).unwrap();
474
475 assert!(chunks.len() > 1); for chunk in &chunks {
479 let chunk_tokens = counter.count_tokens(chunk).unwrap();
480 assert!(chunk_tokens <= 120); }
482 }
483
484 #[test]
485 fn test_tokenization_comparison() {
486 let counter = TokenCounter::default().unwrap();
487
488 let code_content = r#"
489use std::collections::HashMap;
490
491fn process_data(input: &str) -> Result<HashMap<String, i32>, Box<dyn std::error::Error>> {
492 let mut result = HashMap::new();
493 for line in input.lines() {
494 let parts: Vec<&str> = line.split(':').collect();
495 if parts.len() == 2 {
496 result.insert(parts[0].to_string(), parts[1].parse()?);
497 }
498 }
499 Ok(result)
500}
501"#;
502
503 let comparison = utils::compare_tokenization_accuracy(code_content, &counter).unwrap();
504
505 assert!(comparison.tiktoken_count > 0);
506 assert!(comparison.legacy_count > 0);
507 assert!(comparison.accuracy_ratio > 0.0);
508
509 let formatted = comparison.format();
510 assert!(formatted.contains("Tiktoken"));
511 assert!(formatted.contains("Legacy"));
512 }
513
514 #[test]
515 fn test_budget_recommendations() {
516 let code_budget = utils::recommend_token_budget("gpt-4", ContentType::Code);
517 let doc_budget = utils::recommend_token_budget("gpt-4", ContentType::Documentation);
518 let mixed_budget = utils::recommend_token_budget("gpt-4", ContentType::Mixed);
519
520 assert!(code_budget < doc_budget); assert!(mixed_budget > code_budget);
522 assert!(mixed_budget < doc_budget);
523 }
524}