swiftide_integrations/tiktoken/
mod.rs

1//! Use tiktoken-rs to estimate token count on various common Swiftide types
2//!
3//! Intended to be used for openai models.
4//!
5//! Note that the library is heavy on the unwraps.
6
7use std::sync::Arc;
8
9use anyhow::Result;
10use async_trait::async_trait;
11use swiftide_core::tokenizer::{Estimatable, EstimateTokens};
12use tiktoken_rs::{CoreBPE, get_bpe_from_model, get_bpe_from_tokenizer, tokenizer::Tokenizer};
13
14/// A tiktoken based tokenizer for openai models. Can also be used for other models.
15///
16/// Implements `EstimateTokens` for various swiftide types (prompts, chat messages, lists of chat
17/// messages) and regular strings.
18///
19/// Estimates are estimates; not exact counts.
20///
21/// # Example
22///
23/// ```no_run
24/// # use swiftide_core::tokenizer::EstimateTokens;
25/// # use swiftide_integrations::tiktoken::TikToken;
26///
27/// # async fn test() {
28/// let tokenizer = TikToken::try_from_model("gpt-4-0314").unwrap();
29/// let estimate = tokenizer.estimate("hello {{world}}").await.unwrap();
30///
31/// assert_eq!(estimate, 4);
32/// # }
33/// ```
34#[derive(Clone)]
35pub struct TikToken {
36    /// The tiktoken model to use
37    bpe: Arc<CoreBPE>,
38}
39
40impl std::fmt::Debug for TikToken {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        f.debug_struct("TikToken").finish()
43    }
44}
45
46impl Default for TikToken {
47    fn default() -> Self {
48        Self::try_from_model("gpt-4o")
49            .expect("infallible; gpt-4o should be valid model for tiktoken")
50    }
51}
52
53impl TikToken {
54    /// Build a `TikToken` from an openai model name
55    ///
56    /// # Errors
57    ///
58    /// Errors if the tokenizer cannot be found from the model or it cannot be build
59    pub fn try_from_model(model: impl AsRef<str>) -> Result<Self> {
60        let bpe = get_bpe_from_model(model.as_ref())?;
61        Ok(Self { bpe: Arc::new(bpe) })
62    }
63
64    /// Build a `TikToken` from a `tiktoken_rs::tiktoken::Tokenizer`
65    ///
66    /// # Errors
67    ///
68    /// Errors if the tokenizer cannot be build
69    pub fn try_from_tokenizer(tokenizer: Tokenizer) -> Result<Self> {
70        let bpe = get_bpe_from_tokenizer(tokenizer)?;
71        Ok(Self { bpe: Arc::new(bpe) })
72    }
73}
74
75#[async_trait]
76impl EstimateTokens for TikToken {
77    async fn estimate(&self, value: impl Estimatable) -> Result<usize> {
78        Ok(self
79            .bpe
80            .encode_with_special_tokens(value.for_estimate().await?.as_ref())
81            .len()
82            + value.additional_tokens())
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use swiftide_core::{chat_completion::ChatMessage, prompt::Prompt};
89
90    use super::*;
91
92    #[tokio::test]
93    async fn test_estimate_tokens() {
94        let tokenizer = TikToken::try_from_model("gpt-4-0314").unwrap();
95        let prompt = Prompt::from("hello {{world}}");
96        let tokens = tokenizer.estimate(&prompt).await.unwrap();
97        assert_eq!(tokens, 4);
98    }
99
100    #[tokio::test]
101    async fn test_estimate_tokens_from_tokenizer() {
102        let tokenizer = TikToken::try_from_tokenizer(Tokenizer::O200kBase).unwrap();
103        let prompt = "hello {{world}}";
104        let tokens = tokenizer.estimate(prompt).await.unwrap();
105        assert_eq!(tokens, 4);
106    }
107
108    #[tokio::test]
109    async fn test_estimate_chat_messages() {
110        let messages = vec![
111            ChatMessage::new_user("hello ".repeat(10)),
112            ChatMessage::new_system("world"),
113        ];
114
115        // 11x hello + 1x world + 2x 4 per message + 1x 3 for full + 2 whatever = 23
116
117        let tokenizer = TikToken::try_from_model("gpt-4-0314").unwrap();
118        dbg!(messages.as_slice().for_estimate().await.unwrap());
119
120        assert_eq!(tokenizer.estimate(messages.as_slice()).await.unwrap(), 23);
121    }
122}