smoltok_core/regex/
config_parallel.rs

1use crate::regex::{GPT4_SPLIT_PATTERN, parse_pattern};
2use crate::tokenizer::{
3    MIN_VOCAB_SIZE, VocabSizeTooSmall, get_pair_counts, parse_merges, string_to_token_ids, train_bpe,
4    verify_stok_extension,
5};
6use crate::{
7    Deserializable, ParallelRegexBPETokenizer, RegexBPETokenizerConfigError, RegexCompilationError, TokenId, Trainable,
8    regex,
9};
10use fancy_regex::Regex;
11use rayon::prelude::*;
12use std::collections::HashMap;
13use std::fs::File;
14use std::io::{BufRead, BufReader, Error};
15use std::path::Path;
16
17/// Configuration for training a regex-based BPE tokenizer in parallel with rayon.
18///
19/// This struct implements [`Trainable`] and produces a [`ParallelRegexBPETokenizer`].
20#[derive(Debug)]
21pub struct ParallelRegexBPETokenizerConfig {
22    vocab_size: u32,
23    pattern: String,
24    compiled_pattern: Regex,
25}
26
27impl ParallelRegexBPETokenizerConfig {
28    /// Create a new configuration for training a regex BPE tokenizer.
29    ///
30    /// # Arguments
31    ///
32    /// * `vocab_size` - The desired vocabulary size (must be at least 256).
33    /// * `pattern` - Optional custom regex pattern. If `None`, uses the GPT-4 split pattern.
34    ///
35    /// # Returns
36    ///
37    /// * `Ok(RegexBPETokenizerConfig)` if configuration is valid.
38    /// * `Err(RegexBPETokenizerConfigError)` if vocab size is too small or pattern is invalid.
39    pub fn build(vocab_size: u32, pattern: Option<&str>) -> Result<Self, RegexBPETokenizerConfigError> {
40        VocabSizeTooSmall::check(vocab_size)?;
41
42        let pattern = pattern.unwrap_or(GPT4_SPLIT_PATTERN).to_string();
43        let compiled_pattern = Regex::new(pattern.as_str()).map_err(|e| RegexCompilationError(e.to_string()))?;
44
45        Ok(ParallelRegexBPETokenizerConfig {
46            vocab_size,
47            pattern,
48            compiled_pattern,
49        })
50    }
51
52    /// Create a new configuration from the number of merges instead of vocab size.
53    ///
54    /// # Arguments
55    ///
56    /// * `merges` - The number of merge operations to perform.
57    /// * `pattern` - Optional custom regex pattern.
58    pub fn from_merges(merges: u32, pattern: Option<&str>) -> Result<Self, RegexBPETokenizerConfigError> {
59        Self::build(MIN_VOCAB_SIZE + merges, pattern)
60    }
61}
62
63impl Trainable for ParallelRegexBPETokenizerConfig {
64    type Output = ParallelRegexBPETokenizer;
65    type TrainingError = std::convert::Infallible;
66
67    fn train(&self, dataset: &str) -> Result<ParallelRegexBPETokenizer, Self::TrainingError> {
68        let dataset_chunks = regex::split_text(dataset, &self.compiled_pattern);
69
70        let mut chunks_tokens: Vec<Vec<TokenId>> =
71            dataset_chunks.iter().map(|chunk| string_to_token_ids(chunk)).collect();
72
73        let n_iterations = self.vocab_size - MIN_VOCAB_SIZE;
74        let merges = train_bpe(&mut chunks_tokens, n_iterations, |chunks| {
75            // parallel counting with rayon
76            chunks
77                .par_iter()
78                .fold(HashMap::new, |mut thread_map, tokens| {
79                    get_pair_counts(tokens.as_slice(), &mut thread_map);
80                    thread_map
81                })
82                .reduce(HashMap::new, |mut combined, thread_map| {
83                    for (pair, count) in thread_map {
84                        *combined.entry(pair).or_insert(0) += count;
85                    }
86                    combined
87                })
88        });
89
90        Ok(ParallelRegexBPETokenizer::new(merges, self.pattern.clone()))
91    }
92}
93
94impl Deserializable for ParallelRegexBPETokenizerConfig {
95    type Output = ParallelRegexBPETokenizer;
96
97    /// Loads a ParallelRegexBPETokenizer from a file.
98    ///
99    /// The file must contain a header line with the regex pattern,
100    /// followed by merge rules (one per line).
101    ///
102    /// # Arguments
103    ///
104    /// * `path` - The path to load the tokenizer from. Must have a `.stok` extension.
105    ///
106    /// # Returns
107    ///
108    /// * `Ok(ParallelRegexBPETokenizer)` if the tokenizer was loaded successfully.
109    /// * `Err(std::io::Error)` if the file extension is invalid, reading fails,
110    ///   or the file format is invalid.
111    fn load(&self, path: &Path) -> Result<Self::Output, Error> {
112        verify_stok_extension(path)?;
113
114        let file = File::open(path)?;
115        let reader = BufReader::new(file);
116        let mut lines = reader.lines();
117
118        let first_line = lines
119            .next()
120            .ok_or_else(|| Error::new(std::io::ErrorKind::InvalidData, "File is empty"))??;
121
122        let pattern = parse_pattern(first_line)?;
123        // + 1 for pattern line
124        let merges = parse_merges(lines.enumerate().map(|(i, line)| (i + 2, line)))?;
125
126        Ok(ParallelRegexBPETokenizer::new(merges, pattern))
127    }
128}