tokenizers/models/wordlevel/
trainer.rs1use super::WordLevel;
2use crate::utils::parallelism::*;
3use crate::{AddedToken, Result, Trainer};
4use serde::{Deserialize, Serialize};
5use std::cmp::Ordering;
6use std::collections::HashMap;
7
8#[non_exhaustive]
9#[derive(Debug, Clone, Builder, Serialize, Deserialize)]
10pub struct WordLevelTrainer {
11 #[builder(default = "0")]
13 pub min_frequency: u64,
14 #[builder(default = "30_000")]
16 pub vocab_size: usize,
17 #[builder(default = "true")]
19 pub show_progress: bool,
20 #[builder(default)]
22 pub special_tokens: Vec<AddedToken>,
23
24 #[builder(default, private)]
25 words: HashMap<String, u64>,
26}
27
28impl Default for WordLevelTrainer {
29 fn default() -> Self {
30 Self::builder().build().unwrap()
31 }
32}
33
34impl WordLevelTrainer {
35 pub fn builder() -> WordLevelTrainerBuilder {
36 WordLevelTrainerBuilder::default()
37 }
38
39 fn do_train(
40 &self,
41 word_counts: &HashMap<String, u64>,
42 model: &mut WordLevel,
43 ) -> Result<Vec<AddedToken>> {
44 let mut ordered_counts = word_counts.iter().collect::<Vec<_>>();
45
46 let cmp = |l: &(&String, &u64), r: &(&String, &u64)| -> Ordering {
49 let count_comp: Ordering = l.1.cmp(r.1);
50 if count_comp != Ordering::Equal {
51 return count_comp.reverse();
52 }
53 l.0.cmp(r.0)
54 };
55
56 ordered_counts.sort_by(cmp);
57
58 let word_level = WordLevel::builder()
59 .vocab(
60 self.special_tokens
61 .iter()
62 .map(|token| token.content.clone())
63 .chain(
64 ordered_counts
65 .into_iter()
66 .filter(|(_, n)| **n >= self.min_frequency)
67 .map(|(w, _)| w.to_owned()),
68 )
69 .take(self.vocab_size)
70 .enumerate()
71 .map(|(i, w)| (w, i as u32))
72 .collect(),
73 )
74 .build()?;
75
76 model.vocab = word_level.vocab;
78 model.vocab_r = word_level.vocab_r;
79
80 Ok(self.special_tokens.clone())
81 }
82}
83
84impl Trainer for WordLevelTrainer {
85 type Model = WordLevel;
86
87 fn train(&self, model: &mut WordLevel) -> Result<Vec<AddedToken>> {
89 self.do_train(&self.words, model)
90 }
91
92 fn should_show_progress(&self) -> bool {
94 self.show_progress
95 }
96
97 fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
98 where
99 I: Iterator<Item = S> + Send,
100 S: AsRef<str> + Send,
101 F: Fn(&str) -> Result<Vec<String>> + Sync,
102 {
103 let words: Result<HashMap<String, u64>> = iterator
104 .maybe_par_bridge()
105 .map(|sequence| {
106 let words = process(sequence.as_ref())?;
107 let mut map = HashMap::new();
108 for word in words {
109 map.entry(word).and_modify(|c| *c += 1).or_insert(1);
110 }
111 Ok(map)
112 })
113 .reduce(
114 || Ok(HashMap::new()),
115 |acc, ws| {
116 let mut acc = acc?;
117 for (k, v) in ws? {
118 acc.entry(k).and_modify(|c| *c += v).or_insert(v);
119 }
120 Ok(acc)
121 },
122 );
123
124 self.words = words?;
125 Ok(())
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn test_train() {
135 let word_counts: HashMap<String, u64> = [
136 ("the".into(), 25),
137 ("roses".into(), 22),
138 ("are".into(), 24),
139 ("red".into(), 12),
140 ("voilets".into(), 10),
141 ("blue".into(), 16),
142 ]
143 .iter()
144 .cloned()
145 .collect();
146
147 let mut trainer = WordLevelTrainer {
148 vocab_size: 5,
149 ..Default::default()
150 };
151
152 let mut model = WordLevel::default();
153 trainer.do_train(&word_counts, &mut model).unwrap();
154 let expected_vocab: HashMap<String, u32> = [
155 ("the".into(), 0),
156 ("are".into(), 1),
157 ("roses".into(), 2),
158 ("blue".into(), 3),
159 ("red".into(), 4),
160 ]
161 .iter()
162 .cloned()
163 .collect();
164 assert_eq!(model.vocab, expected_vocab);
165
166 trainer.min_frequency = 15;
168 let mut model = WordLevel::default();
169 trainer.do_train(&word_counts, &mut model).unwrap();
170 let expected_vocab: HashMap<String, u32> = [
171 ("the".into(), 0),
172 ("are".into(), 1),
173 ("roses".into(), 2),
174 ("blue".into(), 3),
175 ]
176 .iter()
177 .cloned()
178 .collect();
179
180 assert_eq!(model.vocab, expected_vocab);
181 }
182}