smoltok_core/simple/
tokenizer.rs1use crate::tokenizer::{
4 MergeRule, Serializable, TokenId, TokenPair, Tokenizer, UnknownTokenId, bpe_encode, build_merge_lookup,
5 build_vocab, save_merges, save_vocab, verify_stok_extension,
6};
7use std::collections::HashMap;
8use std::fs::File;
9use std::io::BufWriter;
10use std::io::Error;
11use std::path::Path;
12
13#[derive(Debug)]
15pub struct SimpleBPETokenizer {
16 merges: Vec<MergeRule>,
18 merge_lookup: HashMap<TokenPair, (usize, TokenId)>,
20 vocab: HashMap<TokenId, Vec<u8>>,
22}
23
24impl SimpleBPETokenizer {
25 pub fn from_merges(merges: Vec<MergeRule>) -> Self {
33 Self {
34 vocab: build_vocab(merges.as_slice()),
35 merge_lookup: build_merge_lookup(merges.as_slice()),
36 merges,
37 }
38 }
39}
40
41impl Tokenizer for SimpleBPETokenizer {
42 type DecodingError = UnknownTokenId;
43
44 fn merges(&self) -> &[MergeRule] {
45 self.merges.as_slice()
46 }
47
48 fn vocab(&self) -> &HashMap<TokenId, Vec<u8>> {
49 &self.vocab
50 }
51
52 fn encode(&self, text: &str) -> Vec<TokenId> {
53 bpe_encode(text, &self.merge_lookup)
54 }
55
56 fn decode(&self, tokens: &[TokenId]) -> Result<String, UnknownTokenId> {
57 let mut bytes = Vec::new();
58
59 for token_id in tokens {
60 if let Some(token_bytes) = self.vocab.get(token_id) {
61 bytes.extend_from_slice(token_bytes);
62 } else {
63 return Err(UnknownTokenId(*token_id));
64 }
65 }
66
67 Ok(String::from_utf8_lossy(&bytes).into_owned())
68 }
69}
70
71impl Serializable for SimpleBPETokenizer {
72 fn save(&self, path: &Path) -> Result<(), Error> {
73 verify_stok_extension(path)?;
74
75 let file = File::create(path)?;
76 let mut writer = BufWriter::new(file);
77
78 save_merges(&mut writer, self.merges.as_slice())?;
79 save_vocab(&path.with_extension("vocab"), self.merges.as_slice(), self.vocab())?;
80
81 Ok(())
82 }
83}