smoltok_core/tokenizer/
traits.rs1use super::{MergeRule, TokenId};
2use std::collections::HashMap;
3use std::fs::File;
4use std::io::{BufWriter, Write};
5use std::path::Path;
6use std::str::FromStr;
7
8pub trait Tokenizer {
10 type DecodingError: std::error::Error;
12
13 fn merges(&self) -> &[MergeRule];
14
15 fn num_merges(&self) -> usize {
17 self.merges().len()
18 }
19
20 fn vocab(&self) -> &HashMap<TokenId, Vec<u8>>;
21
22 fn encode(&self, text: &str) -> Vec<TokenId>;
32
33 fn decode(&self, tokens: &[TokenId]) -> Result<String, Self::DecodingError>;
44}
45
46pub trait Trainable {
51 type Output: Tokenizer;
53 type TrainingError: std::error::Error;
55
56 fn train(&self, dataset: &str) -> Result<Self::Output, Self::TrainingError>;
68}
69
70pub fn verify_stok_extension(path: &Path) -> Result<(), std::io::Error> {
71 match path.extension().and_then(|s| s.to_str()) {
72 Some("stok") => Ok(()),
73 _ => Err(std::io::Error::new(
74 std::io::ErrorKind::InvalidInput,
75 "File must have .stok extension",
76 )),
77 }
78}
79
80pub fn save_merges(writer: &mut dyn Write, merges: &[MergeRule]) -> Result<(), std::io::Error> {
82 for merge_rule in merges {
83 writeln!(writer, "{}", merge_rule)?;
84 }
85
86 writer.flush()?;
87
88 Ok(())
89}
90
91pub fn save_vocab(path: &Path, merges: &[MergeRule], vocab: &HashMap<TokenId, Vec<u8>>) -> Result<(), std::io::Error> {
93 let file = File::create(path)?;
94 let mut writer = BufWriter::new(file);
95
96 for merge_rule in merges {
97 let first_bytes = vocab.get(&merge_rule.pair().first()).unwrap();
98 let second_bytes = vocab.get(&merge_rule.pair().second()).unwrap();
99 let new_bytes = vocab.get(&merge_rule.new_id()).unwrap();
100
101 writeln!(
102 writer,
103 "{} + {}: {}",
104 String::from_utf8_lossy(first_bytes.as_slice()),
105 String::from_utf8_lossy(second_bytes.as_slice()),
106 String::from_utf8_lossy(new_bytes.as_slice())
107 )?;
108 }
109
110 writer.flush()?;
111 Ok(())
112}
113
114pub fn parse_merges<I>(lines: I) -> Result<Vec<MergeRule>, std::io::Error>
126where
127 I: Iterator<Item = (usize, Result<String, std::io::Error>)>,
128{
129 let mut merges = Vec::new();
130 for (line_num, line) in lines {
131 let line = line?;
132
133 if line.trim().is_empty() {
134 continue; }
136
137 let merge_rule = MergeRule::from_str(&line).map_err(|e| {
138 std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Line {}: {}", line_num + 1, e))
139 })?;
140
141 merges.push(merge_rule);
142 }
143
144 Ok(merges)
145}
146
147pub trait Serializable: Tokenizer {
152 fn save(&self, path: &Path) -> Result<(), std::io::Error>;
163}
164
165pub trait Deserializable {
169 type Output: Tokenizer;
171
172 fn load(&self, path: &Path) -> Result<Self::Output, std::io::Error>;
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188
189 #[test]
190 fn test_verify_stok_extension_valid() {
191 let path = Path::new("model.stok");
192 assert!(verify_stok_extension(path).is_ok());
193 }
194
195 #[test]
196 fn test_verify_stok_extension_double_extension() {
197 let path = Path::new("model.backup.stok");
198 assert!(verify_stok_extension(path).is_ok());
199 }
200
201 #[test]
202 fn test_verify_stok_extension_valid_with_path() {
203 let path = Path::new("/some/path/to/model.stok");
204 assert!(verify_stok_extension(path).is_ok());
205 }
206
207 #[test]
208 fn test_verify_stok_extension_invalid_txt() {
209 let path = Path::new("model.txt");
210 let result = verify_stok_extension(path);
211 assert!(result.is_err());
212 assert_eq!(result.unwrap_err().to_string(), "File must have .stok extension");
213
214 let path = Path::new("model.json");
215 let result = verify_stok_extension(path);
216 assert!(result.is_err());
217 assert_eq!(result.unwrap_err().to_string(), "File must have .stok extension");
218
219 let path = Path::new("model");
220 let result = verify_stok_extension(path);
221 assert!(result.is_err());
222 assert_eq!(result.unwrap_err().to_string(), "File must have .stok extension");
223 }
224}