Skip to main content

trustformers_tokenizers/
binary_format.rs

1use anyhow::anyhow;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fs::File;
5use std::io::{BufRead, BufReader, BufWriter, Read, Write};
6use std::path::Path;
7use trustformers_core::errors::{Result, TrustformersError};
8
9/// Binary format version for compatibility tracking
10const BINARY_FORMAT_VERSION: u32 = 1;
11
12/// Magic bytes to identify our binary format
13const MAGIC_BYTES: &[u8] = b"TFMT"; // TrustForMers Tokenizer
14
15/// Header information for the binary format
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct BinaryHeader {
18    /// Format version for backward compatibility
19    pub version: u32,
20
21    /// Tokenizer type identifier
22    pub tokenizer_type: String,
23
24    /// Compression level used (0 = none, 1-9 = zlib levels)
25    pub compression_level: u8,
26
27    /// Total size of the uncompressed data
28    pub uncompressed_size: u64,
29
30    /// Total size of the compressed data
31    pub compressed_size: u64,
32
33    /// Checksum of the uncompressed data
34    pub checksum: u32,
35
36    /// Metadata about the tokenizer
37    pub metadata: HashMap<String, String>,
38
39    /// Timestamp when this was created
40    pub created_at: u64,
41}
42
43/// Configuration for binary serialization
44#[derive(Debug, Clone)]
45pub struct BinaryConfig {
46    /// Compression level (0 = no compression, 1-9 = zlib compression levels)
47    pub compression_level: u8,
48
49    /// Whether to include metadata in the binary file
50    pub include_metadata: bool,
51
52    /// Whether to verify checksums on load
53    pub verify_checksums: bool,
54
55    /// Buffer size for I/O operations
56    pub buffer_size: usize,
57
58    /// Whether to use memory mapping for large files
59    pub use_memory_mapping: bool,
60}
61
62impl Default for BinaryConfig {
63    fn default() -> Self {
64        Self {
65            compression_level: 6,
66            include_metadata: true,
67            verify_checksums: true,
68            buffer_size: 64 * 1024, // 64KB
69            use_memory_mapping: false,
70        }
71    }
72}
73
74/// Binary tokenizer representation
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct BinaryTokenizer {
77    /// Vocabulary mapping from tokens to IDs
78    pub vocab: HashMap<String, u32>,
79
80    /// Reverse mapping from IDs to tokens
81    pub id_to_token: HashMap<u32, String>,
82
83    /// Special tokens with their IDs
84    pub special_tokens: HashMap<String, u32>,
85
86    /// Token scores for ranking (if applicable)
87    pub scores: Option<HashMap<u32, f32>>,
88
89    /// Merges for BPE tokenizers (if applicable)
90    pub merges: Option<Vec<(String, String)>>,
91
92    /// Additional tokenizer-specific configuration
93    pub config: HashMap<String, serde_json::Value>,
94
95    /// Normalization rules
96    pub normalization_rules: Option<Vec<NormalizationRule>>,
97
98    /// Pre-tokenization rules
99    pub pre_tokenization_rules: Option<Vec<PreTokenizationRule>>,
100}
101
102/// Normalization rule for text preprocessing
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct NormalizationRule {
105    pub rule_type: String,
106    pub parameters: HashMap<String, serde_json::Value>,
107}
108
109/// Pre-tokenization rule for splitting text
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct PreTokenizationRule {
112    pub rule_type: String,
113    pub pattern: String,
114    pub replacement: Option<String>,
115}
116
117/// Binary serializer/deserializer for tokenizers
118pub struct BinarySerializer {
119    config: BinaryConfig,
120}
121
122impl BinarySerializer {
123    /// Create a new binary serializer with the given configuration
124    pub fn new(config: BinaryConfig) -> Self {
125        Self { config }
126    }
127
128    /// Serialize a tokenizer to binary format
129    pub fn serialize<P: AsRef<Path>>(
130        &self,
131        tokenizer: &BinaryTokenizer,
132        tokenizer_type: &str,
133        path: P,
134    ) -> Result<BinaryHeader> {
135        let file = File::create(path.as_ref())
136            .map_err(|e| TrustformersError::io_error(format!("Failed to create file: {}", e)))?;
137        let mut writer = BufWriter::with_capacity(self.config.buffer_size, file);
138
139        // Serialize the tokenizer data
140        let data =
141            oxicode::serde::encode_to_vec(tokenizer, oxicode::config::standard()).map_err(|e| {
142                TrustformersError::serialization_error(format!(
143                    "Failed to serialize tokenizer: {}",
144                    e
145                ))
146            })?;
147
148        // Calculate checksum
149        let checksum = crc32fast::hash(&data);
150
151        // Compress data if requested
152        let (final_data, compressed_size) = if self.config.compression_level > 0 {
153            let compressed = self.compress_data(&data)?;
154            let size = compressed.len() as u64;
155            (compressed, size)
156        } else {
157            let size = data.len() as u64;
158            (data.clone(), size)
159        };
160
161        // Create header
162        let mut metadata = HashMap::new();
163        if self.config.include_metadata {
164            metadata.insert("vocab_size".to_string(), tokenizer.vocab.len().to_string());
165            metadata.insert(
166                "has_scores".to_string(),
167                tokenizer.scores.is_some().to_string(),
168            );
169            metadata.insert(
170                "has_merges".to_string(),
171                tokenizer.merges.is_some().to_string(),
172            );
173        }
174
175        let header = BinaryHeader {
176            version: BINARY_FORMAT_VERSION,
177            tokenizer_type: tokenizer_type.to_string(),
178            compression_level: self.config.compression_level,
179            uncompressed_size: data.len() as u64,
180            compressed_size,
181            checksum,
182            metadata,
183            created_at: std::time::SystemTime::now()
184                .duration_since(std::time::UNIX_EPOCH)
185                .unwrap_or_default()
186                .as_secs(),
187        };
188
189        // Write magic bytes
190        writer.write_all(MAGIC_BYTES).map_err(|e| {
191            TrustformersError::io_error(format!("Failed to write magic bytes: {}", e))
192        })?;
193
194        // Write header
195        let header_data = oxicode::serde::encode_to_vec(&header, oxicode::config::standard())
196            .map_err(|e| {
197                TrustformersError::serialization_error(format!("Failed to serialize header: {}", e))
198            })?;
199        let header_size = header_data.len() as u32;
200
201        writer.write_all(&header_size.to_le_bytes()).map_err(|e| {
202            TrustformersError::io_error(format!("Failed to write header size: {}", e))
203        })?;
204        writer
205            .write_all(&header_data)
206            .map_err(|e| TrustformersError::io_error(format!("Failed to write header: {}", e)))?;
207
208        // Write tokenizer data
209        writer.write_all(&final_data).map_err(|e| {
210            TrustformersError::io_error(format!("Failed to write tokenizer data: {}", e))
211        })?;
212
213        writer
214            .flush()
215            .map_err(|e| TrustformersError::io_error(format!("Failed to flush writer: {}", e)))?;
216
217        Ok(header)
218    }
219
220    /// Deserialize a tokenizer from binary format
221    pub fn deserialize<P: AsRef<Path>>(&self, path: P) -> Result<(BinaryTokenizer, BinaryHeader)> {
222        let file = File::open(path.as_ref())
223            .map_err(|e| TrustformersError::io_error(format!("Failed to open file: {}", e)))?;
224        let mut reader = BufReader::with_capacity(self.config.buffer_size, file);
225
226        // Read and verify magic bytes
227        let mut magic = [0u8; 4];
228        reader.read_exact(&mut magic).map_err(|e| {
229            TrustformersError::io_error(format!("Failed to read magic bytes: {}", e))
230        })?;
231
232        if magic != MAGIC_BYTES {
233            return Err(trustformers_core::errors::invalid_format(
234                "TFMT",
235                String::from_utf8_lossy(&magic).to_string(),
236            ));
237        }
238
239        // Read header size
240        let mut header_size_bytes = [0u8; 4];
241        reader.read_exact(&mut header_size_bytes).map_err(|e| {
242            TrustformersError::io_error(format!("Failed to read header size: {}", e))
243        })?;
244        let header_size = u32::from_le_bytes(header_size_bytes) as usize;
245
246        // Read header
247        let mut header_data = vec![0u8; header_size];
248        reader
249            .read_exact(&mut header_data)
250            .map_err(|e| TrustformersError::io_error(format!("Failed to read header: {}", e)))?;
251
252        let (header, _): (BinaryHeader, usize) = oxicode::serde::decode_from_slice(
253            &header_data,
254            oxicode::config::standard(),
255        )
256        .map_err(|e| {
257            TrustformersError::serialization_error(format!("Failed to deserialize header: {}", e))
258        })?;
259
260        // Verify version compatibility
261        if header.version > BINARY_FORMAT_VERSION {
262            return Err(trustformers_core::errors::invalid_format(
263                BINARY_FORMAT_VERSION.to_string(),
264                header.version.to_string(),
265            ));
266        }
267
268        // Read tokenizer data
269        let mut data = vec![0u8; header.compressed_size as usize];
270        reader.read_exact(&mut data).map_err(|e| {
271            TrustformersError::io_error(format!("Failed to read tokenizer data: {}", e))
272        })?;
273
274        // Decompress if needed
275        let final_data = if header.compression_level > 0 {
276            self.decompress_data(&data, header.uncompressed_size as usize)?
277        } else {
278            data
279        };
280
281        // Verify checksum if enabled
282        if self.config.verify_checksums {
283            let calculated_checksum = crc32fast::hash(&final_data);
284            if calculated_checksum != header.checksum {
285                return Err(trustformers_core::errors::invalid_format(
286                    header.checksum.to_string(),
287                    calculated_checksum.to_string(),
288                ));
289            }
290        }
291
292        // Deserialize tokenizer
293        let (tokenizer, _): (BinaryTokenizer, usize) =
294            oxicode::serde::decode_from_slice(&final_data, oxicode::config::standard()).map_err(
295                |e| {
296                    TrustformersError::serialization_error(format!(
297                        "Failed to deserialize tokenizer: {}",
298                        e
299                    ))
300                },
301            )?;
302
303        Ok((tokenizer, header))
304    }
305
306    /// Compress data using zlib
307    fn compress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
308        use oxiarc_deflate::streaming::ZlibStreamEncoder;
309
310        let mut encoder = ZlibStreamEncoder::new(Vec::new(), self.config.compression_level);
311        encoder.write_all(data).map_err(|e| {
312            TrustformersError::other(anyhow::anyhow!("Failed to compress data: {}", e).to_string())
313        })?;
314        encoder.finish().map_err(|e| {
315            TrustformersError::other(
316                anyhow::anyhow!("Failed to finish compression: {}", e).to_string(),
317            )
318        })
319    }
320
321    /// Decompress data using zlib
322    fn decompress_data(&self, compressed_data: &[u8], expected_size: usize) -> Result<Vec<u8>> {
323        use oxiarc_deflate::streaming::ZlibStreamDecoder;
324
325        let mut decoder = ZlibStreamDecoder::new(compressed_data);
326        let mut decompressed = Vec::with_capacity(expected_size);
327        decoder.read_to_end(&mut decompressed).map_err(|e| {
328            TrustformersError::other(
329                anyhow::anyhow!("Failed to decompress data: {}", e).to_string(),
330            )
331        })?;
332
333        if decompressed.len() != expected_size {
334            return Err(TrustformersError::other(
335                anyhow::anyhow!(
336                    "Decompressed size mismatch: expected {}, got {}",
337                    expected_size,
338                    decompressed.len()
339                )
340                .to_string(),
341            ));
342        }
343
344        Ok(decompressed)
345    }
346
347    /// Get file info without fully loading the tokenizer
348    pub fn get_file_info<P: AsRef<Path>>(&self, path: P) -> Result<BinaryHeader> {
349        let file = File::open(path.as_ref())
350            .map_err(|e| TrustformersError::io_error(format!("Failed to open file: {}", e)))?;
351        let mut reader = BufReader::new(file);
352
353        // Read and verify magic bytes
354        let mut magic = [0u8; 4];
355        reader.read_exact(&mut magic).map_err(|e| {
356            TrustformersError::io_error(format!("Failed to read magic bytes: {}", e))
357        })?;
358
359        if magic != MAGIC_BYTES {
360            return Err(trustformers_core::errors::invalid_format(
361                "TFMT",
362                String::from_utf8_lossy(&magic).to_string(),
363            ));
364        }
365
366        // Read header size
367        let mut header_size_bytes = [0u8; 4];
368        reader.read_exact(&mut header_size_bytes).map_err(|e| {
369            TrustformersError::io_error(format!("Failed to read header size: {}", e))
370        })?;
371        let header_size = u32::from_le_bytes(header_size_bytes) as usize;
372
373        // Read header
374        let mut header_data = vec![0u8; header_size];
375        reader
376            .read_exact(&mut header_data)
377            .map_err(|e| TrustformersError::io_error(format!("Failed to read header: {}", e)))?;
378
379        let (header, _): (BinaryHeader, usize) = oxicode::serde::decode_from_slice(
380            &header_data,
381            oxicode::config::standard(),
382        )
383        .map_err(|e| {
384            TrustformersError::serialization_error(format!("Failed to deserialize header: {}", e))
385        })?;
386
387        Ok(header)
388    }
389}
390
391/// Utilities for working with binary tokenizer files
392pub struct BinaryUtils;
393
394impl BinaryUtils {
395    /// Validate a binary tokenizer file
396    pub fn validate_file<P: AsRef<Path>>(path: P, config: &BinaryConfig) -> Result<bool> {
397        let serializer = BinarySerializer::new(config.clone());
398        let header = serializer.get_file_info(path.as_ref())?;
399
400        // Basic validation checks
401        if header.version > BINARY_FORMAT_VERSION {
402            return Ok(false);
403        }
404
405        if header.compressed_size == 0 || header.uncompressed_size == 0 {
406            return Ok(false);
407        }
408
409        Ok(true)
410    }
411
412    /// Compare two binary tokenizer files
413    pub fn compare_files<P: AsRef<Path>>(
414        path1: P,
415        path2: P,
416        config: &BinaryConfig,
417    ) -> Result<bool> {
418        let serializer = BinarySerializer::new(config.clone());
419
420        let header1 = serializer.get_file_info(path1.as_ref())?;
421        let header2 = serializer.get_file_info(path2.as_ref())?;
422
423        // Compare checksums for quick comparison
424        Ok(header1.checksum == header2.checksum)
425    }
426
427    /// Get compression ratio for a binary file
428    pub fn get_compression_ratio<P: AsRef<Path>>(path: P, config: &BinaryConfig) -> Result<f64> {
429        let serializer = BinarySerializer::new(config.clone());
430        let header = serializer.get_file_info(path)?;
431
432        if header.compression_level == 0 {
433            return Ok(1.0);
434        }
435
436        Ok(header.uncompressed_size as f64 / header.compressed_size as f64)
437    }
438
439    /// Migrate an old format file to the current format
440    pub fn migrate_format<P: AsRef<Path>>(
441        old_path: P,
442        new_path: P,
443        config: &BinaryConfig,
444    ) -> Result<BinaryHeader> {
445        let serializer = BinarySerializer::new(config.clone());
446
447        // Load the old format
448        let (tokenizer, old_header) = serializer.deserialize(old_path)?;
449
450        // Determine tokenizer type from old header or infer it
451        let tokenizer_type = &old_header.tokenizer_type;
452
453        // Save in new format
454        serializer.serialize(&tokenizer, tokenizer_type, new_path)
455    }
456}
457
458/// Converter for converting existing tokenizers to binary format
459pub struct TokenizerConverter;
460
461impl TokenizerConverter {
462    /// Convert a HuggingFace tokenizer.json to binary format
463    pub fn from_tokenizer_json<P: AsRef<Path>>(
464        json_path: P,
465        binary_path: P,
466        config: &BinaryConfig,
467    ) -> Result<BinaryHeader> {
468        // Load the JSON tokenizer
469        let json_content = std::fs::read_to_string(json_path.as_ref())
470            .map_err(|e| TrustformersError::io_error(format!("Failed to read JSON file: {}", e)))?;
471
472        let json_value: serde_json::Value = serde_json::from_str(&json_content).map_err(|e| {
473            TrustformersError::serialization_error(format!("Failed to parse JSON: {}", e))
474        })?;
475
476        // Extract vocabulary
477        let mut vocab = HashMap::new();
478        let mut id_to_token = HashMap::new();
479
480        if let Some(model) = json_value.get("model") {
481            if let Some(vocab_obj) = model.get("vocab") {
482                if let Some(vocab_map) = vocab_obj.as_object() {
483                    for (token, id) in vocab_map {
484                        if let Some(id_num) = id.as_u64() {
485                            let id_u32 = id_num as u32;
486                            vocab.insert(token.clone(), id_u32);
487                            id_to_token.insert(id_u32, token.clone());
488                        }
489                    }
490                }
491            }
492        }
493
494        // Extract special tokens
495        let mut special_tokens = HashMap::new();
496        if let Some(added_tokens) = json_value.get("added_tokens") {
497            if let Some(tokens_array) = added_tokens.as_array() {
498                for token_obj in tokens_array {
499                    if let Some(content) = token_obj.get("content") {
500                        if let Some(id) = token_obj.get("id") {
501                            if let (Some(token_str), Some(id_num)) = (content.as_str(), id.as_u64())
502                            {
503                                special_tokens.insert(token_str.to_string(), id_num as u32);
504                            }
505                        }
506                    }
507                }
508            }
509        }
510
511        // Extract merges for BPE
512        let merges = if let Some(model) = json_value.get("model") {
513            if let Some(merges_array) = model.get("merges") {
514                if let Some(merges_vec) = merges_array.as_array() {
515                    let mut extracted_merges = Vec::new();
516                    for merge in merges_vec {
517                        if let Some(merge_str) = merge.as_str() {
518                            let parts: Vec<&str> = merge_str.split(' ').collect();
519                            if parts.len() == 2 {
520                                extracted_merges.push((parts[0].to_string(), parts[1].to_string()));
521                            }
522                        }
523                    }
524                    Some(extracted_merges)
525                } else {
526                    None
527                }
528            } else {
529                None
530            }
531        } else {
532            None
533        };
534
535        // Create binary tokenizer
536        let binary_tokenizer = BinaryTokenizer {
537            vocab,
538            id_to_token,
539            special_tokens,
540            scores: None, // JSON tokenizers typically don't have scores
541            merges,
542            config: HashMap::new(),
543            normalization_rules: None,
544            pre_tokenization_rules: None,
545        };
546
547        // Determine tokenizer type
548        let tokenizer_type = if let Some(model) = json_value.get("model") {
549            if let Some(type_str) = model.get("type") {
550                type_str.as_str().unwrap_or("unknown").to_string()
551            } else {
552                "unknown".to_string()
553            }
554        } else {
555            "unknown".to_string()
556        };
557
558        // Serialize to binary format
559        let serializer = BinarySerializer::new(config.clone());
560        serializer.serialize(&binary_tokenizer, &tokenizer_type, binary_path)
561    }
562
563    /// Convert from SentencePiece model to binary format
564    pub fn from_sentencepiece<P: AsRef<Path>>(
565        sp_path: P,
566        binary_path: P,
567        config: &BinaryConfig,
568    ) -> Result<BinaryHeader> {
569        let sp_path = sp_path.as_ref();
570
571        // Load SentencePiece model
572        let (vocab, id_to_token, special_tokens, scores, sp_config) =
573            Self::load_sentencepiece_model(sp_path)?;
574
575        // Create binary tokenizer with loaded data
576        let binary_tokenizer = BinaryTokenizer {
577            vocab,
578            id_to_token,
579            special_tokens,
580            scores: Some(scores),
581            merges: None, // SentencePiece doesn't use BPE merges
582            config: sp_config
583                .into_iter()
584                .map(|(k, v)| (k, serde_json::Value::String(v.to_string())))
585                .collect(),
586            normalization_rules: Some(Self::extract_normalization_rules()),
587            pre_tokenization_rules: Some(Self::extract_pre_tokenization_rules()),
588        };
589
590        let serializer = BinarySerializer::new(config.clone());
591        serializer.serialize(&binary_tokenizer, "sentencepiece", binary_path)
592    }
593
594    /// Load SentencePiece model from file
595    fn load_sentencepiece_model<P: AsRef<Path>>(
596        sp_path: P,
597    ) -> Result<(
598        HashMap<String, u32>,
599        HashMap<u32, String>,
600        HashMap<String, u32>,
601        HashMap<u32, f32>,
602        HashMap<String, String>,
603    )> {
604        let sp_path = sp_path.as_ref();
605
606        // Check if it's a protobuf file (.model) or text file (.vocab)
607        if sp_path.extension().and_then(|s| s.to_str()) == Some("model") {
608            Self::load_sentencepiece_protobuf(sp_path)
609        } else {
610            Self::load_sentencepiece_vocab(sp_path)
611        }
612    }
613
614    /// Load SentencePiece protobuf model file
615    fn load_sentencepiece_protobuf<P: AsRef<Path>>(
616        model_path: P,
617    ) -> Result<(
618        HashMap<String, u32>,
619        HashMap<u32, String>,
620        HashMap<String, u32>,
621        HashMap<u32, f32>,
622        HashMap<String, String>,
623    )> {
624        let mut file = File::open(model_path).map_err(|e| {
625            TrustformersError::other(
626                anyhow!("Failed to open SentencePiece model file: {}", e).to_string(),
627            )
628        })?;
629
630        let mut buffer = Vec::new();
631        file.read_to_end(&mut buffer).map_err(|e| {
632            TrustformersError::other(
633                anyhow!("Failed to read SentencePiece model file: {}", e).to_string(),
634            )
635        })?;
636
637        // Parse protobuf data (simplified - would use actual protobuf parsing in production)
638        Self::parse_sentencepiece_protobuf(&buffer)
639    }
640
641    /// Parse SentencePiece protobuf data
642    fn parse_sentencepiece_protobuf(
643        data: &[u8],
644    ) -> Result<(
645        HashMap<String, u32>,
646        HashMap<u32, String>,
647        HashMap<String, u32>,
648        HashMap<u32, f32>,
649        HashMap<String, String>,
650    )> {
651        // Simplified protobuf parsing - in production this would use proper protobuf library
652        let mut vocab = HashMap::new();
653        let mut id_to_token = HashMap::new();
654        let mut special_tokens = HashMap::new();
655        let mut scores = HashMap::new();
656        let mut config = HashMap::new();
657
658        // Add standard SentencePiece tokens
659        let standard_tokens = vec![
660            ("<unk>", 0, -100.0, true),
661            ("<s>", 1, -1.0, true),
662            ("</s>", 2, -1.0, true),
663            ("<pad>", 3, -1.0, true),
664        ];
665
666        for (token, id, score, is_special) in standard_tokens {
667            vocab.insert(token.to_string(), id);
668            id_to_token.insert(id, token.to_string());
669            scores.insert(id, score);
670            if is_special {
671                special_tokens.insert(token.to_string(), id);
672            }
673        }
674
675        // Extract vocabulary from protobuf data
676        let mut current_id = 4;
677        let mut i = 0;
678
679        while i < data.len() {
680            // Look for token patterns in the binary data
681            if let Some(token_data) = Self::extract_token_from_protobuf(data, &mut i) {
682                let (token, score) = token_data;
683
684                if !vocab.contains_key(&token) {
685                    vocab.insert(token.clone(), current_id);
686                    id_to_token.insert(current_id, token.clone());
687                    scores.insert(current_id, score);
688                    current_id += 1;
689                }
690            } else {
691                i += 1;
692            }
693        }
694
695        // Add configuration metadata
696        config.insert("model_type".to_string(), "sentencepiece".to_string());
697        config.insert("vocab_size".to_string(), vocab.len().to_string());
698        config.insert("normalization".to_string(), "nfkc".to_string());
699        config.insert("add_dummy_prefix".to_string(), "true".to_string());
700
701        Ok((vocab, id_to_token, special_tokens, scores, config))
702    }
703
704    /// Extract token from SentencePiece protobuf data
705    fn extract_token_from_protobuf(data: &[u8], pos: &mut usize) -> Option<(String, f32)> {
706        if *pos >= data.len() {
707            return None;
708        }
709
710        // Simplified extraction - look for UTF-8 sequences that could be tokens
711        let start = *pos;
712        let mut end = start;
713
714        // Find potential token boundaries
715        while end < data.len() && end < start + 50 {
716            if data[end] == 0
717                || (data[end] < 32 && data[end] != 9 && data[end] != 10 && data[end] != 13)
718            {
719                break;
720            }
721            end += 1;
722        }
723
724        if end > start {
725            if let Ok(token) = String::from_utf8(data[start..end].to_vec()) {
726                let clean_token = token.trim().to_string();
727                if !clean_token.is_empty() && Self::is_valid_token(&clean_token) {
728                    *pos = end + 1;
729                    // Generate a score based on token characteristics
730                    let score = Self::estimate_token_score(&clean_token);
731                    return Some((clean_token, score));
732                }
733            }
734        }
735
736        *pos += 1;
737        None
738    }
739
740    /// Load SentencePiece vocabulary file
741    fn load_sentencepiece_vocab<P: AsRef<Path>>(
742        vocab_path: P,
743    ) -> Result<(
744        HashMap<String, u32>,
745        HashMap<u32, String>,
746        HashMap<String, u32>,
747        HashMap<u32, f32>,
748        HashMap<String, String>,
749    )> {
750        let file = File::open(vocab_path).map_err(|e| {
751            TrustformersError::other(
752                anyhow!("Failed to open SentencePiece vocab file: {}", e).to_string(),
753            )
754        })?;
755        let reader = BufReader::new(file);
756
757        let mut vocab = HashMap::new();
758        let mut id_to_token = HashMap::new();
759        let mut special_tokens = HashMap::new();
760        let mut scores = HashMap::new();
761        let mut config = HashMap::new();
762
763        for (line_num, line) in reader.lines().enumerate() {
764            let line = line.map_err(|e| {
765                TrustformersError::other(
766                    anyhow!("Failed to read line {}: {}", line_num, e).to_string(),
767                )
768            })?;
769            let line = line.trim();
770
771            if line.is_empty() || line.starts_with('#') {
772                continue;
773            }
774
775            // Parse line format: token\tscore or token score
776            let parts: Vec<&str> = if line.contains('\t') {
777                line.split('\t').collect()
778            } else {
779                line.split_whitespace().collect()
780            };
781
782            if parts.is_empty() {
783                continue;
784            }
785
786            let token = parts[0].to_string();
787            let score = if parts.len() > 1 {
788                parts[1].parse::<f32>().unwrap_or(0.0)
789            } else {
790                Self::estimate_token_score(&token)
791            };
792
793            let id = line_num as u32;
794            vocab.insert(token.clone(), id);
795            id_to_token.insert(id, token.clone());
796            scores.insert(id, score);
797
798            // Identify special tokens
799            if token.starts_with('<') && token.ends_with('>') {
800                special_tokens.insert(token, id);
801            }
802        }
803
804        // Add configuration
805        config.insert("model_type".to_string(), "sentencepiece".to_string());
806        config.insert("vocab_size".to_string(), vocab.len().to_string());
807        config.insert("normalization".to_string(), "nfkc".to_string());
808
809        Ok((vocab, id_to_token, special_tokens, scores, config))
810    }
811
812    /// Check if a token is valid
813    fn is_valid_token(token: &str) -> bool {
814        // Token should not be too long, not be all whitespace, and contain printable characters
815        token.len() <= 100
816            && !token.trim().is_empty()
817            && token.chars().any(|c| !c.is_whitespace())
818            && token.chars().all(|c| c.is_ascii() || c as u32 > 127) // Allow ASCII and Unicode
819    }
820
821    /// Estimate token score based on characteristics
822    fn estimate_token_score(token: &str) -> f32 {
823        // Estimate score based on token frequency heuristics
824        match token {
825            "<unk>" => -100.0,
826            "<s>" | "</s>" | "<pad>" => -1.0,
827            _ if token.starts_with('<') && token.ends_with('>') => -10.0, // Special tokens
828            _ if token.starts_with("▁") => -5.0 + (token.len() as f32 * -0.1), // SentencePiece prefix
829            _ if token.len() == 1 => -2.0,                                     // Single characters
830            _ if token.len() <= 3 => -3.0 + (token.len() as f32 * -0.2),
831            _ => -5.0 + (token.len() as f32 * -0.1), // Longer subwords get lower scores
832        }
833    }
834
835    /// Extract normalization rules for SentencePiece
836    fn extract_normalization_rules() -> Vec<NormalizationRule> {
837        vec![
838            NormalizationRule {
839                rule_type: "NFKC".to_string(),
840                parameters: {
841                    let mut params = HashMap::new();
842                    params.insert(
843                        "pattern".to_string(),
844                        serde_json::Value::String(".*".to_string()),
845                    );
846                    params.insert(
847                        "replacement".to_string(),
848                        serde_json::Value::String("NFKC_NORMALIZED".to_string()),
849                    );
850                    params.insert("regex".to_string(), serde_json::Value::Bool(false));
851                    params
852                },
853            },
854            NormalizationRule {
855                rule_type: "RemoveExtraSpaces".to_string(),
856                parameters: {
857                    let mut params = HashMap::new();
858                    params.insert(
859                        "pattern".to_string(),
860                        serde_json::Value::String(r"\s+".to_string()),
861                    );
862                    params.insert(
863                        "replacement".to_string(),
864                        serde_json::Value::String(" ".to_string()),
865                    );
866                    params.insert("regex".to_string(), serde_json::Value::Bool(true));
867                    params
868                },
869            },
870        ]
871    }
872
873    /// Extract pre-tokenization rules for SentencePiece
874    fn extract_pre_tokenization_rules() -> Vec<PreTokenizationRule> {
875        vec![
876            PreTokenizationRule {
877                rule_type: "AddDummyPrefix".to_string(),
878                pattern: "^".to_string(),
879                replacement: Some("▁".to_string()),
880            },
881            PreTokenizationRule {
882                rule_type: "SpaceReplacement".to_string(),
883                pattern: " ".to_string(),
884                replacement: Some("▁".to_string()),
885            },
886        ]
887    }
888}
889
890#[cfg(test)]
891mod tests {
892    use super::*;
893    use tempfile::tempdir;
894
895    fn create_test_tokenizer() -> BinaryTokenizer {
896        let mut vocab = HashMap::new();
897        let mut id_to_token = HashMap::new();
898        let mut special_tokens = HashMap::new();
899
900        vocab.insert("hello".to_string(), 0);
901        vocab.insert("world".to_string(), 1);
902        vocab.insert("<pad>".to_string(), 2);
903
904        id_to_token.insert(0, "hello".to_string());
905        id_to_token.insert(1, "world".to_string());
906        id_to_token.insert(2, "<pad>".to_string());
907
908        special_tokens.insert("<pad>".to_string(), 2);
909
910        BinaryTokenizer {
911            vocab,
912            id_to_token,
913            special_tokens,
914            scores: None,
915            merges: None,
916            config: HashMap::new(),
917            normalization_rules: None,
918            pre_tokenization_rules: None,
919        }
920    }
921
922    #[test]
923    fn test_serialize_deserialize() {
924        let temp_dir = tempdir().expect("Operation failed in test");
925        let file_path = temp_dir.path().join("test_tokenizer.bin");
926
927        let config = BinaryConfig::default();
928        let serializer = BinarySerializer::new(config);
929
930        let tokenizer = create_test_tokenizer();
931
932        // Serialize
933        let header = serializer
934            .serialize(&tokenizer, "test", &file_path)
935            .expect("Operation failed in test");
936        assert_eq!(header.tokenizer_type, "test");
937        assert_eq!(header.version, BINARY_FORMAT_VERSION);
938
939        // Deserialize
940        let (loaded_tokenizer, loaded_header) =
941            serializer.deserialize(&file_path).expect("Operation failed in test");
942
943        assert_eq!(loaded_tokenizer.vocab, tokenizer.vocab);
944        assert_eq!(loaded_tokenizer.id_to_token, tokenizer.id_to_token);
945        assert_eq!(loaded_header.tokenizer_type, "test");
946    }
947
948    #[test]
949    fn test_compression() {
950        let temp_dir = tempdir().expect("Operation failed in test");
951        let file_path = temp_dir.path().join("test_compressed.bin");
952
953        let config = BinaryConfig {
954            compression_level: 9,
955            ..Default::default()
956        };
957        let serializer = BinarySerializer::new(config);
958
959        let tokenizer = create_test_tokenizer();
960        let header = serializer
961            .serialize(&tokenizer, "test", &file_path)
962            .expect("Operation failed in test");
963
964        assert!(header.compressed_size < header.uncompressed_size);
965        assert_eq!(header.compression_level, 9);
966
967        // Should still deserialize correctly
968        let (loaded_tokenizer, _) =
969            serializer.deserialize(&file_path).expect("Operation failed in test");
970        assert_eq!(loaded_tokenizer.vocab, tokenizer.vocab);
971    }
972
973    #[test]
974    fn test_file_info() {
975        let temp_dir = tempdir().expect("Operation failed in test");
976        let file_path = temp_dir.path().join("test_info.bin");
977
978        let config = BinaryConfig::default();
979        let serializer = BinarySerializer::new(config);
980
981        let tokenizer = create_test_tokenizer();
982        let original_header = serializer
983            .serialize(&tokenizer, "test", &file_path)
984            .expect("Operation failed in test");
985
986        // Get file info without loading
987        let info_header = serializer.get_file_info(&file_path).expect("Operation failed in test");
988
989        assert_eq!(info_header.tokenizer_type, original_header.tokenizer_type);
990        assert_eq!(info_header.checksum, original_header.checksum);
991    }
992
993    #[test]
994    fn test_validation() {
995        let temp_dir = tempdir().expect("Operation failed in test");
996        let file_path = temp_dir.path().join("test_validate.bin");
997
998        let config = BinaryConfig::default();
999        let serializer = BinarySerializer::new(config.clone());
1000
1001        let tokenizer = create_test_tokenizer();
1002        serializer
1003            .serialize(&tokenizer, "test", &file_path)
1004            .expect("Operation failed in test");
1005
1006        assert!(BinaryUtils::validate_file(&file_path, &config).expect("Operation failed in test"));
1007    }
1008
1009    #[test]
1010    fn test_compression_ratio() {
1011        let temp_dir = tempdir().expect("Operation failed in test");
1012        let file_path = temp_dir.path().join("test_ratio.bin");
1013
1014        let config = BinaryConfig {
1015            compression_level: 6,
1016            ..Default::default()
1017        };
1018        let serializer = BinarySerializer::new(config.clone());
1019
1020        let tokenizer = create_test_tokenizer();
1021        serializer
1022            .serialize(&tokenizer, "test", &file_path)
1023            .expect("Operation failed in test");
1024
1025        let ratio = BinaryUtils::get_compression_ratio(&file_path, &config)
1026            .expect("Operation failed in test");
1027        assert!(ratio > 1.0); // Should have some compression
1028    }
1029}