Skip to main content

token_count/tokenizers/
openai.rs

1//! OpenAI tokenization using tiktoken-rs
2
3use crate::tokenizers::{ModelInfo, Tokenizer};
4use anyhow::{Context, Result};
5use tiktoken_rs::CoreBPE;
6
7/// OpenAI tokenizer using tiktoken-rs
8pub struct OpenAITokenizer {
9    bpe: CoreBPE,
10    model_info: ModelInfo,
11}
12
13impl OpenAITokenizer {
14    /// Create a new OpenAI tokenizer for the given encoding
15    pub fn new(encoding_name: &str, model_info: ModelInfo) -> Result<Self> {
16        let tokenizer_enum = match encoding_name {
17            "o200k_base" => tiktoken_rs::tokenizer::Tokenizer::O200kBase,
18            "cl100k_base" => tiktoken_rs::tokenizer::Tokenizer::Cl100kBase,
19            "p50k_base" => tiktoken_rs::tokenizer::Tokenizer::P50kBase,
20            "r50k_base" => tiktoken_rs::tokenizer::Tokenizer::R50kBase,
21            "gpt2" => tiktoken_rs::tokenizer::Tokenizer::Gpt2,
22            _ => anyhow::bail!("Unsupported encoding: {}", encoding_name),
23        };
24
25        let bpe = tiktoken_rs::get_bpe_from_tokenizer(tokenizer_enum)
26            .with_context(|| format!("Failed to load encoding: {}", encoding_name))?;
27
28        Ok(Self { bpe, model_info })
29    }
30}
31
32impl Tokenizer for OpenAITokenizer {
33    fn count_tokens(&self, text: &str) -> Result<usize> {
34        let tokens = self.bpe.encode_with_special_tokens(text);
35        Ok(tokens.len())
36    }
37
38    fn get_model_info(&self) -> ModelInfo {
39        self.model_info.clone()
40    }
41}
42
43#[cfg(test)]
44mod tests {
45    use super::*;
46
47    #[test]
48    fn test_basic_tokenization() {
49        let model_info = ModelInfo {
50            name: "gpt-4".to_string(),
51            encoding: "cl100k_base".to_string(),
52            context_window: 128000,
53            description: "GPT-4 model".to_string(),
54        };
55
56        let tokenizer = OpenAITokenizer::new("cl100k_base", model_info).unwrap();
57        let count = tokenizer.count_tokens("Hello world").unwrap();
58        assert_eq!(count, 2);
59    }
60
61    #[test]
62    fn test_empty_string() {
63        let model_info = ModelInfo {
64            name: "gpt-4".to_string(),
65            encoding: "cl100k_base".to_string(),
66            context_window: 128000,
67            description: "GPT-4 model".to_string(),
68        };
69
70        let tokenizer = OpenAITokenizer::new("cl100k_base", model_info).unwrap();
71        let count = tokenizer.count_tokens("").unwrap();
72        assert_eq!(count, 0);
73    }
74}