1use memmap2::{Mmap, MmapOptions};
2use std::collections::HashMap;
3use std::fs::File;
4use std::path::Path;
5use std::slice;
6use trustformers_core::errors::{Result, TrustformersError};
7use trustformers_core::traits::{TokenizedInput, Tokenizer};
8
9#[repr(C, packed)]
11#[derive(Debug, Clone, Copy)]
12pub struct ZeroCopyHeader {
13 pub magic: [u8; 4],
15 pub version: u32,
17 pub header_size: u32,
19 pub vocab_offset: u64,
21 pub vocab_size: u64,
23 pub metadata_offset: u64,
25 pub metadata_size: u64,
27 pub special_tokens_offset: u64,
29 pub special_tokens_size: u64,
31 pub checksum: u64,
33 pub padding: [u8; 8],
35}
36
37impl ZeroCopyHeader {
38 const MAGIC: [u8; 4] = *b"TFZC"; const VERSION: u32 = 1;
40 const SIZE: usize = std::mem::size_of::<Self>();
41
42 pub fn new(
44 vocab_offset: u64,
45 vocab_size: u64,
46 metadata_offset: u64,
47 metadata_size: u64,
48 special_tokens_offset: u64,
49 special_tokens_size: u64,
50 checksum: u64,
51 ) -> Self {
52 Self {
53 magic: Self::MAGIC,
54 version: Self::VERSION,
55 header_size: Self::SIZE as u32,
56 vocab_offset,
57 vocab_size,
58 metadata_offset,
59 metadata_size,
60 special_tokens_offset,
61 special_tokens_size,
62 checksum,
63 padding: [0; 8],
64 }
65 }
66
67 pub fn validate(&self) -> Result<()> {
69 if self.magic != Self::MAGIC {
70 return Err(TrustformersError::serialization_error(
71 "Invalid magic bytes in zero-copy header".to_string(),
72 ));
73 }
74
75 let version = self.version;
76 if version != Self::VERSION {
77 return Err(TrustformersError::serialization_error(format!(
78 "Unsupported version: {}, expected: {}",
79 version,
80 Self::VERSION
81 )));
82 }
83
84 let header_size = self.header_size;
85 if header_size != Self::SIZE as u32 {
86 return Err(TrustformersError::serialization_error(
87 "Invalid header size".to_string(),
88 ));
89 }
90
91 Ok(())
92 }
93}
94
95#[repr(C, packed)]
97#[derive(Debug, Clone, Copy)]
98pub struct ZeroCopyVocabEntry {
99 pub id: u32,
101 pub token_offset: u64,
103 pub token_length: u32,
105 pub frequency: f32,
107 pub flags: u32,
109 pub padding: [u8; 4],
111}
112
113impl ZeroCopyVocabEntry {
114 pub fn is_special(&self) -> bool {
116 (self.flags & 0x01) != 0
117 }
118
119 pub fn set_special(&mut self, is_special: bool) {
121 if is_special {
122 self.flags |= 0x01;
123 } else {
124 self.flags &= !0x01;
125 }
126 }
127}
128
129pub struct ZeroCopyTokenizer {
131 mmap: Mmap,
133 header: ZeroCopyHeader,
135 vocab_entries: &'static [ZeroCopyVocabEntry],
137 token_to_id: HashMap<String, u32>,
139 id_to_token: HashMap<u32, String>,
141}
142
143impl ZeroCopyTokenizer {
144 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
146 let file = File::open(path)?;
147 let mmap = unsafe { MmapOptions::new().map(&file)? };
148
149 if mmap.len() < ZeroCopyHeader::SIZE {
151 return Err(TrustformersError::serialization_error(
152 "File too small to contain header".to_string(),
153 ));
154 }
155
156 let header_bytes = &mmap[0..ZeroCopyHeader::SIZE];
157 let header: ZeroCopyHeader =
158 unsafe { std::ptr::read(header_bytes.as_ptr() as *const ZeroCopyHeader) };
159
160 header.validate()?;
161
162 let vocab_start = header.vocab_offset as usize;
164 let vocab_end = vocab_start + header.vocab_size as usize;
165
166 if vocab_end > mmap.len() {
167 return Err(TrustformersError::serialization_error(
168 "Vocabulary section extends beyond file".to_string(),
169 ));
170 }
171
172 let entry_size = std::mem::size_of::<ZeroCopyVocabEntry>();
173 let num_entries = header.vocab_size as usize / entry_size;
174
175 let vocab_entries = unsafe {
176 slice::from_raw_parts(
177 mmap[vocab_start..].as_ptr() as *const ZeroCopyVocabEntry,
178 num_entries,
179 )
180 };
181
182 let mut token_to_id = HashMap::new();
184 let mut id_to_token = HashMap::new();
185
186 for entry in vocab_entries {
187 let token_start = entry.token_offset as usize;
188 let token_end = token_start + entry.token_length as usize;
189
190 if token_end > mmap.len() {
191 return Err(TrustformersError::serialization_error(
192 "Token string extends beyond file".to_string(),
193 ));
194 }
195
196 let token_bytes = &mmap[token_start..token_end];
197 let token = String::from_utf8(token_bytes.to_vec()).map_err(|e| {
198 TrustformersError::serialization_error(format!("Invalid UTF-8 in token: {}", e))
199 })?;
200
201 token_to_id.insert(token.clone(), entry.id);
202 id_to_token.insert(entry.id, token);
203 }
204
205 Ok(Self {
206 mmap,
207 header,
208 vocab_entries,
209 token_to_id,
210 id_to_token,
211 })
212 }
213
214 pub fn header(&self) -> &ZeroCopyHeader {
216 &self.header
217 }
218
219 pub fn vocab_size(&self) -> usize {
221 self.vocab_entries.len()
222 }
223
224 pub fn get_token_unchecked(&self, id: u32) -> Option<&str> {
226 self.vocab_entries.iter().find(|entry| entry.id == id).and_then(|entry| {
227 let token_start = entry.token_offset as usize;
228 let token_end = token_start + entry.token_length as usize;
229
230 if token_end <= self.mmap.len() {
231 std::str::from_utf8(&self.mmap[token_start..token_end]).ok()
232 } else {
233 None
234 }
235 })
236 }
237
238 pub fn get_id_unchecked(&self, token: &str) -> Option<u32> {
240 self.token_to_id.get(token).copied()
241 }
242
243 pub fn get_vocab_entry(&self, index: usize) -> Option<&ZeroCopyVocabEntry> {
245 self.vocab_entries.get(index)
246 }
247
248 pub fn vocab_entries(&self) -> impl Iterator<Item = &ZeroCopyVocabEntry> {
250 self.vocab_entries.iter()
251 }
252
253 pub fn metadata_bytes(&self) -> &[u8] {
255 let start = self.header.metadata_offset as usize;
256 let end = start + self.header.metadata_size as usize;
257 &self.mmap[start..end]
258 }
259
260 pub fn special_tokens_bytes(&self) -> &[u8] {
262 let start = self.header.special_tokens_offset as usize;
263 let end = start + self.header.special_tokens_size as usize;
264 &self.mmap[start..end]
265 }
266
267 pub fn memory_stats(&self) -> ZeroCopyMemoryStats {
269 ZeroCopyMemoryStats {
270 file_size: self.mmap.len(),
271 header_size: ZeroCopyHeader::SIZE,
272 vocab_size: self.header.vocab_size as usize,
273 metadata_size: self.header.metadata_size as usize,
274 special_tokens_size: self.header.special_tokens_size as usize,
275 lookup_table_size: self.token_to_id.len()
276 * (std::mem::size_of::<String>() + std::mem::size_of::<u32>())
277 + self.id_to_token.len()
278 * (std::mem::size_of::<u32>() + std::mem::size_of::<String>()),
279 }
280 }
281
282 pub fn verify_integrity(&self) -> Result<bool> {
284 let mut hasher = crc32fast::Hasher::new();
286
287 let vocab_start = self.header.vocab_offset as usize;
289 let vocab_end = vocab_start + self.header.vocab_size as usize;
290 hasher.update(&self.mmap[vocab_start..vocab_end]);
291
292 let metadata_start = self.header.metadata_offset as usize;
294 let metadata_end = metadata_start + self.header.metadata_size as usize;
295 hasher.update(&self.mmap[metadata_start..metadata_end]);
296
297 let special_start = self.header.special_tokens_offset as usize;
299 let special_end = special_start + self.header.special_tokens_size as usize;
300 hasher.update(&self.mmap[special_start..special_end]);
301
302 let calculated_checksum = hasher.finalize() as u64;
303 Ok(calculated_checksum == self.header.checksum)
304 }
305}
306
307impl Tokenizer for ZeroCopyTokenizer {
308 fn encode(&self, text: &str) -> Result<TokenizedInput> {
309 let tokens: Vec<&str> = text.split_whitespace().collect();
311 let mut input_ids = Vec::new();
312 let mut tokens_out = Vec::new();
313
314 for token in tokens {
315 if let Some(id) = self.get_id_unchecked(token) {
316 input_ids.push(id);
317 tokens_out.push(token.to_string());
318 }
319 }
320
321 Ok(TokenizedInput {
322 input_ids,
323 attention_mask: vec![1; tokens_out.len()],
324 token_type_ids: None,
325 special_tokens_mask: None,
326 offset_mapping: None,
327 overflowing_tokens: None,
328 })
329 }
330
331 fn decode(&self, token_ids: &[u32]) -> Result<String> {
332 let tokens: std::result::Result<Vec<&str>, _> = token_ids
333 .iter()
334 .map(|&id| {
335 self.get_token_unchecked(id)
336 .ok_or_else(|| TrustformersError::other(format!("Unknown token ID: {}", id)))
337 })
338 .collect();
339
340 Ok(tokens?.join(" "))
341 }
342
343 fn get_vocab(&self) -> HashMap<String, u32> {
344 self.token_to_id.clone()
345 }
346
347 fn token_to_id(&self, token: &str) -> Option<u32> {
348 self.get_id_unchecked(token)
349 }
350
351 fn id_to_token(&self, id: u32) -> Option<String> {
352 self.get_token_unchecked(id).map(|s| s.to_string())
353 }
354
355 fn encode_pair(&self, text_a: &str, text_b: &str) -> Result<TokenizedInput> {
356 let combined = format!("{} {}", text_a, text_b);
358 self.encode(&combined)
359 }
360
361 fn vocab_size(&self) -> usize {
362 self.token_to_id.len()
363 }
364}
365
366#[derive(Debug, Clone)]
368pub struct ZeroCopyMemoryStats {
369 pub file_size: usize,
370 pub header_size: usize,
371 pub vocab_size: usize,
372 pub metadata_size: usize,
373 pub special_tokens_size: usize,
374 pub lookup_table_size: usize,
375}
376
377impl ZeroCopyMemoryStats {
378 pub fn total_memory(&self) -> usize {
380 self.lookup_table_size }
382
383 pub fn efficiency_ratio(&self) -> f64 {
385 if self.file_size == 0 {
386 0.0
387 } else {
388 self.total_memory() as f64 / self.file_size as f64
389 }
390 }
391}
392
393pub struct ZeroCopyBuilder {
395 vocabulary: Vec<(String, u32, f32, bool)>, metadata: Vec<u8>,
397 special_tokens: Vec<u8>,
398}
399
400impl ZeroCopyBuilder {
401 pub fn new() -> Self {
403 Self {
404 vocabulary: Vec::new(),
405 metadata: Vec::new(),
406 special_tokens: Vec::new(),
407 }
408 }
409
410 pub fn add_token(
412 &mut self,
413 token: String,
414 id: u32,
415 frequency: f32,
416 is_special: bool,
417 ) -> &mut Self {
418 self.vocabulary.push((token, id, frequency, is_special));
419 self
420 }
421
422 pub fn add_tokens_from_map(&mut self, vocab: &HashMap<String, u32>) -> &mut Self {
424 for (token, &id) in vocab {
425 self.add_token(token.clone(), id, 1.0, false);
426 }
427 self
428 }
429
430 pub fn set_metadata(&mut self, metadata: Vec<u8>) -> &mut Self {
432 self.metadata = metadata;
433 self
434 }
435
436 pub fn set_special_tokens(&mut self, special_tokens: Vec<u8>) -> &mut Self {
438 self.special_tokens = special_tokens;
439 self
440 }
441
442 pub fn build_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
444 use std::fs::OpenOptions;
445 use std::io::Write;
446
447 let mut file = OpenOptions::new().create(true).write(true).truncate(true).open(path)?;
448
449 let header_size = ZeroCopyHeader::SIZE as u64;
451 let vocab_entry_size = std::mem::size_of::<ZeroCopyVocabEntry>() as u64;
452 let vocab_entries_size = self.vocabulary.len() as u64 * vocab_entry_size;
453
454 let string_data_size: u64 =
456 self.vocabulary.iter().map(|(token, _, _, _)| token.len() as u64).sum();
457
458 let vocab_offset = header_size;
459 let vocab_size = vocab_entries_size + string_data_size;
460 let metadata_offset = vocab_offset + vocab_size;
461 let metadata_size = self.metadata.len() as u64;
462 let special_tokens_offset = metadata_offset + metadata_size;
463 let special_tokens_size = self.special_tokens.len() as u64;
464
465 let mut vocab_entries = Vec::new();
467 let mut string_data = Vec::new();
468 let mut current_string_offset = vocab_offset + vocab_entries_size;
469
470 for (token, id, frequency, is_special) in &self.vocabulary {
471 let token_bytes = token.as_bytes();
472 let mut entry = ZeroCopyVocabEntry {
473 id: *id,
474 token_offset: current_string_offset,
475 token_length: token_bytes.len() as u32,
476 frequency: *frequency,
477 flags: 0,
478 padding: [0; 4],
479 };
480
481 entry.set_special(*is_special);
482 vocab_entries.push(entry);
483
484 string_data.extend_from_slice(token_bytes);
485 current_string_offset += token_bytes.len() as u64;
486 }
487
488 let mut hasher = crc32fast::Hasher::new();
490
491 let entries_bytes = unsafe {
493 slice::from_raw_parts(
494 vocab_entries.as_ptr() as *const u8,
495 vocab_entries.len() * vocab_entry_size as usize,
496 )
497 };
498 hasher.update(entries_bytes);
499 hasher.update(&string_data);
500 hasher.update(&self.metadata);
501 hasher.update(&self.special_tokens);
502
503 let checksum = hasher.finalize() as u64;
504
505 let header = ZeroCopyHeader::new(
507 vocab_offset,
508 vocab_size,
509 metadata_offset,
510 metadata_size,
511 special_tokens_offset,
512 special_tokens_size,
513 checksum,
514 );
515
516 let header_bytes = unsafe {
518 slice::from_raw_parts(
519 &header as *const ZeroCopyHeader as *const u8,
520 ZeroCopyHeader::SIZE,
521 )
522 };
523 file.write_all(header_bytes)?;
524
525 file.write_all(entries_bytes)?;
527
528 file.write_all(&string_data)?;
530
531 file.write_all(&self.metadata)?;
533
534 file.write_all(&self.special_tokens)?;
536
537 file.flush()?;
538 Ok(())
539 }
540}
541
542impl Default for ZeroCopyBuilder {
543 fn default() -> Self {
544 Self::new()
545 }
546}
547
548pub struct ZeroCopyUtils;
550
551impl ZeroCopyUtils {
552 pub fn convert_tokenizer_to_zero_copy<T: Tokenizer, P: AsRef<Path>>(
554 tokenizer: &T,
555 path: P,
556 metadata: Option<&[u8]>,
557 special_tokens: Option<&[u8]>,
558 ) -> Result<()> {
559 let mut builder = ZeroCopyBuilder::new();
560
561 let vocab = tokenizer.get_vocab();
562 builder.add_tokens_from_map(&vocab);
563
564 if let Some(meta) = metadata {
565 builder.set_metadata(meta.to_vec());
566 }
567
568 if let Some(special) = special_tokens {
569 builder.set_special_tokens(special.to_vec());
570 }
571
572 builder.build_to_file(path)
573 }
574
575 pub fn validate_file<P: AsRef<Path>>(path: P) -> Result<bool> {
577 let tokenizer = ZeroCopyTokenizer::from_file(path)?;
578 tokenizer.verify_integrity()
579 }
580
581 pub fn get_file_info<P: AsRef<Path>>(path: P) -> Result<HashMap<String, String>> {
583 let file = File::open(path)?;
584 let mmap = unsafe { MmapOptions::new().map(&file)? };
585
586 if mmap.len() < ZeroCopyHeader::SIZE {
587 return Err(TrustformersError::serialization_error(
588 "File too small to contain header".to_string(),
589 ));
590 }
591
592 let header_bytes = &mmap[0..ZeroCopyHeader::SIZE];
593 let header: ZeroCopyHeader =
594 unsafe { std::ptr::read(header_bytes.as_ptr() as *const ZeroCopyHeader) };
595
596 header.validate()?;
597
598 let mut info = HashMap::new();
599 let version = header.version;
601 let vocab_size = header.vocab_size;
602 let metadata_size = header.metadata_size;
603 let special_tokens_size = header.special_tokens_size;
604
605 info.insert("format".to_string(), "ZeroCopy".to_string());
606 info.insert("version".to_string(), version.to_string());
607 info.insert("file_size".to_string(), mmap.len().to_string());
608 info.insert(
609 "vocab_size".to_string(),
610 (vocab_size / std::mem::size_of::<ZeroCopyVocabEntry>() as u64).to_string(),
611 );
612 info.insert("metadata_size".to_string(), metadata_size.to_string());
613 info.insert(
614 "special_tokens_size".to_string(),
615 special_tokens_size.to_string(),
616 );
617
618 Ok(info)
619 }
620
621 pub fn compare_files<P1: AsRef<Path>, P2: AsRef<Path>>(
623 path1: P1,
624 path2: P2,
625 ) -> Result<HashMap<String, String>> {
626 let info1 = Self::get_file_info(path1)?;
627 let info2 = Self::get_file_info(path2)?;
628
629 let mut comparison = HashMap::new();
630
631 let default_value = "0".to_string();
632 for key in &[
633 "file_size",
634 "vocab_size",
635 "metadata_size",
636 "special_tokens_size",
637 ] {
638 let val1 = info1.get(*key).unwrap_or(&default_value);
639 let val2 = info2.get(*key).unwrap_or(&default_value);
640
641 comparison.insert(format!("{}_1", key), val1.clone());
642 comparison.insert(format!("{}_2", key), val2.clone());
643 comparison.insert(format!("{}_equal", key), (val1 == val2).to_string());
644 }
645
646 Ok(comparison)
647 }
648}
649
650#[cfg(test)]
651mod tests {
652 use super::*;
653 use std::collections::HashMap as TestHashMap;
654 use tempfile::tempdir;
655
656 #[test]
657 fn test_zero_copy_header() {
658 let header = ZeroCopyHeader::new(100, 200, 300, 50, 350, 25, 0x12345678);
659
660 assert_eq!(header.magic, ZeroCopyHeader::MAGIC);
661 let version = header.version;
663 let vocab_offset = header.vocab_offset;
664 let vocab_size = header.vocab_size;
665 let checksum = header.checksum;
666 assert_eq!(version, ZeroCopyHeader::VERSION);
667 assert_eq!(vocab_offset, 100);
668 assert_eq!(vocab_size, 200);
669 assert_eq!(checksum, 0x12345678);
670
671 assert!(header.validate().is_ok());
672 }
673
674 #[test]
675 fn test_vocab_entry_flags() {
676 let mut entry = ZeroCopyVocabEntry {
677 id: 1,
678 token_offset: 0,
679 token_length: 5,
680 frequency: 1.0,
681 flags: 0,
682 padding: [0; 4],
683 };
684
685 assert!(!entry.is_special());
686
687 entry.set_special(true);
688 assert!(entry.is_special());
689
690 entry.set_special(false);
691 assert!(!entry.is_special());
692 }
693
694 #[test]
695 fn test_zero_copy_builder() {
696 let temp_dir = tempdir().expect("Operation failed in test");
697 let file_path = temp_dir.path().join("test_tokenizer.zc");
698
699 let mut builder = ZeroCopyBuilder::new();
700 builder
701 .add_token("hello".to_string(), 1, 1.0, false)
702 .add_token("world".to_string(), 2, 1.0, false)
703 .add_token("<pad>".to_string(), 0, 1.0, true);
704
705 assert!(builder.build_to_file(&file_path).is_ok());
706 assert!(file_path.exists());
707 }
708
709 #[test]
710 fn test_zero_copy_tokenizer_loading() {
711 let temp_dir = tempdir().expect("Operation failed in test");
712 let file_path = temp_dir.path().join("test_load.zc");
713
714 let mut builder = ZeroCopyBuilder::new();
716 builder
717 .add_token("test".to_string(), 1, 1.0, false)
718 .add_token("token".to_string(), 2, 1.0, false)
719 .add_token("[CLS]".to_string(), 0, 1.0, true);
720
721 builder.build_to_file(&file_path).expect("Operation failed in test");
722
723 let tokenizer = ZeroCopyTokenizer::from_file(&file_path).expect("Operation failed in test");
725
726 assert_eq!(tokenizer.vocab_size(), 3);
727 assert_eq!(tokenizer.get_id_unchecked("test"), Some(1));
728 assert_eq!(tokenizer.get_id_unchecked("token"), Some(2));
729 assert_eq!(tokenizer.get_id_unchecked("[CLS]"), Some(0));
730
731 assert_eq!(tokenizer.get_token_unchecked(1), Some("test"));
732 assert_eq!(tokenizer.get_token_unchecked(2), Some("token"));
733 assert_eq!(tokenizer.get_token_unchecked(0), Some("[CLS]"));
734 }
735
736 #[test]
737 fn test_tokenizer_interface() {
738 let temp_dir = tempdir().expect("Operation failed in test");
739 let file_path = temp_dir.path().join("test_interface.zc");
740
741 let mut builder = ZeroCopyBuilder::new();
743 builder.add_token("hello".to_string(), 1, 1.0, false).add_token(
744 "world".to_string(),
745 2,
746 1.0,
747 false,
748 );
749
750 builder.build_to_file(&file_path).expect("Operation failed in test");
751
752 let tokenizer = ZeroCopyTokenizer::from_file(&file_path).expect("Operation failed in test");
754
755 let encoded = tokenizer.encode("hello world").expect("Encoding failed");
756 assert_eq!(encoded.input_ids, vec![1, 2]);
757 let tokens: Vec<String> = encoded
759 .input_ids
760 .iter()
761 .map(|&id| tokenizer.id_to_token(id).expect("Operation failed in test"))
762 .collect();
763 assert_eq!(tokens, vec!["hello", "world"]);
764
765 let decoded = tokenizer.decode(&[1, 2]).expect("Decoding failed");
766 assert_eq!(decoded, "hello world");
767
768 let vocab = tokenizer.get_vocab();
769 assert_eq!(vocab.len(), 2);
770 assert_eq!(vocab.get("hello"), Some(&1));
771 assert_eq!(vocab.get("world"), Some(&2));
772 }
773
774 #[test]
775 fn test_memory_stats() {
776 let temp_dir = tempdir().expect("Operation failed in test");
777 let file_path = temp_dir.path().join("test_stats.zc");
778
779 let mut builder = ZeroCopyBuilder::new();
780 builder
781 .add_token("test".to_string(), 1, 1.0, false)
782 .add_token("memory".to_string(), 2, 1.0, false)
783 .set_metadata(b"test metadata".to_vec())
784 .set_special_tokens(b"special".to_vec());
785
786 builder.build_to_file(&file_path).expect("Operation failed in test");
787
788 let tokenizer = ZeroCopyTokenizer::from_file(&file_path).expect("Operation failed in test");
789 let stats = tokenizer.memory_stats();
790
791 assert!(stats.file_size > 0);
792 assert!(stats.vocab_size > 0);
793 assert_eq!(stats.metadata_size, 13); assert_eq!(stats.special_tokens_size, 7); assert!(stats.lookup_table_size > 0);
796 }
797
798 #[test]
799 fn test_integrity_verification() {
800 let temp_dir = tempdir().expect("Operation failed in test");
801 let file_path = temp_dir.path().join("test_integrity.zc");
802
803 let mut builder = ZeroCopyBuilder::new();
804 builder.add_token("integrity".to_string(), 1, 1.0, false).add_token(
805 "check".to_string(),
806 2,
807 1.0,
808 false,
809 );
810
811 builder.build_to_file(&file_path).expect("Operation failed in test");
812
813 let tokenizer = ZeroCopyTokenizer::from_file(&file_path).expect("Operation failed in test");
814 assert!(tokenizer.verify_integrity().expect("Operation failed in test"));
815 }
816
817 #[test]
818 fn test_utils_file_info() {
819 let temp_dir = tempdir().expect("Operation failed in test");
820 let file_path = temp_dir.path().join("test_info.zc");
821
822 let mut builder = ZeroCopyBuilder::new();
823 builder.add_token("info".to_string(), 1, 1.0, false).add_token(
824 "test".to_string(),
825 2,
826 1.0,
827 false,
828 );
829
830 builder.build_to_file(&file_path).expect("Operation failed in test");
831
832 let info = ZeroCopyUtils::get_file_info(&file_path).expect("Operation failed in test");
833
834 assert_eq!(info.get("format").expect("Key not found"), "ZeroCopy");
835 assert_eq!(info.get("vocab_size").expect("Key not found"), "2");
836 assert!(info.contains_key("file_size"));
837 }
838
839 #[test]
840 fn test_utils_validation() {
841 let temp_dir = tempdir().expect("Operation failed in test");
842 let file_path = temp_dir.path().join("test_validation.zc");
843
844 let mut builder = ZeroCopyBuilder::new();
845 builder.add_token("validate".to_string(), 1, 1.0, false);
846
847 builder.build_to_file(&file_path).expect("Operation failed in test");
848
849 assert!(ZeroCopyUtils::validate_file(&file_path).expect("Operation failed in test"));
850 }
851
852 #[test]
853 fn test_builder_from_map() {
854 let temp_dir = tempdir().expect("Operation failed in test");
855 let file_path = temp_dir.path().join("test_from_map.zc");
856
857 let mut vocab = TestHashMap::new();
858 vocab.insert("from".to_string(), 1);
859 vocab.insert("map".to_string(), 2);
860 vocab.insert("test".to_string(), 3);
861
862 let mut builder = ZeroCopyBuilder::new();
863 builder.add_tokens_from_map(&vocab);
864
865 builder.build_to_file(&file_path).expect("Operation failed in test");
866
867 let tokenizer = ZeroCopyTokenizer::from_file(&file_path).expect("Operation failed in test");
868 assert_eq!(tokenizer.vocab_size(), 3);
869
870 for (token, &expected_id) in &vocab {
871 assert_eq!(tokenizer.get_id_unchecked(token), Some(expected_id));
872 }
873 }
874}