Skip to main content

tiktoken_rs/
api.rs

1use anyhow::{anyhow, Result};
2
3use crate::{
4    cl100k_base_singleton,
5    model::get_context_size,
6    o200k_base_singleton, o200k_harmony_singleton, p50k_base_singleton, p50k_edit_singleton,
7    r50k_base_singleton,
8    tokenizer::{get_tokenizer, Tokenizer},
9    CoreBPE,
10};
11
12/// Returns the maximum number of tokens available for a text completion, given a model and prompt.
13///
14/// This is for legacy text/prompt completions (single string input). For chat completions,
15/// use [`get_chat_completion_max_tokens`] instead.
16///
17/// Calculates `context_size - prompt_tokens` for the given model.
18///
19/// # Arguments
20///
21/// * `model` - A string slice representing the model name, e.g., `"gpt-4o"`.
22/// * `prompt` - A string slice containing the prompt text.
23///
24/// # Errors
25///
26/// Returns an error if no tokenizer is found for the given model.
27///
28/// # Examples
29///
30/// ```
31/// use tiktoken_rs::get_text_completion_max_tokens;
32///
33/// let max_tokens = get_text_completion_max_tokens("gpt-4o", "Translate to French: '").unwrap();
34/// ```
35pub fn get_text_completion_max_tokens(model: &str, prompt: &str) -> Result<usize> {
36    let context_size = get_context_size(model)
37        .ok_or_else(|| anyhow!("Unknown context size for model {}", model))?;
38    let tokenizer =
39        get_tokenizer(model).ok_or_else(|| anyhow!("No tokenizer found for model {}", model))?;
40    let bpe = bpe_singleton(tokenizer);
41    let prompt_tokens = bpe.count_with_special_tokens(prompt);
42    Ok(context_size.saturating_sub(prompt_tokens))
43}
44
45/// Use [`get_text_completion_max_tokens`] instead.
46#[deprecated(since = "0.10.0", note = "renamed to `get_text_completion_max_tokens`")]
47pub fn get_completion_max_tokens(model: &str, prompt: &str) -> Result<usize> {
48    get_text_completion_max_tokens(model, prompt)
49}
50
51/// The name and arguments of a function that should be called, as generated by the model.
52#[derive(Debug, Default, Clone, PartialEq, Eq)]
53pub struct FunctionCall {
54    /// The name of the function to call.
55    pub name: String,
56    /// The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.
57    pub arguments: String,
58}
59
60#[derive(Debug, Default, Clone, PartialEq, Eq)]
61pub struct ChatCompletionRequestMessage {
62    /// The role of the messages author. One of `system`, `developer`, `user`, `assistant`, `tool`, or `function`.
63    pub role: String,
64    /// The contents of the message.
65    /// `content` is required for all messages except assistant messages with function calls.
66    pub content: Option<String>,
67    /// The name of the author of this message. `name` is required if role is function,
68    /// and it should be the name of the function whose response is in the `content`.
69    /// May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters.
70    pub name: Option<String>,
71    /// The name and arguments of a function that should be called, as generated by the model.
72    pub function_call: Option<FunctionCall>,
73    /// Tool calls generated by the model, represented as FunctionCall structs.
74    /// Tool call IDs and type discriminators are not preserved.
75    pub tool_calls: Vec<FunctionCall>,
76    /// The refusal message generated by the model.
77    pub refusal: Option<String>,
78}
79
80/// Based on <https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb>
81///
82/// num_tokens_from_messages returns the number of tokens required to encode the given messages into
83/// the given model. This is used to estimate the number of tokens that will be used for chat
84/// completion.
85///
86/// # Arguments
87///
88/// * model: A string slice containing the model name (e.g. "gpt-3.5").
89/// * messages: A slice of ChatCompletionRequestMessage structs representing chat messages.
90///
91/// # Returns
92///
93/// * `Result<usize>`: A Result containing the total number of tokens needed to encode the messages
94///   for the specified model, or an error if the tokenizer for the model is not found or not supported.
95///
96/// # Errors
97///
98/// This function will return an error if:
99///
100/// * The tokenizer for the specified model is not found.
101/// * The tokenizer is not a supported chat model (i.e., not one of Cl100kBase, O200kBase, or O200kHarmony).
102///
103pub fn num_tokens_from_messages(
104    model: &str,
105    messages: &[ChatCompletionRequestMessage],
106) -> Result<usize> {
107    let tokenizer =
108        get_tokenizer(model).ok_or_else(|| anyhow!("No tokenizer found for model {}", model))?;
109    if tokenizer != Tokenizer::Cl100kBase
110        && tokenizer != Tokenizer::O200kBase
111        && tokenizer != Tokenizer::O200kHarmony
112    {
113        anyhow::bail!(
114            "Chat token counting is not supported for model {:?} (tokenizer {:?}). \
115             Supported tokenizers: Cl100kBase, O200kBase, O200kHarmony.",
116            model,
117            tokenizer
118        )
119    }
120    let bpe = bpe_singleton(tokenizer);
121
122    // Token overhead constants adapted from the OpenAI cookbook:
123    // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
124    //
125    // tokens_per_message: overhead tokens per message for framing (3 for current models)
126    // tokens_per_name: extra tokens when a `name` field is present (1 for current models)
127    //
128    // The gpt-3.5-turbo-0301 branch (4, -1) was removed from the cookbook in later revisions;
129    // we retain it for backward compatibility with that specific snapshot.
130    //
131    // FUNCTION_CALL_OVERHEAD: 1 extra token per function/tool call (heuristic)
132    // REPLY_PRIMING: 3 tokens added once at the end (per cookbook: <|start|>assistant<|message|>)
133    const FUNCTION_CALL_OVERHEAD: i32 = 1;
134    const REPLY_PRIMING: i32 = 3;
135
136    let (tokens_per_message, tokens_per_name) = if model == "gpt-3.5-turbo-0301" {
137        (4, -1)
138    } else {
139        (3, 1)
140    };
141
142    let mut num_tokens: i32 = 0;
143    for message in messages {
144        num_tokens += tokens_per_message;
145        num_tokens += bpe.count_with_special_tokens(&message.role) as i32;
146        if let Some(content) = &message.content {
147            num_tokens += bpe.count_with_special_tokens(content) as i32;
148        }
149        if let Some(name) = &message.name {
150            num_tokens += bpe.count_with_special_tokens(name) as i32;
151            num_tokens += tokens_per_name;
152        }
153        if let Some(function_call) = &message.function_call {
154            num_tokens += bpe.count_with_special_tokens(&function_call.name) as i32;
155            num_tokens += bpe.count_with_special_tokens(&function_call.arguments) as i32;
156            num_tokens += FUNCTION_CALL_OVERHEAD;
157        }
158        for tool_call in &message.tool_calls {
159            num_tokens += bpe.count_with_special_tokens(&tool_call.name) as i32;
160            num_tokens += bpe.count_with_special_tokens(&tool_call.arguments) as i32;
161            num_tokens += FUNCTION_CALL_OVERHEAD;
162        }
163        if let Some(refusal) = &message.refusal {
164            num_tokens += bpe.count_with_special_tokens(refusal) as i32;
165        }
166    }
167    num_tokens += REPLY_PRIMING;
168    Ok(num_tokens as usize)
169}
170
171/// Calculates the maximum number of tokens available for chat completion based on the model and messages provided.
172///
173/// This function determines the number of tokens left for a chat completion task, given the model and a slice of
174/// chat completion request messages. It first retrieves the tokenizer for the given model and checks if chat completion
175/// is supported. Then, it calculates the number of tokens in the existing messages using the appropriate tokenizer.
176///
177/// # Arguments
178///
179/// * `model` - A string slice representing the model name, e.g., "gpt-3.5-turbo".
180/// * `messages` - A slice of `ChatCompletionRequestMessage` instances containing the chat context.
181///
182/// # Errors
183///
184/// This function returns an error in the following cases:
185///
186/// * If there is no tokenizer found for the specified model.
187/// * If chat completion is not supported for the specified model.
188/// * If there is a failure in creating a `CoreBPE` instance for the specified tokenizer.
189///
190/// # Example
191///
192/// ```
193/// use tiktoken_rs::{get_chat_completion_max_tokens, ChatCompletionRequestMessage};
194///
195/// let model = "gpt-3.5-turbo";
196/// let messages = vec![
197///     ChatCompletionRequestMessage {
198///         content: Some("You are a helpful assistant that only speaks French.".to_string()),
199///         role: "system".to_string(),
200///         ..Default::default()
201///     },
202///     ChatCompletionRequestMessage {
203///         content: Some("Hello, how are you?".to_string()),
204///         role: "user".to_string(),
205///         ..Default::default()
206///     },
207///     ChatCompletionRequestMessage {
208///         content: Some("Parlez-vous francais?".to_string()),
209///         role: "system".to_string(),
210///         ..Default::default()
211///     },
212/// ];
213/// let max_tokens = get_chat_completion_max_tokens(model, &messages).unwrap();
214/// ```
215///
216/// # Returns
217///
218/// If successful, the function returns a `Result` containing the maximum number of tokens available for chat completion,
219/// based on the given model and messages.
220pub fn get_chat_completion_max_tokens(
221    model: &str,
222    messages: &[ChatCompletionRequestMessage],
223) -> Result<usize> {
224    let context_size = get_context_size(model)
225        .ok_or_else(|| anyhow!("Unknown context size for model {}", model))?;
226    let prompt_tokens = num_tokens_from_messages(model, messages)?;
227    Ok(context_size.saturating_sub(prompt_tokens))
228}
229
230fn bpe_singleton(tokenizer: Tokenizer) -> &'static CoreBPE {
231    match tokenizer {
232        Tokenizer::O200kHarmony => o200k_harmony_singleton(),
233        Tokenizer::O200kBase => o200k_base_singleton(),
234        Tokenizer::Cl100kBase => cl100k_base_singleton(),
235        Tokenizer::R50kBase => r50k_base_singleton(),
236        Tokenizer::P50kBase => p50k_base_singleton(),
237        Tokenizer::P50kEdit => p50k_edit_singleton(),
238        Tokenizer::Gpt2 => r50k_base_singleton(),
239    }
240}
241
242/// Returns a cached reference to the BPE tokenizer for the given model name.
243///
244/// Looks up which tokenizer the model uses, then returns a `&'static CoreBPE` singleton.
245/// The singleton is initialized once and reused for all subsequent calls.
246///
247/// # Arguments
248///
249/// * `model` - A model name, e.g., `"gpt-4o"`, `"gpt-3.5-turbo"`, `"o3-mini"`.
250///
251/// # Errors
252///
253/// Returns an error if no tokenizer is found for the given model name.
254///
255/// # Examples
256///
257/// ```
258/// use tiktoken_rs::bpe_for_model;
259///
260/// let bpe = bpe_for_model("gpt-4o").unwrap();
261/// let tokens = bpe.encode_with_special_tokens("hello world");
262/// ```
263pub fn bpe_for_model(model: &str) -> Result<&'static CoreBPE> {
264    let tokenizer =
265        get_tokenizer(model).ok_or_else(|| anyhow!("No tokenizer found for model {}", model))?;
266    bpe_for_tokenizer(tokenizer)
267}
268
269/// Use [`bpe_for_model`] instead.
270#[deprecated(since = "0.10.0", note = "renamed to `bpe_for_model`")]
271pub fn get_bpe_from_model(model: &str) -> Result<&'static CoreBPE> {
272    bpe_for_model(model)
273}
274
275/// Returns a cached reference to the BPE tokenizer for the given tokenizer type.
276///
277/// Returns a `&'static CoreBPE` singleton. The singleton is initialized once and reused
278/// for all subsequent calls.
279///
280/// # Arguments
281///
282/// * `tokenizer` - A [`Tokenizer`] enum variant.
283///
284/// # Examples
285///
286/// ```
287/// use tiktoken_rs::bpe_for_tokenizer;
288/// use tiktoken_rs::tokenizer::Tokenizer;
289///
290/// let bpe = bpe_for_tokenizer(Tokenizer::O200kBase).unwrap();
291/// let tokens = bpe.encode_with_special_tokens("hello world");
292/// ```
293pub fn bpe_for_tokenizer(tokenizer: Tokenizer) -> Result<&'static CoreBPE> {
294    Ok(bpe_singleton(tokenizer))
295}
296
297/// Use [`bpe_for_tokenizer`] instead.
298#[deprecated(since = "0.10.0", note = "renamed to `bpe_for_tokenizer`")]
299pub fn get_bpe_from_tokenizer(tokenizer: Tokenizer) -> Result<&'static CoreBPE> {
300    bpe_for_tokenizer(tokenizer)
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_bpe_for_tokenizer() {
309        let bpe = bpe_for_tokenizer(Tokenizer::Cl100kBase).unwrap();
310        assert_eq!(bpe.decode(&[15339]).unwrap(), "hello");
311    }
312
313    #[test]
314    fn test_num_tokens_from_messages() {
315        let messages = vec![
316            ChatCompletionRequestMessage {
317                role: "system".to_string(),
318                name: None,
319                content: Some("You are a helpful, pattern-following assistant that translates corporate jargon into plain English.".to_string()),
320                ..Default::default()
321            },
322            ChatCompletionRequestMessage {
323                role: "system".to_string(),
324                name: Some("example_user".to_string()),
325                content: Some("New synergies will help drive top-line growth.".to_string()),
326                ..Default::default()
327            },
328            ChatCompletionRequestMessage {
329                role: "system".to_string(),
330                name: Some("example_assistant".to_string()),
331                content: Some("Things working well together will increase revenue.".to_string()),
332                ..Default::default()
333            },
334            ChatCompletionRequestMessage {
335                role: "system".to_string(),
336                name: Some("example_user".to_string()),
337                content: Some("Let's circle back when we have more bandwidth to touch base on opportunities for increased leverage.".to_string()),
338                ..Default::default()
339            },
340            ChatCompletionRequestMessage {
341                role: "system".to_string(),
342                name: Some("example_assistant".to_string()),
343                content: Some("Let's talk later when we're less busy about how to do better.".to_string()),
344                ..Default::default()
345            },
346            ChatCompletionRequestMessage {
347                role: "user".to_string(),
348                name: None,
349                content: Some("This late pivot means we don't have time to boil the ocean for the client deliverable.".to_string()),
350                ..Default::default()
351            },
352        ];
353        let num_tokens = num_tokens_from_messages("gpt-3.5-turbo-0301", &messages).unwrap();
354        assert_eq!(num_tokens, 127);
355
356        let num_tokens = num_tokens_from_messages("gpt-4-0314", &messages).unwrap();
357        assert_eq!(num_tokens, 129);
358
359        let num_tokens = num_tokens_from_messages("gpt-4o-2024-05-13", &messages).unwrap();
360        assert_eq!(num_tokens, 124);
361
362        // Newer gpt-3.5 snapshots use (3, 1) like gpt-4, not (4, -1) like gpt-3.5-turbo-0301
363        let num_tokens = num_tokens_from_messages("gpt-3.5-turbo-0125", &messages).unwrap();
364        assert_eq!(num_tokens, 129);
365    }
366
367    #[test]
368    fn test_num_tokens_from_messages_with_function_call() {
369        let messages = vec![
370            ChatCompletionRequestMessage {
371                role: "system".to_string(),
372                content: Some("You are a friendly chatbot.\n".to_string()),
373                name: None,
374                ..Default::default()
375            },
376            ChatCompletionRequestMessage {
377                role: "assistant".to_string(),
378                content: Some("Hello, I am a friendly chatbot!\n".to_string()),
379                name: None,
380                ..Default::default()
381            },
382            ChatCompletionRequestMessage {
383                role: "user".to_string(),
384                content: Some("What is the weather in New York?".to_string()),
385                name: None,
386                ..Default::default()
387            },
388            ChatCompletionRequestMessage {
389                role: "assistant".to_string(),
390                content: Some(String::new()),
391                function_call: Some(FunctionCall {
392                    name: "get_weather".to_string(),
393                    arguments: "{\n  \"city\": \"New York\"\n}".to_string(),
394                }),
395                ..Default::default()
396            },
397            ChatCompletionRequestMessage {
398                role: "function".to_string(),
399                content: Some(
400                    "{\"temperature\": 72, \"conditions\": \"partly_cloudy\"}".to_string(),
401                ),
402                name: Some("get_weather".to_string()),
403                ..Default::default()
404            },
405        ];
406        // Validated against OpenAI API response (issue #40)
407        let num_tokens = num_tokens_from_messages("gpt-4-0613", &messages).unwrap();
408        assert_eq!(num_tokens, 78);
409    }
410
411    #[test]
412    fn test_num_tokens_from_messages_with_tool_calls() {
413        let messages_with = vec![ChatCompletionRequestMessage {
414            role: "assistant".to_string(),
415            tool_calls: vec![FunctionCall {
416                name: "get_weather".to_string(),
417                arguments: r#"{"city": "Paris"}"#.to_string(),
418            }],
419            ..Default::default()
420        }];
421        let messages_without = vec![ChatCompletionRequestMessage {
422            role: "assistant".to_string(),
423            ..Default::default()
424        }];
425        let with = num_tokens_from_messages("gpt-4o", &messages_with).unwrap();
426        let without = num_tokens_from_messages("gpt-4o", &messages_without).unwrap();
427        assert!(
428            with > without,
429            "tool_calls should contribute tokens: {with} vs {without}"
430        );
431    }
432
433    #[test]
434    fn test_num_tokens_from_messages_with_multiple_tool_calls() {
435        let single = vec![ChatCompletionRequestMessage {
436            role: "assistant".to_string(),
437            tool_calls: vec![FunctionCall {
438                name: "get_weather".to_string(),
439                arguments: r#"{"city": "Paris"}"#.to_string(),
440            }],
441            ..Default::default()
442        }];
443        let double = vec![ChatCompletionRequestMessage {
444            role: "assistant".to_string(),
445            tool_calls: vec![
446                FunctionCall {
447                    name: "get_weather".to_string(),
448                    arguments: r#"{"city": "Paris"}"#.to_string(),
449                },
450                FunctionCall {
451                    name: "get_weather".to_string(),
452                    arguments: r#"{"city": "London"}"#.to_string(),
453                },
454            ],
455            ..Default::default()
456        }];
457        let single_tokens = num_tokens_from_messages("gpt-4o", &single).unwrap();
458        let double_tokens = num_tokens_from_messages("gpt-4o", &double).unwrap();
459        assert!(
460            double_tokens > single_tokens,
461            "multiple tool_calls should each contribute tokens: {double_tokens} vs {single_tokens}"
462        );
463    }
464
465    #[test]
466    fn test_num_tokens_from_messages_with_refusal() {
467        let messages_with = vec![ChatCompletionRequestMessage {
468            role: "assistant".to_string(),
469            refusal: Some("I cannot help with that request.".to_string()),
470            ..Default::default()
471        }];
472        let messages_without = vec![ChatCompletionRequestMessage {
473            role: "assistant".to_string(),
474            ..Default::default()
475        }];
476        let with = num_tokens_from_messages("gpt-4o", &messages_with).unwrap();
477        let without = num_tokens_from_messages("gpt-4o", &messages_without).unwrap();
478        assert!(
479            with > without,
480            "refusal should contribute tokens: {with} vs {without}"
481        );
482    }
483
484    #[test]
485    fn test_num_tokens_from_messages_repeated_calls_consistent() {
486        let messages = vec![ChatCompletionRequestMessage {
487            role: "user".to_string(),
488            content: Some("Hello, world!".to_string()),
489            ..Default::default()
490        }];
491        let first = num_tokens_from_messages("gpt-4o", &messages).unwrap();
492        for _ in 0..5 {
493            let result = num_tokens_from_messages("gpt-4o", &messages).unwrap();
494            assert_eq!(first, result);
495        }
496    }
497
498    #[test]
499    fn test_text_completion_max_tokens_repeated_calls_consistent() {
500        let first = get_text_completion_max_tokens("gpt-4o", "Hello, world!").unwrap();
501        for _ in 0..5 {
502            let result = get_text_completion_max_tokens("gpt-4o", "Hello, world!").unwrap();
503            assert_eq!(first, result);
504        }
505    }
506
507    #[test]
508    fn test_bpe_singleton_matches_fresh_bpe() {
509        let singleton = bpe_singleton(Tokenizer::Cl100kBase);
510        let fresh = bpe_for_tokenizer(Tokenizer::Cl100kBase).unwrap();
511        let text = "The quick brown fox jumps over the lazy dog";
512        assert_eq!(
513            singleton.encode_with_special_tokens(text),
514            fresh.encode_with_special_tokens(text),
515        );
516    }
517
518    #[test]
519    fn test_get_chat_completion_max_tokens() {
520        let model = "gpt-3.5-turbo";
521        let messages = vec![
522            ChatCompletionRequestMessage {
523                content: Some("You are a helpful assistant that only speaks French.".to_string()),
524                role: "system".to_string(),
525                name: None,
526                ..Default::default()
527            },
528            ChatCompletionRequestMessage {
529                content: Some("Hello, how are you?".to_string()),
530                role: "user".to_string(),
531                name: None,
532                ..Default::default()
533            },
534            ChatCompletionRequestMessage {
535                content: Some("Parlez-vous francais?".to_string()),
536                role: "system".to_string(),
537                name: None,
538                ..Default::default()
539            },
540        ];
541        let max_tokens = get_chat_completion_max_tokens(model, &messages).unwrap();
542        assert!(max_tokens > 0);
543    }
544
545    #[test]
546    fn test_text_completion_max_tokens() {
547        let model = "gpt-3.5-turbo";
548        let prompt = "Translate the following English text to French: '";
549        let max_tokens = get_text_completion_max_tokens(model, prompt).unwrap();
550        assert!(max_tokens > 0);
551    }
552}
553
554/// This module provide support for working with the `async_openai` crate.
555#[cfg(feature = "async-openai")]
556pub mod async_openai {
557    use anyhow::Result;
558    use async_openai::types::chat::{
559        ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageContent,
560        ChatCompletionRequestAssistantMessageContentPart,
561        ChatCompletionRequestDeveloperMessageContent,
562        ChatCompletionRequestDeveloperMessageContentPart, ChatCompletionRequestMessage,
563        ChatCompletionRequestSystemMessageContent, ChatCompletionRequestSystemMessageContentPart,
564        ChatCompletionRequestToolMessageContent, ChatCompletionRequestToolMessageContentPart,
565        ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
566        FunctionCall,
567    };
568
569    impl From<&FunctionCall> for super::FunctionCall {
570        fn from(f: &FunctionCall) -> Self {
571            Self {
572                name: f.name.clone(),
573                arguments: f.arguments.clone(),
574            }
575        }
576    }
577
578    fn join_texts(texts: Vec<String>) -> Option<String> {
579        if texts.is_empty() {
580            None
581        } else {
582            Some(texts.join(""))
583        }
584    }
585
586    fn system_content_text(content: &ChatCompletionRequestSystemMessageContent) -> Option<String> {
587        match content {
588            ChatCompletionRequestSystemMessageContent::Text(s) => Some(s.clone()),
589            ChatCompletionRequestSystemMessageContent::Array(parts) => join_texts(
590                parts
591                    .iter()
592                    .map(|ChatCompletionRequestSystemMessageContentPart::Text(t)| t.text.clone())
593                    .collect(),
594            ),
595        }
596    }
597
598    fn developer_content_text(
599        content: &ChatCompletionRequestDeveloperMessageContent,
600    ) -> Option<String> {
601        match content {
602            ChatCompletionRequestDeveloperMessageContent::Text(s) => Some(s.clone()),
603            ChatCompletionRequestDeveloperMessageContent::Array(parts) => join_texts(
604                parts
605                    .iter()
606                    .map(|ChatCompletionRequestDeveloperMessageContentPart::Text(t)| t.text.clone())
607                    .collect(),
608            ),
609        }
610    }
611
612    fn user_content_text(content: &ChatCompletionRequestUserMessageContent) -> Option<String> {
613        match content {
614            ChatCompletionRequestUserMessageContent::Text(s) => Some(s.clone()),
615            ChatCompletionRequestUserMessageContent::Array(parts) => join_texts(
616                parts
617                    .iter()
618                    .filter_map(|p| match p {
619                        ChatCompletionRequestUserMessageContentPart::Text(t) => {
620                            Some(t.text.clone())
621                        }
622                        ChatCompletionRequestUserMessageContentPart::ImageUrl(_)
623                        | ChatCompletionRequestUserMessageContentPart::InputAudio(_)
624                        | ChatCompletionRequestUserMessageContentPart::File(_) => None,
625                    })
626                    .collect(),
627            ),
628        }
629    }
630
631    fn assistant_content_text(
632        content: &ChatCompletionRequestAssistantMessageContent,
633    ) -> (Option<String>, Option<String>) {
634        match content {
635            ChatCompletionRequestAssistantMessageContent::Text(s) => (Some(s.clone()), None),
636            ChatCompletionRequestAssistantMessageContent::Array(parts) => {
637                let mut texts = Vec::new();
638                let mut refusals = Vec::new();
639                for p in parts {
640                    match p {
641                        ChatCompletionRequestAssistantMessageContentPart::Text(t) => {
642                            texts.push(t.text.clone());
643                        }
644                        ChatCompletionRequestAssistantMessageContentPart::Refusal(r) => {
645                            refusals.push(r.refusal.clone());
646                        }
647                    }
648                }
649                (join_texts(texts), join_texts(refusals))
650            }
651        }
652    }
653
654    fn tool_content_text(content: &ChatCompletionRequestToolMessageContent) -> Option<String> {
655        match content {
656            ChatCompletionRequestToolMessageContent::Text(s) => Some(s.clone()),
657            ChatCompletionRequestToolMessageContent::Array(parts) => join_texts(
658                parts
659                    .iter()
660                    .map(|ChatCompletionRequestToolMessageContentPart::Text(t)| t.text.clone())
661                    .collect(),
662            ),
663        }
664    }
665
666    fn extract_tool_calls(
667        tool_calls: &Option<Vec<ChatCompletionMessageToolCalls>>,
668    ) -> Vec<super::FunctionCall> {
669        tool_calls
670            .as_ref()
671            .map(|calls| {
672                calls
673                    .iter()
674                    .map(|tc| match tc {
675                        ChatCompletionMessageToolCalls::Function(f) => (&f.function).into(),
676                        ChatCompletionMessageToolCalls::Custom(c) => super::FunctionCall {
677                            name: c.custom_tool.name.clone(),
678                            arguments: c.custom_tool.input.clone(),
679                        },
680                    })
681                    .collect()
682            })
683            .unwrap_or_default()
684    }
685
686    #[allow(deprecated)]
687    impl From<&ChatCompletionRequestMessage> for super::ChatCompletionRequestMessage {
688        fn from(m: &ChatCompletionRequestMessage) -> Self {
689            match m {
690                ChatCompletionRequestMessage::System(msg) => Self {
691                    role: "system".to_string(),
692                    name: msg.name.clone(),
693                    content: Some(system_content_text(&msg.content).unwrap_or_default()),
694                    ..Default::default()
695                },
696                ChatCompletionRequestMessage::Developer(msg) => Self {
697                    role: "developer".to_string(),
698                    name: msg.name.clone(),
699                    content: Some(developer_content_text(&msg.content).unwrap_or_default()),
700                    ..Default::default()
701                },
702                ChatCompletionRequestMessage::User(msg) => Self {
703                    role: "user".to_string(),
704                    name: msg.name.clone(),
705                    content: Some(user_content_text(&msg.content).unwrap_or_default()),
706                    ..Default::default()
707                },
708                ChatCompletionRequestMessage::Assistant(msg) => {
709                    let (content, refusal) = msg
710                        .content
711                        .as_ref()
712                        .map(assistant_content_text)
713                        .unwrap_or_default();
714                    let refusal = refusal.or_else(|| msg.refusal.clone());
715                    Self {
716                        role: "assistant".to_string(),
717                        name: msg.name.clone(),
718                        content,
719                        function_call: msg.function_call.as_ref().map(|f| f.into()),
720                        tool_calls: extract_tool_calls(&msg.tool_calls),
721                        refusal,
722                    }
723                }
724                ChatCompletionRequestMessage::Tool(msg) => Self {
725                    role: "tool".to_string(),
726                    name: Some(msg.tool_call_id.clone()),
727                    content: Some(tool_content_text(&msg.content).unwrap_or_default()),
728                    ..Default::default()
729                },
730                ChatCompletionRequestMessage::Function(msg) => Self {
731                    role: "function".to_string(),
732                    name: Some(msg.name.clone()),
733                    content: msg.content.clone(),
734                    ..Default::default()
735                },
736            }
737        }
738    }
739
740    /// Calculates the total number of tokens for the given list of messages.
741    ///
742    /// **Note:** Only text content is counted. Non-text parts (images, audio, files) are
743    /// silently skipped because they use a separate token formula based on resolution/duration,
744    /// not BPE encoding. If your messages contain non-text content, the returned count will
745    /// be lower than the actual API token usage.
746    ///
747    /// # Arguments
748    ///
749    /// * `model` - A string slice representing the name of the model.
750    /// * `messages` - A slice of `async_openai::types::ChatCompletionRequestMessage` instances.
751    ///
752    /// # Returns
753    ///
754    /// * A `Result` containing the total number of tokens (`usize`) or an error if the calculation fails.
755    pub fn num_tokens_from_messages(
756        model: &str,
757        messages: &[ChatCompletionRequestMessage],
758    ) -> Result<usize> {
759        let messages: Vec<super::ChatCompletionRequestMessage> =
760            messages.iter().map(|m| m.into()).collect();
761        super::num_tokens_from_messages(model, &messages)
762    }
763
764    /// Retrieves the maximum token limit for chat completions.
765    ///
766    /// # Arguments
767    ///
768    /// * `model` - A string slice representing the name of the model.
769    /// * `messages` - A slice of `async_openai::types::ChatCompletionRequestMessage` instances.
770    ///
771    /// # Returns
772    ///
773    /// * A `Result` containing the maximum number of tokens (`usize`) allowed for chat completions or an error if the retrieval fails.
774    pub fn get_chat_completion_max_tokens(
775        model: &str,
776        messages: &[ChatCompletionRequestMessage],
777    ) -> Result<usize> {
778        let messages: Vec<super::ChatCompletionRequestMessage> =
779            messages.iter().map(|m| m.into()).collect();
780        super::get_chat_completion_max_tokens(model, &messages)
781    }
782
783    #[cfg(test)]
784    #[allow(deprecated)]
785    mod tests {
786        use super::*;
787        use async_openai::types::chat::{
788            ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessage,
789            ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
790        };
791
792        #[test]
793        fn test_num_tokens_from_messages_system() {
794            let model = "gpt-4o";
795            let messages = &[ChatCompletionRequestMessage::System(
796                ChatCompletionRequestSystemMessage {
797                    content: ChatCompletionRequestSystemMessageContent::Text(
798                        "You are a helpful assistant.".to_string(),
799                    ),
800                    name: None,
801                },
802            )];
803            let num_tokens = num_tokens_from_messages(model, messages).unwrap();
804            assert!(num_tokens > 0);
805        }
806
807        #[test]
808        fn test_num_tokens_from_messages_user() {
809            let model = "gpt-4o";
810            let messages = &[ChatCompletionRequestMessage::User(
811                ChatCompletionRequestUserMessage {
812                    content: ChatCompletionRequestUserMessageContent::Text(
813                        "Hello, how are you?".to_string(),
814                    ),
815                    name: None,
816                },
817            )];
818            let num_tokens = num_tokens_from_messages(model, messages).unwrap();
819            assert!(num_tokens > 0);
820        }
821
822        #[test]
823        fn test_num_tokens_with_tool_calls() {
824            let model = "gpt-4o";
825            let messages = &[ChatCompletionRequestMessage::Assistant(
826                ChatCompletionRequestAssistantMessage {
827                    content: None,
828                    refusal: None,
829                    name: None,
830                    audio: None,
831                    tool_calls: Some(vec![ChatCompletionMessageToolCalls::Function(
832                        ChatCompletionMessageToolCall {
833                            id: "call_123".to_string(),
834                            function: FunctionCall {
835                                name: "get_weather".to_string(),
836                                arguments: r#"{"location": "Paris"}"#.to_string(),
837                            },
838                        },
839                    )]),
840                    function_call: None,
841                },
842            )];
843            let tokens_with = num_tokens_from_messages(model, messages).unwrap();
844
845            let empty = &[ChatCompletionRequestMessage::Assistant(
846                ChatCompletionRequestAssistantMessage {
847                    content: None,
848                    refusal: None,
849                    name: None,
850                    audio: None,
851                    tool_calls: None,
852                    function_call: None,
853                },
854            )];
855            let tokens_without = num_tokens_from_messages(model, empty).unwrap();
856
857            assert!(
858                tokens_with > tokens_without,
859                "tool_calls should contribute tokens: {tokens_with} vs {tokens_without}"
860            );
861        }
862
863        #[test]
864        fn test_num_tokens_with_refusal() {
865            let model = "gpt-4o";
866            let messages = &[ChatCompletionRequestMessage::Assistant(
867                ChatCompletionRequestAssistantMessage {
868                    content: None,
869                    refusal: Some("I cannot help with that request.".to_string()),
870                    name: None,
871                    audio: None,
872                    tool_calls: None,
873                    function_call: None,
874                },
875            )];
876            let tokens_with = num_tokens_from_messages(model, messages).unwrap();
877
878            let empty = &[ChatCompletionRequestMessage::Assistant(
879                ChatCompletionRequestAssistantMessage {
880                    content: None,
881                    refusal: None,
882                    name: None,
883                    audio: None,
884                    tool_calls: None,
885                    function_call: None,
886                },
887            )];
888            let tokens_without = num_tokens_from_messages(model, empty).unwrap();
889
890            assert!(
891                tokens_with > tokens_without,
892                "refusal should contribute tokens: {tokens_with} vs {tokens_without}"
893            );
894        }
895
896        #[test]
897        fn test_get_chat_completion_max_tokens() {
898            let model = "gpt-4o";
899            let messages = &[ChatCompletionRequestMessage::System(
900                ChatCompletionRequestSystemMessage {
901                    content: ChatCompletionRequestSystemMessageContent::Text(
902                        "You are a helpful assistant.".to_string(),
903                    ),
904                    name: None,
905                },
906            )];
907            let max_tokens = get_chat_completion_max_tokens(model, messages).unwrap();
908            assert!(max_tokens > 0);
909        }
910    }
911}