1use std::sync::atomic::{AtomicU64, Ordering};
2
3use super::error::ModelError;
4use yscv_tensor::Tensor;
5
6pub trait Transform: Send + Sync {
8 fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError>;
9}
10
11pub struct Compose {
13 transforms: Vec<Box<dyn Transform>>,
14}
15
16impl Default for Compose {
17 fn default() -> Self {
18 Self::new()
19 }
20}
21
22impl Compose {
23 pub fn new() -> Self {
24 Self {
25 transforms: Vec::new(),
26 }
27 }
28
29 pub fn add<T: Transform + 'static>(mut self, t: T) -> Self {
30 self.transforms.push(Box::new(t));
31 self
32 }
33}
34
35impl Transform for Compose {
36 fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError> {
37 let mut current = input.clone();
38 for t in &self.transforms {
39 current = t.apply(¤t)?;
40 }
41 Ok(current)
42 }
43}
44
45pub struct Normalize {
47 pub mean: Vec<f32>,
48 pub std: Vec<f32>,
49}
50
51impl Normalize {
52 pub fn new(mean: Vec<f32>, std: Vec<f32>) -> Self {
53 Self { mean, std }
54 }
55}
56
57impl Transform for Normalize {
58 fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError> {
59 let data = input.data();
61 let c = self.mean.len();
62 let mut out = data.to_vec();
63 for (i, val) in out.iter_mut().enumerate() {
64 let ch = i % c;
65 *val = (*val - self.mean[ch]) / self.std[ch];
66 }
67 Ok(Tensor::from_vec(input.shape().to_vec(), out)?)
68 }
69}
70
71pub struct ScaleValues {
73 pub factor: f32,
74}
75
76impl ScaleValues {
77 pub fn new(factor: f32) -> Self {
78 Self { factor }
79 }
80}
81
82impl Transform for ScaleValues {
83 fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError> {
84 Ok(input.scale(self.factor))
85 }
86}
87
88pub struct PermuteDims {
90 pub order: Vec<usize>,
91}
92
93impl PermuteDims {
94 pub fn new(order: Vec<usize>) -> Self {
95 Self { order }
96 }
97}
98
99impl Transform for PermuteDims {
100 fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError> {
101 Ok(input.permute(&self.order)?)
102 }
103}
104
105pub struct Resize {
108 pub height: usize,
109 pub width: usize,
110}
111
112impl Resize {
113 pub fn new(height: usize, width: usize) -> Self {
114 Self { height, width }
115 }
116}
117
118impl Transform for Resize {
119 fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError> {
120 let shape = input.shape();
121 if shape.len() != 3 {
122 return Err(ModelError::InvalidInputShape {
123 expected_features: 3,
124 got: shape.to_vec(),
125 });
126 }
127 let (in_h, in_w, c) = (shape[0], shape[1], shape[2]);
128 let data = input.data();
129 let out_h = self.height;
130 let out_w = self.width;
131 let mut out = vec![0.0f32; out_h * out_w * c];
132
133 for oh in 0..out_h {
134 for ow in 0..out_w {
135 let sy = if out_h > 1 {
137 oh as f32 * (in_h as f32 - 1.0) / (out_h as f32 - 1.0)
138 } else {
139 (in_h as f32 - 1.0) / 2.0
140 };
141 let sx = if out_w > 1 {
142 ow as f32 * (in_w as f32 - 1.0) / (out_w as f32 - 1.0)
143 } else {
144 (in_w as f32 - 1.0) / 2.0
145 };
146
147 let y0 = sy.floor() as usize;
148 let x0 = sx.floor() as usize;
149 let y1 = (y0 + 1).min(in_h - 1);
150 let x1 = (x0 + 1).min(in_w - 1);
151 let fy = sy - sy.floor();
152 let fx = sx - sx.floor();
153
154 for ch in 0..c {
155 let v00 = data[(y0 * in_w + x0) * c + ch];
156 let v01 = data[(y0 * in_w + x1) * c + ch];
157 let v10 = data[(y1 * in_w + x0) * c + ch];
158 let v11 = data[(y1 * in_w + x1) * c + ch];
159 let val = v00 * (1.0 - fy) * (1.0 - fx)
160 + v01 * (1.0 - fy) * fx
161 + v10 * fy * (1.0 - fx)
162 + v11 * fy * fx;
163 out[(oh * out_w + ow) * c + ch] = val;
164 }
165 }
166 }
167 Ok(Tensor::from_vec(vec![out_h, out_w, c], out)?)
168 }
169}
170
171pub struct CenterCrop {
174 pub size: usize,
175}
176
177impl CenterCrop {
178 pub fn new(size: usize) -> Self {
179 Self { size }
180 }
181}
182
183impl Transform for CenterCrop {
184 fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError> {
185 let shape = input.shape();
186 if shape.len() != 3 {
187 return Err(ModelError::InvalidInputShape {
188 expected_features: 3,
189 got: shape.to_vec(),
190 });
191 }
192 let (h, w) = (shape[0], shape[1]);
193 let start_h = (h.saturating_sub(self.size)) / 2;
194 let start_w = (w.saturating_sub(self.size)) / 2;
195 let cropped = input.narrow(0, start_h, self.size)?;
196 let cropped = cropped.narrow(1, start_w, self.size)?;
197 Ok(cropped)
198 }
199}
200
201pub struct RandomHorizontalFlip {
205 p: f32,
206 seed: AtomicU64,
207}
208
209impl RandomHorizontalFlip {
210 pub fn new(p: f32, seed: u64) -> Self {
211 Self {
212 p,
213 seed: AtomicU64::new(seed),
214 }
215 }
216
217 fn next_rand(&self) -> f32 {
218 let mut s = self.seed.load(Ordering::Relaxed);
219 s ^= s << 13;
220 s ^= s >> 7;
221 s ^= s << 17;
222 self.seed.store(s, Ordering::Relaxed);
223 (s as u32 as f32) / (u32::MAX as f32)
225 }
226}
227
228impl Transform for RandomHorizontalFlip {
229 fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError> {
230 let shape = input.shape();
231 if shape.len() != 3 {
232 return Err(ModelError::InvalidInputShape {
233 expected_features: 3,
234 got: shape.to_vec(),
235 });
236 }
237 if self.next_rand() >= self.p {
238 return Ok(input.clone());
239 }
240 let (h, w, c) = (shape[0], shape[1], shape[2]);
241 let data = input.data();
242 let mut out = vec![0.0f32; h * w * c];
243 for row in 0..h {
244 for col in 0..w {
245 let src = (row * w + (w - 1 - col)) * c;
246 let dst = (row * w + col) * c;
247 out[dst..dst + c].copy_from_slice(&data[src..src + c]);
248 }
249 }
250 Ok(Tensor::from_vec(shape.to_vec(), out)?)
251 }
252}
253
254pub struct GaussianBlur {
257 pub kernel_size: usize,
258 pub sigma: f32,
259}
260
261impl GaussianBlur {
262 pub fn new(kernel_size: usize, sigma: f32) -> Self {
263 Self { kernel_size, sigma }
264 }
265
266 fn build_kernel(&self) -> Vec<f32> {
267 let ks = self.kernel_size;
268 let half = ks as f32 / 2.0;
269 let mut kernel = vec![0.0f32; ks * ks];
270 let mut sum = 0.0f32;
271 for ky in 0..ks {
272 for kx in 0..ks {
273 let dy = ky as f32 - half + 0.5;
274 let dx = kx as f32 - half + 0.5;
275 let val = (-(dy * dy + dx * dx) / (2.0 * self.sigma * self.sigma)).exp();
276 kernel[ky * ks + kx] = val;
277 sum += val;
278 }
279 }
280 for v in kernel.iter_mut() {
281 *v /= sum;
282 }
283 kernel
284 }
285}
286
287impl Transform for GaussianBlur {
288 fn apply(&self, input: &Tensor) -> Result<Tensor, ModelError> {
289 let shape = input.shape();
290 if shape.len() != 3 {
291 return Err(ModelError::InvalidInputShape {
292 expected_features: 3,
293 got: shape.to_vec(),
294 });
295 }
296 let (h, w, c) = (shape[0], shape[1], shape[2]);
297 let ks = self.kernel_size;
298 let pad = ks / 2;
299 let kernel = self.build_kernel();
300 let data = input.data();
301 let mut out = vec![0.0f32; h * w * c];
302
303 for row in 0..h {
304 for col in 0..w {
305 for ch in 0..c {
306 let mut acc = 0.0f32;
307 for ky in 0..ks {
308 for kx in 0..ks {
309 let sy = row as isize + ky as isize - pad as isize;
310 let sx = col as isize + kx as isize - pad as isize;
311 let sy = sy.max(0).min(h as isize - 1) as usize;
313 let sx = sx.max(0).min(w as isize - 1) as usize;
314 acc += data[(sy * w + sx) * c + ch] * kernel[ky * ks + kx];
315 }
316 }
317 out[(row * w + col) * c + ch] = acc;
318 }
319 }
320 }
321 Ok(Tensor::from_vec(shape.to_vec(), out)?)
322 }
323}