1use aksr::Builder;
2#[cfg(feature = "tokenizers")]
3use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
4
5use crate::ResizeMode;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ImageTensorLayout {
18 NCHW,
21 NHWC,
24 CHW,
27 HWC,
30}
31
32#[derive(Builder, Debug, Clone)]
34pub struct ProcessorConfig {
35 pub image_width: Option<u32>,
38 pub image_height: Option<u32>,
40 pub do_resize: bool,
42 pub resize_mode: ResizeMode,
44 pub resize_filter: Option<&'static str>,
46 pub padding_value: u8,
48 pub normalize: bool,
50 pub image_std: Vec<f32>,
52 pub image_mean: Vec<f32>,
54 pub nchw: bool,
56 pub unsigned: bool,
58 pub pad_image: bool,
60 pub pad_size: usize,
62 pub up_scale: f32,
64 pub image_tensor_layout: ImageTensorLayout,
66
67 pub model_max_length: Option<u64>,
70 pub tokenizer_file: Option<String>,
72 pub config_file: Option<String>,
74 pub special_tokens_map_file: Option<String>,
76 pub tokenizer_config_file: Option<String>,
78 pub generation_config_file: Option<String>,
80 pub vocab_file: Option<String>,
82 pub vocab_txt: Option<String>,
84 pub temperature: f32,
86 pub topp: f32,
88}
89
90impl Default for ProcessorConfig {
91 fn default() -> Self {
92 Self {
93 image_width: None,
94 image_height: None,
95 do_resize: true,
96 resize_mode: ResizeMode::FitExact,
97 resize_filter: Some("Bilinear"),
98 padding_value: 114,
99 image_tensor_layout: ImageTensorLayout::NCHW,
100 normalize: true,
101 image_std: vec![],
102 image_mean: vec![],
103 nchw: true,
104 unsigned: false,
105 pad_image: false,
106 pad_size: 8,
107 up_scale: 2.,
108 model_max_length: None,
109 tokenizer_file: None,
110 config_file: None,
111 special_tokens_map_file: None,
112 tokenizer_config_file: None,
113 generation_config_file: None,
114 vocab_file: None,
115 vocab_txt: None,
116 temperature: 1.0,
117 topp: 0.9,
118 }
119 }
120}
121
122impl ProcessorConfig {
123 #[cfg(feature = "tokenizers")]
124 pub fn try_build_tokenizer(&self) -> anyhow::Result<Option<Tokenizer>> {
125 use crate::Hub;
126 let mut hub = Hub::default();
127
128 let mut tokenizer: Tokenizer = match &self.tokenizer_file {
130 None => return Ok(None),
131 Some(file) => Tokenizer::from_file(hub.try_fetch(file)?)
132 .map_err(|err| anyhow::anyhow!("Faild to build tokenizer: {err}"))?,
133 };
134
135 let pad_id = match &self.tokenizer_config_file {
138 None => 0u32,
139 Some(file) => match hub.try_fetch(file) {
140 Ok(x) => {
141 let config: serde_json::Value =
142 serde_json::from_str(&std::fs::read_to_string(x)?)?;
143 config["pad_token_id"].as_u64().unwrap_or(0) as u32
144 }
145 Err(_err) => 0u32,
146 },
147 };
148
149 let mut max_length = None;
151 let mut pad_token = String::from("[PAD]");
152
153 if let Some(file) = &self.tokenizer_config_file {
154 match hub.try_fetch(file) {
155 Err(_) => {}
156 Ok(x) => {
157 let tokenizer_config: serde_json::Value =
158 serde_json::from_str(&std::fs::read_to_string(x)?)?;
159 max_length = tokenizer_config["model_max_length"].as_u64();
160 pad_token = tokenizer_config["pad_token"]
161 .as_str()
162 .unwrap_or("[PAD]")
163 .to_string();
164 }
165 }
166 }
167
168 let tokenizer = match self.model_max_length {
173 Some(n) => {
174 let n = match max_length {
175 None => n,
176 Some(x) => x.min(n),
177 };
178 tokenizer
179 .with_padding(Some(PaddingParams {
180 strategy: PaddingStrategy::Fixed(n as _),
181 pad_token,
182 pad_id,
183 ..Default::default()
184 }))
185 .clone()
186 }
187 None => match max_length {
188 Some(n) => tokenizer
189 .with_padding(Some(PaddingParams {
190 strategy: PaddingStrategy::BatchLongest,
191 pad_token,
192 pad_id,
193 ..Default::default()
194 }))
195 .with_truncation(Some(TruncationParams {
196 max_length: n as _,
197 ..Default::default()
198 }))
199 .map_err(|err| anyhow::anyhow!("Failed to truncate: {}", err))?
200 .clone(),
201 None => tokenizer
202 .with_padding(Some(PaddingParams {
203 strategy: PaddingStrategy::BatchLongest,
204 pad_token,
205 pad_id,
206 ..Default::default()
207 }))
208 .clone(),
209 },
210 };
211
212 Ok(Some(tokenizer.into()))
213 }
214}
215
216macro_rules! impl_processor_config_methods {
217 ($ty:ty, $field:ident) => {
218 impl $ty {
219 pub fn with_image_width(mut self, image_width: u32) -> Self {
220 self.$field = self.$field.with_image_width(image_width);
221 self
222 }
223 pub fn with_image_height(mut self, image_height: u32) -> Self {
224 self.$field = self.$field.with_image_height(image_height);
225 self
226 }
227 pub fn with_do_resize(mut self, do_resize: bool) -> Self {
228 self.$field = self.$field.with_do_resize(do_resize);
229 self
230 }
231 pub fn with_resize_mode(mut self, resize_mode: $crate::ResizeMode) -> Self {
232 self.$field = self.$field.with_resize_mode(resize_mode);
233 self
234 }
235 pub fn with_resize_filter(mut self, resize_filter: &'static str) -> Self {
236 self.$field = self.$field.with_resize_filter(resize_filter);
237 self
238 }
239 pub fn with_padding_value(mut self, padding_value: u8) -> Self {
240 self.$field = self.$field.with_padding_value(padding_value);
241 self
242 }
243 pub fn with_normalize(mut self, normalize: bool) -> Self {
244 self.$field = self.$field.with_normalize(normalize);
245 self
246 }
247 pub fn with_image_std(mut self, image_std: &[f32]) -> Self {
248 self.$field = self.$field.with_image_std(image_std);
249 self
250 }
251 pub fn with_image_mean(mut self, image_mean: &[f32]) -> Self {
252 self.$field = self.$field.with_image_mean(image_mean);
253 self
254 }
255 pub fn with_nchw(mut self, nchw: bool) -> Self {
256 self.$field = self.$field.with_nchw(nchw);
257 self
258 }
259 pub fn with_unsigned(mut self, unsigned: bool) -> Self {
260 self.$field = self.$field.with_unsigned(unsigned);
261 self
262 }
263 pub fn with_pad_image(mut self, pad_image: bool) -> Self {
264 self.$field = self.$field.with_pad_image(pad_image);
265 self
266 }
267 pub fn with_pad_size(mut self, pad_size: usize) -> Self {
268 self.$field = self.$field.with_pad_size(pad_size);
269 self
270 }
271 pub fn with_up_scale(mut self, up_scale: f32) -> Self {
272 self.$field = self.$field.with_up_scale(up_scale);
273 self
274 }
275 pub fn with_image_tensor_layout(
276 mut self,
277 image_tensor_layout: $crate::ImageTensorLayout,
278 ) -> Self {
279 self.$field = self.$field.with_image_tensor_layout(image_tensor_layout);
280 self
281 }
282 pub fn with_model_max_length(mut self, model_max_length: u64) -> Self {
283 self.$field = self.$field.with_model_max_length(model_max_length);
284 self
285 }
286 pub fn with_tokenizer_file(mut self, tokenizer_file: &str) -> Self {
287 self.$field = self.$field.with_tokenizer_file(tokenizer_file);
288 self
289 }
290 pub fn with_config_file(mut self, config_file: &str) -> Self {
291 self.$field = self.$field.with_config_file(config_file);
292 self
293 }
294 pub fn with_special_tokens_map_file(mut self, special_tokens_map_file: &str) -> Self {
295 self.$field = self
296 .$field
297 .with_special_tokens_map_file(special_tokens_map_file);
298 self
299 }
300 pub fn with_tokenizer_config_file(mut self, tokenizer_config_file: &str) -> Self {
301 self.$field = self
302 .$field
303 .with_tokenizer_config_file(tokenizer_config_file);
304 self
305 }
306 pub fn with_generation_config_file(mut self, generation_config_file: &str) -> Self {
307 self.$field = self
308 .$field
309 .with_generation_config_file(generation_config_file);
310 self
311 }
312 pub fn with_vocab_file(mut self, vocab_file: &str) -> Self {
313 self.$field = self.$field.with_vocab_file(vocab_file);
314 self
315 }
316 pub fn with_vocab_txt(mut self, vocab_txt: &str) -> Self {
317 self.$field = self.$field.with_vocab_txt(vocab_txt);
318 self
319 }
320 pub fn with_temperature(mut self, temperature: f32) -> Self {
321 self.$field = self.$field.with_temperature(temperature);
322 self
323 }
324 pub fn with_topp(mut self, topp: f32) -> Self {
325 self.$field = self.$field.with_topp(topp);
326 self
327 }
328 }
329 };
330}