1use rust_genai_types::config::GenerationConfig;
4use rust_genai_types::content::{Content, FunctionCall, FunctionResponse, PartKind};
5use rust_genai_types::models::CountTokensConfig;
6use rust_genai_types::tool::{FunctionDeclaration, Schema, Tool};
7use serde_json::Value;
8
9pub trait TokenEstimator {
11 fn estimate_tokens(&self, contents: &[Content]) -> usize;
12}
13
14#[derive(Debug, Clone, Default)]
16pub struct SimpleTokenEstimator;
17
18impl TokenEstimator for SimpleTokenEstimator {
19 fn estimate_tokens(&self, contents: &[Content]) -> usize {
20 let mut bytes = 0usize;
21 for content in contents {
22 for part in &content.parts {
23 match &part.kind {
24 PartKind::Text { text } => {
25 bytes += text.len();
26 }
27 PartKind::InlineData { inline_data } => {
28 bytes += inline_data.data.len();
29 }
30 PartKind::FileData { file_data } => {
31 bytes += file_data.file_uri.len();
32 }
33 PartKind::FunctionCall { function_call } => {
34 if let Some(name) = &function_call.name {
35 bytes += name.len();
36 }
37 }
38 PartKind::FunctionResponse { function_response } => {
39 if let Some(name) = &function_response.name {
40 bytes += name.len();
41 }
42 }
43 PartKind::ExecutableCode { executable_code } => {
44 bytes += executable_code.code.len();
45 }
46 PartKind::CodeExecutionResult {
47 code_execution_result,
48 } => {
49 if let Some(output) = &code_execution_result.output {
50 bytes += output.len();
51 }
52 }
53 }
54 }
55 }
56 bytes.div_ceil(4)
58 }
59}
60
61pub(crate) fn build_estimation_contents(
62 contents: &[Content],
63 config: &CountTokensConfig,
64) -> Vec<Content> {
65 let mut combined = Vec::with_capacity(contents.len() + 1);
66 combined.extend_from_slice(contents);
67 if let Some(system_instruction) = &config.system_instruction {
68 combined.push(system_instruction.clone());
69 }
70
71 let mut accumulator = TextAccumulator::default();
72 accumulator.add_function_texts_from_contents(&combined);
73 if let Some(tools) = &config.tools {
74 accumulator.add_tools(tools);
75 }
76 if let Some(generation_config) = &config.generation_config {
77 accumulator.add_generation_config(generation_config);
78 }
79 combined.extend(accumulator.into_contents());
80 combined
81}
82
83#[derive(Debug, Default)]
84struct TextAccumulator {
85 texts: Vec<String>,
86}
87
88impl TextAccumulator {
89 fn add_function_texts_from_contents(&mut self, contents: &[Content]) {
90 for content in contents {
91 self.add_function_texts_from_content(content);
92 }
93 }
94
95 fn add_function_texts_from_content(&mut self, content: &Content) {
96 for part in &content.parts {
97 match &part.kind {
98 PartKind::FunctionCall { function_call } => {
99 self.add_function_call(function_call);
100 }
101 PartKind::FunctionResponse { function_response } => {
102 self.add_function_response(function_response);
103 }
104 _ => {}
105 }
106 }
107 }
108
109 fn add_function_call(&mut self, function_call: &FunctionCall) {
110 if let Some(name) = &function_call.name {
111 self.push_text(name);
112 }
113 if let Some(args) = &function_call.args {
114 self.add_json(args);
115 }
116 }
117
118 fn add_function_response(&mut self, function_response: &FunctionResponse) {
119 if let Some(name) = &function_response.name {
120 self.push_text(name);
121 }
122 if let Some(response) = &function_response.response {
123 self.add_json(response);
124 }
125 }
126
127 fn add_tools(&mut self, tools: &[Tool]) {
128 for tool in tools {
129 if let Some(functions) = &tool.function_declarations {
130 for function in functions {
131 self.add_function_declaration(function);
132 }
133 }
134 }
135 }
136
137 fn add_function_declaration(&mut self, declaration: &FunctionDeclaration) {
138 self.push_text(&declaration.name);
139 if let Some(description) = &declaration.description {
140 self.push_text(description);
141 }
142 if let Some(parameters) = &declaration.parameters {
143 self.add_schema(parameters);
144 }
145 if let Some(response) = &declaration.response {
146 self.add_schema(response);
147 }
148 if let Some(parameters_json) = &declaration.parameters_json_schema {
149 self.add_json(parameters_json);
150 }
151 if let Some(response_json) = &declaration.response_json_schema {
152 self.add_json(response_json);
153 }
154 }
155
156 fn add_generation_config(&mut self, generation_config: &GenerationConfig) {
157 if let Some(response_schema) = &generation_config.response_schema {
158 self.add_schema(response_schema);
159 }
160 if let Some(response_json_schema) = &generation_config.response_json_schema {
161 self.add_json(response_json_schema);
162 }
163 }
164
165 fn add_schema(&mut self, schema: &Schema) {
166 if let Some(title) = &schema.title {
167 self.push_text(title);
168 }
169 if let Some(format) = &schema.format {
170 self.push_text(format);
171 }
172 if let Some(description) = &schema.description {
173 self.push_text(description);
174 }
175 if let Some(enum_values) = &schema.enum_values {
176 for value in enum_values {
177 self.push_text(value);
178 }
179 }
180 if let Some(required) = &schema.required {
181 for value in required {
182 self.push_text(value);
183 }
184 }
185 if let Some(properties) = &schema.properties {
186 for (key, value) in properties {
187 self.push_text(key);
188 self.add_schema(value);
189 }
190 }
191 if let Some(items) = &schema.items {
192 self.add_schema(items);
193 }
194 if let Some(any_of) = &schema.any_of {
195 for schema in any_of {
196 self.add_schema(schema);
197 }
198 }
199 if let Some(example) = &schema.example {
200 self.add_json(example);
201 }
202 if let Some(default) = &schema.default {
203 self.add_json(default);
204 }
205 }
206
207 fn add_json(&mut self, value: &Value) {
208 match value {
209 Value::String(value) => self.push_text(value),
210 Value::Array(values) => {
211 for item in values {
212 self.add_json(item);
213 }
214 }
215 Value::Object(map) => {
216 for (key, value) in map {
217 self.push_text(key);
218 self.add_json(value);
219 }
220 }
221 _ => {}
222 }
223 }
224
225 fn push_text(&mut self, value: &str) {
226 if !value.is_empty() {
227 self.texts.push(value.to_string());
228 }
229 }
230
231 fn into_contents(self) -> Vec<Content> {
232 self.texts.into_iter().map(Content::text).collect()
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use rust_genai_types::config::GenerationConfig;
240 use rust_genai_types::content::{FunctionCall, FunctionResponse, Part, Role};
241 use rust_genai_types::models::CountTokensConfig;
242 use rust_genai_types::tool::{FunctionDeclaration, Schema, Tool};
243 use serde_json::json;
244 use std::collections::HashMap;
245
246 #[test]
247 fn simple_token_estimator_counts_various_parts() {
248 let call = FunctionCall {
249 id: Some("call-1".into()),
250 name: Some("lookup".into()),
251 args: Some(json!({"q": "rust"})),
252 partial_args: None,
253 will_continue: None,
254 };
255 let response = FunctionResponse {
256 will_continue: None,
257 scheduling: None,
258 parts: None,
259 id: Some("resp-1".into()),
260 name: Some("lookup".into()),
261 response: Some(json!({"ok": true})),
262 };
263 let content = Content::from_parts(
264 vec![
265 Part::text("hello"),
266 Part::inline_data(vec![0, 1, 2, 3], "image/png"),
267 Part::file_data("files/abc", "application/pdf"),
268 Part::function_call(call),
269 Part::function_response(response),
270 Part::executable_code("print('hi')", rust_genai_types::enums::Language::Python),
271 Part::code_execution_result(rust_genai_types::enums::Outcome::OutcomeOk, "ok"),
272 ],
273 Role::User,
274 );
275
276 let estimator = SimpleTokenEstimator;
277 let tokens = estimator.estimate_tokens(&[content]);
278 assert!(tokens > 0);
279 }
280
281 #[test]
282 fn build_estimation_contents_includes_tools_and_config() {
283 let declaration = FunctionDeclaration {
284 name: "search".to_string(),
285 description: Some("desc".to_string()),
286 parameters: Some(
287 Schema::object()
288 .property("q", Schema::string())
289 .required("q")
290 .build(),
291 ),
292 parameters_json_schema: Some(
293 json!({"type": "object", "properties": {"q": {"type": "string"}}}),
294 ),
295 response: Some(Schema::string()),
296 response_json_schema: Some(json!({"type": "string"})),
297 behavior: None,
298 };
299 let tool = Tool {
300 function_declarations: Some(vec![declaration]),
301 ..Default::default()
302 };
303 let generation_config = GenerationConfig {
304 response_schema: Some(Schema::object().property("r", Schema::string()).build()),
305 response_json_schema: Some(
306 json!({"type": "object", "properties": {"r": {"type": "string"}}}),
307 ),
308 ..Default::default()
309 };
310
311 let config = CountTokensConfig {
312 system_instruction: Some(Content::text("sys")),
313 tools: Some(vec![tool]),
314 generation_config: Some(generation_config),
315 };
316
317 let contents = vec![Content::text("user")];
318 let combined = build_estimation_contents(&contents, &config);
319 assert!(combined.len() >= 2);
321 }
322
323 #[test]
324 fn text_accumulator_collects_schema_and_json_fields() {
325 let mut properties = HashMap::new();
326 properties.insert("prop".to_string(), Box::new(Schema::string()));
327 let schema = Schema {
328 title: Some("Title".into()),
329 format: Some("Fmt".into()),
330 description: Some("Desc".into()),
331 enum_values: Some(vec!["A".into(), "B".into()]),
332 required: Some(vec!["req".into()]),
333 properties: Some(properties),
334 items: Some(Box::new(Schema::number())),
335 any_of: Some(vec![Schema::boolean()]),
336 example: Some(json!({"ex_key": "ex_val"})),
337 default: Some(json!(["d"])),
338 ..Default::default()
339 };
340
341 let mut accumulator = TextAccumulator::default();
342 accumulator.add_schema(&schema);
343 accumulator.add_json(&json!(["a", {"k": "v"}, 1]));
344 let texts = accumulator.texts;
345
346 assert!(texts.contains(&"Title".to_string()));
347 assert!(texts.contains(&"Fmt".to_string()));
348 assert!(texts.contains(&"Desc".to_string()));
349 assert!(texts.contains(&"A".to_string()));
350 assert!(texts.contains(&"B".to_string()));
351 assert!(texts.contains(&"req".to_string()));
352 assert!(texts.contains(&"prop".to_string()));
353 assert!(texts.contains(&"ex_key".to_string()));
354 assert!(texts.contains(&"ex_val".to_string()));
355 assert!(texts.contains(&"k".to_string()));
356 assert!(texts.contains(&"v".to_string()));
357 assert!(texts.contains(&"a".to_string()));
358 }
359
360 #[test]
361 fn text_accumulator_collects_function_parts() {
362 let call = FunctionCall {
363 id: None,
364 name: None,
365 args: Some(json!({"q": "rust"})),
366 partial_args: None,
367 will_continue: None,
368 };
369 let response = FunctionResponse {
370 will_continue: None,
371 scheduling: None,
372 parts: None,
373 id: None,
374 name: None,
375 response: Some(json!({"answer": "ok"})),
376 };
377 let content = Content::from_parts(
378 vec![Part::function_call(call), Part::function_response(response)],
379 Role::User,
380 );
381
382 let mut accumulator = TextAccumulator::default();
383 accumulator.add_function_texts_from_content(&content);
384 let texts = accumulator.texts;
385
386 assert!(texts.contains(&"q".to_string()));
387 assert!(texts.contains(&"rust".to_string()));
388 assert!(texts.contains(&"answer".to_string()));
389 assert!(texts.contains(&"ok".to_string()));
390 }
391
392 #[test]
393 fn text_accumulator_collects_named_parts_and_declarations() {
394 let call = FunctionCall {
395 id: None,
396 name: Some("lookup".into()),
397 args: Some(json!({"k": "v"})),
398 partial_args: None,
399 will_continue: None,
400 };
401 let response = FunctionResponse {
402 will_continue: None,
403 scheduling: None,
404 parts: None,
405 id: None,
406 name: Some("lookup_result".into()),
407 response: Some(json!({"out": "done"})),
408 };
409 let content = Content::from_parts(
410 vec![Part::function_call(call), Part::function_response(response)],
411 Role::User,
412 );
413
414 let declaration = FunctionDeclaration {
415 name: "search".to_string(),
416 description: Some("desc".to_string()),
417 parameters: Some(Schema::object().property("q", Schema::string()).build()),
418 parameters_json_schema: Some(
419 json!({"type": "object", "properties": {"q": {"type": "string"}}}),
420 ),
421 response: Some(Schema::string()),
422 response_json_schema: Some(json!({"type": "string"})),
423 behavior: None,
424 };
425
426 let generation_config = GenerationConfig {
427 response_schema: Some(Schema::string()),
428 response_json_schema: Some(json!({"type": "string"})),
429 ..Default::default()
430 };
431
432 let mut accumulator = TextAccumulator::default();
433 accumulator.add_function_texts_from_content(&content);
434 accumulator.add_function_declaration(&declaration);
435 accumulator.add_generation_config(&generation_config);
436 let texts = accumulator.texts;
437
438 assert!(texts.contains(&"lookup".to_string()));
439 assert!(texts.contains(&"lookup_result".to_string()));
440 assert!(texts.contains(&"k".to_string()));
441 assert!(texts.contains(&"v".to_string()));
442 assert!(texts.contains(&"out".to_string()));
443 assert!(texts.contains(&"done".to_string()));
444 assert!(texts.contains(&"search".to_string()));
445 assert!(texts.contains(&"desc".to_string()));
446 assert!(texts.contains(&"q".to_string()));
447 }
448
449 #[test]
450 fn simple_token_estimator_counts_function_names() {
451 let call = FunctionCall {
452 id: None,
453 name: Some("ping".into()),
454 args: None,
455 partial_args: None,
456 will_continue: None,
457 };
458 let response = FunctionResponse {
459 will_continue: None,
460 scheduling: None,
461 parts: None,
462 id: None,
463 name: Some("pong".into()),
464 response: None,
465 };
466 let content = Content::from_parts(
467 vec![Part::function_call(call), Part::function_response(response)],
468 Role::User,
469 );
470
471 let estimator = SimpleTokenEstimator;
472 let tokens = estimator.estimate_tokens(&[content]);
473 assert_eq!(tokens, 2);
474 }
475
476 #[test]
477 fn simple_token_estimator_empty_is_zero() {
478 let estimator = SimpleTokenEstimator;
479 let tokens = estimator.estimate_tokens(&[]);
480 assert_eq!(tokens, 0);
481 }
482}
483
484#[cfg(feature = "kitoken")]
485pub mod kitoken {
486 use super::TokenEstimator;
487 use base64::engine::general_purpose::STANDARD;
488 use base64::Engine as _;
489 use kitoken::convert::ConversionError;
490 use kitoken::EncodeError;
491 use kitoken::Kitoken;
492 use rust_genai_types::content::{
493 Content, FunctionCall, FunctionResponse, Part, PartKind, Role,
494 };
495 use rust_genai_types::models::{ComputeTokensResponse, TokensInfo};
496 use sha2::{Digest, Sha256};
497 use std::collections::HashMap;
498 use std::fmt::Write;
499 use std::fs;
500 use std::path::{Path, PathBuf};
501 use std::sync::Arc;
502
503 const CACHE_DIR: &str = "vertexai_tokenizer_model";
504
505 struct TokenizerConfig {
506 model_url: &'static str,
507 model_hash: &'static str,
508 }
509
510 const GEMINI_MODELS_TO_TOKENIZER_NAMES: &[(&str, &str)] = &[
511 ("gemini-1.0-pro", "gemma2"),
512 ("gemini-1.5-pro", "gemma2"),
513 ("gemini-1.5-flash", "gemma2"),
514 ("gemini-2.5-pro", "gemma3"),
515 ("gemini-2.5-flash", "gemma3"),
516 ("gemini-2.5-flash-lite", "gemma3"),
517 ("gemini-2.0-flash", "gemma3"),
518 ("gemini-2.0-flash-lite", "gemma3"),
519 ];
520
521 const GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES: &[(&str, &str)] = &[
522 ("gemini-1.0-pro-001", "gemma2"),
523 ("gemini-1.0-pro-002", "gemma2"),
524 ("gemini-1.5-pro-001", "gemma2"),
525 ("gemini-1.5-pro-002", "gemma2"),
526 ("gemini-1.5-flash-001", "gemma2"),
527 ("gemini-1.5-flash-002", "gemma2"),
528 ("gemini-2.5-pro-preview-06-05", "gemma3"),
529 ("gemini-2.5-pro-preview-05-06", "gemma3"),
530 ("gemini-2.5-pro-exp-03-25", "gemma3"),
531 ("gemini-live-2.5-flash", "gemma3"),
532 ("gemini-2.5-flash-native-audio-preview-12-2025", "gemma3"),
533 ("gemini-2.5-flash-native-audio-preview-09-2025", "gemma3"),
534 ("gemini-2.5-flash-preview-05-20", "gemma3"),
535 ("gemini-2.5-flash-preview-04-17", "gemma3"),
536 ("gemini-2.5-flash-lite-preview-06-17", "gemma3"),
537 ("gemini-2.0-flash-001", "gemma3"),
538 ("gemini-2.0-flash-lite-001", "gemma3"),
539 ("gemini-3-pro-preview", "gemma3"),
540 ];
541
542 fn tokenizer_config(name: &str) -> Option<TokenizerConfig> {
543 match name {
544 "gemma2" => Some(TokenizerConfig {
545 model_url: "https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model",
546 model_hash: "61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2",
547 }),
548 "gemma3" => Some(TokenizerConfig {
549 model_url: "https://raw.githubusercontent.com/google/gemma_pytorch/014acb7ac4563a5f77c76d7ff98f31b568c16508/tokenizer/gemma3_cleaned_262144_v2.spiece.model",
550 model_hash: "1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c",
551 }),
552 _ => None,
553 }
554 }
555
556 #[derive(Debug, thiserror::Error)]
557 pub enum LocalTokenizerError {
558 #[error("Model {model} is not supported. Supported models: {supported}")]
559 UnsupportedModel { model: String, supported: String },
560 #[error("Tokenizer {name} is not supported")]
561 UnsupportedTokenizer { name: String },
562 #[error("Failed to download tokenizer model from {url}: {source}")]
563 Download {
564 url: String,
565 #[source]
566 source: reqwest::Error,
567 },
568 #[error("Tokenizer model download returned non-success status {status} for {url}")]
569 DownloadStatus { url: String, status: u16 },
570 #[error("Tokenizer model hash mismatch. expected {expected}, got {actual}")]
571 HashMismatch { expected: String, actual: String },
572 #[error("IO error: {source}")]
573 Io {
574 #[from]
575 source: std::io::Error,
576 },
577 #[error("Tokenizer encode error: {source}")]
578 Encode {
579 #[from]
580 source: EncodeError,
581 },
582 #[error("Local tokenizer does not support non-text content: {kind}")]
583 UnsupportedContent { kind: &'static str },
584 #[error("Tokenizer token id {id} not found in vocabulary")]
585 MissingToken { id: u32 },
586 #[error("Tokenizer conversion error: {source}")]
587 Conversion {
588 #[from]
589 source: ConversionError,
590 },
591 }
592
593 #[derive(Debug, Clone)]
595 pub struct KitokenEstimator {
596 encoder: Arc<Kitoken>,
597 token_bytes: Arc<HashMap<u32, Vec<u8>>>,
598 }
599
600 impl KitokenEstimator {
601 fn from_encoder(encoder: Kitoken) -> Self {
602 let token_bytes = Arc::new(build_token_bytes_map(&encoder));
603 Self {
604 encoder: Arc::new(encoder),
605 token_bytes,
606 }
607 }
608
609 pub fn from_sentencepiece_file(
614 path: impl AsRef<Path>,
615 ) -> Result<Self, LocalTokenizerError> {
616 let encoder = Kitoken::from_sentencepiece_file(path)?;
617 Ok(Self::from_encoder(encoder))
618 }
619
620 pub async fn from_model_name(model_name: &str) -> Result<Self, LocalTokenizerError> {
625 let tokenizer_name = get_tokenizer_name(model_name)?;
626 let config = tokenizer_config(tokenizer_name).ok_or_else(|| {
627 LocalTokenizerError::UnsupportedTokenizer {
628 name: tokenizer_name.to_string(),
629 }
630 })?;
631 let model_bytes = load_model_bytes(config.model_url, config.model_hash).await?;
632 let encoder = Kitoken::from_sentencepiece_slice(&model_bytes)?;
633 Ok(Self::from_encoder(encoder))
634 }
635
636 pub fn compute_tokens(
641 &self,
642 contents: &[Content],
643 ) -> Result<ComputeTokensResponse, LocalTokenizerError> {
644 let mut tokens_info: Vec<TokensInfo> = Vec::new();
645 for content in contents {
646 let role = content
647 .role
648 .map(|role| match role {
649 Role::User => "user",
650 Role::Model => "model",
651 Role::Function => "function",
652 })
653 .map(ToString::to_string);
654
655 for part in &content.parts {
656 let texts = collect_part_texts(part)?;
657 if texts.is_empty() {
658 continue;
659 }
660 let mut token_ids = Vec::new();
661 let mut tokens = Vec::new();
662 for text in texts {
663 if text.is_empty() {
664 continue;
665 }
666 let ids = self.encoder.encode(&text, true)?;
667 for id in ids {
668 let bytes = self
669 .token_bytes
670 .get(&id)
671 .ok_or(LocalTokenizerError::MissingToken { id })?;
672 tokens.push(STANDARD.encode(bytes));
673 token_ids.push(i64::from(id));
674 }
675 }
676 if token_ids.is_empty() {
677 continue;
678 }
679 tokens_info.push(TokensInfo {
680 role: role.clone(),
681 token_ids: Some(token_ids),
682 tokens: Some(tokens),
683 });
684 }
685 }
686
687 Ok(ComputeTokensResponse {
688 sdk_http_response: None,
689 tokens_info: Some(tokens_info),
690 })
691 }
692 }
693
694 impl TokenEstimator for KitokenEstimator {
695 fn estimate_tokens(&self, contents: &[Content]) -> usize {
696 let mut total = 0usize;
697 for content in contents {
698 for part in &content.parts {
699 if let Some(text) = part.text_value() {
700 if let Ok(tokens) = self.encoder.encode(text, true) {
701 total += tokens.len();
702 }
703 }
704 }
705 }
706 total
707 }
708 }
709
710 fn build_token_bytes_map(encoder: &Kitoken) -> HashMap<u32, Vec<u8>> {
711 let definition = encoder.to_definition();
712 let mut map = HashMap::new();
713 for token in definition.model.vocab() {
714 map.insert(token.id, normalize_token_bytes(&token.bytes));
715 }
716 for special in definition.specials {
717 map.insert(special.id, normalize_token_bytes(&special.bytes));
718 }
719 map
720 }
721
722 fn normalize_token_bytes(bytes: &[u8]) -> Vec<u8> {
723 std::str::from_utf8(bytes).map_or_else(
724 |_| bytes.to_vec(),
725 |text| text.replace('▁', " ").into_bytes(),
726 )
727 }
728
729 fn collect_part_texts(part: &Part) -> Result<Vec<String>, LocalTokenizerError> {
730 let mut texts = Vec::new();
731 match &part.kind {
732 PartKind::Text { text } => {
733 if !text.is_empty() {
734 texts.push(text.clone());
735 }
736 }
737 PartKind::FunctionCall { function_call } => {
738 add_function_call_texts(function_call, &mut texts);
739 }
740 PartKind::FunctionResponse { function_response } => {
741 add_function_response_texts(function_response, &mut texts);
742 }
743 PartKind::ExecutableCode { executable_code } => {
744 if !executable_code.code.is_empty() {
745 texts.push(executable_code.code.clone());
746 }
747 }
748 PartKind::CodeExecutionResult {
749 code_execution_result,
750 } => {
751 if let Some(output) = &code_execution_result.output {
752 if !output.is_empty() {
753 texts.push(output.clone());
754 }
755 }
756 }
757 PartKind::InlineData { .. } => {
758 return Err(LocalTokenizerError::UnsupportedContent {
759 kind: "inline_data",
760 });
761 }
762 PartKind::FileData { .. } => {
763 return Err(LocalTokenizerError::UnsupportedContent { kind: "file_data" });
764 }
765 }
766 Ok(texts)
767 }
768
769 fn add_function_call_texts(function_call: &FunctionCall, texts: &mut Vec<String>) {
770 if let Some(name) = &function_call.name {
771 if !name.is_empty() {
772 texts.push(name.clone());
773 }
774 }
775 if let Some(args) = &function_call.args {
776 add_json_texts(args, texts);
777 }
778 }
779
780 fn add_function_response_texts(function_response: &FunctionResponse, texts: &mut Vec<String>) {
781 if let Some(name) = &function_response.name {
782 if !name.is_empty() {
783 texts.push(name.clone());
784 }
785 }
786 if let Some(response) = &function_response.response {
787 add_json_texts(response, texts);
788 }
789 }
790
791 fn add_json_texts(value: &serde_json::Value, texts: &mut Vec<String>) {
792 match value {
793 serde_json::Value::String(value) => {
794 if !value.is_empty() {
795 texts.push(value.clone());
796 }
797 }
798 serde_json::Value::Array(values) => {
799 for item in values {
800 add_json_texts(item, texts);
801 }
802 }
803 serde_json::Value::Object(map) => {
804 for (key, value) in map {
805 if !key.is_empty() {
806 texts.push(key.clone());
807 }
808 add_json_texts(value, texts);
809 }
810 }
811 _ => {}
812 }
813 }
814
815 fn get_tokenizer_name(model_name: &str) -> Result<&'static str, LocalTokenizerError> {
816 for (name, tokenizer) in GEMINI_MODELS_TO_TOKENIZER_NAMES {
817 if *name == model_name {
818 return Ok(*tokenizer);
819 }
820 }
821 for (name, tokenizer) in GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES {
822 if *name == model_name {
823 return Ok(*tokenizer);
824 }
825 }
826 let mut supported: Vec<String> = GEMINI_MODELS_TO_TOKENIZER_NAMES
827 .iter()
828 .map(|(name, _)| (*name).to_string())
829 .collect();
830 supported.extend(
831 GEMINI_STABLE_MODELS_TO_TOKENIZER_NAMES
832 .iter()
833 .map(|(name, _)| (*name).to_string()),
834 );
835 supported.sort();
836 supported.dedup();
837 Err(LocalTokenizerError::UnsupportedModel {
838 model: model_name.to_string(),
839 supported: supported.join(", "),
840 })
841 }
842
843 async fn load_model_bytes(
844 url: &str,
845 expected_hash: &str,
846 ) -> Result<Vec<u8>, LocalTokenizerError> {
847 let cache_path = cache_path_for(url);
848 if let Some(bytes) = read_cache(&cache_path, expected_hash)? {
849 return Ok(bytes);
850 }
851 let bytes = download_model(url).await?;
852 let actual_hash = sha256_hex(&bytes);
853 if actual_hash != expected_hash {
854 return Err(LocalTokenizerError::HashMismatch {
855 expected: expected_hash.to_string(),
856 actual: actual_hash,
857 });
858 }
859 let _ = write_cache(&cache_path, &bytes);
860 Ok(bytes)
861 }
862
863 fn cache_path_for(url: &str) -> PathBuf {
864 let filename = sha256_hex(url.as_bytes());
865 std::env::temp_dir().join(CACHE_DIR).join(filename)
866 }
867
868 fn read_cache(
869 path: &Path,
870 expected_hash: &str,
871 ) -> Result<Option<Vec<u8>>, LocalTokenizerError> {
872 if !path.exists() {
873 return Ok(None);
874 }
875 let bytes = fs::read(path)?;
876 if sha256_hex(&bytes) == expected_hash {
877 return Ok(Some(bytes));
878 }
879 let _ = fs::remove_file(path);
880 Ok(None)
881 }
882
883 fn write_cache(path: &Path, bytes: &[u8]) -> Result<(), LocalTokenizerError> {
884 if let Some(parent) = path.parent() {
885 fs::create_dir_all(parent)?;
886 }
887 let tmp_path = path.with_extension("tmp");
888 fs::write(&tmp_path, bytes)?;
889 fs::rename(tmp_path, path)?;
890 Ok(())
891 }
892
893 async fn download_model(url: &str) -> Result<Vec<u8>, LocalTokenizerError> {
894 let response = reqwest::get(url)
895 .await
896 .map_err(|source| LocalTokenizerError::Download {
897 url: url.to_string(),
898 source,
899 })?;
900 let status = response.status();
901 if !status.is_success() {
902 return Err(LocalTokenizerError::DownloadStatus {
903 url: url.to_string(),
904 status: status.as_u16(),
905 });
906 }
907 let bytes = response
908 .bytes()
909 .await
910 .map_err(|source| LocalTokenizerError::Download {
911 url: url.to_string(),
912 source,
913 })?;
914 Ok(bytes.to_vec())
915 }
916
917 fn sha256_hex(data: &[u8]) -> String {
918 let digest = Sha256::digest(data);
919 let mut output = String::with_capacity(digest.len() * 2);
920 for byte in digest {
921 let _ = write!(output, "{byte:02x}");
922 }
923 output
924 }
925
926 #[cfg(test)]
927 mod tests {
928 use super::*;
929 use rust_genai_types::content::{Content, FunctionCall, FunctionResponse, Part, Role};
930 use rust_genai_types::enums::{Language, Outcome};
931 use serde_json::json;
932 use std::fs;
933 use std::time::{SystemTime, UNIX_EPOCH};
934
935 fn build_test_encoder() -> Kitoken {
936 let vocab = vec![
937 kitoken::Token {
938 id: 0,
939 bytes: b"hi".to_vec(),
940 },
941 kitoken::Token {
942 id: 1,
943 bytes: b"lookup".to_vec(),
944 },
945 kitoken::Token {
946 id: 2,
947 bytes: b"q".to_vec(),
948 },
949 kitoken::Token {
950 id: 3,
951 bytes: b"rust".to_vec(),
952 },
953 kitoken::Token {
954 id: 4,
955 bytes: b"resp".to_vec(),
956 },
957 kitoken::Token {
958 id: 5,
959 bytes: b"ok".to_vec(),
960 },
961 kitoken::Token {
962 id: 6,
963 bytes: b"code".to_vec(),
964 },
965 kitoken::Token {
966 id: 7,
967 bytes: b"out".to_vec(),
968 },
969 kitoken::Token {
970 id: 8,
971 bytes: "\u{2581}".as_bytes().to_vec(),
972 },
973 ];
974 let specials = vec![kitoken::SpecialToken {
975 id: 99,
976 bytes: b"[UNK]".to_vec(),
977 kind: kitoken::SpecialTokenKind::Unknown,
978 ident: None,
979 score: 0.0,
980 extract: false,
981 }];
982 let model = kitoken::Model::WordPiece {
983 vocab,
984 max_word_chars: 0,
985 };
986 let config = kitoken::Configuration::default();
987 let meta = kitoken::Metadata::default();
988 Kitoken::new(model, specials, config, meta).unwrap()
989 }
990
991 fn unique_cache_key(tag: &str) -> String {
992 let nanos = SystemTime::now()
993 .duration_since(UNIX_EPOCH)
994 .unwrap_or_default()
995 .as_nanos();
996 format!("test://{tag}-{nanos}")
997 }
998
999 #[test]
1000 fn get_tokenizer_name_known_and_unknown() {
1001 assert_eq!(get_tokenizer_name("gemini-1.5-pro").unwrap(), "gemma2");
1002 let err = get_tokenizer_name("unknown-model").unwrap_err();
1003 match err {
1004 LocalTokenizerError::UnsupportedModel { supported, .. } => {
1005 assert!(supported.contains("gemini-1.0-pro"));
1006 }
1007 _ => panic!("expected UnsupportedModel error"),
1008 }
1009 }
1010
1011 #[test]
1012 fn normalize_token_bytes_replaces_separator_and_handles_invalid_utf8() {
1013 let replaced = normalize_token_bytes("\u{2581}hi".as_bytes());
1014 assert_eq!(replaced, b" hi".to_vec());
1015
1016 let invalid = normalize_token_bytes(&[0xff, 0xfe]);
1017 assert_eq!(invalid, vec![0xff, 0xfe]);
1018 }
1019
1020 #[test]
1021 fn cache_roundtrip_and_mismatch_evicts() {
1022 let key = unique_cache_key("cache-roundtrip");
1023 let path = cache_path_for(&key);
1024 let _ = fs::remove_file(&path);
1025
1026 let bytes = b"cached".to_vec();
1027 write_cache(&path, &bytes).unwrap();
1028 let hash = sha256_hex(&bytes);
1029 let cached = read_cache(&path, &hash).unwrap().unwrap();
1030 assert_eq!(cached, bytes);
1031
1032 let wrong_hash = sha256_hex(b"other");
1033 let result = read_cache(&path, &wrong_hash).unwrap();
1034 assert!(result.is_none());
1035 assert!(!path.exists());
1036 }
1037
1038 #[tokio::test]
1039 async fn load_model_bytes_uses_cache() {
1040 let key = unique_cache_key("load-cache");
1041 let path = cache_path_for(&key);
1042 let _ = fs::remove_file(&path);
1043
1044 let bytes = b"model-bytes".to_vec();
1045 write_cache(&path, &bytes).unwrap();
1046 let hash = sha256_hex(&bytes);
1047
1048 let loaded = load_model_bytes(&key, &hash).await.unwrap();
1049 assert_eq!(loaded, bytes);
1050 }
1051
1052 #[test]
1053 fn collect_part_texts_rejects_binary_parts() {
1054 let inline = Part::inline_data(vec![1, 2, 3], "image/png");
1055 let err = collect_part_texts(&inline).unwrap_err();
1056 assert!(matches!(
1057 err,
1058 LocalTokenizerError::UnsupportedContent {
1059 kind: "inline_data"
1060 }
1061 ));
1062
1063 let file = Part::file_data("files/1", "application/pdf");
1064 let err = collect_part_texts(&file).unwrap_err();
1065 assert!(matches!(
1066 err,
1067 LocalTokenizerError::UnsupportedContent { kind: "file_data" }
1068 ));
1069 }
1070
1071 #[test]
1072 fn kitoken_estimator_compute_tokens_and_map_normalization() {
1073 let encoder = build_test_encoder();
1074 let estimator = KitokenEstimator::from_encoder(encoder);
1075
1076 let call = FunctionCall {
1077 id: None,
1078 name: Some("lookup".into()),
1079 args: Some(json!({"q": "rust"})),
1080 partial_args: None,
1081 will_continue: None,
1082 };
1083 let response = FunctionResponse {
1084 will_continue: None,
1085 scheduling: None,
1086 parts: None,
1087 id: None,
1088 name: Some("resp".into()),
1089 response: Some(json!({"ok": "ok"})),
1090 };
1091 let content = Content::from_parts(
1092 vec![
1093 Part::text("hi"),
1094 Part::function_call(call),
1095 Part::function_response(response),
1096 Part::executable_code("code", Language::Python),
1097 Part::code_execution_result(Outcome::OutcomeOk, "out"),
1098 ],
1099 Role::User,
1100 );
1101
1102 let result = estimator.compute_tokens(&[content]).unwrap();
1103 assert!(!result.tokens_info.as_ref().unwrap().is_empty());
1104
1105 let estimated = estimator.estimate_tokens(&[Content::text("hi")]);
1106 assert!(estimated > 0);
1107
1108 let normalized = estimator.token_bytes.get(&8).unwrap();
1109 assert_eq!(normalized.as_slice(), b" ");
1110 }
1111 }
1112}