tokenizers/models/wordlevel/
trainer.rs

1use super::WordLevel;
2use crate::utils::parallelism::*;
3use crate::{AddedToken, Result, Trainer};
4use serde::{Deserialize, Serialize};
5use std::cmp::Ordering;
6use std::collections::HashMap;
7
8#[non_exhaustive]
9#[derive(Debug, Clone, Builder, Serialize, Deserialize)]
10pub struct WordLevelTrainer {
11    /// The minimum frequency a word must have to be part of the vocabulary
12    #[builder(default = "0")]
13    pub min_frequency: u64,
14    /// The target vocabulary size
15    #[builder(default = "30_000")]
16    pub vocab_size: usize,
17    /// Whether to show progress while training
18    #[builder(default = "true")]
19    pub show_progress: bool,
20    /// A list of special tokens that the model should know of
21    #[builder(default)]
22    pub special_tokens: Vec<AddedToken>,
23
24    #[builder(default, private)]
25    words: HashMap<String, u64>,
26}
27
28impl Default for WordLevelTrainer {
29    fn default() -> Self {
30        Self::builder().build().unwrap()
31    }
32}
33
34impl WordLevelTrainer {
35    pub fn builder() -> WordLevelTrainerBuilder {
36        WordLevelTrainerBuilder::default()
37    }
38
39    fn do_train(
40        &self,
41        word_counts: &HashMap<String, u64>,
42        model: &mut WordLevel,
43    ) -> Result<Vec<AddedToken>> {
44        let mut ordered_counts = word_counts.iter().collect::<Vec<_>>();
45
46        //sort the word counts first by inverse counts and then by word, in order
47        //to keep the sorting deterministic in case of equal counts
48        let cmp = |l: &(&String, &u64), r: &(&String, &u64)| -> Ordering {
49            let count_comp: Ordering = l.1.cmp(r.1);
50            if count_comp != Ordering::Equal {
51                return count_comp.reverse();
52            }
53            l.0.cmp(r.0)
54        };
55
56        ordered_counts.sort_by(cmp);
57
58        let word_level = WordLevel::builder()
59            .vocab(
60                self.special_tokens
61                    .iter()
62                    .map(|token| token.content.clone())
63                    .chain(
64                        ordered_counts
65                            .into_iter()
66                            .filter(|(_, n)| **n >= self.min_frequency)
67                            .map(|(w, _)| w.to_owned()),
68                    )
69                    .take(self.vocab_size)
70                    .enumerate()
71                    .map(|(i, w)| (w, i as u32))
72                    .collect(),
73            )
74            .build()?;
75
76        // Transfer the vocab
77        model.vocab = word_level.vocab;
78        model.vocab_r = word_level.vocab_r;
79
80        Ok(self.special_tokens.clone())
81    }
82}
83
84impl Trainer for WordLevelTrainer {
85    type Model = WordLevel;
86
87    /// Train a WordLevel model
88    fn train(&self, model: &mut WordLevel) -> Result<Vec<AddedToken>> {
89        self.do_train(&self.words, model)
90    }
91
92    /// Whether we should show progress
93    fn should_show_progress(&self) -> bool {
94        self.show_progress
95    }
96
97    fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
98    where
99        I: Iterator<Item = S> + Send,
100        S: AsRef<str> + Send,
101        F: Fn(&str) -> Result<Vec<String>> + Sync,
102    {
103        let words: Result<HashMap<String, u64>> = iterator
104            .maybe_par_bridge()
105            .map(|sequence| {
106                let words = process(sequence.as_ref())?;
107                let mut map = HashMap::new();
108                for word in words {
109                    map.entry(word).and_modify(|c| *c += 1).or_insert(1);
110                }
111                Ok(map)
112            })
113            .reduce(
114                || Ok(HashMap::new()),
115                |acc, ws| {
116                    let mut acc = acc?;
117                    for (k, v) in ws? {
118                        acc.entry(k).and_modify(|c| *c += v).or_insert(v);
119                    }
120                    Ok(acc)
121                },
122            );
123
124        self.words = words?;
125        Ok(())
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn test_train() {
135        let word_counts: HashMap<String, u64> = [
136            ("the".into(), 25),
137            ("roses".into(), 22),
138            ("are".into(), 24),
139            ("red".into(), 12),
140            ("voilets".into(), 10),
141            ("blue".into(), 16),
142        ]
143        .iter()
144        .cloned()
145        .collect();
146
147        let mut trainer = WordLevelTrainer {
148            vocab_size: 5,
149            ..Default::default()
150        };
151
152        let mut model = WordLevel::default();
153        trainer.do_train(&word_counts, &mut model).unwrap();
154        let expected_vocab: HashMap<String, u32> = [
155            ("the".into(), 0),
156            ("are".into(), 1),
157            ("roses".into(), 2),
158            ("blue".into(), 3),
159            ("red".into(), 4),
160        ]
161        .iter()
162        .cloned()
163        .collect();
164        assert_eq!(model.vocab, expected_vocab);
165
166        // If we specify a min_frequency
167        trainer.min_frequency = 15;
168        let mut model = WordLevel::default();
169        trainer.do_train(&word_counts, &mut model).unwrap();
170        let expected_vocab: HashMap<String, u32> = [
171            ("the".into(), 0),
172            ("are".into(), 1),
173            ("roses".into(), 2),
174            ("blue".into(), 3),
175        ]
176        .iter()
177        .cloned()
178        .collect();
179
180        assert_eq!(model.vocab, expected_vocab);
181    }
182}