1use crate::error::Result;
4use scirs2_core::ndarray::ArrayStatCompat;
5use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
6use scirs2_core::numeric::{Float, FromPrimitive};
7use statrs::statistics::Statistics;
8use std::fmt::Debug;
9pub trait Transform<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync>:
11 Send + Sync + Debug
12{
13 fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>;
15 fn description(&self) -> String;
17 fn box_clone(&self) -> Box<dyn Transform<F> + Send + Sync>;
19}
20#[derive(Debug, Clone)]
22pub struct StandardScaler<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> {
23 mean: Option<Array<F, IxDyn>>,
25 std: Option<Array<F, IxDyn>>,
27 fit_per_sample: bool,
29}
30
31impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> StandardScaler<F> {
32 pub fn new(fit_per_sample: bool) -> Self {
34 Self {
35 mean: None,
36 std: None,
37 fit_per_sample,
38 }
39 }
40 pub fn fit(&mut self, data: &Array<F, IxDyn>) -> Result<&mut Self> {
42 let zero = F::from(0.0).unwrap_or(F::zero());
43 if data.ndim() < 2 {
44 let mean = data.mean_or(F::zero());
46 let std = data.std(zero);
47 self.mean = Some(Array::from_elem(IxDyn(&[1]), mean));
48 self.std = Some(Array::from_elem(IxDyn(&[1]), std));
49 } else if self.fit_per_sample {
50 let axis = 1; let mean = data
53 .mean_axis(scirs2_core::ndarray::Axis(axis))
54 .unwrap_or(Array::zeros(IxDyn(&[data.shape()[0]])));
55 let std = data.std_axis(scirs2_core::ndarray::Axis(axis), zero);
56 self.mean = Some(mean);
57 self.std = Some(std);
58 } else {
59 let axis = 0; let mean = data
62 .mean_axis(scirs2_core::ndarray::Axis(axis))
63 .unwrap_or(Array::zeros(IxDyn(&[data.shape()[1]])));
64 let std = data.std_axis(scirs2_core::ndarray::Axis(axis), zero);
65 self.mean = Some(mean);
66 self.std = Some(std);
67 }
68 Ok(self)
69 }
70
71 pub fn transform(&self, data: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
73 if self.mean.is_none() || self.std.is_none() {
74 return Err(crate::error::NeuralError::InferenceError(
75 "StandardScaler has not been fitted".to_string(),
76 ));
77 }
78
79 let mean = self.mean.as_ref().expect("Operation failed");
80 let std = self.std.as_ref().expect("Operation failed");
81 let mut result = data.clone();
82
83 if data.ndim() < 2 {
84 let mean_val = mean[[0]];
86 let std_val = std[[0]].max(F::epsilon());
87 for item in result.iter_mut() {
88 *item = (*item - mean_val) / std_val;
89 }
90 } else if self.fit_per_sample {
91 for i in 0..data.shape()[0] {
93 let mean_val = mean[[i]];
94 let std_val = std[[i]].max(F::epsilon());
95 for j in 0..data.shape()[1] {
96 result[[i, j]] = (data[[i, j]] - mean_val) / std_val;
97 }
98 }
99 } else {
100 for j in 0..data.shape()[1] {
102 let mean_val = mean[[j]];
103 let std_val = std[[j]].max(F::epsilon());
104 for i in 0..data.shape()[0] {
105 result[[i, j]] = (data[[i, j]] - mean_val) / std_val;
106 }
107 }
108 }
109
110 Ok(result)
111 }
112
113 pub fn fit_transform(&mut self, data: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
115 self.fit(data)?;
116 self.transform(data)
117 }
118}
119
120impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> Transform<F>
121 for StandardScaler<F>
122{
123 fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
124 self.transform(input)
125 }
126
127 fn description(&self) -> String {
128 if self.fit_per_sample {
129 "StandardScaler (per-sample)".to_string()
130 } else {
131 "StandardScaler (per-feature)".to_string()
132 }
133 }
134
135 fn box_clone(&self) -> Box<dyn Transform<F> + Send + Sync> {
136 Box::new(self.clone())
137 }
138}
139
140#[derive(Debug, Clone)]
142pub struct MinMaxScaler<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> {
143 min: Option<Array<F, IxDyn>>,
145 max: Option<Array<F, IxDyn>>,
147 range: (F, F),
149 fit_per_sample: bool,
151}
152
153impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> MinMaxScaler<F> {
154 pub fn new(fit_per_sample: bool) -> Self {
156 Self::with_range(F::zero(), F::one(), fit_per_sample)
157 }
158
159 pub fn with_range(min_val: F, max_val: F, fit_per_sample: bool) -> Self {
161 Self {
162 min: None,
163 max: None,
164 range: (min_val, max_val),
165 fit_per_sample,
166 }
167 }
168
169 pub fn fit(&mut self, data: &Array<F, IxDyn>) -> Result<&mut Self> {
171 if data.ndim() < 2 {
172 let min = match data
174 .iter()
175 .min_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
176 {
177 Some(&val) => val,
178 None => F::zero(),
179 };
180 let max = match data
181 .iter()
182 .max_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
183 {
184 Some(&val) => val,
185 None => F::one(),
186 };
187 self.min = Some(Array::from_elem(IxDyn(&[1]), min));
188 self.max = Some(Array::from_elem(IxDyn(&[1]), max));
189 } else if self.fit_per_sample {
190 let mut min_vals = Array::zeros(IxDyn(&[data.shape()[0]]));
192 let mut max_vals = Array::zeros(IxDyn(&[data.shape()[0]]));
193 for i in 0..data.shape()[0] {
194 let row = data.slice(scirs2_core::ndarray::s![i, ..]);
195 let min = match row
196 .iter()
197 .min_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
198 {
199 Some(&val) => val,
200 None => F::zero(),
201 };
202 let max = match row
203 .iter()
204 .max_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
205 {
206 Some(&val) => val,
207 None => F::one(),
208 };
209 min_vals[[i]] = min;
210 max_vals[[i]] = max;
211 }
212 self.min = Some(min_vals);
213 self.max = Some(max_vals);
214 } else {
215 let mut min_vals = Array::zeros(IxDyn(&[data.shape()[1]]));
217 let mut max_vals = Array::zeros(IxDyn(&[data.shape()[1]]));
218 for j in 0..data.shape()[1] {
219 let col = data.slice(scirs2_core::ndarray::s![.., j]);
220 let min = match col
221 .iter()
222 .min_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
223 {
224 Some(&val) => val,
225 None => F::zero(),
226 };
227 let max = match col
228 .iter()
229 .max_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
230 {
231 Some(&val) => val,
232 None => F::one(),
233 };
234 min_vals[[j]] = min;
235 max_vals[[j]] = max;
236 }
237 self.min = Some(min_vals);
238 self.max = Some(max_vals);
239 }
240 Ok(self)
241 }
242
243 pub fn transform(&self, data: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
245 if self.min.is_none() || self.max.is_none() {
246 return Err(crate::error::NeuralError::InferenceError(
247 "MinMaxScaler has not been fitted".to_string(),
248 ));
249 }
250
251 let min = self.min.as_ref().expect("Operation failed");
252 let max = self.max.as_ref().expect("Operation failed");
253 let (range_min, range_max) = self.range;
254 let range_diff = range_max - range_min;
255 let mut result = data.clone();
256
257 if data.ndim() < 2 {
258 let min_val = min[[0]];
260 let max_val = max[[0]];
261 let scale = if max_val > min_val {
262 F::one() / (max_val - min_val)
263 } else {
264 F::one()
265 };
266 for item in result.iter_mut() {
267 *item = range_min + range_diff * ((*item - min_val) * scale);
268 }
269 } else if self.fit_per_sample {
270 for i in 0..data.shape()[0] {
272 let min_val = min[[i]];
273 let max_val = max[[i]];
274 let scale = if max_val > min_val {
275 F::one() / (max_val - min_val)
276 } else {
277 F::one()
278 };
279 for j in 0..data.shape()[1] {
280 result[[i, j]] = range_min + range_diff * ((data[[i, j]] - min_val) * scale);
281 }
282 }
283 } else {
284 for j in 0..data.shape()[1] {
286 let min_val = min[[j]];
287 let max_val = max[[j]];
288 let scale = if max_val > min_val {
289 F::one() / (max_val - min_val)
290 } else {
291 F::one()
292 };
293 for i in 0..data.shape()[0] {
294 result[[i, j]] = range_min + range_diff * ((data[[i, j]] - min_val) * scale);
295 }
296 }
297 }
298
299 Ok(result)
300 }
301
302 pub fn fit_transform(&mut self, data: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
304 self.fit(data)?;
305 self.transform(data)
306 }
307}
308
309impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> Transform<F>
310 for MinMaxScaler<F>
311{
312 fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
313 self.transform(input)
314 }
315
316 fn description(&self) -> String {
317 format!(
318 "MinMaxScaler (range: [{:.1}, {:.1}], {})",
319 self.range.0.to_f64().unwrap_or(0.0),
320 self.range.1.to_f64().unwrap_or(1.0),
321 if self.fit_per_sample {
322 "per-sample"
323 } else {
324 "per-feature"
325 }
326 )
327 }
328
329 fn box_clone(&self) -> Box<dyn Transform<F> + Send + Sync> {
330 Box::new(self.clone())
331 }
332}
333
334#[derive(Debug, Clone)]
336pub struct OneHotEncoder<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> {
337 n_classes: usize,
339 _phantom: std::marker::PhantomData<F>,
341}
342
343impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> OneHotEncoder<F> {
344 pub fn new(n_classes: usize) -> Self {
346 Self {
347 n_classes,
348 _phantom: std::marker::PhantomData,
349 }
350 }
351
352 pub fn transform(&self, data: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
354 let shape = data.shape();
355 let n_samples = shape[0];
356
357 let mut result = Array::zeros(IxDyn(&[n_samples, self.n_classes]));
359
360 for i in 0..n_samples {
362 let class_idx = data[[i]].to_usize().unwrap_or(0);
363 if class_idx >= self.n_classes {
364 return Err(crate::error::NeuralError::InferenceError(format!(
365 "Class index {} is out of bounds for {} classes",
366 class_idx, self.n_classes
367 )));
368 }
369 result[[i, class_idx]] = F::one();
370 }
371
372 Ok(result)
373 }
374}
375
376impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> Transform<F>
377 for OneHotEncoder<F>
378{
379 fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
380 self.transform(input)
381 }
382
383 fn description(&self) -> String {
384 format!("OneHotEncoder (n_classes: {})", self.n_classes)
385 }
386
387 fn box_clone(&self) -> Box<dyn Transform<F> + Send + Sync> {
388 Box::new(self.clone())
389 }
390}
391
392pub struct ComposeTransform<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> {
394 transforms: Vec<Box<dyn Transform<F> + Send + Sync>>,
396}
397
398struct DebugTransformWrapper<'a, F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> {
400 inner: &'a (dyn Transform<F> + Send + Sync),
402}
403
404impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> Debug
405 for DebugTransformWrapper<'_, F>
406{
407 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408 write!(f, "Transform({})", self.inner.description())
409 }
410}
411
412impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> Debug for ComposeTransform<F> {
413 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
414 let mut debug_list = f.debug_list();
415 for transform in &self.transforms {
416 debug_list.entry(&DebugTransformWrapper {
417 inner: transform.as_ref(),
418 });
419 }
420 debug_list.finish()
421 }
422}
423
424impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> Clone for ComposeTransform<F> {
425 fn clone(&self) -> Self {
426 Self {
427 transforms: self
428 .transforms
429 .iter()
430 .map(|transform| transform.box_clone())
431 .collect(),
432 }
433 }
434}
435
436impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> ComposeTransform<F> {
437 pub fn new(transforms: Vec<Box<dyn Transform<F> + Send + Sync>>) -> Self {
439 Self { transforms }
440 }
441}
442
443impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> Transform<F>
444 for ComposeTransform<F>
445{
446 fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
447 let mut data = input.clone();
448 for transform in &self.transforms {
449 data = transform.apply(&data)?;
450 }
451 Ok(data)
452 }
453
454 fn description(&self) -> String {
455 let descriptions: Vec<String> = self.transforms.iter().map(|t| t.description()).collect();
456 format!("Compose({})", descriptions.join(", "))
457 }
458
459 fn box_clone(&self) -> Box<dyn Transform<F> + Send + Sync> {
460 Box::new(self.clone())
461 }
462}