1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use sklears_core::error::SklearsError;
9use sklears_core::traits::{Fit, Predict};
10use std::collections::HashMap;
11
12pub trait BaselineStrategy: Send + Sync + std::fmt::Debug {
14 type Config: Clone + std::fmt::Debug;
15 type FittedData: Clone + std::fmt::Debug;
16 type Prediction: Clone + std::fmt::Debug;
17
18 fn name(&self) -> &'static str;
20
21 fn fit(
23 &self,
24 config: &Self::Config,
25 x: &ArrayView2<f64>,
26 y: &ArrayView1<f64>,
27 ) -> Result<Self::FittedData, SklearsError>;
28
29 fn predict(
31 &self,
32 fitted_data: &Self::FittedData,
33 x: &ArrayView2<f64>,
34 ) -> Result<Vec<Self::Prediction>, SklearsError>;
35
36 fn validate_config(&self, config: &Self::Config) -> Result<(), SklearsError>;
38}
39
40pub trait ClassificationStrategy: BaselineStrategy<Prediction = i32> {
42 fn predict_proba(
44 &self,
45 fitted_data: &Self::FittedData,
46 x: &ArrayView2<f64>,
47 ) -> Result<Vec<HashMap<i32, f64>>, SklearsError> {
48 let predictions = self.predict(fitted_data, x)?;
50 let uniform_proba = predictions
51 .iter()
52 .map(|&pred| [(pred, 1.0)].iter().cloned().collect())
53 .collect();
54 Ok(uniform_proba)
55 }
56
57 fn decision_function(
59 &self,
60 fitted_data: &Self::FittedData,
61 x: &ArrayView2<f64>,
62 ) -> Result<Vec<f64>, SklearsError> {
63 Ok(vec![0.0; x.nrows()])
65 }
66}
67
68pub trait RegressionStrategy: BaselineStrategy<Prediction = f64> {
70 fn predict_interval(
72 &self,
73 fitted_data: &Self::FittedData,
74 x: &ArrayView2<f64>,
75 confidence: f64,
76 ) -> Result<Vec<(f64, f64)>, SklearsError> {
77 let predictions = self.predict(fitted_data, x)?;
79 let intervals = predictions.iter().map(|&pred| (pred, pred)).collect();
80 Ok(intervals)
81 }
82}
83
84#[derive(Debug, Clone)]
86pub struct MostFrequentConfig {
87 pub random_state: Option<u64>,
89}
90
91#[derive(Debug, Clone)]
93pub struct MostFrequentFittedData {
94 pub most_frequent_class: i32,
96 pub class_counts: HashMap<i32, usize>,
98 pub class_priors: HashMap<i32, f64>,
100}
101
102#[derive(Debug, Clone)]
104pub struct MostFrequentStrategy;
105
106impl BaselineStrategy for MostFrequentStrategy {
107 type Config = MostFrequentConfig;
108 type FittedData = MostFrequentFittedData;
109 type Prediction = i32;
110
111 fn name(&self) -> &'static str {
112 "most_frequent"
113 }
114
115 fn fit(
116 &self,
117 config: &Self::Config,
118 _x: &ArrayView2<f64>,
119 y: &ArrayView1<f64>,
120 ) -> Result<Self::FittedData, SklearsError> {
121 self.validate_config(config)?;
122
123 let mut class_counts = HashMap::new();
124 let n_samples = y.len();
125
126 for &value in y.iter() {
128 let class = value as i32;
129 *class_counts.entry(class).or_insert(0) += 1;
130 }
131
132 if class_counts.is_empty() {
133 return Err(SklearsError::InvalidInput("No classes found".to_string()));
134 }
135
136 let most_frequent_class = class_counts
138 .iter()
139 .max_by_key(|(_, &count)| count)
140 .map(|(&class, _)| class)
141 .unwrap();
142
143 let class_priors = class_counts
145 .iter()
146 .map(|(&class, &count)| (class, count as f64 / n_samples as f64))
147 .collect();
148
149 Ok(MostFrequentFittedData {
150 most_frequent_class,
151 class_counts,
152 class_priors,
153 })
154 }
155
156 fn predict(
157 &self,
158 fitted_data: &Self::FittedData,
159 x: &ArrayView2<f64>,
160 ) -> Result<Vec<Self::Prediction>, SklearsError> {
161 Ok(vec![fitted_data.most_frequent_class; x.nrows()])
162 }
163
164 fn validate_config(&self, _config: &Self::Config) -> Result<(), SklearsError> {
165 Ok(())
167 }
168}
169
170impl ClassificationStrategy for MostFrequentStrategy {
171 fn predict_proba(
172 &self,
173 fitted_data: &Self::FittedData,
174 x: &ArrayView2<f64>,
175 ) -> Result<Vec<HashMap<i32, f64>>, SklearsError> {
176 let probabilities = vec![fitted_data.class_priors.clone(); x.nrows()];
177 Ok(probabilities)
178 }
179}
180
181#[derive(Debug, Clone)]
183pub struct MeanConfig {
184 pub random_state: Option<u64>,
186}
187
188#[derive(Debug, Clone)]
190pub struct MeanFittedData {
191 pub target_mean: f64,
193 pub target_std: f64,
195 pub n_samples: usize,
197}
198
199#[derive(Debug, Clone)]
201pub struct MeanStrategy;
202
203impl BaselineStrategy for MeanStrategy {
204 type Config = MeanConfig;
205 type FittedData = MeanFittedData;
206 type Prediction = f64;
207
208 fn name(&self) -> &'static str {
209 "mean"
210 }
211
212 fn fit(
213 &self,
214 config: &Self::Config,
215 _x: &ArrayView2<f64>,
216 y: &ArrayView1<f64>,
217 ) -> Result<Self::FittedData, SklearsError> {
218 self.validate_config(config)?;
219
220 if y.is_empty() {
221 return Err(SklearsError::InvalidInput("Empty target array".to_string()));
222 }
223
224 let n_samples = y.len();
225 let target_mean = y.iter().sum::<f64>() / n_samples as f64;
226
227 let target_std = if n_samples > 1 {
228 let variance = y
229 .iter()
230 .map(|&value| (value - target_mean).powi(2))
231 .sum::<f64>()
232 / (n_samples - 1) as f64;
233 variance.sqrt()
234 } else {
235 0.0
236 };
237
238 Ok(MeanFittedData {
239 target_mean,
240 target_std,
241 n_samples,
242 })
243 }
244
245 fn predict(
246 &self,
247 fitted_data: &Self::FittedData,
248 x: &ArrayView2<f64>,
249 ) -> Result<Vec<Self::Prediction>, SklearsError> {
250 Ok(vec![fitted_data.target_mean; x.nrows()])
251 }
252
253 fn validate_config(&self, _config: &Self::Config) -> Result<(), SklearsError> {
254 Ok(())
256 }
257}
258
259impl RegressionStrategy for MeanStrategy {
260 fn predict_interval(
261 &self,
262 fitted_data: &Self::FittedData,
263 x: &ArrayView2<f64>,
264 confidence: f64,
265 ) -> Result<Vec<(f64, f64)>, SklearsError> {
266 if !(0.0..=1.0).contains(&confidence) {
267 return Err(SklearsError::InvalidInput(
268 "Confidence must be between 0 and 1".to_string(),
269 ));
270 }
271
272 let z_score = if confidence >= 0.99 {
274 2.576
275 } else if confidence >= 0.95 {
276 1.96
277 } else {
278 1.0
279 };
280
281 let margin = z_score * fitted_data.target_std;
282 let lower = fitted_data.target_mean - margin;
283 let upper = fitted_data.target_mean + margin;
284
285 Ok(vec![(lower, upper); x.nrows()])
286 }
287}
288
289pub struct StrategyRegistry {
291 classification_strategies: Vec<String>,
292 regression_strategies: Vec<String>,
293}
294
295impl Default for StrategyRegistry {
296 fn default() -> Self {
297 Self::new()
298 }
299}
300
301impl StrategyRegistry {
302 pub fn new() -> Self {
304 Self {
305 classification_strategies: vec!["most_frequent".to_string()],
306 regression_strategies: vec!["mean".to_string()],
307 }
308 }
309
310 pub fn list_classification_strategies(&self) -> Vec<String> {
312 self.classification_strategies.clone()
313 }
314
315 pub fn list_regression_strategies(&self) -> Vec<String> {
317 self.regression_strategies.clone()
318 }
319}
320
321pub struct PredictionPipeline<S: BaselineStrategy + Clone> {
323 strategy: S,
324 preprocessors: Vec<Box<dyn Preprocessor>>,
325 postprocessors: Vec<Box<dyn Postprocessor<S::Prediction>>>,
326}
327
328impl<S: BaselineStrategy + Clone> PredictionPipeline<S> {
329 pub fn new(strategy: S) -> Self {
331 Self {
332 strategy,
333 preprocessors: Vec::new(),
334 postprocessors: Vec::new(),
335 }
336 }
337
338 pub fn with_preprocessor(mut self, preprocessor: Box<dyn Preprocessor>) -> Self {
340 self.preprocessors.push(preprocessor);
341 self
342 }
343
344 pub fn with_postprocessor(
346 mut self,
347 postprocessor: Box<dyn Postprocessor<S::Prediction>>,
348 ) -> Self {
349 self.postprocessors.push(postprocessor);
350 self
351 }
352
353 pub fn fit(
355 &self,
356 config: &S::Config,
357 x: &ArrayView2<f64>,
358 y: &ArrayView1<f64>,
359 ) -> Result<FittedPipeline<S>, SklearsError> {
360 let mut processed_x = x.to_owned();
362 let mut processed_y = y.to_owned();
363
364 for preprocessor in &self.preprocessors {
365 let (new_x, new_y) =
366 preprocessor.transform(&processed_x.view(), &processed_y.view())?;
367 processed_x = new_x;
368 processed_y = new_y;
369 }
370
371 let fitted_data = self
373 .strategy
374 .fit(config, &processed_x.view(), &processed_y.view())?;
375
376 Ok(FittedPipeline {
377 strategy: self.strategy.clone(),
378 fitted_data,
379 preprocessors: Vec::new(), postprocessors: Vec::new(), })
382 }
383}
384
385pub struct FittedPipeline<S: BaselineStrategy + Clone> {
387 strategy: S,
388 fitted_data: S::FittedData,
389 preprocessors: Vec<Box<dyn Preprocessor>>,
390 postprocessors: Vec<Box<dyn Postprocessor<S::Prediction>>>,
391}
392
393impl<S: BaselineStrategy + Clone> FittedPipeline<S> {
394 pub fn predict(&self, x: &ArrayView2<f64>) -> Result<Vec<S::Prediction>, SklearsError> {
396 let mut processed_x = x.to_owned();
398 for preprocessor in &self.preprocessors {
399 let (new_x, _) =
400 preprocessor.transform(&processed_x.view(), &ArrayView1::from(&[0.0][..]))?;
401 processed_x = new_x;
402 }
403
404 let mut predictions = self
406 .strategy
407 .predict(&self.fitted_data, &processed_x.view())?;
408
409 for postprocessor in &self.postprocessors {
411 predictions = postprocessor.transform(&predictions)?;
412 }
413
414 Ok(predictions)
415 }
416}
417
418pub trait Preprocessor: Send + Sync + std::fmt::Debug {
420 fn transform(
421 &self,
422 x: &ArrayView2<f64>,
423 y: &ArrayView1<f64>,
424 ) -> Result<(Array2<f64>, Array1<f64>), SklearsError>;
425}
426
427pub trait Postprocessor<T>: Send + Sync + std::fmt::Debug {
429 fn transform(&self, predictions: &[T]) -> Result<Vec<T>, SklearsError>;
430}
431
432#[derive(Debug, Clone)]
434pub struct StandardScaler {
435 mean: Vec<f64>,
436 std: Vec<f64>,
437 fitted: bool,
438}
439
440impl Default for StandardScaler {
441 fn default() -> Self {
442 Self::new()
443 }
444}
445
446impl StandardScaler {
447 pub fn new() -> Self {
448 Self {
449 mean: Vec::new(),
450 std: Vec::new(),
451 fitted: false,
452 }
453 }
454
455 pub fn fit(&mut self, x: &ArrayView2<f64>) -> Result<(), SklearsError> {
456 let n_features = x.ncols();
457 let n_samples = x.nrows();
458
459 if n_samples == 0 {
460 return Err(SklearsError::InvalidInput("Empty input array".to_string()));
461 }
462
463 self.mean = vec![0.0; n_features];
464 self.std = vec![1.0; n_features];
465
466 for j in 0..n_features {
468 self.mean[j] = x.column(j).iter().sum::<f64>() / n_samples as f64;
469 }
470
471 if n_samples > 1 {
473 for j in 0..n_features {
474 let variance = x
475 .column(j)
476 .iter()
477 .map(|&value| (value - self.mean[j]).powi(2))
478 .sum::<f64>()
479 / (n_samples - 1) as f64;
480 self.std[j] = variance.sqrt().max(1e-8); }
482 }
483
484 self.fitted = true;
485 Ok(())
486 }
487
488 pub fn transform(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>, SklearsError> {
489 if !self.fitted {
490 return Err(SklearsError::InvalidInput("Scaler not fitted".to_string()));
491 }
492
493 let mut result = Array2::zeros(x.raw_dim());
494 for (i, row) in x.outer_iter().enumerate() {
495 for (j, &value) in row.iter().enumerate() {
496 result[[i, j]] = (value - self.mean[j]) / self.std[j];
497 }
498 }
499 Ok(result)
500 }
501}
502
503impl Preprocessor for StandardScaler {
504 fn transform(
505 &self,
506 x: &ArrayView2<f64>,
507 y: &ArrayView1<f64>,
508 ) -> Result<(Array2<f64>, Array1<f64>), SklearsError> {
509 let transformed_x = self.transform(x)?;
510 Ok((transformed_x, y.to_owned()))
511 }
512}
513
514#[derive(Debug, Clone)]
516pub struct ClippingPostprocessor {
517 min_value: f64,
518 max_value: f64,
519}
520
521impl ClippingPostprocessor {
522 pub fn new(min_value: f64, max_value: f64) -> Self {
523 Self {
524 min_value,
525 max_value,
526 }
527 }
528}
529
530impl Postprocessor<f64> for ClippingPostprocessor {
531 fn transform(&self, predictions: &[f64]) -> Result<Vec<f64>, SklearsError> {
532 let clipped = predictions
533 .iter()
534 .map(|&pred| pred.max(self.min_value).min(self.max_value))
535 .collect();
536 Ok(clipped)
537 }
538}
539
540pub mod statistical_methods {
542 use super::*;
543
544 pub trait StatisticalEstimator: Send + Sync + std::fmt::Debug {
546 type Input: ?Sized;
547 type Output;
548
549 fn estimate(&self, data: &Self::Input) -> Result<Self::Output, SklearsError>;
550 }
551
552 #[derive(Debug, Clone)]
554 pub struct TrimmedMeanEstimator {
555 trim_percentage: f64,
556 }
557
558 impl TrimmedMeanEstimator {
559 pub fn new(trim_percentage: f64) -> Result<Self, SklearsError> {
560 if !(0.0..=0.5).contains(&trim_percentage) {
561 return Err(SklearsError::InvalidInput(
562 "Trim percentage must be between 0 and 0.5".to_string(),
563 ));
564 }
565 Ok(Self { trim_percentage })
566 }
567 }
568
569 impl StatisticalEstimator for TrimmedMeanEstimator {
570 type Input = [f64];
571 type Output = f64;
572
573 fn estimate(&self, data: &Self::Input) -> Result<Self::Output, SklearsError> {
574 if data.is_empty() {
575 return Err(SklearsError::InvalidInput("Empty data array".to_string()));
576 }
577
578 let mut sorted_data = data.to_vec();
579 sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
580
581 let n = sorted_data.len();
582 let trim_count = (n as f64 * self.trim_percentage).floor() as usize;
583
584 if trim_count * 2 >= n {
585 return Ok(sorted_data[n / 2]);
587 }
588
589 let trimmed_data = &sorted_data[trim_count..n - trim_count];
590 let mean = trimmed_data.iter().sum::<f64>() / trimmed_data.len() as f64;
591
592 Ok(mean)
593 }
594 }
595
596 #[derive(Debug, Clone)]
598 pub struct MedianAbsoluteDeviationEstimator;
599
600 impl StatisticalEstimator for MedianAbsoluteDeviationEstimator {
601 type Input = [f64];
602 type Output = f64;
603
604 fn estimate(&self, data: &Self::Input) -> Result<Self::Output, SklearsError> {
605 if data.is_empty() {
606 return Err(SklearsError::InvalidInput("Empty data array".to_string()));
607 }
608
609 let mut sorted_data = data.to_vec();
610 sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
611
612 let n = sorted_data.len();
613 let median = if n % 2 == 0 {
614 (sorted_data[n / 2 - 1] + sorted_data[n / 2]) / 2.0
615 } else {
616 sorted_data[n / 2]
617 };
618
619 let deviations: Vec<f64> = sorted_data.iter().map(|&x| (x - median).abs()).collect();
620
621 let mut sorted_deviations = deviations;
622 sorted_deviations.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
623
624 let mad = if n % 2 == 0 {
625 (sorted_deviations[n / 2 - 1] + sorted_deviations[n / 2]) / 2.0
626 } else {
627 sorted_deviations[n / 2]
628 };
629
630 Ok(mad)
631 }
632 }
633
634 #[derive(Debug, Clone)]
636 pub struct QuantileEstimator {
637 quantile: f64,
638 }
639
640 impl QuantileEstimator {
641 pub fn new(quantile: f64) -> Result<Self, SklearsError> {
642 if !(0.0..=1.0).contains(&quantile) {
643 return Err(SklearsError::InvalidInput(
644 "Quantile must be between 0 and 1".to_string(),
645 ));
646 }
647 Ok(Self { quantile })
648 }
649 }
650
651 impl StatisticalEstimator for QuantileEstimator {
652 type Input = [f64];
653 type Output = f64;
654
655 fn estimate(&self, data: &Self::Input) -> Result<Self::Output, SklearsError> {
656 if data.is_empty() {
657 return Err(SklearsError::InvalidInput("Empty data array".to_string()));
658 }
659
660 let mut sorted_data = data.to_vec();
661 sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
662
663 let n = sorted_data.len();
664 let index = (self.quantile * (n - 1) as f64).floor() as usize;
665 let fraction = self.quantile * (n - 1) as f64 - index as f64;
666
667 let quantile_value = if index >= n - 1 {
668 sorted_data[n - 1]
669 } else {
670 sorted_data[index] + fraction * (sorted_data[index + 1] - sorted_data[index])
671 };
672
673 Ok(quantile_value)
674 }
675 }
676}
677
678pub struct BaselineStrategyFactory;
680
681impl BaselineStrategyFactory {
682 pub fn most_frequent() -> MostFrequentStrategy {
684 MostFrequentStrategy
685 }
686
687 pub fn mean() -> MeanStrategy {
689 MeanStrategy
690 }
691
692 pub fn standard_pipeline<S: BaselineStrategy + Clone>(strategy: S) -> PredictionPipeline<S> {
694 PredictionPipeline::new(strategy).with_preprocessor(Box::new(StandardScaler::new()))
695 }
696
697 pub fn robust_regression_pipeline() -> PredictionPipeline<MeanStrategy> {
699 PredictionPipeline::new(MeanStrategy)
700 .with_preprocessor(Box::new(StandardScaler::new()))
701 .with_postprocessor(Box::new(ClippingPostprocessor::new(-1e6, 1e6)))
702 }
703}
704
705#[allow(non_snake_case)]
706#[cfg(test)]
707mod tests {
708 use super::statistical_methods::StatisticalEstimator;
709 use super::*;
710 use scirs2_core::ndarray::array;
711
712 #[test]
713 fn test_most_frequent_strategy() {
714 let strategy = MostFrequentStrategy;
715 let config = MostFrequentConfig {
716 random_state: Some(42),
717 };
718
719 let x =
720 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
721 let y = array![0.0, 1.0, 1.0, 0.0]; let fitted = strategy.fit(&config, &x.view(), &y.view()).unwrap();
724 assert!(fitted.class_counts.contains_key(&0));
725 assert!(fitted.class_counts.contains_key(&1));
726
727 let predictions = strategy.predict(&fitted, &x.view()).unwrap();
728 assert_eq!(predictions.len(), 4);
729 assert!(predictions.iter().all(|&p| p == 0 || p == 1));
730 }
731
732 #[test]
733 fn test_mean_strategy() {
734 let strategy = MeanStrategy;
735 let config = MeanConfig {
736 random_state: Some(42),
737 };
738
739 let x =
740 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
741 let y = array![1.0, 2.0, 3.0, 4.0];
742
743 let fitted = strategy.fit(&config, &x.view(), &y.view()).unwrap();
744 assert_eq!(fitted.target_mean, 2.5);
745
746 let predictions = strategy.predict(&fitted, &x.view()).unwrap();
747 assert_eq!(predictions.len(), 4);
748 assert!(predictions.iter().all(|&p| p == 2.5));
749 }
750
751 #[test]
752 fn test_prediction_pipeline() {
753 let strategy = MeanStrategy;
754 let config = MeanConfig {
755 random_state: Some(42),
756 };
757
758 let pipeline = PredictionPipeline::new(strategy)
759 .with_postprocessor(Box::new(ClippingPostprocessor::new(0.0, 10.0)));
760
761 let x =
762 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
763 let y = array![1.0, 2.0, 3.0, 4.0];
764
765 let fitted_pipeline = pipeline.fit(&config, &x.view(), &y.view()).unwrap();
766 let predictions = fitted_pipeline.predict(&x.view()).unwrap();
767
768 assert_eq!(predictions.len(), 4);
769 assert!(predictions.iter().all(|&p| p >= 0.0 && p <= 10.0));
770 }
771
772 #[test]
773 fn test_trimmed_mean_estimator() {
774 let estimator = statistical_methods::TrimmedMeanEstimator::new(0.1).unwrap();
775 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]; let result = estimator.estimate(&data).unwrap();
778 assert!(result > 0.0 && result < 50.0); }
782
783 #[test]
784 fn test_mad_estimator() {
785 let estimator = statistical_methods::MedianAbsoluteDeviationEstimator;
786 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
787
788 let result = estimator.estimate(&data).unwrap();
789 assert!(result > 0.0);
790 }
791
792 #[test]
793 fn test_quantile_estimator() {
794 let estimator = statistical_methods::QuantileEstimator::new(0.5).unwrap(); let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
796
797 let result = estimator.estimate(&data).unwrap();
798 assert_eq!(result, 3.0);
799 }
800
801 #[test]
802 fn test_factory_methods() {
803 let most_frequent = BaselineStrategyFactory::most_frequent();
804 assert_eq!(most_frequent.name(), "most_frequent");
805
806 let mean = BaselineStrategyFactory::mean();
807 assert_eq!(mean.name(), "mean");
808
809 let pipeline = BaselineStrategyFactory::standard_pipeline(mean);
810 assert_eq!(pipeline.strategy.name(), "mean");
811 }
812
813 #[test]
814 fn test_standard_scaler() {
815 let mut scaler = StandardScaler::new();
816 let x =
817 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
818
819 scaler.fit(&x.view()).unwrap();
820 let transformed = scaler.transform(&x.view()).unwrap();
821
822 assert_eq!(transformed.shape(), x.shape());
823
824 for j in 0..transformed.ncols() {
826 let col_mean = transformed.column(j).iter().sum::<f64>() / transformed.nrows() as f64;
827 assert!((col_mean).abs() < 1e-10); }
829 }
830
831 #[test]
832 fn test_clipping_postprocessor() {
833 let clipper = ClippingPostprocessor::new(-1.0, 1.0);
834 let predictions = vec![-2.0, -0.5, 0.0, 0.5, 2.0];
835
836 let clipped = clipper.transform(&predictions).unwrap();
837 assert_eq!(clipped, vec![-1.0, -0.5, 0.0, 0.5, 1.0]);
838 }
839}