Skip to main content

tiktoken_rs/
api.rs

1use anyhow::{Result, anyhow};
2
3use crate::{
4    CoreBPE, cl100k_base_singleton,
5    model::get_context_size,
6    o200k_base_singleton, o200k_harmony_singleton, p50k_base_singleton, p50k_edit_singleton,
7    r50k_base_singleton,
8    tokenizer::{Tokenizer, get_tokenizer},
9};
10
11/// Returns the maximum number of tokens available for a text completion, given a model and prompt.
12///
13/// This is for legacy text/prompt completions (single string input). For chat completions,
14/// use [`get_chat_completion_max_tokens`] instead.
15///
16/// Calculates `context_size - prompt_tokens` for the given model.
17///
18/// # Arguments
19///
20/// * `model` - A string slice representing the model name, e.g., `"gpt-4o"`.
21/// * `prompt` - A string slice containing the prompt text.
22///
23/// # Errors
24///
25/// Returns an error if no tokenizer is found for the given model.
26///
27/// # Examples
28///
29/// ```
30/// use tiktoken_rs::get_text_completion_max_tokens;
31///
32/// let max_tokens = get_text_completion_max_tokens("gpt-4o", "Translate to French: '").unwrap();
33/// ```
34pub fn get_text_completion_max_tokens(model: &str, prompt: &str) -> Result<usize> {
35    let context_size = get_context_size(model)
36        .ok_or_else(|| anyhow!("Unknown context size for model {}", model))?;
37    let tokenizer =
38        get_tokenizer(model).ok_or_else(|| anyhow!("No tokenizer found for model {}", model))?;
39    let bpe = bpe_singleton(tokenizer);
40    let prompt_tokens = bpe.count_with_special_tokens(prompt);
41    Ok(context_size.saturating_sub(prompt_tokens))
42}
43
44/// Use [`get_text_completion_max_tokens`] instead.
45#[deprecated(since = "0.10.0", note = "renamed to `get_text_completion_max_tokens`")]
46pub fn get_completion_max_tokens(model: &str, prompt: &str) -> Result<usize> {
47    get_text_completion_max_tokens(model, prompt)
48}
49
50/// The name and arguments of a function that should be called, as generated by the model.
51#[derive(Debug, Default, Clone, PartialEq, Eq)]
52pub struct FunctionCall {
53    /// The name of the function to call.
54    pub name: String,
55    /// The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.
56    pub arguments: String,
57}
58
59#[derive(Debug, Default, Clone, PartialEq, Eq)]
60pub struct ChatCompletionRequestMessage {
61    /// The role of the messages author. One of `system`, `developer`, `user`, `assistant`, `tool`, or `function`.
62    pub role: String,
63    /// The contents of the message.
64    /// `content` is required for all messages except assistant messages with function calls.
65    pub content: Option<String>,
66    /// The name of the author of this message. `name` is required if role is function,
67    /// and it should be the name of the function whose response is in the `content`.
68    /// May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters.
69    pub name: Option<String>,
70    /// The name and arguments of a function that should be called, as generated by the model.
71    pub function_call: Option<FunctionCall>,
72    /// Tool calls generated by the model, represented as FunctionCall structs.
73    /// Tool call IDs and type discriminators are not preserved.
74    pub tool_calls: Vec<FunctionCall>,
75    /// The refusal message generated by the model.
76    pub refusal: Option<String>,
77}
78
79/// Based on <https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb>
80///
81/// num_tokens_from_messages returns the number of tokens required to encode the given messages into
82/// the given model. This is used to estimate the number of tokens that will be used for chat
83/// completion.
84///
85/// # Arguments
86///
87/// * model: A string slice containing the model name (e.g. "gpt-3.5").
88/// * messages: A slice of ChatCompletionRequestMessage structs representing chat messages.
89///
90/// # Returns
91///
92/// * `Result<usize>`: A Result containing the total number of tokens needed to encode the messages
93///   for the specified model, or an error if the tokenizer for the model is not found or not supported.
94///
95/// # Errors
96///
97/// This function will return an error if:
98///
99/// * The tokenizer for the specified model is not found.
100/// * The tokenizer is not a supported chat model (i.e., not one of Cl100kBase, O200kBase, or O200kHarmony).
101///
102pub fn num_tokens_from_messages(
103    model: &str,
104    messages: &[ChatCompletionRequestMessage],
105) -> Result<usize> {
106    let tokenizer =
107        get_tokenizer(model).ok_or_else(|| anyhow!("No tokenizer found for model {}", model))?;
108    if tokenizer != Tokenizer::Cl100kBase
109        && tokenizer != Tokenizer::O200kBase
110        && tokenizer != Tokenizer::O200kHarmony
111    {
112        anyhow::bail!(
113            "Chat token counting is not supported for model {:?} (tokenizer {:?}). \
114             Supported tokenizers: Cl100kBase, O200kBase, O200kHarmony.",
115            model,
116            tokenizer
117        )
118    }
119    let bpe = bpe_singleton(tokenizer);
120
121    // Token overhead constants adapted from the OpenAI cookbook:
122    // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
123    //
124    // tokens_per_message: overhead tokens per message for framing (3 for current models)
125    // tokens_per_name: extra tokens when a `name` field is present (1 for current models)
126    //
127    // The gpt-3.5-turbo-0301 branch (4, -1) was removed from the cookbook in later revisions;
128    // we retain it for backward compatibility with that specific snapshot.
129    //
130    // FUNCTION_CALL_OVERHEAD: 1 extra token per function/tool call (heuristic)
131    // REPLY_PRIMING: 3 tokens added once at the end (per cookbook: <|start|>assistant<|message|>)
132    const FUNCTION_CALL_OVERHEAD: i32 = 1;
133    const REPLY_PRIMING: i32 = 3;
134
135    let (tokens_per_message, tokens_per_name) = if model == "gpt-3.5-turbo-0301" {
136        (4, -1)
137    } else {
138        (3, 1)
139    };
140
141    let mut num_tokens: i32 = 0;
142    for message in messages {
143        num_tokens += tokens_per_message;
144        num_tokens += bpe.count_with_special_tokens(&message.role) as i32;
145        if let Some(content) = &message.content {
146            num_tokens += bpe.count_with_special_tokens(content) as i32;
147        }
148        if let Some(name) = &message.name {
149            num_tokens += bpe.count_with_special_tokens(name) as i32;
150            num_tokens += tokens_per_name;
151        }
152        if let Some(function_call) = &message.function_call {
153            num_tokens += bpe.count_with_special_tokens(&function_call.name) as i32;
154            num_tokens += bpe.count_with_special_tokens(&function_call.arguments) as i32;
155            num_tokens += FUNCTION_CALL_OVERHEAD;
156        }
157        for tool_call in &message.tool_calls {
158            num_tokens += bpe.count_with_special_tokens(&tool_call.name) as i32;
159            num_tokens += bpe.count_with_special_tokens(&tool_call.arguments) as i32;
160            num_tokens += FUNCTION_CALL_OVERHEAD;
161        }
162        if let Some(refusal) = &message.refusal {
163            num_tokens += bpe.count_with_special_tokens(refusal) as i32;
164        }
165    }
166    num_tokens += REPLY_PRIMING;
167    Ok(num_tokens as usize)
168}
169
170/// Calculates the maximum number of tokens available for chat completion based on the model and messages provided.
171///
172/// This function determines the number of tokens left for a chat completion task, given the model and a slice of
173/// chat completion request messages. It first retrieves the tokenizer for the given model and checks if chat completion
174/// is supported. Then, it calculates the number of tokens in the existing messages using the appropriate tokenizer.
175///
176/// # Arguments
177///
178/// * `model` - A string slice representing the model name, e.g., "gpt-3.5-turbo".
179/// * `messages` - A slice of `ChatCompletionRequestMessage` instances containing the chat context.
180///
181/// # Errors
182///
183/// This function returns an error in the following cases:
184///
185/// * If there is no tokenizer found for the specified model.
186/// * If chat completion is not supported for the specified model.
187/// * If there is a failure in creating a `CoreBPE` instance for the specified tokenizer.
188///
189/// # Example
190///
191/// ```
192/// use tiktoken_rs::{get_chat_completion_max_tokens, ChatCompletionRequestMessage};
193///
194/// let model = "gpt-3.5-turbo";
195/// let messages = vec![
196///     ChatCompletionRequestMessage {
197///         content: Some("You are a helpful assistant that only speaks French.".to_string()),
198///         role: "system".to_string(),
199///         ..Default::default()
200///     },
201///     ChatCompletionRequestMessage {
202///         content: Some("Hello, how are you?".to_string()),
203///         role: "user".to_string(),
204///         ..Default::default()
205///     },
206///     ChatCompletionRequestMessage {
207///         content: Some("Parlez-vous francais?".to_string()),
208///         role: "system".to_string(),
209///         ..Default::default()
210///     },
211/// ];
212/// let max_tokens = get_chat_completion_max_tokens(model, &messages).unwrap();
213/// ```
214///
215/// # Returns
216///
217/// If successful, the function returns a `Result` containing the maximum number of tokens available for chat completion,
218/// based on the given model and messages.
219pub fn get_chat_completion_max_tokens(
220    model: &str,
221    messages: &[ChatCompletionRequestMessage],
222) -> Result<usize> {
223    let context_size = get_context_size(model)
224        .ok_or_else(|| anyhow!("Unknown context size for model {}", model))?;
225    let prompt_tokens = num_tokens_from_messages(model, messages)?;
226    Ok(context_size.saturating_sub(prompt_tokens))
227}
228
229fn bpe_singleton(tokenizer: Tokenizer) -> &'static CoreBPE {
230    match tokenizer {
231        Tokenizer::O200kHarmony => o200k_harmony_singleton(),
232        Tokenizer::O200kBase => o200k_base_singleton(),
233        Tokenizer::Cl100kBase => cl100k_base_singleton(),
234        Tokenizer::R50kBase => r50k_base_singleton(),
235        Tokenizer::P50kBase => p50k_base_singleton(),
236        Tokenizer::P50kEdit => p50k_edit_singleton(),
237        Tokenizer::Gpt2 => r50k_base_singleton(),
238    }
239}
240
241/// Returns a cached reference to the BPE tokenizer for the given model name.
242///
243/// Looks up which tokenizer the model uses, then returns a `&'static CoreBPE` singleton.
244/// The singleton is initialized once and reused for all subsequent calls.
245///
246/// # Arguments
247///
248/// * `model` - A model name, e.g., `"gpt-4o"`, `"gpt-3.5-turbo"`, `"o3-mini"`.
249///
250/// # Errors
251///
252/// Returns an error if no tokenizer is found for the given model name.
253///
254/// # Examples
255///
256/// ```
257/// use tiktoken_rs::bpe_for_model;
258///
259/// let bpe = bpe_for_model("gpt-4o").unwrap();
260/// let tokens = bpe.encode_with_special_tokens("hello world");
261/// ```
262pub fn bpe_for_model(model: &str) -> Result<&'static CoreBPE> {
263    let tokenizer =
264        get_tokenizer(model).ok_or_else(|| anyhow!("No tokenizer found for model {}", model))?;
265    bpe_for_tokenizer(tokenizer)
266}
267
268/// Use [`bpe_for_model`] instead.
269#[deprecated(since = "0.10.0", note = "renamed to `bpe_for_model`")]
270pub fn get_bpe_from_model(model: &str) -> Result<&'static CoreBPE> {
271    bpe_for_model(model)
272}
273
274/// Returns a cached reference to the BPE tokenizer for the given tokenizer type.
275///
276/// Returns a `&'static CoreBPE` singleton. The singleton is initialized once and reused
277/// for all subsequent calls.
278///
279/// # Arguments
280///
281/// * `tokenizer` - A [`Tokenizer`] enum variant.
282///
283/// # Examples
284///
285/// ```
286/// use tiktoken_rs::bpe_for_tokenizer;
287/// use tiktoken_rs::tokenizer::Tokenizer;
288///
289/// let bpe = bpe_for_tokenizer(Tokenizer::O200kBase).unwrap();
290/// let tokens = bpe.encode_with_special_tokens("hello world");
291/// ```
292pub fn bpe_for_tokenizer(tokenizer: Tokenizer) -> Result<&'static CoreBPE> {
293    Ok(bpe_singleton(tokenizer))
294}
295
296/// Use [`bpe_for_tokenizer`] instead.
297#[deprecated(since = "0.10.0", note = "renamed to `bpe_for_tokenizer`")]
298pub fn get_bpe_from_tokenizer(tokenizer: Tokenizer) -> Result<&'static CoreBPE> {
299    bpe_for_tokenizer(tokenizer)
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_bpe_for_tokenizer() {
308        let bpe = bpe_for_tokenizer(Tokenizer::Cl100kBase).unwrap();
309        assert_eq!(bpe.decode(&[15339]).unwrap(), "hello");
310    }
311
312    #[test]
313    fn test_num_tokens_from_messages() {
314        let messages = vec![
315            ChatCompletionRequestMessage {
316                role: "system".to_string(),
317                name: None,
318                content: Some("You are a helpful, pattern-following assistant that translates corporate jargon into plain English.".to_string()),
319                ..Default::default()
320            },
321            ChatCompletionRequestMessage {
322                role: "system".to_string(),
323                name: Some("example_user".to_string()),
324                content: Some("New synergies will help drive top-line growth.".to_string()),
325                ..Default::default()
326            },
327            ChatCompletionRequestMessage {
328                role: "system".to_string(),
329                name: Some("example_assistant".to_string()),
330                content: Some("Things working well together will increase revenue.".to_string()),
331                ..Default::default()
332            },
333            ChatCompletionRequestMessage {
334                role: "system".to_string(),
335                name: Some("example_user".to_string()),
336                content: Some("Let's circle back when we have more bandwidth to touch base on opportunities for increased leverage.".to_string()),
337                ..Default::default()
338            },
339            ChatCompletionRequestMessage {
340                role: "system".to_string(),
341                name: Some("example_assistant".to_string()),
342                content: Some("Let's talk later when we're less busy about how to do better.".to_string()),
343                ..Default::default()
344            },
345            ChatCompletionRequestMessage {
346                role: "user".to_string(),
347                name: None,
348                content: Some("This late pivot means we don't have time to boil the ocean for the client deliverable.".to_string()),
349                ..Default::default()
350            },
351        ];
352        let num_tokens = num_tokens_from_messages("gpt-3.5-turbo-0301", &messages).unwrap();
353        assert_eq!(num_tokens, 127);
354
355        let num_tokens = num_tokens_from_messages("gpt-4-0314", &messages).unwrap();
356        assert_eq!(num_tokens, 129);
357
358        let num_tokens = num_tokens_from_messages("gpt-4o-2024-05-13", &messages).unwrap();
359        assert_eq!(num_tokens, 124);
360
361        // Newer gpt-3.5 snapshots use (3, 1) like gpt-4, not (4, -1) like gpt-3.5-turbo-0301
362        let num_tokens = num_tokens_from_messages("gpt-3.5-turbo-0125", &messages).unwrap();
363        assert_eq!(num_tokens, 129);
364    }
365
366    #[test]
367    fn test_num_tokens_from_messages_with_function_call() {
368        let messages = vec![
369            ChatCompletionRequestMessage {
370                role: "system".to_string(),
371                content: Some("You are a friendly chatbot.\n".to_string()),
372                name: None,
373                ..Default::default()
374            },
375            ChatCompletionRequestMessage {
376                role: "assistant".to_string(),
377                content: Some("Hello, I am a friendly chatbot!\n".to_string()),
378                name: None,
379                ..Default::default()
380            },
381            ChatCompletionRequestMessage {
382                role: "user".to_string(),
383                content: Some("What is the weather in New York?".to_string()),
384                name: None,
385                ..Default::default()
386            },
387            ChatCompletionRequestMessage {
388                role: "assistant".to_string(),
389                content: Some(String::new()),
390                function_call: Some(FunctionCall {
391                    name: "get_weather".to_string(),
392                    arguments: "{\n  \"city\": \"New York\"\n}".to_string(),
393                }),
394                ..Default::default()
395            },
396            ChatCompletionRequestMessage {
397                role: "function".to_string(),
398                content: Some(
399                    "{\"temperature\": 72, \"conditions\": \"partly_cloudy\"}".to_string(),
400                ),
401                name: Some("get_weather".to_string()),
402                ..Default::default()
403            },
404        ];
405        // Validated against OpenAI API response (issue #40)
406        let num_tokens = num_tokens_from_messages("gpt-4-0613", &messages).unwrap();
407        assert_eq!(num_tokens, 78);
408    }
409
410    #[test]
411    fn test_num_tokens_from_messages_with_tool_calls() {
412        let messages_with = vec![ChatCompletionRequestMessage {
413            role: "assistant".to_string(),
414            tool_calls: vec![FunctionCall {
415                name: "get_weather".to_string(),
416                arguments: r#"{"city": "Paris"}"#.to_string(),
417            }],
418            ..Default::default()
419        }];
420        let messages_without = vec![ChatCompletionRequestMessage {
421            role: "assistant".to_string(),
422            ..Default::default()
423        }];
424        let with = num_tokens_from_messages("gpt-4o", &messages_with).unwrap();
425        let without = num_tokens_from_messages("gpt-4o", &messages_without).unwrap();
426        assert!(
427            with > without,
428            "tool_calls should contribute tokens: {with} vs {without}"
429        );
430    }
431
432    #[test]
433    fn test_num_tokens_from_messages_with_multiple_tool_calls() {
434        let single = vec![ChatCompletionRequestMessage {
435            role: "assistant".to_string(),
436            tool_calls: vec![FunctionCall {
437                name: "get_weather".to_string(),
438                arguments: r#"{"city": "Paris"}"#.to_string(),
439            }],
440            ..Default::default()
441        }];
442        let double = vec![ChatCompletionRequestMessage {
443            role: "assistant".to_string(),
444            tool_calls: vec![
445                FunctionCall {
446                    name: "get_weather".to_string(),
447                    arguments: r#"{"city": "Paris"}"#.to_string(),
448                },
449                FunctionCall {
450                    name: "get_weather".to_string(),
451                    arguments: r#"{"city": "London"}"#.to_string(),
452                },
453            ],
454            ..Default::default()
455        }];
456        let single_tokens = num_tokens_from_messages("gpt-4o", &single).unwrap();
457        let double_tokens = num_tokens_from_messages("gpt-4o", &double).unwrap();
458        assert!(
459            double_tokens > single_tokens,
460            "multiple tool_calls should each contribute tokens: {double_tokens} vs {single_tokens}"
461        );
462    }
463
464    #[test]
465    fn test_num_tokens_from_messages_with_refusal() {
466        let messages_with = vec![ChatCompletionRequestMessage {
467            role: "assistant".to_string(),
468            refusal: Some("I cannot help with that request.".to_string()),
469            ..Default::default()
470        }];
471        let messages_without = vec![ChatCompletionRequestMessage {
472            role: "assistant".to_string(),
473            ..Default::default()
474        }];
475        let with = num_tokens_from_messages("gpt-4o", &messages_with).unwrap();
476        let without = num_tokens_from_messages("gpt-4o", &messages_without).unwrap();
477        assert!(
478            with > without,
479            "refusal should contribute tokens: {with} vs {without}"
480        );
481    }
482
483    #[test]
484    fn test_num_tokens_from_messages_repeated_calls_consistent() {
485        let messages = vec![ChatCompletionRequestMessage {
486            role: "user".to_string(),
487            content: Some("Hello, world!".to_string()),
488            ..Default::default()
489        }];
490        let first = num_tokens_from_messages("gpt-4o", &messages).unwrap();
491        for _ in 0..5 {
492            let result = num_tokens_from_messages("gpt-4o", &messages).unwrap();
493            assert_eq!(first, result);
494        }
495    }
496
497    #[test]
498    fn test_text_completion_max_tokens_repeated_calls_consistent() {
499        let first = get_text_completion_max_tokens("gpt-4o", "Hello, world!").unwrap();
500        for _ in 0..5 {
501            let result = get_text_completion_max_tokens("gpt-4o", "Hello, world!").unwrap();
502            assert_eq!(first, result);
503        }
504    }
505
506    #[test]
507    fn test_bpe_singleton_matches_fresh_bpe() {
508        let singleton = bpe_singleton(Tokenizer::Cl100kBase);
509        let fresh = bpe_for_tokenizer(Tokenizer::Cl100kBase).unwrap();
510        let text = "The quick brown fox jumps over the lazy dog";
511        assert_eq!(
512            singleton.encode_with_special_tokens(text),
513            fresh.encode_with_special_tokens(text),
514        );
515    }
516
517    #[test]
518    fn test_get_chat_completion_max_tokens() {
519        let model = "gpt-3.5-turbo";
520        let messages = vec![
521            ChatCompletionRequestMessage {
522                content: Some("You are a helpful assistant that only speaks French.".to_string()),
523                role: "system".to_string(),
524                name: None,
525                ..Default::default()
526            },
527            ChatCompletionRequestMessage {
528                content: Some("Hello, how are you?".to_string()),
529                role: "user".to_string(),
530                name: None,
531                ..Default::default()
532            },
533            ChatCompletionRequestMessage {
534                content: Some("Parlez-vous francais?".to_string()),
535                role: "system".to_string(),
536                name: None,
537                ..Default::default()
538            },
539        ];
540        let max_tokens = get_chat_completion_max_tokens(model, &messages).unwrap();
541        assert!(max_tokens > 0);
542    }
543
544    #[test]
545    fn test_text_completion_max_tokens() {
546        let model = "gpt-3.5-turbo";
547        let prompt = "Translate the following English text to French: '";
548        let max_tokens = get_text_completion_max_tokens(model, prompt).unwrap();
549        assert!(max_tokens > 0);
550    }
551}
552
553/// This module provide support for working with the `async_openai` crate.
554#[cfg(feature = "async-openai")]
555pub mod async_openai {
556    use anyhow::Result;
557    use async_openai::types::chat::{
558        ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageContent,
559        ChatCompletionRequestAssistantMessageContentPart,
560        ChatCompletionRequestDeveloperMessageContent,
561        ChatCompletionRequestDeveloperMessageContentPart, ChatCompletionRequestMessage,
562        ChatCompletionRequestSystemMessageContent, ChatCompletionRequestSystemMessageContentPart,
563        ChatCompletionRequestToolMessageContent, ChatCompletionRequestToolMessageContentPart,
564        ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
565        FunctionCall,
566    };
567
568    impl From<&FunctionCall> for super::FunctionCall {
569        fn from(f: &FunctionCall) -> Self {
570            Self {
571                name: f.name.clone(),
572                arguments: f.arguments.clone(),
573            }
574        }
575    }
576
577    fn join_texts(texts: Vec<String>) -> Option<String> {
578        if texts.is_empty() {
579            None
580        } else {
581            Some(texts.join(""))
582        }
583    }
584
585    fn system_content_text(content: &ChatCompletionRequestSystemMessageContent) -> Option<String> {
586        match content {
587            ChatCompletionRequestSystemMessageContent::Text(s) => Some(s.clone()),
588            ChatCompletionRequestSystemMessageContent::Array(parts) => join_texts(
589                parts
590                    .iter()
591                    .map(|ChatCompletionRequestSystemMessageContentPart::Text(t)| t.text.clone())
592                    .collect(),
593            ),
594        }
595    }
596
597    fn developer_content_text(
598        content: &ChatCompletionRequestDeveloperMessageContent,
599    ) -> Option<String> {
600        match content {
601            ChatCompletionRequestDeveloperMessageContent::Text(s) => Some(s.clone()),
602            ChatCompletionRequestDeveloperMessageContent::Array(parts) => join_texts(
603                parts
604                    .iter()
605                    .map(|ChatCompletionRequestDeveloperMessageContentPart::Text(t)| t.text.clone())
606                    .collect(),
607            ),
608        }
609    }
610
611    fn user_content_text(content: &ChatCompletionRequestUserMessageContent) -> Option<String> {
612        match content {
613            ChatCompletionRequestUserMessageContent::Text(s) => Some(s.clone()),
614            ChatCompletionRequestUserMessageContent::Array(parts) => join_texts(
615                parts
616                    .iter()
617                    .filter_map(|p| match p {
618                        ChatCompletionRequestUserMessageContentPart::Text(t) => {
619                            Some(t.text.clone())
620                        }
621                        ChatCompletionRequestUserMessageContentPart::ImageUrl(_)
622                        | ChatCompletionRequestUserMessageContentPart::InputAudio(_)
623                        | ChatCompletionRequestUserMessageContentPart::File(_) => None,
624                    })
625                    .collect(),
626            ),
627        }
628    }
629
630    fn assistant_content_text(
631        content: &ChatCompletionRequestAssistantMessageContent,
632    ) -> (Option<String>, Option<String>) {
633        match content {
634            ChatCompletionRequestAssistantMessageContent::Text(s) => (Some(s.clone()), None),
635            ChatCompletionRequestAssistantMessageContent::Array(parts) => {
636                let mut texts = Vec::new();
637                let mut refusals = Vec::new();
638                for p in parts {
639                    match p {
640                        ChatCompletionRequestAssistantMessageContentPart::Text(t) => {
641                            texts.push(t.text.clone());
642                        }
643                        ChatCompletionRequestAssistantMessageContentPart::Refusal(r) => {
644                            refusals.push(r.refusal.clone());
645                        }
646                    }
647                }
648                (join_texts(texts), join_texts(refusals))
649            }
650        }
651    }
652
653    fn tool_content_text(content: &ChatCompletionRequestToolMessageContent) -> Option<String> {
654        match content {
655            ChatCompletionRequestToolMessageContent::Text(s) => Some(s.clone()),
656            ChatCompletionRequestToolMessageContent::Array(parts) => join_texts(
657                parts
658                    .iter()
659                    .map(|ChatCompletionRequestToolMessageContentPart::Text(t)| t.text.clone())
660                    .collect(),
661            ),
662        }
663    }
664
665    fn extract_tool_calls(
666        tool_calls: &Option<Vec<ChatCompletionMessageToolCalls>>,
667    ) -> Vec<super::FunctionCall> {
668        tool_calls
669            .as_ref()
670            .map(|calls| {
671                calls
672                    .iter()
673                    .map(|tc| match tc {
674                        ChatCompletionMessageToolCalls::Function(f) => (&f.function).into(),
675                        ChatCompletionMessageToolCalls::Custom(c) => super::FunctionCall {
676                            name: c.custom_tool.name.clone(),
677                            arguments: c.custom_tool.input.clone(),
678                        },
679                    })
680                    .collect()
681            })
682            .unwrap_or_default()
683    }
684
685    #[allow(deprecated)]
686    impl From<&ChatCompletionRequestMessage> for super::ChatCompletionRequestMessage {
687        fn from(m: &ChatCompletionRequestMessage) -> Self {
688            match m {
689                ChatCompletionRequestMessage::System(msg) => Self {
690                    role: "system".to_string(),
691                    name: msg.name.clone(),
692                    content: Some(system_content_text(&msg.content).unwrap_or_default()),
693                    ..Default::default()
694                },
695                ChatCompletionRequestMessage::Developer(msg) => Self {
696                    role: "developer".to_string(),
697                    name: msg.name.clone(),
698                    content: Some(developer_content_text(&msg.content).unwrap_or_default()),
699                    ..Default::default()
700                },
701                ChatCompletionRequestMessage::User(msg) => Self {
702                    role: "user".to_string(),
703                    name: msg.name.clone(),
704                    content: Some(user_content_text(&msg.content).unwrap_or_default()),
705                    ..Default::default()
706                },
707                ChatCompletionRequestMessage::Assistant(msg) => {
708                    let (content, refusal) = msg
709                        .content
710                        .as_ref()
711                        .map(assistant_content_text)
712                        .unwrap_or_default();
713                    let refusal = refusal.or_else(|| msg.refusal.clone());
714                    Self {
715                        role: "assistant".to_string(),
716                        name: msg.name.clone(),
717                        content,
718                        function_call: msg.function_call.as_ref().map(|f| f.into()),
719                        tool_calls: extract_tool_calls(&msg.tool_calls),
720                        refusal,
721                    }
722                }
723                ChatCompletionRequestMessage::Tool(msg) => Self {
724                    role: "tool".to_string(),
725                    name: Some(msg.tool_call_id.clone()),
726                    content: Some(tool_content_text(&msg.content).unwrap_or_default()),
727                    ..Default::default()
728                },
729                ChatCompletionRequestMessage::Function(msg) => Self {
730                    role: "function".to_string(),
731                    name: Some(msg.name.clone()),
732                    content: msg.content.clone(),
733                    ..Default::default()
734                },
735            }
736        }
737    }
738
739    /// Calculates the total number of tokens for the given list of messages.
740    ///
741    /// **Note:** Only text content is counted. Non-text parts (images, audio, files) are
742    /// silently skipped because they use a separate token formula based on resolution/duration,
743    /// not BPE encoding. If your messages contain non-text content, the returned count will
744    /// be lower than the actual API token usage.
745    ///
746    /// # Arguments
747    ///
748    /// * `model` - A string slice representing the name of the model.
749    /// * `messages` - A slice of `async_openai::types::ChatCompletionRequestMessage` instances.
750    ///
751    /// # Returns
752    ///
753    /// * A `Result` containing the total number of tokens (`usize`) or an error if the calculation fails.
754    pub fn num_tokens_from_messages(
755        model: &str,
756        messages: &[ChatCompletionRequestMessage],
757    ) -> Result<usize> {
758        let messages: Vec<super::ChatCompletionRequestMessage> =
759            messages.iter().map(|m| m.into()).collect();
760        super::num_tokens_from_messages(model, &messages)
761    }
762
763    /// Retrieves the maximum token limit for chat completions.
764    ///
765    /// # Arguments
766    ///
767    /// * `model` - A string slice representing the name of the model.
768    /// * `messages` - A slice of `async_openai::types::ChatCompletionRequestMessage` instances.
769    ///
770    /// # Returns
771    ///
772    /// * A `Result` containing the maximum number of tokens (`usize`) allowed for chat completions or an error if the retrieval fails.
773    pub fn get_chat_completion_max_tokens(
774        model: &str,
775        messages: &[ChatCompletionRequestMessage],
776    ) -> Result<usize> {
777        let messages: Vec<super::ChatCompletionRequestMessage> =
778            messages.iter().map(|m| m.into()).collect();
779        super::get_chat_completion_max_tokens(model, &messages)
780    }
781
782    #[cfg(test)]
783    #[allow(deprecated)]
784    mod tests {
785        use super::*;
786        use async_openai::types::chat::{
787            ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessage,
788            ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
789        };
790
791        #[test]
792        fn test_num_tokens_from_messages_system() {
793            let model = "gpt-4o";
794            let messages = &[ChatCompletionRequestMessage::System(
795                ChatCompletionRequestSystemMessage {
796                    content: ChatCompletionRequestSystemMessageContent::Text(
797                        "You are a helpful assistant.".to_string(),
798                    ),
799                    name: None,
800                },
801            )];
802            let num_tokens = num_tokens_from_messages(model, messages).unwrap();
803            assert!(num_tokens > 0);
804        }
805
806        #[test]
807        fn test_num_tokens_from_messages_user() {
808            let model = "gpt-4o";
809            let messages = &[ChatCompletionRequestMessage::User(
810                ChatCompletionRequestUserMessage {
811                    content: ChatCompletionRequestUserMessageContent::Text(
812                        "Hello, how are you?".to_string(),
813                    ),
814                    name: None,
815                },
816            )];
817            let num_tokens = num_tokens_from_messages(model, messages).unwrap();
818            assert!(num_tokens > 0);
819        }
820
821        #[test]
822        fn test_num_tokens_with_tool_calls() {
823            let model = "gpt-4o";
824            let messages = &[ChatCompletionRequestMessage::Assistant(
825                ChatCompletionRequestAssistantMessage {
826                    content: None,
827                    refusal: None,
828                    name: None,
829                    audio: None,
830                    tool_calls: Some(vec![ChatCompletionMessageToolCalls::Function(
831                        ChatCompletionMessageToolCall {
832                            id: "call_123".to_string(),
833                            function: FunctionCall {
834                                name: "get_weather".to_string(),
835                                arguments: r#"{"location": "Paris"}"#.to_string(),
836                            },
837                        },
838                    )]),
839                    function_call: None,
840                },
841            )];
842            let tokens_with = num_tokens_from_messages(model, messages).unwrap();
843
844            let empty = &[ChatCompletionRequestMessage::Assistant(
845                ChatCompletionRequestAssistantMessage {
846                    content: None,
847                    refusal: None,
848                    name: None,
849                    audio: None,
850                    tool_calls: None,
851                    function_call: None,
852                },
853            )];
854            let tokens_without = num_tokens_from_messages(model, empty).unwrap();
855
856            assert!(
857                tokens_with > tokens_without,
858                "tool_calls should contribute tokens: {tokens_with} vs {tokens_without}"
859            );
860        }
861
862        #[test]
863        fn test_num_tokens_with_refusal() {
864            let model = "gpt-4o";
865            let messages = &[ChatCompletionRequestMessage::Assistant(
866                ChatCompletionRequestAssistantMessage {
867                    content: None,
868                    refusal: Some("I cannot help with that request.".to_string()),
869                    name: None,
870                    audio: None,
871                    tool_calls: None,
872                    function_call: None,
873                },
874            )];
875            let tokens_with = num_tokens_from_messages(model, messages).unwrap();
876
877            let empty = &[ChatCompletionRequestMessage::Assistant(
878                ChatCompletionRequestAssistantMessage {
879                    content: None,
880                    refusal: None,
881                    name: None,
882                    audio: None,
883                    tool_calls: None,
884                    function_call: None,
885                },
886            )];
887            let tokens_without = num_tokens_from_messages(model, empty).unwrap();
888
889            assert!(
890                tokens_with > tokens_without,
891                "refusal should contribute tokens: {tokens_with} vs {tokens_without}"
892            );
893        }
894
895        #[test]
896        fn test_get_chat_completion_max_tokens() {
897            let model = "gpt-4o";
898            let messages = &[ChatCompletionRequestMessage::System(
899                ChatCompletionRequestSystemMessage {
900                    content: ChatCompletionRequestSystemMessageContent::Text(
901                        "You are a helpful assistant.".to_string(),
902                    ),
903                    name: None,
904                },
905            )];
906            let max_tokens = get_chat_completion_max_tokens(model, messages).unwrap();
907            assert!(max_tokens > 0);
908        }
909    }
910}