usls/core/
processor_config.rs

1use aksr::Builder;
2#[cfg(feature = "tokenizers")]
3use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
4
5use crate::ResizeMode;
6
7/// Image tensor layout formats for organizing image data in memory.
8///
9/// This enum defines different ways to arrange image pixel data in tensors:
10/// - **Batch formats** (with batch dimension): `NCHW`, `NHWC`
11/// - **Single image formats** (no batch dimension): `CHW`, `HWC`
12///
13/// The format affects how image data is stored and accessed in memory,
14/// which is important for compatibility with different model architectures
15/// (e.g., PyTorch typically uses NCHW, TensorFlow uses NHWC).
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ImageTensorLayout {
18    /// NCHW format: (batch, channel, height, width)
19    /// Channels-first layout, commonly used in PyTorch models.
20    NCHW,
21    /// NHWC format: (batch, height, width, channel)
22    /// Channels-last layout, commonly used in TensorFlow models.
23    NHWC,
24    /// CHW format: (channel, height, width)
25    /// Single image with channels-first layout (no batch dimension).
26    CHW,
27    /// HWC format: (height, width, channel)
28    /// Single image with channels-last layout (no batch dimension).
29    HWC,
30}
31
32/// Configuration for image and text processing pipelines.
33#[derive(Builder, Debug, Clone)]
34pub struct ProcessorConfig {
35    // Vision
36    /// Target image width for resizing.
37    pub image_width: Option<u32>,
38    /// Target image height for resizing.
39    pub image_height: Option<u32>,
40    /// Whether to resize the image.
41    pub do_resize: bool,
42    /// Image resizing mode.
43    pub resize_mode: ResizeMode,
44    /// Image resize filter algorithm.
45    pub resize_filter: Option<&'static str>,
46    /// Padding value for image borders.
47    pub padding_value: u8,
48    /// Whether to normalize image values.
49    pub normalize: bool,
50    /// Standard deviation values for normalization.
51    pub image_std: Vec<f32>,
52    /// Mean values for normalization.
53    pub image_mean: Vec<f32>,
54    /// Whether to use NCHW format (channels first).
55    pub nchw: bool,
56    /// Whether to use unsigned integer format.
57    pub unsigned: bool,
58    /// Whether to pad image for super resolution.
59    pub pad_image: bool,
60    /// Padding size for super resolution.
61    pub pad_size: usize,
62    /// Up-scaling factor for super resolution.
63    pub up_scale: f32,
64    /// Image tensor layout format.
65    pub image_tensor_layout: ImageTensorLayout,
66
67    // Text
68    /// Maximum sequence length for tokenization.
69    pub model_max_length: Option<u64>,
70    /// Path to tokenizer file.
71    pub tokenizer_file: Option<String>,
72    /// Path to model configuration file.
73    pub config_file: Option<String>,
74    /// Path to special tokens mapping file.
75    pub special_tokens_map_file: Option<String>,
76    /// Path to tokenizer configuration file.
77    pub tokenizer_config_file: Option<String>,
78    /// Path to generation configuration file.
79    pub generation_config_file: Option<String>,
80    /// Path to vocabulary file.
81    pub vocab_file: Option<String>,
82    /// Path to vocabulary text file.
83    pub vocab_txt: Option<String>,
84    /// Temperature parameter for text generation.
85    pub temperature: f32,
86    /// Top-p parameter for nucleus sampling.
87    pub topp: f32,
88}
89
90impl Default for ProcessorConfig {
91    fn default() -> Self {
92        Self {
93            image_width: None,
94            image_height: None,
95            do_resize: true,
96            resize_mode: ResizeMode::FitExact,
97            resize_filter: Some("Bilinear"),
98            padding_value: 114,
99            image_tensor_layout: ImageTensorLayout::NCHW,
100            normalize: true,
101            image_std: vec![],
102            image_mean: vec![],
103            nchw: true,
104            unsigned: false,
105            pad_image: false,
106            pad_size: 8,
107            up_scale: 2.,
108            model_max_length: None,
109            tokenizer_file: None,
110            config_file: None,
111            special_tokens_map_file: None,
112            tokenizer_config_file: None,
113            generation_config_file: None,
114            vocab_file: None,
115            vocab_txt: None,
116            temperature: 1.0,
117            topp: 0.9,
118        }
119    }
120}
121
122impl ProcessorConfig {
123    #[cfg(feature = "tokenizers")]
124    pub fn try_build_tokenizer(&self) -> anyhow::Result<Option<Tokenizer>> {
125        use crate::Hub;
126        let mut hub = Hub::default();
127
128        // tokenizer file
129        let mut tokenizer: Tokenizer = match &self.tokenizer_file {
130            None => return Ok(None),
131            Some(file) => Tokenizer::from_file(hub.try_fetch(file)?)
132                .map_err(|err| anyhow::anyhow!("Faild to build tokenizer: {err}"))?,
133        };
134
135        // config file
136        // TODO: save configs?
137        let pad_id = match &self.tokenizer_config_file {
138            None => 0u32,
139            Some(file) => match hub.try_fetch(file) {
140                Ok(x) => {
141                    let config: serde_json::Value =
142                        serde_json::from_str(&std::fs::read_to_string(x)?)?;
143                    config["pad_token_id"].as_u64().unwrap_or(0) as u32
144                }
145                Err(_err) => 0u32,
146            },
147        };
148
149        // tokenizer_config file
150        let mut max_length = None;
151        let mut pad_token = String::from("[PAD]");
152
153        if let Some(file) = &self.tokenizer_config_file {
154            match hub.try_fetch(file) {
155                Err(_) => {}
156                Ok(x) => {
157                    let tokenizer_config: serde_json::Value =
158                        serde_json::from_str(&std::fs::read_to_string(x)?)?;
159                    max_length = tokenizer_config["model_max_length"].as_u64();
160                    pad_token = tokenizer_config["pad_token"]
161                        .as_str()
162                        .unwrap_or("[PAD]")
163                        .to_string();
164                }
165            }
166        }
167
168        // TODO: padding
169        // if `max_length` specified: use `Fixed` strategy
170        // else: use `BatchLongest` strategy
171        // TODO: if sequence_length is dynamic, `BatchLongest` is fine
172        let tokenizer = match self.model_max_length {
173            Some(n) => {
174                let n = match max_length {
175                    None => n,
176                    Some(x) => x.min(n),
177                };
178                tokenizer
179                    .with_padding(Some(PaddingParams {
180                        strategy: PaddingStrategy::Fixed(n as _),
181                        pad_token,
182                        pad_id,
183                        ..Default::default()
184                    }))
185                    .clone()
186            }
187            None => match max_length {
188                Some(n) => tokenizer
189                    .with_padding(Some(PaddingParams {
190                        strategy: PaddingStrategy::BatchLongest,
191                        pad_token,
192                        pad_id,
193                        ..Default::default()
194                    }))
195                    .with_truncation(Some(TruncationParams {
196                        max_length: n as _,
197                        ..Default::default()
198                    }))
199                    .map_err(|err| anyhow::anyhow!("Failed to truncate: {}", err))?
200                    .clone(),
201                None => tokenizer
202                    .with_padding(Some(PaddingParams {
203                        strategy: PaddingStrategy::BatchLongest,
204                        pad_token,
205                        pad_id,
206                        ..Default::default()
207                    }))
208                    .clone(),
209            },
210        };
211
212        Ok(Some(tokenizer.into()))
213    }
214}
215
216macro_rules! impl_processor_config_methods {
217    ($ty:ty, $field:ident) => {
218        impl $ty {
219            pub fn with_image_width(mut self, image_width: u32) -> Self {
220                self.$field = self.$field.with_image_width(image_width);
221                self
222            }
223            pub fn with_image_height(mut self, image_height: u32) -> Self {
224                self.$field = self.$field.with_image_height(image_height);
225                self
226            }
227            pub fn with_do_resize(mut self, do_resize: bool) -> Self {
228                self.$field = self.$field.with_do_resize(do_resize);
229                self
230            }
231            pub fn with_resize_mode(mut self, resize_mode: $crate::ResizeMode) -> Self {
232                self.$field = self.$field.with_resize_mode(resize_mode);
233                self
234            }
235            pub fn with_resize_filter(mut self, resize_filter: &'static str) -> Self {
236                self.$field = self.$field.with_resize_filter(resize_filter);
237                self
238            }
239            pub fn with_padding_value(mut self, padding_value: u8) -> Self {
240                self.$field = self.$field.with_padding_value(padding_value);
241                self
242            }
243            pub fn with_normalize(mut self, normalize: bool) -> Self {
244                self.$field = self.$field.with_normalize(normalize);
245                self
246            }
247            pub fn with_image_std(mut self, image_std: &[f32]) -> Self {
248                self.$field = self.$field.with_image_std(image_std);
249                self
250            }
251            pub fn with_image_mean(mut self, image_mean: &[f32]) -> Self {
252                self.$field = self.$field.with_image_mean(image_mean);
253                self
254            }
255            pub fn with_nchw(mut self, nchw: bool) -> Self {
256                self.$field = self.$field.with_nchw(nchw);
257                self
258            }
259            pub fn with_unsigned(mut self, unsigned: bool) -> Self {
260                self.$field = self.$field.with_unsigned(unsigned);
261                self
262            }
263            pub fn with_pad_image(mut self, pad_image: bool) -> Self {
264                self.$field = self.$field.with_pad_image(pad_image);
265                self
266            }
267            pub fn with_pad_size(mut self, pad_size: usize) -> Self {
268                self.$field = self.$field.with_pad_size(pad_size);
269                self
270            }
271            pub fn with_up_scale(mut self, up_scale: f32) -> Self {
272                self.$field = self.$field.with_up_scale(up_scale);
273                self
274            }
275            pub fn with_image_tensor_layout(
276                mut self,
277                image_tensor_layout: $crate::ImageTensorLayout,
278            ) -> Self {
279                self.$field = self.$field.with_image_tensor_layout(image_tensor_layout);
280                self
281            }
282            pub fn with_model_max_length(mut self, model_max_length: u64) -> Self {
283                self.$field = self.$field.with_model_max_length(model_max_length);
284                self
285            }
286            pub fn with_tokenizer_file(mut self, tokenizer_file: &str) -> Self {
287                self.$field = self.$field.with_tokenizer_file(tokenizer_file);
288                self
289            }
290            pub fn with_config_file(mut self, config_file: &str) -> Self {
291                self.$field = self.$field.with_config_file(config_file);
292                self
293            }
294            pub fn with_special_tokens_map_file(mut self, special_tokens_map_file: &str) -> Self {
295                self.$field = self
296                    .$field
297                    .with_special_tokens_map_file(special_tokens_map_file);
298                self
299            }
300            pub fn with_tokenizer_config_file(mut self, tokenizer_config_file: &str) -> Self {
301                self.$field = self
302                    .$field
303                    .with_tokenizer_config_file(tokenizer_config_file);
304                self
305            }
306            pub fn with_generation_config_file(mut self, generation_config_file: &str) -> Self {
307                self.$field = self
308                    .$field
309                    .with_generation_config_file(generation_config_file);
310                self
311            }
312            pub fn with_vocab_file(mut self, vocab_file: &str) -> Self {
313                self.$field = self.$field.with_vocab_file(vocab_file);
314                self
315            }
316            pub fn with_vocab_txt(mut self, vocab_txt: &str) -> Self {
317                self.$field = self.$field.with_vocab_txt(vocab_txt);
318                self
319            }
320            pub fn with_temperature(mut self, temperature: f32) -> Self {
321                self.$field = self.$field.with_temperature(temperature);
322                self
323            }
324            pub fn with_topp(mut self, topp: f32) -> Self {
325                self.$field = self.$field.with_topp(topp);
326                self
327            }
328        }
329    };
330}