Skip to main content

rust_genai/
tokenizer.rs

1//! Local token estimation utilities (optional).
2
3use rust_genai_types::config::GenerationConfig;
4use rust_genai_types::content::{Content, FunctionCall, FunctionResponse, PartKind};
5use rust_genai_types::models::CountTokensConfig;
6use rust_genai_types::tool::{FunctionDeclaration, Schema, Tool};
7use serde_json::Value;
8
9/// Token estimator trait.
10pub trait TokenEstimator {
11    fn estimate_tokens(&self, contents: &[Content]) -> usize;
12}
13
14/// A very rough heuristic estimator (bytes / 4).
15#[derive(Debug, Clone, Default)]
16pub struct SimpleTokenEstimator;
17
18impl TokenEstimator for SimpleTokenEstimator {
19    fn estimate_tokens(&self, contents: &[Content]) -> usize {
20        let mut bytes = 0usize;
21        for content in contents {
22            for part in &content.parts {
23                match &part.kind {
24                    PartKind::Text { text } => {
25                        bytes += text.len();
26                    }
27                    PartKind::InlineData { inline_data } => {
28                        bytes += inline_data.data.len();
29                    }
30                    PartKind::FileData { file_data } => {
31                        bytes += file_data.file_uri.len();
32                    }
33                    PartKind::FunctionCall { function_call } => {
34                        if let Some(name) = &function_call.name {
35                            bytes += name.len();
36                        }
37                    }
38                    PartKind::FunctionResponse { function_response } => {
39                        if let Some(name) = &function_response.name {
40                            bytes += name.len();
41                        }
42                    }
43                    PartKind::ExecutableCode { executable_code } => {
44                        bytes += executable_code.code.len();
45                    }
46                    PartKind::CodeExecutionResult {
47                        code_execution_result,
48                    } => {
49                        if let Some(output) = &code_execution_result.output {
50                            bytes += output.len();
51                        }
52                    }
53                }
54            }
55        }
56        // Rough heuristic: 1 token ~ 4 bytes.
57        bytes.div_ceil(4)
58    }
59}
60
61pub(crate) fn build_estimation_contents(
62    contents: &[Content],
63    config: &CountTokensConfig,
64) -> Vec<Content> {
65    let mut combined = Vec::with_capacity(contents.len() + 1);
66    combined.extend_from_slice(contents);
67    if let Some(system_instruction) = &config.system_instruction {
68        combined.push(system_instruction.clone());
69    }
70
71    let mut accumulator = TextAccumulator::default();
72    accumulator.add_function_texts_from_contents(&combined);
73    if let Some(tools) = &config.tools {
74        accumulator.add_tools(tools);
75    }
76    if let Some(generation_config) = &config.generation_config {
77        accumulator.add_generation_config(generation_config);
78    }
79    combined.extend(accumulator.into_contents());
80    combined
81}
82
83#[derive(Debug, Default)]
84struct TextAccumulator {
85    texts: Vec<String>,
86}
87
88impl TextAccumulator {
89    fn add_function_texts_from_contents(&mut self, contents: &[Content]) {
90        for content in contents {
91            self.add_function_texts_from_content(content);
92        }
93    }
94
95    fn add_function_texts_from_content(&mut self, content: &Content) {
96        for part in &content.parts {
97            match &part.kind {
98                PartKind::FunctionCall { function_call } => {
99                    self.add_function_call(function_call);
100                }
101                PartKind::FunctionResponse { function_response } => {
102                    self.add_function_response(function_response);
103                }
104                _ => {}
105            }
106        }
107    }
108
109    fn add_function_call(&mut self, function_call: &FunctionCall) {
110        if let Some(name) = &function_call.name {
111            self.push_text(name);
112        }
113        if let Some(args) = &function_call.args {
114            self.add_json(args);
115        }
116    }
117
118    fn add_function_response(&mut self, function_response: &FunctionResponse) {
119        if let Some(name) = &function_response.name {
120            self.push_text(name);
121        }
122        if let Some(response) = &function_response.response {
123            self.add_json(response);
124        }
125    }
126
127    fn add_tools(&mut self, tools: &[Tool]) {
128        for tool in tools {
129            if let Some(functions) = &tool.function_declarations {
130                for function in functions {
131                    self.add_function_declaration(function);
132                }
133            }
134        }
135    }
136
137    fn add_function_declaration(&mut self, declaration: &FunctionDeclaration) {
138        self.push_text(&declaration.name);
139        if let Some(description) = &declaration.description {
140            self.push_text(description);
141        }
142        if let Some(parameters) = &declaration.parameters {
143            self.add_schema(parameters);
144        }
145        if let Some(response) = &declaration.response {
146            self.add_schema(response);
147        }
148        if let Some(parameters_json) = &declaration.parameters_json_schema {
149            self.add_json(parameters_json);
150        }
151        if let Some(response_json) = &declaration.response_json_schema {
152            self.add_json(response_json);
153        }
154    }
155
156    fn add_generation_config(&mut self, generation_config: &GenerationConfig) {
157        if let Some(response_schema) = &generation_config.response_schema {
158            self.add_schema(response_schema);
159        }
160        if let Some(response_json_schema) = &generation_config.response_json_schema {
161            self.add_json(response_json_schema);
162        }
163    }
164
165    fn add_schema(&mut self, schema: &Schema) {
166        if let Some(title) = &schema.title {
167            self.push_text(title);
168        }
169        if let Some(format) = &schema.format {
170            self.push_text(format);
171        }
172        if let Some(description) = &schema.description {
173            self.push_text(description);
174        }
175        if let Some(enum_values) = &schema.enum_values {
176            for value in enum_values {
177                self.push_text(value);
178            }
179        }
180        if let Some(required) = &schema.required {
181            for value in required {
182                self.push_text(value);
183            }
184        }
185        if let Some(properties) = &schema.properties {
186            for (key, value) in properties {
187                self.push_text(key);
188                self.add_schema(value);
189            }
190        }
191        if let Some(items) = &schema.items {
192            self.add_schema(items);
193        }
194        if let Some(any_of) = &schema.any_of {
195            for schema in any_of {
196                self.add_schema(schema);
197            }
198        }
199        if let Some(example) = &schema.example {
200            self.add_json(example);
201        }
202        if let Some(default) = &schema.default {
203            self.add_json(default);
204        }
205    }
206
207    fn add_json(&mut self, value: &Value) {
208        match value {
209            Value::String(value) => self.push_text(value),
210            Value::Array(values) => {
211                for item in values {
212                    self.add_json(item);
213                }
214            }
215            Value::Object(map) => {
216                for (key, value) in map {
217                    self.push_text(key);
218                    self.add_json(value);
219                }
220            }
221            _ => {}
222        }
223    }
224
225    fn push_text(&mut self, value: &str) {
226        if !value.is_empty() {
227            self.texts.push(value.to_string());
228        }
229    }
230
231    fn into_contents(self) -> Vec<Content> {
232        self.texts.into_iter().map(Content::text).collect()
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use rust_genai_types::config::GenerationConfig;
240    use rust_genai_types::content::{FunctionCall, FunctionResponse, Part, Role};
241    use rust_genai_types::models::CountTokensConfig;
242    use rust_genai_types::tool::{FunctionDeclaration, Schema, Tool};
243    use serde_json::json;
244    use std::collections::HashMap;
245
246    #[test]
247    fn simple_token_estimator_counts_various_parts() {
248        let call = FunctionCall {
249            id: Some("call-1".into()),
250            name: Some("lookup".into()),
251            args: Some(json!({"q": "rust"})),
252            partial_args: None,
253            will_continue: None,
254        };
255        let response = FunctionResponse {
256            will_continue: None,
257            scheduling: None,
258            parts: None,
259            id: Some("resp-1".into()),
260            name: Some("lookup".into()),
261            response: Some(json!({"ok": true})),
262        };
263        let content = Content::from_parts(
264            vec![
265                Part::text("hello"),
266                Part::inline_data(vec![0, 1, 2, 3], "image/png"),
267                Part::file_data("files/abc", "application/pdf"),
268                Part::function_call(call),
269                Part::function_response(response),
270                Part::executable_code("print('hi')", rust_genai_types::enums::Language::Python),
271                Part::code_execution_result(rust_genai_types::enums::Outcome::OutcomeOk, "ok"),
272            ],
273            Role::User,
274        );
275
276        let estimator = SimpleTokenEstimator;
277        let tokens = estimator.estimate_tokens(&[content]);
278        assert!(tokens > 0);
279    }
280
281    #[test]
282    fn build_estimation_contents_includes_tools_and_config() {
283        let declaration = FunctionDeclaration {
284            name: "search".to_string(),
285            description: Some("desc".to_string()),
286            parameters: Some(
287                Schema::object()
288                    .property("q", Schema::string())
289                    .required("q")
290                    .build(),
291            ),
292            parameters_json_schema: Some(
293                json!({"type": "object", "properties": {"q": {"type": "string"}}}),
294            ),
295            response: Some(Schema::string()),
296            response_json_schema: Some(json!({"type": "string"})),
297            behavior: None,
298        };
299        let tool = Tool {
300            function_declarations: Some(vec![declaration]),
301            ..Default::default()
302        };
303        let generation_config = GenerationConfig {
304            response_schema: Some(Schema::object().property("r", Schema::string()).build()),
305            response_json_schema: Some(
306                json!({"type": "object", "properties": {"r": {"type": "string"}}}),
307            ),
308            ..Default::default()
309        };
310
311        let config = CountTokensConfig {
312            system_instruction: Some(Content::text("sys")),
313            tools: Some(vec![tool]),
314            generation_config: Some(generation_config),
315        };
316
317        let contents = vec![Content::text("user")];
318        let combined = build_estimation_contents(&contents, &config);
319        // 原始内容 + 系统指令 + 追加文本内容
320        assert!(combined.len() >= 2);
321    }
322
323    #[test]
324    fn text_accumulator_collects_schema_and_json_fields() {
325        let mut properties = HashMap::new();
326        properties.insert("prop".to_string(), Box::new(Schema::string()));
327        let schema = Schema {
328            title: Some("Title".into()),
329            format: Some("Fmt".into()),
330            description: Some("Desc".into()),
331            enum_values: Some(vec!["A".into(), "B".into()]),
332            required: Some(vec!["req".into()]),
333            properties: Some(properties),
334            items: Some(Box::new(Schema::number())),
335            any_of: Some(vec![Schema::boolean()]),
336            example: Some(json!({"ex_key": "ex_val"})),
337            default: Some(json!(["d"])),
338            ..Default::default()
339        };
340
341        let mut accumulator = TextAccumulator::default();
342        accumulator.add_schema(&schema);
343        accumulator.add_json(&json!(["a", {"k": "v"}, 1]));
344        let texts = accumulator.texts;
345
346        assert!(texts.contains(&"Title".to_string()));
347        assert!(texts.contains(&"Fmt".to_string()));
348        assert!(texts.contains(&"Desc".to_string()));
349        assert!(texts.contains(&"A".to_string()));
350        assert!(texts.contains(&"B".to_string()));
351        assert!(texts.contains(&"req".to_string()));
352        assert!(texts.contains(&"prop".to_string()));
353        assert!(texts.contains(&"ex_key".to_string()));
354        assert!(texts.contains(&"ex_val".to_string()));
355        assert!(texts.contains(&"k".to_string()));
356        assert!(texts.contains(&"v".to_string()));
357        assert!(texts.contains(&"a".to_string()));
358    }
359
360    #[test]
361    fn text_accumulator_collects_function_parts() {
362        let call = FunctionCall {
363            id: None,
364            name: None,
365            args: Some(json!({"q": "rust"})),
366            partial_args: None,
367            will_continue: None,
368        };
369        let response = FunctionResponse {
370            will_continue: None,
371            scheduling: None,
372            parts: None,
373            id: None,
374            name: None,
375            response: Some(json!({"answer": "ok"})),
376        };
377        let content = Content::from_parts(
378            vec![Part::function_call(call), Part::function_response(response)],
379            Role::User,
380        );
381
382        let mut accumulator = TextAccumulator::default();
383        accumulator.add_function_texts_from_content(&content);
384        let texts = accumulator.texts;
385
386        assert!(texts.contains(&"q".to_string()));
387        assert!(texts.contains(&"rust".to_string()));
388        assert!(texts.contains(&"answer".to_string()));
389        assert!(texts.contains(&"ok".to_string()));
390    }
391
392    #[test]
393    fn text_accumulator_collects_named_parts_and_declarations() {
394        let call = FunctionCall {
395            id: None,
396            name: Some("lookup".into()),
397            args: Some(json!({"k": "v"})),
398            partial_args: None,
399            will_continue: None,
400        };
401        let response = FunctionResponse {
402            will_continue: None,
403            scheduling: None,
404            parts: None,
405            id: None,
406            name: Some("lookup_result".into()),
407            response: Some(json!({"out": "done"})),
408        };
409        let content = Content::from_parts(
410            vec![Part::function_call(call), Part::function_response(response)],
411            Role::User,
412        );
413
414        let declaration = FunctionDeclaration {
415            name: "search".to_string(),
416            description: Some("desc".to_string()),
417            parameters: Some(Schema::object().property("q", Schema::string()).build()),
418            parameters_json_schema: Some(
419                json!({"type": "object", "properties": {"q": {"type": "string"}}}),
420            ),
421            response: Some(Schema::string()),
422            response_json_schema: Some(json!({"type": "string"})),
423            behavior: None,
424        };
425
426        let generation_config = GenerationConfig {
427            response_schema: Some(Schema::string()),
428            response_json_schema: Some(json!({"type": "string"})),
429            ..Default::default()
430        };
431
432        let mut accumulator = TextAccumulator::default();
433        accumulator.add_function_texts_from_content(&content);
434        accumulator.add_function_declaration(&declaration);
435        accumulator.add_generation_config(&generation_config);
436        let texts = accumulator.texts;
437
438        assert!(texts.contains(&"lookup".to_string()));
439        assert!(texts.contains(&"lookup_result".to_string()));
440        assert!(texts.contains(&"k".to_string()));
441        assert!(texts.contains(&"v".to_string()));
442        assert!(texts.contains(&"out".to_string()));
443        assert!(texts.contains(&"done".to_string()));
444        assert!(texts.contains(&"search".to_string()));
445        assert!(texts.contains(&"desc".to_string()));
446        assert!(texts.contains(&"q".to_string()));
447    }
448
449    #[test]
450    fn simple_token_estimator_counts_function_names() {
451        let call = FunctionCall {
452            id: None,
453            name: Some("ping".into()),
454            args: None,
455            partial_args: None,
456            will_continue: None,
457        };
458        let response = FunctionResponse {
459            will_continue: None,
460            scheduling: None,
461            parts: None,
462            id: None,
463            name: Some("pong".into()),
464            response: None,
465        };
466        let content = Content::from_parts(
467            vec![Part::function_call(call), Part::function_response(response)],
468            Role::User,
469        );
470
471        let estimator = SimpleTokenEstimator;
472        let tokens = estimator.estimate_tokens(&[content]);
473        assert_eq!(tokens, 2);
474    }
475
476    #[test]
477    fn simple_token_estimator_empty_is_zero() {
478        let estimator = SimpleTokenEstimator;
479        let tokens = estimator.estimate_tokens(&[]);
480        assert_eq!(tokens, 0);
481    }
482}
483
484#[cfg(feature = "kitoken")]
485pub mod kitoken {
486    use super::TokenEstimator;
487    use base64::engine::general_purpose::STANDARD;
488    use base64::Engine as _;
489    use kitoken::convert::ConversionError;
490    use kitoken::EncodeError;
491    use kitoken::Kitoken;
492    use rust_genai_types::content::{
493        Content, FunctionCall, FunctionResponse, Part, PartKind, Role,
494    };
495    use rust_genai_types::models::{ComputeTokensResponse, TokensInfo};
496    use sha2::{Digest, Sha256};
497    use std::collections::HashMap;
498    use std::fmt::Write;
499    use std::fs;
500    use std::path::{Path, PathBuf};
501    use std::sync::Arc;
502
503    const CACHE_DIR: &str = "vertexai_tokenizer_model";
504
505    struct TokenizerConfig {
506        model_url: &'static str,
507        model_hash: &'static str,
508    }
509
510    const GEMINI_MODELS_TO_TOKENIZER_NAMES: &[(&str, &str)] = &[
511        ("gemini-1.0-pro", "gemma2"),
512        ("gemini-1.5-pro", "gemma2"),
513        ("gemini-1.5-flash", "gemma2"),
514        ("gemini-2.5-pro", "gemma3"),
515        ("gemini-2.5-flash", "gemma3"),
516        ("gemini-2.5-flash-lite", "gemma3"),
517        ("gemini-2.0-flash", "gemma3"),
518        ("gemini-2.0-flash-lite", "gemma3"),
519        ("gemini-3-flash-preview", "gemma3"),
520        ("gemini-3.1-flash-lite-preview", "gemma3"),
521        ("gemini-3.1-flash-image-preview", "gemma3"),
522        ("gemini-3.1-pro-preview", "gemma3"),
523        ("gemini-3-pro-preview", "gemma3"),
524        ("gemini-3-pro-image-preview", "gemma3"),
525    ];
526
527    const GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES: &[(&str, &str)] = &[
528        ("gemini-1.0-pro-001", "gemma2"),
529        ("gemini-1.0-pro-002", "gemma2"),
530        ("gemini-1.5-pro-001", "gemma2"),
531        ("gemini-1.5-pro-002", "gemma2"),
532        ("gemini-1.5-flash-001", "gemma2"),
533        ("gemini-1.5-flash-002", "gemma2"),
534        ("gemini-2.5-pro-preview-06-05", "gemma3"),
535        ("gemini-2.5-pro-preview-05-06", "gemma3"),
536        ("gemini-2.5-pro-exp-03-25", "gemma3"),
537        ("gemini-live-2.5-flash", "gemma3"),
538        ("gemini-3.1-flash-live-preview", "gemma3"),
539        ("gemini-3.1-flash-tts-preview", "gemma3"),
540        ("gemini-2.5-flash-native-audio-preview-12-2025", "gemma3"),
541        ("gemini-2.5-flash-native-audio-preview-09-2025", "gemma3"),
542        ("gemini-2.5-flash-preview-05-20", "gemma3"),
543        ("gemini-2.5-flash-preview-04-17", "gemma3"),
544        ("gemini-2.5-flash-lite-preview-06-17", "gemma3"),
545        ("gemini-2.0-flash-001", "gemma3"),
546        ("gemini-2.0-flash-lite-001", "gemma3"),
547        ("gemini-3-pro-preview", "gemma3"),
548    ];
549
550    fn tokenizer_config(name: &str) -> Option<TokenizerConfig> {
551        match name {
552            "gemma2" => Some(TokenizerConfig {
553                model_url: "https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
554                model_hash: "61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2",
555            }),
556            "gemma3" => Some(TokenizerConfig {
557                model_url: "https://raw.githubusercontent.com/google/gemma_pytorch/014acb7ac4563a5f77c76d7ff98f31b568c16508/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
558                model_hash: "1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
559            }),
560            _ => None,
561        }
562    }
563
564    #[derive(Debug, thiserror::Error)]
565    pub enum LocalTokenizerError {
566        #[error("Model {model} is not supported. Supported models: {supported}")]
567        UnsupportedModel { model: String, supported: String },
568        #[error("Tokenizer {name} is not supported")]
569        UnsupportedTokenizer { name: String },
570        #[error("Failed to download tokenizer model from {url}: {source}")]
571        Download {
572            url: String,
573            #[source]
574            source: reqwest::Error,
575        },
576        #[error("Tokenizer model download returned non-success status {status} for {url}")]
577        DownloadStatus { url: String, status: u16 },
578        #[error("Tokenizer model hash mismatch. expected {expected}, got {actual}")]
579        HashMismatch { expected: String, actual: String },
580        #[error("IO error: {source}")]
581        Io {
582            #[from]
583            source: std::io::Error,
584        },
585        #[error("Tokenizer encode error: {source}")]
586        Encode {
587            #[from]
588            source: EncodeError,
589        },
590        #[error("Local tokenizer does not support non-text content: {kind}")]
591        UnsupportedContent { kind: &'static str },
592        #[error("Tokenizer token id {id} not found in vocabulary")]
593        MissingToken { id: u32 },
594        #[error("Tokenizer conversion error: {source}")]
595        Conversion {
596            #[from]
597            source: ConversionError,
598        },
599    }
600
601    /// Kitoken-based estimator (`SentencePiece` compatible).
602    #[derive(Debug, Clone)]
603    pub struct KitokenEstimator {
604        encoder: Arc<Kitoken>,
605        token_bytes: Arc<HashMap<u32, Vec<u8>>>,
606    }
607
608    impl KitokenEstimator {
609        fn from_encoder(encoder: Kitoken) -> Self {
610            let token_bytes = Arc::new(build_token_bytes_map(&encoder));
611            Self {
612                encoder: Arc::new(encoder),
613                token_bytes,
614            }
615        }
616
617        /// Load a `SentencePiece` model from file.
618        ///
619        /// # Errors
620        /// 当模型加载失败或文件无效时返回错误。
621        pub fn from_sentencepiece_file(
622            path: impl AsRef<Path>,
623        ) -> Result<Self, LocalTokenizerError> {
624            let encoder = Kitoken::from_sentencepiece_file(path)?;
625            Ok(Self::from_encoder(encoder))
626        }
627
628        /// Load a known Gemini model tokenizer by model name (downloads & caches).
629        ///
630        /// # Errors
631        /// 当模型名未知、下载失败或解析失败时返回错误。
632        pub async fn from_model_name(model_name: &str) -> Result<Self, LocalTokenizerError> {
633            let tokenizer_name = get_tokenizer_name(model_name)?;
634            let config = tokenizer_config(tokenizer_name).ok_or_else(|| {
635                LocalTokenizerError::UnsupportedTokenizer {
636                    name: tokenizer_name.to_string(),
637                }
638            })?;
639            let model_bytes = load_model_bytes(config.model_url, config.model_hash).await?;
640            let encoder = Kitoken::from_sentencepiece_slice(&model_bytes)?;
641            Ok(Self::from_encoder(encoder))
642        }
643
644        /// Compute token ids and token bytes for text contents.
645        ///
646        /// # Errors
647        /// 当内容不受支持或编码失败时返回错误。
648        pub fn compute_tokens(
649            &self,
650            contents: &[Content],
651        ) -> Result<ComputeTokensResponse, LocalTokenizerError> {
652            let mut tokens_info: Vec<TokensInfo> = Vec::new();
653            for content in contents {
654                let role = content
655                    .role
656                    .map(|role| match role {
657                        Role::User => "user",
658                        Role::Model => "model",
659                        Role::Function => "function",
660                    })
661                    .map(ToString::to_string);
662
663                for part in &content.parts {
664                    let texts = collect_part_texts(part)?;
665                    if texts.is_empty() {
666                        continue;
667                    }
668                    let mut token_ids = Vec::new();
669                    let mut tokens = Vec::new();
670                    for text in texts {
671                        if text.is_empty() {
672                            continue;
673                        }
674                        let ids = self.encoder.encode(&text, true)?;
675                        for id in ids {
676                            let bytes = self
677                                .token_bytes
678                                .get(&id)
679                                .ok_or(LocalTokenizerError::MissingToken { id })?;
680                            tokens.push(STANDARD.encode(bytes));
681                            token_ids.push(i64::from(id));
682                        }
683                    }
684                    if token_ids.is_empty() {
685                        continue;
686                    }
687                    tokens_info.push(TokensInfo {
688                        role: role.clone(),
689                        token_ids: Some(token_ids),
690                        tokens: Some(tokens),
691                    });
692                }
693            }
694
695            Ok(ComputeTokensResponse {
696                sdk_http_response: None,
697                tokens_info: Some(tokens_info),
698            })
699        }
700    }
701
702    impl TokenEstimator for KitokenEstimator {
703        fn estimate_tokens(&self, contents: &[Content]) -> usize {
704            let mut total = 0usize;
705            for content in contents {
706                for part in &content.parts {
707                    if let Some(text) = part.text_value() {
708                        if let Ok(tokens) = self.encoder.encode(text, true) {
709                            total += tokens.len();
710                        }
711                    }
712                }
713            }
714            total
715        }
716    }
717
718    fn build_token_bytes_map(encoder: &Kitoken) -> HashMap<u32, Vec<u8>> {
719        let definition = encoder.to_definition();
720        let mut map = HashMap::new();
721        for token in definition.model.vocab() {
722            map.insert(token.id, normalize_token_bytes(&token.bytes));
723        }
724        for special in definition.specials {
725            map.insert(special.id, normalize_token_bytes(&special.bytes));
726        }
727        map
728    }
729
730    fn normalize_token_bytes(bytes: &[u8]) -> Vec<u8> {
731        std::str::from_utf8(bytes).map_or_else(
732            |_| bytes.to_vec(),
733            |text| text.replace('▁', " ").into_bytes(),
734        )
735    }
736
737    fn collect_part_texts(part: &Part) -> Result<Vec<String>, LocalTokenizerError> {
738        let mut texts = Vec::new();
739        match &part.kind {
740            PartKind::Text { text } => {
741                if !text.is_empty() {
742                    texts.push(text.clone());
743                }
744            }
745            PartKind::FunctionCall { function_call } => {
746                add_function_call_texts(function_call, &mut texts);
747            }
748            PartKind::FunctionResponse { function_response } => {
749                add_function_response_texts(function_response, &mut texts);
750            }
751            PartKind::ExecutableCode { executable_code } => {
752                if !executable_code.code.is_empty() {
753                    texts.push(executable_code.code.clone());
754                }
755            }
756            PartKind::CodeExecutionResult {
757                code_execution_result,
758            } => {
759                if let Some(output) = &code_execution_result.output {
760                    if !output.is_empty() {
761                        texts.push(output.clone());
762                    }
763                }
764            }
765            PartKind::InlineData { .. } => {
766                return Err(LocalTokenizerError::UnsupportedContent {
767                    kind: "inline_data",
768                });
769            }
770            PartKind::FileData { .. } => {
771                return Err(LocalTokenizerError::UnsupportedContent { kind: "file_data" });
772            }
773        }
774        Ok(texts)
775    }
776
777    fn add_function_call_texts(function_call: &FunctionCall, texts: &mut Vec<String>) {
778        if let Some(name) = &function_call.name {
779            if !name.is_empty() {
780                texts.push(name.clone());
781            }
782        }
783        if let Some(args) = &function_call.args {
784            add_json_texts(args, texts);
785        }
786    }
787
788    fn add_function_response_texts(function_response: &FunctionResponse, texts: &mut Vec<String>) {
789        if let Some(name) = &function_response.name {
790            if !name.is_empty() {
791                texts.push(name.clone());
792            }
793        }
794        if let Some(response) = &function_response.response {
795            add_json_texts(response, texts);
796        }
797    }
798
799    fn add_json_texts(value: &serde_json::Value, texts: &mut Vec<String>) {
800        match value {
801            serde_json::Value::String(value) if !value.is_empty() => {
802                texts.push(value.clone());
803            }
804            serde_json::Value::Array(values) => {
805                for item in values {
806                    add_json_texts(item, texts);
807                }
808            }
809            serde_json::Value::Object(map) => {
810                for (key, value) in map {
811                    if !key.is_empty() {
812                        texts.push(key.clone());
813                    }
814                    add_json_texts(value, texts);
815                }
816            }
817            _ => {}
818        }
819    }
820
821    fn get_tokenizer_name(model_name: &str) -> Result<&'static str, LocalTokenizerError> {
822        for (name, tokenizer) in GEMINI_MODELS_TO_TOKENIZER_NAMES {
823            if *name == model_name {
824                return Ok(*tokenizer);
825            }
826        }
827        for (name, tokenizer) in GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES {
828            if *name == model_name {
829                return Ok(*tokenizer);
830            }
831        }
832        let mut supported: Vec<String> = GEMINI_MODELS_TO_TOKENIZER_NAMES
833            .iter()
834            .map(|(name, _)| (*name).to_string())
835            .collect();
836        supported.extend(
837            GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES
838                .iter()
839                .map(|(name, _)| (*name).to_string()),
840        );
841        supported.sort();
842        supported.dedup();
843        Err(LocalTokenizerError::UnsupportedModel {
844            model: model_name.to_string(),
845            supported: supported.join(", "),
846        })
847    }
848
849    async fn load_model_bytes(
850        url: &str,
851        expected_hash: &str,
852    ) -> Result<Vec<u8>, LocalTokenizerError> {
853        let cache_path = cache_path_for(url);
854        if let Some(bytes) = read_cache(&cache_path, expected_hash)? {
855            return Ok(bytes);
856        }
857        let bytes = download_model(url).await?;
858        let actual_hash = sha256_hex(&bytes);
859        if actual_hash != expected_hash {
860            return Err(LocalTokenizerError::HashMismatch {
861                expected: expected_hash.to_string(),
862                actual: actual_hash,
863            });
864        }
865        let _ = write_cache(&cache_path, &bytes);
866        Ok(bytes)
867    }
868
869    fn cache_path_for(url: &str) -> PathBuf {
870        let filename = sha256_hex(url.as_bytes());
871        std::env::temp_dir().join(CACHE_DIR).join(filename)
872    }
873
874    fn read_cache(
875        path: &Path,
876        expected_hash: &str,
877    ) -> Result<Option<Vec<u8>>, LocalTokenizerError> {
878        if !path.exists() {
879            return Ok(None);
880        }
881        let bytes = fs::read(path)?;
882        if sha256_hex(&bytes) == expected_hash {
883            return Ok(Some(bytes));
884        }
885        let _ = fs::remove_file(path);
886        Ok(None)
887    }
888
889    fn write_cache(path: &Path, bytes: &[u8]) -> Result<(), LocalTokenizerError> {
890        if let Some(parent) = path.parent() {
891            fs::create_dir_all(parent)?;
892        }
893        let tmp_path = path.with_extension("tmp");
894        fs::write(&tmp_path, bytes)?;
895        fs::rename(tmp_path, path)?;
896        Ok(())
897    }
898
899    async fn download_model(url: &str) -> Result<Vec<u8>, LocalTokenizerError> {
900        let response = reqwest::get(url)
901            .await
902            .map_err(|source| LocalTokenizerError::Download {
903                url: url.to_string(),
904                source,
905            })?;
906        let status = response.status();
907        if !status.is_success() {
908            return Err(LocalTokenizerError::DownloadStatus {
909                url: url.to_string(),
910                status: status.as_u16(),
911            });
912        }
913        let bytes = response
914            .bytes()
915            .await
916            .map_err(|source| LocalTokenizerError::Download {
917                url: url.to_string(),
918                source,
919            })?;
920        Ok(bytes.to_vec())
921    }
922
923    fn sha256_hex(data: &[u8]) -> String {
924        let digest = Sha256::digest(data);
925        let mut output = String::with_capacity(digest.len() * 2);
926        for byte in digest {
927            let _ = write!(output, "{byte:02x}");
928        }
929        output
930    }
931
932    #[cfg(test)]
933    mod tests {
934        use super::*;
935        use rust_genai_types::content::{Content, FunctionCall, FunctionResponse, Part, Role};
936        use rust_genai_types::enums::{Language, Outcome};
937        use serde_json::json;
938        use std::fs;
939        use std::time::{SystemTime, UNIX_EPOCH};
940
941        fn build_test_encoder() -> Kitoken {
942            let vocab = vec![
943                kitoken::Token {
944                    id: 0,
945                    bytes: b"hi".to_vec(),
946                },
947                kitoken::Token {
948                    id: 1,
949                    bytes: b"lookup".to_vec(),
950                },
951                kitoken::Token {
952                    id: 2,
953                    bytes: b"q".to_vec(),
954                },
955                kitoken::Token {
956                    id: 3,
957                    bytes: b"rust".to_vec(),
958                },
959                kitoken::Token {
960                    id: 4,
961                    bytes: b"resp".to_vec(),
962                },
963                kitoken::Token {
964                    id: 5,
965                    bytes: b"ok".to_vec(),
966                },
967                kitoken::Token {
968                    id: 6,
969                    bytes: b"code".to_vec(),
970                },
971                kitoken::Token {
972                    id: 7,
973                    bytes: b"out".to_vec(),
974                },
975                kitoken::Token {
976                    id: 8,
977                    bytes: "\u{2581}".as_bytes().to_vec(),
978                },
979            ];
980            let specials = vec![kitoken::SpecialToken {
981                id: 99,
982                bytes: b"[UNK]".to_vec(),
983                kind: kitoken::SpecialTokenKind::Unknown,
984                ident: None,
985                score: 0.0,
986                extract: false,
987            }];
988            let model = kitoken::Model::WordPiece {
989                vocab,
990                max_word_chars: 0,
991            };
992            let config = kitoken::Configuration::default();
993            let meta = kitoken::Metadata::default();
994            Kitoken::new(model, specials, config, meta).unwrap()
995        }
996
997        fn unique_cache_key(tag: &str) -> String {
998            let nanos = SystemTime::now()
999                .duration_since(UNIX_EPOCH)
1000                .unwrap_or_default()
1001                .as_nanos();
1002            format!("test://{tag}-{nanos}")
1003        }
1004
1005        #[test]
1006        fn get_tokenizer_name_known_and_unknown() {
1007            assert_eq!(get_tokenizer_name("gemini-1.5-pro").unwrap(), "gemma2");
1008            let err = get_tokenizer_name("unknown-model").unwrap_err();
1009            match err {
1010                LocalTokenizerError::UnsupportedModel { supported, .. } => {
1011                    assert!(supported.contains("gemini-1.0-pro"));
1012                }
1013                _ => panic!("expected UnsupportedModel error"),
1014            }
1015        }
1016
1017        #[test]
1018        fn normalize_token_bytes_replaces_separator_and_handles_invalid_utf8() {
1019            let replaced = normalize_token_bytes("\u{2581}hi".as_bytes());
1020            assert_eq!(replaced, b" hi".to_vec());
1021
1022            let invalid = normalize_token_bytes(&[0xff, 0xfe]);
1023            assert_eq!(invalid, vec![0xff, 0xfe]);
1024        }
1025
1026        #[test]
1027        fn cache_roundtrip_and_mismatch_evicts() {
1028            let key = unique_cache_key("cache-roundtrip");
1029            let path = cache_path_for(&key);
1030            let _ = fs::remove_file(&path);
1031
1032            let bytes = b"cached".to_vec();
1033            write_cache(&path, &bytes).unwrap();
1034            let hash = sha256_hex(&bytes);
1035            let cached = read_cache(&path, &hash).unwrap().unwrap();
1036            assert_eq!(cached, bytes);
1037
1038            let wrong_hash = sha256_hex(b"other");
1039            let result = read_cache(&path, &wrong_hash).unwrap();
1040            assert!(result.is_none());
1041            assert!(!path.exists());
1042        }
1043
1044        #[tokio::test]
1045        async fn load_model_bytes_uses_cache() {
1046            let key = unique_cache_key("load-cache");
1047            let path = cache_path_for(&key);
1048            let _ = fs::remove_file(&path);
1049
1050            let bytes = b"model-bytes".to_vec();
1051            write_cache(&path, &bytes).unwrap();
1052            let hash = sha256_hex(&bytes);
1053
1054            let loaded = load_model_bytes(&key, &hash).await.unwrap();
1055            assert_eq!(loaded, bytes);
1056        }
1057
1058        #[test]
1059        fn collect_part_texts_rejects_binary_parts() {
1060            let inline = Part::inline_data(vec![1, 2, 3], "image/png");
1061            let err = collect_part_texts(&inline).unwrap_err();
1062            assert!(matches!(
1063                err,
1064                LocalTokenizerError::UnsupportedContent {
1065                    kind: "inline_data"
1066                }
1067            ));
1068
1069            let file = Part::file_data("files/1", "application/pdf");
1070            let err = collect_part_texts(&file).unwrap_err();
1071            assert!(matches!(
1072                err,
1073                LocalTokenizerError::UnsupportedContent { kind: "file_data" }
1074            ));
1075        }
1076
1077        #[test]
1078        fn kitoken_estimator_compute_tokens_and_map_normalization() {
1079            let encoder = build_test_encoder();
1080            let estimator = KitokenEstimator::from_encoder(encoder);
1081
1082            let call = FunctionCall {
1083                id: None,
1084                name: Some("lookup".into()),
1085                args: Some(json!({"q": "rust"})),
1086                partial_args: None,
1087                will_continue: None,
1088            };
1089            let response = FunctionResponse {
1090                will_continue: None,
1091                scheduling: None,
1092                parts: None,
1093                id: None,
1094                name: Some("resp".into()),
1095                response: Some(json!({"ok": "ok"})),
1096            };
1097            let content = Content::from_parts(
1098                vec![
1099                    Part::text("hi"),
1100                    Part::function_call(call),
1101                    Part::function_response(response),
1102                    Part::executable_code("code", Language::Python),
1103                    Part::code_execution_result(Outcome::OutcomeOk, "out"),
1104                ],
1105                Role::User,
1106            );
1107
1108            let result = estimator.compute_tokens(&[content]).unwrap();
1109            assert!(!result.tokens_info.as_ref().unwrap().is_empty());
1110
1111            let estimated = estimator.estimate_tokens(&[Content::text("hi")]);
1112            assert!(estimated > 0);
1113
1114            let normalized = estimator.token_bytes.get(&8).unwrap();
1115            assert_eq!(normalized.as_slice(), b" ");
1116        }
1117    }
1118}