smoltok_core/tokenizer/
traits.rs

1use super::{MergeRule, TokenId};
2use std::collections::HashMap;
3use std::fs::File;
4use std::io::{BufWriter, Write};
5use std::path::Path;
6use std::str::FromStr;
7
8/// A trait defining the behavior of a tokenizer for encoding and decoding.
9pub trait Tokenizer {
10    /// Error that could happen during decoding.
11    type DecodingError: std::error::Error;
12
13    fn merges(&self) -> &[MergeRule];
14
15    /// Returns the number of merge rules learned during training.
16    fn num_merges(&self) -> usize {
17        self.merges().len()
18    }
19
20    fn vocab(&self) -> &HashMap<TokenId, Vec<u8>>;
21
22    /// Encodes a string into a sequence of token IDs.
23    ///
24    /// # Arguments
25    ///
26    /// * `text` - The string to encode.
27    ///
28    /// # Returns
29    ///
30    /// * A vector of token IDs.
31    fn encode(&self, text: &str) -> Vec<TokenId>;
32
33    /// Decodes a sequence of token IDs back into a string.
34    ///
35    /// # Arguments
36    ///
37    /// * `tokens` - The sequence of token IDs.
38    ///
39    /// # Returns
40    ///
41    /// * `Ok(String)` the decoded string.
42    /// * `Err(DecodingError)` if decoding fails.
43    fn decode(&self, tokens: &[TokenId]) -> Result<String, Self::DecodingError>;
44}
45
46/// A trait for training a tokenizer from a dataset.
47///
48/// This trait is separate from `Tokenizer` to allow configuration types to be trainable
49/// while the resulting tokenizer handles encoding/decoding.
50pub trait Trainable {
51    /// The tokenizer type produced by training.
52    type Output: Tokenizer;
53    /// Error that could happen during training.
54    type TrainingError: std::error::Error;
55
56    // TODO: support more dataset kinds, like iterator of strings, folder with files, etc.
57
58    /// Trains a tokenizer on a given dataset to a target vocabulary size.
59    ///
60    /// # Arguments
61    ///
62    /// * `dataset` - The training data as a string.
63    ///
64    /// # Returns
65    /// * `Ok(Self::Output)` if the tokenizer was trained successfully.
66    /// * `Err(TrainingError)` if training fails.
67    fn train(&self, dataset: &str) -> Result<Self::Output, Self::TrainingError>;
68}
69
70pub fn verify_stok_extension(path: &Path) -> Result<(), std::io::Error> {
71    match path.extension().and_then(|s| s.to_str()) {
72        Some("stok") => Ok(()),
73        _ => Err(std::io::Error::new(
74            std::io::ErrorKind::InvalidInput,
75            "File must have .stok extension",
76        )),
77    }
78}
79
80/// Save merges
81pub fn save_merges(writer: &mut dyn Write, merges: &[MergeRule]) -> Result<(), std::io::Error> {
82    for merge_rule in merges {
83        writeln!(writer, "{}", merge_rule)?;
84    }
85
86    writer.flush()?;
87
88    Ok(())
89}
90
91/// Save vocab in format token0 + token1: token2
92pub fn save_vocab(path: &Path, merges: &[MergeRule], vocab: &HashMap<TokenId, Vec<u8>>) -> Result<(), std::io::Error> {
93    let file = File::create(path)?;
94    let mut writer = BufWriter::new(file);
95
96    for merge_rule in merges {
97        let first_bytes = vocab.get(&merge_rule.pair().first()).unwrap();
98        let second_bytes = vocab.get(&merge_rule.pair().second()).unwrap();
99        let new_bytes = vocab.get(&merge_rule.new_id()).unwrap();
100
101        writeln!(
102            writer,
103            "{} + {}: {}",
104            String::from_utf8_lossy(first_bytes.as_slice()),
105            String::from_utf8_lossy(second_bytes.as_slice()),
106            String::from_utf8_lossy(new_bytes.as_slice())
107        )?;
108    }
109
110    writer.flush()?;
111    Ok(())
112}
113
114/// Parses merge rules from an iterator of lines.
115///
116/// # Arguments
117///
118/// * `lines` - An iterator yielding (line_number, line_content) tuples.
119///   Line numbers are used for error messages.
120///
121/// # Returns
122///
123/// * `Ok(Vec<MergeRule>)` containing the parsed merge rules.
124/// * `Err(std::io::Error)` if any line contains invalid merge rule syntax.
125pub fn parse_merges<I>(lines: I) -> Result<Vec<MergeRule>, std::io::Error>
126where
127    I: Iterator<Item = (usize, Result<String, std::io::Error>)>,
128{
129    let mut merges = Vec::new();
130    for (line_num, line) in lines {
131        let line = line?;
132
133        if line.trim().is_empty() {
134            continue; // skip empty lines
135        }
136
137        let merge_rule = MergeRule::from_str(&line).map_err(|e| {
138            std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Line {}: {}", line_num + 1, e))
139        })?;
140
141        merges.push(merge_rule);
142    }
143
144    Ok(merges)
145}
146
147/// A trait for serializing a tokenizer to a file.
148///
149/// This trait provides a default implementation for saving tokenizer merge rules
150/// to a `.stok` file format.
151pub trait Serializable: Tokenizer {
152    /// Saves the tokenizer's merge rules to a `.stok` file and vocabulary to a `.vocab` file.
153    ///
154    /// # Arguments
155    ///
156    /// * `path` - The path to save the tokenizer. Must have a `.stok` extension.
157    ///
158    /// # Returns
159    ///
160    /// * `Ok(())` if the tokenizer was saved successfully.
161    /// * `Err(std::io::Error)` if the file extension is invalid or writing fails.
162    fn save(&self, path: &Path) -> Result<(), std::io::Error>;
163}
164
165/// A trait for deserializing a tokenizer from a file.
166///
167/// This trait provides functionality for loading tokenizer merge rules from a `.stok` file format.
168pub trait Deserializable {
169    /// The tokenizer type produced by loading.
170    type Output: Tokenizer;
171
172    /// Loads a tokenizer from a file.
173    ///
174    /// # Arguments
175    ///
176    /// * `path` - The path to load the tokenizer from. Must have a `.stok` extension.
177    ///
178    /// # Returns
179    ///
180    /// * `Ok(Self::Output)` if the tokenizer was loaded successfully.
181    /// * `Err(std::io::Error)` if the file extension is invalid or reading fails.
182    fn load(&self, path: &Path) -> Result<Self::Output, std::io::Error>;
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    #[test]
190    fn test_verify_stok_extension_valid() {
191        let path = Path::new("model.stok");
192        assert!(verify_stok_extension(path).is_ok());
193    }
194
195    #[test]
196    fn test_verify_stok_extension_double_extension() {
197        let path = Path::new("model.backup.stok");
198        assert!(verify_stok_extension(path).is_ok());
199    }
200
201    #[test]
202    fn test_verify_stok_extension_valid_with_path() {
203        let path = Path::new("/some/path/to/model.stok");
204        assert!(verify_stok_extension(path).is_ok());
205    }
206
207    #[test]
208    fn test_verify_stok_extension_invalid_txt() {
209        let path = Path::new("model.txt");
210        let result = verify_stok_extension(path);
211        assert!(result.is_err());
212        assert_eq!(result.unwrap_err().to_string(), "File must have .stok extension");
213
214        let path = Path::new("model.json");
215        let result = verify_stok_extension(path);
216        assert!(result.is_err());
217        assert_eq!(result.unwrap_err().to_string(), "File must have .stok extension");
218
219        let path = Path::new("model");
220        let result = verify_stok_extension(path);
221        assert!(result.is_err());
222        assert_eq!(result.unwrap_err().to_string(), "File must have .stok extension");
223    }
224}