tokenizers/models/wordpiece/
trainer.rs

1use super::WordPiece;
2use crate::models::bpe::{BpeTrainer, BpeTrainerBuilder, BPE};
3use crate::tokenizer::{AddedToken, Result, Trainer};
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7/// A `WordPieceTrainerBuilder` can be used to create a `WordPieceTrainer` with a custom
8/// configuration.
9pub struct WordPieceTrainerBuilder {
10    bpe_trainer_builder: BpeTrainerBuilder,
11}
12
13impl Default for WordPieceTrainerBuilder {
14    fn default() -> Self {
15        Self {
16            bpe_trainer_builder: BpeTrainerBuilder::new().continuing_subword_prefix("##".into()),
17        }
18    }
19}
20
21impl WordPieceTrainerBuilder {
22    /// Constructs a new `WordPieceTrainerBuilder`
23    pub fn new() -> Self {
24        Self::default()
25    }
26
27    /// Set the expected minimum frequency
28    #[must_use]
29    pub fn min_frequency(mut self, frequency: u64) -> Self {
30        self.bpe_trainer_builder = self.bpe_trainer_builder.min_frequency(frequency);
31        self
32    }
33
34    /// Set the vocabulary size
35    #[must_use]
36    pub fn vocab_size(mut self, size: usize) -> Self {
37        self.bpe_trainer_builder = self.bpe_trainer_builder.vocab_size(size);
38        self
39    }
40
41    /// Set whether to show progress
42    #[must_use]
43    pub fn show_progress(mut self, show: bool) -> Self {
44        self.bpe_trainer_builder = self.bpe_trainer_builder.show_progress(show);
45        self
46    }
47
48    /// Set the special tokens
49    #[must_use]
50    pub fn special_tokens(mut self, tokens: Vec<AddedToken>) -> Self {
51        self.bpe_trainer_builder = self.bpe_trainer_builder.special_tokens(tokens);
52        self
53    }
54
55    /// Set whether to limit the alphabet
56    #[must_use]
57    pub fn limit_alphabet(mut self, limit: usize) -> Self {
58        self.bpe_trainer_builder = self.bpe_trainer_builder.limit_alphabet(limit);
59        self
60    }
61
62    /// Set the initial alphabet
63    #[must_use]
64    pub fn initial_alphabet(mut self, alphabet: HashSet<char>) -> Self {
65        self.bpe_trainer_builder = self.bpe_trainer_builder.initial_alphabet(alphabet);
66        self
67    }
68
69    /// Set the continuing_subword_prefix
70    #[must_use]
71    pub fn continuing_subword_prefix(mut self, prefix: String) -> Self {
72        self.bpe_trainer_builder = self.bpe_trainer_builder.continuing_subword_prefix(prefix);
73        self
74    }
75
76    /// Set the end_of_word_suffix
77    #[must_use]
78    pub fn end_of_word_suffix(mut self, suffix: String) -> Self {
79        self.bpe_trainer_builder = self.bpe_trainer_builder.end_of_word_suffix(suffix);
80        self
81    }
82
83    /// Constructs the final BpeTrainer
84    pub fn build(self) -> WordPieceTrainer {
85        let bpe_trainer = self.bpe_trainer_builder.build();
86        WordPieceTrainer { bpe_trainer }
87    }
88}
89
90/// Trains a `WordPiece` model.
91#[derive(Default, Clone, Deserialize, Serialize)]
92pub struct WordPieceTrainer {
93    bpe_trainer: BpeTrainer,
94}
95
96impl WordPieceTrainer {
97    pub fn min_frequency(&self) -> u64 {
98        self.bpe_trainer.min_frequency
99    }
100
101    pub fn set_min_frequency(&mut self, freq: u64) {
102        self.bpe_trainer.min_frequency = freq;
103    }
104
105    pub fn vocab_size(&self) -> usize {
106        self.bpe_trainer.vocab_size
107    }
108
109    pub fn set_vocab_size(&mut self, size: usize) {
110        self.bpe_trainer.vocab_size = size;
111    }
112
113    pub fn show_progress(&self) -> bool {
114        self.bpe_trainer.show_progress
115    }
116
117    pub fn set_show_progress(&mut self, show_progress: bool) {
118        self.bpe_trainer.show_progress = show_progress;
119    }
120
121    pub fn special_tokens(&self) -> &[AddedToken] {
122        &self.bpe_trainer.special_tokens
123    }
124
125    pub fn set_special_tokens(&mut self, special_tokens: Vec<AddedToken>) {
126        self.bpe_trainer.special_tokens = special_tokens;
127    }
128
129    pub fn limit_alphabet(&self) -> Option<usize> {
130        self.bpe_trainer.limit_alphabet
131    }
132
133    pub fn set_limit_alphabet(&mut self, limit: Option<usize>) {
134        self.bpe_trainer.limit_alphabet = limit;
135    }
136
137    pub fn initial_alphabet(&self) -> &HashSet<char> {
138        &self.bpe_trainer.initial_alphabet
139    }
140
141    pub fn set_initial_alphabet(&mut self, alphabet: HashSet<char>) {
142        self.bpe_trainer.initial_alphabet = alphabet;
143    }
144
145    pub fn continuing_subword_prefix(&self) -> &Option<String> {
146        &self.bpe_trainer.continuing_subword_prefix
147    }
148
149    pub fn set_continuing_subword_prefix(&mut self, prefix: Option<String>) {
150        self.bpe_trainer.continuing_subword_prefix = prefix;
151    }
152
153    pub fn end_of_word_suffix(&self) -> &Option<String> {
154        &self.bpe_trainer.end_of_word_suffix
155    }
156
157    pub fn set_end_of_word_suffix(&mut self, suffix: Option<String>) {
158        self.bpe_trainer.end_of_word_suffix = suffix;
159    }
160
161    pub fn builder() -> WordPieceTrainerBuilder {
162        WordPieceTrainerBuilder::default()
163    }
164
165    pub fn train(&self, model: &mut WordPiece) -> Result<Vec<AddedToken>> {
166        let mut bpe = BPE::default();
167        let special_tokens = self.bpe_trainer.train(&mut bpe)?;
168        let new_wordpiece = WordPiece::from_bpe(&bpe);
169
170        // Transfer the vocab
171        model.vocab = new_wordpiece.vocab;
172        model.vocab_r = new_wordpiece.vocab_r;
173        // The continuing_subword_prefix is the only other option to be overriden by the trainer
174        model.continuing_subword_prefix = new_wordpiece.continuing_subword_prefix;
175
176        Ok(special_tokens)
177    }
178}
179
180impl Trainer for WordPieceTrainer {
181    type Model = WordPiece;
182
183    fn train(&self, model: &mut WordPiece) -> Result<Vec<AddedToken>> {
184        self.train(model)
185    }
186
187    fn should_show_progress(&self) -> bool {
188        self.bpe_trainer.should_show_progress()
189    }
190
191    fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
192    where
193        I: Iterator<Item = S> + Send,
194        S: AsRef<str> + Send,
195        F: Fn(&str) -> Result<Vec<String>> + Sync,
196    {
197        self.bpe_trainer.feed(iterator, process)
198    }
199}