swiftide_integrations/tiktoken/
mod.rs1use std::sync::Arc;
8
9use anyhow::Result;
10use async_trait::async_trait;
11use swiftide_core::tokenizer::{Estimatable, EstimateTokens};
12use tiktoken_rs::{CoreBPE, get_bpe_from_model, get_bpe_from_tokenizer, tokenizer::Tokenizer};
13
14#[derive(Clone)]
35pub struct TikToken {
36 bpe: Arc<CoreBPE>,
38}
39
40impl std::fmt::Debug for TikToken {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 f.debug_struct("TikToken").finish()
43 }
44}
45
46impl Default for TikToken {
47 fn default() -> Self {
48 Self::try_from_model("gpt-4o")
49 .expect("infallible; gpt-4o should be valid model for tiktoken")
50 }
51}
52
53impl TikToken {
54 pub fn try_from_model(model: impl AsRef<str>) -> Result<Self> {
60 let bpe = get_bpe_from_model(model.as_ref())?;
61 Ok(Self { bpe: Arc::new(bpe) })
62 }
63
64 pub fn try_from_tokenizer(tokenizer: Tokenizer) -> Result<Self> {
70 let bpe = get_bpe_from_tokenizer(tokenizer)?;
71 Ok(Self { bpe: Arc::new(bpe) })
72 }
73}
74
75#[async_trait]
76impl EstimateTokens for TikToken {
77 async fn estimate(&self, value: impl Estimatable) -> Result<usize> {
78 Ok(self
79 .bpe
80 .encode_with_special_tokens(value.for_estimate().await?.as_ref())
81 .len()
82 + value.additional_tokens())
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use swiftide_core::{chat_completion::ChatMessage, prompt::Prompt};
89
90 use super::*;
91
92 #[tokio::test]
93 async fn test_estimate_tokens() {
94 let tokenizer = TikToken::try_from_model("gpt-4-0314").unwrap();
95 let prompt = Prompt::from("hello {{world}}");
96 let tokens = tokenizer.estimate(&prompt).await.unwrap();
97 assert_eq!(tokens, 4);
98 }
99
100 #[tokio::test]
101 async fn test_estimate_tokens_from_tokenizer() {
102 let tokenizer = TikToken::try_from_tokenizer(Tokenizer::O200kBase).unwrap();
103 let prompt = "hello {{world}}";
104 let tokens = tokenizer.estimate(prompt).await.unwrap();
105 assert_eq!(tokens, 4);
106 }
107
108 #[tokio::test]
109 async fn test_estimate_chat_messages() {
110 let messages = vec![
111 ChatMessage::new_user("hello ".repeat(10)),
112 ChatMessage::new_system("world"),
113 ];
114
115 let tokenizer = TikToken::try_from_model("gpt-4-0314").unwrap();
118 dbg!(messages.as_slice().for_estimate().await.unwrap());
119
120 assert_eq!(tokenizer.estimate(messages.as_slice()).await.unwrap(), 23);
121 }
122}