Skip to main content

rust_genai/
tokenizer.rs

1//! Local token estimation utilities (optional).
2
3use rust_genai_types::config::GenerationConfig;
4use rust_genai_types::content::{Content, FunctionCall, FunctionResponse, PartKind};
5use rust_genai_types::models::CountTokensConfig;
6use rust_genai_types::tool::{FunctionDeclaration, Schema, Tool};
7use serde_json::Value;
8
9/// Token estimator trait.
10pub trait TokenEstimator {
11    fn estimate_tokens(&self, contents: &[Content]) -> usize;
12}
13
14/// A very rough heuristic estimator (bytes / 4).
15#[derive(Debug, Clone, Default)]
16pub struct SimpleTokenEstimator;
17
18impl TokenEstimator for SimpleTokenEstimator {
19    fn estimate_tokens(&self, contents: &[Content]) -> usize {
20        let mut bytes = 0usize;
21        for content in contents {
22            for part in &content.parts {
23                match &part.kind {
24                    PartKind::Text { text } => {
25                        bytes += text.len();
26                    }
27                    PartKind::InlineData { inline_data } => {
28                        bytes += inline_data.data.len();
29                    }
30                    PartKind::FileData { file_data } => {
31                        bytes += file_data.file_uri.len();
32                    }
33                    PartKind::FunctionCall { function_call } => {
34                        if let Some(name) = &function_call.name {
35                            bytes += name.len();
36                        }
37                    }
38                    PartKind::FunctionResponse { function_response } => {
39                        if let Some(name) = &function_response.name {
40                            bytes += name.len();
41                        }
42                    }
43                    PartKind::ExecutableCode { executable_code } => {
44                        bytes += executable_code.code.len();
45                    }
46                    PartKind::CodeExecutionResult {
47                        code_execution_result,
48                    } => {
49                        if let Some(output) = &code_execution_result.output {
50                            bytes += output.len();
51                        }
52                    }
53                }
54            }
55        }
56        // Rough heuristic: 1 token ~ 4 bytes.
57        bytes.div_ceil(4)
58    }
59}
60
61pub(crate) fn build_estimation_contents(
62    contents: &[Content],
63    config: &CountTokensConfig,
64) -> Vec<Content> {
65    let mut combined = Vec::with_capacity(contents.len() + 1);
66    combined.extend_from_slice(contents);
67    if let Some(system_instruction) = &config.system_instruction {
68        combined.push(system_instruction.clone());
69    }
70
71    let mut accumulator = TextAccumulator::default();
72    accumulator.add_function_texts_from_contents(&combined);
73    if let Some(tools) = &config.tools {
74        accumulator.add_tools(tools);
75    }
76    if let Some(generation_config) = &config.generation_config {
77        accumulator.add_generation_config(generation_config);
78    }
79    combined.extend(accumulator.into_contents());
80    combined
81}
82
83#[derive(Debug, Default)]
84struct TextAccumulator {
85    texts: Vec<String>,
86}
87
88impl TextAccumulator {
89    fn add_function_texts_from_contents(&mut self, contents: &[Content]) {
90        for content in contents {
91            self.add_function_texts_from_content(content);
92        }
93    }
94
95    fn add_function_texts_from_content(&mut self, content: &Content) {
96        for part in &content.parts {
97            match &part.kind {
98                PartKind::FunctionCall { function_call } => {
99                    self.add_function_call(function_call);
100                }
101                PartKind::FunctionResponse { function_response } => {
102                    self.add_function_response(function_response);
103                }
104                _ => {}
105            }
106        }
107    }
108
109    fn add_function_call(&mut self, function_call: &FunctionCall) {
110        if let Some(name) = &function_call.name {
111            self.push_text(name);
112        }
113        if let Some(args) = &function_call.args {
114            self.add_json(args);
115        }
116    }
117
118    fn add_function_response(&mut self, function_response: &FunctionResponse) {
119        if let Some(name) = &function_response.name {
120            self.push_text(name);
121        }
122        if let Some(response) = &function_response.response {
123            self.add_json(response);
124        }
125    }
126
127    fn add_tools(&mut self, tools: &[Tool]) {
128        for tool in tools {
129            if let Some(functions) = &tool.function_declarations {
130                for function in functions {
131                    self.add_function_declaration(function);
132                }
133            }
134        }
135    }
136
137    fn add_function_declaration(&mut self, declaration: &FunctionDeclaration) {
138        self.push_text(&declaration.name);
139        if let Some(description) = &declaration.description {
140            self.push_text(description);
141        }
142        if let Some(parameters) = &declaration.parameters {
143            self.add_schema(parameters);
144        }
145        if let Some(response) = &declaration.response {
146            self.add_schema(response);
147        }
148        if let Some(parameters_json) = &declaration.parameters_json_schema {
149            self.add_json(parameters_json);
150        }
151        if let Some(response_json) = &declaration.response_json_schema {
152            self.add_json(response_json);
153        }
154    }
155
156    fn add_generation_config(&mut self, generation_config: &GenerationConfig) {
157        if let Some(response_schema) = &generation_config.response_schema {
158            self.add_schema(response_schema);
159        }
160        if let Some(response_json_schema) = &generation_config.response_json_schema {
161            self.add_json(response_json_schema);
162        }
163    }
164
165    fn add_schema(&mut self, schema: &Schema) {
166        if let Some(title) = &schema.title {
167            self.push_text(title);
168        }
169        if let Some(format) = &schema.format {
170            self.push_text(format);
171        }
172        if let Some(description) = &schema.description {
173            self.push_text(description);
174        }
175        if let Some(enum_values) = &schema.enum_values {
176            for value in enum_values {
177                self.push_text(value);
178            }
179        }
180        if let Some(required) = &schema.required {
181            for value in required {
182                self.push_text(value);
183            }
184        }
185        if let Some(properties) = &schema.properties {
186            for (key, value) in properties {
187                self.push_text(key);
188                self.add_schema(value);
189            }
190        }
191        if let Some(items) = &schema.items {
192            self.add_schema(items);
193        }
194        if let Some(any_of) = &schema.any_of {
195            for schema in any_of {
196                self.add_schema(schema);
197            }
198        }
199        if let Some(example) = &schema.example {
200            self.add_json(example);
201        }
202        if let Some(default) = &schema.default {
203            self.add_json(default);
204        }
205    }
206
207    fn add_json(&mut self, value: &Value) {
208        match value {
209            Value::String(value) => self.push_text(value),
210            Value::Array(values) => {
211                for item in values {
212                    self.add_json(item);
213                }
214            }
215            Value::Object(map) => {
216                for (key, value) in map {
217                    self.push_text(key);
218                    self.add_json(value);
219                }
220            }
221            _ => {}
222        }
223    }
224
225    fn push_text(&mut self, value: &str) {
226        if !value.is_empty() {
227            self.texts.push(value.to_string());
228        }
229    }
230
231    fn into_contents(self) -> Vec<Content> {
232        self.texts.into_iter().map(Content::text).collect()
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use rust_genai_types::config::GenerationConfig;
240    use rust_genai_types::content::{FunctionCall, FunctionResponse, Part, Role};
241    use rust_genai_types::models::CountTokensConfig;
242    use rust_genai_types::tool::{FunctionDeclaration, Schema, Tool};
243    use serde_json::json;
244    use std::collections::HashMap;
245
246    #[test]
247    fn simple_token_estimator_counts_various_parts() {
248        let call = FunctionCall {
249            id: Some("call-1".into()),
250            name: Some("lookup".into()),
251            args: Some(json!({"q": "rust"})),
252            partial_args: None,
253            will_continue: None,
254        };
255        let response = FunctionResponse {
256            will_continue: None,
257            scheduling: None,
258            parts: None,
259            id: Some("resp-1".into()),
260            name: Some("lookup".into()),
261            response: Some(json!({"ok": true})),
262        };
263        let content = Content::from_parts(
264            vec![
265                Part::text("hello"),
266                Part::inline_data(vec![0, 1, 2, 3], "image/png"),
267                Part::file_data("files/abc", "application/pdf"),
268                Part::function_call(call),
269                Part::function_response(response),
270                Part::executable_code("print('hi')", rust_genai_types::enums::Language::Python),
271                Part::code_execution_result(rust_genai_types::enums::Outcome::OutcomeOk, "ok"),
272            ],
273            Role::User,
274        );
275
276        let estimator = SimpleTokenEstimator;
277        let tokens = estimator.estimate_tokens(&[content]);
278        assert!(tokens > 0);
279    }
280
281    #[test]
282    fn build_estimation_contents_includes_tools_and_config() {
283        let declaration = FunctionDeclaration {
284            name: "search".to_string(),
285            description: Some("desc".to_string()),
286            parameters: Some(
287                Schema::object()
288                    .property("q", Schema::string())
289                    .required("q")
290                    .build(),
291            ),
292            parameters_json_schema: Some(
293                json!({"type": "object", "properties": {"q": {"type": "string"}}}),
294            ),
295            response: Some(Schema::string()),
296            response_json_schema: Some(json!({"type": "string"})),
297            behavior: None,
298        };
299        let tool = Tool {
300            function_declarations: Some(vec![declaration]),
301            ..Default::default()
302        };
303        let generation_config = GenerationConfig {
304            response_schema: Some(Schema::object().property("r", Schema::string()).build()),
305            response_json_schema: Some(
306                json!({"type": "object", "properties": {"r": {"type": "string"}}}),
307            ),
308            ..Default::default()
309        };
310
311        let config = CountTokensConfig {
312            system_instruction: Some(Content::text("sys")),
313            tools: Some(vec![tool]),
314            generation_config: Some(generation_config),
315        };
316
317        let contents = vec![Content::text("user")];
318        let combined = build_estimation_contents(&contents, &config);
319        // 原始内容 + 系统指令 + 追加文本内容
320        assert!(combined.len() >= 2);
321    }
322
323    #[test]
324    fn text_accumulator_collects_schema_and_json_fields() {
325        let mut properties = HashMap::new();
326        properties.insert("prop".to_string(), Box::new(Schema::string()));
327        let schema = Schema {
328            title: Some("Title".into()),
329            format: Some("Fmt".into()),
330            description: Some("Desc".into()),
331            enum_values: Some(vec!["A".into(), "B".into()]),
332            required: Some(vec!["req".into()]),
333            properties: Some(properties),
334            items: Some(Box::new(Schema::number())),
335            any_of: Some(vec![Schema::boolean()]),
336            example: Some(json!({"ex_key": "ex_val"})),
337            default: Some(json!(["d"])),
338            ..Default::default()
339        };
340
341        let mut accumulator = TextAccumulator::default();
342        accumulator.add_schema(&schema);
343        accumulator.add_json(&json!(["a", {"k": "v"}, 1]));
344        let texts = accumulator.texts;
345
346        assert!(texts.contains(&"Title".to_string()));
347        assert!(texts.contains(&"Fmt".to_string()));
348        assert!(texts.contains(&"Desc".to_string()));
349        assert!(texts.contains(&"A".to_string()));
350        assert!(texts.contains(&"B".to_string()));
351        assert!(texts.contains(&"req".to_string()));
352        assert!(texts.contains(&"prop".to_string()));
353        assert!(texts.contains(&"ex_key".to_string()));
354        assert!(texts.contains(&"ex_val".to_string()));
355        assert!(texts.contains(&"k".to_string()));
356        assert!(texts.contains(&"v".to_string()));
357        assert!(texts.contains(&"a".to_string()));
358    }
359
360    #[test]
361    fn text_accumulator_collects_function_parts() {
362        let call = FunctionCall {
363            id: None,
364            name: None,
365            args: Some(json!({"q": "rust"})),
366            partial_args: None,
367            will_continue: None,
368        };
369        let response = FunctionResponse {
370            will_continue: None,
371            scheduling: None,
372            parts: None,
373            id: None,
374            name: None,
375            response: Some(json!({"answer": "ok"})),
376        };
377        let content = Content::from_parts(
378            vec![Part::function_call(call), Part::function_response(response)],
379            Role::User,
380        );
381
382        let mut accumulator = TextAccumulator::default();
383        accumulator.add_function_texts_from_content(&content);
384        let texts = accumulator.texts;
385
386        assert!(texts.contains(&"q".to_string()));
387        assert!(texts.contains(&"rust".to_string()));
388        assert!(texts.contains(&"answer".to_string()));
389        assert!(texts.contains(&"ok".to_string()));
390    }
391
392    #[test]
393    fn text_accumulator_collects_named_parts_and_declarations() {
394        let call = FunctionCall {
395            id: None,
396            name: Some("lookup".into()),
397            args: Some(json!({"k": "v"})),
398            partial_args: None,
399            will_continue: None,
400        };
401        let response = FunctionResponse {
402            will_continue: None,
403            scheduling: None,
404            parts: None,
405            id: None,
406            name: Some("lookup_result".into()),
407            response: Some(json!({"out": "done"})),
408        };
409        let content = Content::from_parts(
410            vec![Part::function_call(call), Part::function_response(response)],
411            Role::User,
412        );
413
414        let declaration = FunctionDeclaration {
415            name: "search".to_string(),
416            description: Some("desc".to_string()),
417            parameters: Some(Schema::object().property("q", Schema::string()).build()),
418            parameters_json_schema: Some(
419                json!({"type": "object", "properties": {"q": {"type": "string"}}}),
420            ),
421            response: Some(Schema::string()),
422            response_json_schema: Some(json!({"type": "string"})),
423            behavior: None,
424        };
425
426        let generation_config = GenerationConfig {
427            response_schema: Some(Schema::string()),
428            response_json_schema: Some(json!({"type": "string"})),
429            ..Default::default()
430        };
431
432        let mut accumulator = TextAccumulator::default();
433        accumulator.add_function_texts_from_content(&content);
434        accumulator.add_function_declaration(&declaration);
435        accumulator.add_generation_config(&generation_config);
436        let texts = accumulator.texts;
437
438        assert!(texts.contains(&"lookup".to_string()));
439        assert!(texts.contains(&"lookup_result".to_string()));
440        assert!(texts.contains(&"k".to_string()));
441        assert!(texts.contains(&"v".to_string()));
442        assert!(texts.contains(&"out".to_string()));
443        assert!(texts.contains(&"done".to_string()));
444        assert!(texts.contains(&"search".to_string()));
445        assert!(texts.contains(&"desc".to_string()));
446        assert!(texts.contains(&"q".to_string()));
447    }
448
449    #[test]
450    fn simple_token_estimator_counts_function_names() {
451        let call = FunctionCall {
452            id: None,
453            name: Some("ping".into()),
454            args: None,
455            partial_args: None,
456            will_continue: None,
457        };
458        let response = FunctionResponse {
459            will_continue: None,
460            scheduling: None,
461            parts: None,
462            id: None,
463            name: Some("pong".into()),
464            response: None,
465        };
466        let content = Content::from_parts(
467            vec![Part::function_call(call), Part::function_response(response)],
468            Role::User,
469        );
470
471        let estimator = SimpleTokenEstimator;
472        let tokens = estimator.estimate_tokens(&[content]);
473        assert_eq!(tokens, 2);
474    }
475
476    #[test]
477    fn simple_token_estimator_empty_is_zero() {
478        let estimator = SimpleTokenEstimator;
479        let tokens = estimator.estimate_tokens(&[]);
480        assert_eq!(tokens, 0);
481    }
482}
483
484#[cfg(feature = "kitoken")]
485pub mod kitoken {
486    use super::TokenEstimator;
487    use base64::engine::general_purpose::STANDARD;
488    use base64::Engine as _;
489    use kitoken::convert::ConversionError;
490    use kitoken::EncodeError;
491    use kitoken::Kitoken;
492    use rust_genai_types::content::{
493        Content, FunctionCall, FunctionResponse, Part, PartKind, Role,
494    };
495    use rust_genai_types::models::{ComputeTokensResponse, TokensInfo};
496    use sha2::{Digest, Sha256};
497    use std::collections::HashMap;
498    use std::fmt::Write;
499    use std::fs;
500    use std::path::{Path, PathBuf};
501    use std::sync::Arc;
502
503    const CACHE_DIR: &str = "vertexai_tokenizer_model";
504
505    struct TokenizerConfig {
506        model_url: &'static str,
507        model_hash: &'static str,
508    }
509
510    const GEMINI_MODELS_TO_TOKENIZER_NAMES: &[(&str, &str)] = &[
511        ("gemini-1.0-pro", "gemma2"),
512        ("gemini-1.5-pro", "gemma2"),
513        ("gemini-1.5-flash", "gemma2"),
514        ("gemini-2.5-pro", "gemma3"),
515        ("gemini-2.5-flash", "gemma3"),
516        ("gemini-2.5-flash-lite", "gemma3"),
517        ("gemini-2.0-flash", "gemma3"),
518        ("gemini-2.0-flash-lite", "gemma3"),
519    ];
520
521    const GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES: &[(&str, &str)] = &[
522        ("gemini-1.0-pro-001", "gemma2"),
523        ("gemini-1.0-pro-002", "gemma2"),
524        ("gemini-1.5-pro-001", "gemma2"),
525        ("gemini-1.5-pro-002", "gemma2"),
526        ("gemini-1.5-flash-001", "gemma2"),
527        ("gemini-1.5-flash-002", "gemma2"),
528        ("gemini-2.5-pro-preview-06-05", "gemma3"),
529        ("gemini-2.5-pro-preview-05-06", "gemma3"),
530        ("gemini-2.5-pro-exp-03-25", "gemma3"),
531        ("gemini-live-2.5-flash", "gemma3"),
532        ("gemini-2.5-flash-native-audio-preview-12-2025", "gemma3"),
533        ("gemini-2.5-flash-native-audio-preview-09-2025", "gemma3"),
534        ("gemini-2.5-flash-preview-05-20", "gemma3"),
535        ("gemini-2.5-flash-preview-04-17", "gemma3"),
536        ("gemini-2.5-flash-lite-preview-06-17", "gemma3"),
537        ("gemini-2.0-flash-001", "gemma3"),
538        ("gemini-2.0-flash-lite-001", "gemma3"),
539        ("gemini-3-pro-preview", "gemma3"),
540    ];
541
542    fn tokenizer_config(name: &str) -> Option<TokenizerConfig> {
543        match name {
544            "gemma2" => Some(TokenizerConfig {
545                model_url: "https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
546                model_hash: "61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2",
547            }),
548            "gemma3" => Some(TokenizerConfig {
549                model_url: "https://raw.githubusercontent.com/google/gemma_pytorch/014acb7ac4563a5f77c76d7ff98f31b568c16508/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
550                model_hash: "1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
551            }),
552            _ => None,
553        }
554    }
555
556    #[derive(Debug, thiserror::Error)]
557    pub enum LocalTokenizerError {
558        #[error("Model {model} is not supported. Supported models: {supported}")]
559        UnsupportedModel { model: String, supported: String },
560        #[error("Tokenizer {name} is not supported")]
561        UnsupportedTokenizer { name: String },
562        #[error("Failed to download tokenizer model from {url}: {source}")]
563        Download {
564            url: String,
565            #[source]
566            source: reqwest::Error,
567        },
568        #[error("Tokenizer model download returned non-success status {status} for {url}")]
569        DownloadStatus { url: String, status: u16 },
570        #[error("Tokenizer model hash mismatch. expected {expected}, got {actual}")]
571        HashMismatch { expected: String, actual: String },
572        #[error("IO error: {source}")]
573        Io {
574            #[from]
575            source: std::io::Error,
576        },
577        #[error("Tokenizer encode error: {source}")]
578        Encode {
579            #[from]
580            source: EncodeError,
581        },
582        #[error("Local tokenizer does not support non-text content: {kind}")]
583        UnsupportedContent { kind: &'static str },
584        #[error("Tokenizer token id {id} not found in vocabulary")]
585        MissingToken { id: u32 },
586        #[error("Tokenizer conversion error: {source}")]
587        Conversion {
588            #[from]
589            source: ConversionError,
590        },
591    }
592
593    /// Kitoken-based estimator (`SentencePiece` compatible).
594    #[derive(Debug, Clone)]
595    pub struct KitokenEstimator {
596        encoder: Arc<Kitoken>,
597        token_bytes: Arc<HashMap<u32, Vec<u8>>>,
598    }
599
600    impl KitokenEstimator {
601        fn from_encoder(encoder: Kitoken) -> Self {
602            let token_bytes = Arc::new(build_token_bytes_map(&encoder));
603            Self {
604                encoder: Arc::new(encoder),
605                token_bytes,
606            }
607        }
608
609        /// Load a `SentencePiece` model from file.
610        ///
611        /// # Errors
612        /// 当模型加载失败或文件无效时返回错误。
613        pub fn from_sentencepiece_file(
614            path: impl AsRef<Path>,
615        ) -> Result<Self, LocalTokenizerError> {
616            let encoder = Kitoken::from_sentencepiece_file(path)?;
617            Ok(Self::from_encoder(encoder))
618        }
619
620        /// Load a known Gemini model tokenizer by model name (downloads & caches).
621        ///
622        /// # Errors
623        /// 当模型名未知、下载失败或解析失败时返回错误。
624        pub async fn from_model_name(model_name: &str) -> Result<Self, LocalTokenizerError> {
625            let tokenizer_name = get_tokenizer_name(model_name)?;
626            let config = tokenizer_config(tokenizer_name).ok_or_else(|| {
627                LocalTokenizerError::UnsupportedTokenizer {
628                    name: tokenizer_name.to_string(),
629                }
630            })?;
631            let model_bytes = load_model_bytes(config.model_url, config.model_hash).await?;
632            let encoder = Kitoken::from_sentencepiece_slice(&model_bytes)?;
633            Ok(Self::from_encoder(encoder))
634        }
635
636        /// Compute token ids and token bytes for text contents.
637        ///
638        /// # Errors
639        /// 当内容不受支持或编码失败时返回错误。
640        pub fn compute_tokens(
641            &self,
642            contents: &[Content],
643        ) -> Result<ComputeTokensResponse, LocalTokenizerError> {
644            let mut tokens_info: Vec<TokensInfo> = Vec::new();
645            for content in contents {
646                let role = content
647                    .role
648                    .map(|role| match role {
649                        Role::User => "user",
650                        Role::Model => "model",
651                        Role::Function => "function",
652                    })
653                    .map(ToString::to_string);
654
655                for part in &content.parts {
656                    let texts = collect_part_texts(part)?;
657                    if texts.is_empty() {
658                        continue;
659                    }
660                    let mut token_ids = Vec::new();
661                    let mut tokens = Vec::new();
662                    for text in texts {
663                        if text.is_empty() {
664                            continue;
665                        }
666                        let ids = self.encoder.encode(&text, true)?;
667                        for id in ids {
668                            let bytes = self
669                                .token_bytes
670                                .get(&id)
671                                .ok_or(LocalTokenizerError::MissingToken { id })?;
672                            tokens.push(STANDARD.encode(bytes));
673                            token_ids.push(i64::from(id));
674                        }
675                    }
676                    if token_ids.is_empty() {
677                        continue;
678                    }
679                    tokens_info.push(TokensInfo {
680                        role: role.clone(),
681                        token_ids: Some(token_ids),
682                        tokens: Some(tokens),
683                    });
684                }
685            }
686
687            Ok(ComputeTokensResponse {
688                sdk_http_response: None,
689                tokens_info: Some(tokens_info),
690            })
691        }
692    }
693
694    impl TokenEstimator for KitokenEstimator {
695        fn estimate_tokens(&self, contents: &[Content]) -> usize {
696            let mut total = 0usize;
697            for content in contents {
698                for part in &content.parts {
699                    if let Some(text) = part.text_value() {
700                        if let Ok(tokens) = self.encoder.encode(text, true) {
701                            total += tokens.len();
702                        }
703                    }
704                }
705            }
706            total
707        }
708    }
709
710    fn build_token_bytes_map(encoder: &Kitoken) -> HashMap<u32, Vec<u8>> {
711        let definition = encoder.to_definition();
712        let mut map = HashMap::new();
713        for token in definition.model.vocab() {
714            map.insert(token.id, normalize_token_bytes(&token.bytes));
715        }
716        for special in definition.specials {
717            map.insert(special.id, normalize_token_bytes(&special.bytes));
718        }
719        map
720    }
721
722    fn normalize_token_bytes(bytes: &[u8]) -> Vec<u8> {
723        std::str::from_utf8(bytes).map_or_else(
724            |_| bytes.to_vec(),
725            |text| text.replace('▁', " ").into_bytes(),
726        )
727    }
728
729    fn collect_part_texts(part: &Part) -> Result<Vec<String>, LocalTokenizerError> {
730        let mut texts = Vec::new();
731        match &part.kind {
732            PartKind::Text { text } => {
733                if !text.is_empty() {
734                    texts.push(text.clone());
735                }
736            }
737            PartKind::FunctionCall { function_call } => {
738                add_function_call_texts(function_call, &mut texts);
739            }
740            PartKind::FunctionResponse { function_response } => {
741                add_function_response_texts(function_response, &mut texts);
742            }
743            PartKind::ExecutableCode { executable_code } => {
744                if !executable_code.code.is_empty() {
745                    texts.push(executable_code.code.clone());
746                }
747            }
748            PartKind::CodeExecutionResult {
749                code_execution_result,
750            } => {
751                if let Some(output) = &code_execution_result.output {
752                    if !output.is_empty() {
753                        texts.push(output.clone());
754                    }
755                }
756            }
757            PartKind::InlineData { .. } => {
758                return Err(LocalTokenizerError::UnsupportedContent {
759                    kind: "inline_data",
760                });
761            }
762            PartKind::FileData { .. } => {
763                return Err(LocalTokenizerError::UnsupportedContent { kind: "file_data" });
764            }
765        }
766        Ok(texts)
767    }
768
769    fn add_function_call_texts(function_call: &FunctionCall, texts: &mut Vec<String>) {
770        if let Some(name) = &function_call.name {
771            if !name.is_empty() {
772                texts.push(name.clone());
773            }
774        }
775        if let Some(args) = &function_call.args {
776            add_json_texts(args, texts);
777        }
778    }
779
780    fn add_function_response_texts(function_response: &FunctionResponse, texts: &mut Vec<String>) {
781        if let Some(name) = &function_response.name {
782            if !name.is_empty() {
783                texts.push(name.clone());
784            }
785        }
786        if let Some(response) = &function_response.response {
787            add_json_texts(response, texts);
788        }
789    }
790
791    fn add_json_texts(value: &serde_json::Value, texts: &mut Vec<String>) {
792        match value {
793            serde_json::Value::String(value) => {
794                if !value.is_empty() {
795                    texts.push(value.clone());
796                }
797            }
798            serde_json::Value::Array(values) => {
799                for item in values {
800                    add_json_texts(item, texts);
801                }
802            }
803            serde_json::Value::Object(map) => {
804                for (key, value) in map {
805                    if !key.is_empty() {
806                        texts.push(key.clone());
807                    }
808                    add_json_texts(value, texts);
809                }
810            }
811            _ => {}
812        }
813    }
814
815    fn get_tokenizer_name(model_name: &str) -> Result<&'static str, LocalTokenizerError> {
816        for (name, tokenizer) in GEMINI_MODELS_TO_TOKENIZER_NAMES {
817            if *name == model_name {
818                return Ok(*tokenizer);
819            }
820        }
821        for (name, tokenizer) in GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES {
822            if *name == model_name {
823                return Ok(*tokenizer);
824            }
825        }
826        let mut supported: Vec<String> = GEMINI_MODELS_TO_TOKENIZER_NAMES
827            .iter()
828            .map(|(name, _)| (*name).to_string())
829            .collect();
830        supported.extend(
831            GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES
832                .iter()
833                .map(|(name, _)| (*name).to_string()),
834        );
835        supported.sort();
836        supported.dedup();
837        Err(LocalTokenizerError::UnsupportedModel {
838            model: model_name.to_string(),
839            supported: supported.join(", "),
840        })
841    }
842
843    async fn load_model_bytes(
844        url: &str,
845        expected_hash: &str,
846    ) -> Result<Vec<u8>, LocalTokenizerError> {
847        let cache_path = cache_path_for(url);
848        if let Some(bytes) = read_cache(&cache_path, expected_hash)? {
849            return Ok(bytes);
850        }
851        let bytes = download_model(url).await?;
852        let actual_hash = sha256_hex(&bytes);
853        if actual_hash != expected_hash {
854            return Err(LocalTokenizerError::HashMismatch {
855                expected: expected_hash.to_string(),
856                actual: actual_hash,
857            });
858        }
859        let _ = write_cache(&cache_path, &bytes);
860        Ok(bytes)
861    }
862
863    fn cache_path_for(url: &str) -> PathBuf {
864        let filename = sha256_hex(url.as_bytes());
865        std::env::temp_dir().join(CACHE_DIR).join(filename)
866    }
867
868    fn read_cache(
869        path: &Path,
870        expected_hash: &str,
871    ) -> Result<Option<Vec<u8>>, LocalTokenizerError> {
872        if !path.exists() {
873            return Ok(None);
874        }
875        let bytes = fs::read(path)?;
876        if sha256_hex(&bytes) == expected_hash {
877            return Ok(Some(bytes));
878        }
879        let _ = fs::remove_file(path);
880        Ok(None)
881    }
882
883    fn write_cache(path: &Path, bytes: &[u8]) -> Result<(), LocalTokenizerError> {
884        if let Some(parent) = path.parent() {
885            fs::create_dir_all(parent)?;
886        }
887        let tmp_path = path.with_extension("tmp");
888        fs::write(&tmp_path, bytes)?;
889        fs::rename(tmp_path, path)?;
890        Ok(())
891    }
892
893    async fn download_model(url: &str) -> Result<Vec<u8>, LocalTokenizerError> {
894        let response = reqwest::get(url)
895            .await
896            .map_err(|source| LocalTokenizerError::Download {
897                url: url.to_string(),
898                source,
899            })?;
900        let status = response.status();
901        if !status.is_success() {
902            return Err(LocalTokenizerError::DownloadStatus {
903                url: url.to_string(),
904                status: status.as_u16(),
905            });
906        }
907        let bytes = response
908            .bytes()
909            .await
910            .map_err(|source| LocalTokenizerError::Download {
911                url: url.to_string(),
912                source,
913            })?;
914        Ok(bytes.to_vec())
915    }
916
917    fn sha256_hex(data: &[u8]) -> String {
918        let digest = Sha256::digest(data);
919        let mut output = String::with_capacity(digest.len() * 2);
920        for byte in digest {
921            let _ = write!(output, "{byte:02x}");
922        }
923        output
924    }
925
926    #[cfg(test)]
927    mod tests {
928        use super::*;
929        use rust_genai_types::content::{Content, FunctionCall, FunctionResponse, Part, Role};
930        use rust_genai_types::enums::{Language, Outcome};
931        use serde_json::json;
932        use std::fs;
933        use std::time::{SystemTime, UNIX_EPOCH};
934
935        fn build_test_encoder() -> Kitoken {
936            let vocab = vec![
937                kitoken::Token {
938                    id: 0,
939                    bytes: b"hi".to_vec(),
940                },
941                kitoken::Token {
942                    id: 1,
943                    bytes: b"lookup".to_vec(),
944                },
945                kitoken::Token {
946                    id: 2,
947                    bytes: b"q".to_vec(),
948                },
949                kitoken::Token {
950                    id: 3,
951                    bytes: b"rust".to_vec(),
952                },
953                kitoken::Token {
954                    id: 4,
955                    bytes: b"resp".to_vec(),
956                },
957                kitoken::Token {
958                    id: 5,
959                    bytes: b"ok".to_vec(),
960                },
961                kitoken::Token {
962                    id: 6,
963                    bytes: b"code".to_vec(),
964                },
965                kitoken::Token {
966                    id: 7,
967                    bytes: b"out".to_vec(),
968                },
969                kitoken::Token {
970                    id: 8,
971                    bytes: "\u{2581}".as_bytes().to_vec(),
972                },
973            ];
974            let specials = vec![kitoken::SpecialToken {
975                id: 99,
976                bytes: b"[UNK]".to_vec(),
977                kind: kitoken::SpecialTokenKind::Unknown,
978                ident: None,
979                score: 0.0,
980                extract: false,
981            }];
982            let model = kitoken::Model::WordPiece {
983                vocab,
984                max_word_chars: 0,
985            };
986            let config = kitoken::Configuration::default();
987            let meta = kitoken::Metadata::default();
988            Kitoken::new(model, specials, config, meta).unwrap()
989        }
990
991        fn unique_cache_key(tag: &str) -> String {
992            let nanos = SystemTime::now()
993                .duration_since(UNIX_EPOCH)
994                .unwrap_or_default()
995                .as_nanos();
996            format!("test://{tag}-{nanos}")
997        }
998
999        #[test]
1000        fn get_tokenizer_name_known_and_unknown() {
1001            assert_eq!(get_tokenizer_name("gemini-1.5-pro").unwrap(), "gemma2");
1002            let err = get_tokenizer_name("unknown-model").unwrap_err();
1003            match err {
1004                LocalTokenizerError::UnsupportedModel { supported, .. } => {
1005                    assert!(supported.contains("gemini-1.0-pro"));
1006                }
1007                _ => panic!("expected UnsupportedModel error"),
1008            }
1009        }
1010
1011        #[test]
1012        fn normalize_token_bytes_replaces_separator_and_handles_invalid_utf8() {
1013            let replaced = normalize_token_bytes("\u{2581}hi".as_bytes());
1014            assert_eq!(replaced, b" hi".to_vec());
1015
1016            let invalid = normalize_token_bytes(&[0xff, 0xfe]);
1017            assert_eq!(invalid, vec![0xff, 0xfe]);
1018        }
1019
1020        #[test]
1021        fn cache_roundtrip_and_mismatch_evicts() {
1022            let key = unique_cache_key("cache-roundtrip");
1023            let path = cache_path_for(&key);
1024            let _ = fs::remove_file(&path);
1025
1026            let bytes = b"cached".to_vec();
1027            write_cache(&path, &bytes).unwrap();
1028            let hash = sha256_hex(&bytes);
1029            let cached = read_cache(&path, &hash).unwrap().unwrap();
1030            assert_eq!(cached, bytes);
1031
1032            let wrong_hash = sha256_hex(b"other");
1033            let result = read_cache(&path, &wrong_hash).unwrap();
1034            assert!(result.is_none());
1035            assert!(!path.exists());
1036        }
1037
1038        #[tokio::test]
1039        async fn load_model_bytes_uses_cache() {
1040            let key = unique_cache_key("load-cache");
1041            let path = cache_path_for(&key);
1042            let _ = fs::remove_file(&path);
1043
1044            let bytes = b"model-bytes".to_vec();
1045            write_cache(&path, &bytes).unwrap();
1046            let hash = sha256_hex(&bytes);
1047
1048            let loaded = load_model_bytes(&key, &hash).await.unwrap();
1049            assert_eq!(loaded, bytes);
1050        }
1051
1052        #[test]
1053        fn collect_part_texts_rejects_binary_parts() {
1054            let inline = Part::inline_data(vec![1, 2, 3], "image/png");
1055            let err = collect_part_texts(&inline).unwrap_err();
1056            assert!(matches!(
1057                err,
1058                LocalTokenizerError::UnsupportedContent {
1059                    kind: "inline_data"
1060                }
1061            ));
1062
1063            let file = Part::file_data("files/1", "application/pdf");
1064            let err = collect_part_texts(&file).unwrap_err();
1065            assert!(matches!(
1066                err,
1067                LocalTokenizerError::UnsupportedContent { kind: "file_data" }
1068            ));
1069        }
1070
1071        #[test]
1072        fn kitoken_estimator_compute_tokens_and_map_normalization() {
1073            let encoder = build_test_encoder();
1074            let estimator = KitokenEstimator::from_encoder(encoder);
1075
1076            let call = FunctionCall {
1077                id: None,
1078                name: Some("lookup".into()),
1079                args: Some(json!({"q": "rust"})),
1080                partial_args: None,
1081                will_continue: None,
1082            };
1083            let response = FunctionResponse {
1084                will_continue: None,
1085                scheduling: None,
1086                parts: None,
1087                id: None,
1088                name: Some("resp".into()),
1089                response: Some(json!({"ok": "ok"})),
1090            };
1091            let content = Content::from_parts(
1092                vec![
1093                    Part::text("hi"),
1094                    Part::function_call(call),
1095                    Part::function_response(response),
1096                    Part::executable_code("code", Language::Python),
1097                    Part::code_execution_result(Outcome::OutcomeOk, "out"),
1098                ],
1099                Role::User,
1100            );
1101
1102            let result = estimator.compute_tokens(&[content]).unwrap();
1103            assert!(!result.tokens_info.as_ref().unwrap().is_empty());
1104
1105            let estimated = estimator.estimate_tokens(&[Content::text("hi")]);
1106            assert!(estimated > 0);
1107
1108            let normalized = estimator.token_bytes.get(&8).unwrap();
1109            assert_eq!(normalized.as_slice(), b" ");
1110        }
1111    }
1112}