1use crate::error::{Result, TextError};
25use crate::tokenizers::{BertTokenizer, RobertaTokenizer};
26use std::collections::HashMap;
27use std::fs;
28
29#[derive(Debug, Clone, PartialEq)]
33pub struct HfNormalizerConfig {
34 pub lowercase: bool,
36 pub strip_accents: bool,
38 pub handle_chinese_chars: bool,
40}
41
42impl Default for HfNormalizerConfig {
43 fn default() -> Self {
44 HfNormalizerConfig {
45 lowercase: true,
46 strip_accents: true,
47 handle_chinese_chars: true,
48 }
49 }
50}
51
52#[derive(Debug, Clone, PartialEq)]
54pub struct HfAddedToken {
55 pub id: u32,
57 pub content: String,
59 pub special: bool,
61 pub lstrip: bool,
63 pub rstrip: bool,
65 pub single_word: bool,
67 pub normalized: bool,
69}
70
71impl HfAddedToken {
72 pub fn special(id: u32, content: impl Into<String>) -> Self {
74 HfAddedToken {
75 id,
76 content: content.into(),
77 special: true,
78 lstrip: false,
79 rstrip: false,
80 single_word: false,
81 normalized: false,
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
90pub struct HfTokenizerJson {
91 pub version: String,
93 pub model_type: String,
95 pub vocab: HashMap<String, u32>,
97 pub merges: Vec<(String, String)>,
99 pub special_tokens: HashMap<String, u32>,
101 pub normalizer: Option<HfNormalizerConfig>,
103 pub pre_tokenizer: Option<String>,
105 pub added_tokens: Vec<HfAddedToken>,
107 pub wordpiece_prefix: String,
109 pub unk_token: String,
111}
112
113impl HfTokenizerJson {
114 pub fn new_wordpiece(vocab: HashMap<String, u32>) -> Self {
118 HfTokenizerJson {
119 version: "1.0".to_string(),
120 model_type: "WordPiece".to_string(),
121 vocab,
122 merges: Vec::new(),
123 special_tokens: HashMap::new(),
124 normalizer: Some(HfNormalizerConfig::default()),
125 pre_tokenizer: Some("BertPreTokenizer".to_string()),
126 added_tokens: Vec::new(),
127 wordpiece_prefix: "##".to_string(),
128 unk_token: "[UNK]".to_string(),
129 }
130 }
131
132 pub fn new_bpe(vocab: HashMap<String, u32>, merges: Vec<(String, String)>) -> Self {
134 HfTokenizerJson {
135 version: "1.0".to_string(),
136 model_type: "BPE".to_string(),
137 vocab,
138 merges,
139 special_tokens: HashMap::new(),
140 normalizer: None,
141 pre_tokenizer: Some("ByteLevel".to_string()),
142 added_tokens: Vec::new(),
143 wordpiece_prefix: "##".to_string(),
144 unk_token: "<unk>".to_string(),
145 }
146 }
147
148 pub fn from_bert(tokenizer: &BertTokenizer) -> Self {
155 let vocab = tokenizer.vocab().clone();
156
157 let bert_specials = [
158 ("[PAD]", "pad_token"),
159 ("[UNK]", "unk_token"),
160 ("[CLS]", "cls_token"),
161 ("[SEP]", "sep_token"),
162 ("[MASK]", "mask_token"),
163 ];
164
165 let mut special_tokens = HashMap::new();
166 let mut added_tokens: Vec<HfAddedToken> = Vec::new();
167
168 for (token, role) in &bert_specials {
169 if let Some(&id) = vocab.get(*token) {
170 special_tokens.insert(role.to_string(), id);
171 added_tokens.push(HfAddedToken::special(id, *token));
172 }
173 }
174
175 let unk_token = vocab
176 .get("[UNK]")
177 .map(|_| "[UNK]".to_string())
178 .unwrap_or_default();
179
180 let normalizer = Some(HfNormalizerConfig {
181 lowercase: tokenizer.lowercase(),
182 strip_accents: true,
183 handle_chinese_chars: true,
184 });
185
186 HfTokenizerJson {
187 version: "1.0".to_string(),
188 model_type: "WordPiece".to_string(),
189 vocab,
190 merges: Vec::new(),
191 special_tokens,
192 normalizer,
193 pre_tokenizer: Some("BertPreTokenizer".to_string()),
194 added_tokens,
195 wordpiece_prefix: "##".to_string(),
196 unk_token,
197 }
198 }
199
200 pub fn from_roberta(tokenizer: &RobertaTokenizer) -> Self {
202 let vocab = tokenizer.vocab().clone();
203 let merges = tokenizer.merges().to_vec();
204
205 let roberta_specials = [
206 ("<s>", "bos_token"),
207 ("</s>", "eos_token"),
208 ("<pad>", "pad_token"),
209 ("<unk>", "unk_token"),
210 ("<mask>", "mask_token"),
211 ];
212
213 let mut special_tokens = HashMap::new();
214 let mut added_tokens: Vec<HfAddedToken> = Vec::new();
215
216 for (token, role) in &roberta_specials {
217 if let Some(&id) = vocab.get(*token) {
218 special_tokens.insert(role.to_string(), id);
219 added_tokens.push(HfAddedToken::special(id, *token));
220 }
221 }
222
223 let unk_token = "<unk>".to_string();
224
225 HfTokenizerJson {
226 version: "1.0".to_string(),
227 model_type: "BPE".to_string(),
228 vocab,
229 merges,
230 special_tokens,
231 normalizer: None,
232 pre_tokenizer: Some("ByteLevel".to_string()),
233 added_tokens,
234 wordpiece_prefix: "##".to_string(),
235 unk_token,
236 }
237 }
238
239 pub fn to_bert_tokenizer(&self) -> Result<BertTokenizer> {
243 if self.model_type != "WordPiece" {
244 return Err(TextError::InvalidInput(format!(
245 "Cannot create BertTokenizer from model type '{}'",
246 self.model_type
247 )));
248 }
249 let lowercase = self
250 .normalizer
251 .as_ref()
252 .map(|n| n.lowercase)
253 .unwrap_or(true);
254 Ok(BertTokenizer::new(self.vocab.clone(), lowercase))
255 }
256
257 pub fn to_json(&self) -> Result<String> {
264 let mut obj = serde_json_obj();
265
266 obj.insert("version".to_string(), json_string(&self.version));
268
269 obj.insert("truncation".to_string(), "null".to_string());
271 obj.insert("padding".to_string(), "null".to_string());
272
273 let added_tokens_json = self
275 .added_tokens
276 .iter()
277 .map(|t| {
278 format!(
279 "{{\"id\":{},\"content\":{},\"single_word\":{},\"lstrip\":{},\
280 \"rstrip\":{},\"normalized\":{},\"special\":{}}}",
281 t.id,
282 json_string(&t.content),
283 t.single_word,
284 t.lstrip,
285 t.rstrip,
286 t.normalized,
287 t.special,
288 )
289 })
290 .collect::<Vec<_>>()
291 .join(",");
292 obj.insert(
293 "added_tokens".to_string(),
294 format!("[{}]", added_tokens_json),
295 );
296
297 let normalizer_json = match &self.normalizer {
299 None => "null".to_string(),
300 Some(n) => {
301 format!(
302 "{{\"type\":\"BertNormalizer\",\"clean_text\":true,\
303 \"handle_chinese_chars\":{},\"strip_accents\":{},\"lowercase\":{}}}",
304 n.handle_chinese_chars, n.strip_accents, n.lowercase,
305 )
306 }
307 };
308 obj.insert("normalizer".to_string(), normalizer_json);
309
310 let pre_tok_json = match &self.pre_tokenizer {
312 None => "null".to_string(),
313 Some(name) if name == "BertPreTokenizer" => {
314 "{\"type\":\"BertPreTokenizer\"}".to_string()
315 }
316 Some(name) if name == "ByteLevel" => {
317 "{\"type\":\"ByteLevel\",\"add_prefix_space\":false}".to_string()
318 }
319 Some(name) => format!("{{\"type\":{}}}", json_string(name)),
320 };
321 obj.insert("pre_tokenizer".to_string(), pre_tok_json);
322
323 let post_proc_json = if self.model_type == "WordPiece" {
325 if let (Some(&cls_id), Some(&sep_id)) = (
326 self.special_tokens.get("cls_token"),
327 self.special_tokens.get("sep_token"),
328 ) {
329 format!(
330 "{{\"type\":\"TemplateProcessing\",\
331 \"single\":\"[CLS]:0 $A:0 [SEP]:0\",\
332 \"pair\":\"[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1\",\
333 \"special_tokens\":{{\
334 \"[CLS]\":{{\"id\":{},\"ids\":[{}],\"tokens\":[\"[CLS]\"]}},\
335 \"[SEP]\":{{\"id\":{},\"ids\":[{}],\"tokens\":[\"[SEP]\"]}}\
336 }}}}",
337 cls_id, cls_id, sep_id, sep_id,
338 )
339 } else {
340 "null".to_string()
341 }
342 } else {
343 "null".to_string()
344 };
345 obj.insert("post_processor".to_string(), post_proc_json);
346
347 let decoder_json = if self.model_type == "WordPiece" {
349 format!(
350 "{{\"type\":\"WordPiece\",\"prefix\":{},\"cleanup\":true}}",
351 json_string(&self.wordpiece_prefix)
352 )
353 } else {
354 "{\"type\":\"ByteLevel\",\"add_prefix_space\":false}".to_string()
355 };
356 obj.insert("decoder".to_string(), decoder_json);
357
358 let model_json = match self.model_type.as_str() {
360 "WordPiece" => {
361 let vocab_entries = build_vocab_json(&self.vocab);
362 format!(
363 "{{\"type\":\"WordPiece\",\"unk_token\":{},\"continuing_subword_prefix\":{},\
364 \"max_input_chars_per_word\":100,\"vocab\":{{{}}}}}",
365 json_string(&self.unk_token),
366 json_string(&self.wordpiece_prefix),
367 vocab_entries,
368 )
369 }
370 "BPE" => {
371 let vocab_entries = build_vocab_json(&self.vocab);
372 let merges_entries = self
373 .merges
374 .iter()
375 .map(|(a, b)| json_string(&format!("{} {}", a, b)))
376 .collect::<Vec<_>>()
377 .join(",");
378 format!(
379 "{{\"type\":\"BPE\",\"dropout\":null,\"unk_token\":{},\
380 \"continuing_subword_prefix\":null,\"end_of_word_suffix\":null,\
381 \"fuse_unk\":false,\"vocab\":{{{}}},\"merges\":[{}]}}",
382 json_string(&self.unk_token),
383 vocab_entries,
384 merges_entries,
385 )
386 }
387 other => {
388 return Err(TextError::InvalidInput(format!(
389 "Unsupported model type for JSON serialization: {}",
390 other
391 )))
392 }
393 };
394 obj.insert("model".to_string(), model_json);
395
396 let fields = [
398 "version",
399 "truncation",
400 "padding",
401 "added_tokens",
402 "normalizer",
403 "pre_tokenizer",
404 "post_processor",
405 "decoder",
406 "model",
407 ];
408
409 let body = fields
410 .iter()
411 .filter_map(|k| obj.get(*k).map(|v| format!("\"{}\":{}", k, v)))
412 .collect::<Vec<_>>()
413 .join(",");
414
415 Ok(format!("{{{}}}", body))
416 }
417
418 pub fn from_json(json: &str) -> Result<Self> {
420 parse_hf_json(json)
424 }
425
426 pub fn save(&self, path: &str) -> Result<()> {
428 let json = self.to_json()?;
429 fs::write(path, json).map_err(|e| TextError::IoError(e.to_string()))
430 }
431
432 pub fn load(path: &str) -> Result<Self> {
434 let contents = fs::read_to_string(path).map_err(|e| TextError::IoError(e.to_string()))?;
435 Self::from_json(&contents)
436 }
437}
438
439fn json_string(s: &str) -> String {
443 let mut out = String::with_capacity(s.len() + 2);
444 out.push('"');
445 for ch in s.chars() {
446 match ch {
447 '"' => out.push_str("\\\""),
448 '\\' => out.push_str("\\\\"),
449 '\n' => out.push_str("\\n"),
450 '\r' => out.push_str("\\r"),
451 '\t' => out.push_str("\\t"),
452 c if (c as u32) < 0x20 => {
453 out.push_str(&format!("\\u{:04x}", c as u32));
454 }
455 c => out.push(c),
456 }
457 }
458 out.push('"');
459 out
460}
461
462fn build_vocab_json(vocab: &HashMap<String, u32>) -> String {
464 let mut entries: Vec<(&String, &u32)> = vocab.iter().collect();
465 entries.sort_by_key(|(_, &id)| id);
466 entries
467 .iter()
468 .map(|(k, v)| format!("{}:{}", json_string(k), v))
469 .collect::<Vec<_>>()
470 .join(",")
471}
472
473fn serde_json_obj() -> HashMap<String, String> {
475 HashMap::new()
476}
477
478fn parse_hf_json(json: &str) -> Result<HfTokenizerJson> {
485 let root = JsonValue::parse(json)
486 .ok_or_else(|| TextError::InvalidInput("Failed to parse JSON".to_string()))?;
487
488 if root.as_obj().is_none() {
489 return Err(TextError::InvalidInput(
490 "Root must be a JSON object".to_string(),
491 ));
492 }
493
494 let model = root
496 .get("model")
497 .ok_or_else(|| TextError::InvalidInput("Missing 'model' field".to_string()))?;
498
499 let model_type = model
500 .get("type")
501 .and_then(|t| t.as_str())
502 .unwrap_or("WordPiece")
503 .to_string();
504
505 let vocab: HashMap<String, u32> = model
507 .get("vocab")
508 .and_then(|v| v.as_obj())
509 .map(|obj| {
510 obj.iter()
511 .filter_map(|(k, v)| v.as_u32().map(|id| (k.clone(), id)))
512 .collect()
513 })
514 .unwrap_or_default();
515
516 let merges: Vec<(String, String)> = model
518 .get("merges")
519 .and_then(|m| m.as_arr())
520 .map(|arr| {
521 arr.iter()
522 .filter_map(|item| {
523 let s = item.as_str()?;
524 let mut parts = s.splitn(2, ' ');
525 let a = parts.next()?.to_string();
526 let b = parts.next()?.to_string();
527 Some((a, b))
528 })
529 .collect()
530 })
531 .unwrap_or_default();
532
533 let unk_token = model
534 .get("unk_token")
535 .and_then(|t| t.as_str())
536 .unwrap_or("[UNK]")
537 .to_string();
538
539 let wordpiece_prefix = model
540 .get("continuing_subword_prefix")
541 .and_then(|p| p.as_str())
542 .unwrap_or("##")
543 .to_string();
544
545 let normalizer = root.get("normalizer").and_then(|n| {
547 n.as_obj()?;
548 Some(HfNormalizerConfig {
549 lowercase: n
550 .get("lowercase")
551 .and_then(|v| v.as_bool())
552 .unwrap_or(false),
553 strip_accents: n
554 .get("strip_accents")
555 .and_then(|v| v.as_bool())
556 .unwrap_or(false),
557 handle_chinese_chars: n
558 .get("handle_chinese_chars")
559 .and_then(|v| v.as_bool())
560 .unwrap_or(false),
561 })
562 });
563
564 let pre_tokenizer = root
566 .get("pre_tokenizer")
567 .and_then(|pt| pt.get("type"))
568 .and_then(|t| t.as_str())
569 .map(|s| s.to_string());
570
571 let added_tokens: Vec<HfAddedToken> = root
573 .get("added_tokens")
574 .and_then(|at| at.as_arr())
575 .map(|arr| {
576 arr.iter()
577 .filter_map(|item| {
578 item.as_obj()?;
579 let id = item.get("id")?.as_u32()?;
580 let content = item.get("content")?.as_str()?.to_string();
581 let special = item
582 .get("special")
583 .and_then(|v| v.as_bool())
584 .unwrap_or(false);
585 let lstrip = item
586 .get("lstrip")
587 .and_then(|v| v.as_bool())
588 .unwrap_or(false);
589 let rstrip = item
590 .get("rstrip")
591 .and_then(|v| v.as_bool())
592 .unwrap_or(false);
593 let single_word = item
594 .get("single_word")
595 .and_then(|v| v.as_bool())
596 .unwrap_or(false);
597 let normalized = item
598 .get("normalized")
599 .and_then(|v| v.as_bool())
600 .unwrap_or(false);
601 Some(HfAddedToken {
602 id,
603 content,
604 special,
605 lstrip,
606 rstrip,
607 single_word,
608 normalized,
609 })
610 })
611 .collect()
612 })
613 .unwrap_or_default();
614
615 let mut special_tokens: HashMap<String, u32> = HashMap::new();
617 if model_type == "WordPiece" {
618 let roles = [
619 ("[PAD]", "pad_token"),
620 ("[UNK]", "unk_token"),
621 ("[CLS]", "cls_token"),
622 ("[SEP]", "sep_token"),
623 ("[MASK]", "mask_token"),
624 ];
625 for (tok, role) in &roles {
626 if let Some(&id) = vocab.get(*tok) {
627 special_tokens.insert(role.to_string(), id);
628 }
629 }
630 } else {
631 let roles = [
632 ("<s>", "bos_token"),
633 ("</s>", "eos_token"),
634 ("<pad>", "pad_token"),
635 ("<unk>", "unk_token"),
636 ("<mask>", "mask_token"),
637 ];
638 for (tok, role) in &roles {
639 if let Some(&id) = vocab.get(*tok) {
640 special_tokens.insert(role.to_string(), id);
641 }
642 }
643 }
644
645 Ok(HfTokenizerJson {
646 version: root
647 .get("version")
648 .and_then(|v| v.as_str())
649 .unwrap_or("1.0")
650 .to_string(),
651 model_type,
652 vocab,
653 merges,
654 special_tokens,
655 normalizer,
656 pre_tokenizer,
657 added_tokens,
658 wordpiece_prefix,
659 unk_token,
660 })
661}
662
663#[derive(Debug)]
667enum JsonValue {
668 Null,
669 Bool(bool),
670 Number(f64),
671 Str(String),
672 Array(Vec<JsonValue>),
673 Object(Vec<(String, JsonValue)>),
674}
675
676impl JsonValue {
677 fn as_str(&self) -> Option<&str> {
678 if let JsonValue::Str(s) = self {
679 Some(s.as_str())
680 } else {
681 None
682 }
683 }
684
685 fn as_bool(&self) -> Option<bool> {
686 if let JsonValue::Bool(b) = self {
687 Some(*b)
688 } else {
689 None
690 }
691 }
692
693 fn as_u32(&self) -> Option<u32> {
694 if let JsonValue::Number(n) = self {
695 Some(*n as u32)
696 } else {
697 None
698 }
699 }
700
701 fn as_obj(&self) -> Option<&[(String, JsonValue)]> {
702 if let JsonValue::Object(fields) = self {
703 Some(fields.as_slice())
704 } else {
705 None
706 }
707 }
708
709 fn get(&self, key: &str) -> Option<&JsonValue> {
711 if let JsonValue::Object(fields) = self {
712 fields.iter().find(|(k, _)| k == key).map(|(_, v)| v)
713 } else {
714 None
715 }
716 }
717
718 fn as_arr(&self) -> Option<&[JsonValue]> {
719 if let JsonValue::Array(items) = self {
720 Some(items.as_slice())
721 } else {
722 None
723 }
724 }
725
726 fn parse(s: &str) -> Option<Self> {
728 let mut p = Parser {
729 src: s.as_bytes(),
730 pos: 0,
731 };
732 let v = p.parse_value()?;
733 p.skip_ws();
734 if p.pos == p.src.len() {
735 Some(v)
736 } else {
737 None
738 }
739 }
740}
741
742struct Parser<'a> {
745 src: &'a [u8],
746 pos: usize,
747}
748
749impl<'a> Parser<'a> {
750 fn peek(&self) -> Option<u8> {
751 self.src.get(self.pos).copied()
752 }
753
754 fn advance(&mut self) -> Option<u8> {
755 let b = self.src.get(self.pos).copied();
756 self.pos += 1;
757 b
758 }
759
760 fn skip_ws(&mut self) {
761 while let Some(b) = self.peek() {
762 if b == b' ' || b == b'\t' || b == b'\n' || b == b'\r' {
763 self.pos += 1;
764 } else {
765 break;
766 }
767 }
768 }
769
770 fn expect(&mut self, byte: u8) -> Option<()> {
771 self.skip_ws();
772 if self.peek()? == byte {
773 self.pos += 1;
774 Some(())
775 } else {
776 None
777 }
778 }
779
780 fn parse_value(&mut self) -> Option<JsonValue> {
781 self.skip_ws();
782 match self.peek()? {
783 b'"' => self.parse_string().map(JsonValue::Str),
784 b'{' => self.parse_object(),
785 b'[' => self.parse_array(),
786 b't' => {
787 self.expect_literal(b"true")?;
788 Some(JsonValue::Bool(true))
789 }
790 b'f' => {
791 self.expect_literal(b"false")?;
792 Some(JsonValue::Bool(false))
793 }
794 b'n' => {
795 self.expect_literal(b"null")?;
796 Some(JsonValue::Null)
797 }
798 b'-' | b'0'..=b'9' => self.parse_number().map(JsonValue::Number),
799 _ => None,
800 }
801 }
802
803 fn expect_literal(&mut self, lit: &[u8]) -> Option<()> {
804 let end = self.pos + lit.len();
805 if self.src.get(self.pos..end)? == lit {
806 self.pos = end;
807 Some(())
808 } else {
809 None
810 }
811 }
812
813 fn parse_string(&mut self) -> Option<String> {
814 self.skip_ws();
815 self.expect(b'"')?;
816 let mut s = String::new();
817 loop {
818 match self.advance()? {
819 b'"' => break,
820 b'\\' => {
821 match self.advance()? {
822 b'"' => s.push('"'),
823 b'\\' => s.push('\\'),
824 b'/' => s.push('/'),
825 b'n' => s.push('\n'),
826 b'r' => s.push('\r'),
827 b't' => s.push('\t'),
828 b'b' => s.push('\x08'),
829 b'f' => s.push('\x0C'),
830 b'u' => {
831 let mut code: u32 = 0;
833 for _ in 0..4 {
834 let h = self.advance()?;
835 let digit = match h {
836 b'0'..=b'9' => h - b'0',
837 b'a'..=b'f' => h - b'a' + 10,
838 b'A'..=b'F' => h - b'A' + 10,
839 _ => return None,
840 };
841 code = (code << 4) | digit as u32;
842 }
843 s.push(char::from_u32(code)?);
844 }
845 _ => return None,
846 }
847 }
848 byte => {
849 let start = self.pos - 1;
851 let leading = byte;
853 let extra = if leading < 0x80 {
854 0
855 } else if leading < 0xE0 {
856 1
857 } else if leading < 0xF0 {
858 2
859 } else {
860 3
861 };
862 for _ in 0..extra {
863 self.advance()?;
864 }
865 let slice = &self.src[start..self.pos];
866 let ch = std::str::from_utf8(slice).ok()?.chars().next()?;
867 s.push(ch);
868 }
869 }
870 }
871 Some(s)
872 }
873
874 fn parse_number(&mut self) -> Option<f64> {
875 let start = self.pos;
876 if self.peek() == Some(b'-') {
878 self.pos += 1;
879 }
880 while matches!(self.peek(), Some(b'0'..=b'9')) {
882 self.pos += 1;
883 }
884 if self.peek() == Some(b'.') {
886 self.pos += 1;
887 while matches!(self.peek(), Some(b'0'..=b'9')) {
888 self.pos += 1;
889 }
890 }
891 if matches!(self.peek(), Some(b'e') | Some(b'E')) {
893 self.pos += 1;
894 if matches!(self.peek(), Some(b'+') | Some(b'-')) {
895 self.pos += 1;
896 }
897 while matches!(self.peek(), Some(b'0'..=b'9')) {
898 self.pos += 1;
899 }
900 }
901 let slice = std::str::from_utf8(&self.src[start..self.pos]).ok()?;
902 slice.parse::<f64>().ok()
903 }
904
905 fn parse_object(&mut self) -> Option<JsonValue> {
906 self.expect(b'{')?;
907 let mut fields: Vec<(String, JsonValue)> = Vec::new();
908 self.skip_ws();
909 if self.peek() == Some(b'}') {
910 self.pos += 1;
911 return Some(JsonValue::Object(fields));
912 }
913 loop {
914 self.skip_ws();
915 let key = self.parse_string()?;
916 self.expect(b':')?;
917 let val = self.parse_value()?;
918 fields.push((key, val));
919 self.skip_ws();
920 match self.peek()? {
921 b',' => {
922 self.pos += 1;
923 }
924 b'}' => {
925 self.pos += 1;
926 break;
927 }
928 _ => return None,
929 }
930 }
931 Some(JsonValue::Object(fields))
932 }
933
934 fn parse_array(&mut self) -> Option<JsonValue> {
935 self.expect(b'[')?;
936 let mut items: Vec<JsonValue> = Vec::new();
937 self.skip_ws();
938 if self.peek() == Some(b']') {
939 self.pos += 1;
940 return Some(JsonValue::Array(items));
941 }
942 loop {
943 let val = self.parse_value()?;
944 items.push(val);
945 self.skip_ws();
946 match self.peek()? {
947 b',' => {
948 self.pos += 1;
949 }
950 b']' => {
951 self.pos += 1;
952 break;
953 }
954 _ => return None,
955 }
956 }
957 Some(JsonValue::Array(items))
958 }
959}
960
961#[cfg(test)]
964mod tests {
965 use super::*;
966
967 fn small_bert_vocab() -> HashMap<String, u32> {
968 let pairs = [
969 ("[PAD]", 0u32),
970 ("[UNK]", 1),
971 ("[CLS]", 2),
972 ("[SEP]", 3),
973 ("[MASK]", 4),
974 ("hello", 5),
975 ("world", 6),
976 ("##ing", 7),
977 ];
978 pairs.iter().map(|(k, v)| (k.to_string(), *v)).collect()
979 }
980
981 fn small_bpe_vocab() -> HashMap<String, u32> {
982 let pairs = [
983 ("<s>", 0u32),
984 ("<pad>", 1),
985 ("</s>", 2),
986 ("<unk>", 3),
987 ("he", 4),
988 ("llo", 5),
989 ("hello", 6),
990 ("<mask>", 50264),
991 ];
992 pairs.iter().map(|(k, v)| (k.to_string(), *v)).collect()
993 }
994
995 #[test]
996 fn test_hf_json_wordpiece_roundtrip() {
997 let vocab = small_bert_vocab();
998 let hf = HfTokenizerJson::new_wordpiece(vocab.clone());
999
1000 let json = hf.to_json().expect("serialize");
1001 let parsed = HfTokenizerJson::from_json(&json).expect("deserialize");
1002
1003 assert_eq!(parsed.model_type, "WordPiece");
1004 assert_eq!(parsed.vocab.len(), vocab.len());
1005 for (k, v) in &vocab {
1006 assert_eq!(parsed.vocab.get(k), Some(v), "mismatch for token {}", k);
1007 }
1008 }
1009
1010 #[test]
1011 fn test_hf_json_from_bert() {
1012 let vocab = small_bert_vocab();
1013 let tokenizer = BertTokenizer::new(vocab.clone(), true);
1014 let hf = HfTokenizerJson::from_bert(&tokenizer);
1015
1016 assert_eq!(hf.model_type, "WordPiece");
1017 assert_eq!(hf.version, "1.0");
1018 assert!(hf.vocab.contains_key("[CLS]"));
1019 assert!(hf.vocab.contains_key("[SEP]"));
1020 }
1021
1022 #[test]
1023 fn test_hf_json_special_tokens() {
1024 let vocab = small_bert_vocab();
1025 let tokenizer = BertTokenizer::new(vocab, true);
1026 let hf = HfTokenizerJson::from_bert(&tokenizer);
1027
1028 assert!(hf.special_tokens.contains_key("cls_token"));
1029 assert!(hf.special_tokens.contains_key("sep_token"));
1030 assert!(hf.special_tokens.contains_key("pad_token"));
1031 assert!(hf.special_tokens.contains_key("unk_token"));
1032 assert!(hf.special_tokens.contains_key("mask_token"));
1033
1034 let contents: Vec<&str> = hf.added_tokens.iter().map(|t| t.content.as_str()).collect();
1036 assert!(contents.contains(&"[CLS]"));
1037 assert!(contents.contains(&"[SEP]"));
1038 }
1039
1040 #[test]
1041 fn test_hf_json_bpe_merges() {
1042 let vocab = small_bpe_vocab();
1043 let merges = vec![("he".to_string(), "llo".to_string())];
1044 let hf = HfTokenizerJson::new_bpe(vocab.clone(), merges.clone());
1045
1046 let json = hf.to_json().expect("serialize");
1047 let parsed = HfTokenizerJson::from_json(&json).expect("deserialize");
1048
1049 assert_eq!(parsed.model_type, "BPE");
1050 assert_eq!(parsed.merges.len(), 1);
1051 assert_eq!(parsed.merges[0], ("he".to_string(), "llo".to_string()));
1052 }
1053
1054 #[test]
1055 fn test_hf_json_to_bert_tokenizer() {
1056 let vocab = small_bert_vocab();
1057 let hf = HfTokenizerJson::new_wordpiece(vocab.clone());
1058 let tokenizer = hf.to_bert_tokenizer().expect("reconstruction");
1059 assert_eq!(tokenizer.vocab_size(), vocab.len());
1060 }
1061
1062 #[test]
1063 fn test_hf_json_wordpiece_prefix_preserved() {
1064 let vocab = small_bert_vocab();
1065 let mut hf = HfTokenizerJson::new_wordpiece(vocab);
1066 hf.wordpiece_prefix = "@@".to_string();
1067
1068 let json = hf.to_json().expect("serialize");
1069 let parsed = HfTokenizerJson::from_json(&json).expect("deserialize");
1070 assert_eq!(parsed.wordpiece_prefix, "@@");
1071 }
1072
1073 #[test]
1074 fn test_hf_json_normalizer_roundtrip() {
1075 let vocab = small_bert_vocab();
1076 let mut hf = HfTokenizerJson::new_wordpiece(vocab);
1077 hf.normalizer = Some(HfNormalizerConfig {
1078 lowercase: false,
1079 strip_accents: true,
1080 handle_chinese_chars: false,
1081 });
1082
1083 let json = hf.to_json().expect("serialize");
1084 let parsed = HfTokenizerJson::from_json(&json).expect("deserialize");
1085 let norm = parsed.normalizer.expect("normalizer present");
1086 assert!(!norm.lowercase);
1087 assert!(norm.strip_accents);
1088 assert!(!norm.handle_chinese_chars);
1089 }
1090
1091 #[test]
1092 fn test_hf_json_empty_merges_bpe() {
1093 let vocab = small_bpe_vocab();
1094 let hf = HfTokenizerJson::new_bpe(vocab, vec![]);
1095 let json = hf.to_json().expect("serialize");
1096 let parsed = HfTokenizerJson::from_json(&json).expect("deserialize");
1097 assert!(parsed.merges.is_empty());
1098 }
1099
1100 #[test]
1101 fn test_hf_json_save_and_load() {
1102 let vocab = small_bert_vocab();
1103 let hf = HfTokenizerJson::new_wordpiece(vocab.clone());
1104
1105 let tmp = std::env::temp_dir().join("test_hf_tokenizer.json");
1106 let path = tmp.to_str().expect("valid path");
1107
1108 hf.save(path).expect("save");
1109 let loaded = HfTokenizerJson::load(path).expect("load");
1110
1111 assert_eq!(loaded.model_type, "WordPiece");
1112 assert_eq!(loaded.vocab.len(), vocab.len());
1113
1114 let _ = std::fs::remove_file(path);
1115 }
1116}