1use chrono;
2use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, HashSet};
4use std::fs::File;
5use std::io::{BufReader, BufWriter, Read, Write};
6use std::path::Path;
7use trustformers_core::errors::{Result, TrustformersError};
8use trustformers_core::traits::{TokenizedInput, Tokenizer};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct MessagePackTokenizerMetadata {
13 pub name: String,
14 pub version: String,
15 pub tokenizer_type: String,
16 pub vocab_size: u32,
17 pub special_tokens: HashMap<String, u32>,
18 pub max_length: Option<u32>,
19 pub truncation_side: String,
20 pub padding_side: String,
21 pub do_lower_case: bool,
22 pub strip_accents: Option<bool>,
23 pub add_prefix_space: bool,
24 pub trim_offsets: bool,
25 pub created_at: String,
26 pub model_id: Option<String>,
27 pub custom_attributes: HashMap<String, Vec<u8>>, }
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct MessagePackVocabEntry {
33 pub token: String,
34 pub id: u32,
35 pub frequency: f64,
36 pub is_special: bool,
37 pub token_type: u32, }
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct MessagePackNormalizationRule {
43 pub rule_type: u32, pub pattern: Option<String>,
45 pub replacement: Option<String>,
46 pub enabled: bool,
47 pub priority: u32,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct MessagePackMergeRule {
53 pub first_token: String,
54 pub second_token: String,
55 pub merged_token: String,
56 pub priority: u32,
57 pub frequency: f64,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct MessagePackTokenizerConfig {
63 pub metadata: MessagePackTokenizerMetadata,
64 pub vocabulary: Vec<MessagePackVocabEntry>,
65 pub normalization_rules: Vec<MessagePackNormalizationRule>,
66 pub merge_rules: Vec<MessagePackMergeRule>,
67 pub preprocessing_config: HashMap<String, Vec<u8>>,
68 pub postprocessing_config: HashMap<String, Vec<u8>>,
69 pub training_config: Option<HashMap<String, Vec<u8>>>,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct MessagePackTokenizedInput {
75 pub input_ids: Vec<u32>,
76 pub attention_mask: Option<Vec<u32>>,
77 pub token_type_ids: Option<Vec<u32>>,
78 pub special_tokens_mask: Option<Vec<u32>>,
79 pub offsets: Option<Vec<(u32, u32)>>,
80 pub tokens: Vec<String>,
81 pub overflow: bool,
82 pub sequence_length: u32,
83 pub metadata: HashMap<String, Vec<u8>>,
84}
85
86#[derive(Debug, Clone)]
88pub struct MessagePackConfig {
89 pub use_binary_format: bool,
91
92 pub include_metadata: bool,
94
95 pub include_vocabulary: bool,
97
98 pub include_training_config: bool,
100
101 pub compress: bool,
103
104 pub custom_attributes: HashMap<String, Vec<u8>>,
106}
107
108impl Default for MessagePackConfig {
109 fn default() -> Self {
110 Self {
111 use_binary_format: true,
112 include_metadata: true,
113 include_vocabulary: true,
114 include_training_config: false,
115 compress: false,
116 custom_attributes: HashMap::new(),
117 }
118 }
119}
120
121pub struct MessagePackSerializer {
123 config: MessagePackConfig,
124}
125
126impl MessagePackSerializer {
127 pub fn new(config: MessagePackConfig) -> Self {
129 Self { config }
130 }
131
132 pub fn default() -> Self {
134 Self {
135 config: MessagePackConfig::default(),
136 }
137 }
138
139 pub fn serialize_tokenizer<T: Tokenizer>(
141 &self,
142 tokenizer: &T,
143 metadata: Option<HashMap<String, String>>,
144 ) -> Result<Vec<u8>> {
145 let vocab = tokenizer.get_vocab();
146 let special_tokens = self.detect_special_tokens(&vocab);
147 let vocab_entries: Vec<MessagePackVocabEntry> = vocab
148 .iter()
149 .map(|(token, &id)| MessagePackVocabEntry {
150 token: token.clone(),
151 id,
152 frequency: 1.0, is_special: special_tokens.contains(token),
154 token_type: if special_tokens.contains(token) { 1 } else { 0 },
155 })
156 .collect();
157
158 let tokenizer_metadata = MessagePackTokenizerMetadata {
159 name: metadata
160 .as_ref()
161 .and_then(|m| m.get("name"))
162 .unwrap_or(&"unknown".to_string())
163 .clone(),
164 version: metadata
165 .as_ref()
166 .and_then(|m| m.get("version"))
167 .unwrap_or(&"1.0.0".to_string())
168 .clone(),
169 tokenizer_type: self.get_tokenizer_type(&metadata),
170 vocab_size: vocab.len() as u32,
171 special_tokens: special_tokens
172 .iter()
173 .enumerate()
174 .map(|(i, token)| (token.clone(), i as u32))
175 .collect(),
176 max_length: metadata
177 .as_ref()
178 .and_then(|m| m.get("max_length"))
179 .and_then(|v| v.parse().ok()),
180 truncation_side: "right".to_string(),
181 padding_side: "right".to_string(),
182 do_lower_case: false,
183 strip_accents: None,
184 add_prefix_space: false,
185 trim_offsets: true,
186 created_at: chrono::Utc::now().to_rfc3339(),
187 model_id: metadata.as_ref().and_then(|m| m.get("model_id")).cloned(),
188 custom_attributes: self.config.custom_attributes.clone(),
189 };
190
191 let config = MessagePackTokenizerConfig {
192 metadata: tokenizer_metadata,
193 vocabulary: if self.config.include_vocabulary { vocab_entries } else { Vec::new() },
194 normalization_rules: self.extract_normalization_rules(&metadata),
195 merge_rules: self.extract_merge_rules(&metadata),
196 preprocessing_config: HashMap::new(),
197 postprocessing_config: HashMap::new(),
198 training_config: if self.config.include_training_config {
199 Some(HashMap::new())
200 } else {
201 None
202 },
203 };
204
205 self.serialize_to_messagepack(&config)
206 }
207
208 pub fn serialize_tokenized_input(&self, input: &TokenizedInput) -> Result<Vec<u8>> {
210 let msgpack_input = MessagePackTokenizedInput {
211 input_ids: input.input_ids.clone(),
212 attention_mask: Some(input.attention_mask.iter().map(|&x| x as u32).collect()),
213 token_type_ids: input.token_type_ids.clone(),
214 special_tokens_mask: None,
215 offsets: None,
216 tokens: Vec::new(),
217 overflow: false,
218 sequence_length: input.input_ids.len() as u32,
219 metadata: HashMap::new(),
220 };
221
222 self.serialize_to_messagepack(&msgpack_input)
223 }
224
225 pub fn serialize_tokenized_batch(&self, batch: &[TokenizedInput]) -> Result<Vec<u8>> {
227 let msgpack_batch: Vec<MessagePackTokenizedInput> = batch
228 .iter()
229 .map(|input| MessagePackTokenizedInput {
230 input_ids: input.input_ids.clone(),
231 attention_mask: Some(input.attention_mask.iter().map(|&x| x as u32).collect()),
232 token_type_ids: input.token_type_ids.clone(),
233 special_tokens_mask: None,
234 offsets: None,
235 tokens: Vec::new(),
236 overflow: false,
237 sequence_length: input.input_ids.len() as u32,
238 metadata: HashMap::new(),
239 })
240 .collect();
241
242 self.serialize_to_messagepack(&msgpack_batch)
243 }
244
245 pub fn deserialize_tokenizer_config(&self, data: &[u8]) -> Result<MessagePackTokenizerConfig> {
247 self.deserialize_from_messagepack(data)
248 }
249
250 pub fn deserialize_tokenized_input(&self, data: &[u8]) -> Result<TokenizedInput> {
252 let msgpack_input: MessagePackTokenizedInput = self.deserialize_from_messagepack(data)?;
253
254 let input_ids_len = msgpack_input.input_ids.len();
255 Ok(TokenizedInput {
256 input_ids: msgpack_input.input_ids,
257 attention_mask: msgpack_input
258 .attention_mask
259 .unwrap_or_else(|| vec![1; input_ids_len])
260 .into_iter()
261 .map(|x| x as u8)
262 .collect(),
263 token_type_ids: msgpack_input.token_type_ids,
264 special_tokens_mask: None,
265 offset_mapping: None,
266 overflowing_tokens: None,
267 })
268 }
269
270 pub fn deserialize_tokenized_batch(&self, data: &[u8]) -> Result<Vec<TokenizedInput>> {
272 let msgpack_batch: Vec<MessagePackTokenizedInput> =
273 self.deserialize_from_messagepack(data)?;
274
275 Ok(msgpack_batch
276 .into_iter()
277 .map(|msgpack_input| {
278 let input_ids_len = msgpack_input.input_ids.len();
279 TokenizedInput {
280 input_ids: msgpack_input.input_ids,
281 attention_mask: msgpack_input
282 .attention_mask
283 .unwrap_or_else(|| vec![1; input_ids_len])
284 .into_iter()
285 .map(|x| x as u8)
286 .collect(),
287 token_type_ids: msgpack_input.token_type_ids,
288 special_tokens_mask: None,
289 offset_mapping: None,
290 overflowing_tokens: None,
291 }
292 })
293 .collect())
294 }
295
296 pub fn save_tokenizer_to_file<T: Tokenizer, P: AsRef<Path>>(
298 &self,
299 tokenizer: &T,
300 path: P,
301 metadata: Option<HashMap<String, String>>,
302 ) -> Result<()> {
303 let data = self.serialize_tokenizer(tokenizer, metadata)?;
304 let mut file = BufWriter::new(File::create(path)?);
305 file.write_all(&data)?;
306 file.flush()?;
307 Ok(())
308 }
309
310 pub fn save_tokenized_input_to_file<P: AsRef<Path>>(
312 &self,
313 input: &TokenizedInput,
314 path: P,
315 ) -> Result<()> {
316 let data = self.serialize_tokenized_input(input)?;
317 let mut file = BufWriter::new(File::create(path)?);
318 file.write_all(&data)?;
319 file.flush()?;
320 Ok(())
321 }
322
323 pub fn load_tokenizer_config_from_file<P: AsRef<Path>>(
325 &self,
326 path: P,
327 ) -> Result<MessagePackTokenizerConfig> {
328 let mut file = BufReader::new(File::open(path)?);
329 let mut data = Vec::new();
330 file.read_to_end(&mut data)?;
331 self.deserialize_tokenizer_config(&data)
332 }
333
334 pub fn load_tokenized_input_from_file<P: AsRef<Path>>(
336 &self,
337 path: P,
338 ) -> Result<TokenizedInput> {
339 let mut file = BufReader::new(File::open(path)?);
340 let mut data = Vec::new();
341 file.read_to_end(&mut data)?;
342 self.deserialize_tokenized_input(&data)
343 }
344
345 pub fn validate_messagepack_data(&self, data: &[u8]) -> Result<bool> {
347 match rmp_serde::from_slice::<serde_json::Value>(data) {
349 Ok(_) => Ok(true),
350 Err(e) => Err(TrustformersError::serialization_error(format!(
351 "Invalid MessagePack data: {}",
352 e
353 ))),
354 }
355 }
356
357 pub fn get_messagepack_info(&self, data: &[u8]) -> Result<HashMap<String, String>> {
359 let mut info = HashMap::new();
360
361 info.insert("format".to_string(), "MessagePack".to_string());
362 info.insert("size_bytes".to_string(), data.len().to_string());
363
364 if let Ok(config) = self.deserialize_tokenizer_config(data) {
366 info.insert("data_type".to_string(), "tokenizer_config".to_string());
367 info.insert("tokenizer_type".to_string(), config.metadata.tokenizer_type);
368 info.insert(
369 "vocab_size".to_string(),
370 config.metadata.vocab_size.to_string(),
371 );
372 info.insert("version".to_string(), config.metadata.version);
373 } else if let Ok(_input) = self.deserialize_tokenized_input(data) {
374 info.insert("data_type".to_string(), "tokenized_input".to_string());
375 } else if let Ok(batch) = self.deserialize_tokenized_batch(data) {
376 info.insert("data_type".to_string(), "tokenized_batch".to_string());
377 info.insert("batch_size".to_string(), batch.len().to_string());
378 } else {
379 info.insert("data_type".to_string(), "unknown".to_string());
380 }
381
382 Ok(info)
383 }
384
385 pub fn compare_messagepack_files<P1: AsRef<Path>, P2: AsRef<Path>>(
387 &self,
388 path1: P1,
389 path2: P2,
390 ) -> Result<HashMap<String, String>> {
391 let mut file1 = BufReader::new(File::open(path1)?);
392 let mut file2 = BufReader::new(File::open(path2)?);
393
394 let mut data1 = Vec::new();
395 let mut data2 = Vec::new();
396
397 file1.read_to_end(&mut data1)?;
398 file2.read_to_end(&mut data2)?;
399
400 let mut comparison = HashMap::new();
401
402 comparison.insert("size1_bytes".to_string(), data1.len().to_string());
403 comparison.insert("size2_bytes".to_string(), data2.len().to_string());
404 comparison.insert(
405 "sizes_equal".to_string(),
406 (data1.len() == data2.len()).to_string(),
407 );
408 comparison.insert("contents_equal".to_string(), (data1 == data2).to_string());
409
410 let info1 = self.get_messagepack_info(&data1)?;
411 let info2 = self.get_messagepack_info(&data2)?;
412
413 comparison.insert(
414 "type1".to_string(),
415 info1.get("data_type").unwrap_or(&"unknown".to_string()).clone(),
416 );
417 comparison.insert(
418 "type2".to_string(),
419 info2.get("data_type").unwrap_or(&"unknown".to_string()).clone(),
420 );
421
422 Ok(comparison)
423 }
424
425 fn serialize_to_messagepack<T: Serialize>(&self, data: &T) -> Result<Vec<u8>> {
427 rmp_serde::to_vec(data).map_err(|e| {
428 TrustformersError::serialization_error(format!(
429 "MessagePack serialization failed: {}",
430 e
431 ))
432 })
433 }
434
435 fn deserialize_from_messagepack<T: for<'de> Deserialize<'de>>(&self, data: &[u8]) -> Result<T> {
437 rmp_serde::from_slice(data).map_err(|e| {
438 TrustformersError::serialization_error(format!(
439 "MessagePack deserialization failed: {}",
440 e
441 ))
442 })
443 }
444
445 fn detect_special_tokens(&self, vocab: &HashMap<String, u32>) -> HashSet<String> {
447 let common_special_tokens = [
448 "[PAD]",
449 "[UNK]",
450 "[CLS]",
451 "[SEP]",
452 "[MASK]",
453 "<|endoftext|>",
454 "<|startoftext|>",
455 "<|padding|>",
456 "<pad>",
457 "<unk>",
458 "<cls>",
459 "<sep>",
460 "<mask>",
461 "<s>",
462 "</s>",
463 "<eos>",
464 "<bos>",
465 ];
466
467 vocab
468 .keys()
469 .filter(|token| {
470 common_special_tokens.contains(&token.as_str())
471 || token.starts_with('<') && token.ends_with('>')
472 || token.starts_with('[') && token.ends_with(']')
473 })
474 .cloned()
475 .collect()
476 }
477
478 fn get_tokenizer_type(&self, metadata: &Option<HashMap<String, String>>) -> String {
480 metadata
481 .as_ref()
482 .and_then(|m| m.get("tokenizer_type"))
483 .cloned()
484 .unwrap_or_else(|| "generic".to_string())
485 }
486
487 fn extract_normalization_rules(
489 &self,
490 metadata: &Option<HashMap<String, String>>,
491 ) -> Vec<MessagePackNormalizationRule> {
492 let mut rules = Vec::new();
493
494 if let Some(meta) = metadata {
495 if meta.get("normalize_case").map(|v| v == "true").unwrap_or(false) {
496 rules.push(MessagePackNormalizationRule {
497 rule_type: 1, pattern: None,
499 replacement: None,
500 enabled: true,
501 priority: 1,
502 });
503 }
504 if meta.get("strip_accents").map(|v| v == "true").unwrap_or(false) {
505 rules.push(MessagePackNormalizationRule {
506 rule_type: 2, pattern: None,
508 replacement: None,
509 enabled: true,
510 priority: 2,
511 });
512 }
513 }
514
515 rules
516 }
517
518 fn extract_merge_rules(
520 &self,
521 metadata: &Option<HashMap<String, String>>,
522 ) -> Vec<MessagePackMergeRule> {
523 let mut rules = Vec::new();
524
525 if let Some(meta) = metadata {
526 if let Some(merge_data) = meta.get("merge_rules") {
527 for (i, line) in merge_data.lines().enumerate() {
529 let parts: Vec<&str> = line.split(' ').collect();
530 if parts.len() >= 2 {
531 rules.push(MessagePackMergeRule {
532 first_token: parts[0].to_string(),
533 second_token: parts[1].to_string(),
534 merged_token: format!("{}{}", parts[0], parts[1]),
535 priority: i as u32,
536 frequency: 1.0,
537 });
538 }
539 }
540 }
541 }
542
543 rules
544 }
545}
546
547pub struct MessagePackUtils;
549
550impl MessagePackUtils {
551 pub fn messagepack_to_json(data: &[u8]) -> Result<String> {
553 let value: serde_json::Value = rmp_serde::from_slice(data).map_err(|e| {
554 TrustformersError::serialization_error(format!(
555 "MessagePack to JSON conversion failed: {}",
556 e
557 ))
558 })?;
559
560 serde_json::to_string_pretty(&value).map_err(|e| {
561 TrustformersError::serialization_error(format!("JSON serialization failed: {}", e))
562 })
563 }
564
565 pub fn json_to_messagepack(json: &str) -> Result<Vec<u8>> {
567 let value: serde_json::Value = serde_json::from_str(json).map_err(|e| {
568 TrustformersError::serialization_error(format!("JSON parsing failed: {}", e))
569 })?;
570
571 rmp_serde::to_vec(&value).map_err(|e| {
572 TrustformersError::serialization_error(format!(
573 "JSON to MessagePack conversion failed: {}",
574 e
575 ))
576 })
577 }
578
579 pub fn get_statistics(data: &[u8]) -> Result<HashMap<String, String>> {
581 let mut stats = HashMap::new();
582
583 stats.insert("format".to_string(), "MessagePack".to_string());
584 stats.insert("size_bytes".to_string(), data.len().to_string());
585
586 if let Ok(value) = rmp_serde::from_slice::<serde_json::Value>(data) {
588 match &value {
589 serde_json::Value::Object(map) => {
590 stats.insert("type".to_string(), "object".to_string());
591 stats.insert("fields_count".to_string(), map.len().to_string());
592 },
593 serde_json::Value::Array(arr) => {
594 stats.insert("type".to_string(), "array".to_string());
595 stats.insert("elements_count".to_string(), arr.len().to_string());
596 },
597 _ => {
598 stats.insert("type".to_string(), "primitive".to_string());
599 },
600 }
601 }
602
603 Ok(stats)
604 }
605
606 pub fn validate_file<P: AsRef<Path>>(path: P) -> Result<bool> {
608 let mut file = BufReader::new(File::open(path)?);
609 let mut data = Vec::new();
610 file.read_to_end(&mut data)?;
611
612 match rmp_serde::from_slice::<serde_json::Value>(&data) {
613 Ok(_) => Ok(true),
614 Err(e) => Err(TrustformersError::serialization_error(format!(
615 "MessagePack file validation failed: {}",
616 e
617 ))),
618 }
619 }
620}
621
622#[cfg(test)]
623mod tests {
624 use super::*;
625
626 use tempfile::tempdir;
627
628 #[test]
629 fn test_messagepack_config_default() {
630 let config = MessagePackConfig::default();
631 assert!(config.use_binary_format);
632 assert!(config.include_metadata);
633 assert!(config.include_vocabulary);
634 assert!(!config.include_training_config);
635 assert!(!config.compress);
636 }
637
638 #[test]
639 fn test_messagepack_serializer_creation() {
640 let config = MessagePackConfig::default();
641 let _serializer = MessagePackSerializer::new(config);
642
643 let default_serializer = MessagePackSerializer::default();
645 assert!(default_serializer.config.use_binary_format);
646 }
647
648 #[test]
649 fn test_serialize_tokenized_input() {
650 let serializer = MessagePackSerializer::default();
651
652 let input = TokenizedInput {
653 input_ids: vec![1, 2, 3, 4],
654 attention_mask: vec![1, 1, 1, 1],
655 token_type_ids: Some(vec![0, 0, 1, 1]),
656 special_tokens_mask: None,
657 offset_mapping: None,
658 overflowing_tokens: None,
659 };
660
661 let serialized =
662 serializer.serialize_tokenized_input(&input).expect("Operation failed in test");
663 assert!(!serialized.is_empty());
664
665 let deserialized = serializer
666 .deserialize_tokenized_input(&serialized)
667 .expect("Operation failed in test");
668 assert_eq!(input.input_ids, deserialized.input_ids);
669 assert_eq!(input.attention_mask, deserialized.attention_mask);
670 assert_eq!(input.token_type_ids, deserialized.token_type_ids);
671 }
672
673 #[test]
674 fn test_serialize_tokenized_batch() {
675 let serializer = MessagePackSerializer::default();
676
677 let batch = vec![
678 TokenizedInput {
679 input_ids: vec![1, 2, 3],
680 attention_mask: vec![1, 1, 1],
681 token_type_ids: None,
682 special_tokens_mask: None,
683 offset_mapping: None,
684 overflowing_tokens: None,
685 },
686 TokenizedInput {
687 input_ids: vec![4, 5, 6, 7],
688 attention_mask: vec![1, 1, 1, 1],
689 token_type_ids: None,
690 special_tokens_mask: None,
691 offset_mapping: None,
692 overflowing_tokens: None,
693 },
694 ];
695
696 let serialized =
697 serializer.serialize_tokenized_batch(&batch).expect("Operation failed in test");
698 assert!(!serialized.is_empty());
699
700 let deserialized = serializer
701 .deserialize_tokenized_batch(&serialized)
702 .expect("Operation failed in test");
703 assert_eq!(batch.len(), deserialized.len());
704 assert_eq!(batch[0].input_ids, deserialized[0].input_ids);
705 assert_eq!(batch[1].input_ids, deserialized[1].input_ids);
706 }
707
708 #[test]
709 fn test_messagepack_validation() {
710 let serializer = MessagePackSerializer::default();
711
712 let input = TokenizedInput {
713 input_ids: vec![1, 2, 3],
714 attention_mask: vec![1, 1, 1],
715 token_type_ids: None,
716 special_tokens_mask: None,
717 offset_mapping: None,
718 overflowing_tokens: None,
719 };
720
721 let serialized =
722 serializer.serialize_tokenized_input(&input).expect("Operation failed in test");
723
724 assert!(serializer
726 .validate_messagepack_data(&serialized)
727 .expect("Operation failed in test"));
728
729 let invalid_data = vec![0x82]; assert!(serializer.validate_messagepack_data(&invalid_data).is_err());
733 }
734
735 #[test]
736 fn test_messagepack_info() {
737 let serializer = MessagePackSerializer::default();
738
739 let input = TokenizedInput {
740 input_ids: vec![1, 2, 3],
741 attention_mask: vec![1, 1, 1],
742 token_type_ids: None,
743 special_tokens_mask: None,
744 offset_mapping: None,
745 overflowing_tokens: None,
746 };
747
748 let serialized =
749 serializer.serialize_tokenized_input(&input).expect("Operation failed in test");
750 let info = serializer.get_messagepack_info(&serialized).expect("Operation failed in test");
751
752 assert_eq!(info.get("format").expect("Key not found"), "MessagePack");
753 assert_eq!(
754 info.get("data_type").expect("Key not found"),
755 "tokenized_input"
756 );
757 assert_eq!(
758 info.get("size_bytes").expect("Key not found"),
759 &serialized.len().to_string()
760 );
761 }
762
763 #[test]
764 fn test_file_operations() {
765 let serializer = MessagePackSerializer::default();
766 let temp_dir = tempdir().expect("Operation failed in test");
767
768 let input = TokenizedInput {
769 input_ids: vec![1, 2, 3, 4],
770 attention_mask: vec![1, 1, 1, 1],
771 token_type_ids: None,
772 special_tokens_mask: None,
773 offset_mapping: None,
774 overflowing_tokens: None,
775 };
776
777 let file_path = temp_dir.path().join("test_input.msgpack");
778
779 serializer
781 .save_tokenized_input_to_file(&input, &file_path)
782 .expect("Operation failed in test");
783 assert!(file_path.exists());
784
785 let loaded_input = serializer
787 .load_tokenized_input_from_file(&file_path)
788 .expect("Operation failed in test");
789 assert_eq!(input.input_ids, loaded_input.input_ids);
790 assert_eq!(input.attention_mask, loaded_input.attention_mask);
791 assert_eq!(input.token_type_ids, loaded_input.token_type_ids);
792 }
793
794 #[test]
795 fn test_messagepack_utils() {
796 let test_data = r#"{"test": "data", "number": 42}"#;
797
798 let msgpack_data =
800 MessagePackUtils::json_to_messagepack(test_data).expect("Operation failed in test");
801 assert!(!msgpack_data.is_empty());
802
803 let json_data =
805 MessagePackUtils::messagepack_to_json(&msgpack_data).expect("Operation failed in test");
806 assert!(json_data.contains("test"));
807 assert!(json_data.contains("42"));
808
809 let stats =
811 MessagePackUtils::get_statistics(&msgpack_data).expect("Operation failed in test");
812 assert_eq!(stats.get("format").expect("Key not found"), "MessagePack");
813 assert_eq!(stats.get("type").expect("Key not found"), "object");
814 }
815
816 #[test]
817 fn test_file_validation() {
818 let serializer = MessagePackSerializer::default();
819 let temp_dir = tempdir().expect("Operation failed in test");
820
821 let input = TokenizedInput {
822 input_ids: vec![1, 2, 3],
823 attention_mask: vec![1, 1, 1],
824 token_type_ids: None,
825 special_tokens_mask: None,
826 offset_mapping: None,
827 overflowing_tokens: None,
828 };
829
830 let file_path = temp_dir.path().join("validation_test.msgpack");
831 serializer
832 .save_tokenized_input_to_file(&input, &file_path)
833 .expect("Operation failed in test");
834
835 assert!(MessagePackUtils::validate_file(&file_path).expect("Operation failed in test"));
837 }
838
839 #[test]
840 fn test_file_comparison() {
841 let serializer = MessagePackSerializer::default();
842 let temp_dir = tempdir().expect("Operation failed in test");
843
844 let input1 = TokenizedInput {
845 input_ids: vec![1, 2, 3],
846 attention_mask: vec![1, 1, 1],
847 token_type_ids: None,
848 special_tokens_mask: None,
849 offset_mapping: None,
850 overflowing_tokens: None,
851 };
852
853 let input2 = TokenizedInput {
854 input_ids: vec![4, 5, 6],
855 attention_mask: vec![1, 1, 1],
856 token_type_ids: None,
857 special_tokens_mask: None,
858 offset_mapping: None,
859 overflowing_tokens: None,
860 };
861
862 let file1_path = temp_dir.path().join("compare1.msgpack");
863 let file2_path = temp_dir.path().join("compare2.msgpack");
864
865 serializer
866 .save_tokenized_input_to_file(&input1, &file1_path)
867 .expect("Operation failed in test");
868 serializer
869 .save_tokenized_input_to_file(&input2, &file2_path)
870 .expect("Operation failed in test");
871
872 let comparison = serializer
873 .compare_messagepack_files(&file1_path, &file2_path)
874 .expect("Operation failed in test");
875
876 assert_eq!(
877 comparison.get("contents_equal").expect("Key not found"),
878 "false"
879 );
880 assert_eq!(
881 comparison.get("type1").expect("Key not found"),
882 "tokenized_input"
883 );
884 assert_eq!(
885 comparison.get("type2").expect("Key not found"),
886 "tokenized_input"
887 );
888 }
889}