Skip to main content

trustformers_tokenizers/
protobuf_serialization.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::Path;
4use trustformers_core::errors::{Result, TrustformersError};
5use trustformers_core::traits::{TokenizedInput, Tokenizer};
6
7/// Protobuf-compatible tokenizer metadata
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ProtobufTokenizerMetadata {
10    pub name: String,
11    pub version: String,
12    pub tokenizer_type: String,
13    pub vocab_size: u32,
14    pub special_tokens: HashMap<String, u32>,
15    pub max_length: Option<u32>,
16    pub truncation_side: String,
17    pub padding_side: String,
18    pub do_lower_case: bool,
19    pub strip_accents: Option<bool>,
20    pub add_prefix_space: bool,
21    pub trim_offsets: bool,
22    pub created_at: String,
23    pub model_id: Option<String>,
24    pub custom_attributes: HashMap<String, Vec<u8>>, // For extension data
25}
26
27/// Protobuf-compatible vocabulary entry
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ProtobufVocabEntry {
30    pub token: String,
31    pub id: u32,
32    pub frequency: f64,
33    pub is_special: bool,
34    pub token_type: u32, // Enumerated token type
35}
36
37/// Protobuf-compatible normalization rule
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ProtobufNormalizationRule {
40    pub rule_type: u32, // Enumerated rule type
41    pub pattern: Option<String>,
42    pub replacement: Option<String>,
43    pub enabled: bool,
44    pub priority: u32,
45}
46
47/// Protobuf-compatible merge rule (for BPE)
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ProtobufMergeRule {
50    pub first_token: String,
51    pub second_token: String,
52    pub merged_token: String,
53    pub priority: u32,
54}
55
56/// Complete protobuf tokenizer model
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ProtobufTokenizerModel {
59    pub metadata: ProtobufTokenizerMetadata,
60    pub vocabulary: Vec<ProtobufVocabEntry>,
61    pub normalization_rules: Vec<ProtobufNormalizationRule>,
62    pub merge_rules: Vec<ProtobufMergeRule>,
63    pub added_tokens: Vec<ProtobufVocabEntry>,
64}
65
66/// Protobuf-compatible tokenized input
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct ProtobufTokenizedInput {
69    pub input_ids: Vec<u32>,
70    pub attention_mask: Vec<u32>,
71    pub token_type_ids: Vec<u32>,
72    pub special_tokens_mask: Vec<u32>,
73    pub offset_mapping: Vec<ProtobufOffset>,
74    pub overflowing_tokens: Vec<ProtobufTokenizedInput>,
75    pub num_truncated_tokens: u32,
76}
77
78/// Offset information for protobuf
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ProtobufOffset {
81    pub start: u32,
82    pub end: u32,
83}
84
85/// Batch tokenized input for protobuf
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct ProtobufBatchTokenizedInput {
88    pub batch: Vec<ProtobufTokenizedInput>,
89    pub batch_size: u32,
90    pub max_length: u32,
91    pub padding_strategy: u32, // Enumerated padding strategy
92}
93
94/// Protobuf serialization utilities
95pub struct ProtobufSerializer;
96
97impl ProtobufSerializer {
98    /// Convert tokenizer to protobuf model
99    pub fn serialize_tokenizer<T: Tokenizer>(
100        tokenizer: &T,
101        metadata: ProtobufTokenizerMetadata,
102    ) -> Result<ProtobufTokenizerModel> {
103        // Extract vocabulary from tokenizer
104        let vocab_map = tokenizer.get_vocab();
105        let mut vocabulary = Vec::new();
106
107        // Convert vocabulary to protobuf format
108        for (token, id) in vocab_map.iter() {
109            vocabulary.push(ProtobufVocabEntry {
110                token: token.clone(),
111                id: *id,
112                frequency: 0.0, // Default frequency, could be extracted if available
113                is_special: Self::is_special_token(token),
114                token_type: 0, // Default token type (Normal)
115            });
116        }
117
118        // Sort vocabulary by ID for consistency
119        vocabulary.sort_by_key(|token| token.id);
120
121        // Extract normalization rules (basic implementation)
122        let normalization_rules = vec![
123            ProtobufNormalizationRule {
124                rule_type: 1, // NFC normalize
125                enabled: true,
126                pattern: None,
127                replacement: None,
128                priority: 1,
129            },
130            ProtobufNormalizationRule {
131                rule_type: 2,   // Lowercase
132                enabled: false, // Default to false, could be configured
133                pattern: None,
134                replacement: None,
135                priority: 2,
136            },
137        ];
138
139        // For merge rules, we'd need tokenizer-specific logic
140        // This is a basic implementation that works for most tokenizers
141        let merge_rules = vec![];
142
143        // Identify special tokens from the vocabulary
144        let mut added_tokens = Vec::new();
145        for (token, id) in vocab_map.iter() {
146            if Self::is_special_token(token) {
147                added_tokens.push(ProtobufVocabEntry {
148                    token: token.clone(),
149                    id: *id,
150                    frequency: 1.0, // Special tokens usually have high frequency
151                    is_special: true,
152                    token_type: 1, // Special token type
153                });
154            }
155        }
156
157        Ok(ProtobufTokenizerModel {
158            metadata,
159            vocabulary,
160            normalization_rules,
161            merge_rules,
162            added_tokens,
163        })
164    }
165
166    /// Helper function to identify special tokens
167    fn is_special_token(token: &str) -> bool {
168        // Common special token patterns
169        token.starts_with('<') && token.ends_with('>')
170            || token.starts_with('[') && token.ends_with(']')
171            || matches!(
172                token,
173                "<pad>"
174                    | "<unk>"
175                    | "<s>"
176                    | "</s>"
177                    | "<cls>"
178                    | "<sep>"
179                    | "<mask>"
180                    | "[PAD]"
181                    | "[UNK]"
182                    | "[CLS]"
183                    | "[SEP]"
184                    | "[MASK]"
185                    | "[BOS]"
186                    | "[EOS]"
187            )
188    }
189
190    /// Convert tokenized input to protobuf format
191    pub fn serialize_tokenized_input(input: &TokenizedInput) -> ProtobufTokenizedInput {
192        ProtobufTokenizedInput {
193            input_ids: input.input_ids.clone(),
194            attention_mask: input.attention_mask.iter().map(|&x| x as u32).collect(),
195            token_type_ids: input.token_type_ids.clone().unwrap_or_default(),
196            special_tokens_mask: vec![], // Would need to be computed
197            offset_mapping: vec![],      // Would need offset information
198            overflowing_tokens: vec![],
199            num_truncated_tokens: 0,
200        }
201    }
202
203    /// Convert protobuf tokenized input back to standard format
204    pub fn deserialize_tokenized_input(protobuf_input: &ProtobufTokenizedInput) -> TokenizedInput {
205        TokenizedInput {
206            input_ids: protobuf_input.input_ids.clone(),
207            attention_mask: protobuf_input.attention_mask.iter().map(|&x| x as u8).collect(),
208            token_type_ids: if protobuf_input.token_type_ids.is_empty() {
209                None
210            } else {
211                Some(protobuf_input.token_type_ids.clone())
212            },
213            special_tokens_mask: None,
214            offset_mapping: None,
215            overflowing_tokens: None,
216        }
217    }
218
219    /// Serialize to protobuf binary format
220    pub fn to_protobuf_bytes(model: &ProtobufTokenizerModel) -> Result<Vec<u8>> {
221        // Using serde with oxicode as a simplified protobuf-like format
222        // In a real implementation, you'd use actual protobuf libraries like prost
223        oxicode::serde::encode_to_vec(model, oxicode::config::standard()).map_err(|e| {
224            TrustformersError::other(
225                anyhow::anyhow!("Failed to serialize protobuf: {}", e).to_string(),
226            )
227        })
228    }
229
230    /// Deserialize from protobuf binary format
231    pub fn from_protobuf_bytes(bytes: &[u8]) -> Result<ProtobufTokenizerModel> {
232        let (result, _): (ProtobufTokenizerModel, usize) =
233            oxicode::serde::decode_from_slice(bytes, oxicode::config::standard()).map_err(|e| {
234                TrustformersError::other(
235                    anyhow::anyhow!("Failed to deserialize protobuf: {}", e).to_string(),
236                )
237            })?;
238        Ok(result)
239    }
240
241    /// Save tokenizer model to protobuf file
242    pub fn save_to_file<P: AsRef<Path>>(model: &ProtobufTokenizerModel, path: P) -> Result<()> {
243        let bytes = Self::to_protobuf_bytes(model)?;
244        std::fs::write(path, bytes).map_err(|e| {
245            TrustformersError::other(
246                anyhow::anyhow!("Failed to write protobuf file: {}", e).to_string(),
247            )
248        })
249    }
250
251    /// Load tokenizer model from protobuf file
252    pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<ProtobufTokenizerModel> {
253        let bytes = std::fs::read(path).map_err(|e| {
254            TrustformersError::other(
255                anyhow::anyhow!("Failed to read protobuf file: {}", e).to_string(),
256            )
257        })?;
258        Self::from_protobuf_bytes(&bytes)
259    }
260
261    /// Convert to text-based protobuf format (proto text)
262    pub fn to_proto_text(model: &ProtobufTokenizerModel) -> Result<String> {
263        // Simplified proto text format
264        let mut text = String::new();
265
266        text.push_str("# Tokenizer Model (Proto Text Format)\n");
267        text.push_str("metadata {\n");
268        text.push_str(&format!("  name: \"{}\"\n", model.metadata.name));
269        text.push_str(&format!("  version: \"{}\"\n", model.metadata.version));
270        text.push_str(&format!(
271            "  tokenizer_type: \"{}\"\n",
272            model.metadata.tokenizer_type
273        ));
274        text.push_str(&format!("  vocab_size: {}\n", model.metadata.vocab_size));
275        text.push_str(&format!(
276            "  do_lower_case: {}\n",
277            model.metadata.do_lower_case
278        ));
279        text.push_str("}\n\n");
280
281        if !model.vocabulary.is_empty() {
282            text.push_str("vocabulary {\n");
283            for (i, entry) in model.vocabulary.iter().enumerate() {
284                if i >= 10 {
285                    // Limit output for readability
286                    text.push_str(&format!(
287                        "  # ... {} more entries\n",
288                        model.vocabulary.len() - 10
289                    ));
290                    break;
291                }
292                text.push_str("  entry {\n");
293                text.push_str(&format!("    token: \"{}\"\n", entry.token));
294                text.push_str(&format!("    id: {}\n", entry.id));
295                text.push_str(&format!("    frequency: {}\n", entry.frequency));
296                text.push_str(&format!("    is_special: {}\n", entry.is_special));
297                text.push_str("  }\n");
298            }
299            text.push_str("}\n\n");
300        }
301
302        if !model.merge_rules.is_empty() {
303            text.push_str("merge_rules {\n");
304            for (i, rule) in model.merge_rules.iter().enumerate() {
305                if i >= 5 {
306                    // Limit output for readability
307                    text.push_str(&format!(
308                        "  # ... {} more rules\n",
309                        model.merge_rules.len() - 5
310                    ));
311                    break;
312                }
313                text.push_str("  rule {\n");
314                text.push_str(&format!("    first_token: \"{}\"\n", rule.first_token));
315                text.push_str(&format!("    second_token: \"{}\"\n", rule.second_token));
316                text.push_str(&format!("    merged_token: \"{}\"\n", rule.merged_token));
317                text.push_str(&format!("    priority: {}\n", rule.priority));
318                text.push_str("  }\n");
319            }
320            text.push_str("}\n");
321        }
322
323        Ok(text)
324    }
325
326    /// Parse from text-based protobuf format
327    pub fn from_proto_text(text: &str) -> Result<ProtobufTokenizerModel> {
328        // Simplified parser for proto text format
329        // In a real implementation, you'd use a proper protobuf text parser
330
331        let mut metadata = ProtobufTokenizerMetadata {
332            name: "unknown".to_string(),
333            version: "1.0".to_string(),
334            tokenizer_type: "unknown".to_string(),
335            vocab_size: 0,
336            special_tokens: HashMap::new(),
337            max_length: None,
338            truncation_side: "right".to_string(),
339            padding_side: "right".to_string(),
340            do_lower_case: false,
341            strip_accents: None,
342            add_prefix_space: false,
343            trim_offsets: true,
344            created_at: chrono::Utc::now().to_rfc3339(),
345            model_id: None,
346            custom_attributes: HashMap::new(),
347        };
348
349        // Simple pattern matching for key fields
350        for line in text.lines() {
351            let line = line.trim();
352            if line.starts_with("name:") {
353                if let Some(name) = Self::extract_quoted_value(line) {
354                    metadata.name = name;
355                }
356            } else if line.starts_with("version:") {
357                if let Some(version) = Self::extract_quoted_value(line) {
358                    metadata.version = version;
359                }
360            } else if line.starts_with("tokenizer_type:") {
361                if let Some(tokenizer_type) = Self::extract_quoted_value(line) {
362                    metadata.tokenizer_type = tokenizer_type;
363                }
364            } else if line.starts_with("vocab_size:") {
365                if let Some(size_str) = line.split(':').nth(1) {
366                    if let Ok(size) = size_str.trim().parse::<u32>() {
367                        metadata.vocab_size = size;
368                    }
369                }
370            } else if line.starts_with("do_lower_case:") {
371                if let Some(bool_str) = line.split(':').nth(1) {
372                    metadata.do_lower_case = bool_str.trim() == "true";
373                }
374            }
375        }
376
377        Ok(ProtobufTokenizerModel {
378            metadata,
379            vocabulary: vec![],
380            normalization_rules: vec![],
381            merge_rules: vec![],
382            added_tokens: vec![],
383        })
384    }
385
386    /// Extract quoted value from proto text line
387    fn extract_quoted_value(line: &str) -> Option<String> {
388        if let Some(start) = line.find('"') {
389            if let Some(end) = line.rfind('"') {
390                if start < end {
391                    return Some(line[start + 1..end].to_string());
392                }
393            }
394        }
395        None
396    }
397
398    /// Validate protobuf model
399    pub fn validate_model(model: &ProtobufTokenizerModel) -> Result<Vec<String>> {
400        let mut warnings = Vec::new();
401
402        // Check vocabulary consistency
403        if model.vocabulary.len() != model.metadata.vocab_size as usize {
404            warnings.push(format!(
405                "Vocabulary size mismatch: metadata claims {} but found {} tokens",
406                model.metadata.vocab_size,
407                model.vocabulary.len()
408            ));
409        }
410
411        // Check for duplicate token IDs
412        let mut seen_ids = std::collections::HashSet::new();
413        for entry in &model.vocabulary {
414            if !seen_ids.insert(entry.id) {
415                warnings.push(format!("Duplicate token ID: {}", entry.id));
416            }
417        }
418
419        // Check merge rules validity
420        for rule in &model.merge_rules {
421            if rule.first_token.is_empty() || rule.second_token.is_empty() {
422                warnings.push("Empty tokens in merge rule".to_string());
423            }
424        }
425
426        Ok(warnings)
427    }
428
429    /// Get model statistics
430    pub fn get_model_stats(model: &ProtobufTokenizerModel) -> HashMap<String, serde_json::Value> {
431        let mut stats = HashMap::new();
432
433        stats.insert(
434            "vocab_size".to_string(),
435            serde_json::Value::Number(model.vocabulary.len().into()),
436        );
437
438        stats.insert(
439            "special_tokens_count".to_string(),
440            serde_json::Value::Number(model.metadata.special_tokens.len().into()),
441        );
442
443        stats.insert(
444            "merge_rules_count".to_string(),
445            serde_json::Value::Number(model.merge_rules.len().into()),
446        );
447
448        stats.insert(
449            "normalization_rules_count".to_string(),
450            serde_json::Value::Number(model.normalization_rules.len().into()),
451        );
452
453        let special_token_ratio = if model.metadata.vocab_size > 0 {
454            model.metadata.special_tokens.len() as f64 / model.metadata.vocab_size as f64
455        } else {
456            0.0
457        };
458        if let Some(ratio_number) = serde_json::Number::from_f64(special_token_ratio) {
459            stats.insert(
460                "special_token_ratio".to_string(),
461                serde_json::Value::Number(ratio_number),
462            );
463        }
464
465        stats
466    }
467
468    /// Compress protobuf data
469    pub fn compress_model(model: &ProtobufTokenizerModel) -> Result<Vec<u8>> {
470        let serialized = Self::to_protobuf_bytes(model)?;
471
472        use oxiarc_deflate::streaming::GzipStreamEncoder;
473        use std::io::Write;
474
475        let mut encoder = GzipStreamEncoder::new(Vec::new(), 6);
476        encoder.write_all(&serialized).map_err(|e| {
477            TrustformersError::other(anyhow::anyhow!("Failed to compress: {}", e).to_string())
478        })?;
479
480        encoder.finish().map_err(|e| {
481            TrustformersError::other(
482                anyhow::anyhow!("Failed to finish compression: {}", e).to_string(),
483            )
484        })
485    }
486
487    /// Decompress protobuf data
488    pub fn decompress_model(compressed_data: &[u8]) -> Result<ProtobufTokenizerModel> {
489        use oxiarc_deflate::streaming::GzipStreamDecoder;
490        use std::io::Read;
491
492        let mut decoder = GzipStreamDecoder::new(compressed_data);
493        let mut decompressed = Vec::new();
494        decoder.read_to_end(&mut decompressed).map_err(|e| {
495            TrustformersError::other(anyhow::anyhow!("Failed to decompress: {}", e).to_string())
496        })?;
497
498        Self::from_protobuf_bytes(&decompressed)
499    }
500}
501
502/// Helper trait for protobuf conversion
503pub trait ProtobufConvertible {
504    /// Convert to protobuf model
505    fn to_protobuf_model(
506        &self,
507        metadata: ProtobufTokenizerMetadata,
508    ) -> Result<ProtobufTokenizerModel>;
509
510    /// Create from protobuf model
511    fn from_protobuf_model(model: &ProtobufTokenizerModel) -> Result<Self>
512    where
513        Self: Sized;
514}
515
516/// Configuration for protobuf export
517#[derive(Debug, Clone, Serialize, Deserialize)]
518pub struct ProtobufExportConfig {
519    pub include_vocabulary: bool,
520    pub include_merge_rules: bool,
521    pub include_normalization_rules: bool,
522    pub compress_output: bool,
523    pub validate_output: bool,
524    pub export_format: ProtobufFormat,
525}
526
527/// Protobuf export formats
528#[derive(Debug, Clone, Serialize, Deserialize)]
529pub enum ProtobufFormat {
530    Binary,
531    TextFormat,
532    Json,
533    CompressedBinary,
534}
535
536impl Default for ProtobufExportConfig {
537    fn default() -> Self {
538        Self {
539            include_vocabulary: true,
540            include_merge_rules: true,
541            include_normalization_rules: true,
542            compress_output: false,
543            validate_output: true,
544            export_format: ProtobufFormat::Binary,
545        }
546    }
547}
548
549/// Protobuf export utility
550pub struct ProtobufExporter {
551    config: ProtobufExportConfig,
552}
553
554impl ProtobufExporter {
555    /// Create new exporter with configuration
556    pub fn new(config: ProtobufExportConfig) -> Self {
557        Self { config }
558    }
559
560    /// Export tokenizer model
561    pub fn export_model<P: AsRef<Path>>(
562        &self,
563        model: &ProtobufTokenizerModel,
564        path: P,
565    ) -> Result<()> {
566        // Validate if requested
567        if self.config.validate_output {
568            let warnings = ProtobufSerializer::validate_model(model)?;
569            if !warnings.is_empty() {
570                eprintln!("Validation warnings:");
571                for warning in warnings {
572                    eprintln!("  - {}", warning);
573                }
574            }
575        }
576
577        match self.config.export_format {
578            ProtobufFormat::Binary => {
579                if self.config.compress_output {
580                    let compressed = ProtobufSerializer::compress_model(model)?;
581                    std::fs::write(path, compressed).map_err(|e| {
582                        TrustformersError::other(
583                            anyhow::anyhow!("Failed to write file: {}", e).to_string(),
584                        )
585                    })?;
586                } else {
587                    ProtobufSerializer::save_to_file(model, path)?;
588                }
589            },
590            ProtobufFormat::TextFormat => {
591                let text = ProtobufSerializer::to_proto_text(model)?;
592                std::fs::write(path, text).map_err(|e| {
593                    TrustformersError::other(
594                        anyhow::anyhow!("Failed to write text file: {}", e).to_string(),
595                    )
596                })?;
597            },
598            ProtobufFormat::Json => {
599                let json = serde_json::to_string_pretty(model).map_err(|e| {
600                    TrustformersError::other(
601                        anyhow::anyhow!("Failed to serialize JSON: {}", e).to_string(),
602                    )
603                })?;
604                std::fs::write(path, json).map_err(|e| {
605                    TrustformersError::other(
606                        anyhow::anyhow!("Failed to write JSON file: {}", e).to_string(),
607                    )
608                })?;
609            },
610            ProtobufFormat::CompressedBinary => {
611                let compressed = ProtobufSerializer::compress_model(model)?;
612                std::fs::write(path, compressed).map_err(|e| {
613                    TrustformersError::other(
614                        anyhow::anyhow!("Failed to write compressed file: {}", e).to_string(),
615                    )
616                })?;
617            },
618        }
619
620        Ok(())
621    }
622
623    /// Import tokenizer model
624    pub fn import_model<P: AsRef<Path>>(&self, path: P) -> Result<ProtobufTokenizerModel> {
625        match self.config.export_format {
626            ProtobufFormat::Binary => ProtobufSerializer::load_from_file(path),
627            ProtobufFormat::TextFormat => {
628                let text = std::fs::read_to_string(path).map_err(|e| {
629                    TrustformersError::other(
630                        anyhow::anyhow!("Failed to read text file: {}", e).to_string(),
631                    )
632                })?;
633                ProtobufSerializer::from_proto_text(&text)
634            },
635            ProtobufFormat::Json => {
636                let json = std::fs::read_to_string(path).map_err(|e| {
637                    TrustformersError::other(
638                        anyhow::anyhow!("Failed to read JSON file: {}", e).to_string(),
639                    )
640                })?;
641                serde_json::from_str(&json).map_err(|e| {
642                    TrustformersError::other(
643                        anyhow::anyhow!("Failed to parse JSON: {}", e).to_string(),
644                    )
645                })
646            },
647            ProtobufFormat::CompressedBinary => {
648                let compressed = std::fs::read(path).map_err(|e| {
649                    TrustformersError::other(
650                        anyhow::anyhow!("Failed to read compressed file: {}", e).to_string(),
651                    )
652                })?;
653                ProtobufSerializer::decompress_model(&compressed)
654            },
655        }
656    }
657}
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662
663    #[test]
664    fn test_protobuf_metadata_creation() {
665        let metadata = ProtobufTokenizerMetadata {
666            name: "test-tokenizer".to_string(),
667            version: "1.0".to_string(),
668            tokenizer_type: "bpe".to_string(),
669            vocab_size: 1000,
670            special_tokens: HashMap::new(),
671            max_length: Some(512),
672            truncation_side: "right".to_string(),
673            padding_side: "right".to_string(),
674            do_lower_case: false,
675            strip_accents: None,
676            add_prefix_space: false,
677            trim_offsets: true,
678            created_at: chrono::Utc::now().to_rfc3339(),
679            model_id: None,
680            custom_attributes: HashMap::new(),
681        };
682
683        assert_eq!(metadata.name, "test-tokenizer");
684        assert_eq!(metadata.vocab_size, 1000);
685    }
686
687    #[test]
688    fn test_tokenized_input_conversion() {
689        let input = TokenizedInput {
690            input_ids: vec![1, 2, 3],
691            attention_mask: vec![1, 1, 1],
692            token_type_ids: Some(vec![0, 0, 0]),
693            special_tokens_mask: None,
694            offset_mapping: None,
695            overflowing_tokens: None,
696        };
697
698        let protobuf_input = ProtobufSerializer::serialize_tokenized_input(&input);
699        let converted_back = ProtobufSerializer::deserialize_tokenized_input(&protobuf_input);
700
701        assert_eq!(input.input_ids, converted_back.input_ids);
702        assert_eq!(input.attention_mask, converted_back.attention_mask);
703        assert_eq!(input.token_type_ids, converted_back.token_type_ids);
704    }
705
706    #[test]
707    fn test_protobuf_serialization() {
708        let metadata = ProtobufTokenizerMetadata {
709            name: "test".to_string(),
710            version: "1.0".to_string(),
711            tokenizer_type: "test".to_string(),
712            vocab_size: 0,
713            special_tokens: HashMap::new(),
714            max_length: None,
715            truncation_side: "right".to_string(),
716            padding_side: "right".to_string(),
717            do_lower_case: false,
718            strip_accents: None,
719            add_prefix_space: false,
720            trim_offsets: true,
721            created_at: chrono::Utc::now().to_rfc3339(),
722            model_id: None,
723            custom_attributes: HashMap::new(),
724        };
725
726        let model = ProtobufTokenizerModel {
727            metadata,
728            vocabulary: vec![],
729            normalization_rules: vec![],
730            merge_rules: vec![],
731            added_tokens: vec![],
732        };
733
734        let bytes =
735            ProtobufSerializer::to_protobuf_bytes(&model).expect("Operation failed in test");
736        let recovered =
737            ProtobufSerializer::from_protobuf_bytes(&bytes).expect("Operation failed in test");
738
739        assert_eq!(model.metadata.name, recovered.metadata.name);
740        assert_eq!(model.metadata.version, recovered.metadata.version);
741    }
742
743    #[test]
744    fn test_proto_text_format() {
745        let metadata = ProtobufTokenizerMetadata {
746            name: "test-tokenizer".to_string(),
747            version: "1.0".to_string(),
748            tokenizer_type: "bpe".to_string(),
749            vocab_size: 100,
750            special_tokens: HashMap::new(),
751            max_length: None,
752            truncation_side: "right".to_string(),
753            padding_side: "right".to_string(),
754            do_lower_case: true,
755            strip_accents: None,
756            add_prefix_space: false,
757            trim_offsets: true,
758            created_at: chrono::Utc::now().to_rfc3339(),
759            model_id: None,
760            custom_attributes: HashMap::new(),
761        };
762
763        let model = ProtobufTokenizerModel {
764            metadata,
765            vocabulary: vec![],
766            normalization_rules: vec![],
767            merge_rules: vec![],
768            added_tokens: vec![],
769        };
770
771        let text = ProtobufSerializer::to_proto_text(&model).expect("Operation failed in test");
772        assert!(text.contains("name: \"test-tokenizer\""));
773        assert!(text.contains("version: \"1.0\""));
774        assert!(text.contains("vocab_size: 100"));
775        assert!(text.contains("do_lower_case: true"));
776
777        let parsed = ProtobufSerializer::from_proto_text(&text).expect("Operation failed in test");
778        assert_eq!(parsed.metadata.name, "test-tokenizer");
779        assert_eq!(parsed.metadata.version, "1.0");
780        assert_eq!(parsed.metadata.vocab_size, 100);
781        assert!(parsed.metadata.do_lower_case);
782    }
783
784    #[test]
785    fn test_model_validation() {
786        let metadata = ProtobufTokenizerMetadata {
787            name: "test".to_string(),
788            version: "1.0".to_string(),
789            tokenizer_type: "test".to_string(),
790            vocab_size: 2,
791            special_tokens: HashMap::new(),
792            max_length: None,
793            truncation_side: "right".to_string(),
794            padding_side: "right".to_string(),
795            do_lower_case: false,
796            strip_accents: None,
797            add_prefix_space: false,
798            trim_offsets: true,
799            created_at: chrono::Utc::now().to_rfc3339(),
800            model_id: None,
801            custom_attributes: HashMap::new(),
802        };
803
804        let model = ProtobufTokenizerModel {
805            metadata,
806            vocabulary: vec![ProtobufVocabEntry {
807                token: "hello".to_string(),
808                id: 0,
809                frequency: 0.1,
810                is_special: false,
811                token_type: 0,
812            }], // Only 1 token but metadata claims 2
813            normalization_rules: vec![],
814            merge_rules: vec![],
815            added_tokens: vec![],
816        };
817
818        let warnings =
819            ProtobufSerializer::validate_model(&model).expect("Operation failed in test");
820        assert!(!warnings.is_empty());
821        assert!(warnings[0].contains("Vocabulary size mismatch"));
822    }
823
824    #[test]
825    fn test_compression() {
826        let metadata = ProtobufTokenizerMetadata {
827            name: "test".to_string(),
828            version: "1.0".to_string(),
829            tokenizer_type: "test".to_string(),
830            vocab_size: 0,
831            special_tokens: HashMap::new(),
832            max_length: None,
833            truncation_side: "right".to_string(),
834            padding_side: "right".to_string(),
835            do_lower_case: false,
836            strip_accents: None,
837            add_prefix_space: false,
838            trim_offsets: true,
839            created_at: chrono::Utc::now().to_rfc3339(),
840            model_id: None,
841            custom_attributes: HashMap::new(),
842        };
843
844        let model = ProtobufTokenizerModel {
845            metadata,
846            vocabulary: vec![],
847            normalization_rules: vec![],
848            merge_rules: vec![],
849            added_tokens: vec![],
850        };
851
852        let compressed =
853            ProtobufSerializer::compress_model(&model).expect("Operation failed in test");
854        let decompressed =
855            ProtobufSerializer::decompress_model(&compressed).expect("Operation failed in test");
856
857        assert_eq!(model.metadata.name, decompressed.metadata.name);
858    }
859}