Skip to main content

scirs2_text/tokenization/
hf_json.rs

1//! HuggingFace `tokenizers.json` format serialisation / deserialisation.
2//!
3//! The HuggingFace tokenizers library persists tokenizer configurations as
4//! a single JSON file.  This module can read and write that format so that
5//! tokenizers trained with [`super::wordpiece::WordPieceTokenizer`] or
6//! [`crate::gpt_bpe::Gpt2BpeTokenizer`] can be exchanged with the HF
7//! ecosystem.
8//!
9//! # Format summary
10//!
11//! ```json
12//! {
13//!   "version": "1.0",
14//!   "model": {
15//!     "type": "WordPiece",
16//!     "unk_token": "[UNK]",
17//!     "continuing_subword_prefix": "##",
18//!     "max_input_chars_per_word": 100,
19//!     "vocab": { "[PAD]": 0, "[UNK]": 1, ... }
20//!   },
21//!   "added_tokens": [
22//!     { "id": 0, "content": "[PAD]", "special": true },
23//!     ...
24//!   ],
25//!   "normalizer": null,
26//!   "pre_tokenizer": null,
27//!   "post_processor": null,
28//!   "decoder": null,
29//!   "truncation": null,
30//!   "padding": null
31//! }
32//! ```
33
34use std::collections::HashMap;
35
36use crate::error::{Result, TextError};
37use crate::gpt_bpe::Gpt2BpeTokenizer;
38use crate::tokenization::wordpiece::WordPieceTokenizer;
39
40// ── Model-type enum ──────────────────────────────────────────────────────────
41
42/// The tokenizer model type encoded in a HuggingFace tokenizers JSON file.
43#[derive(Debug, Clone, PartialEq, Eq)]
44#[non_exhaustive]
45pub enum HfModelType {
46    /// BERT WordPiece model.
47    WordPiece,
48    /// Byte-Pair Encoding model (GPT-2, RoBERTa, …).
49    Bpe,
50    /// SentencePiece Unigram Language Model.
51    Unigram,
52    /// Unknown or unrecognised model type.
53    Unknown(String),
54}
55
56impl HfModelType {
57    /// Convert to the string used in the JSON `model.type` field.
58    pub fn as_str(&self) -> &str {
59        match self {
60            HfModelType::WordPiece => "WordPiece",
61            HfModelType::Bpe => "BPE",
62            HfModelType::Unigram => "Unigram",
63            HfModelType::Unknown(s) => s.as_str(),
64        }
65    }
66
67    /// Parse from the JSON string representation.
68    pub fn parse(s: &str) -> Self {
69        match s {
70            "WordPiece" | "wordpiece" | "WORDPIECE" => HfModelType::WordPiece,
71            "BPE" | "Bpe" | "bpe" => HfModelType::Bpe,
72            "Unigram" | "unigram" | "UNIGRAM" => HfModelType::Unigram,
73            other => HfModelType::Unknown(other.to_string()),
74        }
75    }
76}
77
78// ── HfAddedToken ─────────────────────────────────────────────────────────────
79
80/// A special / added token entry in the HF JSON.
81#[derive(Debug, Clone)]
82pub struct HfAddedToken {
83    /// Integer ID in the vocabulary.
84    pub id: u32,
85    /// Token surface string.
86    pub content: String,
87    /// Whether this token is a "special" control token.
88    pub special: bool,
89    /// Whether the token should not be split during pre-tokenisation.
90    pub single_word: bool,
91    /// Whether this token strips leading whitespace.
92    pub lstrip: bool,
93    /// Whether this token strips trailing whitespace.
94    pub rstrip: bool,
95    /// Whether the token is normalised.
96    pub normalized: bool,
97}
98
99impl HfAddedToken {
100    /// Construct a simple special token entry.
101    pub fn special(id: u32, content: impl Into<String>) -> Self {
102        HfAddedToken {
103            id,
104            content: content.into(),
105            special: true,
106            single_word: false,
107            lstrip: false,
108            rstrip: false,
109            normalized: false,
110        }
111    }
112
113    /// Render to a JSON object string (no trailing newline).
114    fn to_json_object(&self) -> String {
115        format!(
116            r#"{{"id":{},"content":{},"single_word":{},"lstrip":{},"rstrip":{},"normalized":{},"special":{}}}"#,
117            self.id,
118            json_string(&self.content),
119            self.single_word,
120            self.lstrip,
121            self.rstrip,
122            self.normalized,
123            self.special,
124        )
125    }
126
127    /// Parse from a JSON object string (best-effort).
128    fn from_json_object(obj: &str) -> Option<Self> {
129        let id = parse_u32_field(obj, "id")?;
130        let content = parse_string_field(obj, "content")?;
131        let special = parse_bool_field(obj, "special").unwrap_or(false);
132        let single_word = parse_bool_field(obj, "single_word").unwrap_or(false);
133        let lstrip = parse_bool_field(obj, "lstrip").unwrap_or(false);
134        let rstrip = parse_bool_field(obj, "rstrip").unwrap_or(false);
135        let normalized = parse_bool_field(obj, "normalized").unwrap_or(false);
136        Some(HfAddedToken {
137            id,
138            content,
139            special,
140            single_word,
141            lstrip,
142            rstrip,
143            normalized,
144        })
145    }
146}
147
148// ── HfModel ──────────────────────────────────────────────────────────────────
149
150/// The `model` object inside a HuggingFace tokenizers JSON.
151#[derive(Debug, Clone)]
152pub struct HfModel {
153    /// The `"type"` discriminant string (e.g. `"WordPiece"`, `"BPE"`).
154    pub model_type: String,
155    /// Token-string → integer-ID vocabulary mapping.
156    pub vocab: HashMap<String, u32>,
157    /// Ordered BPE merge rules as `"A B"` strings (BPE only).
158    pub merges: Option<Vec<String>>,
159    /// The UNK token string.
160    pub unk_token: Option<String>,
161    /// The continuation subword prefix (WordPiece: `"##"`).
162    pub continuing_subword_prefix: Option<String>,
163    /// Maximum number of input characters per word before falling back to UNK.
164    pub max_input_chars_per_word: Option<u32>,
165}
166
167impl HfModel {
168    /// Serialise to a JSON object string.
169    fn to_json_string(&self) -> String {
170        let mut parts: Vec<String> = Vec::new();
171
172        parts.push(format!(r#""type":{}"#, json_string(&self.model_type)));
173
174        if let Some(ref unk) = self.unk_token {
175            parts.push(format!(r#""unk_token":{}"#, json_string(unk)));
176        }
177
178        if let Some(ref pfx) = self.continuing_subword_prefix {
179            parts.push(format!(
180                r#""continuing_subword_prefix":{}"#,
181                json_string(pfx)
182            ));
183        }
184
185        if let Some(max_chars) = self.max_input_chars_per_word {
186            parts.push(format!(r#""max_input_chars_per_word":{}"#, max_chars));
187        }
188
189        // vocab: sort for determinism
190        let vocab_entries = {
191            let mut sorted: Vec<(&String, &u32)> = self.vocab.iter().collect();
192            sorted.sort_by_key(|(_, &id)| id);
193            sorted
194                .iter()
195                .map(|(tok, id)| format!("{}:{}", json_string(tok), id))
196                .collect::<Vec<_>>()
197                .join(",")
198        };
199        parts.push(format!(r#""vocab":{{{}}}"#, vocab_entries));
200
201        if let Some(ref merges) = self.merges {
202            let merge_strs = merges
203                .iter()
204                .map(|m| json_string(m))
205                .collect::<Vec<_>>()
206                .join(",");
207            parts.push(format!(r#""merges":[{}]"#, merge_strs));
208        }
209
210        format!("{{{}}}", parts.join(","))
211    }
212
213    /// Deserialise from the JSON string of a model object.
214    fn from_json_str(s: &str) -> Result<Self> {
215        let model_type = parse_string_field(s, "type").ok_or_else(|| {
216            TextError::InvalidInput("HF JSON: missing model.type field".to_string())
217        })?;
218
219        let unk_token = parse_string_field(s, "unk_token");
220        let continuing_subword_prefix = parse_string_field(s, "continuing_subword_prefix");
221        let max_input_chars_per_word = parse_u32_field(s, "max_input_chars_per_word");
222
223        let vocab = parse_vocab_object(s)?;
224        let merges = parse_string_array_field(s, "merges");
225
226        Ok(HfModel {
227            model_type,
228            vocab,
229            merges,
230            unk_token,
231            continuing_subword_prefix,
232            max_input_chars_per_word,
233        })
234    }
235}
236
237// ── HfTokenizerJson ───────────────────────────────────────────────────────────
238
239/// A complete HuggingFace `tokenizers.json` document.
240#[derive(Debug, Clone)]
241pub struct HfTokenizerJson {
242    /// Format version (typically `"1.0"`).
243    pub version: String,
244    /// The core tokenizer model (vocab, merges, etc.).
245    pub model: HfModel,
246    /// Extra special tokens added on top of the base vocabulary.
247    pub added_tokens: Vec<HfAddedToken>,
248    /// Raw JSON for the normalizer component (null if absent).
249    pub normalizer_json: Option<String>,
250    /// Raw JSON for the pre-tokenizer component.
251    pub pre_tokenizer_json: Option<String>,
252    /// Raw JSON for the post-processor component.
253    pub post_processor_json: Option<String>,
254    /// Raw JSON for the decoder component.
255    pub decoder_json: Option<String>,
256}
257
258impl HfTokenizerJson {
259    // ── Constructors ───────────────────────────────────────────────────────
260
261    /// Build a [`HfTokenizerJson`] from a trained [`WordPieceTokenizer`].
262    ///
263    /// Extracts the vocabulary and annotates the standard BERT special tokens
264    /// as `added_tokens`.
265    pub fn from_wordpiece(wp: &WordPieceTokenizer) -> Self {
266        let vocab: HashMap<String, u32> = wp.vocab_snapshot();
267
268        // Standard BERT special token IDs (use vocab lookup with fallbacks)
269        let get = |tok: &str, fallback: u32| -> u32 { vocab.get(tok).copied().unwrap_or(fallback) };
270
271        let added_tokens = vec![
272            HfAddedToken::special(get("[PAD]", 0), "[PAD]"),
273            HfAddedToken::special(get("[UNK]", 1), "[UNK]"),
274            HfAddedToken::special(get("[CLS]", 101), "[CLS]"),
275            HfAddedToken::special(get("[SEP]", 102), "[SEP]"),
276            HfAddedToken::special(get("[MASK]", 103), "[MASK]"),
277        ];
278
279        let model = HfModel {
280            model_type: "WordPiece".to_string(),
281            vocab,
282            merges: None,
283            unk_token: Some("[UNK]".to_string()),
284            continuing_subword_prefix: Some("##".to_string()),
285            max_input_chars_per_word: Some(100),
286        };
287
288        HfTokenizerJson {
289            version: "1.0".to_string(),
290            model,
291            added_tokens,
292            normalizer_json: None,
293            pre_tokenizer_json: None,
294            post_processor_json: None,
295            decoder_json: None,
296        }
297    }
298
299    /// Build a [`HfTokenizerJson`] from a trained [`Gpt2BpeTokenizer`].
300    pub fn from_gpt2_bpe(bpe: &Gpt2BpeTokenizer) -> Self {
301        let vocab: HashMap<String, u32> = bpe.vocab_snapshot();
302        let merges: Vec<String> = bpe
303            .merges()
304            .iter()
305            .map(|(a, b)| format!("{} {}", a, b))
306            .collect();
307
308        let model = HfModel {
309            model_type: "BPE".to_string(),
310            vocab,
311            merges: Some(merges),
312            unk_token: None,
313            continuing_subword_prefix: None,
314            max_input_chars_per_word: None,
315        };
316
317        HfTokenizerJson {
318            version: "1.0".to_string(),
319            model,
320            added_tokens: vec![],
321            normalizer_json: None,
322            pre_tokenizer_json: None,
323            post_processor_json: None,
324            decoder_json: None,
325        }
326    }
327
328    // ── Serialisation ──────────────────────────────────────────────────────
329
330    /// Serialise to a JSON string.
331    ///
332    /// When the `serde-support` feature is enabled this delegates to
333    /// `serde_json`; otherwise a manual serialiser is used.
334    pub fn to_json_string(&self) -> String {
335        let added_tokens_str = self
336            .added_tokens
337            .iter()
338            .map(|t| t.to_json_object())
339            .collect::<Vec<_>>()
340            .join(",");
341
342        let null_or =
343            |opt: &Option<String>| -> String { opt.as_deref().unwrap_or("null").to_string() };
344
345        format!(
346            r#"{{"version":{},"truncation":null,"padding":null,"added_tokens":[{}],"normalizer":{},"pre_tokenizer":{},"post_processor":{},"decoder":{},"model":{}}}"#,
347            json_string(&self.version),
348            added_tokens_str,
349            null_or(&self.normalizer_json),
350            null_or(&self.pre_tokenizer_json),
351            null_or(&self.post_processor_json),
352            null_or(&self.decoder_json),
353            self.model.to_json_string(),
354        )
355    }
356
357    // ── Deserialisation ────────────────────────────────────────────────────
358
359    /// Parse a HuggingFace `tokenizers.json` string.
360    pub fn from_json_str(s: &str) -> Result<Self> {
361        let version = parse_string_field(s, "version").unwrap_or_else(|| "1.0".to_string());
362
363        // Extract the model object
364        let model_str = extract_object_field(s, "model").ok_or_else(|| {
365            TextError::InvalidInput("HF JSON: missing 'model' object".to_string())
366        })?;
367        let model = HfModel::from_json_str(model_str)?;
368
369        // Extract added_tokens array
370        let added_tokens = extract_array_field(s, "added_tokens")
371            .unwrap_or_default()
372            .iter()
373            .filter_map(|obj| HfAddedToken::from_json_object(obj))
374            .collect();
375
376        let normalizer_json = extract_object_field(s, "normalizer").map(|o| o.to_string());
377        let pre_tokenizer_json = extract_object_field(s, "pre_tokenizer").map(|o| o.to_string());
378        let post_processor_json = extract_object_field(s, "post_processor").map(|o| o.to_string());
379        let decoder_json = extract_object_field(s, "decoder").map(|o| o.to_string());
380
381        Ok(HfTokenizerJson {
382            version,
383            model,
384            added_tokens,
385            normalizer_json,
386            pre_tokenizer_json,
387            post_processor_json,
388            decoder_json,
389        })
390    }
391
392    // ── Utilities ──────────────────────────────────────────────────────────
393
394    /// Check that the tokenizer round-trips through JSON without data loss.
395    ///
396    /// Returns `true` when the vocab sizes before and after serialisation
397    /// match and the model type is preserved.
398    pub fn wordpiece_roundtrip_check(wp: &WordPieceTokenizer) -> bool {
399        let original = Self::from_wordpiece(wp);
400        let json = original.to_json_string();
401        match Self::from_json_str(&json) {
402            Ok(restored) => {
403                restored.model.vocab.len() == original.model.vocab.len()
404                    && restored.model.model_type == original.model.model_type
405            }
406            Err(_) => false,
407        }
408    }
409}
410
411// ── Free function ─────────────────────────────────────────────────────────────
412
413/// Peek at a HuggingFace tokenizers JSON string and return the model type.
414pub fn detect_model_type(json: &str) -> Result<HfModelType> {
415    let model_str = extract_object_field(json, "model").ok_or_else(|| {
416        TextError::InvalidInput("HF JSON: could not locate 'model' object".to_string())
417    })?;
418    let type_str = parse_string_field(model_str, "type")
419        .ok_or_else(|| TextError::InvalidInput("HF JSON: missing model.type field".to_string()))?;
420    Ok(HfModelType::parse(&type_str))
421}
422
423// Note: vocab_snapshot() methods are defined directly on WordPieceTokenizer
424// (in tokenization/wordpiece.rs) and Gpt2BpeTokenizer (in gpt_bpe.rs).
425// Both return HashMap<String, u32>.
426
427// ── Minimal JSON helpers ─────────────────────────────────────────────────────
428
429/// Encode a string as a JSON string literal (with escaping).
430fn json_string(s: &str) -> String {
431    let mut out = String::with_capacity(s.len() + 2);
432    out.push('"');
433    for ch in s.chars() {
434        match ch {
435            '"' => out.push_str(r#"\""#),
436            '\\' => out.push_str(r"\\"),
437            '\n' => out.push_str(r"\n"),
438            '\r' => out.push_str(r"\r"),
439            '\t' => out.push_str(r"\t"),
440            c if (c as u32) < 0x20 => {
441                out.push_str(&format!("\\u{:04x}", c as u32));
442            }
443            c => out.push(c),
444        }
445    }
446    out.push('"');
447    out
448}
449
450/// Extract the raw value (string, number, or object) for a JSON key in a
451/// *flat* (non-nested) context.  Returns `None` when the key is absent or the
452/// value is `null`.
453///
454/// This is intentionally simple: it searches for `"key":` then extracts the
455/// next token.  It works for the HF JSON format because top-level keys are not
456/// repeated in nested structures at the same parse depth.
457fn extract_json_value<'a>(json: &'a str, key: &str) -> Option<&'a str> {
458    let needle = format!("\"{}\":", key);
459    let pos = json.find(needle.as_str())?;
460    let after_key = json[pos + needle.len()..].trim_start();
461
462    if after_key.starts_with("null") {
463        return None;
464    }
465
466    // Return a slice covering the raw value (string, number, bool, or {…}/[…])
467    Some(after_key)
468}
469
470/// Parse a JSON string field from a JSON object string.
471fn parse_string_field(json: &str, key: &str) -> Option<String> {
472    let raw = extract_json_value(json, key)?;
473    if !raw.starts_with('"') {
474        return None;
475    }
476    // Walk forward to find the closing unescaped quote
477    let mut chars = raw.char_indices().skip(1); // skip opening "
478    let mut result = String::new();
479    loop {
480        match chars.next() {
481            None => return None,
482            Some((_, '"')) => break,
483            Some((_, '\\')) => {
484                match chars.next() {
485                    Some((_, '"')) => result.push('"'),
486                    Some((_, '\\')) => result.push('\\'),
487                    Some((_, 'n')) => result.push('\n'),
488                    Some((_, 'r')) => result.push('\r'),
489                    Some((_, 't')) => result.push('\t'),
490                    Some((_, 'u')) => {
491                        // \uXXXX
492                        let mut hex = String::new();
493                        for _ in 0..4 {
494                            if let Some((_, c)) = chars.next() {
495                                hex.push(c);
496                            }
497                        }
498                        if let Ok(n) = u32::from_str_radix(&hex, 16) {
499                            if let Some(c) = char::from_u32(n) {
500                                result.push(c);
501                            }
502                        }
503                    }
504                    Some((_, c)) => result.push(c),
505                    None => return None,
506                }
507            }
508            Some((_, c)) => result.push(c),
509        }
510    }
511    Some(result)
512}
513
514/// Parse a JSON boolean field.
515fn parse_bool_field(json: &str, key: &str) -> Option<bool> {
516    let raw = extract_json_value(json, key)?;
517    if raw.starts_with("true") {
518        Some(true)
519    } else if raw.starts_with("false") {
520        Some(false)
521    } else {
522        None
523    }
524}
525
526/// Parse a JSON unsigned-integer field.
527fn parse_u32_field(json: &str, key: &str) -> Option<u32> {
528    let raw = extract_json_value(json, key)?;
529    let num: String = raw.chars().take_while(|c| c.is_ascii_digit()).collect();
530    num.parse().ok()
531}
532
533/// Extract the inner text of a JSON object `{...}` for a given key.
534/// Returns `None` when the value is `null` or not an object.
535fn extract_object_field<'a>(json: &'a str, key: &str) -> Option<&'a str> {
536    let raw = extract_json_value(json, key)?;
537    if !raw.starts_with('{') {
538        return None;
539    }
540    // Find matching closing brace
541    let end = find_matching_brace(raw, '{', '}')?;
542    Some(&raw[..=end])
543}
544
545/// Extract each element of a JSON array `[...]` for a given key.
546/// Only handles arrays of objects `[{...},{...}]`.
547fn extract_array_field(json: &str, key: &str) -> Option<Vec<String>> {
548    let raw = extract_json_value(json, key)?;
549    if !raw.starts_with('[') {
550        return None;
551    }
552    let end = find_matching_brace(raw, '[', ']')?;
553    let inner = &raw[1..end]; // strip [ and ]
554    Some(split_json_array_objects(inner))
555}
556
557/// Parse a JSON string-array field (e.g. the `merges` list).
558fn parse_string_array_field(json: &str, key: &str) -> Option<Vec<String>> {
559    let raw = extract_json_value(json, key)?;
560    if !raw.starts_with('[') {
561        return None;
562    }
563    let end = find_matching_brace(raw, '[', ']')?;
564    let inner = &raw[1..end];
565
566    let mut result = Vec::new();
567    let mut remainder = inner.trim();
568    while !remainder.is_empty() {
569        if remainder.starts_with('"') {
570            // Scan for the end of the string
571            let mut chars = remainder.char_indices().skip(1);
572            let mut s = String::new();
573            let mut end_pos = 0;
574            let mut found = false;
575            loop {
576                match chars.next() {
577                    None => break,
578                    Some((i, '"')) => {
579                        end_pos = i;
580                        found = true;
581                        break;
582                    }
583                    Some((_, '\\')) => match chars.next() {
584                        Some((_, c)) => s.push(c),
585                        None => break,
586                    },
587                    Some((_, c)) => s.push(c),
588                }
589            }
590            if found {
591                result.push(s);
592                remainder = remainder[end_pos + 1..].trim_start_matches(',').trim();
593            } else {
594                break;
595            }
596        } else {
597            // Skip non-string element
598            let skip = remainder
599                .find(',')
600                .map(|i| i + 1)
601                .unwrap_or(remainder.len());
602            remainder = &remainder[skip..];
603        }
604    }
605    Some(result)
606}
607
608/// Parse the `vocab` object (`{"token": id, ...}`) embedded within a model
609/// object string.
610fn parse_vocab_object(json: &str) -> Result<HashMap<String, u32>> {
611    // Find the vocab sub-object
612    let vocab_raw = extract_object_field(json, "vocab").ok_or_else(|| {
613        TextError::InvalidInput("HF JSON: missing model.vocab object".to_string())
614    })?;
615
616    let inner = &vocab_raw[1..vocab_raw.len() - 1]; // strip { }
617    let mut map = HashMap::new();
618
619    // Parse "token": id pairs separated by commas
620    let mut remainder = inner.trim();
621    while !remainder.is_empty() {
622        if remainder.starts_with('"') {
623            // Parse key string
624            let key = match parse_json_string_at_start(remainder) {
625                Some((s, consumed)) => {
626                    remainder = &remainder[consumed..];
627                    s
628                }
629                None => break,
630            };
631            remainder = remainder.trim_start();
632            if !remainder.starts_with(':') {
633                break;
634            }
635            remainder = remainder[1..].trim_start();
636            // Parse numeric value
637            let num_str: String = remainder
638                .chars()
639                .take_while(|c| c.is_ascii_digit())
640                .collect();
641            if num_str.is_empty() {
642                break;
643            }
644            if let Ok(id) = num_str.parse::<u32>() {
645                map.insert(key, id);
646            }
647            remainder = &remainder[num_str.len()..];
648            remainder = remainder.trim_start();
649            if remainder.starts_with(',') {
650                remainder = remainder[1..].trim_start();
651            }
652        } else {
653            // Skip unexpected characters
654            remainder = &remainder[1..];
655        }
656    }
657
658    Ok(map)
659}
660
661/// Parse a JSON string at the start of `s`, returning `(value, bytes_consumed)`.
662fn parse_json_string_at_start(s: &str) -> Option<(String, usize)> {
663    if !s.starts_with('"') {
664        return None;
665    }
666    let mut result = String::new();
667    let mut chars = s.char_indices().skip(1);
668    loop {
669        match chars.next() {
670            None => return None,
671            Some((i, '"')) => return Some((result, i + '"'.len_utf8())),
672            Some((_, '\\')) => match chars.next() {
673                Some((_, '"')) => result.push('"'),
674                Some((_, '\\')) => result.push('\\'),
675                Some((_, 'n')) => result.push('\n'),
676                Some((_, 'r')) => result.push('\r'),
677                Some((_, 't')) => result.push('\t'),
678                Some((_, 'u')) => {
679                    let mut hex = String::new();
680                    for _ in 0..4 {
681                        if let Some((_, c)) = chars.next() {
682                            hex.push(c);
683                        }
684                    }
685                    if let Ok(n) = u32::from_str_radix(&hex, 16) {
686                        if let Some(c) = char::from_u32(n) {
687                            result.push(c);
688                        }
689                    }
690                }
691                Some((_, c)) => result.push(c),
692                None => return None,
693            },
694            Some((_, c)) => result.push(c),
695        }
696    }
697}
698
699/// Find the byte offset of the closing bracket/brace matching the opening
700/// bracket/brace at position 0 of `s`.
701fn find_matching_brace(s: &str, open: char, close: char) -> Option<usize> {
702    let mut depth = 0i32;
703    let mut in_string = false;
704    let mut prev_escape = false;
705
706    for (i, ch) in s.char_indices() {
707        if prev_escape {
708            prev_escape = false;
709            continue;
710        }
711        if in_string {
712            if ch == '\\' {
713                prev_escape = true;
714            } else if ch == '"' {
715                in_string = false;
716            }
717            continue;
718        }
719        if ch == '"' {
720            in_string = true;
721        } else if ch == open {
722            depth += 1;
723        } else if ch == close {
724            depth -= 1;
725            if depth == 0 {
726                return Some(i);
727            }
728        }
729    }
730    None
731}
732
733/// Split a JSON array's inner text `{...},{...},...` into object strings.
734fn split_json_array_objects(inner: &str) -> Vec<String> {
735    let mut result = Vec::new();
736    let mut remainder = inner.trim();
737    while !remainder.is_empty() {
738        if remainder.starts_with('{') {
739            match find_matching_brace(remainder, '{', '}') {
740                Some(end) => {
741                    result.push(remainder[..=end].to_string());
742                    remainder = remainder[end + 1..].trim_start_matches(',').trim();
743                }
744                None => break,
745            }
746        } else {
747            // Skip unexpected content
748            let skip = remainder.find('{').unwrap_or(remainder.len());
749            if skip == remainder.len() {
750                break;
751            }
752            remainder = &remainder[skip..];
753        }
754    }
755    result
756}
757
758// ── Tests ─────────────────────────────────────────────────────────────────────
759
760#[cfg(test)]
761mod tests {
762    use super::*;
763    use crate::tokenization::wordpiece::WordPieceTokenizer;
764
765    fn minimal_wp() -> WordPieceTokenizer {
766        let tokens = vec![
767            "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "hello", "world", "##ing", "foo",
768        ];
769        WordPieceTokenizer::from_vocab_list(&tokens)
770    }
771
772    #[test]
773    fn from_wordpiece_model_type() {
774        let wp = minimal_wp();
775        let hf = HfTokenizerJson::from_wordpiece(&wp);
776        assert_eq!(hf.model.model_type, "WordPiece");
777    }
778
779    #[test]
780    fn to_json_string_contains_vocab() {
781        let wp = minimal_wp();
782        let hf = HfTokenizerJson::from_wordpiece(&wp);
783        let s = hf.to_json_string();
784        assert!(s.contains("\"vocab\""), "JSON must contain vocab key");
785    }
786
787    #[test]
788    fn roundtrip_from_json_str() {
789        let wp = minimal_wp();
790        let hf = HfTokenizerJson::from_wordpiece(&wp);
791        let json = hf.to_json_string();
792        let restored = HfTokenizerJson::from_json_str(&json).expect("parse failed");
793        assert_eq!(restored.model.model_type, "WordPiece");
794    }
795
796    #[test]
797    fn detect_model_type_wordpiece() {
798        let wp = minimal_wp();
799        let hf = HfTokenizerJson::from_wordpiece(&wp);
800        let json = hf.to_json_string();
801        let mt = detect_model_type(&json).expect("detect failed");
802        assert_eq!(mt, HfModelType::WordPiece);
803    }
804
805    #[test]
806    fn detect_model_type_bpe() {
807        // Build a minimal BPE-type JSON by hand
808        let json = r#"{"version":"1.0","model":{"type":"BPE","vocab":{"hello":0},"merges":["h e"]},"added_tokens":[]}"#;
809        let mt = detect_model_type(json).expect("detect failed");
810        assert_eq!(mt, HfModelType::Bpe);
811    }
812
813    #[test]
814    fn added_tokens_contains_cls() {
815        let wp = minimal_wp();
816        let hf = HfTokenizerJson::from_wordpiece(&wp);
817        let has_cls = hf.added_tokens.iter().any(|t| t.content == "[CLS]");
818        assert!(has_cls, "added_tokens must contain [CLS]");
819    }
820
821    #[test]
822    fn vocab_size_matches_input() {
823        let wp = minimal_wp();
824        let hf = HfTokenizerJson::from_wordpiece(&wp);
825        assert_eq!(hf.model.vocab.len(), wp.vocab_size());
826    }
827
828    #[test]
829    fn empty_vocab_serialises_without_panic() {
830        let tokens: &[&str] = &[];
831        let wp = WordPieceTokenizer::from_vocab_list(tokens);
832        let hf = HfTokenizerJson::from_wordpiece(&wp);
833        let json = hf.to_json_string();
834        assert!(json.contains("WordPiece"));
835    }
836
837    #[test]
838    fn hf_model_type_variants_accessible() {
839        let _ = HfModelType::WordPiece;
840        let _ = HfModelType::Bpe;
841        let _ = HfModelType::Unigram;
842        let _ = HfModelType::Unknown("X".to_string());
843    }
844
845    #[test]
846    fn invalid_json_returns_err() {
847        let result = HfTokenizerJson::from_json_str("not json at all }{");
848        assert!(result.is_err());
849    }
850
851    #[test]
852    fn roundtrip_check_helper() {
853        let wp = minimal_wp();
854        assert!(HfTokenizerJson::wordpiece_roundtrip_check(&wp));
855    }
856
857    #[test]
858    fn version_field_preserved() {
859        let wp = minimal_wp();
860        let hf = HfTokenizerJson::from_wordpiece(&wp);
861        let json = hf.to_json_string();
862        let restored = HfTokenizerJson::from_json_str(&json).unwrap();
863        assert_eq!(restored.version, "1.0");
864    }
865}