1use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::random::{
9 essentials::Normal, prelude::*, rngs::StdRng, Distribution, Rng, SeedableRng,
10};
11use sklears_core::error::Result;
12use sklears_core::traits::{Estimator, Fit, Predict};
13use sklears_core::types::{Features, Float};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone, PartialEq)]
18#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
19pub enum ContextAwareStrategy {
20 Conditional {
22 n_bins: usize,
24 min_samples_per_bin: usize,
26 },
27 FeatureWeighted {
29 weighting: FeatureWeighting,
31 },
32 ClusterBased {
34 n_clusters: usize,
36 max_iter: usize,
38 },
39 LocalitySensitive {
41 n_neighbors: usize,
43 distance_power: Float,
45 },
46 AdaptiveLocal {
48 radius: Float,
50 min_local_samples: usize,
52 },
53}
54
55#[derive(Debug, Clone, PartialEq)]
57#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
58pub enum FeatureWeighting {
59 Uniform,
61 Variance,
63 Correlation,
65 Custom(Array1<Float>),
67}
68
69#[derive(Debug, Clone)]
71pub struct ContextAwareDummyRegressor<State = sklears_core::traits::Untrained> {
72 pub strategy: ContextAwareStrategy,
74 pub random_state: Option<u64>,
76
77 pub(crate) feature_bins_: Option<Vec<Array1<Float>>>,
80 pub(crate) bin_predictions_: Option<HashMap<Vec<usize>, Float>>,
82
83 pub(crate) feature_weights_: Option<Array1<Float>>,
85 pub(crate) weighted_intercept_: Option<Float>,
87 pub(crate) weighted_coefficients_: Option<Array1<Float>>,
88
89 pub(crate) cluster_centers_: Option<Array2<Float>>,
91 pub(crate) cluster_predictions_: Option<Array1<Float>>,
93
94 pub(crate) training_features_: Option<Array2<Float>>,
96 pub(crate) training_targets_: Option<Array1<Float>>,
97
98 pub(crate) local_means_: Option<Array1<Float>>,
100 pub(crate) local_stds_: Option<Array1<Float>>,
101 pub(crate) local_centers_: Option<Array2<Float>>,
102
103 pub(crate) _state: std::marker::PhantomData<State>,
105}
106
107impl ContextAwareDummyRegressor {
108 pub fn new(strategy: ContextAwareStrategy) -> Self {
110 Self {
111 strategy,
112 random_state: None,
113 feature_bins_: None,
114 bin_predictions_: None,
115 feature_weights_: None,
116 weighted_intercept_: None,
117 weighted_coefficients_: None,
118 cluster_centers_: None,
119 cluster_predictions_: None,
120 training_features_: None,
121 training_targets_: None,
122 local_means_: None,
123 local_stds_: None,
124 local_centers_: None,
125 _state: std::marker::PhantomData,
126 }
127 }
128
129 pub fn with_random_state(mut self, random_state: u64) -> Self {
131 self.random_state = Some(random_state);
132 self
133 }
134}
135
136impl Default for ContextAwareDummyRegressor {
137 fn default() -> Self {
138 Self::new(ContextAwareStrategy::Conditional {
139 n_bins: 5,
140 min_samples_per_bin: 3,
141 })
142 }
143}
144
145impl Estimator for ContextAwareDummyRegressor {
146 type Config = ();
147 type Error = sklears_core::error::SklearsError;
148 type Float = Float;
149
150 fn config(&self) -> &Self::Config {
151 &()
152 }
153}
154
155impl Fit<Features, Array1<Float>> for ContextAwareDummyRegressor {
156 type Fitted = ContextAwareDummyRegressor<sklears_core::traits::Trained>;
157
158 fn fit(self, x: &Features, y: &Array1<Float>) -> Result<Self::Fitted> {
159 if x.is_empty() || y.is_empty() {
160 return Err(sklears_core::error::SklearsError::InvalidInput(
161 "Input cannot be empty".to_string(),
162 ));
163 }
164
165 if x.nrows() != y.len() {
166 return Err(sklears_core::error::SklearsError::InvalidInput(
167 "Number of samples in X and y must be equal".to_string(),
168 ));
169 }
170
171 let mut fitted = ContextAwareDummyRegressor {
172 strategy: self.strategy.clone(),
173 random_state: self.random_state,
174 feature_bins_: None,
175 bin_predictions_: None,
176 feature_weights_: None,
177 weighted_intercept_: None,
178 weighted_coefficients_: None,
179 cluster_centers_: None,
180 cluster_predictions_: None,
181 training_features_: None,
182 training_targets_: None,
183 local_means_: None,
184 local_stds_: None,
185 local_centers_: None,
186 _state: std::marker::PhantomData,
187 };
188
189 match &self.strategy {
190 ContextAwareStrategy::Conditional {
191 n_bins,
192 min_samples_per_bin,
193 } => {
194 fitted.fit_conditional(x, y, *n_bins, *min_samples_per_bin)?;
195 }
196 ContextAwareStrategy::FeatureWeighted { weighting } => {
197 fitted.fit_feature_weighted(x, y, weighting)?;
198 }
199 ContextAwareStrategy::ClusterBased {
200 n_clusters,
201 max_iter,
202 } => {
203 fitted.fit_cluster_based(x, y, *n_clusters, *max_iter)?;
204 }
205 ContextAwareStrategy::LocalitySensitive {
206 n_neighbors,
207 distance_power,
208 } => {
209 fitted.fit_locality_sensitive(x, y, *n_neighbors, *distance_power)?;
210 }
211 ContextAwareStrategy::AdaptiveLocal {
212 radius,
213 min_local_samples,
214 } => {
215 fitted.fit_adaptive_local(x, y, *radius, *min_local_samples)?;
216 }
217 }
218
219 Ok(fitted)
220 }
221}
222
223impl ContextAwareDummyRegressor<sklears_core::traits::Trained> {
224 fn fit_conditional(
226 &mut self,
227 x: &Features,
228 y: &Array1<Float>,
229 n_bins: usize,
230 min_samples_per_bin: usize,
231 ) -> Result<()> {
232 let n_features = x.ncols();
233 let mut feature_bins = Vec::with_capacity(n_features);
234 let mut bin_predictions = HashMap::new();
235
236 for feature_idx in 0..n_features {
238 let feature_values = x.column(feature_idx);
239 let min_val = feature_values
240 .iter()
241 .fold(Float::INFINITY, |a, &b| a.min(b));
242 let max_val = feature_values
243 .iter()
244 .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
245
246 let bin_width = (max_val - min_val) / n_bins as Float;
247 let mut bins = Array1::zeros(n_bins + 1);
248
249 for i in 0..=n_bins {
250 bins[i] = min_val + i as Float * bin_width;
251 }
252 bins[n_bins] = max_val + 1e-10; feature_bins.push(bins);
255 }
256
257 for i in 0..x.nrows() {
259 let mut bin_indices = Vec::with_capacity(n_features);
260
261 for (feature_idx, bins) in feature_bins.iter().enumerate() {
262 let value = x[[i, feature_idx]];
263 let bin_idx = bins
264 .iter()
265 .position(|&bin_edge| value < bin_edge)
266 .unwrap_or(bins.len() - 1)
267 .saturating_sub(1);
268 bin_indices.push(bin_idx);
269 }
270
271 let entry = bin_predictions.entry(bin_indices).or_insert_with(Vec::new);
272 entry.push(y[i]);
273 }
274
275 let mut final_bin_predictions = HashMap::new();
277 for (bin_key, targets) in bin_predictions {
278 if targets.len() >= min_samples_per_bin {
279 let mean = targets.iter().sum::<Float>() / targets.len() as Float;
280 final_bin_predictions.insert(bin_key, mean);
281 }
282 }
283
284 self.feature_bins_ = Some(feature_bins);
285 self.bin_predictions_ = Some(final_bin_predictions);
286 Ok(())
287 }
288
289 fn fit_feature_weighted(
291 &mut self,
292 x: &Features,
293 y: &Array1<Float>,
294 weighting: &FeatureWeighting,
295 ) -> Result<()> {
296 let n_features = x.ncols();
297 let weights = match weighting {
298 FeatureWeighting::Uniform => Array1::from_elem(n_features, 1.0 / n_features as Float),
299 FeatureWeighting::Variance => {
300 let mut weights = Array1::zeros(n_features);
301 for i in 0..n_features {
302 let feature = x.column(i);
303 let mean = feature.mean().unwrap_or(0.0);
304 let variance = feature
305 .iter()
306 .map(|&val| (val - mean).powi(2))
307 .sum::<Float>()
308 / feature.len() as Float;
309 weights[i] = variance;
310 }
311 let sum_weights = weights.sum();
312 if sum_weights > 0.0 {
313 weights / sum_weights
314 } else {
315 Array1::from_elem(n_features, 1.0 / n_features as Float)
316 }
317 }
318 FeatureWeighting::Correlation => {
319 let mut weights = Array1::zeros(n_features);
320 let y_mean = y.mean().unwrap_or(0.0);
321
322 for i in 0..n_features {
323 let feature = x.column(i);
324 let x_mean = feature.mean().unwrap_or(0.0);
325
326 let mut numerator = 0.0;
327 let mut x_var = 0.0;
328 let mut y_var = 0.0;
329
330 for j in 0..feature.len() {
331 let x_diff = feature[j] - x_mean;
332 let y_diff = y[j] - y_mean;
333 numerator += x_diff * y_diff;
334 x_var += x_diff * x_diff;
335 y_var += y_diff * y_diff;
336 }
337
338 let correlation = if x_var > 0.0 && y_var > 0.0 {
339 numerator / (x_var * y_var).sqrt()
340 } else {
341 0.0
342 };
343
344 weights[i] = correlation.abs();
345 }
346
347 let sum_weights = weights.sum();
348 if sum_weights > 0.0 {
349 weights / sum_weights
350 } else {
351 Array1::from_elem(n_features, 1.0 / n_features as Float)
352 }
353 }
354 FeatureWeighting::Custom(custom_weights) => {
355 if custom_weights.len() != n_features {
356 return Err(sklears_core::error::SklearsError::InvalidInput(
357 "Custom weights length must match number of features".to_string(),
358 ));
359 }
360 custom_weights.clone()
361 }
362 };
363
364 let y_mean = y.mean().unwrap_or(0.0);
366 let mut coefficients = Array1::zeros(n_features);
367
368 for i in 0..n_features {
369 let feature = x.column(i);
370 let x_mean = feature.mean().unwrap_or(0.0);
371 coefficients[i] = weights[i] * (y_mean - x_mean);
372 }
373
374 self.feature_weights_ = Some(weights);
375 self.weighted_intercept_ = Some(y_mean);
376 self.weighted_coefficients_ = Some(coefficients);
377 Ok(())
378 }
379
380 fn fit_cluster_based(
382 &mut self,
383 x: &Features,
384 y: &Array1<Float>,
385 n_clusters: usize,
386 max_iter: usize,
387 ) -> Result<()> {
388 let n_samples = x.nrows();
389 let n_features = x.ncols();
390
391 if n_clusters > n_samples {
392 return Err(sklears_core::error::SklearsError::InvalidInput(
393 "Number of clusters cannot exceed number of samples".to_string(),
394 ));
395 }
396
397 let mut rng = if let Some(seed) = self.random_state {
398 StdRng::seed_from_u64(seed)
399 } else {
400 StdRng::seed_from_u64(0)
401 };
402
403 let mut centers = Array2::zeros((n_clusters, n_features));
405 for i in 0..n_clusters {
406 let sample_idx = rng.gen_range(0..n_samples);
407 for j in 0..n_features {
408 centers[[i, j]] = x[[sample_idx, j]];
409 }
410 }
411
412 let mut assignments = vec![0; n_samples];
414
415 for _iter in 0..max_iter {
416 let mut changed = false;
417
418 for i in 0..n_samples {
420 let mut min_distance = Float::INFINITY;
421 let mut best_cluster = 0;
422
423 for cluster in 0..n_clusters {
424 let mut distance = 0.0;
425 for j in 0..n_features {
426 let diff = x[[i, j]] - centers[[cluster, j]];
427 distance += diff * diff;
428 }
429
430 if distance < min_distance {
431 min_distance = distance;
432 best_cluster = cluster;
433 }
434 }
435
436 if assignments[i] != best_cluster {
437 assignments[i] = best_cluster;
438 changed = true;
439 }
440 }
441
442 if !changed {
443 break;
444 }
445
446 let mut cluster_counts = vec![0; n_clusters];
448 centers.fill(0.0);
449
450 for i in 0..n_samples {
451 let cluster = assignments[i];
452 cluster_counts[cluster] += 1;
453 for j in 0..n_features {
454 centers[[cluster, j]] += x[[i, j]];
455 }
456 }
457
458 for cluster in 0..n_clusters {
459 if cluster_counts[cluster] > 0 {
460 for j in 0..n_features {
461 centers[[cluster, j]] /= cluster_counts[cluster] as Float;
462 }
463 }
464 }
465 }
466
467 let mut cluster_targets: Vec<Vec<Float>> = vec![Vec::new(); n_clusters];
469 for i in 0..n_samples {
470 cluster_targets[assignments[i]].push(y[i]);
471 }
472
473 let mut cluster_predictions = Array1::zeros(n_clusters);
474 for i in 0..n_clusters {
475 if !cluster_targets[i].is_empty() {
476 cluster_predictions[i] =
477 cluster_targets[i].iter().sum::<Float>() / cluster_targets[i].len() as Float;
478 }
479 }
480
481 self.cluster_centers_ = Some(centers);
482 self.cluster_predictions_ = Some(cluster_predictions);
483 Ok(())
484 }
485
486 fn fit_locality_sensitive(
488 &mut self,
489 x: &Features,
490 y: &Array1<Float>,
491 _n_neighbors: usize,
492 _distance_power: Float,
493 ) -> Result<()> {
494 self.training_features_ = Some(x.clone());
496 self.training_targets_ = Some(y.clone());
497 Ok(())
498 }
499
500 fn fit_adaptive_local(
502 &mut self,
503 x: &Features,
504 y: &Array1<Float>,
505 radius: Float,
506 min_local_samples: usize,
507 ) -> Result<()> {
508 let n_samples = x.nrows();
509 let n_features = x.ncols();
510
511 let n_centers = (n_samples / min_local_samples).max(1);
513 let mut centers = Array2::zeros((n_centers, n_features));
514 let mut local_means = Array1::zeros(n_centers);
515 let mut local_stds = Array1::zeros(n_centers);
516
517 let mut rng = if let Some(seed) = self.random_state {
518 StdRng::seed_from_u64(seed)
519 } else {
520 StdRng::seed_from_u64(0)
521 };
522
523 for i in 0..n_centers {
525 let sample_idx = rng.gen_range(0..n_samples);
526 for j in 0..n_features {
527 centers[[i, j]] = x[[sample_idx, j]];
528 }
529 }
530
531 for i in 0..n_centers {
533 let mut local_targets = Vec::new();
534
535 for j in 0..n_samples {
536 let mut distance = 0.0;
537 for k in 0..n_features {
538 let diff = x[[j, k]] - centers[[i, k]];
539 distance += diff * diff;
540 }
541 distance = distance.sqrt();
542
543 if distance <= radius {
544 local_targets.push(y[j]);
545 }
546 }
547
548 if local_targets.len() >= min_local_samples {
549 let mean = local_targets.iter().sum::<Float>() / local_targets.len() as Float;
550 let variance = local_targets
551 .iter()
552 .map(|&val| (val - mean).powi(2))
553 .sum::<Float>()
554 / local_targets.len() as Float;
555 let std_dev = variance.sqrt();
556
557 local_means[i] = mean;
558 local_stds[i] = std_dev;
559 } else {
560 let global_mean = y.mean().unwrap_or(0.0);
562 let global_variance = y
563 .iter()
564 .map(|&val| (val - global_mean).powi(2))
565 .sum::<Float>()
566 / y.len() as Float;
567
568 local_means[i] = global_mean;
569 local_stds[i] = global_variance.sqrt();
570 }
571 }
572
573 self.local_centers_ = Some(centers);
574 self.local_means_ = Some(local_means);
575 self.local_stds_ = Some(local_stds);
576 Ok(())
577 }
578}
579
580impl Predict<Features, Array1<Float>>
581 for ContextAwareDummyRegressor<sklears_core::traits::Trained>
582{
583 fn predict(&self, x: &Features) -> Result<Array1<Float>> {
584 if x.is_empty() {
585 return Err(sklears_core::error::SklearsError::InvalidInput(
586 "Input cannot be empty".to_string(),
587 ));
588 }
589
590 let n_samples = x.nrows();
591 let mut predictions = Array1::zeros(n_samples);
592
593 match &self.strategy {
594 ContextAwareStrategy::Conditional { .. } => {
595 self.predict_conditional(x, &mut predictions)?;
596 }
597 ContextAwareStrategy::FeatureWeighted { .. } => {
598 self.predict_feature_weighted(x, &mut predictions)?;
599 }
600 ContextAwareStrategy::ClusterBased { .. } => {
601 self.predict_cluster_based(x, &mut predictions)?;
602 }
603 ContextAwareStrategy::LocalitySensitive {
604 n_neighbors,
605 distance_power,
606 } => {
607 self.predict_locality_sensitive(
608 x,
609 &mut predictions,
610 *n_neighbors,
611 *distance_power,
612 )?;
613 }
614 ContextAwareStrategy::AdaptiveLocal { radius, .. } => {
615 self.predict_adaptive_local(x, &mut predictions, *radius)?;
616 }
617 }
618
619 Ok(predictions)
620 }
621}
622
623impl ContextAwareDummyRegressor<sklears_core::traits::Trained> {
624 fn predict_conditional(&self, x: &Features, predictions: &mut Array1<Float>) -> Result<()> {
626 let feature_bins = self.feature_bins_.as_ref().unwrap();
627 let bin_predictions = self.bin_predictions_.as_ref().unwrap();
628 let global_mean = bin_predictions.values().sum::<Float>() / bin_predictions.len() as Float;
629
630 for i in 0..x.nrows() {
631 let mut bin_indices = Vec::with_capacity(feature_bins.len());
632
633 for (feature_idx, bins) in feature_bins.iter().enumerate() {
634 let value = x[[i, feature_idx]];
635 let bin_idx = bins
636 .iter()
637 .position(|&bin_edge| value < bin_edge)
638 .unwrap_or(bins.len() - 1)
639 .saturating_sub(1);
640 bin_indices.push(bin_idx);
641 }
642
643 predictions[i] = *bin_predictions.get(&bin_indices).unwrap_or(&global_mean);
644 }
645
646 Ok(())
647 }
648
649 fn predict_feature_weighted(
651 &self,
652 x: &Features,
653 predictions: &mut Array1<Float>,
654 ) -> Result<()> {
655 let weights = self.feature_weights_.as_ref().unwrap();
656 let intercept = self.weighted_intercept_.unwrap();
657 let coefficients = self.weighted_coefficients_.as_ref().unwrap();
658
659 for i in 0..x.nrows() {
660 let mut weighted_sum = intercept;
661 for j in 0..x.ncols() {
662 weighted_sum += x[[i, j]] * weights[j] + coefficients[j];
663 }
664 predictions[i] = weighted_sum;
665 }
666
667 Ok(())
668 }
669
670 fn predict_cluster_based(&self, x: &Features, predictions: &mut Array1<Float>) -> Result<()> {
672 let centers = self.cluster_centers_.as_ref().unwrap();
673 let cluster_predictions = self.cluster_predictions_.as_ref().unwrap();
674
675 for i in 0..x.nrows() {
676 let mut min_distance = Float::INFINITY;
677 let mut best_cluster = 0;
678
679 for cluster in 0..centers.nrows() {
680 let mut distance = 0.0;
681 for j in 0..x.ncols() {
682 let diff = x[[i, j]] - centers[[cluster, j]];
683 distance += diff * diff;
684 }
685
686 if distance < min_distance {
687 min_distance = distance;
688 best_cluster = cluster;
689 }
690 }
691
692 predictions[i] = cluster_predictions[best_cluster];
693 }
694
695 Ok(())
696 }
697
698 fn predict_locality_sensitive(
700 &self,
701 x: &Features,
702 predictions: &mut Array1<Float>,
703 n_neighbors: usize,
704 distance_power: Float,
705 ) -> Result<()> {
706 let training_features = self.training_features_.as_ref().unwrap();
707 let training_targets = self.training_targets_.as_ref().unwrap();
708
709 for i in 0..x.nrows() {
710 let mut distances = Vec::new();
711
712 for j in 0..training_features.nrows() {
714 let mut distance = 0.0;
715 for k in 0..x.ncols() {
716 let diff = x[[i, k]] - training_features[[j, k]];
717 distance += diff * diff;
718 }
719 distance = distance.sqrt();
720 distances.push((distance, j));
721 }
722
723 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
725 let k_nearest = distances.into_iter().take(n_neighbors).collect::<Vec<_>>();
726
727 let mut weighted_sum = 0.0;
729 let mut weight_sum = 0.0;
730
731 for (distance, idx) in k_nearest {
732 let weight = if distance == 0.0 {
733 1000.0 } else {
735 1.0 / distance.powf(distance_power)
736 };
737
738 weighted_sum += weight * training_targets[idx];
739 weight_sum += weight;
740 }
741
742 predictions[i] = if weight_sum > 0.0 {
743 weighted_sum / weight_sum
744 } else {
745 training_targets.mean().unwrap_or(0.0)
746 };
747 }
748
749 Ok(())
750 }
751
752 fn predict_adaptive_local(
754 &self,
755 x: &Features,
756 predictions: &mut Array1<Float>,
757 radius: Float,
758 ) -> Result<()> {
759 let centers = self.local_centers_.as_ref().unwrap();
760 let local_means = self.local_means_.as_ref().unwrap();
761 let local_stds = self.local_stds_.as_ref().unwrap();
762
763 let mut rng = if let Some(seed) = self.random_state {
764 StdRng::seed_from_u64(seed)
765 } else {
766 StdRng::seed_from_u64(0)
767 };
768
769 for i in 0..x.nrows() {
770 let mut min_distance = Float::INFINITY;
772 let mut best_center = 0;
773
774 for j in 0..centers.nrows() {
775 let mut distance = 0.0;
776 for k in 0..x.ncols() {
777 let diff = x[[i, k]] - centers[[j, k]];
778 distance += diff * diff;
779 }
780 distance = distance.sqrt();
781
782 if distance <= radius && distance < min_distance {
783 min_distance = distance;
784 best_center = j;
785 }
786 }
787
788 if min_distance <= radius {
790 let mean = local_means[best_center];
791 let std = local_stds[best_center];
792
793 if std > 0.0 {
794 let normal = Normal::new(mean, std).unwrap();
795 predictions[i] = normal.sample(&mut rng);
796 } else {
797 predictions[i] = mean;
798 }
799 } else {
800 predictions[i] = local_means.mean().unwrap_or(0.0);
802 }
803 }
804
805 Ok(())
806 }
807}
808
809#[derive(Debug, Clone)]
811pub struct ContextAwareDummyClassifier<State = sklears_core::traits::Untrained> {
812 pub strategy: ContextAwareStrategy,
814 pub random_state: Option<u64>,
816
817 pub(crate) feature_bins_: Option<Vec<Array1<Float>>>,
819 pub(crate) bin_class_probs_: Option<HashMap<Vec<usize>, HashMap<i32, Float>>>,
820 pub(crate) classes_: Option<Array1<i32>>,
821 pub(crate) training_features_: Option<Array2<Float>>,
822 pub(crate) training_targets_: Option<Array1<i32>>,
823
824 pub(crate) _state: std::marker::PhantomData<State>,
826}
827
828impl ContextAwareDummyClassifier {
829 pub fn new(strategy: ContextAwareStrategy) -> Self {
831 Self {
832 strategy,
833 random_state: None,
834 feature_bins_: None,
835 bin_class_probs_: None,
836 classes_: None,
837 training_features_: None,
838 training_targets_: None,
839 _state: std::marker::PhantomData,
840 }
841 }
842
843 pub fn with_random_state(mut self, random_state: u64) -> Self {
845 self.random_state = Some(random_state);
846 self
847 }
848}
849
850impl Default for ContextAwareDummyClassifier {
851 fn default() -> Self {
852 Self::new(ContextAwareStrategy::Conditional {
853 n_bins: 5,
854 min_samples_per_bin: 3,
855 })
856 }
857}
858
859impl Estimator for ContextAwareDummyClassifier {
860 type Config = ();
861 type Error = sklears_core::error::SklearsError;
862 type Float = Float;
863
864 fn config(&self) -> &Self::Config {
865 &()
866 }
867}
868
869impl Fit<Features, Array1<i32>> for ContextAwareDummyClassifier {
870 type Fitted = ContextAwareDummyClassifier<sklears_core::traits::Trained>;
871
872 fn fit(self, x: &Features, y: &Array1<i32>) -> Result<Self::Fitted> {
873 if x.is_empty() || y.is_empty() {
874 return Err(sklears_core::error::SklearsError::InvalidInput(
875 "Input cannot be empty".to_string(),
876 ));
877 }
878
879 if x.nrows() != y.len() {
880 return Err(sklears_core::error::SklearsError::InvalidInput(
881 "Number of samples in X and y must be equal".to_string(),
882 ));
883 }
884
885 let mut unique_classes = y.iter().cloned().collect::<Vec<_>>();
887 unique_classes.sort_unstable();
888 unique_classes.dedup();
889 let classes = Array1::from_vec(unique_classes);
890
891 let mut fitted = ContextAwareDummyClassifier {
892 strategy: self.strategy.clone(),
893 random_state: self.random_state,
894 feature_bins_: None,
895 bin_class_probs_: None,
896 classes_: Some(classes),
897 training_features_: None,
898 training_targets_: None,
899 _state: std::marker::PhantomData,
900 };
901
902 match &self.strategy {
904 ContextAwareStrategy::Conditional {
905 n_bins,
906 min_samples_per_bin,
907 } => {
908 fitted.fit_conditional_classifier(x, y, *n_bins, *min_samples_per_bin)?;
909 }
910 _ => {
911 fitted.training_features_ = Some(x.clone());
913 fitted.training_targets_ = Some(y.clone());
914 }
915 }
916
917 Ok(fitted)
918 }
919}
920
921impl ContextAwareDummyClassifier<sklears_core::traits::Trained> {
922 fn fit_conditional_classifier(
924 &mut self,
925 x: &Features,
926 y: &Array1<i32>,
927 n_bins: usize,
928 min_samples_per_bin: usize,
929 ) -> Result<()> {
930 let n_features = x.ncols();
931 let mut feature_bins = Vec::with_capacity(n_features);
932 let mut bin_class_counts: HashMap<Vec<usize>, HashMap<i32, usize>> = HashMap::new();
933
934 for feature_idx in 0..n_features {
936 let feature_values = x.column(feature_idx);
937 let min_val = feature_values
938 .iter()
939 .fold(Float::INFINITY, |a, &b| a.min(b));
940 let max_val = feature_values
941 .iter()
942 .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
943
944 let bin_width = (max_val - min_val) / n_bins as Float;
945 let mut bins = Array1::zeros(n_bins + 1);
946
947 for i in 0..=n_bins {
948 bins[i] = min_val + i as Float * bin_width;
949 }
950 bins[n_bins] = max_val + 1e-10;
951
952 feature_bins.push(bins);
953 }
954
955 for i in 0..x.nrows() {
957 let mut bin_indices = Vec::with_capacity(n_features);
958
959 for (feature_idx, bins) in feature_bins.iter().enumerate() {
960 let value = x[[i, feature_idx]];
961 let bin_idx = bins
962 .iter()
963 .position(|&bin_edge| value < bin_edge)
964 .unwrap_or(bins.len() - 1)
965 .saturating_sub(1);
966 bin_indices.push(bin_idx);
967 }
968
969 let class_counts = bin_class_counts.entry(bin_indices).or_default();
970 *class_counts.entry(y[i]).or_insert(0) += 1;
971 }
972
973 let mut bin_class_probs = HashMap::new();
975 for (bin_key, class_counts) in bin_class_counts {
976 let total_count: usize = class_counts.values().sum();
977 if total_count >= min_samples_per_bin {
978 let mut class_probs = HashMap::new();
979 for (&class, &count) in &class_counts {
980 class_probs.insert(class, count as Float / total_count as Float);
981 }
982 bin_class_probs.insert(bin_key, class_probs);
983 }
984 }
985
986 self.feature_bins_ = Some(feature_bins);
987 self.bin_class_probs_ = Some(bin_class_probs);
988 Ok(())
989 }
990}
991
992impl Predict<Features, Array1<i32>> for ContextAwareDummyClassifier<sklears_core::traits::Trained> {
993 fn predict(&self, x: &Features) -> Result<Array1<i32>> {
994 if x.is_empty() {
995 return Err(sklears_core::error::SklearsError::InvalidInput(
996 "Input cannot be empty".to_string(),
997 ));
998 }
999
1000 let n_samples = x.nrows();
1001 let mut predictions = Array1::zeros(n_samples);
1002 let classes = self.classes_.as_ref().unwrap();
1003
1004 let mut rng = if let Some(seed) = self.random_state {
1005 StdRng::seed_from_u64(seed)
1006 } else {
1007 StdRng::seed_from_u64(0)
1008 };
1009
1010 match &self.strategy {
1011 ContextAwareStrategy::Conditional { .. } => {
1012 let feature_bins = self.feature_bins_.as_ref().unwrap();
1013 let bin_class_probs = self.bin_class_probs_.as_ref().unwrap();
1014
1015 let global_class = classes[0]; for i in 0..x.nrows() {
1019 let mut bin_indices = Vec::with_capacity(feature_bins.len());
1020
1021 for (feature_idx, bins) in feature_bins.iter().enumerate() {
1022 let value = x[[i, feature_idx]];
1023 let bin_idx = bins
1024 .iter()
1025 .position(|&bin_edge| value < bin_edge)
1026 .unwrap_or(bins.len() - 1)
1027 .saturating_sub(1);
1028 bin_indices.push(bin_idx);
1029 }
1030
1031 if let Some(class_probs) = bin_class_probs.get(&bin_indices) {
1032 let rand_val: Float = rng.gen();
1034 let mut cumulative_prob = 0.0;
1035 let mut selected_class = global_class;
1036
1037 for (&class, &prob) in class_probs {
1038 cumulative_prob += prob;
1039 if rand_val <= cumulative_prob {
1040 selected_class = class;
1041 break;
1042 }
1043 }
1044 predictions[i] = selected_class;
1045 } else {
1046 predictions[i] = global_class;
1047 }
1048 }
1049 }
1050 _ => {
1051 let most_frequent_class = classes[0];
1053 predictions.fill(most_frequent_class);
1054 }
1055 }
1056
1057 Ok(predictions)
1058 }
1059}
1060
1061#[allow(non_snake_case)]
1062#[cfg(test)]
1063mod tests {
1064 use super::*;
1065 use approx::assert_abs_diff_eq;
1066 use scirs2_core::ndarray::{array, Array2};
1067
1068 #[test]
1069 fn test_conditional_regressor() {
1070 let x = Array2::from_shape_vec(
1071 (6, 2),
1072 vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0],
1073 )
1074 .unwrap();
1075 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1076
1077 let regressor = ContextAwareDummyRegressor::new(ContextAwareStrategy::Conditional {
1078 n_bins: 2,
1079 min_samples_per_bin: 1,
1080 });
1081
1082 let fitted = regressor.fit(&x, &y).unwrap();
1083 let predictions = fitted.predict(&x).unwrap();
1084
1085 assert_eq!(predictions.len(), 6);
1086 assert!(predictions.iter().all(|&p| p >= 1.0 && p <= 6.0));
1087 }
1088
1089 #[test]
1090 fn test_feature_weighted_regressor() {
1091 let x =
1092 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0]).unwrap();
1093 let y = array![1.0, 2.0, 3.0, 4.0];
1094
1095 let regressor = ContextAwareDummyRegressor::new(ContextAwareStrategy::FeatureWeighted {
1096 weighting: FeatureWeighting::Uniform,
1097 });
1098
1099 let fitted = regressor.fit(&x, &y).unwrap();
1100 let predictions = fitted.predict(&x).unwrap();
1101
1102 assert_eq!(predictions.len(), 4);
1103 }
1104
1105 #[test]
1106 fn test_cluster_based_regressor() {
1107 let x = Array2::from_shape_vec(
1108 (6, 2),
1109 vec![1.0, 1.0, 1.1, 1.1, 5.0, 5.0, 5.1, 5.1, 9.0, 9.0, 9.1, 9.1],
1110 )
1111 .unwrap();
1112 let y = array![1.0, 1.0, 5.0, 5.0, 9.0, 9.0];
1113
1114 let regressor = ContextAwareDummyRegressor::new(ContextAwareStrategy::ClusterBased {
1115 n_clusters: 3,
1116 max_iter: 10,
1117 })
1118 .with_random_state(42);
1119
1120 let fitted = regressor.fit(&x, &y).unwrap();
1121 let predictions = fitted.predict(&x).unwrap();
1122
1123 assert_eq!(predictions.len(), 6);
1124 }
1125
1126 #[test]
1127 fn test_locality_sensitive_regressor() {
1128 let x =
1129 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0]).unwrap();
1130 let y = array![1.0, 2.0, 3.0, 4.0];
1131
1132 let regressor = ContextAwareDummyRegressor::new(ContextAwareStrategy::LocalitySensitive {
1133 n_neighbors: 2,
1134 distance_power: 2.0,
1135 });
1136
1137 let fitted = regressor.fit(&x, &y).unwrap();
1138 let predictions = fitted.predict(&x).unwrap();
1139
1140 assert_eq!(predictions.len(), 4);
1141 }
1142
1143 #[test]
1144 fn test_adaptive_local_regressor() {
1145 let x = Array2::from_shape_vec(
1146 (6, 2),
1147 vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0],
1148 )
1149 .unwrap();
1150 let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1151
1152 let regressor = ContextAwareDummyRegressor::new(ContextAwareStrategy::AdaptiveLocal {
1153 radius: 2.0,
1154 min_local_samples: 2,
1155 })
1156 .with_random_state(42);
1157
1158 let fitted = regressor.fit(&x, &y).unwrap();
1159 let predictions = fitted.predict(&x).unwrap();
1160
1161 assert_eq!(predictions.len(), 6);
1162 }
1163
1164 #[test]
1165 fn test_conditional_classifier() {
1166 let x = Array2::from_shape_vec(
1167 (6, 2),
1168 vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0],
1169 )
1170 .unwrap();
1171 let y = array![0, 0, 1, 1, 0, 1];
1172
1173 let classifier = ContextAwareDummyClassifier::new(ContextAwareStrategy::Conditional {
1174 n_bins: 2,
1175 min_samples_per_bin: 1,
1176 })
1177 .with_random_state(42);
1178
1179 let fitted = classifier.fit(&x, &y).unwrap();
1180 let predictions = fitted.predict(&x).unwrap();
1181
1182 assert_eq!(predictions.len(), 6);
1183 assert!(predictions.iter().all(|&p| p == 0 || p == 1));
1184 }
1185}