smoltok_core/regex/
config_parallel.rs1use crate::regex::{GPT4_SPLIT_PATTERN, parse_pattern};
2use crate::tokenizer::{
3 MIN_VOCAB_SIZE, VocabSizeTooSmall, get_pair_counts, parse_merges, string_to_token_ids, train_bpe,
4 verify_stok_extension,
5};
6use crate::{
7 Deserializable, ParallelRegexBPETokenizer, RegexBPETokenizerConfigError, RegexCompilationError, TokenId, Trainable,
8 regex,
9};
10use fancy_regex::Regex;
11use rayon::prelude::*;
12use std::collections::HashMap;
13use std::fs::File;
14use std::io::{BufRead, BufReader, Error};
15use std::path::Path;
16
17#[derive(Debug)]
21pub struct ParallelRegexBPETokenizerConfig {
22 vocab_size: u32,
23 pattern: String,
24 compiled_pattern: Regex,
25}
26
27impl ParallelRegexBPETokenizerConfig {
28 pub fn build(vocab_size: u32, pattern: Option<&str>) -> Result<Self, RegexBPETokenizerConfigError> {
40 VocabSizeTooSmall::check(vocab_size)?;
41
42 let pattern = pattern.unwrap_or(GPT4_SPLIT_PATTERN).to_string();
43 let compiled_pattern = Regex::new(pattern.as_str()).map_err(|e| RegexCompilationError(e.to_string()))?;
44
45 Ok(ParallelRegexBPETokenizerConfig {
46 vocab_size,
47 pattern,
48 compiled_pattern,
49 })
50 }
51
52 pub fn from_merges(merges: u32, pattern: Option<&str>) -> Result<Self, RegexBPETokenizerConfigError> {
59 Self::build(MIN_VOCAB_SIZE + merges, pattern)
60 }
61}
62
63impl Trainable for ParallelRegexBPETokenizerConfig {
64 type Output = ParallelRegexBPETokenizer;
65 type TrainingError = std::convert::Infallible;
66
67 fn train(&self, dataset: &str) -> Result<ParallelRegexBPETokenizer, Self::TrainingError> {
68 let dataset_chunks = regex::split_text(dataset, &self.compiled_pattern);
69
70 let mut chunks_tokens: Vec<Vec<TokenId>> =
71 dataset_chunks.iter().map(|chunk| string_to_token_ids(chunk)).collect();
72
73 let n_iterations = self.vocab_size - MIN_VOCAB_SIZE;
74 let merges = train_bpe(&mut chunks_tokens, n_iterations, |chunks| {
75 chunks
77 .par_iter()
78 .fold(HashMap::new, |mut thread_map, tokens| {
79 get_pair_counts(tokens.as_slice(), &mut thread_map);
80 thread_map
81 })
82 .reduce(HashMap::new, |mut combined, thread_map| {
83 for (pair, count) in thread_map {
84 *combined.entry(pair).or_insert(0) += count;
85 }
86 combined
87 })
88 });
89
90 Ok(ParallelRegexBPETokenizer::new(merges, self.pattern.clone()))
91 }
92}
93
94impl Deserializable for ParallelRegexBPETokenizerConfig {
95 type Output = ParallelRegexBPETokenizer;
96
97 fn load(&self, path: &Path) -> Result<Self::Output, Error> {
112 verify_stok_extension(path)?;
113
114 let file = File::open(path)?;
115 let reader = BufReader::new(file);
116 let mut lines = reader.lines();
117
118 let first_line = lines
119 .next()
120 .ok_or_else(|| Error::new(std::io::ErrorKind::InvalidData, "File is empty"))??;
121
122 let pattern = parse_pattern(first_line)?;
123 let merges = parse_merges(lines.enumerate().map(|(i, line)| (i + 2, line)))?;
125
126 Ok(ParallelRegexBPETokenizer::new(merges, pattern))
127 }
128}