1use anyhow::anyhow;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fs::File;
5use std::io::{BufRead, BufReader, BufWriter, Read, Write};
6use std::path::Path;
7use trustformers_core::errors::{Result, TrustformersError};
8
9const BINARY_FORMAT_VERSION: u32 = 1;
11
12const MAGIC_BYTES: &[u8] = b"TFMT"; #[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct BinaryHeader {
18 pub version: u32,
20
21 pub tokenizer_type: String,
23
24 pub compression_level: u8,
26
27 pub uncompressed_size: u64,
29
30 pub compressed_size: u64,
32
33 pub checksum: u32,
35
36 pub metadata: HashMap<String, String>,
38
39 pub created_at: u64,
41}
42
43#[derive(Debug, Clone)]
45pub struct BinaryConfig {
46 pub compression_level: u8,
48
49 pub include_metadata: bool,
51
52 pub verify_checksums: bool,
54
55 pub buffer_size: usize,
57
58 pub use_memory_mapping: bool,
60}
61
62impl Default for BinaryConfig {
63 fn default() -> Self {
64 Self {
65 compression_level: 6,
66 include_metadata: true,
67 verify_checksums: true,
68 buffer_size: 64 * 1024, use_memory_mapping: false,
70 }
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct BinaryTokenizer {
77 pub vocab: HashMap<String, u32>,
79
80 pub id_to_token: HashMap<u32, String>,
82
83 pub special_tokens: HashMap<String, u32>,
85
86 pub scores: Option<HashMap<u32, f32>>,
88
89 pub merges: Option<Vec<(String, String)>>,
91
92 pub config: HashMap<String, serde_json::Value>,
94
95 pub normalization_rules: Option<Vec<NormalizationRule>>,
97
98 pub pre_tokenization_rules: Option<Vec<PreTokenizationRule>>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct NormalizationRule {
105 pub rule_type: String,
106 pub parameters: HashMap<String, serde_json::Value>,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct PreTokenizationRule {
112 pub rule_type: String,
113 pub pattern: String,
114 pub replacement: Option<String>,
115}
116
117pub struct BinarySerializer {
119 config: BinaryConfig,
120}
121
122impl BinarySerializer {
123 pub fn new(config: BinaryConfig) -> Self {
125 Self { config }
126 }
127
128 pub fn serialize<P: AsRef<Path>>(
130 &self,
131 tokenizer: &BinaryTokenizer,
132 tokenizer_type: &str,
133 path: P,
134 ) -> Result<BinaryHeader> {
135 let file = File::create(path.as_ref())
136 .map_err(|e| TrustformersError::io_error(format!("Failed to create file: {}", e)))?;
137 let mut writer = BufWriter::with_capacity(self.config.buffer_size, file);
138
139 let data =
141 oxicode::serde::encode_to_vec(tokenizer, oxicode::config::standard()).map_err(|e| {
142 TrustformersError::serialization_error(format!(
143 "Failed to serialize tokenizer: {}",
144 e
145 ))
146 })?;
147
148 let checksum = crc32fast::hash(&data);
150
151 let (final_data, compressed_size) = if self.config.compression_level > 0 {
153 let compressed = self.compress_data(&data)?;
154 let size = compressed.len() as u64;
155 (compressed, size)
156 } else {
157 let size = data.len() as u64;
158 (data.clone(), size)
159 };
160
161 let mut metadata = HashMap::new();
163 if self.config.include_metadata {
164 metadata.insert("vocab_size".to_string(), tokenizer.vocab.len().to_string());
165 metadata.insert(
166 "has_scores".to_string(),
167 tokenizer.scores.is_some().to_string(),
168 );
169 metadata.insert(
170 "has_merges".to_string(),
171 tokenizer.merges.is_some().to_string(),
172 );
173 }
174
175 let header = BinaryHeader {
176 version: BINARY_FORMAT_VERSION,
177 tokenizer_type: tokenizer_type.to_string(),
178 compression_level: self.config.compression_level,
179 uncompressed_size: data.len() as u64,
180 compressed_size,
181 checksum,
182 metadata,
183 created_at: std::time::SystemTime::now()
184 .duration_since(std::time::UNIX_EPOCH)
185 .unwrap_or_default()
186 .as_secs(),
187 };
188
189 writer.write_all(MAGIC_BYTES).map_err(|e| {
191 TrustformersError::io_error(format!("Failed to write magic bytes: {}", e))
192 })?;
193
194 let header_data = oxicode::serde::encode_to_vec(&header, oxicode::config::standard())
196 .map_err(|e| {
197 TrustformersError::serialization_error(format!("Failed to serialize header: {}", e))
198 })?;
199 let header_size = header_data.len() as u32;
200
201 writer.write_all(&header_size.to_le_bytes()).map_err(|e| {
202 TrustformersError::io_error(format!("Failed to write header size: {}", e))
203 })?;
204 writer
205 .write_all(&header_data)
206 .map_err(|e| TrustformersError::io_error(format!("Failed to write header: {}", e)))?;
207
208 writer.write_all(&final_data).map_err(|e| {
210 TrustformersError::io_error(format!("Failed to write tokenizer data: {}", e))
211 })?;
212
213 writer
214 .flush()
215 .map_err(|e| TrustformersError::io_error(format!("Failed to flush writer: {}", e)))?;
216
217 Ok(header)
218 }
219
220 pub fn deserialize<P: AsRef<Path>>(&self, path: P) -> Result<(BinaryTokenizer, BinaryHeader)> {
222 let file = File::open(path.as_ref())
223 .map_err(|e| TrustformersError::io_error(format!("Failed to open file: {}", e)))?;
224 let mut reader = BufReader::with_capacity(self.config.buffer_size, file);
225
226 let mut magic = [0u8; 4];
228 reader.read_exact(&mut magic).map_err(|e| {
229 TrustformersError::io_error(format!("Failed to read magic bytes: {}", e))
230 })?;
231
232 if magic != MAGIC_BYTES {
233 return Err(trustformers_core::errors::invalid_format(
234 "TFMT",
235 String::from_utf8_lossy(&magic).to_string(),
236 ));
237 }
238
239 let mut header_size_bytes = [0u8; 4];
241 reader.read_exact(&mut header_size_bytes).map_err(|e| {
242 TrustformersError::io_error(format!("Failed to read header size: {}", e))
243 })?;
244 let header_size = u32::from_le_bytes(header_size_bytes) as usize;
245
246 let mut header_data = vec![0u8; header_size];
248 reader
249 .read_exact(&mut header_data)
250 .map_err(|e| TrustformersError::io_error(format!("Failed to read header: {}", e)))?;
251
252 let (header, _): (BinaryHeader, usize) = oxicode::serde::decode_from_slice(
253 &header_data,
254 oxicode::config::standard(),
255 )
256 .map_err(|e| {
257 TrustformersError::serialization_error(format!("Failed to deserialize header: {}", e))
258 })?;
259
260 if header.version > BINARY_FORMAT_VERSION {
262 return Err(trustformers_core::errors::invalid_format(
263 BINARY_FORMAT_VERSION.to_string(),
264 header.version.to_string(),
265 ));
266 }
267
268 let mut data = vec![0u8; header.compressed_size as usize];
270 reader.read_exact(&mut data).map_err(|e| {
271 TrustformersError::io_error(format!("Failed to read tokenizer data: {}", e))
272 })?;
273
274 let final_data = if header.compression_level > 0 {
276 self.decompress_data(&data, header.uncompressed_size as usize)?
277 } else {
278 data
279 };
280
281 if self.config.verify_checksums {
283 let calculated_checksum = crc32fast::hash(&final_data);
284 if calculated_checksum != header.checksum {
285 return Err(trustformers_core::errors::invalid_format(
286 header.checksum.to_string(),
287 calculated_checksum.to_string(),
288 ));
289 }
290 }
291
292 let (tokenizer, _): (BinaryTokenizer, usize) =
294 oxicode::serde::decode_from_slice(&final_data, oxicode::config::standard()).map_err(
295 |e| {
296 TrustformersError::serialization_error(format!(
297 "Failed to deserialize tokenizer: {}",
298 e
299 ))
300 },
301 )?;
302
303 Ok((tokenizer, header))
304 }
305
306 fn compress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
308 use oxiarc_deflate::streaming::ZlibStreamEncoder;
309
310 let mut encoder = ZlibStreamEncoder::new(Vec::new(), self.config.compression_level);
311 encoder.write_all(data).map_err(|e| {
312 TrustformersError::other(anyhow::anyhow!("Failed to compress data: {}", e).to_string())
313 })?;
314 encoder.finish().map_err(|e| {
315 TrustformersError::other(
316 anyhow::anyhow!("Failed to finish compression: {}", e).to_string(),
317 )
318 })
319 }
320
321 fn decompress_data(&self, compressed_data: &[u8], expected_size: usize) -> Result<Vec<u8>> {
323 use oxiarc_deflate::streaming::ZlibStreamDecoder;
324
325 let mut decoder = ZlibStreamDecoder::new(compressed_data);
326 let mut decompressed = Vec::with_capacity(expected_size);
327 decoder.read_to_end(&mut decompressed).map_err(|e| {
328 TrustformersError::other(
329 anyhow::anyhow!("Failed to decompress data: {}", e).to_string(),
330 )
331 })?;
332
333 if decompressed.len() != expected_size {
334 return Err(TrustformersError::other(
335 anyhow::anyhow!(
336 "Decompressed size mismatch: expected {}, got {}",
337 expected_size,
338 decompressed.len()
339 )
340 .to_string(),
341 ));
342 }
343
344 Ok(decompressed)
345 }
346
347 pub fn get_file_info<P: AsRef<Path>>(&self, path: P) -> Result<BinaryHeader> {
349 let file = File::open(path.as_ref())
350 .map_err(|e| TrustformersError::io_error(format!("Failed to open file: {}", e)))?;
351 let mut reader = BufReader::new(file);
352
353 let mut magic = [0u8; 4];
355 reader.read_exact(&mut magic).map_err(|e| {
356 TrustformersError::io_error(format!("Failed to read magic bytes: {}", e))
357 })?;
358
359 if magic != MAGIC_BYTES {
360 return Err(trustformers_core::errors::invalid_format(
361 "TFMT",
362 String::from_utf8_lossy(&magic).to_string(),
363 ));
364 }
365
366 let mut header_size_bytes = [0u8; 4];
368 reader.read_exact(&mut header_size_bytes).map_err(|e| {
369 TrustformersError::io_error(format!("Failed to read header size: {}", e))
370 })?;
371 let header_size = u32::from_le_bytes(header_size_bytes) as usize;
372
373 let mut header_data = vec![0u8; header_size];
375 reader
376 .read_exact(&mut header_data)
377 .map_err(|e| TrustformersError::io_error(format!("Failed to read header: {}", e)))?;
378
379 let (header, _): (BinaryHeader, usize) = oxicode::serde::decode_from_slice(
380 &header_data,
381 oxicode::config::standard(),
382 )
383 .map_err(|e| {
384 TrustformersError::serialization_error(format!("Failed to deserialize header: {}", e))
385 })?;
386
387 Ok(header)
388 }
389}
390
391pub struct BinaryUtils;
393
394impl BinaryUtils {
395 pub fn validate_file<P: AsRef<Path>>(path: P, config: &BinaryConfig) -> Result<bool> {
397 let serializer = BinarySerializer::new(config.clone());
398 let header = serializer.get_file_info(path.as_ref())?;
399
400 if header.version > BINARY_FORMAT_VERSION {
402 return Ok(false);
403 }
404
405 if header.compressed_size == 0 || header.uncompressed_size == 0 {
406 return Ok(false);
407 }
408
409 Ok(true)
410 }
411
412 pub fn compare_files<P: AsRef<Path>>(
414 path1: P,
415 path2: P,
416 config: &BinaryConfig,
417 ) -> Result<bool> {
418 let serializer = BinarySerializer::new(config.clone());
419
420 let header1 = serializer.get_file_info(path1.as_ref())?;
421 let header2 = serializer.get_file_info(path2.as_ref())?;
422
423 Ok(header1.checksum == header2.checksum)
425 }
426
427 pub fn get_compression_ratio<P: AsRef<Path>>(path: P, config: &BinaryConfig) -> Result<f64> {
429 let serializer = BinarySerializer::new(config.clone());
430 let header = serializer.get_file_info(path)?;
431
432 if header.compression_level == 0 {
433 return Ok(1.0);
434 }
435
436 Ok(header.uncompressed_size as f64 / header.compressed_size as f64)
437 }
438
439 pub fn migrate_format<P: AsRef<Path>>(
441 old_path: P,
442 new_path: P,
443 config: &BinaryConfig,
444 ) -> Result<BinaryHeader> {
445 let serializer = BinarySerializer::new(config.clone());
446
447 let (tokenizer, old_header) = serializer.deserialize(old_path)?;
449
450 let tokenizer_type = &old_header.tokenizer_type;
452
453 serializer.serialize(&tokenizer, tokenizer_type, new_path)
455 }
456}
457
458pub struct TokenizerConverter;
460
461impl TokenizerConverter {
462 pub fn from_tokenizer_json<P: AsRef<Path>>(
464 json_path: P,
465 binary_path: P,
466 config: &BinaryConfig,
467 ) -> Result<BinaryHeader> {
468 let json_content = std::fs::read_to_string(json_path.as_ref())
470 .map_err(|e| TrustformersError::io_error(format!("Failed to read JSON file: {}", e)))?;
471
472 let json_value: serde_json::Value = serde_json::from_str(&json_content).map_err(|e| {
473 TrustformersError::serialization_error(format!("Failed to parse JSON: {}", e))
474 })?;
475
476 let mut vocab = HashMap::new();
478 let mut id_to_token = HashMap::new();
479
480 if let Some(model) = json_value.get("model") {
481 if let Some(vocab_obj) = model.get("vocab") {
482 if let Some(vocab_map) = vocab_obj.as_object() {
483 for (token, id) in vocab_map {
484 if let Some(id_num) = id.as_u64() {
485 let id_u32 = id_num as u32;
486 vocab.insert(token.clone(), id_u32);
487 id_to_token.insert(id_u32, token.clone());
488 }
489 }
490 }
491 }
492 }
493
494 let mut special_tokens = HashMap::new();
496 if let Some(added_tokens) = json_value.get("added_tokens") {
497 if let Some(tokens_array) = added_tokens.as_array() {
498 for token_obj in tokens_array {
499 if let Some(content) = token_obj.get("content") {
500 if let Some(id) = token_obj.get("id") {
501 if let (Some(token_str), Some(id_num)) = (content.as_str(), id.as_u64())
502 {
503 special_tokens.insert(token_str.to_string(), id_num as u32);
504 }
505 }
506 }
507 }
508 }
509 }
510
511 let merges = if let Some(model) = json_value.get("model") {
513 if let Some(merges_array) = model.get("merges") {
514 if let Some(merges_vec) = merges_array.as_array() {
515 let mut extracted_merges = Vec::new();
516 for merge in merges_vec {
517 if let Some(merge_str) = merge.as_str() {
518 let parts: Vec<&str> = merge_str.split(' ').collect();
519 if parts.len() == 2 {
520 extracted_merges.push((parts[0].to_string(), parts[1].to_string()));
521 }
522 }
523 }
524 Some(extracted_merges)
525 } else {
526 None
527 }
528 } else {
529 None
530 }
531 } else {
532 None
533 };
534
535 let binary_tokenizer = BinaryTokenizer {
537 vocab,
538 id_to_token,
539 special_tokens,
540 scores: None, merges,
542 config: HashMap::new(),
543 normalization_rules: None,
544 pre_tokenization_rules: None,
545 };
546
547 let tokenizer_type = if let Some(model) = json_value.get("model") {
549 if let Some(type_str) = model.get("type") {
550 type_str.as_str().unwrap_or("unknown").to_string()
551 } else {
552 "unknown".to_string()
553 }
554 } else {
555 "unknown".to_string()
556 };
557
558 let serializer = BinarySerializer::new(config.clone());
560 serializer.serialize(&binary_tokenizer, &tokenizer_type, binary_path)
561 }
562
563 pub fn from_sentencepiece<P: AsRef<Path>>(
565 sp_path: P,
566 binary_path: P,
567 config: &BinaryConfig,
568 ) -> Result<BinaryHeader> {
569 let sp_path = sp_path.as_ref();
570
571 let (vocab, id_to_token, special_tokens, scores, sp_config) =
573 Self::load_sentencepiece_model(sp_path)?;
574
575 let binary_tokenizer = BinaryTokenizer {
577 vocab,
578 id_to_token,
579 special_tokens,
580 scores: Some(scores),
581 merges: None, config: sp_config
583 .into_iter()
584 .map(|(k, v)| (k, serde_json::Value::String(v.to_string())))
585 .collect(),
586 normalization_rules: Some(Self::extract_normalization_rules()),
587 pre_tokenization_rules: Some(Self::extract_pre_tokenization_rules()),
588 };
589
590 let serializer = BinarySerializer::new(config.clone());
591 serializer.serialize(&binary_tokenizer, "sentencepiece", binary_path)
592 }
593
594 fn load_sentencepiece_model<P: AsRef<Path>>(
596 sp_path: P,
597 ) -> Result<(
598 HashMap<String, u32>,
599 HashMap<u32, String>,
600 HashMap<String, u32>,
601 HashMap<u32, f32>,
602 HashMap<String, String>,
603 )> {
604 let sp_path = sp_path.as_ref();
605
606 if sp_path.extension().and_then(|s| s.to_str()) == Some("model") {
608 Self::load_sentencepiece_protobuf(sp_path)
609 } else {
610 Self::load_sentencepiece_vocab(sp_path)
611 }
612 }
613
614 fn load_sentencepiece_protobuf<P: AsRef<Path>>(
616 model_path: P,
617 ) -> Result<(
618 HashMap<String, u32>,
619 HashMap<u32, String>,
620 HashMap<String, u32>,
621 HashMap<u32, f32>,
622 HashMap<String, String>,
623 )> {
624 let mut file = File::open(model_path).map_err(|e| {
625 TrustformersError::other(
626 anyhow!("Failed to open SentencePiece model file: {}", e).to_string(),
627 )
628 })?;
629
630 let mut buffer = Vec::new();
631 file.read_to_end(&mut buffer).map_err(|e| {
632 TrustformersError::other(
633 anyhow!("Failed to read SentencePiece model file: {}", e).to_string(),
634 )
635 })?;
636
637 Self::parse_sentencepiece_protobuf(&buffer)
639 }
640
641 fn parse_sentencepiece_protobuf(
643 data: &[u8],
644 ) -> Result<(
645 HashMap<String, u32>,
646 HashMap<u32, String>,
647 HashMap<String, u32>,
648 HashMap<u32, f32>,
649 HashMap<String, String>,
650 )> {
651 let mut vocab = HashMap::new();
653 let mut id_to_token = HashMap::new();
654 let mut special_tokens = HashMap::new();
655 let mut scores = HashMap::new();
656 let mut config = HashMap::new();
657
658 let standard_tokens = vec![
660 ("<unk>", 0, -100.0, true),
661 ("<s>", 1, -1.0, true),
662 ("</s>", 2, -1.0, true),
663 ("<pad>", 3, -1.0, true),
664 ];
665
666 for (token, id, score, is_special) in standard_tokens {
667 vocab.insert(token.to_string(), id);
668 id_to_token.insert(id, token.to_string());
669 scores.insert(id, score);
670 if is_special {
671 special_tokens.insert(token.to_string(), id);
672 }
673 }
674
675 let mut current_id = 4;
677 let mut i = 0;
678
679 while i < data.len() {
680 if let Some(token_data) = Self::extract_token_from_protobuf(data, &mut i) {
682 let (token, score) = token_data;
683
684 if !vocab.contains_key(&token) {
685 vocab.insert(token.clone(), current_id);
686 id_to_token.insert(current_id, token.clone());
687 scores.insert(current_id, score);
688 current_id += 1;
689 }
690 } else {
691 i += 1;
692 }
693 }
694
695 config.insert("model_type".to_string(), "sentencepiece".to_string());
697 config.insert("vocab_size".to_string(), vocab.len().to_string());
698 config.insert("normalization".to_string(), "nfkc".to_string());
699 config.insert("add_dummy_prefix".to_string(), "true".to_string());
700
701 Ok((vocab, id_to_token, special_tokens, scores, config))
702 }
703
704 fn extract_token_from_protobuf(data: &[u8], pos: &mut usize) -> Option<(String, f32)> {
706 if *pos >= data.len() {
707 return None;
708 }
709
710 let start = *pos;
712 let mut end = start;
713
714 while end < data.len() && end < start + 50 {
716 if data[end] == 0
717 || (data[end] < 32 && data[end] != 9 && data[end] != 10 && data[end] != 13)
718 {
719 break;
720 }
721 end += 1;
722 }
723
724 if end > start {
725 if let Ok(token) = String::from_utf8(data[start..end].to_vec()) {
726 let clean_token = token.trim().to_string();
727 if !clean_token.is_empty() && Self::is_valid_token(&clean_token) {
728 *pos = end + 1;
729 let score = Self::estimate_token_score(&clean_token);
731 return Some((clean_token, score));
732 }
733 }
734 }
735
736 *pos += 1;
737 None
738 }
739
740 fn load_sentencepiece_vocab<P: AsRef<Path>>(
742 vocab_path: P,
743 ) -> Result<(
744 HashMap<String, u32>,
745 HashMap<u32, String>,
746 HashMap<String, u32>,
747 HashMap<u32, f32>,
748 HashMap<String, String>,
749 )> {
750 let file = File::open(vocab_path).map_err(|e| {
751 TrustformersError::other(
752 anyhow!("Failed to open SentencePiece vocab file: {}", e).to_string(),
753 )
754 })?;
755 let reader = BufReader::new(file);
756
757 let mut vocab = HashMap::new();
758 let mut id_to_token = HashMap::new();
759 let mut special_tokens = HashMap::new();
760 let mut scores = HashMap::new();
761 let mut config = HashMap::new();
762
763 for (line_num, line) in reader.lines().enumerate() {
764 let line = line.map_err(|e| {
765 TrustformersError::other(
766 anyhow!("Failed to read line {}: {}", line_num, e).to_string(),
767 )
768 })?;
769 let line = line.trim();
770
771 if line.is_empty() || line.starts_with('#') {
772 continue;
773 }
774
775 let parts: Vec<&str> = if line.contains('\t') {
777 line.split('\t').collect()
778 } else {
779 line.split_whitespace().collect()
780 };
781
782 if parts.is_empty() {
783 continue;
784 }
785
786 let token = parts[0].to_string();
787 let score = if parts.len() > 1 {
788 parts[1].parse::<f32>().unwrap_or(0.0)
789 } else {
790 Self::estimate_token_score(&token)
791 };
792
793 let id = line_num as u32;
794 vocab.insert(token.clone(), id);
795 id_to_token.insert(id, token.clone());
796 scores.insert(id, score);
797
798 if token.starts_with('<') && token.ends_with('>') {
800 special_tokens.insert(token, id);
801 }
802 }
803
804 config.insert("model_type".to_string(), "sentencepiece".to_string());
806 config.insert("vocab_size".to_string(), vocab.len().to_string());
807 config.insert("normalization".to_string(), "nfkc".to_string());
808
809 Ok((vocab, id_to_token, special_tokens, scores, config))
810 }
811
812 fn is_valid_token(token: &str) -> bool {
814 token.len() <= 100
816 && !token.trim().is_empty()
817 && token.chars().any(|c| !c.is_whitespace())
818 && token.chars().all(|c| c.is_ascii() || c as u32 > 127) }
820
821 fn estimate_token_score(token: &str) -> f32 {
823 match token {
825 "<unk>" => -100.0,
826 "<s>" | "</s>" | "<pad>" => -1.0,
827 _ if token.starts_with('<') && token.ends_with('>') => -10.0, _ if token.starts_with("▁") => -5.0 + (token.len() as f32 * -0.1), _ if token.len() == 1 => -2.0, _ if token.len() <= 3 => -3.0 + (token.len() as f32 * -0.2),
831 _ => -5.0 + (token.len() as f32 * -0.1), }
833 }
834
835 fn extract_normalization_rules() -> Vec<NormalizationRule> {
837 vec![
838 NormalizationRule {
839 rule_type: "NFKC".to_string(),
840 parameters: {
841 let mut params = HashMap::new();
842 params.insert(
843 "pattern".to_string(),
844 serde_json::Value::String(".*".to_string()),
845 );
846 params.insert(
847 "replacement".to_string(),
848 serde_json::Value::String("NFKC_NORMALIZED".to_string()),
849 );
850 params.insert("regex".to_string(), serde_json::Value::Bool(false));
851 params
852 },
853 },
854 NormalizationRule {
855 rule_type: "RemoveExtraSpaces".to_string(),
856 parameters: {
857 let mut params = HashMap::new();
858 params.insert(
859 "pattern".to_string(),
860 serde_json::Value::String(r"\s+".to_string()),
861 );
862 params.insert(
863 "replacement".to_string(),
864 serde_json::Value::String(" ".to_string()),
865 );
866 params.insert("regex".to_string(), serde_json::Value::Bool(true));
867 params
868 },
869 },
870 ]
871 }
872
873 fn extract_pre_tokenization_rules() -> Vec<PreTokenizationRule> {
875 vec![
876 PreTokenizationRule {
877 rule_type: "AddDummyPrefix".to_string(),
878 pattern: "^".to_string(),
879 replacement: Some("▁".to_string()),
880 },
881 PreTokenizationRule {
882 rule_type: "SpaceReplacement".to_string(),
883 pattern: " ".to_string(),
884 replacement: Some("▁".to_string()),
885 },
886 ]
887 }
888}
889
890#[cfg(test)]
891mod tests {
892 use super::*;
893 use tempfile::tempdir;
894
895 fn create_test_tokenizer() -> BinaryTokenizer {
896 let mut vocab = HashMap::new();
897 let mut id_to_token = HashMap::new();
898 let mut special_tokens = HashMap::new();
899
900 vocab.insert("hello".to_string(), 0);
901 vocab.insert("world".to_string(), 1);
902 vocab.insert("<pad>".to_string(), 2);
903
904 id_to_token.insert(0, "hello".to_string());
905 id_to_token.insert(1, "world".to_string());
906 id_to_token.insert(2, "<pad>".to_string());
907
908 special_tokens.insert("<pad>".to_string(), 2);
909
910 BinaryTokenizer {
911 vocab,
912 id_to_token,
913 special_tokens,
914 scores: None,
915 merges: None,
916 config: HashMap::new(),
917 normalization_rules: None,
918 pre_tokenization_rules: None,
919 }
920 }
921
922 #[test]
923 fn test_serialize_deserialize() {
924 let temp_dir = tempdir().expect("Operation failed in test");
925 let file_path = temp_dir.path().join("test_tokenizer.bin");
926
927 let config = BinaryConfig::default();
928 let serializer = BinarySerializer::new(config);
929
930 let tokenizer = create_test_tokenizer();
931
932 let header = serializer
934 .serialize(&tokenizer, "test", &file_path)
935 .expect("Operation failed in test");
936 assert_eq!(header.tokenizer_type, "test");
937 assert_eq!(header.version, BINARY_FORMAT_VERSION);
938
939 let (loaded_tokenizer, loaded_header) =
941 serializer.deserialize(&file_path).expect("Operation failed in test");
942
943 assert_eq!(loaded_tokenizer.vocab, tokenizer.vocab);
944 assert_eq!(loaded_tokenizer.id_to_token, tokenizer.id_to_token);
945 assert_eq!(loaded_header.tokenizer_type, "test");
946 }
947
948 #[test]
949 fn test_compression() {
950 let temp_dir = tempdir().expect("Operation failed in test");
951 let file_path = temp_dir.path().join("test_compressed.bin");
952
953 let config = BinaryConfig {
954 compression_level: 9,
955 ..Default::default()
956 };
957 let serializer = BinarySerializer::new(config);
958
959 let tokenizer = create_test_tokenizer();
960 let header = serializer
961 .serialize(&tokenizer, "test", &file_path)
962 .expect("Operation failed in test");
963
964 assert!(header.compressed_size < header.uncompressed_size);
965 assert_eq!(header.compression_level, 9);
966
967 let (loaded_tokenizer, _) =
969 serializer.deserialize(&file_path).expect("Operation failed in test");
970 assert_eq!(loaded_tokenizer.vocab, tokenizer.vocab);
971 }
972
973 #[test]
974 fn test_file_info() {
975 let temp_dir = tempdir().expect("Operation failed in test");
976 let file_path = temp_dir.path().join("test_info.bin");
977
978 let config = BinaryConfig::default();
979 let serializer = BinarySerializer::new(config);
980
981 let tokenizer = create_test_tokenizer();
982 let original_header = serializer
983 .serialize(&tokenizer, "test", &file_path)
984 .expect("Operation failed in test");
985
986 let info_header = serializer.get_file_info(&file_path).expect("Operation failed in test");
988
989 assert_eq!(info_header.tokenizer_type, original_header.tokenizer_type);
990 assert_eq!(info_header.checksum, original_header.checksum);
991 }
992
993 #[test]
994 fn test_validation() {
995 let temp_dir = tempdir().expect("Operation failed in test");
996 let file_path = temp_dir.path().join("test_validate.bin");
997
998 let config = BinaryConfig::default();
999 let serializer = BinarySerializer::new(config.clone());
1000
1001 let tokenizer = create_test_tokenizer();
1002 serializer
1003 .serialize(&tokenizer, "test", &file_path)
1004 .expect("Operation failed in test");
1005
1006 assert!(BinaryUtils::validate_file(&file_path, &config).expect("Operation failed in test"));
1007 }
1008
1009 #[test]
1010 fn test_compression_ratio() {
1011 let temp_dir = tempdir().expect("Operation failed in test");
1012 let file_path = temp_dir.path().join("test_ratio.bin");
1013
1014 let config = BinaryConfig {
1015 compression_level: 6,
1016 ..Default::default()
1017 };
1018 let serializer = BinarySerializer::new(config.clone());
1019
1020 let tokenizer = create_test_tokenizer();
1021 serializer
1022 .serialize(&tokenizer, "test", &file_path)
1023 .expect("Operation failed in test");
1024
1025 let ratio = BinaryUtils::get_compression_ratio(&file_path, &config)
1026 .expect("Operation failed in test");
1027 assert!(ratio > 1.0); }
1029}