1use std::collections::HashMap;
35
36use crate::error::{Result, TextError};
37use crate::gpt_bpe::Gpt2BpeTokenizer;
38use crate::tokenization::wordpiece::WordPieceTokenizer;
39
40#[derive(Debug, Clone, PartialEq, Eq)]
44#[non_exhaustive]
45pub enum HfModelType {
46 WordPiece,
48 Bpe,
50 Unigram,
52 Unknown(String),
54}
55
56impl HfModelType {
57 pub fn as_str(&self) -> &str {
59 match self {
60 HfModelType::WordPiece => "WordPiece",
61 HfModelType::Bpe => "BPE",
62 HfModelType::Unigram => "Unigram",
63 HfModelType::Unknown(s) => s.as_str(),
64 }
65 }
66
67 pub fn parse(s: &str) -> Self {
69 match s {
70 "WordPiece" | "wordpiece" | "WORDPIECE" => HfModelType::WordPiece,
71 "BPE" | "Bpe" | "bpe" => HfModelType::Bpe,
72 "Unigram" | "unigram" | "UNIGRAM" => HfModelType::Unigram,
73 other => HfModelType::Unknown(other.to_string()),
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
82pub struct HfAddedToken {
83 pub id: u32,
85 pub content: String,
87 pub special: bool,
89 pub single_word: bool,
91 pub lstrip: bool,
93 pub rstrip: bool,
95 pub normalized: bool,
97}
98
99impl HfAddedToken {
100 pub fn special(id: u32, content: impl Into<String>) -> Self {
102 HfAddedToken {
103 id,
104 content: content.into(),
105 special: true,
106 single_word: false,
107 lstrip: false,
108 rstrip: false,
109 normalized: false,
110 }
111 }
112
113 fn to_json_object(&self) -> String {
115 format!(
116 r#"{{"id":{},"content":{},"single_word":{},"lstrip":{},"rstrip":{},"normalized":{},"special":{}}}"#,
117 self.id,
118 json_string(&self.content),
119 self.single_word,
120 self.lstrip,
121 self.rstrip,
122 self.normalized,
123 self.special,
124 )
125 }
126
127 fn from_json_object(obj: &str) -> Option<Self> {
129 let id = parse_u32_field(obj, "id")?;
130 let content = parse_string_field(obj, "content")?;
131 let special = parse_bool_field(obj, "special").unwrap_or(false);
132 let single_word = parse_bool_field(obj, "single_word").unwrap_or(false);
133 let lstrip = parse_bool_field(obj, "lstrip").unwrap_or(false);
134 let rstrip = parse_bool_field(obj, "rstrip").unwrap_or(false);
135 let normalized = parse_bool_field(obj, "normalized").unwrap_or(false);
136 Some(HfAddedToken {
137 id,
138 content,
139 special,
140 single_word,
141 lstrip,
142 rstrip,
143 normalized,
144 })
145 }
146}
147
148#[derive(Debug, Clone)]
152pub struct HfModel {
153 pub model_type: String,
155 pub vocab: HashMap<String, u32>,
157 pub merges: Option<Vec<String>>,
159 pub unk_token: Option<String>,
161 pub continuing_subword_prefix: Option<String>,
163 pub max_input_chars_per_word: Option<u32>,
165}
166
167impl HfModel {
168 fn to_json_string(&self) -> String {
170 let mut parts: Vec<String> = Vec::new();
171
172 parts.push(format!(r#""type":{}"#, json_string(&self.model_type)));
173
174 if let Some(ref unk) = self.unk_token {
175 parts.push(format!(r#""unk_token":{}"#, json_string(unk)));
176 }
177
178 if let Some(ref pfx) = self.continuing_subword_prefix {
179 parts.push(format!(
180 r#""continuing_subword_prefix":{}"#,
181 json_string(pfx)
182 ));
183 }
184
185 if let Some(max_chars) = self.max_input_chars_per_word {
186 parts.push(format!(r#""max_input_chars_per_word":{}"#, max_chars));
187 }
188
189 let vocab_entries = {
191 let mut sorted: Vec<(&String, &u32)> = self.vocab.iter().collect();
192 sorted.sort_by_key(|(_, &id)| id);
193 sorted
194 .iter()
195 .map(|(tok, id)| format!("{}:{}", json_string(tok), id))
196 .collect::<Vec<_>>()
197 .join(",")
198 };
199 parts.push(format!(r#""vocab":{{{}}}"#, vocab_entries));
200
201 if let Some(ref merges) = self.merges {
202 let merge_strs = merges
203 .iter()
204 .map(|m| json_string(m))
205 .collect::<Vec<_>>()
206 .join(",");
207 parts.push(format!(r#""merges":[{}]"#, merge_strs));
208 }
209
210 format!("{{{}}}", parts.join(","))
211 }
212
213 fn from_json_str(s: &str) -> Result<Self> {
215 let model_type = parse_string_field(s, "type").ok_or_else(|| {
216 TextError::InvalidInput("HF JSON: missing model.type field".to_string())
217 })?;
218
219 let unk_token = parse_string_field(s, "unk_token");
220 let continuing_subword_prefix = parse_string_field(s, "continuing_subword_prefix");
221 let max_input_chars_per_word = parse_u32_field(s, "max_input_chars_per_word");
222
223 let vocab = parse_vocab_object(s)?;
224 let merges = parse_string_array_field(s, "merges");
225
226 Ok(HfModel {
227 model_type,
228 vocab,
229 merges,
230 unk_token,
231 continuing_subword_prefix,
232 max_input_chars_per_word,
233 })
234 }
235}
236
237#[derive(Debug, Clone)]
241pub struct HfTokenizerJson {
242 pub version: String,
244 pub model: HfModel,
246 pub added_tokens: Vec<HfAddedToken>,
248 pub normalizer_json: Option<String>,
250 pub pre_tokenizer_json: Option<String>,
252 pub post_processor_json: Option<String>,
254 pub decoder_json: Option<String>,
256}
257
258impl HfTokenizerJson {
259 pub fn from_wordpiece(wp: &WordPieceTokenizer) -> Self {
266 let vocab: HashMap<String, u32> = wp.vocab_snapshot();
267
268 let get = |tok: &str, fallback: u32| -> u32 { vocab.get(tok).copied().unwrap_or(fallback) };
270
271 let added_tokens = vec![
272 HfAddedToken::special(get("[PAD]", 0), "[PAD]"),
273 HfAddedToken::special(get("[UNK]", 1), "[UNK]"),
274 HfAddedToken::special(get("[CLS]", 101), "[CLS]"),
275 HfAddedToken::special(get("[SEP]", 102), "[SEP]"),
276 HfAddedToken::special(get("[MASK]", 103), "[MASK]"),
277 ];
278
279 let model = HfModel {
280 model_type: "WordPiece".to_string(),
281 vocab,
282 merges: None,
283 unk_token: Some("[UNK]".to_string()),
284 continuing_subword_prefix: Some("##".to_string()),
285 max_input_chars_per_word: Some(100),
286 };
287
288 HfTokenizerJson {
289 version: "1.0".to_string(),
290 model,
291 added_tokens,
292 normalizer_json: None,
293 pre_tokenizer_json: None,
294 post_processor_json: None,
295 decoder_json: None,
296 }
297 }
298
299 pub fn from_gpt2_bpe(bpe: &Gpt2BpeTokenizer) -> Self {
301 let vocab: HashMap<String, u32> = bpe.vocab_snapshot();
302 let merges: Vec<String> = bpe
303 .merges()
304 .iter()
305 .map(|(a, b)| format!("{} {}", a, b))
306 .collect();
307
308 let model = HfModel {
309 model_type: "BPE".to_string(),
310 vocab,
311 merges: Some(merges),
312 unk_token: None,
313 continuing_subword_prefix: None,
314 max_input_chars_per_word: None,
315 };
316
317 HfTokenizerJson {
318 version: "1.0".to_string(),
319 model,
320 added_tokens: vec![],
321 normalizer_json: None,
322 pre_tokenizer_json: None,
323 post_processor_json: None,
324 decoder_json: None,
325 }
326 }
327
328 pub fn to_json_string(&self) -> String {
335 let added_tokens_str = self
336 .added_tokens
337 .iter()
338 .map(|t| t.to_json_object())
339 .collect::<Vec<_>>()
340 .join(",");
341
342 let null_or =
343 |opt: &Option<String>| -> String { opt.as_deref().unwrap_or("null").to_string() };
344
345 format!(
346 r#"{{"version":{},"truncation":null,"padding":null,"added_tokens":[{}],"normalizer":{},"pre_tokenizer":{},"post_processor":{},"decoder":{},"model":{}}}"#,
347 json_string(&self.version),
348 added_tokens_str,
349 null_or(&self.normalizer_json),
350 null_or(&self.pre_tokenizer_json),
351 null_or(&self.post_processor_json),
352 null_or(&self.decoder_json),
353 self.model.to_json_string(),
354 )
355 }
356
357 pub fn from_json_str(s: &str) -> Result<Self> {
361 let version = parse_string_field(s, "version").unwrap_or_else(|| "1.0".to_string());
362
363 let model_str = extract_object_field(s, "model").ok_or_else(|| {
365 TextError::InvalidInput("HF JSON: missing 'model' object".to_string())
366 })?;
367 let model = HfModel::from_json_str(model_str)?;
368
369 let added_tokens = extract_array_field(s, "added_tokens")
371 .unwrap_or_default()
372 .iter()
373 .filter_map(|obj| HfAddedToken::from_json_object(obj))
374 .collect();
375
376 let normalizer_json = extract_object_field(s, "normalizer").map(|o| o.to_string());
377 let pre_tokenizer_json = extract_object_field(s, "pre_tokenizer").map(|o| o.to_string());
378 let post_processor_json = extract_object_field(s, "post_processor").map(|o| o.to_string());
379 let decoder_json = extract_object_field(s, "decoder").map(|o| o.to_string());
380
381 Ok(HfTokenizerJson {
382 version,
383 model,
384 added_tokens,
385 normalizer_json,
386 pre_tokenizer_json,
387 post_processor_json,
388 decoder_json,
389 })
390 }
391
392 pub fn wordpiece_roundtrip_check(wp: &WordPieceTokenizer) -> bool {
399 let original = Self::from_wordpiece(wp);
400 let json = original.to_json_string();
401 match Self::from_json_str(&json) {
402 Ok(restored) => {
403 restored.model.vocab.len() == original.model.vocab.len()
404 && restored.model.model_type == original.model.model_type
405 }
406 Err(_) => false,
407 }
408 }
409}
410
411pub fn detect_model_type(json: &str) -> Result<HfModelType> {
415 let model_str = extract_object_field(json, "model").ok_or_else(|| {
416 TextError::InvalidInput("HF JSON: could not locate 'model' object".to_string())
417 })?;
418 let type_str = parse_string_field(model_str, "type")
419 .ok_or_else(|| TextError::InvalidInput("HF JSON: missing model.type field".to_string()))?;
420 Ok(HfModelType::parse(&type_str))
421}
422
423fn json_string(s: &str) -> String {
431 let mut out = String::with_capacity(s.len() + 2);
432 out.push('"');
433 for ch in s.chars() {
434 match ch {
435 '"' => out.push_str(r#"\""#),
436 '\\' => out.push_str(r"\\"),
437 '\n' => out.push_str(r"\n"),
438 '\r' => out.push_str(r"\r"),
439 '\t' => out.push_str(r"\t"),
440 c if (c as u32) < 0x20 => {
441 out.push_str(&format!("\\u{:04x}", c as u32));
442 }
443 c => out.push(c),
444 }
445 }
446 out.push('"');
447 out
448}
449
450fn extract_json_value<'a>(json: &'a str, key: &str) -> Option<&'a str> {
458 let needle = format!("\"{}\":", key);
459 let pos = json.find(needle.as_str())?;
460 let after_key = json[pos + needle.len()..].trim_start();
461
462 if after_key.starts_with("null") {
463 return None;
464 }
465
466 Some(after_key)
468}
469
470fn parse_string_field(json: &str, key: &str) -> Option<String> {
472 let raw = extract_json_value(json, key)?;
473 if !raw.starts_with('"') {
474 return None;
475 }
476 let mut chars = raw.char_indices().skip(1); let mut result = String::new();
479 loop {
480 match chars.next() {
481 None => return None,
482 Some((_, '"')) => break,
483 Some((_, '\\')) => {
484 match chars.next() {
485 Some((_, '"')) => result.push('"'),
486 Some((_, '\\')) => result.push('\\'),
487 Some((_, 'n')) => result.push('\n'),
488 Some((_, 'r')) => result.push('\r'),
489 Some((_, 't')) => result.push('\t'),
490 Some((_, 'u')) => {
491 let mut hex = String::new();
493 for _ in 0..4 {
494 if let Some((_, c)) = chars.next() {
495 hex.push(c);
496 }
497 }
498 if let Ok(n) = u32::from_str_radix(&hex, 16) {
499 if let Some(c) = char::from_u32(n) {
500 result.push(c);
501 }
502 }
503 }
504 Some((_, c)) => result.push(c),
505 None => return None,
506 }
507 }
508 Some((_, c)) => result.push(c),
509 }
510 }
511 Some(result)
512}
513
514fn parse_bool_field(json: &str, key: &str) -> Option<bool> {
516 let raw = extract_json_value(json, key)?;
517 if raw.starts_with("true") {
518 Some(true)
519 } else if raw.starts_with("false") {
520 Some(false)
521 } else {
522 None
523 }
524}
525
526fn parse_u32_field(json: &str, key: &str) -> Option<u32> {
528 let raw = extract_json_value(json, key)?;
529 let num: String = raw.chars().take_while(|c| c.is_ascii_digit()).collect();
530 num.parse().ok()
531}
532
533fn extract_object_field<'a>(json: &'a str, key: &str) -> Option<&'a str> {
536 let raw = extract_json_value(json, key)?;
537 if !raw.starts_with('{') {
538 return None;
539 }
540 let end = find_matching_brace(raw, '{', '}')?;
542 Some(&raw[..=end])
543}
544
545fn extract_array_field(json: &str, key: &str) -> Option<Vec<String>> {
548 let raw = extract_json_value(json, key)?;
549 if !raw.starts_with('[') {
550 return None;
551 }
552 let end = find_matching_brace(raw, '[', ']')?;
553 let inner = &raw[1..end]; Some(split_json_array_objects(inner))
555}
556
557fn parse_string_array_field(json: &str, key: &str) -> Option<Vec<String>> {
559 let raw = extract_json_value(json, key)?;
560 if !raw.starts_with('[') {
561 return None;
562 }
563 let end = find_matching_brace(raw, '[', ']')?;
564 let inner = &raw[1..end];
565
566 let mut result = Vec::new();
567 let mut remainder = inner.trim();
568 while !remainder.is_empty() {
569 if remainder.starts_with('"') {
570 let mut chars = remainder.char_indices().skip(1);
572 let mut s = String::new();
573 let mut end_pos = 0;
574 let mut found = false;
575 loop {
576 match chars.next() {
577 None => break,
578 Some((i, '"')) => {
579 end_pos = i;
580 found = true;
581 break;
582 }
583 Some((_, '\\')) => match chars.next() {
584 Some((_, c)) => s.push(c),
585 None => break,
586 },
587 Some((_, c)) => s.push(c),
588 }
589 }
590 if found {
591 result.push(s);
592 remainder = remainder[end_pos + 1..].trim_start_matches(',').trim();
593 } else {
594 break;
595 }
596 } else {
597 let skip = remainder
599 .find(',')
600 .map(|i| i + 1)
601 .unwrap_or(remainder.len());
602 remainder = &remainder[skip..];
603 }
604 }
605 Some(result)
606}
607
608fn parse_vocab_object(json: &str) -> Result<HashMap<String, u32>> {
611 let vocab_raw = extract_object_field(json, "vocab").ok_or_else(|| {
613 TextError::InvalidInput("HF JSON: missing model.vocab object".to_string())
614 })?;
615
616 let inner = &vocab_raw[1..vocab_raw.len() - 1]; let mut map = HashMap::new();
618
619 let mut remainder = inner.trim();
621 while !remainder.is_empty() {
622 if remainder.starts_with('"') {
623 let key = match parse_json_string_at_start(remainder) {
625 Some((s, consumed)) => {
626 remainder = &remainder[consumed..];
627 s
628 }
629 None => break,
630 };
631 remainder = remainder.trim_start();
632 if !remainder.starts_with(':') {
633 break;
634 }
635 remainder = remainder[1..].trim_start();
636 let num_str: String = remainder
638 .chars()
639 .take_while(|c| c.is_ascii_digit())
640 .collect();
641 if num_str.is_empty() {
642 break;
643 }
644 if let Ok(id) = num_str.parse::<u32>() {
645 map.insert(key, id);
646 }
647 remainder = &remainder[num_str.len()..];
648 remainder = remainder.trim_start();
649 if remainder.starts_with(',') {
650 remainder = remainder[1..].trim_start();
651 }
652 } else {
653 remainder = &remainder[1..];
655 }
656 }
657
658 Ok(map)
659}
660
661fn parse_json_string_at_start(s: &str) -> Option<(String, usize)> {
663 if !s.starts_with('"') {
664 return None;
665 }
666 let mut result = String::new();
667 let mut chars = s.char_indices().skip(1);
668 loop {
669 match chars.next() {
670 None => return None,
671 Some((i, '"')) => return Some((result, i + '"'.len_utf8())),
672 Some((_, '\\')) => match chars.next() {
673 Some((_, '"')) => result.push('"'),
674 Some((_, '\\')) => result.push('\\'),
675 Some((_, 'n')) => result.push('\n'),
676 Some((_, 'r')) => result.push('\r'),
677 Some((_, 't')) => result.push('\t'),
678 Some((_, 'u')) => {
679 let mut hex = String::new();
680 for _ in 0..4 {
681 if let Some((_, c)) = chars.next() {
682 hex.push(c);
683 }
684 }
685 if let Ok(n) = u32::from_str_radix(&hex, 16) {
686 if let Some(c) = char::from_u32(n) {
687 result.push(c);
688 }
689 }
690 }
691 Some((_, c)) => result.push(c),
692 None => return None,
693 },
694 Some((_, c)) => result.push(c),
695 }
696 }
697}
698
699fn find_matching_brace(s: &str, open: char, close: char) -> Option<usize> {
702 let mut depth = 0i32;
703 let mut in_string = false;
704 let mut prev_escape = false;
705
706 for (i, ch) in s.char_indices() {
707 if prev_escape {
708 prev_escape = false;
709 continue;
710 }
711 if in_string {
712 if ch == '\\' {
713 prev_escape = true;
714 } else if ch == '"' {
715 in_string = false;
716 }
717 continue;
718 }
719 if ch == '"' {
720 in_string = true;
721 } else if ch == open {
722 depth += 1;
723 } else if ch == close {
724 depth -= 1;
725 if depth == 0 {
726 return Some(i);
727 }
728 }
729 }
730 None
731}
732
733fn split_json_array_objects(inner: &str) -> Vec<String> {
735 let mut result = Vec::new();
736 let mut remainder = inner.trim();
737 while !remainder.is_empty() {
738 if remainder.starts_with('{') {
739 match find_matching_brace(remainder, '{', '}') {
740 Some(end) => {
741 result.push(remainder[..=end].to_string());
742 remainder = remainder[end + 1..].trim_start_matches(',').trim();
743 }
744 None => break,
745 }
746 } else {
747 let skip = remainder.find('{').unwrap_or(remainder.len());
749 if skip == remainder.len() {
750 break;
751 }
752 remainder = &remainder[skip..];
753 }
754 }
755 result
756}
757
758#[cfg(test)]
761mod tests {
762 use super::*;
763 use crate::tokenization::wordpiece::WordPieceTokenizer;
764
765 fn minimal_wp() -> WordPieceTokenizer {
766 let tokens = vec![
767 "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "hello", "world", "##ing", "foo",
768 ];
769 WordPieceTokenizer::from_vocab_list(&tokens)
770 }
771
772 #[test]
773 fn from_wordpiece_model_type() {
774 let wp = minimal_wp();
775 let hf = HfTokenizerJson::from_wordpiece(&wp);
776 assert_eq!(hf.model.model_type, "WordPiece");
777 }
778
779 #[test]
780 fn to_json_string_contains_vocab() {
781 let wp = minimal_wp();
782 let hf = HfTokenizerJson::from_wordpiece(&wp);
783 let s = hf.to_json_string();
784 assert!(s.contains("\"vocab\""), "JSON must contain vocab key");
785 }
786
787 #[test]
788 fn roundtrip_from_json_str() {
789 let wp = minimal_wp();
790 let hf = HfTokenizerJson::from_wordpiece(&wp);
791 let json = hf.to_json_string();
792 let restored = HfTokenizerJson::from_json_str(&json).expect("parse failed");
793 assert_eq!(restored.model.model_type, "WordPiece");
794 }
795
796 #[test]
797 fn detect_model_type_wordpiece() {
798 let wp = minimal_wp();
799 let hf = HfTokenizerJson::from_wordpiece(&wp);
800 let json = hf.to_json_string();
801 let mt = detect_model_type(&json).expect("detect failed");
802 assert_eq!(mt, HfModelType::WordPiece);
803 }
804
805 #[test]
806 fn detect_model_type_bpe() {
807 let json = r#"{"version":"1.0","model":{"type":"BPE","vocab":{"hello":0},"merges":["h e"]},"added_tokens":[]}"#;
809 let mt = detect_model_type(json).expect("detect failed");
810 assert_eq!(mt, HfModelType::Bpe);
811 }
812
813 #[test]
814 fn added_tokens_contains_cls() {
815 let wp = minimal_wp();
816 let hf = HfTokenizerJson::from_wordpiece(&wp);
817 let has_cls = hf.added_tokens.iter().any(|t| t.content == "[CLS]");
818 assert!(has_cls, "added_tokens must contain [CLS]");
819 }
820
821 #[test]
822 fn vocab_size_matches_input() {
823 let wp = minimal_wp();
824 let hf = HfTokenizerJson::from_wordpiece(&wp);
825 assert_eq!(hf.model.vocab.len(), wp.vocab_size());
826 }
827
828 #[test]
829 fn empty_vocab_serialises_without_panic() {
830 let tokens: &[&str] = &[];
831 let wp = WordPieceTokenizer::from_vocab_list(tokens);
832 let hf = HfTokenizerJson::from_wordpiece(&wp);
833 let json = hf.to_json_string();
834 assert!(json.contains("WordPiece"));
835 }
836
837 #[test]
838 fn hf_model_type_variants_accessible() {
839 let _ = HfModelType::WordPiece;
840 let _ = HfModelType::Bpe;
841 let _ = HfModelType::Unigram;
842 let _ = HfModelType::Unknown("X".to_string());
843 }
844
845 #[test]
846 fn invalid_json_returns_err() {
847 let result = HfTokenizerJson::from_json_str("not json at all }{");
848 assert!(result.is_err());
849 }
850
851 #[test]
852 fn roundtrip_check_helper() {
853 let wp = minimal_wp();
854 assert!(HfTokenizerJson::wordpiece_roundtrip_check(&wp));
855 }
856
857 #[test]
858 fn version_field_preserved() {
859 let wp = minimal_wp();
860 let hf = HfTokenizerJson::from_wordpiece(&wp);
861 let json = hf.to_json_string();
862 let restored = HfTokenizerJson::from_json_str(&json).unwrap();
863 assert_eq!(restored.version, "1.0");
864 }
865}