swiftide_core/
tokenizer.rs

1use std::borrow::Cow;
2
3use anyhow::Result;
4use async_trait::async_trait;
5
6use crate::{chat_completion::ChatMessage, prompt::Prompt};
7
8/// Estimate the number of tokens in a given value.
9#[async_trait]
10pub trait EstimateTokens {
11    async fn estimate(&self, value: impl Estimatable) -> Result<usize>;
12}
13
14/// A value that can be estimated for the number of tokens it contains.
15#[async_trait]
16pub trait Estimatable: Send + Sync {
17    async fn for_estimate(&self) -> Result<Cow<'_, str>>;
18
19    /// Optionally return extra tokens that should be added to the estimate.
20    fn additional_tokens(&self) -> usize {
21        0
22    }
23}
24
25#[async_trait]
26impl Estimatable for &str {
27    async fn for_estimate(&self) -> Result<Cow<'_, str>> {
28        Ok(Cow::Borrowed(self))
29    }
30}
31
32#[async_trait]
33impl Estimatable for String {
34    async fn for_estimate(&self) -> Result<Cow<'_, str>> {
35        Ok(Cow::Borrowed(self.as_str()))
36    }
37}
38
39#[async_trait]
40impl Estimatable for &Prompt {
41    async fn for_estimate(&self) -> Result<Cow<'_, str>> {
42        let rendered = self.render()?;
43        Ok(Cow::Owned(rendered))
44    }
45}
46
47#[async_trait]
48impl Estimatable for &ChatMessage {
49    async fn for_estimate(&self) -> Result<Cow<'_, str>> {
50        Ok(match self {
51            ChatMessage::User(msg) | ChatMessage::Summary(msg) | ChatMessage::System(msg) => {
52                Cow::Borrowed(msg)
53            }
54            ChatMessage::Assistant(msg, vec) => {
55                // Note that this is not super accurate.
56                //
57                // It's a bit verbose to avoid unnecessary allocations. Is what it is.
58                let tool_calls = vec.as_ref().map(|vec| {
59                    vec.iter()
60                        .map(std::string::ToString::to_string)
61                        .collect::<Vec<String>>()
62                        .join(" ")
63                });
64
65                if let Some(msg) = msg {
66                    if let Some(tool_calls) = tool_calls {
67                        format!("{msg} {tool_calls}").into()
68                    } else {
69                        msg.into()
70                    }
71                } else if let Some(tool_calls) = tool_calls {
72                    tool_calls.into()
73                } else {
74                    "None".into()
75                }
76            }
77            ChatMessage::ToolOutput(tool_call, tool_output) => {
78                let tool_call_id = tool_call.id();
79                let tool_output_content = tool_output.content().unwrap_or_default();
80
81                format!("{tool_call_id} {tool_output_content}").into()
82            }
83        })
84    }
85
86    // 4 each for the role
87    //
88    // See https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
89    fn additional_tokens(&self) -> usize {
90        4
91    }
92}
93
94#[async_trait]
95impl Estimatable for &[ChatMessage] {
96    async fn for_estimate(&self) -> Result<Cow<'_, str>> {
97        let mut total = 0;
98        for msg in *self {
99            total += msg.for_estimate().await?.len();
100        }
101
102        Ok(total.to_string().into())
103    }
104
105    // Apparently every reply is primed with a <|start|>assistant<|message|>
106    fn additional_tokens(&self) -> usize {
107        self.iter().map(|m| m.additional_tokens()).sum::<usize>() + 3
108    }
109}