Skip to main content

scirs2_text/tokenization/
byte_level_bpe.rs

1//! Byte-Level BPE tokenizer — GPT-2/GPT-4 style.
2//!
3//! Encodes each byte as a unicode character using the GPT-2 byte→unicode table,
4//! then applies BPE merges on top of that representation.
5//!
6//! The mapping is bijective: 256 bytes → 256 distinct unicode code points.
7
8use crate::error::{Result, TextError};
9use std::collections::HashMap;
10use std::io::{BufRead, BufReader, Write};
11
12/// Type alias for the four-tuple returned by `init_base`.
13type InitBaseMaps = (
14    HashMap<u8, char>,
15    HashMap<char, u8>,
16    HashMap<String, u32>,
17    Vec<String>,
18);
19
20// ─── GPT-2 byte-to-unicode table ─────────────────────────────────────────────
21
22/// Build the GPT-2 byte→unicode bijection.
23///
24/// Bytes that are already printable ASCII (33-126) or in the
25/// Latin-1 supplement printable range (161-172, 174-255) map to themselves.
26/// The remaining 68 bytes (0-32, 127-160, 173) map to the
27/// consecutive block starting at U+0100 (LATIN CAPITAL LETTER A WITH MACRON).
28pub fn bytes_to_unicode() -> HashMap<u8, char> {
29    // Collect the "nice" printable bytes first
30    let mut bs: Vec<u8> = (b'!'..=b'~').collect(); // 33-126
31    bs.extend(b'\xa1'..=b'\xac'); // 161-172
32    bs.extend(b'\xae'..=b'\xff'); // 174-255
33
34    // The remaining bytes need remapping — assign them unicode codepoints
35    // starting at U+0100 in order of their byte value.
36    let mut cs: Vec<char> = bs.iter().map(|&b| b as char).collect();
37    let mut n = 0u32; // offset counter into the extension block
38    for b in 0u8..=255u8 {
39        if !bs.contains(&b) {
40            bs.push(b);
41            // U+0100 + n
42            let cp = 0x0100u32 + n;
43            cs.push(char::from_u32(cp).unwrap_or('\u{0100}'));
44            n += 1;
45        }
46    }
47
48    bs.into_iter().zip(cs).collect()
49}
50
51// ─── ByteLevelBpeConfig ───────────────────────────────────────────────────────
52
53/// Configuration for [`ByteLevelBpeTokenizer`] training.
54#[derive(Debug, Clone)]
55pub struct ByteLevelBpeConfig {
56    /// Target vocabulary size (includes the 256 byte-level base tokens).
57    pub vocab_size: usize,
58    /// Minimum pair frequency required for a merge to be created.
59    pub min_frequency: usize,
60    /// Whether to add a space prefix (Ġ) before each word except the first.
61    pub add_prefix_space: bool,
62}
63
64impl Default for ByteLevelBpeConfig {
65    fn default() -> Self {
66        ByteLevelBpeConfig {
67            vocab_size: 50257,
68            min_frequency: 2,
69            add_prefix_space: true,
70        }
71    }
72}
73
74// ─── ByteLevelBpeTokenizer ────────────────────────────────────────────────────
75
76/// GPT-2-style byte-level BPE tokenizer.
77///
78/// Every input byte is first mapped to a unique unicode character via the
79/// GPT-2 byte→unicode table, so the BPE algorithm operates on unicode
80/// character sequences.  This makes the tokenizer vocabulary guaranteed to be
81/// lossless and eliminates any `[UNK]` token for arbitrary UTF-8 input.
82#[derive(Debug, Clone)]
83pub struct ByteLevelBpeTokenizer {
84    /// token string → integer id
85    pub vocab: HashMap<String, u32>,
86    /// integer id → token string
87    pub id_to_token: Vec<String>,
88    /// ordered merge rules (left_piece, right_piece)
89    pub merges: Vec<(String, String)>,
90    /// byte → unicode char
91    pub byte_encoder: HashMap<u8, char>,
92    /// unicode char → byte  (inverse of byte_encoder)
93    pub byte_decoder: HashMap<char, u8>,
94}
95
96// Internal helpers
97impl ByteLevelBpeTokenizer {
98    /// Build encoder/decoder maps and seed base vocabulary from 256 bytes.
99    fn init_base() -> InitBaseMaps {
100        let byte_encoder = bytes_to_unicode();
101        let byte_decoder: HashMap<char, u8> = byte_encoder.iter().map(|(&b, &c)| (c, b)).collect();
102
103        let mut vocab: HashMap<String, u32> = HashMap::new();
104        let mut id_to_token: Vec<String> = Vec::new();
105        // Add all 256 byte-level characters in byte-value order
106        for b in 0u8..=255u8 {
107            let ch = byte_encoder[&b];
108            let tok = ch.to_string();
109            if !vocab.contains_key(&tok) {
110                let id = id_to_token.len() as u32;
111                vocab.insert(tok.clone(), id);
112                id_to_token.push(tok);
113            }
114        }
115        (byte_encoder, byte_decoder, vocab, id_to_token)
116    }
117
118    /// Encode a single `word` string (already byte-encoded) into a list of
119    /// individual character tokens, then apply all known merge rules.
120    fn apply_merges(&self, chars: Vec<String>) -> Vec<String> {
121        let mut word = chars;
122        // Build a fast merge-priority map
123        let merge_rank: HashMap<(&str, &str), usize> = self
124            .merges
125            .iter()
126            .enumerate()
127            .map(|(i, (a, b))| (a.as_str(), b.as_str()))
128            // We can't borrow from a temporary this way — collect differently
129            .enumerate()
130            .map(|(i, _)| (("", ""), i)) // placeholder, rebuilt below
131            .collect();
132        // Rebuild properly using indices
133        let merge_rank: HashMap<(String, String), usize> = self
134            .merges
135            .iter()
136            .enumerate()
137            .map(|(i, (a, b))| ((a.clone(), b.clone()), i))
138            .collect();
139
140        loop {
141            if word.len() < 2 {
142                break;
143            }
144            // Find the highest-priority (lowest rank) adjacent pair
145            let mut best_rank = usize::MAX;
146            let mut best_idx = usize::MAX;
147            for i in 0..word.len() - 1 {
148                let pair = (word[i].clone(), word[i + 1].clone());
149                if let Some(&rank) = merge_rank.get(&pair) {
150                    if rank < best_rank {
151                        best_rank = rank;
152                        best_idx = i;
153                    }
154                }
155            }
156            if best_idx == usize::MAX {
157                break; // no more merges possible
158            }
159            // Merge at best_idx
160            let merged = format!("{}{}", word[best_idx], word[best_idx + 1]);
161            word.remove(best_idx + 1);
162            word[best_idx] = merged;
163        }
164        word
165    }
166
167    /// Byte-encode a string: each UTF-8 byte is mapped to its unicode char.
168    fn byte_encode_str(&self, s: &str) -> Vec<String> {
169        s.bytes()
170            .map(|b| {
171                self.byte_encoder
172                    .get(&b)
173                    .copied()
174                    .unwrap_or('\u{FFFD}')
175                    .to_string()
176            })
177            .collect()
178    }
179}
180
181// ─── Training ────────────────────────────────────────────────────────────────
182
183impl ByteLevelBpeTokenizer {
184    /// Train a new [`ByteLevelBpeTokenizer`] from raw text slices.
185    ///
186    /// Pre-tokenises on whitespace boundaries and prepends `Ġ` (U+0120) to
187    /// every word that is **not** at the beginning of the pre-tokenised
188    /// sequence.
189    pub fn train(texts: &[&str], config: ByteLevelBpeConfig) -> Self {
190        let (byte_encoder, byte_decoder, mut vocab, mut id_to_token) = Self::init_base();
191
192        // Count word frequencies after byte-encoding.
193        // IMPORTANT: the Ġ prefix is the *byte-level* representation of byte 0x20
194        // (space).  We must NOT format `"\u{0120}word"` and then call `.bytes()`
195        // because that would encode the UTF-8 bytes of Ġ (0xC4 0xA0) rather than
196        // the single byte-level token Ġ.  Instead, prepend the encoded form of
197        // byte 0x20 directly.
198        let space_char = byte_encoder
199            .get(&0x20u8)
200            .copied()
201            .unwrap_or('\u{0120}')
202            .to_string();
203
204        let mut word_freq: HashMap<Vec<String>, usize> = HashMap::new();
205        for text in texts {
206            // simple whitespace pre-tokenisation
207            let mut first = true;
208            for word in text.split_whitespace() {
209                // byte-encode the word's raw UTF-8 bytes
210                let mut encoded: Vec<String> = word
211                    .bytes()
212                    .map(|b| {
213                        byte_encoder
214                            .get(&b)
215                            .copied()
216                            .unwrap_or('\u{FFFD}')
217                            .to_string()
218                    })
219                    .collect();
220                // Prepend the byte-level space token to non-first words
221                if !first && config.add_prefix_space {
222                    encoded.insert(0, space_char.clone());
223                }
224                first = false;
225                *word_freq.entry(encoded).or_insert(0) += 1;
226            }
227        }
228
229        let mut merges: Vec<(String, String)> = Vec::new();
230
231        // BPE merge loop
232        while vocab.len() < config.vocab_size {
233            // Count pair frequencies weighted by word frequency
234            let mut pair_freq: HashMap<(String, String), usize> = HashMap::new();
235            for (word, &freq) in &word_freq {
236                for i in 0..word.len().saturating_sub(1) {
237                    let pair = (word[i].clone(), word[i + 1].clone());
238                    *pair_freq.entry(pair).or_insert(0) += freq;
239                }
240            }
241
242            // Find best pair
243            let best = pair_freq
244                .iter()
245                .filter(|(_, &f)| f >= config.min_frequency)
246                .max_by_key(|(_, &f)| f);
247
248            let (left, right) = match best {
249                Some(((l, r), _)) => (l.clone(), r.clone()),
250                None => break,
251            };
252
253            // Record merge
254            merges.push((left.clone(), right.clone()));
255            let merged = format!("{}{}", left, right);
256            let new_id = id_to_token.len() as u32;
257            vocab.insert(merged.clone(), new_id);
258            id_to_token.push(merged.clone());
259
260            // Apply merge to all words
261            let updated: HashMap<Vec<String>, usize> = word_freq
262                .into_iter()
263                .map(|(word, freq)| {
264                    let new_word = merge_pair_in_word(word, &left, &right);
265                    (new_word, freq)
266                })
267                .collect();
268            word_freq = updated;
269        }
270
271        ByteLevelBpeTokenizer {
272            vocab,
273            id_to_token,
274            merges,
275            byte_encoder,
276            byte_decoder,
277        }
278    }
279}
280
281/// Merge all occurrences of (left, right) adjacent pair in `word`.
282fn merge_pair_in_word(word: Vec<String>, left: &str, right: &str) -> Vec<String> {
283    let mut result = Vec::with_capacity(word.len());
284    let mut i = 0;
285    while i < word.len() {
286        if i + 1 < word.len() && word[i] == left && word[i + 1] == right {
287            result.push(format!("{}{}", left, right));
288            i += 2;
289        } else {
290            result.push(word[i].clone());
291            i += 1;
292        }
293    }
294    result
295}
296
297// ─── Encoding / Decoding ─────────────────────────────────────────────────────
298
299impl ByteLevelBpeTokenizer {
300    /// Encode `text` to a sequence of token IDs.
301    ///
302    /// Applies the same whitespace pre-tokenisation + byte-encoding + BPE
303    /// merges as during training.
304    pub fn encode(&self, text: &str) -> Vec<u32> {
305        let mut ids = Vec::new();
306        // The byte-level token for space (byte 0x20) is used as a word-prefix
307        // marker — it is the *character* that byte 0x20 maps to in the
308        // byte_encoder (i.e. Ġ = U+0120).  We prepend it directly (without
309        // going through byte_encode_str again) so that the Ġ character itself
310        // ends up as a single token rather than its two UTF-8 bytes.
311        let space_tok = self
312            .byte_encoder
313            .get(&0x20u8)
314            .copied()
315            .unwrap_or('\u{0120}')
316            .to_string();
317
318        let mut first = true;
319        for word in text.split_whitespace() {
320            // Byte-encode only the word's own UTF-8 bytes
321            let mut chars = self.byte_encode_str(word);
322            // Prepend the byte-level space token for non-first words
323            if !first {
324                chars.insert(0, space_tok.clone());
325            }
326            first = false;
327            let merged = self.apply_merges(chars);
328            for tok in merged {
329                if let Some(&id) = self.vocab.get(&tok) {
330                    ids.push(id);
331                }
332                // With byte-level encoding, every byte maps to a valid base token.
333                // Unknown tokens should not occur, but we silently skip them if they do.
334            }
335        }
336        ids
337    }
338
339    /// Decode a sequence of token IDs back to a UTF-8 string.
340    ///
341    /// This is lossless: `decode(encode(text)) == text` for any valid UTF-8.
342    pub fn decode(&self, ids: &[u32]) -> String {
343        // Map ids → token strings → bytes
344        let mut bytes: Vec<u8> = Vec::new();
345        for &id in ids {
346            if let Some(tok) = self.id_to_token.get(id as usize) {
347                for ch in tok.chars() {
348                    if let Some(&b) = self.byte_decoder.get(&ch) {
349                        bytes.push(b);
350                    }
351                }
352            }
353        }
354        String::from_utf8_lossy(&bytes).into_owned()
355    }
356}
357
358// ─── Serialisation ───────────────────────────────────────────────────────────
359
360impl ByteLevelBpeTokenizer {
361    /// Save vocabulary (HuggingFace JSON format) and merge rules to separate files.
362    ///
363    /// The vocab file is a JSON object mapping token strings to integer IDs.
364    /// The merges file contains one merge rule per line: `left right`.
365    pub fn save_vocab(&self, vocab_path: &str, merges_path: &str) -> Result<()> {
366        // Write vocab JSON
367        {
368            let mut f =
369                std::fs::File::create(vocab_path).map_err(|e| TextError::IoError(e.to_string()))?;
370            // Manually write JSON to avoid external dependency
371            write!(f, "{{").map_err(|e| TextError::IoError(e.to_string()))?;
372            let mut pairs: Vec<(&String, &u32)> = self.vocab.iter().collect();
373            pairs.sort_by_key(|(_, &id)| id);
374            for (i, (tok, id)) in pairs.iter().enumerate() {
375                let escaped = escape_json_string(tok);
376                if i + 1 < pairs.len() {
377                    write!(f, "\"{}\": {}, ", escaped, id)
378                        .map_err(|e| TextError::IoError(e.to_string()))?;
379                } else {
380                    write!(f, "\"{}\": {}", escaped, id)
381                        .map_err(|e| TextError::IoError(e.to_string()))?;
382                }
383            }
384            writeln!(f, "}}").map_err(|e| TextError::IoError(e.to_string()))?;
385        }
386
387        // Write merges
388        {
389            let mut f = std::fs::File::create(merges_path)
390                .map_err(|e| TextError::IoError(e.to_string()))?;
391            writeln!(f, "#version: 0.2").map_err(|e| TextError::IoError(e.to_string()))?;
392            for (left, right) in &self.merges {
393                writeln!(f, "{} {}", left, right).map_err(|e| TextError::IoError(e.to_string()))?;
394            }
395        }
396        Ok(())
397    }
398
399    /// Load a tokenizer from a HuggingFace-format vocab JSON and merges text file.
400    pub fn load(vocab_path: &str, merges_path: &str) -> Result<Self> {
401        // Parse vocab JSON (minimal, no external dep)
402        let vocab_content =
403            std::fs::read_to_string(vocab_path).map_err(|e| TextError::IoError(e.to_string()))?;
404        let vocab = parse_vocab_json(&vocab_content)?;
405
406        // Build id_to_token
407        let max_id = vocab.values().copied().max().unwrap_or(0) as usize;
408        let mut id_to_token = vec![String::new(); max_id + 1];
409        for (tok, &id) in &vocab {
410            if let Some(slot) = id_to_token.get_mut(id as usize) {
411                *slot = tok.clone();
412            }
413        }
414
415        // Parse merges
416        let merges_file =
417            std::fs::File::open(merges_path).map_err(|e| TextError::IoError(e.to_string()))?;
418        let reader = BufReader::new(merges_file);
419        let mut merges = Vec::new();
420        for line in reader.lines() {
421            let line = line.map_err(|e| TextError::IoError(e.to_string()))?;
422            let line = line.trim();
423            if line.is_empty() || line.starts_with('#') {
424                continue;
425            }
426            let parts: Vec<&str> = line.splitn(2, ' ').collect();
427            if parts.len() == 2 {
428                merges.push((parts[0].to_string(), parts[1].to_string()));
429            }
430        }
431
432        let byte_encoder = bytes_to_unicode();
433        let byte_decoder: HashMap<char, u8> = byte_encoder.iter().map(|(&b, &c)| (c, b)).collect();
434
435        Ok(ByteLevelBpeTokenizer {
436            vocab,
437            id_to_token,
438            merges,
439            byte_encoder,
440            byte_decoder,
441        })
442    }
443
444    /// Return the vocabulary size.
445    pub fn vocab_size(&self) -> usize {
446        self.vocab.len()
447    }
448
449    /// Look up the token string for an ID.
450    pub fn id_to_token(&self, id: u32) -> Option<&str> {
451        self.id_to_token.get(id as usize).map(|s| s.as_str())
452    }
453
454    /// Look up the ID for a token string.
455    pub fn token_to_id(&self, token: &str) -> Option<u32> {
456        self.vocab.get(token).copied()
457    }
458}
459
460// ─── Helpers ─────────────────────────────────────────────────────────────────
461
462/// Minimal JSON string escaping.
463fn escape_json_string(s: &str) -> String {
464    let mut out = String::with_capacity(s.len());
465    for ch in s.chars() {
466        match ch {
467            '"' => out.push_str("\\\""),
468            '\\' => out.push_str("\\\\"),
469            '\n' => out.push_str("\\n"),
470            '\r' => out.push_str("\\r"),
471            '\t' => out.push_str("\\t"),
472            c if (c as u32) < 0x20 => {
473                out.push_str(&format!("\\u{:04x}", c as u32));
474            }
475            c => out.push(c),
476        }
477    }
478    out
479}
480
481/// Minimal JSON object parser that only handles `{"key": number, ...}`.
482///
483/// Uses a string-aware comma splitter so that tokens containing `"` or `,`
484/// are handled correctly.
485fn parse_vocab_json(s: &str) -> Result<HashMap<String, u32>> {
486    let s = s.trim();
487    let inner = s
488        .strip_prefix('{')
489        .and_then(|s| s.strip_suffix('}'))
490        .ok_or_else(|| TextError::IoError("Invalid vocab JSON: missing braces".to_string()))?;
491
492    let mut vocab = HashMap::new();
493    // Split on `,` only when outside a JSON string (tracks in-string state).
494    let chars: Vec<char> = inner.chars().collect();
495    let n = chars.len();
496    let mut i = 0;
497    let mut start = 0;
498
499    while i <= n {
500        let at_end = i == n;
501
502        if at_end {
503            // Flush the final entry
504            let entry: String = chars[start..i].iter().collect();
505            let entry = entry.trim();
506            if !entry.is_empty() {
507                parse_vocab_entry(entry, &mut vocab)?;
508            }
509            break;
510        }
511
512        let ch = chars[i];
513
514        if ch == '"' {
515            // Skip over the whole quoted string including any escaped characters
516            i += 1;
517            while i < n {
518                let sc = chars[i];
519                i += 1;
520                if sc == '\\' {
521                    // skip the escaped character
522                    i += 1;
523                } else if sc == '"' {
524                    break;
525                }
526            }
527            // After closing quote, continue the outer loop without incrementing i
528            continue;
529        }
530
531        if ch == ',' {
532            let entry: String = chars[start..i].iter().collect();
533            let entry = entry.trim();
534            if !entry.is_empty() {
535                parse_vocab_entry(entry, &mut vocab)?;
536            }
537            start = i + 1;
538        }
539
540        i += 1;
541    }
542
543    Ok(vocab)
544}
545
546fn parse_vocab_entry(entry: &str, vocab: &mut HashMap<String, u32>) -> Result<()> {
547    // Format: `"token": id`
548    let colon_pos = find_colon_outside_string(entry)
549        .ok_or_else(|| TextError::IoError(format!("Invalid vocab entry (no colon): {}", entry)))?;
550    let key_part = entry[..colon_pos].trim();
551    let val_part = entry[colon_pos + 1..].trim();
552
553    let key = key_part
554        .strip_prefix('"')
555        .and_then(|s| s.strip_suffix('"'))
556        .map(unescape_json_string)
557        .ok_or_else(|| TextError::IoError(format!("Invalid vocab key: {}", key_part)))?;
558
559    let id: u32 = val_part
560        .parse()
561        .map_err(|_| TextError::IoError(format!("Invalid vocab id: {}", val_part)))?;
562
563    vocab.insert(key, id);
564    Ok(())
565}
566
567fn find_colon_outside_string(s: &str) -> Option<usize> {
568    let mut in_str = false;
569    let mut escaped = false;
570    for (i, ch) in s.char_indices() {
571        if escaped {
572            escaped = false;
573            continue;
574        }
575        if ch == '\\' && in_str {
576            escaped = true;
577            continue;
578        }
579        if ch == '"' {
580            in_str = !in_str;
581            continue;
582        }
583        if ch == ':' && !in_str {
584            return Some(i);
585        }
586    }
587    None
588}
589
590fn unescape_json_string(s: &str) -> String {
591    let mut out = String::with_capacity(s.len());
592    let mut chars = s.chars().peekable();
593    while let Some(ch) = chars.next() {
594        if ch == '\\' {
595            match chars.next() {
596                Some('"') => out.push('"'),
597                Some('\\') => out.push('\\'),
598                Some('/') => out.push('/'),
599                Some('n') => out.push('\n'),
600                Some('r') => out.push('\r'),
601                Some('t') => out.push('\t'),
602                Some('u') => {
603                    let hex: String = chars.by_ref().take(4).collect();
604                    if let Ok(n) = u32::from_str_radix(&hex, 16) {
605                        if let Some(c) = char::from_u32(n) {
606                            out.push(c);
607                        }
608                    }
609                }
610                Some(c) => out.push(c),
611                None => {}
612            }
613        } else {
614            out.push(ch);
615        }
616    }
617    out
618}
619
620// ─── Tests ───────────────────────────────────────────────────────────────────
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625
626    #[test]
627    fn test_bytes_to_unicode_count() {
628        let map = bytes_to_unicode();
629        assert_eq!(map.len(), 256, "should have exactly 256 entries");
630    }
631
632    #[test]
633    fn test_bytes_to_unicode_bijective() {
634        let map = bytes_to_unicode();
635        let mut chars: Vec<char> = map.values().copied().collect();
636        chars.sort();
637        chars.dedup();
638        assert_eq!(
639            chars.len(),
640            256,
641            "all unicode chars must be distinct (bijection)"
642        );
643    }
644
645    #[test]
646    fn test_bytes_to_unicode_ascii_identity() {
647        let map = bytes_to_unicode();
648        // printable ASCII chars should map to themselves
649        for b in b'!'..=b'~' {
650            let ch = map[&b];
651            assert_eq!(
652                ch as u32, b as u32,
653                "byte {} should map to itself, got {}",
654                b, ch as u32
655            );
656        }
657    }
658
659    #[test]
660    fn test_train_vocab_size() {
661        let texts = [
662            "the quick brown fox jumps over the lazy dog",
663            "hello world hello rust hello tokenizer",
664            "byte level bpe tokenizer training test data for vocabulary",
665            "more text data to train the byte level bpe model properly",
666        ];
667        let config = ByteLevelBpeConfig {
668            vocab_size: 300,
669            min_frequency: 1,
670            add_prefix_space: true,
671        };
672        let tok = ByteLevelBpeTokenizer::train(&texts, config);
673        assert!(
674            tok.vocab_size() <= 300,
675            "vocab size should not exceed requested"
676        );
677        assert!(
678            tok.vocab_size() >= 256,
679            "should have at least base 256 tokens"
680        );
681    }
682
683    #[test]
684    fn test_encode_decode_roundtrip() {
685        let texts = [
686            "hello world",
687            "the quick brown fox",
688            "rust programming language",
689            "byte level encoding test",
690        ];
691        let config = ByteLevelBpeConfig {
692            vocab_size: 500,
693            min_frequency: 1,
694            add_prefix_space: true,
695        };
696        let tok = ByteLevelBpeTokenizer::train(&texts, config);
697        let input = "hello world";
698        let ids = tok.encode(input);
699        let decoded = tok.decode(&ids);
700        assert_eq!(decoded, input, "encode/decode roundtrip should be lossless");
701    }
702
703    #[test]
704    fn test_gword_prefix() {
705        // Non-first words should be prefixed with Ġ (U+0120)
706        let texts = ["hello world test"];
707        let config = ByteLevelBpeConfig {
708            vocab_size: 300,
709            min_frequency: 1,
710            add_prefix_space: true,
711        };
712        let tok = ByteLevelBpeTokenizer::train(&texts, config);
713        // The tokenizer vocab should contain a Ġ-prefixed token
714        let has_g_prefix = tok.vocab.keys().any(|k| k.starts_with('\u{0120}'));
715        assert!(has_g_prefix, "vocabulary should contain Ġ-prefixed tokens");
716    }
717
718    #[test]
719    fn test_hello_token() {
720        let texts = ["hello world hello hello hello"];
721        let config = ByteLevelBpeConfig {
722            vocab_size: 300,
723            min_frequency: 1,
724            add_prefix_space: false,
725        };
726        let tok = ByteLevelBpeTokenizer::train(&texts, config);
727        // After training, "hello" should appear as a merged token
728        // (since it's frequent enough)
729        assert!(
730            tok.vocab.contains_key("hello"),
731            "hello should be in vocabulary after training on repeated hello"
732        );
733    }
734
735    #[test]
736    fn test_save_load_roundtrip() {
737        let texts = [
738            "hello world",
739            "test tokenizer save load",
740            "byte level bpe tokenizer",
741        ];
742        let config = ByteLevelBpeConfig {
743            vocab_size: 350,
744            min_frequency: 1,
745            add_prefix_space: true,
746        };
747        let tok = ByteLevelBpeTokenizer::train(&texts, config);
748
749        let dir = std::env::temp_dir();
750        let vocab_path = dir
751            .join("test_bpe_vocab.json")
752            .to_string_lossy()
753            .into_owned();
754        let merges_path = dir
755            .join("test_bpe_merges.txt")
756            .to_string_lossy()
757            .into_owned();
758
759        tok.save_vocab(&vocab_path, &merges_path)
760            .expect("save failed");
761        let loaded = ByteLevelBpeTokenizer::load(&vocab_path, &merges_path).expect("load failed");
762
763        assert_eq!(tok.vocab_size(), loaded.vocab_size());
764        assert_eq!(tok.merges.len(), loaded.merges.len());
765    }
766}