1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::Path;
4use trustformers_core::errors::{Result, TrustformersError};
5use trustformers_core::traits::{TokenizedInput, Tokenizer};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ProtobufTokenizerMetadata {
10 pub name: String,
11 pub version: String,
12 pub tokenizer_type: String,
13 pub vocab_size: u32,
14 pub special_tokens: HashMap<String, u32>,
15 pub max_length: Option<u32>,
16 pub truncation_side: String,
17 pub padding_side: String,
18 pub do_lower_case: bool,
19 pub strip_accents: Option<bool>,
20 pub add_prefix_space: bool,
21 pub trim_offsets: bool,
22 pub created_at: String,
23 pub model_id: Option<String>,
24 pub custom_attributes: HashMap<String, Vec<u8>>, }
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ProtobufVocabEntry {
30 pub token: String,
31 pub id: u32,
32 pub frequency: f64,
33 pub is_special: bool,
34 pub token_type: u32, }
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ProtobufNormalizationRule {
40 pub rule_type: u32, pub pattern: Option<String>,
42 pub replacement: Option<String>,
43 pub enabled: bool,
44 pub priority: u32,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ProtobufMergeRule {
50 pub first_token: String,
51 pub second_token: String,
52 pub merged_token: String,
53 pub priority: u32,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ProtobufTokenizerModel {
59 pub metadata: ProtobufTokenizerMetadata,
60 pub vocabulary: Vec<ProtobufVocabEntry>,
61 pub normalization_rules: Vec<ProtobufNormalizationRule>,
62 pub merge_rules: Vec<ProtobufMergeRule>,
63 pub added_tokens: Vec<ProtobufVocabEntry>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ProtobufTokenizedInput {
69 pub input_ids: Vec<u32>,
70 pub attention_mask: Vec<u32>,
71 pub token_type_ids: Vec<u32>,
72 pub special_tokens_mask: Vec<u32>,
73 pub offset_mapping: Vec<ProtobufOffset>,
74 pub overflowing_tokens: Vec<ProtobufTokenizedInput>,
75 pub num_truncated_tokens: u32,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ProtobufOffset {
81 pub start: u32,
82 pub end: u32,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct ProtobufBatchTokenizedInput {
88 pub batch: Vec<ProtobufTokenizedInput>,
89 pub batch_size: u32,
90 pub max_length: u32,
91 pub padding_strategy: u32, }
93
94pub struct ProtobufSerializer;
96
97impl ProtobufSerializer {
98 pub fn serialize_tokenizer<T: Tokenizer>(
100 tokenizer: &T,
101 metadata: ProtobufTokenizerMetadata,
102 ) -> Result<ProtobufTokenizerModel> {
103 let vocab_map = tokenizer.get_vocab();
105 let mut vocabulary = Vec::new();
106
107 for (token, id) in vocab_map.iter() {
109 vocabulary.push(ProtobufVocabEntry {
110 token: token.clone(),
111 id: *id,
112 frequency: 0.0, is_special: Self::is_special_token(token),
114 token_type: 0, });
116 }
117
118 vocabulary.sort_by_key(|token| token.id);
120
121 let normalization_rules = vec![
123 ProtobufNormalizationRule {
124 rule_type: 1, enabled: true,
126 pattern: None,
127 replacement: None,
128 priority: 1,
129 },
130 ProtobufNormalizationRule {
131 rule_type: 2, enabled: false, pattern: None,
134 replacement: None,
135 priority: 2,
136 },
137 ];
138
139 let merge_rules = vec![];
142
143 let mut added_tokens = Vec::new();
145 for (token, id) in vocab_map.iter() {
146 if Self::is_special_token(token) {
147 added_tokens.push(ProtobufVocabEntry {
148 token: token.clone(),
149 id: *id,
150 frequency: 1.0, is_special: true,
152 token_type: 1, });
154 }
155 }
156
157 Ok(ProtobufTokenizerModel {
158 metadata,
159 vocabulary,
160 normalization_rules,
161 merge_rules,
162 added_tokens,
163 })
164 }
165
166 fn is_special_token(token: &str) -> bool {
168 token.starts_with('<') && token.ends_with('>')
170 || token.starts_with('[') && token.ends_with(']')
171 || matches!(
172 token,
173 "<pad>"
174 | "<unk>"
175 | "<s>"
176 | "</s>"
177 | "<cls>"
178 | "<sep>"
179 | "<mask>"
180 | "[PAD]"
181 | "[UNK]"
182 | "[CLS]"
183 | "[SEP]"
184 | "[MASK]"
185 | "[BOS]"
186 | "[EOS]"
187 )
188 }
189
190 pub fn serialize_tokenized_input(input: &TokenizedInput) -> ProtobufTokenizedInput {
192 ProtobufTokenizedInput {
193 input_ids: input.input_ids.clone(),
194 attention_mask: input.attention_mask.iter().map(|&x| x as u32).collect(),
195 token_type_ids: input.token_type_ids.clone().unwrap_or_default(),
196 special_tokens_mask: vec![], offset_mapping: vec![], overflowing_tokens: vec![],
199 num_truncated_tokens: 0,
200 }
201 }
202
203 pub fn deserialize_tokenized_input(protobuf_input: &ProtobufTokenizedInput) -> TokenizedInput {
205 TokenizedInput {
206 input_ids: protobuf_input.input_ids.clone(),
207 attention_mask: protobuf_input.attention_mask.iter().map(|&x| x as u8).collect(),
208 token_type_ids: if protobuf_input.token_type_ids.is_empty() {
209 None
210 } else {
211 Some(protobuf_input.token_type_ids.clone())
212 },
213 special_tokens_mask: None,
214 offset_mapping: None,
215 overflowing_tokens: None,
216 }
217 }
218
219 pub fn to_protobuf_bytes(model: &ProtobufTokenizerModel) -> Result<Vec<u8>> {
221 oxicode::serde::encode_to_vec(model, oxicode::config::standard()).map_err(|e| {
224 TrustformersError::other(
225 anyhow::anyhow!("Failed to serialize protobuf: {}", e).to_string(),
226 )
227 })
228 }
229
230 pub fn from_protobuf_bytes(bytes: &[u8]) -> Result<ProtobufTokenizerModel> {
232 let (result, _): (ProtobufTokenizerModel, usize) =
233 oxicode::serde::decode_from_slice(bytes, oxicode::config::standard()).map_err(|e| {
234 TrustformersError::other(
235 anyhow::anyhow!("Failed to deserialize protobuf: {}", e).to_string(),
236 )
237 })?;
238 Ok(result)
239 }
240
241 pub fn save_to_file<P: AsRef<Path>>(model: &ProtobufTokenizerModel, path: P) -> Result<()> {
243 let bytes = Self::to_protobuf_bytes(model)?;
244 std::fs::write(path, bytes).map_err(|e| {
245 TrustformersError::other(
246 anyhow::anyhow!("Failed to write protobuf file: {}", e).to_string(),
247 )
248 })
249 }
250
251 pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<ProtobufTokenizerModel> {
253 let bytes = std::fs::read(path).map_err(|e| {
254 TrustformersError::other(
255 anyhow::anyhow!("Failed to read protobuf file: {}", e).to_string(),
256 )
257 })?;
258 Self::from_protobuf_bytes(&bytes)
259 }
260
261 pub fn to_proto_text(model: &ProtobufTokenizerModel) -> Result<String> {
263 let mut text = String::new();
265
266 text.push_str("# Tokenizer Model (Proto Text Format)\n");
267 text.push_str("metadata {\n");
268 text.push_str(&format!(" name: \"{}\"\n", model.metadata.name));
269 text.push_str(&format!(" version: \"{}\"\n", model.metadata.version));
270 text.push_str(&format!(
271 " tokenizer_type: \"{}\"\n",
272 model.metadata.tokenizer_type
273 ));
274 text.push_str(&format!(" vocab_size: {}\n", model.metadata.vocab_size));
275 text.push_str(&format!(
276 " do_lower_case: {}\n",
277 model.metadata.do_lower_case
278 ));
279 text.push_str("}\n\n");
280
281 if !model.vocabulary.is_empty() {
282 text.push_str("vocabulary {\n");
283 for (i, entry) in model.vocabulary.iter().enumerate() {
284 if i >= 10 {
285 text.push_str(&format!(
287 " # ... {} more entries\n",
288 model.vocabulary.len() - 10
289 ));
290 break;
291 }
292 text.push_str(" entry {\n");
293 text.push_str(&format!(" token: \"{}\"\n", entry.token));
294 text.push_str(&format!(" id: {}\n", entry.id));
295 text.push_str(&format!(" frequency: {}\n", entry.frequency));
296 text.push_str(&format!(" is_special: {}\n", entry.is_special));
297 text.push_str(" }\n");
298 }
299 text.push_str("}\n\n");
300 }
301
302 if !model.merge_rules.is_empty() {
303 text.push_str("merge_rules {\n");
304 for (i, rule) in model.merge_rules.iter().enumerate() {
305 if i >= 5 {
306 text.push_str(&format!(
308 " # ... {} more rules\n",
309 model.merge_rules.len() - 5
310 ));
311 break;
312 }
313 text.push_str(" rule {\n");
314 text.push_str(&format!(" first_token: \"{}\"\n", rule.first_token));
315 text.push_str(&format!(" second_token: \"{}\"\n", rule.second_token));
316 text.push_str(&format!(" merged_token: \"{}\"\n", rule.merged_token));
317 text.push_str(&format!(" priority: {}\n", rule.priority));
318 text.push_str(" }\n");
319 }
320 text.push_str("}\n");
321 }
322
323 Ok(text)
324 }
325
326 pub fn from_proto_text(text: &str) -> Result<ProtobufTokenizerModel> {
328 let mut metadata = ProtobufTokenizerMetadata {
332 name: "unknown".to_string(),
333 version: "1.0".to_string(),
334 tokenizer_type: "unknown".to_string(),
335 vocab_size: 0,
336 special_tokens: HashMap::new(),
337 max_length: None,
338 truncation_side: "right".to_string(),
339 padding_side: "right".to_string(),
340 do_lower_case: false,
341 strip_accents: None,
342 add_prefix_space: false,
343 trim_offsets: true,
344 created_at: chrono::Utc::now().to_rfc3339(),
345 model_id: None,
346 custom_attributes: HashMap::new(),
347 };
348
349 for line in text.lines() {
351 let line = line.trim();
352 if line.starts_with("name:") {
353 if let Some(name) = Self::extract_quoted_value(line) {
354 metadata.name = name;
355 }
356 } else if line.starts_with("version:") {
357 if let Some(version) = Self::extract_quoted_value(line) {
358 metadata.version = version;
359 }
360 } else if line.starts_with("tokenizer_type:") {
361 if let Some(tokenizer_type) = Self::extract_quoted_value(line) {
362 metadata.tokenizer_type = tokenizer_type;
363 }
364 } else if line.starts_with("vocab_size:") {
365 if let Some(size_str) = line.split(':').nth(1) {
366 if let Ok(size) = size_str.trim().parse::<u32>() {
367 metadata.vocab_size = size;
368 }
369 }
370 } else if line.starts_with("do_lower_case:") {
371 if let Some(bool_str) = line.split(':').nth(1) {
372 metadata.do_lower_case = bool_str.trim() == "true";
373 }
374 }
375 }
376
377 Ok(ProtobufTokenizerModel {
378 metadata,
379 vocabulary: vec![],
380 normalization_rules: vec![],
381 merge_rules: vec![],
382 added_tokens: vec![],
383 })
384 }
385
386 fn extract_quoted_value(line: &str) -> Option<String> {
388 if let Some(start) = line.find('"') {
389 if let Some(end) = line.rfind('"') {
390 if start < end {
391 return Some(line[start + 1..end].to_string());
392 }
393 }
394 }
395 None
396 }
397
398 pub fn validate_model(model: &ProtobufTokenizerModel) -> Result<Vec<String>> {
400 let mut warnings = Vec::new();
401
402 if model.vocabulary.len() != model.metadata.vocab_size as usize {
404 warnings.push(format!(
405 "Vocabulary size mismatch: metadata claims {} but found {} tokens",
406 model.metadata.vocab_size,
407 model.vocabulary.len()
408 ));
409 }
410
411 let mut seen_ids = std::collections::HashSet::new();
413 for entry in &model.vocabulary {
414 if !seen_ids.insert(entry.id) {
415 warnings.push(format!("Duplicate token ID: {}", entry.id));
416 }
417 }
418
419 for rule in &model.merge_rules {
421 if rule.first_token.is_empty() || rule.second_token.is_empty() {
422 warnings.push("Empty tokens in merge rule".to_string());
423 }
424 }
425
426 Ok(warnings)
427 }
428
429 pub fn get_model_stats(model: &ProtobufTokenizerModel) -> HashMap<String, serde_json::Value> {
431 let mut stats = HashMap::new();
432
433 stats.insert(
434 "vocab_size".to_string(),
435 serde_json::Value::Number(model.vocabulary.len().into()),
436 );
437
438 stats.insert(
439 "special_tokens_count".to_string(),
440 serde_json::Value::Number(model.metadata.special_tokens.len().into()),
441 );
442
443 stats.insert(
444 "merge_rules_count".to_string(),
445 serde_json::Value::Number(model.merge_rules.len().into()),
446 );
447
448 stats.insert(
449 "normalization_rules_count".to_string(),
450 serde_json::Value::Number(model.normalization_rules.len().into()),
451 );
452
453 let special_token_ratio = if model.metadata.vocab_size > 0 {
454 model.metadata.special_tokens.len() as f64 / model.metadata.vocab_size as f64
455 } else {
456 0.0
457 };
458 if let Some(ratio_number) = serde_json::Number::from_f64(special_token_ratio) {
459 stats.insert(
460 "special_token_ratio".to_string(),
461 serde_json::Value::Number(ratio_number),
462 );
463 }
464
465 stats
466 }
467
468 pub fn compress_model(model: &ProtobufTokenizerModel) -> Result<Vec<u8>> {
470 let serialized = Self::to_protobuf_bytes(model)?;
471
472 use oxiarc_deflate::streaming::GzipStreamEncoder;
473 use std::io::Write;
474
475 let mut encoder = GzipStreamEncoder::new(Vec::new(), 6);
476 encoder.write_all(&serialized).map_err(|e| {
477 TrustformersError::other(anyhow::anyhow!("Failed to compress: {}", e).to_string())
478 })?;
479
480 encoder.finish().map_err(|e| {
481 TrustformersError::other(
482 anyhow::anyhow!("Failed to finish compression: {}", e).to_string(),
483 )
484 })
485 }
486
487 pub fn decompress_model(compressed_data: &[u8]) -> Result<ProtobufTokenizerModel> {
489 use oxiarc_deflate::streaming::GzipStreamDecoder;
490 use std::io::Read;
491
492 let mut decoder = GzipStreamDecoder::new(compressed_data);
493 let mut decompressed = Vec::new();
494 decoder.read_to_end(&mut decompressed).map_err(|e| {
495 TrustformersError::other(anyhow::anyhow!("Failed to decompress: {}", e).to_string())
496 })?;
497
498 Self::from_protobuf_bytes(&decompressed)
499 }
500}
501
502pub trait ProtobufConvertible {
504 fn to_protobuf_model(
506 &self,
507 metadata: ProtobufTokenizerMetadata,
508 ) -> Result<ProtobufTokenizerModel>;
509
510 fn from_protobuf_model(model: &ProtobufTokenizerModel) -> Result<Self>
512 where
513 Self: Sized;
514}
515
516#[derive(Debug, Clone, Serialize, Deserialize)]
518pub struct ProtobufExportConfig {
519 pub include_vocabulary: bool,
520 pub include_merge_rules: bool,
521 pub include_normalization_rules: bool,
522 pub compress_output: bool,
523 pub validate_output: bool,
524 pub export_format: ProtobufFormat,
525}
526
527#[derive(Debug, Clone, Serialize, Deserialize)]
529pub enum ProtobufFormat {
530 Binary,
531 TextFormat,
532 Json,
533 CompressedBinary,
534}
535
536impl Default for ProtobufExportConfig {
537 fn default() -> Self {
538 Self {
539 include_vocabulary: true,
540 include_merge_rules: true,
541 include_normalization_rules: true,
542 compress_output: false,
543 validate_output: true,
544 export_format: ProtobufFormat::Binary,
545 }
546 }
547}
548
549pub struct ProtobufExporter {
551 config: ProtobufExportConfig,
552}
553
554impl ProtobufExporter {
555 pub fn new(config: ProtobufExportConfig) -> Self {
557 Self { config }
558 }
559
560 pub fn export_model<P: AsRef<Path>>(
562 &self,
563 model: &ProtobufTokenizerModel,
564 path: P,
565 ) -> Result<()> {
566 if self.config.validate_output {
568 let warnings = ProtobufSerializer::validate_model(model)?;
569 if !warnings.is_empty() {
570 eprintln!("Validation warnings:");
571 for warning in warnings {
572 eprintln!(" - {}", warning);
573 }
574 }
575 }
576
577 match self.config.export_format {
578 ProtobufFormat::Binary => {
579 if self.config.compress_output {
580 let compressed = ProtobufSerializer::compress_model(model)?;
581 std::fs::write(path, compressed).map_err(|e| {
582 TrustformersError::other(
583 anyhow::anyhow!("Failed to write file: {}", e).to_string(),
584 )
585 })?;
586 } else {
587 ProtobufSerializer::save_to_file(model, path)?;
588 }
589 },
590 ProtobufFormat::TextFormat => {
591 let text = ProtobufSerializer::to_proto_text(model)?;
592 std::fs::write(path, text).map_err(|e| {
593 TrustformersError::other(
594 anyhow::anyhow!("Failed to write text file: {}", e).to_string(),
595 )
596 })?;
597 },
598 ProtobufFormat::Json => {
599 let json = serde_json::to_string_pretty(model).map_err(|e| {
600 TrustformersError::other(
601 anyhow::anyhow!("Failed to serialize JSON: {}", e).to_string(),
602 )
603 })?;
604 std::fs::write(path, json).map_err(|e| {
605 TrustformersError::other(
606 anyhow::anyhow!("Failed to write JSON file: {}", e).to_string(),
607 )
608 })?;
609 },
610 ProtobufFormat::CompressedBinary => {
611 let compressed = ProtobufSerializer::compress_model(model)?;
612 std::fs::write(path, compressed).map_err(|e| {
613 TrustformersError::other(
614 anyhow::anyhow!("Failed to write compressed file: {}", e).to_string(),
615 )
616 })?;
617 },
618 }
619
620 Ok(())
621 }
622
623 pub fn import_model<P: AsRef<Path>>(&self, path: P) -> Result<ProtobufTokenizerModel> {
625 match self.config.export_format {
626 ProtobufFormat::Binary => ProtobufSerializer::load_from_file(path),
627 ProtobufFormat::TextFormat => {
628 let text = std::fs::read_to_string(path).map_err(|e| {
629 TrustformersError::other(
630 anyhow::anyhow!("Failed to read text file: {}", e).to_string(),
631 )
632 })?;
633 ProtobufSerializer::from_proto_text(&text)
634 },
635 ProtobufFormat::Json => {
636 let json = std::fs::read_to_string(path).map_err(|e| {
637 TrustformersError::other(
638 anyhow::anyhow!("Failed to read JSON file: {}", e).to_string(),
639 )
640 })?;
641 serde_json::from_str(&json).map_err(|e| {
642 TrustformersError::other(
643 anyhow::anyhow!("Failed to parse JSON: {}", e).to_string(),
644 )
645 })
646 },
647 ProtobufFormat::CompressedBinary => {
648 let compressed = std::fs::read(path).map_err(|e| {
649 TrustformersError::other(
650 anyhow::anyhow!("Failed to read compressed file: {}", e).to_string(),
651 )
652 })?;
653 ProtobufSerializer::decompress_model(&compressed)
654 },
655 }
656 }
657}
658
659#[cfg(test)]
660mod tests {
661 use super::*;
662
663 #[test]
664 fn test_protobuf_metadata_creation() {
665 let metadata = ProtobufTokenizerMetadata {
666 name: "test-tokenizer".to_string(),
667 version: "1.0".to_string(),
668 tokenizer_type: "bpe".to_string(),
669 vocab_size: 1000,
670 special_tokens: HashMap::new(),
671 max_length: Some(512),
672 truncation_side: "right".to_string(),
673 padding_side: "right".to_string(),
674 do_lower_case: false,
675 strip_accents: None,
676 add_prefix_space: false,
677 trim_offsets: true,
678 created_at: chrono::Utc::now().to_rfc3339(),
679 model_id: None,
680 custom_attributes: HashMap::new(),
681 };
682
683 assert_eq!(metadata.name, "test-tokenizer");
684 assert_eq!(metadata.vocab_size, 1000);
685 }
686
687 #[test]
688 fn test_tokenized_input_conversion() {
689 let input = TokenizedInput {
690 input_ids: vec![1, 2, 3],
691 attention_mask: vec![1, 1, 1],
692 token_type_ids: Some(vec![0, 0, 0]),
693 special_tokens_mask: None,
694 offset_mapping: None,
695 overflowing_tokens: None,
696 };
697
698 let protobuf_input = ProtobufSerializer::serialize_tokenized_input(&input);
699 let converted_back = ProtobufSerializer::deserialize_tokenized_input(&protobuf_input);
700
701 assert_eq!(input.input_ids, converted_back.input_ids);
702 assert_eq!(input.attention_mask, converted_back.attention_mask);
703 assert_eq!(input.token_type_ids, converted_back.token_type_ids);
704 }
705
706 #[test]
707 fn test_protobuf_serialization() {
708 let metadata = ProtobufTokenizerMetadata {
709 name: "test".to_string(),
710 version: "1.0".to_string(),
711 tokenizer_type: "test".to_string(),
712 vocab_size: 0,
713 special_tokens: HashMap::new(),
714 max_length: None,
715 truncation_side: "right".to_string(),
716 padding_side: "right".to_string(),
717 do_lower_case: false,
718 strip_accents: None,
719 add_prefix_space: false,
720 trim_offsets: true,
721 created_at: chrono::Utc::now().to_rfc3339(),
722 model_id: None,
723 custom_attributes: HashMap::new(),
724 };
725
726 let model = ProtobufTokenizerModel {
727 metadata,
728 vocabulary: vec![],
729 normalization_rules: vec![],
730 merge_rules: vec![],
731 added_tokens: vec![],
732 };
733
734 let bytes =
735 ProtobufSerializer::to_protobuf_bytes(&model).expect("Operation failed in test");
736 let recovered =
737 ProtobufSerializer::from_protobuf_bytes(&bytes).expect("Operation failed in test");
738
739 assert_eq!(model.metadata.name, recovered.metadata.name);
740 assert_eq!(model.metadata.version, recovered.metadata.version);
741 }
742
743 #[test]
744 fn test_proto_text_format() {
745 let metadata = ProtobufTokenizerMetadata {
746 name: "test-tokenizer".to_string(),
747 version: "1.0".to_string(),
748 tokenizer_type: "bpe".to_string(),
749 vocab_size: 100,
750 special_tokens: HashMap::new(),
751 max_length: None,
752 truncation_side: "right".to_string(),
753 padding_side: "right".to_string(),
754 do_lower_case: true,
755 strip_accents: None,
756 add_prefix_space: false,
757 trim_offsets: true,
758 created_at: chrono::Utc::now().to_rfc3339(),
759 model_id: None,
760 custom_attributes: HashMap::new(),
761 };
762
763 let model = ProtobufTokenizerModel {
764 metadata,
765 vocabulary: vec![],
766 normalization_rules: vec![],
767 merge_rules: vec![],
768 added_tokens: vec![],
769 };
770
771 let text = ProtobufSerializer::to_proto_text(&model).expect("Operation failed in test");
772 assert!(text.contains("name: \"test-tokenizer\""));
773 assert!(text.contains("version: \"1.0\""));
774 assert!(text.contains("vocab_size: 100"));
775 assert!(text.contains("do_lower_case: true"));
776
777 let parsed = ProtobufSerializer::from_proto_text(&text).expect("Operation failed in test");
778 assert_eq!(parsed.metadata.name, "test-tokenizer");
779 assert_eq!(parsed.metadata.version, "1.0");
780 assert_eq!(parsed.metadata.vocab_size, 100);
781 assert!(parsed.metadata.do_lower_case);
782 }
783
784 #[test]
785 fn test_model_validation() {
786 let metadata = ProtobufTokenizerMetadata {
787 name: "test".to_string(),
788 version: "1.0".to_string(),
789 tokenizer_type: "test".to_string(),
790 vocab_size: 2,
791 special_tokens: HashMap::new(),
792 max_length: None,
793 truncation_side: "right".to_string(),
794 padding_side: "right".to_string(),
795 do_lower_case: false,
796 strip_accents: None,
797 add_prefix_space: false,
798 trim_offsets: true,
799 created_at: chrono::Utc::now().to_rfc3339(),
800 model_id: None,
801 custom_attributes: HashMap::new(),
802 };
803
804 let model = ProtobufTokenizerModel {
805 metadata,
806 vocabulary: vec![ProtobufVocabEntry {
807 token: "hello".to_string(),
808 id: 0,
809 frequency: 0.1,
810 is_special: false,
811 token_type: 0,
812 }], normalization_rules: vec![],
814 merge_rules: vec![],
815 added_tokens: vec![],
816 };
817
818 let warnings =
819 ProtobufSerializer::validate_model(&model).expect("Operation failed in test");
820 assert!(!warnings.is_empty());
821 assert!(warnings[0].contains("Vocabulary size mismatch"));
822 }
823
824 #[test]
825 fn test_compression() {
826 let metadata = ProtobufTokenizerMetadata {
827 name: "test".to_string(),
828 version: "1.0".to_string(),
829 tokenizer_type: "test".to_string(),
830 vocab_size: 0,
831 special_tokens: HashMap::new(),
832 max_length: None,
833 truncation_side: "right".to_string(),
834 padding_side: "right".to_string(),
835 do_lower_case: false,
836 strip_accents: None,
837 add_prefix_space: false,
838 trim_offsets: true,
839 created_at: chrono::Utc::now().to_rfc3339(),
840 model_id: None,
841 custom_attributes: HashMap::new(),
842 };
843
844 let model = ProtobufTokenizerModel {
845 metadata,
846 vocabulary: vec![],
847 normalization_rules: vec![],
848 merge_rules: vec![],
849 added_tokens: vec![],
850 };
851
852 let compressed =
853 ProtobufSerializer::compress_model(&model).expect("Operation failed in test");
854 let decompressed =
855 ProtobufSerializer::decompress_model(&compressed).expect("Operation failed in test");
856
857 assert_eq!(model.metadata.name, decompressed.metadata.name);
858 }
859}