tokenizers/models/wordpiece/
trainer.rs1use super::WordPiece;
2use crate::models::bpe::{BpeTrainer, BpeTrainerBuilder, BPE};
3use crate::tokenizer::{AddedToken, Result, Trainer};
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7pub struct WordPieceTrainerBuilder {
10 bpe_trainer_builder: BpeTrainerBuilder,
11}
12
13impl Default for WordPieceTrainerBuilder {
14 fn default() -> Self {
15 Self {
16 bpe_trainer_builder: BpeTrainerBuilder::new().continuing_subword_prefix("##".into()),
17 }
18 }
19}
20
21impl WordPieceTrainerBuilder {
22 pub fn new() -> Self {
24 Self::default()
25 }
26
27 #[must_use]
29 pub fn min_frequency(mut self, frequency: u64) -> Self {
30 self.bpe_trainer_builder = self.bpe_trainer_builder.min_frequency(frequency);
31 self
32 }
33
34 #[must_use]
36 pub fn vocab_size(mut self, size: usize) -> Self {
37 self.bpe_trainer_builder = self.bpe_trainer_builder.vocab_size(size);
38 self
39 }
40
41 #[must_use]
43 pub fn show_progress(mut self, show: bool) -> Self {
44 self.bpe_trainer_builder = self.bpe_trainer_builder.show_progress(show);
45 self
46 }
47
48 #[must_use]
50 pub fn special_tokens(mut self, tokens: Vec<AddedToken>) -> Self {
51 self.bpe_trainer_builder = self.bpe_trainer_builder.special_tokens(tokens);
52 self
53 }
54
55 #[must_use]
57 pub fn limit_alphabet(mut self, limit: usize) -> Self {
58 self.bpe_trainer_builder = self.bpe_trainer_builder.limit_alphabet(limit);
59 self
60 }
61
62 #[must_use]
64 pub fn initial_alphabet(mut self, alphabet: HashSet<char>) -> Self {
65 self.bpe_trainer_builder = self.bpe_trainer_builder.initial_alphabet(alphabet);
66 self
67 }
68
69 #[must_use]
71 pub fn continuing_subword_prefix(mut self, prefix: String) -> Self {
72 self.bpe_trainer_builder = self.bpe_trainer_builder.continuing_subword_prefix(prefix);
73 self
74 }
75
76 #[must_use]
78 pub fn end_of_word_suffix(mut self, suffix: String) -> Self {
79 self.bpe_trainer_builder = self.bpe_trainer_builder.end_of_word_suffix(suffix);
80 self
81 }
82
83 pub fn build(self) -> WordPieceTrainer {
85 let bpe_trainer = self.bpe_trainer_builder.build();
86 WordPieceTrainer { bpe_trainer }
87 }
88}
89
90#[derive(Default, Clone, Deserialize, Serialize)]
92pub struct WordPieceTrainer {
93 bpe_trainer: BpeTrainer,
94}
95
96impl WordPieceTrainer {
97 pub fn min_frequency(&self) -> u64 {
98 self.bpe_trainer.min_frequency
99 }
100
101 pub fn set_min_frequency(&mut self, freq: u64) {
102 self.bpe_trainer.min_frequency = freq;
103 }
104
105 pub fn vocab_size(&self) -> usize {
106 self.bpe_trainer.vocab_size
107 }
108
109 pub fn set_vocab_size(&mut self, size: usize) {
110 self.bpe_trainer.vocab_size = size;
111 }
112
113 pub fn show_progress(&self) -> bool {
114 self.bpe_trainer.show_progress
115 }
116
117 pub fn set_show_progress(&mut self, show_progress: bool) {
118 self.bpe_trainer.show_progress = show_progress;
119 }
120
121 pub fn special_tokens(&self) -> &[AddedToken] {
122 &self.bpe_trainer.special_tokens
123 }
124
125 pub fn set_special_tokens(&mut self, special_tokens: Vec<AddedToken>) {
126 self.bpe_trainer.special_tokens = special_tokens;
127 }
128
129 pub fn limit_alphabet(&self) -> Option<usize> {
130 self.bpe_trainer.limit_alphabet
131 }
132
133 pub fn set_limit_alphabet(&mut self, limit: Option<usize>) {
134 self.bpe_trainer.limit_alphabet = limit;
135 }
136
137 pub fn initial_alphabet(&self) -> &HashSet<char> {
138 &self.bpe_trainer.initial_alphabet
139 }
140
141 pub fn set_initial_alphabet(&mut self, alphabet: HashSet<char>) {
142 self.bpe_trainer.initial_alphabet = alphabet;
143 }
144
145 pub fn continuing_subword_prefix(&self) -> &Option<String> {
146 &self.bpe_trainer.continuing_subword_prefix
147 }
148
149 pub fn set_continuing_subword_prefix(&mut self, prefix: Option<String>) {
150 self.bpe_trainer.continuing_subword_prefix = prefix;
151 }
152
153 pub fn end_of_word_suffix(&self) -> &Option<String> {
154 &self.bpe_trainer.end_of_word_suffix
155 }
156
157 pub fn set_end_of_word_suffix(&mut self, suffix: Option<String>) {
158 self.bpe_trainer.end_of_word_suffix = suffix;
159 }
160
161 pub fn builder() -> WordPieceTrainerBuilder {
162 WordPieceTrainerBuilder::default()
163 }
164
165 pub fn train(&self, model: &mut WordPiece) -> Result<Vec<AddedToken>> {
166 let mut bpe = BPE::default();
167 let special_tokens = self.bpe_trainer.train(&mut bpe)?;
168 let new_wordpiece = WordPiece::from_bpe(&bpe);
169
170 model.vocab = new_wordpiece.vocab;
172 model.vocab_r = new_wordpiece.vocab_r;
173 model.continuing_subword_prefix = new_wordpiece.continuing_subword_prefix;
175
176 Ok(special_tokens)
177 }
178}
179
180impl Trainer for WordPieceTrainer {
181 type Model = WordPiece;
182
183 fn train(&self, model: &mut WordPiece) -> Result<Vec<AddedToken>> {
184 self.train(model)
185 }
186
187 fn should_show_progress(&self) -> bool {
188 self.bpe_trainer.should_show_progress()
189 }
190
191 fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
192 where
193 I: Iterator<Item = S> + Send,
194 S: AsRef<str> + Send,
195 F: Fn(&str) -> Result<Vec<String>> + Sync,
196 {
197 self.bpe_trainer.feed(iterator, process)
198 }
199}