sklears_dummy/
context_aware.rs

1//! Context-aware dummy estimators that use feature information
2//!
3//! This module provides dummy estimators that incorporate input features into their
4//! baseline predictions, making them more sophisticated than traditional dummy methods
5//! while still remaining simple and interpretable baselines.
6
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::random::{
9    essentials::Normal, prelude::*, rngs::StdRng, Distribution, Rng, SeedableRng,
10};
11use sklears_core::error::Result;
12use sklears_core::traits::{Estimator, Fit, Predict};
13use sklears_core::types::{Features, Float};
14use std::collections::HashMap;
15
16/// Strategy for context-aware predictions
17#[derive(Debug, Clone, PartialEq)]
18#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
19pub enum ContextAwareStrategy {
20    /// Make predictions conditional on feature bins/intervals
21    Conditional {
22        /// Number of bins for each feature
23        n_bins: usize,
24        /// Minimum samples per bin to make predictions
25        min_samples_per_bin: usize,
26    },
27    /// Use feature-weighted predictions based on feature importance
28    FeatureWeighted {
29        /// Weighting method for features
30        weighting: FeatureWeighting,
31    },
32    /// Cluster-based predictions using simple k-means
33    ClusterBased {
34        /// Number of clusters
35        n_clusters: usize,
36        /// Maximum iterations for clustering
37        max_iter: usize,
38    },
39    /// Locality-sensitive predictions using nearest neighbors
40    LocalitySensitive {
41        /// Number of neighbors to consider
42        n_neighbors: usize,
43        /// Distance metric weighting factor
44        distance_power: Float,
45    },
46    /// Adaptive local baselines that adjust based on local feature statistics
47    AdaptiveLocal {
48        /// Radius for local neighborhood
49        radius: Float,
50        /// Minimum samples in neighborhood
51        min_local_samples: usize,
52    },
53}
54
55/// Feature weighting methods
56#[derive(Debug, Clone, PartialEq)]
57#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
58pub enum FeatureWeighting {
59    /// Equal weights for all features
60    Uniform,
61    /// Weights based on feature variance
62    Variance,
63    /// Weights based on correlation with target
64    Correlation,
65    /// Custom user-specified weights
66    Custom(Array1<Float>),
67}
68
69/// Context-aware dummy regressor
70#[derive(Debug, Clone)]
71pub struct ContextAwareDummyRegressor<State = sklears_core::traits::Untrained> {
72    /// Strategy for context-aware predictions
73    pub strategy: ContextAwareStrategy,
74    /// Random state for reproducible output
75    pub random_state: Option<u64>,
76
77    // Fitted parameters
78    /// Feature bins for conditional strategy
79    pub(crate) feature_bins_: Option<Vec<Array1<Float>>>,
80    /// Bin predictions for conditional strategy
81    pub(crate) bin_predictions_: Option<HashMap<Vec<usize>, Float>>,
82
83    /// Feature weights for weighted strategy
84    pub(crate) feature_weights_: Option<Array1<Float>>,
85    /// Weighted prediction function parameters
86    pub(crate) weighted_intercept_: Option<Float>,
87    pub(crate) weighted_coefficients_: Option<Array1<Float>>,
88
89    /// Cluster centers for cluster-based strategy
90    pub(crate) cluster_centers_: Option<Array2<Float>>,
91    /// Cluster predictions
92    pub(crate) cluster_predictions_: Option<Array1<Float>>,
93
94    /// Training data for locality-sensitive strategy
95    pub(crate) training_features_: Option<Array2<Float>>,
96    pub(crate) training_targets_: Option<Array1<Float>>,
97
98    /// Local statistics for adaptive strategy
99    pub(crate) local_means_: Option<Array1<Float>>,
100    pub(crate) local_stds_: Option<Array1<Float>>,
101    pub(crate) local_centers_: Option<Array2<Float>>,
102
103    /// Phantom data for state
104    pub(crate) _state: std::marker::PhantomData<State>,
105}
106
107impl ContextAwareDummyRegressor {
108    /// Create a new context-aware dummy regressor
109    pub fn new(strategy: ContextAwareStrategy) -> Self {
110        Self {
111            strategy,
112            random_state: None,
113            feature_bins_: None,
114            bin_predictions_: None,
115            feature_weights_: None,
116            weighted_intercept_: None,
117            weighted_coefficients_: None,
118            cluster_centers_: None,
119            cluster_predictions_: None,
120            training_features_: None,
121            training_targets_: None,
122            local_means_: None,
123            local_stds_: None,
124            local_centers_: None,
125            _state: std::marker::PhantomData,
126        }
127    }
128
129    /// Set the random state for reproducible output
130    pub fn with_random_state(mut self, random_state: u64) -> Self {
131        self.random_state = Some(random_state);
132        self
133    }
134}
135
136impl Default for ContextAwareDummyRegressor {
137    fn default() -> Self {
138        Self::new(ContextAwareStrategy::Conditional {
139            n_bins: 5,
140            min_samples_per_bin: 3,
141        })
142    }
143}
144
145impl Estimator for ContextAwareDummyRegressor {
146    type Config = ();
147    type Error = sklears_core::error::SklearsError;
148    type Float = Float;
149
150    fn config(&self) -> &Self::Config {
151        &()
152    }
153}
154
155impl Fit<Features, Array1<Float>> for ContextAwareDummyRegressor {
156    type Fitted = ContextAwareDummyRegressor<sklears_core::traits::Trained>;
157
158    fn fit(self, x: &Features, y: &Array1<Float>) -> Result<Self::Fitted> {
159        if x.is_empty() || y.is_empty() {
160            return Err(sklears_core::error::SklearsError::InvalidInput(
161                "Input cannot be empty".to_string(),
162            ));
163        }
164
165        if x.nrows() != y.len() {
166            return Err(sklears_core::error::SklearsError::InvalidInput(
167                "Number of samples in X and y must be equal".to_string(),
168            ));
169        }
170
171        let mut fitted = ContextAwareDummyRegressor {
172            strategy: self.strategy.clone(),
173            random_state: self.random_state,
174            feature_bins_: None,
175            bin_predictions_: None,
176            feature_weights_: None,
177            weighted_intercept_: None,
178            weighted_coefficients_: None,
179            cluster_centers_: None,
180            cluster_predictions_: None,
181            training_features_: None,
182            training_targets_: None,
183            local_means_: None,
184            local_stds_: None,
185            local_centers_: None,
186            _state: std::marker::PhantomData,
187        };
188
189        match &self.strategy {
190            ContextAwareStrategy::Conditional {
191                n_bins,
192                min_samples_per_bin,
193            } => {
194                fitted.fit_conditional(x, y, *n_bins, *min_samples_per_bin)?;
195            }
196            ContextAwareStrategy::FeatureWeighted { weighting } => {
197                fitted.fit_feature_weighted(x, y, weighting)?;
198            }
199            ContextAwareStrategy::ClusterBased {
200                n_clusters,
201                max_iter,
202            } => {
203                fitted.fit_cluster_based(x, y, *n_clusters, *max_iter)?;
204            }
205            ContextAwareStrategy::LocalitySensitive {
206                n_neighbors,
207                distance_power,
208            } => {
209                fitted.fit_locality_sensitive(x, y, *n_neighbors, *distance_power)?;
210            }
211            ContextAwareStrategy::AdaptiveLocal {
212                radius,
213                min_local_samples,
214            } => {
215                fitted.fit_adaptive_local(x, y, *radius, *min_local_samples)?;
216            }
217        }
218
219        Ok(fitted)
220    }
221}
222
223impl ContextAwareDummyRegressor<sklears_core::traits::Trained> {
224    /// Fit conditional strategy
225    fn fit_conditional(
226        &mut self,
227        x: &Features,
228        y: &Array1<Float>,
229        n_bins: usize,
230        min_samples_per_bin: usize,
231    ) -> Result<()> {
232        let n_features = x.ncols();
233        let mut feature_bins = Vec::with_capacity(n_features);
234        let mut bin_predictions = HashMap::new();
235
236        // Create bins for each feature
237        for feature_idx in 0..n_features {
238            let feature_values = x.column(feature_idx);
239            let min_val = feature_values
240                .iter()
241                .fold(Float::INFINITY, |a, &b| a.min(b));
242            let max_val = feature_values
243                .iter()
244                .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
245
246            let bin_width = (max_val - min_val) / n_bins as Float;
247            let mut bins = Array1::zeros(n_bins + 1);
248
249            for i in 0..=n_bins {
250                bins[i] = min_val + i as Float * bin_width;
251            }
252            bins[n_bins] = max_val + 1e-10; // Ensure max value is included
253
254            feature_bins.push(bins);
255        }
256
257        // Compute predictions for each bin combination
258        for i in 0..x.nrows() {
259            let mut bin_indices = Vec::with_capacity(n_features);
260
261            for (feature_idx, bins) in feature_bins.iter().enumerate() {
262                let value = x[[i, feature_idx]];
263                let bin_idx = bins
264                    .iter()
265                    .position(|&bin_edge| value < bin_edge)
266                    .unwrap_or(bins.len() - 1)
267                    .saturating_sub(1);
268                bin_indices.push(bin_idx);
269            }
270
271            let entry = bin_predictions.entry(bin_indices).or_insert_with(Vec::new);
272            entry.push(y[i]);
273        }
274
275        // Compute mean for each bin with sufficient samples
276        let mut final_bin_predictions = HashMap::new();
277        for (bin_key, targets) in bin_predictions {
278            if targets.len() >= min_samples_per_bin {
279                let mean = targets.iter().sum::<Float>() / targets.len() as Float;
280                final_bin_predictions.insert(bin_key, mean);
281            }
282        }
283
284        self.feature_bins_ = Some(feature_bins);
285        self.bin_predictions_ = Some(final_bin_predictions);
286        Ok(())
287    }
288
289    /// Fit feature-weighted strategy
290    fn fit_feature_weighted(
291        &mut self,
292        x: &Features,
293        y: &Array1<Float>,
294        weighting: &FeatureWeighting,
295    ) -> Result<()> {
296        let n_features = x.ncols();
297        let weights = match weighting {
298            FeatureWeighting::Uniform => Array1::from_elem(n_features, 1.0 / n_features as Float),
299            FeatureWeighting::Variance => {
300                let mut weights = Array1::zeros(n_features);
301                for i in 0..n_features {
302                    let feature = x.column(i);
303                    let mean = feature.mean().unwrap_or(0.0);
304                    let variance = feature
305                        .iter()
306                        .map(|&val| (val - mean).powi(2))
307                        .sum::<Float>()
308                        / feature.len() as Float;
309                    weights[i] = variance;
310                }
311                let sum_weights = weights.sum();
312                if sum_weights > 0.0 {
313                    weights / sum_weights
314                } else {
315                    Array1::from_elem(n_features, 1.0 / n_features as Float)
316                }
317            }
318            FeatureWeighting::Correlation => {
319                let mut weights = Array1::zeros(n_features);
320                let y_mean = y.mean().unwrap_or(0.0);
321
322                for i in 0..n_features {
323                    let feature = x.column(i);
324                    let x_mean = feature.mean().unwrap_or(0.0);
325
326                    let mut numerator = 0.0;
327                    let mut x_var = 0.0;
328                    let mut y_var = 0.0;
329
330                    for j in 0..feature.len() {
331                        let x_diff = feature[j] - x_mean;
332                        let y_diff = y[j] - y_mean;
333                        numerator += x_diff * y_diff;
334                        x_var += x_diff * x_diff;
335                        y_var += y_diff * y_diff;
336                    }
337
338                    let correlation = if x_var > 0.0 && y_var > 0.0 {
339                        numerator / (x_var * y_var).sqrt()
340                    } else {
341                        0.0
342                    };
343
344                    weights[i] = correlation.abs();
345                }
346
347                let sum_weights = weights.sum();
348                if sum_weights > 0.0 {
349                    weights / sum_weights
350                } else {
351                    Array1::from_elem(n_features, 1.0 / n_features as Float)
352                }
353            }
354            FeatureWeighting::Custom(custom_weights) => {
355                if custom_weights.len() != n_features {
356                    return Err(sklears_core::error::SklearsError::InvalidInput(
357                        "Custom weights length must match number of features".to_string(),
358                    ));
359                }
360                custom_weights.clone()
361            }
362        };
363
364        // Compute weighted linear combination parameters
365        let y_mean = y.mean().unwrap_or(0.0);
366        let mut coefficients = Array1::zeros(n_features);
367
368        for i in 0..n_features {
369            let feature = x.column(i);
370            let x_mean = feature.mean().unwrap_or(0.0);
371            coefficients[i] = weights[i] * (y_mean - x_mean);
372        }
373
374        self.feature_weights_ = Some(weights);
375        self.weighted_intercept_ = Some(y_mean);
376        self.weighted_coefficients_ = Some(coefficients);
377        Ok(())
378    }
379
380    /// Fit cluster-based strategy using simple k-means
381    fn fit_cluster_based(
382        &mut self,
383        x: &Features,
384        y: &Array1<Float>,
385        n_clusters: usize,
386        max_iter: usize,
387    ) -> Result<()> {
388        let n_samples = x.nrows();
389        let n_features = x.ncols();
390
391        if n_clusters > n_samples {
392            return Err(sklears_core::error::SklearsError::InvalidInput(
393                "Number of clusters cannot exceed number of samples".to_string(),
394            ));
395        }
396
397        let mut rng = if let Some(seed) = self.random_state {
398            StdRng::seed_from_u64(seed)
399        } else {
400            StdRng::seed_from_u64(0)
401        };
402
403        // Initialize cluster centers randomly
404        let mut centers = Array2::zeros((n_clusters, n_features));
405        for i in 0..n_clusters {
406            let sample_idx = rng.gen_range(0..n_samples);
407            for j in 0..n_features {
408                centers[[i, j]] = x[[sample_idx, j]];
409            }
410        }
411
412        // K-means iterations
413        let mut assignments = vec![0; n_samples];
414
415        for _iter in 0..max_iter {
416            let mut changed = false;
417
418            // Assign points to nearest centers
419            for i in 0..n_samples {
420                let mut min_distance = Float::INFINITY;
421                let mut best_cluster = 0;
422
423                for cluster in 0..n_clusters {
424                    let mut distance = 0.0;
425                    for j in 0..n_features {
426                        let diff = x[[i, j]] - centers[[cluster, j]];
427                        distance += diff * diff;
428                    }
429
430                    if distance < min_distance {
431                        min_distance = distance;
432                        best_cluster = cluster;
433                    }
434                }
435
436                if assignments[i] != best_cluster {
437                    assignments[i] = best_cluster;
438                    changed = true;
439                }
440            }
441
442            if !changed {
443                break;
444            }
445
446            // Update cluster centers
447            let mut cluster_counts = vec![0; n_clusters];
448            centers.fill(0.0);
449
450            for i in 0..n_samples {
451                let cluster = assignments[i];
452                cluster_counts[cluster] += 1;
453                for j in 0..n_features {
454                    centers[[cluster, j]] += x[[i, j]];
455                }
456            }
457
458            for cluster in 0..n_clusters {
459                if cluster_counts[cluster] > 0 {
460                    for j in 0..n_features {
461                        centers[[cluster, j]] /= cluster_counts[cluster] as Float;
462                    }
463                }
464            }
465        }
466
467        // Compute cluster predictions
468        let mut cluster_targets: Vec<Vec<Float>> = vec![Vec::new(); n_clusters];
469        for i in 0..n_samples {
470            cluster_targets[assignments[i]].push(y[i]);
471        }
472
473        let mut cluster_predictions = Array1::zeros(n_clusters);
474        for i in 0..n_clusters {
475            if !cluster_targets[i].is_empty() {
476                cluster_predictions[i] =
477                    cluster_targets[i].iter().sum::<Float>() / cluster_targets[i].len() as Float;
478            }
479        }
480
481        self.cluster_centers_ = Some(centers);
482        self.cluster_predictions_ = Some(cluster_predictions);
483        Ok(())
484    }
485
486    /// Fit locality-sensitive strategy
487    fn fit_locality_sensitive(
488        &mut self,
489        x: &Features,
490        y: &Array1<Float>,
491        _n_neighbors: usize,
492        _distance_power: Float,
493    ) -> Result<()> {
494        // Store training data for prediction time
495        self.training_features_ = Some(x.clone());
496        self.training_targets_ = Some(y.clone());
497        Ok(())
498    }
499
500    /// Fit adaptive local strategy
501    fn fit_adaptive_local(
502        &mut self,
503        x: &Features,
504        y: &Array1<Float>,
505        radius: Float,
506        min_local_samples: usize,
507    ) -> Result<()> {
508        let n_samples = x.nrows();
509        let n_features = x.ncols();
510
511        // Create local statistics for representative points
512        let n_centers = (n_samples / min_local_samples).max(1);
513        let mut centers = Array2::zeros((n_centers, n_features));
514        let mut local_means = Array1::zeros(n_centers);
515        let mut local_stds = Array1::zeros(n_centers);
516
517        let mut rng = if let Some(seed) = self.random_state {
518            StdRng::seed_from_u64(seed)
519        } else {
520            StdRng::seed_from_u64(0)
521        };
522
523        // Select representative centers
524        for i in 0..n_centers {
525            let sample_idx = rng.gen_range(0..n_samples);
526            for j in 0..n_features {
527                centers[[i, j]] = x[[sample_idx, j]];
528            }
529        }
530
531        // Compute local statistics for each center
532        for i in 0..n_centers {
533            let mut local_targets = Vec::new();
534
535            for j in 0..n_samples {
536                let mut distance = 0.0;
537                for k in 0..n_features {
538                    let diff = x[[j, k]] - centers[[i, k]];
539                    distance += diff * diff;
540                }
541                distance = distance.sqrt();
542
543                if distance <= radius {
544                    local_targets.push(y[j]);
545                }
546            }
547
548            if local_targets.len() >= min_local_samples {
549                let mean = local_targets.iter().sum::<Float>() / local_targets.len() as Float;
550                let variance = local_targets
551                    .iter()
552                    .map(|&val| (val - mean).powi(2))
553                    .sum::<Float>()
554                    / local_targets.len() as Float;
555                let std_dev = variance.sqrt();
556
557                local_means[i] = mean;
558                local_stds[i] = std_dev;
559            } else {
560                // Fallback to global statistics
561                let global_mean = y.mean().unwrap_or(0.0);
562                let global_variance = y
563                    .iter()
564                    .map(|&val| (val - global_mean).powi(2))
565                    .sum::<Float>()
566                    / y.len() as Float;
567
568                local_means[i] = global_mean;
569                local_stds[i] = global_variance.sqrt();
570            }
571        }
572
573        self.local_centers_ = Some(centers);
574        self.local_means_ = Some(local_means);
575        self.local_stds_ = Some(local_stds);
576        Ok(())
577    }
578}
579
580impl Predict<Features, Array1<Float>>
581    for ContextAwareDummyRegressor<sklears_core::traits::Trained>
582{
583    fn predict(&self, x: &Features) -> Result<Array1<Float>> {
584        if x.is_empty() {
585            return Err(sklears_core::error::SklearsError::InvalidInput(
586                "Input cannot be empty".to_string(),
587            ));
588        }
589
590        let n_samples = x.nrows();
591        let mut predictions = Array1::zeros(n_samples);
592
593        match &self.strategy {
594            ContextAwareStrategy::Conditional { .. } => {
595                self.predict_conditional(x, &mut predictions)?;
596            }
597            ContextAwareStrategy::FeatureWeighted { .. } => {
598                self.predict_feature_weighted(x, &mut predictions)?;
599            }
600            ContextAwareStrategy::ClusterBased { .. } => {
601                self.predict_cluster_based(x, &mut predictions)?;
602            }
603            ContextAwareStrategy::LocalitySensitive {
604                n_neighbors,
605                distance_power,
606            } => {
607                self.predict_locality_sensitive(
608                    x,
609                    &mut predictions,
610                    *n_neighbors,
611                    *distance_power,
612                )?;
613            }
614            ContextAwareStrategy::AdaptiveLocal { radius, .. } => {
615                self.predict_adaptive_local(x, &mut predictions, *radius)?;
616            }
617        }
618
619        Ok(predictions)
620    }
621}
622
623impl ContextAwareDummyRegressor<sklears_core::traits::Trained> {
624    /// Predict using conditional strategy
625    fn predict_conditional(&self, x: &Features, predictions: &mut Array1<Float>) -> Result<()> {
626        let feature_bins = self.feature_bins_.as_ref().unwrap();
627        let bin_predictions = self.bin_predictions_.as_ref().unwrap();
628        let global_mean = bin_predictions.values().sum::<Float>() / bin_predictions.len() as Float;
629
630        for i in 0..x.nrows() {
631            let mut bin_indices = Vec::with_capacity(feature_bins.len());
632
633            for (feature_idx, bins) in feature_bins.iter().enumerate() {
634                let value = x[[i, feature_idx]];
635                let bin_idx = bins
636                    .iter()
637                    .position(|&bin_edge| value < bin_edge)
638                    .unwrap_or(bins.len() - 1)
639                    .saturating_sub(1);
640                bin_indices.push(bin_idx);
641            }
642
643            predictions[i] = *bin_predictions.get(&bin_indices).unwrap_or(&global_mean);
644        }
645
646        Ok(())
647    }
648
649    /// Predict using feature-weighted strategy
650    fn predict_feature_weighted(
651        &self,
652        x: &Features,
653        predictions: &mut Array1<Float>,
654    ) -> Result<()> {
655        let weights = self.feature_weights_.as_ref().unwrap();
656        let intercept = self.weighted_intercept_.unwrap();
657        let coefficients = self.weighted_coefficients_.as_ref().unwrap();
658
659        for i in 0..x.nrows() {
660            let mut weighted_sum = intercept;
661            for j in 0..x.ncols() {
662                weighted_sum += x[[i, j]] * weights[j] + coefficients[j];
663            }
664            predictions[i] = weighted_sum;
665        }
666
667        Ok(())
668    }
669
670    /// Predict using cluster-based strategy
671    fn predict_cluster_based(&self, x: &Features, predictions: &mut Array1<Float>) -> Result<()> {
672        let centers = self.cluster_centers_.as_ref().unwrap();
673        let cluster_predictions = self.cluster_predictions_.as_ref().unwrap();
674
675        for i in 0..x.nrows() {
676            let mut min_distance = Float::INFINITY;
677            let mut best_cluster = 0;
678
679            for cluster in 0..centers.nrows() {
680                let mut distance = 0.0;
681                for j in 0..x.ncols() {
682                    let diff = x[[i, j]] - centers[[cluster, j]];
683                    distance += diff * diff;
684                }
685
686                if distance < min_distance {
687                    min_distance = distance;
688                    best_cluster = cluster;
689                }
690            }
691
692            predictions[i] = cluster_predictions[best_cluster];
693        }
694
695        Ok(())
696    }
697
698    /// Predict using locality-sensitive strategy
699    fn predict_locality_sensitive(
700        &self,
701        x: &Features,
702        predictions: &mut Array1<Float>,
703        n_neighbors: usize,
704        distance_power: Float,
705    ) -> Result<()> {
706        let training_features = self.training_features_.as_ref().unwrap();
707        let training_targets = self.training_targets_.as_ref().unwrap();
708
709        for i in 0..x.nrows() {
710            let mut distances = Vec::new();
711
712            // Compute distances to all training points
713            for j in 0..training_features.nrows() {
714                let mut distance = 0.0;
715                for k in 0..x.ncols() {
716                    let diff = x[[i, k]] - training_features[[j, k]];
717                    distance += diff * diff;
718                }
719                distance = distance.sqrt();
720                distances.push((distance, j));
721            }
722
723            // Sort by distance and take k nearest neighbors
724            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
725            let k_nearest = distances.into_iter().take(n_neighbors).collect::<Vec<_>>();
726
727            // Weighted average based on inverse distance
728            let mut weighted_sum = 0.0;
729            let mut weight_sum = 0.0;
730
731            for (distance, idx) in k_nearest {
732                let weight = if distance == 0.0 {
733                    1000.0 // Large weight for exact matches
734                } else {
735                    1.0 / distance.powf(distance_power)
736                };
737
738                weighted_sum += weight * training_targets[idx];
739                weight_sum += weight;
740            }
741
742            predictions[i] = if weight_sum > 0.0 {
743                weighted_sum / weight_sum
744            } else {
745                training_targets.mean().unwrap_or(0.0)
746            };
747        }
748
749        Ok(())
750    }
751
752    /// Predict using adaptive local strategy
753    fn predict_adaptive_local(
754        &self,
755        x: &Features,
756        predictions: &mut Array1<Float>,
757        radius: Float,
758    ) -> Result<()> {
759        let centers = self.local_centers_.as_ref().unwrap();
760        let local_means = self.local_means_.as_ref().unwrap();
761        let local_stds = self.local_stds_.as_ref().unwrap();
762
763        let mut rng = if let Some(seed) = self.random_state {
764            StdRng::seed_from_u64(seed)
765        } else {
766            StdRng::seed_from_u64(0)
767        };
768
769        for i in 0..x.nrows() {
770            // Find nearest center within radius
771            let mut min_distance = Float::INFINITY;
772            let mut best_center = 0;
773
774            for j in 0..centers.nrows() {
775                let mut distance = 0.0;
776                for k in 0..x.ncols() {
777                    let diff = x[[i, k]] - centers[[j, k]];
778                    distance += diff * diff;
779                }
780                distance = distance.sqrt();
781
782                if distance <= radius && distance < min_distance {
783                    min_distance = distance;
784                    best_center = j;
785                }
786            }
787
788            // Sample from local distribution
789            if min_distance <= radius {
790                let mean = local_means[best_center];
791                let std = local_stds[best_center];
792
793                if std > 0.0 {
794                    let normal = Normal::new(mean, std).unwrap();
795                    predictions[i] = normal.sample(&mut rng);
796                } else {
797                    predictions[i] = mean;
798                }
799            } else {
800                // Fallback to global mean
801                predictions[i] = local_means.mean().unwrap_or(0.0);
802            }
803        }
804
805        Ok(())
806    }
807}
808
809/// Context-aware dummy classifier
810#[derive(Debug, Clone)]
811pub struct ContextAwareDummyClassifier<State = sklears_core::traits::Untrained> {
812    /// Strategy for context-aware predictions
813    pub strategy: ContextAwareStrategy,
814    /// Random state for reproducible output
815    pub random_state: Option<u64>,
816
817    // Fitted parameters (similar structure to regressor but for classification)
818    pub(crate) feature_bins_: Option<Vec<Array1<Float>>>,
819    pub(crate) bin_class_probs_: Option<HashMap<Vec<usize>, HashMap<i32, Float>>>,
820    pub(crate) classes_: Option<Array1<i32>>,
821    pub(crate) training_features_: Option<Array2<Float>>,
822    pub(crate) training_targets_: Option<Array1<i32>>,
823
824    /// Phantom data for state
825    pub(crate) _state: std::marker::PhantomData<State>,
826}
827
828impl ContextAwareDummyClassifier {
829    /// Create a new context-aware dummy classifier
830    pub fn new(strategy: ContextAwareStrategy) -> Self {
831        Self {
832            strategy,
833            random_state: None,
834            feature_bins_: None,
835            bin_class_probs_: None,
836            classes_: None,
837            training_features_: None,
838            training_targets_: None,
839            _state: std::marker::PhantomData,
840        }
841    }
842
843    /// Set the random state for reproducible output
844    pub fn with_random_state(mut self, random_state: u64) -> Self {
845        self.random_state = Some(random_state);
846        self
847    }
848}
849
850impl Default for ContextAwareDummyClassifier {
851    fn default() -> Self {
852        Self::new(ContextAwareStrategy::Conditional {
853            n_bins: 5,
854            min_samples_per_bin: 3,
855        })
856    }
857}
858
859impl Estimator for ContextAwareDummyClassifier {
860    type Config = ();
861    type Error = sklears_core::error::SklearsError;
862    type Float = Float;
863
864    fn config(&self) -> &Self::Config {
865        &()
866    }
867}
868
869impl Fit<Features, Array1<i32>> for ContextAwareDummyClassifier {
870    type Fitted = ContextAwareDummyClassifier<sklears_core::traits::Trained>;
871
872    fn fit(self, x: &Features, y: &Array1<i32>) -> Result<Self::Fitted> {
873        if x.is_empty() || y.is_empty() {
874            return Err(sklears_core::error::SklearsError::InvalidInput(
875                "Input cannot be empty".to_string(),
876            ));
877        }
878
879        if x.nrows() != y.len() {
880            return Err(sklears_core::error::SklearsError::InvalidInput(
881                "Number of samples in X and y must be equal".to_string(),
882            ));
883        }
884
885        // Get unique classes
886        let mut unique_classes = y.iter().cloned().collect::<Vec<_>>();
887        unique_classes.sort_unstable();
888        unique_classes.dedup();
889        let classes = Array1::from_vec(unique_classes);
890
891        let mut fitted = ContextAwareDummyClassifier {
892            strategy: self.strategy.clone(),
893            random_state: self.random_state,
894            feature_bins_: None,
895            bin_class_probs_: None,
896            classes_: Some(classes),
897            training_features_: None,
898            training_targets_: None,
899            _state: std::marker::PhantomData,
900        };
901
902        // For now, implement only conditional strategy for classifier
903        match &self.strategy {
904            ContextAwareStrategy::Conditional {
905                n_bins,
906                min_samples_per_bin,
907            } => {
908                fitted.fit_conditional_classifier(x, y, *n_bins, *min_samples_per_bin)?;
909            }
910            _ => {
911                // Store training data for other strategies
912                fitted.training_features_ = Some(x.clone());
913                fitted.training_targets_ = Some(y.clone());
914            }
915        }
916
917        Ok(fitted)
918    }
919}
920
921impl ContextAwareDummyClassifier<sklears_core::traits::Trained> {
922    /// Fit conditional strategy for classification
923    fn fit_conditional_classifier(
924        &mut self,
925        x: &Features,
926        y: &Array1<i32>,
927        n_bins: usize,
928        min_samples_per_bin: usize,
929    ) -> Result<()> {
930        let n_features = x.ncols();
931        let mut feature_bins = Vec::with_capacity(n_features);
932        let mut bin_class_counts: HashMap<Vec<usize>, HashMap<i32, usize>> = HashMap::new();
933
934        // Create bins for each feature
935        for feature_idx in 0..n_features {
936            let feature_values = x.column(feature_idx);
937            let min_val = feature_values
938                .iter()
939                .fold(Float::INFINITY, |a, &b| a.min(b));
940            let max_val = feature_values
941                .iter()
942                .fold(Float::NEG_INFINITY, |a, &b| a.max(b));
943
944            let bin_width = (max_val - min_val) / n_bins as Float;
945            let mut bins = Array1::zeros(n_bins + 1);
946
947            for i in 0..=n_bins {
948                bins[i] = min_val + i as Float * bin_width;
949            }
950            bins[n_bins] = max_val + 1e-10;
951
952            feature_bins.push(bins);
953        }
954
955        // Count classes in each bin combination
956        for i in 0..x.nrows() {
957            let mut bin_indices = Vec::with_capacity(n_features);
958
959            for (feature_idx, bins) in feature_bins.iter().enumerate() {
960                let value = x[[i, feature_idx]];
961                let bin_idx = bins
962                    .iter()
963                    .position(|&bin_edge| value < bin_edge)
964                    .unwrap_or(bins.len() - 1)
965                    .saturating_sub(1);
966                bin_indices.push(bin_idx);
967            }
968
969            let class_counts = bin_class_counts.entry(bin_indices).or_default();
970            *class_counts.entry(y[i]).or_insert(0) += 1;
971        }
972
973        // Convert counts to probabilities for bins with sufficient samples
974        let mut bin_class_probs = HashMap::new();
975        for (bin_key, class_counts) in bin_class_counts {
976            let total_count: usize = class_counts.values().sum();
977            if total_count >= min_samples_per_bin {
978                let mut class_probs = HashMap::new();
979                for (&class, &count) in &class_counts {
980                    class_probs.insert(class, count as Float / total_count as Float);
981                }
982                bin_class_probs.insert(bin_key, class_probs);
983            }
984        }
985
986        self.feature_bins_ = Some(feature_bins);
987        self.bin_class_probs_ = Some(bin_class_probs);
988        Ok(())
989    }
990}
991
992impl Predict<Features, Array1<i32>> for ContextAwareDummyClassifier<sklears_core::traits::Trained> {
993    fn predict(&self, x: &Features) -> Result<Array1<i32>> {
994        if x.is_empty() {
995            return Err(sklears_core::error::SklearsError::InvalidInput(
996                "Input cannot be empty".to_string(),
997            ));
998        }
999
1000        let n_samples = x.nrows();
1001        let mut predictions = Array1::zeros(n_samples);
1002        let classes = self.classes_.as_ref().unwrap();
1003
1004        let mut rng = if let Some(seed) = self.random_state {
1005            StdRng::seed_from_u64(seed)
1006        } else {
1007            StdRng::seed_from_u64(0)
1008        };
1009
1010        match &self.strategy {
1011            ContextAwareStrategy::Conditional { .. } => {
1012                let feature_bins = self.feature_bins_.as_ref().unwrap();
1013                let bin_class_probs = self.bin_class_probs_.as_ref().unwrap();
1014
1015                // Global class distribution as fallback
1016                let global_class = classes[0]; // Simplified fallback
1017
1018                for i in 0..x.nrows() {
1019                    let mut bin_indices = Vec::with_capacity(feature_bins.len());
1020
1021                    for (feature_idx, bins) in feature_bins.iter().enumerate() {
1022                        let value = x[[i, feature_idx]];
1023                        let bin_idx = bins
1024                            .iter()
1025                            .position(|&bin_edge| value < bin_edge)
1026                            .unwrap_or(bins.len() - 1)
1027                            .saturating_sub(1);
1028                        bin_indices.push(bin_idx);
1029                    }
1030
1031                    if let Some(class_probs) = bin_class_probs.get(&bin_indices) {
1032                        // Sample from class distribution
1033                        let rand_val: Float = rng.gen();
1034                        let mut cumulative_prob = 0.0;
1035                        let mut selected_class = global_class;
1036
1037                        for (&class, &prob) in class_probs {
1038                            cumulative_prob += prob;
1039                            if rand_val <= cumulative_prob {
1040                                selected_class = class;
1041                                break;
1042                            }
1043                        }
1044                        predictions[i] = selected_class;
1045                    } else {
1046                        predictions[i] = global_class;
1047                    }
1048                }
1049            }
1050            _ => {
1051                // Fallback for other strategies - use most frequent class
1052                let most_frequent_class = classes[0];
1053                predictions.fill(most_frequent_class);
1054            }
1055        }
1056
1057        Ok(predictions)
1058    }
1059}
1060
1061#[allow(non_snake_case)]
1062#[cfg(test)]
1063mod tests {
1064    use super::*;
1065    use approx::assert_abs_diff_eq;
1066    use scirs2_core::ndarray::{array, Array2};
1067
1068    #[test]
1069    fn test_conditional_regressor() {
1070        let x = Array2::from_shape_vec(
1071            (6, 2),
1072            vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0],
1073        )
1074        .unwrap();
1075        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1076
1077        let regressor = ContextAwareDummyRegressor::new(ContextAwareStrategy::Conditional {
1078            n_bins: 2,
1079            min_samples_per_bin: 1,
1080        });
1081
1082        let fitted = regressor.fit(&x, &y).unwrap();
1083        let predictions = fitted.predict(&x).unwrap();
1084
1085        assert_eq!(predictions.len(), 6);
1086        assert!(predictions.iter().all(|&p| p >= 1.0 && p <= 6.0));
1087    }
1088
1089    #[test]
1090    fn test_feature_weighted_regressor() {
1091        let x =
1092            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0]).unwrap();
1093        let y = array![1.0, 2.0, 3.0, 4.0];
1094
1095        let regressor = ContextAwareDummyRegressor::new(ContextAwareStrategy::FeatureWeighted {
1096            weighting: FeatureWeighting::Uniform,
1097        });
1098
1099        let fitted = regressor.fit(&x, &y).unwrap();
1100        let predictions = fitted.predict(&x).unwrap();
1101
1102        assert_eq!(predictions.len(), 4);
1103    }
1104
1105    #[test]
1106    fn test_cluster_based_regressor() {
1107        let x = Array2::from_shape_vec(
1108            (6, 2),
1109            vec![1.0, 1.0, 1.1, 1.1, 5.0, 5.0, 5.1, 5.1, 9.0, 9.0, 9.1, 9.1],
1110        )
1111        .unwrap();
1112        let y = array![1.0, 1.0, 5.0, 5.0, 9.0, 9.0];
1113
1114        let regressor = ContextAwareDummyRegressor::new(ContextAwareStrategy::ClusterBased {
1115            n_clusters: 3,
1116            max_iter: 10,
1117        })
1118        .with_random_state(42);
1119
1120        let fitted = regressor.fit(&x, &y).unwrap();
1121        let predictions = fitted.predict(&x).unwrap();
1122
1123        assert_eq!(predictions.len(), 6);
1124    }
1125
1126    #[test]
1127    fn test_locality_sensitive_regressor() {
1128        let x =
1129            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0]).unwrap();
1130        let y = array![1.0, 2.0, 3.0, 4.0];
1131
1132        let regressor = ContextAwareDummyRegressor::new(ContextAwareStrategy::LocalitySensitive {
1133            n_neighbors: 2,
1134            distance_power: 2.0,
1135        });
1136
1137        let fitted = regressor.fit(&x, &y).unwrap();
1138        let predictions = fitted.predict(&x).unwrap();
1139
1140        assert_eq!(predictions.len(), 4);
1141    }
1142
1143    #[test]
1144    fn test_adaptive_local_regressor() {
1145        let x = Array2::from_shape_vec(
1146            (6, 2),
1147            vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0],
1148        )
1149        .unwrap();
1150        let y = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1151
1152        let regressor = ContextAwareDummyRegressor::new(ContextAwareStrategy::AdaptiveLocal {
1153            radius: 2.0,
1154            min_local_samples: 2,
1155        })
1156        .with_random_state(42);
1157
1158        let fitted = regressor.fit(&x, &y).unwrap();
1159        let predictions = fitted.predict(&x).unwrap();
1160
1161        assert_eq!(predictions.len(), 6);
1162    }
1163
1164    #[test]
1165    fn test_conditional_classifier() {
1166        let x = Array2::from_shape_vec(
1167            (6, 2),
1168            vec![1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, 7.0],
1169        )
1170        .unwrap();
1171        let y = array![0, 0, 1, 1, 0, 1];
1172
1173        let classifier = ContextAwareDummyClassifier::new(ContextAwareStrategy::Conditional {
1174            n_bins: 2,
1175            min_samples_per_bin: 1,
1176        })
1177        .with_random_state(42);
1178
1179        let fitted = classifier.fit(&x, &y).unwrap();
1180        let predictions = fitted.predict(&x).unwrap();
1181
1182        assert_eq!(predictions.len(), 6);
1183        assert!(predictions.iter().all(|&p| p == 0 || p == 1));
1184    }
1185}