toonify_core/
tokens.rs

1use once_cell::sync::OnceCell;
2use tiktoken_rs::{CoreBPE, cl100k_base, o200k_base};
3
4use crate::error::ToonifyError;
5
6#[derive(Clone, Copy, Debug, Eq, PartialEq)]
7pub enum TokenModel {
8    Cl100k,
9    O200k,
10}
11
12impl std::fmt::Display for TokenModel {
13    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14        match self {
15            TokenModel::Cl100k => write!(f, "cl100k_base"),
16            TokenModel::O200k => write!(f, "o200k_base"),
17        }
18    }
19}
20
21static CL100K: OnceCell<CoreBPE> = OnceCell::new();
22static O200K: OnceCell<CoreBPE> = OnceCell::new();
23
24pub fn count_tokens(text: &str, model: TokenModel) -> Result<usize, ToonifyError> {
25    let tokenizer = get_tokenizer(model)?;
26    Ok(tokenizer.encode_ordinary(text).len())
27}
28
29fn get_tokenizer(model: TokenModel) -> Result<&'static CoreBPE, ToonifyError> {
30    match model {
31        TokenModel::Cl100k => CL100K.get_or_try_init(|| {
32            cl100k_base().map_err(|err| ToonifyError::tokenizer(err.to_string()))
33        }),
34        TokenModel::O200k => O200K.get_or_try_init(|| {
35            o200k_base().map_err(|err| ToonifyError::tokenizer(err.to_string()))
36        }),
37    }
38}
39
40#[cfg(test)]
41mod tests {
42    use super::*;
43
44    #[test]
45    fn counts_tokens_for_simple_text() {
46        let text = "Hello world!";
47        let cl = count_tokens(text, TokenModel::Cl100k).unwrap();
48        let o2 = count_tokens(text, TokenModel::O200k).unwrap();
49        assert!(cl > 0);
50        assert!(o2 > 0);
51    }
52}