smoltok_core/simple/
config.rs1use super::SimpleBPETokenizer;
4use crate::Deserializable;
5use crate::tokenizer::{
6 MIN_VOCAB_SIZE, MergeRule, TokenId, Trainable, VocabSizeTooSmall, get_most_common_pair, get_pair_counts, merge,
7 parse_merges, string_to_token_ids, verify_stok_extension,
8};
9use std::collections::HashMap;
10use std::convert::Infallible;
11use std::fs::File;
12use std::io::{BufRead, BufReader};
13use std::path::Path;
14
15#[derive(Debug)]
19pub struct SimpleBPETokenizerConfig {
20 vocab_size: u32,
21}
22
23impl SimpleBPETokenizerConfig {
24 pub fn build(vocab_size: u32) -> Result<Self, VocabSizeTooSmall> {
26 VocabSizeTooSmall::check(vocab_size)?;
27 Ok(SimpleBPETokenizerConfig { vocab_size })
28 }
29
30 pub fn from_merges(merges: u32) -> Self {
32 Self::build(MIN_VOCAB_SIZE + merges).unwrap()
33 }
34}
35
36impl Trainable for SimpleBPETokenizerConfig {
37 type Output = SimpleBPETokenizer;
38 type TrainingError = Infallible;
39
40 fn train(&self, dataset: &str) -> Result<SimpleBPETokenizer, Infallible> {
41 let mut tokens = string_to_token_ids(dataset);
42 let n_iterations = self.vocab_size - MIN_VOCAB_SIZE;
43 let mut merges: Vec<MergeRule> = Vec::with_capacity(n_iterations as usize);
44
45 for i in 0..n_iterations {
46 let mut counts = HashMap::new();
47 get_pair_counts(tokens.as_slice(), &mut counts);
48 let Some(most_common_pair) = get_most_common_pair(&counts) else {
49 break;
50 };
51
52 let rule = most_common_pair.with_new_id(TokenId::for_new_token(i));
53
54 merges.push(rule);
55 merge(&mut tokens, rule);
56 }
57
58 Ok(SimpleBPETokenizer::from_merges(merges))
59 }
60}
61
62impl Deserializable for SimpleBPETokenizerConfig {
63 type Output = SimpleBPETokenizer;
64
65 fn load(&self, path: &Path) -> Result<Self::Output, std::io::Error> {
66 verify_stok_extension(path)?;
67
68 let file = File::open(path)?;
69 let reader = BufReader::new(file);
70
71 Ok(SimpleBPETokenizer::from_merges(parse_merges(
72 reader.lines().enumerate().map(|(i, line)| (i + 1, line)),
73 )?))
74 }
75}