Skip to main content

trustformers_tokenizers/
messagepack_serialization.rs

1use chrono;
2use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, HashSet};
4use std::fs::File;
5use std::io::{BufReader, BufWriter, Read, Write};
6use std::path::Path;
7use trustformers_core::errors::{Result, TrustformersError};
8use trustformers_core::traits::{TokenizedInput, Tokenizer};
9
10/// MessagePack-compatible tokenizer metadata
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct MessagePackTokenizerMetadata {
13    pub name: String,
14    pub version: String,
15    pub tokenizer_type: String,
16    pub vocab_size: u32,
17    pub special_tokens: HashMap<String, u32>,
18    pub max_length: Option<u32>,
19    pub truncation_side: String,
20    pub padding_side: String,
21    pub do_lower_case: bool,
22    pub strip_accents: Option<bool>,
23    pub add_prefix_space: bool,
24    pub trim_offsets: bool,
25    pub created_at: String,
26    pub model_id: Option<String>,
27    pub custom_attributes: HashMap<String, Vec<u8>>, // For extension data
28}
29
30/// MessagePack-compatible vocabulary entry
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct MessagePackVocabEntry {
33    pub token: String,
34    pub id: u32,
35    pub frequency: f64,
36    pub is_special: bool,
37    pub token_type: u32, // Enumerated token type
38}
39
40/// MessagePack-compatible normalization rule
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct MessagePackNormalizationRule {
43    pub rule_type: u32, // Enumerated rule type
44    pub pattern: Option<String>,
45    pub replacement: Option<String>,
46    pub enabled: bool,
47    pub priority: u32,
48}
49
50/// MessagePack-compatible merge rule (for BPE)
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct MessagePackMergeRule {
53    pub first_token: String,
54    pub second_token: String,
55    pub merged_token: String,
56    pub priority: u32,
57    pub frequency: f64,
58}
59
60/// MessagePack-compatible tokenizer configuration
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct MessagePackTokenizerConfig {
63    pub metadata: MessagePackTokenizerMetadata,
64    pub vocabulary: Vec<MessagePackVocabEntry>,
65    pub normalization_rules: Vec<MessagePackNormalizationRule>,
66    pub merge_rules: Vec<MessagePackMergeRule>,
67    pub preprocessing_config: HashMap<String, Vec<u8>>,
68    pub postprocessing_config: HashMap<String, Vec<u8>>,
69    pub training_config: Option<HashMap<String, Vec<u8>>>,
70}
71
72/// MessagePack-compatible tokenized input representation
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct MessagePackTokenizedInput {
75    pub input_ids: Vec<u32>,
76    pub attention_mask: Option<Vec<u32>>,
77    pub token_type_ids: Option<Vec<u32>>,
78    pub special_tokens_mask: Option<Vec<u32>>,
79    pub offsets: Option<Vec<(u32, u32)>>,
80    pub tokens: Vec<String>,
81    pub overflow: bool,
82    pub sequence_length: u32,
83    pub metadata: HashMap<String, Vec<u8>>,
84}
85
86/// Configuration options for MessagePack serialization
87#[derive(Debug, Clone)]
88pub struct MessagePackConfig {
89    /// Whether to use binary or named format
90    pub use_binary_format: bool,
91
92    /// Whether to include metadata in the serialized data
93    pub include_metadata: bool,
94
95    /// Whether to include vocabulary in the serialized data
96    pub include_vocabulary: bool,
97
98    /// Whether to include training configuration
99    pub include_training_config: bool,
100
101    /// Whether to compress the output (using built-in MessagePack compression)
102    pub compress: bool,
103
104    /// Custom attributes to include
105    pub custom_attributes: HashMap<String, Vec<u8>>,
106}
107
108impl Default for MessagePackConfig {
109    fn default() -> Self {
110        Self {
111            use_binary_format: true,
112            include_metadata: true,
113            include_vocabulary: true,
114            include_training_config: false,
115            compress: false,
116            custom_attributes: HashMap::new(),
117        }
118    }
119}
120
121/// MessagePack serializer for tokenizers and tokenized inputs
122pub struct MessagePackSerializer {
123    config: MessagePackConfig,
124}
125
126impl MessagePackSerializer {
127    /// Create a new MessagePack serializer with the given configuration
128    pub fn new(config: MessagePackConfig) -> Self {
129        Self { config }
130    }
131
132    /// Create a new MessagePack serializer with default configuration
133    pub fn default() -> Self {
134        Self {
135            config: MessagePackConfig::default(),
136        }
137    }
138
139    /// Serialize a tokenizer to MessagePack format
140    pub fn serialize_tokenizer<T: Tokenizer>(
141        &self,
142        tokenizer: &T,
143        metadata: Option<HashMap<String, String>>,
144    ) -> Result<Vec<u8>> {
145        let vocab = tokenizer.get_vocab();
146        let special_tokens = self.detect_special_tokens(&vocab);
147        let vocab_entries: Vec<MessagePackVocabEntry> = vocab
148            .iter()
149            .map(|(token, &id)| MessagePackVocabEntry {
150                token: token.clone(),
151                id,
152                frequency: 1.0, // Default frequency
153                is_special: special_tokens.contains(token),
154                token_type: if special_tokens.contains(token) { 1 } else { 0 },
155            })
156            .collect();
157
158        let tokenizer_metadata = MessagePackTokenizerMetadata {
159            name: metadata
160                .as_ref()
161                .and_then(|m| m.get("name"))
162                .unwrap_or(&"unknown".to_string())
163                .clone(),
164            version: metadata
165                .as_ref()
166                .and_then(|m| m.get("version"))
167                .unwrap_or(&"1.0.0".to_string())
168                .clone(),
169            tokenizer_type: self.get_tokenizer_type(&metadata),
170            vocab_size: vocab.len() as u32,
171            special_tokens: special_tokens
172                .iter()
173                .enumerate()
174                .map(|(i, token)| (token.clone(), i as u32))
175                .collect(),
176            max_length: metadata
177                .as_ref()
178                .and_then(|m| m.get("max_length"))
179                .and_then(|v| v.parse().ok()),
180            truncation_side: "right".to_string(),
181            padding_side: "right".to_string(),
182            do_lower_case: false,
183            strip_accents: None,
184            add_prefix_space: false,
185            trim_offsets: true,
186            created_at: chrono::Utc::now().to_rfc3339(),
187            model_id: metadata.as_ref().and_then(|m| m.get("model_id")).cloned(),
188            custom_attributes: self.config.custom_attributes.clone(),
189        };
190
191        let config = MessagePackTokenizerConfig {
192            metadata: tokenizer_metadata,
193            vocabulary: if self.config.include_vocabulary { vocab_entries } else { Vec::new() },
194            normalization_rules: self.extract_normalization_rules(&metadata),
195            merge_rules: self.extract_merge_rules(&metadata),
196            preprocessing_config: HashMap::new(),
197            postprocessing_config: HashMap::new(),
198            training_config: if self.config.include_training_config {
199                Some(HashMap::new())
200            } else {
201                None
202            },
203        };
204
205        self.serialize_to_messagepack(&config)
206    }
207
208    /// Serialize a tokenized input to MessagePack format
209    pub fn serialize_tokenized_input(&self, input: &TokenizedInput) -> Result<Vec<u8>> {
210        let msgpack_input = MessagePackTokenizedInput {
211            input_ids: input.input_ids.clone(),
212            attention_mask: Some(input.attention_mask.iter().map(|&x| x as u32).collect()),
213            token_type_ids: input.token_type_ids.clone(),
214            special_tokens_mask: None,
215            offsets: None,
216            tokens: Vec::new(),
217            overflow: false,
218            sequence_length: input.input_ids.len() as u32,
219            metadata: HashMap::new(),
220        };
221
222        self.serialize_to_messagepack(&msgpack_input)
223    }
224
225    /// Serialize a TokenizedInput batch to MessagePack format
226    pub fn serialize_tokenized_batch(&self, batch: &[TokenizedInput]) -> Result<Vec<u8>> {
227        let msgpack_batch: Vec<MessagePackTokenizedInput> = batch
228            .iter()
229            .map(|input| MessagePackTokenizedInput {
230                input_ids: input.input_ids.clone(),
231                attention_mask: Some(input.attention_mask.iter().map(|&x| x as u32).collect()),
232                token_type_ids: input.token_type_ids.clone(),
233                special_tokens_mask: None,
234                offsets: None,
235                tokens: Vec::new(),
236                overflow: false,
237                sequence_length: input.input_ids.len() as u32,
238                metadata: HashMap::new(),
239            })
240            .collect();
241
242        self.serialize_to_messagepack(&msgpack_batch)
243    }
244
245    /// Deserialize a tokenizer configuration from MessagePack format
246    pub fn deserialize_tokenizer_config(&self, data: &[u8]) -> Result<MessagePackTokenizerConfig> {
247        self.deserialize_from_messagepack(data)
248    }
249
250    /// Deserialize a tokenized input from MessagePack format
251    pub fn deserialize_tokenized_input(&self, data: &[u8]) -> Result<TokenizedInput> {
252        let msgpack_input: MessagePackTokenizedInput = self.deserialize_from_messagepack(data)?;
253
254        let input_ids_len = msgpack_input.input_ids.len();
255        Ok(TokenizedInput {
256            input_ids: msgpack_input.input_ids,
257            attention_mask: msgpack_input
258                .attention_mask
259                .unwrap_or_else(|| vec![1; input_ids_len])
260                .into_iter()
261                .map(|x| x as u8)
262                .collect(),
263            token_type_ids: msgpack_input.token_type_ids,
264            special_tokens_mask: None,
265            offset_mapping: None,
266            overflowing_tokens: None,
267        })
268    }
269
270    /// Deserialize a batch of tokenized inputs from MessagePack format
271    pub fn deserialize_tokenized_batch(&self, data: &[u8]) -> Result<Vec<TokenizedInput>> {
272        let msgpack_batch: Vec<MessagePackTokenizedInput> =
273            self.deserialize_from_messagepack(data)?;
274
275        Ok(msgpack_batch
276            .into_iter()
277            .map(|msgpack_input| {
278                let input_ids_len = msgpack_input.input_ids.len();
279                TokenizedInput {
280                    input_ids: msgpack_input.input_ids,
281                    attention_mask: msgpack_input
282                        .attention_mask
283                        .unwrap_or_else(|| vec![1; input_ids_len])
284                        .into_iter()
285                        .map(|x| x as u8)
286                        .collect(),
287                    token_type_ids: msgpack_input.token_type_ids,
288                    special_tokens_mask: None,
289                    offset_mapping: None,
290                    overflowing_tokens: None,
291                }
292            })
293            .collect())
294    }
295
296    /// Save a tokenizer to a MessagePack file
297    pub fn save_tokenizer_to_file<T: Tokenizer, P: AsRef<Path>>(
298        &self,
299        tokenizer: &T,
300        path: P,
301        metadata: Option<HashMap<String, String>>,
302    ) -> Result<()> {
303        let data = self.serialize_tokenizer(tokenizer, metadata)?;
304        let mut file = BufWriter::new(File::create(path)?);
305        file.write_all(&data)?;
306        file.flush()?;
307        Ok(())
308    }
309
310    /// Save a tokenized input to a MessagePack file
311    pub fn save_tokenized_input_to_file<P: AsRef<Path>>(
312        &self,
313        input: &TokenizedInput,
314        path: P,
315    ) -> Result<()> {
316        let data = self.serialize_tokenized_input(input)?;
317        let mut file = BufWriter::new(File::create(path)?);
318        file.write_all(&data)?;
319        file.flush()?;
320        Ok(())
321    }
322
323    /// Load a tokenizer configuration from a MessagePack file
324    pub fn load_tokenizer_config_from_file<P: AsRef<Path>>(
325        &self,
326        path: P,
327    ) -> Result<MessagePackTokenizerConfig> {
328        let mut file = BufReader::new(File::open(path)?);
329        let mut data = Vec::new();
330        file.read_to_end(&mut data)?;
331        self.deserialize_tokenizer_config(&data)
332    }
333
334    /// Load a tokenized input from a MessagePack file
335    pub fn load_tokenized_input_from_file<P: AsRef<Path>>(
336        &self,
337        path: P,
338    ) -> Result<TokenizedInput> {
339        let mut file = BufReader::new(File::open(path)?);
340        let mut data = Vec::new();
341        file.read_to_end(&mut data)?;
342        self.deserialize_tokenized_input(&data)
343    }
344
345    /// Validate MessagePack data structure
346    pub fn validate_messagepack_data(&self, data: &[u8]) -> Result<bool> {
347        // Try to deserialize to validate structure
348        match rmp_serde::from_slice::<serde_json::Value>(data) {
349            Ok(_) => Ok(true),
350            Err(e) => Err(TrustformersError::serialization_error(format!(
351                "Invalid MessagePack data: {}",
352                e
353            ))),
354        }
355    }
356
357    /// Get information about MessagePack data
358    pub fn get_messagepack_info(&self, data: &[u8]) -> Result<HashMap<String, String>> {
359        let mut info = HashMap::new();
360
361        info.insert("format".to_string(), "MessagePack".to_string());
362        info.insert("size_bytes".to_string(), data.len().to_string());
363
364        // Try to parse as tokenizer config first
365        if let Ok(config) = self.deserialize_tokenizer_config(data) {
366            info.insert("data_type".to_string(), "tokenizer_config".to_string());
367            info.insert("tokenizer_type".to_string(), config.metadata.tokenizer_type);
368            info.insert(
369                "vocab_size".to_string(),
370                config.metadata.vocab_size.to_string(),
371            );
372            info.insert("version".to_string(), config.metadata.version);
373        } else if let Ok(_input) = self.deserialize_tokenized_input(data) {
374            info.insert("data_type".to_string(), "tokenized_input".to_string());
375        } else if let Ok(batch) = self.deserialize_tokenized_batch(data) {
376            info.insert("data_type".to_string(), "tokenized_batch".to_string());
377            info.insert("batch_size".to_string(), batch.len().to_string());
378        } else {
379            info.insert("data_type".to_string(), "unknown".to_string());
380        }
381
382        Ok(info)
383    }
384
385    /// Compare two MessagePack files
386    pub fn compare_messagepack_files<P1: AsRef<Path>, P2: AsRef<Path>>(
387        &self,
388        path1: P1,
389        path2: P2,
390    ) -> Result<HashMap<String, String>> {
391        let mut file1 = BufReader::new(File::open(path1)?);
392        let mut file2 = BufReader::new(File::open(path2)?);
393
394        let mut data1 = Vec::new();
395        let mut data2 = Vec::new();
396
397        file1.read_to_end(&mut data1)?;
398        file2.read_to_end(&mut data2)?;
399
400        let mut comparison = HashMap::new();
401
402        comparison.insert("size1_bytes".to_string(), data1.len().to_string());
403        comparison.insert("size2_bytes".to_string(), data2.len().to_string());
404        comparison.insert(
405            "sizes_equal".to_string(),
406            (data1.len() == data2.len()).to_string(),
407        );
408        comparison.insert("contents_equal".to_string(), (data1 == data2).to_string());
409
410        let info1 = self.get_messagepack_info(&data1)?;
411        let info2 = self.get_messagepack_info(&data2)?;
412
413        comparison.insert(
414            "type1".to_string(),
415            info1.get("data_type").unwrap_or(&"unknown".to_string()).clone(),
416        );
417        comparison.insert(
418            "type2".to_string(),
419            info2.get("data_type").unwrap_or(&"unknown".to_string()).clone(),
420        );
421
422        Ok(comparison)
423    }
424
425    /// Generic serialization method
426    fn serialize_to_messagepack<T: Serialize>(&self, data: &T) -> Result<Vec<u8>> {
427        rmp_serde::to_vec(data).map_err(|e| {
428            TrustformersError::serialization_error(format!(
429                "MessagePack serialization failed: {}",
430                e
431            ))
432        })
433    }
434
435    /// Generic deserialization method
436    fn deserialize_from_messagepack<T: for<'de> Deserialize<'de>>(&self, data: &[u8]) -> Result<T> {
437        rmp_serde::from_slice(data).map_err(|e| {
438            TrustformersError::serialization_error(format!(
439                "MessagePack deserialization failed: {}",
440                e
441            ))
442        })
443    }
444
445    /// Detect common special tokens in vocabulary
446    fn detect_special_tokens(&self, vocab: &HashMap<String, u32>) -> HashSet<String> {
447        let common_special_tokens = [
448            "[PAD]",
449            "[UNK]",
450            "[CLS]",
451            "[SEP]",
452            "[MASK]",
453            "<|endoftext|>",
454            "<|startoftext|>",
455            "<|padding|>",
456            "<pad>",
457            "<unk>",
458            "<cls>",
459            "<sep>",
460            "<mask>",
461            "<s>",
462            "</s>",
463            "<eos>",
464            "<bos>",
465        ];
466
467        vocab
468            .keys()
469            .filter(|token| {
470                common_special_tokens.contains(&token.as_str())
471                    || token.starts_with('<') && token.ends_with('>')
472                    || token.starts_with('[') && token.ends_with(']')
473            })
474            .cloned()
475            .collect()
476    }
477
478    /// Get tokenizer type from metadata
479    fn get_tokenizer_type(&self, metadata: &Option<HashMap<String, String>>) -> String {
480        metadata
481            .as_ref()
482            .and_then(|m| m.get("tokenizer_type"))
483            .cloned()
484            .unwrap_or_else(|| "generic".to_string())
485    }
486
487    /// Extract normalization rules from metadata
488    fn extract_normalization_rules(
489        &self,
490        metadata: &Option<HashMap<String, String>>,
491    ) -> Vec<MessagePackNormalizationRule> {
492        let mut rules = Vec::new();
493
494        if let Some(meta) = metadata {
495            if meta.get("normalize_case").map(|v| v == "true").unwrap_or(false) {
496                rules.push(MessagePackNormalizationRule {
497                    rule_type: 1, // Lowercase
498                    pattern: None,
499                    replacement: None,
500                    enabled: true,
501                    priority: 1,
502                });
503            }
504            if meta.get("strip_accents").map(|v| v == "true").unwrap_or(false) {
505                rules.push(MessagePackNormalizationRule {
506                    rule_type: 2, // Strip accents
507                    pattern: None,
508                    replacement: None,
509                    enabled: true,
510                    priority: 2,
511                });
512            }
513        }
514
515        rules
516    }
517
518    /// Extract merge rules from metadata (for BPE tokenizers)
519    fn extract_merge_rules(
520        &self,
521        metadata: &Option<HashMap<String, String>>,
522    ) -> Vec<MessagePackMergeRule> {
523        let mut rules = Vec::new();
524
525        if let Some(meta) = metadata {
526            if let Some(merge_data) = meta.get("merge_rules") {
527                // Parse merge rules from metadata (simplified implementation)
528                for (i, line) in merge_data.lines().enumerate() {
529                    let parts: Vec<&str> = line.split(' ').collect();
530                    if parts.len() >= 2 {
531                        rules.push(MessagePackMergeRule {
532                            first_token: parts[0].to_string(),
533                            second_token: parts[1].to_string(),
534                            merged_token: format!("{}{}", parts[0], parts[1]),
535                            priority: i as u32,
536                            frequency: 1.0,
537                        });
538                    }
539                }
540            }
541        }
542
543        rules
544    }
545}
546
547/// Utility functions for MessagePack operations
548pub struct MessagePackUtils;
549
550impl MessagePackUtils {
551    /// Convert MessagePack to JSON for inspection
552    pub fn messagepack_to_json(data: &[u8]) -> Result<String> {
553        let value: serde_json::Value = rmp_serde::from_slice(data).map_err(|e| {
554            TrustformersError::serialization_error(format!(
555                "MessagePack to JSON conversion failed: {}",
556                e
557            ))
558        })?;
559
560        serde_json::to_string_pretty(&value).map_err(|e| {
561            TrustformersError::serialization_error(format!("JSON serialization failed: {}", e))
562        })
563    }
564
565    /// Convert JSON to MessagePack
566    pub fn json_to_messagepack(json: &str) -> Result<Vec<u8>> {
567        let value: serde_json::Value = serde_json::from_str(json).map_err(|e| {
568            TrustformersError::serialization_error(format!("JSON parsing failed: {}", e))
569        })?;
570
571        rmp_serde::to_vec(&value).map_err(|e| {
572            TrustformersError::serialization_error(format!(
573                "JSON to MessagePack conversion failed: {}",
574                e
575            ))
576        })
577    }
578
579    /// Get MessagePack data statistics
580    pub fn get_statistics(data: &[u8]) -> Result<HashMap<String, String>> {
581        let mut stats = HashMap::new();
582
583        stats.insert("format".to_string(), "MessagePack".to_string());
584        stats.insert("size_bytes".to_string(), data.len().to_string());
585
586        // Try to parse and count elements
587        if let Ok(value) = rmp_serde::from_slice::<serde_json::Value>(data) {
588            match &value {
589                serde_json::Value::Object(map) => {
590                    stats.insert("type".to_string(), "object".to_string());
591                    stats.insert("fields_count".to_string(), map.len().to_string());
592                },
593                serde_json::Value::Array(arr) => {
594                    stats.insert("type".to_string(), "array".to_string());
595                    stats.insert("elements_count".to_string(), arr.len().to_string());
596                },
597                _ => {
598                    stats.insert("type".to_string(), "primitive".to_string());
599                },
600            }
601        }
602
603        Ok(stats)
604    }
605
606    /// Validate MessagePack file integrity
607    pub fn validate_file<P: AsRef<Path>>(path: P) -> Result<bool> {
608        let mut file = BufReader::new(File::open(path)?);
609        let mut data = Vec::new();
610        file.read_to_end(&mut data)?;
611
612        match rmp_serde::from_slice::<serde_json::Value>(&data) {
613            Ok(_) => Ok(true),
614            Err(e) => Err(TrustformersError::serialization_error(format!(
615                "MessagePack file validation failed: {}",
616                e
617            ))),
618        }
619    }
620}
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625
626    use tempfile::tempdir;
627
628    #[test]
629    fn test_messagepack_config_default() {
630        let config = MessagePackConfig::default();
631        assert!(config.use_binary_format);
632        assert!(config.include_metadata);
633        assert!(config.include_vocabulary);
634        assert!(!config.include_training_config);
635        assert!(!config.compress);
636    }
637
638    #[test]
639    fn test_messagepack_serializer_creation() {
640        let config = MessagePackConfig::default();
641        let _serializer = MessagePackSerializer::new(config);
642
643        // Test that default constructor works
644        let default_serializer = MessagePackSerializer::default();
645        assert!(default_serializer.config.use_binary_format);
646    }
647
648    #[test]
649    fn test_serialize_tokenized_input() {
650        let serializer = MessagePackSerializer::default();
651
652        let input = TokenizedInput {
653            input_ids: vec![1, 2, 3, 4],
654            attention_mask: vec![1, 1, 1, 1],
655            token_type_ids: Some(vec![0, 0, 1, 1]),
656            special_tokens_mask: None,
657            offset_mapping: None,
658            overflowing_tokens: None,
659        };
660
661        let serialized =
662            serializer.serialize_tokenized_input(&input).expect("Operation failed in test");
663        assert!(!serialized.is_empty());
664
665        let deserialized = serializer
666            .deserialize_tokenized_input(&serialized)
667            .expect("Operation failed in test");
668        assert_eq!(input.input_ids, deserialized.input_ids);
669        assert_eq!(input.attention_mask, deserialized.attention_mask);
670        assert_eq!(input.token_type_ids, deserialized.token_type_ids);
671    }
672
673    #[test]
674    fn test_serialize_tokenized_batch() {
675        let serializer = MessagePackSerializer::default();
676
677        let batch = vec![
678            TokenizedInput {
679                input_ids: vec![1, 2, 3],
680                attention_mask: vec![1, 1, 1],
681                token_type_ids: None,
682                special_tokens_mask: None,
683                offset_mapping: None,
684                overflowing_tokens: None,
685            },
686            TokenizedInput {
687                input_ids: vec![4, 5, 6, 7],
688                attention_mask: vec![1, 1, 1, 1],
689                token_type_ids: None,
690                special_tokens_mask: None,
691                offset_mapping: None,
692                overflowing_tokens: None,
693            },
694        ];
695
696        let serialized =
697            serializer.serialize_tokenized_batch(&batch).expect("Operation failed in test");
698        assert!(!serialized.is_empty());
699
700        let deserialized = serializer
701            .deserialize_tokenized_batch(&serialized)
702            .expect("Operation failed in test");
703        assert_eq!(batch.len(), deserialized.len());
704        assert_eq!(batch[0].input_ids, deserialized[0].input_ids);
705        assert_eq!(batch[1].input_ids, deserialized[1].input_ids);
706    }
707
708    #[test]
709    fn test_messagepack_validation() {
710        let serializer = MessagePackSerializer::default();
711
712        let input = TokenizedInput {
713            input_ids: vec![1, 2, 3],
714            attention_mask: vec![1, 1, 1],
715            token_type_ids: None,
716            special_tokens_mask: None,
717            offset_mapping: None,
718            overflowing_tokens: None,
719        };
720
721        let serialized =
722            serializer.serialize_tokenized_input(&input).expect("Operation failed in test");
723
724        // Valid data should validate successfully
725        assert!(serializer
726            .validate_messagepack_data(&serialized)
727            .expect("Operation failed in test"));
728
729        // Invalid data should fail validation
730        // Try with truncated MessagePack data (incomplete)
731        let invalid_data = vec![0x82]; // Map with 2 elements but no actual data
732        assert!(serializer.validate_messagepack_data(&invalid_data).is_err());
733    }
734
735    #[test]
736    fn test_messagepack_info() {
737        let serializer = MessagePackSerializer::default();
738
739        let input = TokenizedInput {
740            input_ids: vec![1, 2, 3],
741            attention_mask: vec![1, 1, 1],
742            token_type_ids: None,
743            special_tokens_mask: None,
744            offset_mapping: None,
745            overflowing_tokens: None,
746        };
747
748        let serialized =
749            serializer.serialize_tokenized_input(&input).expect("Operation failed in test");
750        let info = serializer.get_messagepack_info(&serialized).expect("Operation failed in test");
751
752        assert_eq!(info.get("format").expect("Key not found"), "MessagePack");
753        assert_eq!(
754            info.get("data_type").expect("Key not found"),
755            "tokenized_input"
756        );
757        assert_eq!(
758            info.get("size_bytes").expect("Key not found"),
759            &serialized.len().to_string()
760        );
761    }
762
763    #[test]
764    fn test_file_operations() {
765        let serializer = MessagePackSerializer::default();
766        let temp_dir = tempdir().expect("Operation failed in test");
767
768        let input = TokenizedInput {
769            input_ids: vec![1, 2, 3, 4],
770            attention_mask: vec![1, 1, 1, 1],
771            token_type_ids: None,
772            special_tokens_mask: None,
773            offset_mapping: None,
774            overflowing_tokens: None,
775        };
776
777        let file_path = temp_dir.path().join("test_input.msgpack");
778
779        // Save to file
780        serializer
781            .save_tokenized_input_to_file(&input, &file_path)
782            .expect("Operation failed in test");
783        assert!(file_path.exists());
784
785        // Load from file
786        let loaded_input = serializer
787            .load_tokenized_input_from_file(&file_path)
788            .expect("Operation failed in test");
789        assert_eq!(input.input_ids, loaded_input.input_ids);
790        assert_eq!(input.attention_mask, loaded_input.attention_mask);
791        assert_eq!(input.token_type_ids, loaded_input.token_type_ids);
792    }
793
794    #[test]
795    fn test_messagepack_utils() {
796        let test_data = r#"{"test": "data", "number": 42}"#;
797
798        // Convert JSON to MessagePack
799        let msgpack_data =
800            MessagePackUtils::json_to_messagepack(test_data).expect("Operation failed in test");
801        assert!(!msgpack_data.is_empty());
802
803        // Convert MessagePack back to JSON
804        let json_data =
805            MessagePackUtils::messagepack_to_json(&msgpack_data).expect("Operation failed in test");
806        assert!(json_data.contains("test"));
807        assert!(json_data.contains("42"));
808
809        // Get statistics
810        let stats =
811            MessagePackUtils::get_statistics(&msgpack_data).expect("Operation failed in test");
812        assert_eq!(stats.get("format").expect("Key not found"), "MessagePack");
813        assert_eq!(stats.get("type").expect("Key not found"), "object");
814    }
815
816    #[test]
817    fn test_file_validation() {
818        let serializer = MessagePackSerializer::default();
819        let temp_dir = tempdir().expect("Operation failed in test");
820
821        let input = TokenizedInput {
822            input_ids: vec![1, 2, 3],
823            attention_mask: vec![1, 1, 1],
824            token_type_ids: None,
825            special_tokens_mask: None,
826            offset_mapping: None,
827            overflowing_tokens: None,
828        };
829
830        let file_path = temp_dir.path().join("validation_test.msgpack");
831        serializer
832            .save_tokenized_input_to_file(&input, &file_path)
833            .expect("Operation failed in test");
834
835        // Valid file should validate successfully
836        assert!(MessagePackUtils::validate_file(&file_path).expect("Operation failed in test"));
837    }
838
839    #[test]
840    fn test_file_comparison() {
841        let serializer = MessagePackSerializer::default();
842        let temp_dir = tempdir().expect("Operation failed in test");
843
844        let input1 = TokenizedInput {
845            input_ids: vec![1, 2, 3],
846            attention_mask: vec![1, 1, 1],
847            token_type_ids: None,
848            special_tokens_mask: None,
849            offset_mapping: None,
850            overflowing_tokens: None,
851        };
852
853        let input2 = TokenizedInput {
854            input_ids: vec![4, 5, 6],
855            attention_mask: vec![1, 1, 1],
856            token_type_ids: None,
857            special_tokens_mask: None,
858            offset_mapping: None,
859            overflowing_tokens: None,
860        };
861
862        let file1_path = temp_dir.path().join("compare1.msgpack");
863        let file2_path = temp_dir.path().join("compare2.msgpack");
864
865        serializer
866            .save_tokenized_input_to_file(&input1, &file1_path)
867            .expect("Operation failed in test");
868        serializer
869            .save_tokenized_input_to_file(&input2, &file2_path)
870            .expect("Operation failed in test");
871
872        let comparison = serializer
873            .compare_messagepack_files(&file1_path, &file2_path)
874            .expect("Operation failed in test");
875
876        assert_eq!(
877            comparison.get("contents_equal").expect("Key not found"),
878            "false"
879        );
880        assert_eq!(
881            comparison.get("type1").expect("Key not found"),
882            "tokenized_input"
883        );
884        assert_eq!(
885            comparison.get("type2").expect("Key not found"),
886            "tokenized_input"
887        );
888    }
889}