smoltok_core/simple/
config.rs

1//! Configuration for the simple BPE tokenizer.
2
3use super::SimpleBPETokenizer;
4use crate::Deserializable;
5use crate::tokenizer::{
6    MIN_VOCAB_SIZE, MergeRule, TokenId, Trainable, VocabSizeTooSmall, get_most_common_pair, get_pair_counts, merge,
7    parse_merges, string_to_token_ids, verify_stok_extension,
8};
9use std::collections::HashMap;
10use std::convert::Infallible;
11use std::fs::File;
12use std::io::{BufRead, BufReader};
13use std::path::Path;
14
15/// Configuration for training a simple BPE tokenizer.
16///
17/// This struct implements [`Trainable`] and produces a [`SimpleBPETokenizer`].
18#[derive(Debug)]
19pub struct SimpleBPETokenizerConfig {
20    vocab_size: u32,
21}
22
23impl SimpleBPETokenizerConfig {
24    /// Create a new configuration for training a simple BPE tokenizer from vocab size.
25    pub fn build(vocab_size: u32) -> Result<Self, VocabSizeTooSmall> {
26        VocabSizeTooSmall::check(vocab_size)?;
27        Ok(SimpleBPETokenizerConfig { vocab_size })
28    }
29
30    /// Create a new configuration for training a simple BPE tokenizer from number of merges.
31    pub fn from_merges(merges: u32) -> Self {
32        Self::build(MIN_VOCAB_SIZE + merges).unwrap()
33    }
34}
35
36impl Trainable for SimpleBPETokenizerConfig {
37    type Output = SimpleBPETokenizer;
38    type TrainingError = Infallible;
39
40    fn train(&self, dataset: &str) -> Result<SimpleBPETokenizer, Infallible> {
41        let mut tokens = string_to_token_ids(dataset);
42        let n_iterations = self.vocab_size - MIN_VOCAB_SIZE;
43        let mut merges: Vec<MergeRule> = Vec::with_capacity(n_iterations as usize);
44
45        for i in 0..n_iterations {
46            let mut counts = HashMap::new();
47            get_pair_counts(tokens.as_slice(), &mut counts);
48            let Some(most_common_pair) = get_most_common_pair(&counts) else {
49                break;
50            };
51
52            let rule = most_common_pair.with_new_id(TokenId::for_new_token(i));
53
54            merges.push(rule);
55            merge(&mut tokens, rule);
56        }
57
58        Ok(SimpleBPETokenizer::from_merges(merges))
59    }
60}
61
62impl Deserializable for SimpleBPETokenizerConfig {
63    type Output = SimpleBPETokenizer;
64
65    fn load(&self, path: &Path) -> Result<Self::Output, std::io::Error> {
66        verify_stok_extension(path)?;
67
68        let file = File::open(path)?;
69        let reader = BufReader::new(file);
70
71        Ok(SimpleBPETokenizer::from_merges(parse_merges(
72            reader.lines().enumerate().map(|(i, line)| (i + 1, line)),
73        )?))
74    }
75}