1use once_cell::sync::OnceCell;
2use tiktoken_rs::{CoreBPE, cl100k_base, o200k_base};
3
4use crate::error::ToonifyError;
5
6#[derive(Clone, Copy, Debug, Eq, PartialEq)]
7pub enum TokenModel {
8 Cl100k,
9 O200k,
10}
11
12impl std::fmt::Display for TokenModel {
13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
14 match self {
15 TokenModel::Cl100k => write!(f, "cl100k_base"),
16 TokenModel::O200k => write!(f, "o200k_base"),
17 }
18 }
19}
20
21static CL100K: OnceCell<CoreBPE> = OnceCell::new();
22static O200K: OnceCell<CoreBPE> = OnceCell::new();
23
24pub fn count_tokens(text: &str, model: TokenModel) -> Result<usize, ToonifyError> {
25 let tokenizer = get_tokenizer(model)?;
26 Ok(tokenizer.encode_ordinary(text).len())
27}
28
29fn get_tokenizer(model: TokenModel) -> Result<&'static CoreBPE, ToonifyError> {
30 match model {
31 TokenModel::Cl100k => CL100K.get_or_try_init(|| {
32 cl100k_base().map_err(|err| ToonifyError::tokenizer(err.to_string()))
33 }),
34 TokenModel::O200k => O200K.get_or_try_init(|| {
35 o200k_base().map_err(|err| ToonifyError::tokenizer(err.to_string()))
36 }),
37 }
38}
39
40#[cfg(test)]
41mod tests {
42 use super::*;
43
44 #[test]
45 fn counts_tokens_for_simple_text() {
46 let text = "Hello world!";
47 let cl = count_tokens(text, TokenModel::Cl100k).unwrap();
48 let o2 = count_tokens(text, TokenModel::O200k).unwrap();
49 assert!(cl > 0);
50 assert!(o2 > 0);
51 }
52}