1use rust_genai_types::config::GenerationConfig;
4use rust_genai_types::content::{Content, FunctionCall, FunctionResponse, PartKind};
5use rust_genai_types::models::CountTokensConfig;
6use rust_genai_types::tool::{FunctionDeclaration, Schema, Tool};
7use serde_json::Value;
8
9pub trait TokenEstimator {
11 fn estimate_tokens(&self, contents: &[Content]) -> usize;
12}
13
14#[derive(Debug, Clone, Default)]
16pub struct SimpleTokenEstimator;
17
18impl TokenEstimator for SimpleTokenEstimator {
19 fn estimate_tokens(&self, contents: &[Content]) -> usize {
20 let mut bytes = 0usize;
21 for content in contents {
22 for part in &content.parts {
23 match &part.kind {
24 PartKind::Text { text } => {
25 bytes += text.len();
26 }
27 PartKind::InlineData { inline_data } => {
28 bytes += inline_data.data.len();
29 }
30 PartKind::FileData { file_data } => {
31 bytes += file_data.file_uri.len();
32 }
33 PartKind::FunctionCall { function_call } => {
34 if let Some(name) = &function_call.name {
35 bytes += name.len();
36 }
37 }
38 PartKind::FunctionResponse { function_response } => {
39 if let Some(name) = &function_response.name {
40 bytes += name.len();
41 }
42 }
43 PartKind::ExecutableCode { executable_code } => {
44 bytes += executable_code.code.len();
45 }
46 PartKind::CodeExecutionResult {
47 code_execution_result,
48 } => {
49 if let Some(output) = &code_execution_result.output {
50 bytes += output.len();
51 }
52 }
53 }
54 }
55 }
56 bytes.div_ceil(4)
58 }
59}
60
61pub(crate) fn build_estimation_contents(
62 contents: &[Content],
63 config: &CountTokensConfig,
64) -> Vec<Content> {
65 let mut combined = Vec::with_capacity(contents.len() + 1);
66 combined.extend_from_slice(contents);
67 if let Some(system_instruction) = &config.system_instruction {
68 combined.push(system_instruction.clone());
69 }
70
71 let mut accumulator = TextAccumulator::default();
72 accumulator.add_function_texts_from_contents(&combined);
73 if let Some(tools) = &config.tools {
74 accumulator.add_tools(tools);
75 }
76 if let Some(generation_config) = &config.generation_config {
77 accumulator.add_generation_config(generation_config);
78 }
79 combined.extend(accumulator.into_contents());
80 combined
81}
82
83#[derive(Debug, Default)]
84struct TextAccumulator {
85 texts: Vec<String>,
86}
87
88impl TextAccumulator {
89 fn add_function_texts_from_contents(&mut self, contents: &[Content]) {
90 for content in contents {
91 self.add_function_texts_from_content(content);
92 }
93 }
94
95 fn add_function_texts_from_content(&mut self, content: &Content) {
96 for part in &content.parts {
97 match &part.kind {
98 PartKind::FunctionCall { function_call } => {
99 self.add_function_call(function_call);
100 }
101 PartKind::FunctionResponse { function_response } => {
102 self.add_function_response(function_response);
103 }
104 _ => {}
105 }
106 }
107 }
108
109 fn add_function_call(&mut self, function_call: &FunctionCall) {
110 if let Some(name) = &function_call.name {
111 self.push_text(name);
112 }
113 if let Some(args) = &function_call.args {
114 self.add_json(args);
115 }
116 }
117
118 fn add_function_response(&mut self, function_response: &FunctionResponse) {
119 if let Some(name) = &function_response.name {
120 self.push_text(name);
121 }
122 if let Some(response) = &function_response.response {
123 self.add_json(response);
124 }
125 }
126
127 fn add_tools(&mut self, tools: &[Tool]) {
128 for tool in tools {
129 if let Some(functions) = &tool.function_declarations {
130 for function in functions {
131 self.add_function_declaration(function);
132 }
133 }
134 }
135 }
136
137 fn add_function_declaration(&mut self, declaration: &FunctionDeclaration) {
138 self.push_text(&declaration.name);
139 if let Some(description) = &declaration.description {
140 self.push_text(description);
141 }
142 if let Some(parameters) = &declaration.parameters {
143 self.add_schema(parameters);
144 }
145 if let Some(response) = &declaration.response {
146 self.add_schema(response);
147 }
148 if let Some(parameters_json) = &declaration.parameters_json_schema {
149 self.add_json(parameters_json);
150 }
151 if let Some(response_json) = &declaration.response_json_schema {
152 self.add_json(response_json);
153 }
154 }
155
156 fn add_generation_config(&mut self, generation_config: &GenerationConfig) {
157 if let Some(response_schema) = &generation_config.response_schema {
158 self.add_schema(response_schema);
159 }
160 if let Some(response_json_schema) = &generation_config.response_json_schema {
161 self.add_json(response_json_schema);
162 }
163 }
164
165 fn add_schema(&mut self, schema: &Schema) {
166 if let Some(title) = &schema.title {
167 self.push_text(title);
168 }
169 if let Some(format) = &schema.format {
170 self.push_text(format);
171 }
172 if let Some(description) = &schema.description {
173 self.push_text(description);
174 }
175 if let Some(enum_values) = &schema.enum_values {
176 for value in enum_values {
177 self.push_text(value);
178 }
179 }
180 if let Some(required) = &schema.required {
181 for value in required {
182 self.push_text(value);
183 }
184 }
185 if let Some(properties) = &schema.properties {
186 for (key, value) in properties {
187 self.push_text(key);
188 self.add_schema(value);
189 }
190 }
191 if let Some(items) = &schema.items {
192 self.add_schema(items);
193 }
194 if let Some(any_of) = &schema.any_of {
195 for schema in any_of {
196 self.add_schema(schema);
197 }
198 }
199 if let Some(example) = &schema.example {
200 self.add_json(example);
201 }
202 if let Some(default) = &schema.default {
203 self.add_json(default);
204 }
205 }
206
207 fn add_json(&mut self, value: &Value) {
208 match value {
209 Value::String(value) => self.push_text(value),
210 Value::Array(values) => {
211 for item in values {
212 self.add_json(item);
213 }
214 }
215 Value::Object(map) => {
216 for (key, value) in map {
217 self.push_text(key);
218 self.add_json(value);
219 }
220 }
221 _ => {}
222 }
223 }
224
225 fn push_text(&mut self, value: &str) {
226 if !value.is_empty() {
227 self.texts.push(value.to_string());
228 }
229 }
230
231 fn into_contents(self) -> Vec<Content> {
232 self.texts.into_iter().map(Content::text).collect()
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use rust_genai_types::config::GenerationConfig;
240 use rust_genai_types::content::{FunctionCall, FunctionResponse, Part, Role};
241 use rust_genai_types::models::CountTokensConfig;
242 use rust_genai_types::tool::{FunctionDeclaration, Schema, Tool};
243 use serde_json::json;
244 use std::collections::HashMap;
245
246 #[test]
247 fn simple_token_estimator_counts_various_parts() {
248 let call = FunctionCall {
249 id: Some("call-1".into()),
250 name: Some("lookup".into()),
251 args: Some(json!({"q": "rust"})),
252 partial_args: None,
253 will_continue: None,
254 };
255 let response = FunctionResponse {
256 will_continue: None,
257 scheduling: None,
258 parts: None,
259 id: Some("resp-1".into()),
260 name: Some("lookup".into()),
261 response: Some(json!({"ok": true})),
262 };
263 let content = Content::from_parts(
264 vec![
265 Part::text("hello"),
266 Part::inline_data(vec![0, 1, 2, 3], "image/png"),
267 Part::file_data("files/abc", "application/pdf"),
268 Part::function_call(call),
269 Part::function_response(response),
270 Part::executable_code("print('hi')", rust_genai_types::enums::Language::Python),
271 Part::code_execution_result(rust_genai_types::enums::Outcome::OutcomeOk, "ok"),
272 ],
273 Role::User,
274 );
275
276 let estimator = SimpleTokenEstimator;
277 let tokens = estimator.estimate_tokens(&[content]);
278 assert!(tokens > 0);
279 }
280
281 #[test]
282 fn build_estimation_contents_includes_tools_and_config() {
283 let declaration = FunctionDeclaration {
284 name: "search".to_string(),
285 description: Some("desc".to_string()),
286 parameters: Some(
287 Schema::object()
288 .property("q", Schema::string())
289 .required("q")
290 .build(),
291 ),
292 parameters_json_schema: Some(
293 json!({"type": "object", "properties": {"q": {"type": "string"}}}),
294 ),
295 response: Some(Schema::string()),
296 response_json_schema: Some(json!({"type": "string"})),
297 behavior: None,
298 };
299 let tool = Tool {
300 function_declarations: Some(vec![declaration]),
301 ..Default::default()
302 };
303 let generation_config = GenerationConfig {
304 response_schema: Some(Schema::object().property("r", Schema::string()).build()),
305 response_json_schema: Some(
306 json!({"type": "object", "properties": {"r": {"type": "string"}}}),
307 ),
308 ..Default::default()
309 };
310
311 let config = CountTokensConfig {
312 system_instruction: Some(Content::text("sys")),
313 tools: Some(vec![tool]),
314 generation_config: Some(generation_config),
315 };
316
317 let contents = vec![Content::text("user")];
318 let combined = build_estimation_contents(&contents, &config);
319 assert!(combined.len() >= 2);
321 }
322
323 #[test]
324 fn text_accumulator_collects_schema_and_json_fields() {
325 let mut properties = HashMap::new();
326 properties.insert("prop".to_string(), Box::new(Schema::string()));
327 let schema = Schema {
328 title: Some("Title".into()),
329 format: Some("Fmt".into()),
330 description: Some("Desc".into()),
331 enum_values: Some(vec!["A".into(), "B".into()]),
332 required: Some(vec!["req".into()]),
333 properties: Some(properties),
334 items: Some(Box::new(Schema::number())),
335 any_of: Some(vec![Schema::boolean()]),
336 example: Some(json!({"ex_key": "ex_val"})),
337 default: Some(json!(["d"])),
338 ..Default::default()
339 };
340
341 let mut accumulator = TextAccumulator::default();
342 accumulator.add_schema(&schema);
343 accumulator.add_json(&json!(["a", {"k": "v"}, 1]));
344 let texts = accumulator.texts;
345
346 assert!(texts.contains(&"Title".to_string()));
347 assert!(texts.contains(&"Fmt".to_string()));
348 assert!(texts.contains(&"Desc".to_string()));
349 assert!(texts.contains(&"A".to_string()));
350 assert!(texts.contains(&"B".to_string()));
351 assert!(texts.contains(&"req".to_string()));
352 assert!(texts.contains(&"prop".to_string()));
353 assert!(texts.contains(&"ex_key".to_string()));
354 assert!(texts.contains(&"ex_val".to_string()));
355 assert!(texts.contains(&"k".to_string()));
356 assert!(texts.contains(&"v".to_string()));
357 assert!(texts.contains(&"a".to_string()));
358 }
359
360 #[test]
361 fn text_accumulator_collects_function_parts() {
362 let call = FunctionCall {
363 id: None,
364 name: None,
365 args: Some(json!({"q": "rust"})),
366 partial_args: None,
367 will_continue: None,
368 };
369 let response = FunctionResponse {
370 will_continue: None,
371 scheduling: None,
372 parts: None,
373 id: None,
374 name: None,
375 response: Some(json!({"answer": "ok"})),
376 };
377 let content = Content::from_parts(
378 vec![Part::function_call(call), Part::function_response(response)],
379 Role::User,
380 );
381
382 let mut accumulator = TextAccumulator::default();
383 accumulator.add_function_texts_from_content(&content);
384 let texts = accumulator.texts;
385
386 assert!(texts.contains(&"q".to_string()));
387 assert!(texts.contains(&"rust".to_string()));
388 assert!(texts.contains(&"answer".to_string()));
389 assert!(texts.contains(&"ok".to_string()));
390 }
391
392 #[test]
393 fn text_accumulator_collects_named_parts_and_declarations() {
394 let call = FunctionCall {
395 id: None,
396 name: Some("lookup".into()),
397 args: Some(json!({"k": "v"})),
398 partial_args: None,
399 will_continue: None,
400 };
401 let response = FunctionResponse {
402 will_continue: None,
403 scheduling: None,
404 parts: None,
405 id: None,
406 name: Some("lookup_result".into()),
407 response: Some(json!({"out": "done"})),
408 };
409 let content = Content::from_parts(
410 vec![Part::function_call(call), Part::function_response(response)],
411 Role::User,
412 );
413
414 let declaration = FunctionDeclaration {
415 name: "search".to_string(),
416 description: Some("desc".to_string()),
417 parameters: Some(Schema::object().property("q", Schema::string()).build()),
418 parameters_json_schema: Some(
419 json!({"type": "object", "properties": {"q": {"type": "string"}}}),
420 ),
421 response: Some(Schema::string()),
422 response_json_schema: Some(json!({"type": "string"})),
423 behavior: None,
424 };
425
426 let generation_config = GenerationConfig {
427 response_schema: Some(Schema::string()),
428 response_json_schema: Some(json!({"type": "string"})),
429 ..Default::default()
430 };
431
432 let mut accumulator = TextAccumulator::default();
433 accumulator.add_function_texts_from_content(&content);
434 accumulator.add_function_declaration(&declaration);
435 accumulator.add_generation_config(&generation_config);
436 let texts = accumulator.texts;
437
438 assert!(texts.contains(&"lookup".to_string()));
439 assert!(texts.contains(&"lookup_result".to_string()));
440 assert!(texts.contains(&"k".to_string()));
441 assert!(texts.contains(&"v".to_string()));
442 assert!(texts.contains(&"out".to_string()));
443 assert!(texts.contains(&"done".to_string()));
444 assert!(texts.contains(&"search".to_string()));
445 assert!(texts.contains(&"desc".to_string()));
446 assert!(texts.contains(&"q".to_string()));
447 }
448
449 #[test]
450 fn simple_token_estimator_counts_function_names() {
451 let call = FunctionCall {
452 id: None,
453 name: Some("ping".into()),
454 args: None,
455 partial_args: None,
456 will_continue: None,
457 };
458 let response = FunctionResponse {
459 will_continue: None,
460 scheduling: None,
461 parts: None,
462 id: None,
463 name: Some("pong".into()),
464 response: None,
465 };
466 let content = Content::from_parts(
467 vec![Part::function_call(call), Part::function_response(response)],
468 Role::User,
469 );
470
471 let estimator = SimpleTokenEstimator;
472 let tokens = estimator.estimate_tokens(&[content]);
473 assert_eq!(tokens, 2);
474 }
475
476 #[test]
477 fn simple_token_estimator_empty_is_zero() {
478 let estimator = SimpleTokenEstimator;
479 let tokens = estimator.estimate_tokens(&[]);
480 assert_eq!(tokens, 0);
481 }
482}
483
484#[cfg(feature = "kitoken")]
485pub mod kitoken {
486 use super::TokenEstimator;
487 use base64::engine::general_purpose::STANDARD;
488 use base64::Engine as _;
489 use kitoken::convert::ConversionError;
490 use kitoken::EncodeError;
491 use kitoken::Kitoken;
492 use rust_genai_types::content::{
493 Content, FunctionCall, FunctionResponse, Part, PartKind, Role,
494 };
495 use rust_genai_types::models::{ComputeTokensResponse, TokensInfo};
496 use sha2::{Digest, Sha256};
497 use std::collections::HashMap;
498 use std::fmt::Write;
499 use std::fs;
500 use std::path::{Path, PathBuf};
501 use std::sync::Arc;
502
503 const CACHE_DIR: &str = "vertexai_tokenizer_model";
504
505 struct TokenizerConfig {
506 model_url: &'static str,
507 model_hash: &'static str,
508 }
509
510 const GEMINI_MODELS_TO_TOKENIZER_NAMES: &[(&str, &str)] = &[
511 ("gemini-1.0-pro", "gemma2"),
512 ("gemini-1.5-pro", "gemma2"),
513 ("gemini-1.5-flash", "gemma2"),
514 ("gemini-2.5-pro", "gemma3"),
515 ("gemini-2.5-flash", "gemma3"),
516 ("gemini-2.5-flash-lite", "gemma3"),
517 ("gemini-2.0-flash", "gemma3"),
518 ("gemini-2.0-flash-lite", "gemma3"),
519 ("gemini-3-flash-preview", "gemma3"),
520 ("gemini-3.1-flash-lite-preview", "gemma3"),
521 ("gemini-3.1-flash-image-preview", "gemma3"),
522 ("gemini-3.1-pro-preview", "gemma3"),
523 ("gemini-3-pro-preview", "gemma3"),
524 ("gemini-3-pro-image-preview", "gemma3"),
525 ];
526
527 const GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES: &[(&str, &str)] = &[
528 ("gemini-1.0-pro-001", "gemma2"),
529 ("gemini-1.0-pro-002", "gemma2"),
530 ("gemini-1.5-pro-001", "gemma2"),
531 ("gemini-1.5-pro-002", "gemma2"),
532 ("gemini-1.5-flash-001", "gemma2"),
533 ("gemini-1.5-flash-002", "gemma2"),
534 ("gemini-2.5-pro-preview-06-05", "gemma3"),
535 ("gemini-2.5-pro-preview-05-06", "gemma3"),
536 ("gemini-2.5-pro-exp-03-25", "gemma3"),
537 ("gemini-live-2.5-flash", "gemma3"),
538 ("gemini-3.1-flash-live-preview", "gemma3"),
539 ("gemini-3.1-flash-tts-preview", "gemma3"),
540 ("gemini-2.5-flash-native-audio-preview-12-2025", "gemma3"),
541 ("gemini-2.5-flash-native-audio-preview-09-2025", "gemma3"),
542 ("gemini-2.5-flash-preview-05-20", "gemma3"),
543 ("gemini-2.5-flash-preview-04-17", "gemma3"),
544 ("gemini-2.5-flash-lite-preview-06-17", "gemma3"),
545 ("gemini-2.0-flash-001", "gemma3"),
546 ("gemini-2.0-flash-lite-001", "gemma3"),
547 ("gemini-3-pro-preview", "gemma3"),
548 ];
549
550 fn tokenizer_config(name: &str) -> Option<TokenizerConfig> {
551 match name {
552 "gemma2" => Some(TokenizerConfig {
553 model_url: "https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
554 model_hash: "61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2",
555 }),
556 "gemma3" => Some(TokenizerConfig {
557 model_url: "https://raw.githubusercontent.com/google/gemma_pytorch/014acb7ac4563a5f77c76d7ff98f31b568c16508/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
558 model_hash: "1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
559 }),
560 _ => None,
561 }
562 }
563
564 #[derive(Debug, thiserror::Error)]
565 pub enum LocalTokenizerError {
566 #[error("Model {model} is not supported. Supported models: {supported}")]
567 UnsupportedModel { model: String, supported: String },
568 #[error("Tokenizer {name} is not supported")]
569 UnsupportedTokenizer { name: String },
570 #[error("Failed to download tokenizer model from {url}: {source}")]
571 Download {
572 url: String,
573 #[source]
574 source: reqwest::Error,
575 },
576 #[error("Tokenizer model download returned non-success status {status} for {url}")]
577 DownloadStatus { url: String, status: u16 },
578 #[error("Tokenizer model hash mismatch. expected {expected}, got {actual}")]
579 HashMismatch { expected: String, actual: String },
580 #[error("IO error: {source}")]
581 Io {
582 #[from]
583 source: std::io::Error,
584 },
585 #[error("Tokenizer encode error: {source}")]
586 Encode {
587 #[from]
588 source: EncodeError,
589 },
590 #[error("Local tokenizer does not support non-text content: {kind}")]
591 UnsupportedContent { kind: &'static str },
592 #[error("Tokenizer token id {id} not found in vocabulary")]
593 MissingToken { id: u32 },
594 #[error("Tokenizer conversion error: {source}")]
595 Conversion {
596 #[from]
597 source: ConversionError,
598 },
599 }
600
601 #[derive(Debug, Clone)]
603 pub struct KitokenEstimator {
604 encoder: Arc<Kitoken>,
605 token_bytes: Arc<HashMap<u32, Vec<u8>>>,
606 }
607
608 impl KitokenEstimator {
609 fn from_encoder(encoder: Kitoken) -> Self {
610 let token_bytes = Arc::new(build_token_bytes_map(&encoder));
611 Self {
612 encoder: Arc::new(encoder),
613 token_bytes,
614 }
615 }
616
617 pub fn from_sentencepiece_file(
622 path: impl AsRef<Path>,
623 ) -> Result<Self, LocalTokenizerError> {
624 let encoder = Kitoken::from_sentencepiece_file(path)?;
625 Ok(Self::from_encoder(encoder))
626 }
627
628 pub async fn from_model_name(model_name: &str) -> Result<Self, LocalTokenizerError> {
633 let tokenizer_name = get_tokenizer_name(model_name)?;
634 let config = tokenizer_config(tokenizer_name).ok_or_else(|| {
635 LocalTokenizerError::UnsupportedTokenizer {
636 name: tokenizer_name.to_string(),
637 }
638 })?;
639 let model_bytes = load_model_bytes(config.model_url, config.model_hash).await?;
640 let encoder = Kitoken::from_sentencepiece_slice(&model_bytes)?;
641 Ok(Self::from_encoder(encoder))
642 }
643
644 pub fn compute_tokens(
649 &self,
650 contents: &[Content],
651 ) -> Result<ComputeTokensResponse, LocalTokenizerError> {
652 let mut tokens_info: Vec<TokensInfo> = Vec::new();
653 for content in contents {
654 let role = content
655 .role
656 .map(|role| match role {
657 Role::User => "user",
658 Role::Model => "model",
659 Role::Function => "function",
660 })
661 .map(ToString::to_string);
662
663 for part in &content.parts {
664 let texts = collect_part_texts(part)?;
665 if texts.is_empty() {
666 continue;
667 }
668 let mut token_ids = Vec::new();
669 let mut tokens = Vec::new();
670 for text in texts {
671 if text.is_empty() {
672 continue;
673 }
674 let ids = self.encoder.encode(&text, true)?;
675 for id in ids {
676 let bytes = self
677 .token_bytes
678 .get(&id)
679 .ok_or(LocalTokenizerError::MissingToken { id })?;
680 tokens.push(STANDARD.encode(bytes));
681 token_ids.push(i64::from(id));
682 }
683 }
684 if token_ids.is_empty() {
685 continue;
686 }
687 tokens_info.push(TokensInfo {
688 role: role.clone(),
689 token_ids: Some(token_ids),
690 tokens: Some(tokens),
691 });
692 }
693 }
694
695 Ok(ComputeTokensResponse {
696 sdk_http_response: None,
697 tokens_info: Some(tokens_info),
698 })
699 }
700 }
701
702 impl TokenEstimator for KitokenEstimator {
703 fn estimate_tokens(&self, contents: &[Content]) -> usize {
704 let mut total = 0usize;
705 for content in contents {
706 for part in &content.parts {
707 if let Some(text) = part.text_value() {
708 if let Ok(tokens) = self.encoder.encode(text, true) {
709 total += tokens.len();
710 }
711 }
712 }
713 }
714 total
715 }
716 }
717
718 fn build_token_bytes_map(encoder: &Kitoken) -> HashMap<u32, Vec<u8>> {
719 let definition = encoder.to_definition();
720 let mut map = HashMap::new();
721 for token in definition.model.vocab() {
722 map.insert(token.id, normalize_token_bytes(&token.bytes));
723 }
724 for special in definition.specials {
725 map.insert(special.id, normalize_token_bytes(&special.bytes));
726 }
727 map
728 }
729
730 fn normalize_token_bytes(bytes: &[u8]) -> Vec<u8> {
731 std::str::from_utf8(bytes).map_or_else(
732 |_| bytes.to_vec(),
733 |text| text.replace('▁', " ").into_bytes(),
734 )
735 }
736
737 fn collect_part_texts(part: &Part) -> Result<Vec<String>, LocalTokenizerError> {
738 let mut texts = Vec::new();
739 match &part.kind {
740 PartKind::Text { text } => {
741 if !text.is_empty() {
742 texts.push(text.clone());
743 }
744 }
745 PartKind::FunctionCall { function_call } => {
746 add_function_call_texts(function_call, &mut texts);
747 }
748 PartKind::FunctionResponse { function_response } => {
749 add_function_response_texts(function_response, &mut texts);
750 }
751 PartKind::ExecutableCode { executable_code } => {
752 if !executable_code.code.is_empty() {
753 texts.push(executable_code.code.clone());
754 }
755 }
756 PartKind::CodeExecutionResult {
757 code_execution_result,
758 } => {
759 if let Some(output) = &code_execution_result.output {
760 if !output.is_empty() {
761 texts.push(output.clone());
762 }
763 }
764 }
765 PartKind::InlineData { .. } => {
766 return Err(LocalTokenizerError::UnsupportedContent {
767 kind: "inline_data",
768 });
769 }
770 PartKind::FileData { .. } => {
771 return Err(LocalTokenizerError::UnsupportedContent { kind: "file_data" });
772 }
773 }
774 Ok(texts)
775 }
776
777 fn add_function_call_texts(function_call: &FunctionCall, texts: &mut Vec<String>) {
778 if let Some(name) = &function_call.name {
779 if !name.is_empty() {
780 texts.push(name.clone());
781 }
782 }
783 if let Some(args) = &function_call.args {
784 add_json_texts(args, texts);
785 }
786 }
787
788 fn add_function_response_texts(function_response: &FunctionResponse, texts: &mut Vec<String>) {
789 if let Some(name) = &function_response.name {
790 if !name.is_empty() {
791 texts.push(name.clone());
792 }
793 }
794 if let Some(response) = &function_response.response {
795 add_json_texts(response, texts);
796 }
797 }
798
799 fn add_json_texts(value: &serde_json::Value, texts: &mut Vec<String>) {
800 match value {
801 serde_json::Value::String(value) if !value.is_empty() => {
802 texts.push(value.clone());
803 }
804 serde_json::Value::Array(values) => {
805 for item in values {
806 add_json_texts(item, texts);
807 }
808 }
809 serde_json::Value::Object(map) => {
810 for (key, value) in map {
811 if !key.is_empty() {
812 texts.push(key.clone());
813 }
814 add_json_texts(value, texts);
815 }
816 }
817 _ => {}
818 }
819 }
820
821 fn get_tokenizer_name(model_name: &str) -> Result<&'static str, LocalTokenizerError> {
822 for (name, tokenizer) in GEMINI_MODELS_TO_TOKENIZER_NAMES {
823 if *name == model_name {
824 return Ok(*tokenizer);
825 }
826 }
827 for (name, tokenizer) in GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES {
828 if *name == model_name {
829 return Ok(*tokenizer);
830 }
831 }
832 let mut supported: Vec<String> = GEMINI_MODELS_TO_TOKENIZER_NAMES
833 .iter()
834 .map(|(name, _)| (*name).to_string())
835 .collect();
836 supported.extend(
837 GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES
838 .iter()
839 .map(|(name, _)| (*name).to_string()),
840 );
841 supported.sort();
842 supported.dedup();
843 Err(LocalTokenizerError::UnsupportedModel {
844 model: model_name.to_string(),
845 supported: supported.join(", "),
846 })
847 }
848
849 async fn load_model_bytes(
850 url: &str,
851 expected_hash: &str,
852 ) -> Result<Vec<u8>, LocalTokenizerError> {
853 let cache_path = cache_path_for(url);
854 if let Some(bytes) = read_cache(&cache_path, expected_hash)? {
855 return Ok(bytes);
856 }
857 let bytes = download_model(url).await?;
858 let actual_hash = sha256_hex(&bytes);
859 if actual_hash != expected_hash {
860 return Err(LocalTokenizerError::HashMismatch {
861 expected: expected_hash.to_string(),
862 actual: actual_hash,
863 });
864 }
865 let _ = write_cache(&cache_path, &bytes);
866 Ok(bytes)
867 }
868
869 fn cache_path_for(url: &str) -> PathBuf {
870 let filename = sha256_hex(url.as_bytes());
871 std::env::temp_dir().join(CACHE_DIR).join(filename)
872 }
873
874 fn read_cache(
875 path: &Path,
876 expected_hash: &str,
877 ) -> Result<Option<Vec<u8>>, LocalTokenizerError> {
878 if !path.exists() {
879 return Ok(None);
880 }
881 let bytes = fs::read(path)?;
882 if sha256_hex(&bytes) == expected_hash {
883 return Ok(Some(bytes));
884 }
885 let _ = fs::remove_file(path);
886 Ok(None)
887 }
888
889 fn write_cache(path: &Path, bytes: &[u8]) -> Result<(), LocalTokenizerError> {
890 if let Some(parent) = path.parent() {
891 fs::create_dir_all(parent)?;
892 }
893 let tmp_path = path.with_extension("tmp");
894 fs::write(&tmp_path, bytes)?;
895 fs::rename(tmp_path, path)?;
896 Ok(())
897 }
898
899 async fn download_model(url: &str) -> Result<Vec<u8>, LocalTokenizerError> {
900 let response = reqwest::get(url)
901 .await
902 .map_err(|source| LocalTokenizerError::Download {
903 url: url.to_string(),
904 source,
905 })?;
906 let status = response.status();
907 if !status.is_success() {
908 return Err(LocalTokenizerError::DownloadStatus {
909 url: url.to_string(),
910 status: status.as_u16(),
911 });
912 }
913 let bytes = response
914 .bytes()
915 .await
916 .map_err(|source| LocalTokenizerError::Download {
917 url: url.to_string(),
918 source,
919 })?;
920 Ok(bytes.to_vec())
921 }
922
923 fn sha256_hex(data: &[u8]) -> String {
924 let digest = Sha256::digest(data);
925 let mut output = String::with_capacity(digest.len() * 2);
926 for byte in digest {
927 let _ = write!(output, "{byte:02x}");
928 }
929 output
930 }
931
932 #[cfg(test)]
933 mod tests {
934 use super::*;
935 use rust_genai_types::content::{Content, FunctionCall, FunctionResponse, Part, Role};
936 use rust_genai_types::enums::{Language, Outcome};
937 use serde_json::json;
938 use std::fs;
939 use std::time::{SystemTime, UNIX_EPOCH};
940
941 fn build_test_encoder() -> Kitoken {
942 let vocab = vec![
943 kitoken::Token {
944 id: 0,
945 bytes: b"hi".to_vec(),
946 },
947 kitoken::Token {
948 id: 1,
949 bytes: b"lookup".to_vec(),
950 },
951 kitoken::Token {
952 id: 2,
953 bytes: b"q".to_vec(),
954 },
955 kitoken::Token {
956 id: 3,
957 bytes: b"rust".to_vec(),
958 },
959 kitoken::Token {
960 id: 4,
961 bytes: b"resp".to_vec(),
962 },
963 kitoken::Token {
964 id: 5,
965 bytes: b"ok".to_vec(),
966 },
967 kitoken::Token {
968 id: 6,
969 bytes: b"code".to_vec(),
970 },
971 kitoken::Token {
972 id: 7,
973 bytes: b"out".to_vec(),
974 },
975 kitoken::Token {
976 id: 8,
977 bytes: "\u{2581}".as_bytes().to_vec(),
978 },
979 ];
980 let specials = vec![kitoken::SpecialToken {
981 id: 99,
982 bytes: b"[UNK]".to_vec(),
983 kind: kitoken::SpecialTokenKind::Unknown,
984 ident: None,
985 score: 0.0,
986 extract: false,
987 }];
988 let model = kitoken::Model::WordPiece {
989 vocab,
990 max_word_chars: 0,
991 };
992 let config = kitoken::Configuration::default();
993 let meta = kitoken::Metadata::default();
994 Kitoken::new(model, specials, config, meta).unwrap()
995 }
996
997 fn unique_cache_key(tag: &str) -> String {
998 let nanos = SystemTime::now()
999 .duration_since(UNIX_EPOCH)
1000 .unwrap_or_default()
1001 .as_nanos();
1002 format!("test://{tag}-{nanos}")
1003 }
1004
1005 #[test]
1006 fn get_tokenizer_name_known_and_unknown() {
1007 assert_eq!(get_tokenizer_name("gemini-1.5-pro").unwrap(), "gemma2");
1008 let err = get_tokenizer_name("unknown-model").unwrap_err();
1009 match err {
1010 LocalTokenizerError::UnsupportedModel { supported, .. } => {
1011 assert!(supported.contains("gemini-1.0-pro"));
1012 }
1013 _ => panic!("expected UnsupportedModel error"),
1014 }
1015 }
1016
1017 #[test]
1018 fn normalize_token_bytes_replaces_separator_and_handles_invalid_utf8() {
1019 let replaced = normalize_token_bytes("\u{2581}hi".as_bytes());
1020 assert_eq!(replaced, b" hi".to_vec());
1021
1022 let invalid = normalize_token_bytes(&[0xff, 0xfe]);
1023 assert_eq!(invalid, vec![0xff, 0xfe]);
1024 }
1025
1026 #[test]
1027 fn cache_roundtrip_and_mismatch_evicts() {
1028 let key = unique_cache_key("cache-roundtrip");
1029 let path = cache_path_for(&key);
1030 let _ = fs::remove_file(&path);
1031
1032 let bytes = b"cached".to_vec();
1033 write_cache(&path, &bytes).unwrap();
1034 let hash = sha256_hex(&bytes);
1035 let cached = read_cache(&path, &hash).unwrap().unwrap();
1036 assert_eq!(cached, bytes);
1037
1038 let wrong_hash = sha256_hex(b"other");
1039 let result = read_cache(&path, &wrong_hash).unwrap();
1040 assert!(result.is_none());
1041 assert!(!path.exists());
1042 }
1043
1044 #[tokio::test]
1045 async fn load_model_bytes_uses_cache() {
1046 let key = unique_cache_key("load-cache");
1047 let path = cache_path_for(&key);
1048 let _ = fs::remove_file(&path);
1049
1050 let bytes = b"model-bytes".to_vec();
1051 write_cache(&path, &bytes).unwrap();
1052 let hash = sha256_hex(&bytes);
1053
1054 let loaded = load_model_bytes(&key, &hash).await.unwrap();
1055 assert_eq!(loaded, bytes);
1056 }
1057
1058 #[test]
1059 fn collect_part_texts_rejects_binary_parts() {
1060 let inline = Part::inline_data(vec![1, 2, 3], "image/png");
1061 let err = collect_part_texts(&inline).unwrap_err();
1062 assert!(matches!(
1063 err,
1064 LocalTokenizerError::UnsupportedContent {
1065 kind: "inline_data"
1066 }
1067 ));
1068
1069 let file = Part::file_data("files/1", "application/pdf");
1070 let err = collect_part_texts(&file).unwrap_err();
1071 assert!(matches!(
1072 err,
1073 LocalTokenizerError::UnsupportedContent { kind: "file_data" }
1074 ));
1075 }
1076
1077 #[test]
1078 fn kitoken_estimator_compute_tokens_and_map_normalization() {
1079 let encoder = build_test_encoder();
1080 let estimator = KitokenEstimator::from_encoder(encoder);
1081
1082 let call = FunctionCall {
1083 id: None,
1084 name: Some("lookup".into()),
1085 args: Some(json!({"q": "rust"})),
1086 partial_args: None,
1087 will_continue: None,
1088 };
1089 let response = FunctionResponse {
1090 will_continue: None,
1091 scheduling: None,
1092 parts: None,
1093 id: None,
1094 name: Some("resp".into()),
1095 response: Some(json!({"ok": "ok"})),
1096 };
1097 let content = Content::from_parts(
1098 vec![
1099 Part::text("hi"),
1100 Part::function_call(call),
1101 Part::function_response(response),
1102 Part::executable_code("code", Language::Python),
1103 Part::code_execution_result(Outcome::OutcomeOk, "out"),
1104 ],
1105 Role::User,
1106 );
1107
1108 let result = estimator.compute_tokens(&[content]).unwrap();
1109 assert!(!result.tokens_info.as_ref().unwrap().is_empty());
1110
1111 let estimated = estimator.estimate_tokens(&[Content::text("hi")]);
1112 assert!(estimated > 0);
1113
1114 let normalized = estimator.token_bytes.get(&8).unwrap();
1115 assert_eq!(normalized.as_slice(), b" ");
1116 }
1117 }
1118}