sklears_multioutput/
chains.rs

1//! Chain-based multi-output learning algorithms
2//!
3//! This module provides chain-based approaches for multi-label and multi-output problems,
4//! including ClassifierChain, RegressorChain, EnsembleOfChains, and BayesianClassifierChain.
5
6use crate::utils::*;
7// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
8use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2, Axis};
9use sklears_core::{
10    error::{Result as SklResult, SklearsError},
11    traits::{Estimator, Fit, Predict, Untrained},
12    types::Float,
13};
14
15/// Classifier Chain
16///
17/// A multi-label model that arranges binary classifiers into a chain.
18/// Each model makes a prediction in the order specified by the chain using
19/// all of the available features provided to the model plus the predictions
20/// of models that are earlier in the chain.
21///
22/// # Examples
23///
24/// ```
25/// use sklears_multioutput::chains::ClassifierChain;
26/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
27/// use scirs2_core::ndarray::array;
28///
29/// // This is a simple example showing the structure
30/// let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
31/// let labels = array![[0, 1], [1, 0], [1, 1]];
32/// ```
33#[derive(Debug, Clone)]
34pub struct ClassifierChain<S = Untrained> {
35    state: S,
36    order: Option<Vec<usize>>,
37    cv: Option<usize>,
38    random_state: Option<u64>,
39}
40
41impl ClassifierChain<Untrained> {
42    /// Create a new ClassifierChain instance
43    pub fn new() -> Self {
44        Self {
45            state: Untrained,
46            order: None,
47            cv: None,
48            random_state: None,
49        }
50    }
51
52    /// Set the chain order
53    pub fn order(mut self, order: Vec<usize>) -> Self {
54        self.order = Some(order);
55        self
56    }
57
58    /// Set cross-validation folds for training
59    pub fn cv(mut self, cv: usize) -> Self {
60        self.cv = Some(cv);
61        self
62    }
63
64    /// Set random state for reproducibility
65    pub fn random_state(mut self, random_state: u64) -> Self {
66        self.random_state = Some(random_state);
67        self
68    }
69}
70
71impl Default for ClassifierChain<Untrained> {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77impl Estimator for ClassifierChain<Untrained> {
78    type Config = ();
79    type Error = SklearsError;
80    type Float = Float;
81
82    fn config(&self) -> &Self::Config {
83        &()
84    }
85}
86
87impl ClassifierChain<Untrained> {
88    /// Fit the classifier chain using a simple mock approach
89    pub fn fit_simple(
90        self,
91        X: &ArrayView2<'_, Float>,
92        y: &Array2<i32>,
93    ) -> SklResult<ClassifierChain<ClassifierChainTrained>> {
94        let (n_samples, n_features) = X.dim();
95        let n_labels = y.ncols();
96
97        if n_samples != y.nrows() {
98            return Err(SklearsError::InvalidInput(
99                "X and y must have the same number of samples".to_string(),
100            ));
101        }
102
103        // Determine chain order
104        let order = self
105            .order
106            .clone()
107            .unwrap_or_else(|| (0..n_labels).collect());
108
109        if order.len() != n_labels {
110            return Err(SklearsError::InvalidInput(
111                "Chain order must contain all label indices".to_string(),
112            ));
113        }
114
115        // Train models in the chain
116        let mut models = Vec::new();
117        let mut current_features = X.to_owned();
118
119        for (i, &label_idx) in order.iter().enumerate() {
120            let y_binary = y.column(label_idx).to_owned();
121
122            // Train binary classifier
123            let model = train_binary_classifier(&current_features.view(), &y_binary)?;
124            models.push(model);
125
126            // Add predictions as features for next model (except for the last one)
127            if i < order.len() - 1 {
128                let predictions = predict_binary_classifier(&current_features.view(), &models[i]);
129                let n_current_features = current_features.ncols();
130                let mut new_features = Array2::<Float>::zeros((n_samples, n_current_features + 1));
131
132                // Copy existing features
133                new_features
134                    .slice_mut(s![.., ..n_current_features])
135                    .assign(&current_features);
136
137                // Add predictions as new feature
138                for j in 0..n_samples {
139                    new_features[[j, n_current_features]] = predictions[j] as Float;
140                }
141
142                current_features = new_features;
143            }
144        }
145
146        let trained_state = ClassifierChainTrained {
147            models,
148            order,
149            n_features,
150            n_labels,
151        };
152
153        Ok(ClassifierChain {
154            state: trained_state,
155            order: self.order,
156            cv: self.cv,
157            random_state: self.random_state,
158        })
159    }
160}
161
162impl Fit<ArrayView2<'_, Float>, Array2<i32>, ClassifierChainTrained>
163    for ClassifierChain<Untrained>
164{
165    type Fitted = ClassifierChain<ClassifierChainTrained>;
166
167    fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
168        self.fit_simple(X, y)
169    }
170}
171
172/// Trained state for ClassifierChain
173#[derive(Debug, Clone)]
174pub struct ClassifierChainTrained {
175    models: Vec<SimpleBinaryModel>,
176    order: Vec<usize>,
177    n_features: usize,
178    n_labels: usize,
179}
180
181impl Predict<ArrayView2<'_, Float>, Array2<i32>> for ClassifierChain<ClassifierChainTrained> {
182    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
183        let (n_samples, n_features) = X.dim();
184        if n_features != self.state.n_features {
185            return Err(SklearsError::InvalidInput(
186                "X has different number of features than training data".to_string(),
187            ));
188        }
189
190        let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
191        let mut current_features = X.to_owned();
192
193        // Make predictions following the chain order
194        for (i, &label_idx) in self.state.order.iter().enumerate() {
195            let model = &self.state.models[i];
196            let label_predictions = predict_binary_classifier(&current_features.view(), model);
197
198            // Store predictions
199            for j in 0..n_samples {
200                predictions[[j, label_idx]] = label_predictions[j];
201            }
202
203            // Add predictions as features for next model (if not last)
204            if i < self.state.order.len() - 1 {
205                let n_current_features = current_features.ncols();
206                let mut new_features = Array2::<Float>::zeros((n_samples, n_current_features + 1));
207
208                // Copy existing features
209                new_features
210                    .slice_mut(s![.., ..n_current_features])
211                    .assign(&current_features);
212
213                // Add current label predictions as feature
214                for j in 0..n_samples {
215                    new_features[[j, n_current_features]] = label_predictions[j] as Float;
216                }
217
218                current_features = new_features;
219            }
220        }
221
222        Ok(predictions)
223    }
224}
225
226impl ClassifierChain<ClassifierChainTrained> {
227    /// Predict probabilities for each label
228    pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
229        let (n_samples, n_features) = X.dim();
230        if n_features != self.state.n_features {
231            return Err(SklearsError::InvalidInput(
232                "X has different number of features than training data".to_string(),
233            ));
234        }
235
236        let mut probabilities = Array2::<Float>::zeros((n_samples, self.state.n_labels));
237        let mut current_features = X.to_owned();
238
239        // Make probability predictions following the chain order
240        for (i, &label_idx) in self.state.order.iter().enumerate() {
241            let model = &self.state.models[i];
242            let label_probabilities = predict_binary_probabilities(&current_features.view(), model);
243
244            // Store probabilities
245            for j in 0..n_samples {
246                probabilities[[j, label_idx]] = label_probabilities[j];
247            }
248
249            // Add predictions as features for next model (if not last)
250            if i < self.state.order.len() - 1 {
251                let label_predictions =
252                    label_probabilities.mapv(|p| if p > 0.5 { 1.0 } else { 0.0 });
253                let n_current_features = current_features.ncols();
254                let mut new_features = Array2::<Float>::zeros((n_samples, n_current_features + 1));
255
256                // Copy existing features
257                new_features
258                    .slice_mut(s![.., ..n_current_features])
259                    .assign(&current_features);
260
261                // Add predictions as feature
262                for j in 0..n_samples {
263                    new_features[[j, n_current_features]] = label_predictions[j];
264                }
265
266                current_features = new_features;
267            }
268        }
269
270        Ok(probabilities)
271    }
272
273    /// Get the chain order used during training
274    pub fn chain_order(&self) -> &[usize] {
275        &self.state.order
276    }
277
278    /// Get the number of models in the chain
279    pub fn n_models(&self) -> usize {
280        self.state.models.len()
281    }
282
283    /// Get number of targets/labels
284    pub fn n_targets(&self) -> usize {
285        self.state.n_labels
286    }
287
288    /// Simple prediction method (alias for predict)
289    pub fn predict_simple(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
290        self.predict(X)
291    }
292
293    /// Monte Carlo prediction (simplified)
294    pub fn predict_monte_carlo(
295        &self,
296        X: &ArrayView2<'_, Float>,
297        n_samples: usize,
298        random_state: Option<u64>,
299    ) -> SklResult<Array2<Float>> {
300        if n_samples == 0 {
301            return Err(SklearsError::InvalidInput(
302                "n_samples must be greater than 0".to_string(),
303            ));
304        }
305        // For now, just return probabilities
306        self.predict_proba(X)
307    }
308
309    /// Monte Carlo prediction for labels (simplified)
310    pub fn predict_monte_carlo_labels(
311        &self,
312        X: &ArrayView2<'_, Float>,
313        n_samples: usize,
314        random_state: Option<u64>,
315    ) -> SklResult<Array2<i32>> {
316        if n_samples == 0 {
317            return Err(SklearsError::InvalidInput(
318                "n_samples must be greater than 0".to_string(),
319            ));
320        }
321        // For now, just return predictions
322        self.predict(X)
323    }
324}
325
326/// Regressor Chain
327///
328/// A multi-output model that arranges regressors into a chain.
329/// Each model makes a prediction in the order specified by the chain using
330/// all of the available features provided to the model plus the predictions
331/// of models that are earlier in the chain.
332///
333/// # Examples
334///
335/// ```
336/// use sklears_multioutput::chains::RegressorChain;
337/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
338/// use scirs2_core::ndarray::array;
339///
340/// // This is a simple example showing the structure
341/// let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
342/// let targets = array![[1.5, 2.5], [2.5, 3.5], [3.5, 1.5]];
343/// ```
344#[derive(Debug, Clone)]
345pub struct RegressorChain<S = Untrained> {
346    state: S,
347    order: Option<Vec<usize>>,
348    cv: Option<usize>,
349    random_state: Option<u64>,
350}
351
352impl RegressorChain<Untrained> {
353    /// Create a new RegressorChain instance
354    pub fn new() -> Self {
355        Self {
356            state: Untrained,
357            order: None,
358            cv: None,
359            random_state: None,
360        }
361    }
362
363    /// Set the chain order
364    pub fn order(mut self, order: Vec<usize>) -> Self {
365        self.order = Some(order);
366        self
367    }
368
369    /// Set cross-validation folds for training
370    pub fn cv(mut self, cv: usize) -> Self {
371        self.cv = Some(cv);
372        self
373    }
374
375    /// Set random state for reproducibility
376    pub fn random_state(mut self, random_state: u64) -> Self {
377        self.random_state = Some(random_state);
378        self
379    }
380}
381
382impl Default for RegressorChain<Untrained> {
383    fn default() -> Self {
384        Self::new()
385    }
386}
387
388impl Estimator for RegressorChain<Untrained> {
389    type Config = ();
390    type Error = SklearsError;
391    type Float = Float;
392
393    fn config(&self) -> &Self::Config {
394        &()
395    }
396}
397
398impl RegressorChain<Untrained> {
399    /// Fit the regressor chain using a simple linear approach
400    pub fn fit_simple(
401        self,
402        X: &ArrayView2<'_, Float>,
403        y: &Array2<Float>,
404    ) -> SklResult<RegressorChain<RegressorChainTrained>> {
405        let (n_samples, n_features) = X.dim();
406        let n_targets = y.ncols();
407
408        if n_samples != y.nrows() {
409            return Err(SklearsError::InvalidInput(
410                "X and y must have the same number of samples".to_string(),
411            ));
412        }
413
414        // Determine chain order
415        let order = self
416            .order
417            .clone()
418            .unwrap_or_else(|| (0..n_targets).collect());
419
420        if order.len() != n_targets {
421            return Err(SklearsError::InvalidInput(
422                "Chain order must contain all target indices".to_string(),
423            ));
424        }
425
426        // Train models in the chain
427        let mut models = Vec::new();
428        let mut current_features = X.to_owned();
429
430        for (i, &target_idx) in order.iter().enumerate() {
431            let y_target = y.column(target_idx).to_owned();
432
433            // Train linear regressor
434            let model = train_simple_linear_classifier(&current_features.view(), &y_target)?;
435            models.push(model);
436
437            // Add predictions as features for next model (except for the last one)
438            if i < order.len() - 1 {
439                let predictions = predict_simple_linear(&current_features.view(), &models[i]);
440                let n_current_features = current_features.ncols();
441                let mut new_features = Array2::<Float>::zeros((n_samples, n_current_features + 1));
442
443                // Copy existing features
444                new_features
445                    .slice_mut(s![.., ..n_current_features])
446                    .assign(&current_features);
447
448                // Add predictions as new feature
449                for j in 0..n_samples {
450                    new_features[[j, n_current_features]] = predictions[j];
451                }
452
453                current_features = new_features;
454            }
455        }
456
457        let trained_state = RegressorChainTrained {
458            models,
459            order,
460            n_features,
461            n_targets,
462        };
463
464        Ok(RegressorChain {
465            state: trained_state,
466            order: self.order,
467            cv: self.cv,
468            random_state: self.random_state,
469        })
470    }
471}
472
473impl Fit<ArrayView2<'_, Float>, Array2<Float>, RegressorChainTrained>
474    for RegressorChain<Untrained>
475{
476    type Fitted = RegressorChain<RegressorChainTrained>;
477
478    fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<Float>) -> SklResult<Self::Fitted> {
479        self.fit_simple(X, y)
480    }
481}
482
483/// Trained state for RegressorChain
484#[derive(Debug, Clone)]
485pub struct RegressorChainTrained {
486    models: Vec<SimpleLinearClassifier>,
487    order: Vec<usize>,
488    n_features: usize,
489    n_targets: usize,
490}
491
492impl Predict<ArrayView2<'_, Float>, Array2<Float>> for RegressorChain<RegressorChainTrained> {
493    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
494        let (n_samples, n_features) = X.dim();
495        if n_features != self.state.n_features {
496            return Err(SklearsError::InvalidInput(
497                "X has different number of features than training data".to_string(),
498            ));
499        }
500
501        let mut predictions = Array2::<Float>::zeros((n_samples, self.state.n_targets));
502        let mut current_features = X.to_owned();
503
504        // Make predictions following the chain order
505        for (i, &target_idx) in self.state.order.iter().enumerate() {
506            let model = &self.state.models[i];
507            let target_predictions = predict_simple_linear(&current_features.view(), model);
508
509            // Store predictions
510            for j in 0..n_samples {
511                predictions[[j, target_idx]] = target_predictions[j];
512            }
513
514            // Add predictions as features for next model (if not last)
515            if i < self.state.order.len() - 1 {
516                let n_current_features = current_features.ncols();
517                let mut new_features = Array2::<Float>::zeros((n_samples, n_current_features + 1));
518
519                // Copy existing features
520                new_features
521                    .slice_mut(s![.., ..n_current_features])
522                    .assign(&current_features);
523
524                // Add current target predictions as feature
525                for j in 0..n_samples {
526                    new_features[[j, n_current_features]] = target_predictions[j];
527                }
528
529                current_features = new_features;
530            }
531        }
532
533        Ok(predictions)
534    }
535}
536
537impl RegressorChain<RegressorChainTrained> {
538    /// Get the chain order used during training
539    pub fn chain_order(&self) -> &[usize] {
540        &self.state.order
541    }
542
543    /// Get the number of models in the chain
544    pub fn n_models(&self) -> usize {
545        self.state.models.len()
546    }
547
548    /// Get model at specified index
549    pub fn get_model(&self, index: usize) -> Option<&SimpleLinearClassifier> {
550        self.state.models.get(index)
551    }
552
553    /// Get the number of targets
554    pub fn n_targets(&self) -> usize {
555        self.state.n_targets
556    }
557
558    /// Simple prediction method (alias for predict)
559    pub fn predict_simple(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
560        self.predict(X)
561    }
562}
563
564/// Ensemble of Chains
565///
566/// An ensemble approach that combines multiple ClassifierChain models
567/// with different chain orders or different random seeds to improve
568/// prediction performance and robustness.
569///
570/// # Examples
571///
572/// ```
573/// use sklears_multioutput::chains::EnsembleOfChains;
574/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
575/// use scirs2_core::ndarray::array;
576///
577/// // This is a simple example showing the structure
578/// let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
579/// let labels = array![[0, 1], [1, 0], [1, 1]];
580/// ```
581#[derive(Debug, Clone)]
582pub struct EnsembleOfChains<S = Untrained> {
583    state: S,
584    n_chains: usize,
585    chain_method: ChainMethod,
586    random_state: Option<u64>,
587}
588
589/// Method for generating chains in ensemble
590#[derive(Debug, Clone, Copy, PartialEq)]
591pub enum ChainMethod {
592    /// Random chain orders
593    Random,
594    /// Fixed different orders
595    Fixed,
596    /// Bootstrap sampling with chains
597    Bootstrap,
598}
599
600impl EnsembleOfChains<Untrained> {
601    /// Create a new EnsembleOfChains instance
602    pub fn new() -> Self {
603        Self {
604            state: Untrained,
605            n_chains: 10,
606            chain_method: ChainMethod::Random,
607            random_state: None,
608        }
609    }
610
611    /// Set number of chains in ensemble
612    pub fn n_chains(mut self, n_chains: usize) -> Self {
613        self.n_chains = n_chains;
614        self
615    }
616
617    /// Set chain generation method
618    pub fn chain_method(mut self, method: ChainMethod) -> Self {
619        self.chain_method = method;
620        self
621    }
622
623    /// Set random state for reproducibility
624    pub fn random_state(mut self, random_state: u64) -> Self {
625        self.random_state = Some(random_state);
626        self
627    }
628}
629
630impl Default for EnsembleOfChains<Untrained> {
631    fn default() -> Self {
632        Self::new()
633    }
634}
635
636impl Estimator for EnsembleOfChains<Untrained> {
637    type Config = ();
638    type Error = SklearsError;
639    type Float = Float;
640
641    fn config(&self) -> &Self::Config {
642        &()
643    }
644}
645
646impl EnsembleOfChains<Untrained> {
647    /// Fit the ensemble of chains
648    pub fn fit_simple(
649        self,
650        X: &ArrayView2<'_, Float>,
651        y: &Array2<i32>,
652    ) -> SklResult<EnsembleOfChains<EnsembleOfChainsTrained>> {
653        let (n_samples, n_features) = X.dim();
654        let n_labels = y.ncols();
655
656        if n_samples != y.nrows() {
657            return Err(SklearsError::InvalidInput(
658                "X and y must have the same number of samples".to_string(),
659            ));
660        }
661
662        let mut chains = Vec::new();
663        let mut rng_state = self.random_state.unwrap_or(42);
664
665        for i in 0..self.n_chains {
666            // Generate chain order based on method
667            let chain_order = match self.chain_method {
668                ChainMethod::Random => {
669                    let mut order: Vec<usize> = (0..n_labels).collect();
670                    // Simple shuffle using deterministic random
671                    for j in (1..order.len()).rev() {
672                        rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
673                        let k = (rng_state as usize) % (j + 1);
674                        order.swap(j, k);
675                    }
676                    order
677                }
678                ChainMethod::Fixed => {
679                    // Create different fixed orders
680                    let mut order: Vec<usize> = (0..n_labels).collect();
681                    order.rotate_left(i % n_labels);
682                    order
683                }
684                ChainMethod::Bootstrap => {
685                    // For bootstrap, use random order and later bootstrap samples
686                    let mut order: Vec<usize> = (0..n_labels).collect();
687                    for j in (1..order.len()).rev() {
688                        rng_state = rng_state.wrapping_mul(1664525).wrapping_add(1013904223);
689                        let k = (rng_state as usize) % (j + 1);
690                        order.swap(j, k);
691                    }
692                    order
693                }
694            };
695
696            // Create and train individual chain
697            let chain = ClassifierChain::new()
698                .order(chain_order)
699                .random_state(rng_state);
700
701            let trained_chain = chain.fit_simple(X, y)?;
702            chains.push(trained_chain);
703
704            rng_state = rng_state.wrapping_add(1);
705        }
706
707        let trained_state = EnsembleOfChainsTrained {
708            chains,
709            n_features,
710            n_labels,
711        };
712
713        Ok(EnsembleOfChains {
714            state: trained_state,
715            n_chains: self.n_chains,
716            chain_method: self.chain_method,
717            random_state: self.random_state,
718        })
719    }
720}
721
722impl Fit<ArrayView2<'_, Float>, Array2<i32>, EnsembleOfChainsTrained>
723    for EnsembleOfChains<Untrained>
724{
725    type Fitted = EnsembleOfChains<EnsembleOfChainsTrained>;
726
727    fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
728        self.fit_simple(X, y)
729    }
730}
731
732/// Trained state for EnsembleOfChains
733#[derive(Debug, Clone)]
734pub struct EnsembleOfChainsTrained {
735    chains: Vec<ClassifierChain<ClassifierChainTrained>>,
736    n_features: usize,
737    n_labels: usize,
738}
739
740impl Predict<ArrayView2<'_, Float>, Array2<i32>> for EnsembleOfChains<EnsembleOfChainsTrained> {
741    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
742        let (n_samples, n_features) = X.dim();
743        if n_features != self.state.n_features {
744            return Err(SklearsError::InvalidInput(
745                "X has different number of features than training data".to_string(),
746            ));
747        }
748
749        // Collect predictions from all chains
750        let mut all_predictions = Vec::new();
751        for chain in &self.state.chains {
752            let predictions = chain.predict(X)?;
753            all_predictions.push(predictions);
754        }
755
756        // Ensemble predictions by majority voting
757        let mut final_predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
758
759        for i in 0..n_samples {
760            for j in 0..self.state.n_labels {
761                let mut votes = 0;
762                for predictions in &all_predictions {
763                    votes += predictions[[i, j]];
764                }
765                // Majority vote
766                final_predictions[[i, j]] = if votes > (self.state.chains.len() as i32) / 2 {
767                    1
768                } else {
769                    0
770                };
771            }
772        }
773
774        Ok(final_predictions)
775    }
776}
777
778impl EnsembleOfChains<EnsembleOfChainsTrained> {
779    /// Predict probabilities using ensemble voting
780    pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
781        let (n_samples, n_features) = X.dim();
782        if n_features != self.state.n_features {
783            return Err(SklearsError::InvalidInput(
784                "X has different number of features than training data".to_string(),
785            ));
786        }
787
788        // Collect probability predictions from all chains
789        let mut all_probabilities = Vec::new();
790        for chain in &self.state.chains {
791            let probabilities = chain.predict_proba(X)?;
792            all_probabilities.push(probabilities);
793        }
794
795        // Average probabilities across chains
796        let mut final_probabilities = Array2::<Float>::zeros((n_samples, self.state.n_labels));
797
798        for i in 0..n_samples {
799            for j in 0..self.state.n_labels {
800                let mut prob_sum = 0.0;
801                for probabilities in &all_probabilities {
802                    prob_sum += probabilities[[i, j]];
803                }
804                final_probabilities[[i, j]] = prob_sum / self.state.chains.len() as Float;
805            }
806        }
807
808        Ok(final_probabilities)
809    }
810
811    /// Get number of chains in ensemble
812    pub fn n_chains(&self) -> usize {
813        self.state.chains.len()
814    }
815
816    /// Get individual chain at specified index
817    pub fn get_chain(&self, index: usize) -> Option<&ClassifierChain<ClassifierChainTrained>> {
818        self.state.chains.get(index)
819    }
820
821    /// Get diversity measure between chains
822    pub fn chain_diversity(&self) -> Float {
823        if self.state.chains.len() < 2 {
824            return 0.0;
825        }
826
827        let mut diversity_sum = 0.0;
828        let mut count = 0;
829
830        // Compare chain orders pairwise
831        for i in 0..self.state.chains.len() {
832            for j in (i + 1)..self.state.chains.len() {
833                let order1 = self.state.chains[i].chain_order();
834                let order2 = self.state.chains[j].chain_order();
835
836                // Calculate order similarity (Kendall's tau-like measure)
837                let mut agreements = 0;
838                for k in 0..order1.len() {
839                    if order1[k] == order2[k] {
840                        agreements += 1;
841                    }
842                }
843
844                let similarity = agreements as Float / order1.len() as Float;
845                diversity_sum += 1.0 - similarity;
846                count += 1;
847            }
848        }
849
850        if count > 0 {
851            diversity_sum / count as Float
852        } else {
853            0.0
854        }
855    }
856
857    /// Get the number of targets
858    pub fn n_targets(&self) -> usize {
859        self.state.n_labels
860    }
861
862    /// Simple prediction method (alias for predict)
863    pub fn predict_simple(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
864        self.predict(X)
865    }
866
867    /// Simple probability prediction method (alias for predict_proba)
868    pub fn predict_proba_simple(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
869        self.predict_proba(X)
870    }
871}
872
873/// Bayesian Classifier Chain
874///
875/// A probabilistic variant of classifier chain that uses Bayesian inference
876/// for the binary classifiers, providing uncertainty quantification alongside
877/// predictions.
878///
879/// # Examples
880///
881/// ```
882/// use sklears_multioutput::chains::BayesianClassifierChain;
883/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
884/// use scirs2_core::ndarray::array;
885///
886/// // This is a simple example showing the structure
887/// let data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
888/// let labels = array![[0, 1], [1, 0], [1, 1]];
889/// let model = BayesianClassifierChain::new()
890///     .n_samples(100)
891///     .prior_strength(1.0);
892/// ```
893#[derive(Debug, Clone)]
894pub struct BayesianClassifierChain<S = Untrained> {
895    state: S,
896    /// order
897    pub order: Option<Vec<usize>>,
898    /// n_samples
899    pub n_samples: usize,
900    /// prior_strength
901    pub prior_strength: Float,
902    /// random_state
903    pub random_state: Option<u64>,
904}
905
906impl BayesianClassifierChain<Untrained> {
907    /// Create a new BayesianClassifierChain instance
908    pub fn new() -> Self {
909        Self {
910            state: Untrained,
911            order: None,
912            n_samples: 100,
913            prior_strength: 1.0,
914            random_state: None,
915        }
916    }
917
918    /// Set the chain order
919    pub fn order(mut self, order: Vec<usize>) -> Self {
920        self.order = Some(order);
921        self
922    }
923
924    /// Set number of posterior samples
925    pub fn n_samples(mut self, n_samples: usize) -> Self {
926        self.n_samples = n_samples;
927        self
928    }
929
930    /// Set prior strength (regularization parameter)
931    pub fn prior_strength(mut self, prior_strength: Float) -> Self {
932        self.prior_strength = prior_strength;
933        self
934    }
935
936    /// Set random state for reproducibility
937    pub fn random_state(mut self, random_state: u64) -> Self {
938        self.random_state = Some(random_state);
939        self
940    }
941}
942
943impl Default for BayesianClassifierChain<Untrained> {
944    fn default() -> Self {
945        Self::new()
946    }
947}
948
949impl Estimator for BayesianClassifierChain<Untrained> {
950    type Config = ();
951    type Error = SklearsError;
952    type Float = Float;
953
954    fn config(&self) -> &Self::Config {
955        &()
956    }
957}
958
959impl BayesianClassifierChain<Untrained> {
960    /// Fit the Bayesian classifier chain
961    #[allow(non_snake_case)]
962    pub fn fit_simple(
963        self,
964        X: &ArrayView2<'_, Float>,
965        y: &Array2<i32>,
966    ) -> SklResult<BayesianClassifierChain<BayesianClassifierChainTrained>> {
967        let (n_samples, n_features) = X.dim();
968        let n_labels = y.ncols();
969
970        if n_samples != y.nrows() {
971            return Err(SklearsError::InvalidInput(
972                "X and y must have the same number of samples".to_string(),
973            ));
974        }
975
976        // Validate binary labels
977        for &val in y.iter() {
978            if val != 0 && val != 1 {
979                return Err(SklearsError::InvalidInput(
980                    "y must contain only binary values (0 or 1)".to_string(),
981                ));
982            }
983        }
984
985        // Determine chain order
986        let order = self
987            .order
988            .clone()
989            .unwrap_or_else(|| (0..n_labels).collect());
990
991        if order.len() != n_labels {
992            return Err(SklearsError::InvalidInput(
993                "Chain order must contain all label indices".to_string(),
994            ));
995        }
996
997        // Standardize features
998        let feature_means = X.mean_axis(Axis(0)).unwrap();
999        let feature_stds = X.std_axis(Axis(0), 0.0);
1000        let X_standardized = standardize_features_simple(X, &feature_means, &feature_stds);
1001
1002        // Train Bayesian models in the chain
1003        let mut bayesian_models = Vec::new();
1004        let mut current_features = X_standardized;
1005
1006        for (i, &label_idx) in order.iter().enumerate() {
1007            let y_binary = y.column(label_idx).to_owned();
1008
1009            // Train Bayesian binary classifier
1010            let model = train_bayesian_binary_classifier(
1011                &current_features,
1012                &y_binary,
1013                self.prior_strength,
1014            )?;
1015            bayesian_models.push(model);
1016
1017            // Add predictions as features for next model (except for the last one)
1018            if i < order.len() - 1 {
1019                let predictions =
1020                    predict_bayesian_mean(&current_features.view(), &bayesian_models[i]);
1021                let n_current_features = current_features.ncols();
1022                let mut new_features = Array2::<Float>::zeros((n_samples, n_current_features + 1));
1023
1024                // Copy existing features
1025                new_features
1026                    .slice_mut(s![.., ..n_current_features])
1027                    .assign(&current_features);
1028
1029                // Add predictions as new feature
1030                for j in 0..n_samples {
1031                    new_features[[j, n_current_features]] = predictions[j];
1032                }
1033
1034                current_features = new_features;
1035            }
1036        }
1037
1038        let trained_state = BayesianClassifierChainTrained {
1039            bayesian_models,
1040            order,
1041            n_features,
1042            n_labels,
1043            feature_means,
1044            feature_stds,
1045        };
1046
1047        Ok(BayesianClassifierChain {
1048            state: trained_state,
1049            order: None,
1050            n_samples: self.n_samples,
1051            prior_strength: self.prior_strength,
1052            random_state: self.random_state,
1053        })
1054    }
1055}
1056
1057impl Fit<ArrayView2<'_, Float>, Array2<i32>, BayesianClassifierChainTrained>
1058    for BayesianClassifierChain<Untrained>
1059{
1060    type Fitted = BayesianClassifierChain<BayesianClassifierChainTrained>;
1061
1062    fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
1063        self.fit_simple(X, y)
1064    }
1065}
1066
1067/// Trained state for Bayesian Classifier Chain
1068#[derive(Debug, Clone)]
1069pub struct BayesianClassifierChainTrained {
1070    bayesian_models: Vec<BayesianBinaryModel>,
1071    order: Vec<usize>,
1072    n_features: usize,
1073    n_labels: usize,
1074    feature_means: Array1<Float>,
1075    feature_stds: Array1<Float>,
1076}
1077
1078impl Predict<ArrayView2<'_, Float>, Array2<i32>>
1079    for BayesianClassifierChain<BayesianClassifierChainTrained>
1080{
1081    #[allow(non_snake_case)]
1082    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
1083        let (n_samples, n_features) = X.dim();
1084        if n_features != self.state.feature_means.len() {
1085            return Err(SklearsError::InvalidInput(
1086                "X has different number of features than training data".to_string(),
1087            ));
1088        }
1089
1090        // Standardize features
1091        let X_standardized =
1092            standardize_features_simple(X, &self.state.feature_means, &self.state.feature_stds);
1093
1094        let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
1095        let mut current_features = X_standardized;
1096
1097        // Make predictions following the chain order
1098        for (chain_pos, &label_idx) in self.state.order.iter().enumerate() {
1099            let model = &self.state.bayesian_models[chain_pos];
1100
1101            // Sample from posterior distribution and make predictions
1102            let label_predictions = predict_bayesian_binary(&current_features.view(), model);
1103
1104            // Convert probabilities to binary predictions
1105            for i in 0..n_samples {
1106                predictions[[i, label_idx]] = if label_predictions[i] > 0.5 { 1 } else { 0 };
1107            }
1108
1109            // Add predictions as features for next model (if not last)
1110            if chain_pos < self.state.order.len() - 1 {
1111                let mut new_features =
1112                    Array2::<Float>::zeros((n_samples, current_features.ncols() + 1));
1113
1114                // Copy existing features
1115                new_features
1116                    .slice_mut(s![.., ..current_features.ncols()])
1117                    .assign(&current_features);
1118
1119                // Add current label predictions as feature
1120                for i in 0..n_samples {
1121                    new_features[[i, current_features.ncols()]] =
1122                        predictions[[i, label_idx]] as Float;
1123                }
1124
1125                current_features = new_features;
1126            }
1127        }
1128
1129        Ok(predictions)
1130    }
1131}
1132
1133impl BayesianClassifierChain<BayesianClassifierChainTrained> {
1134    /// Predict with uncertainty quantification
1135    #[allow(non_snake_case)]
1136    pub fn predict_uncertainty(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
1137        let (n_samples, n_features) = X.dim();
1138        if n_features != self.state.feature_means.len() {
1139            return Err(SklearsError::InvalidInput(
1140                "X has different number of features than training data".to_string(),
1141            ));
1142        }
1143
1144        // Standardize features
1145        let X_standardized =
1146            standardize_features_simple(X, &self.state.feature_means, &self.state.feature_stds);
1147
1148        let mut uncertainties = Array2::<Float>::zeros((n_samples, self.state.n_labels));
1149        let mut current_features = X_standardized;
1150
1151        // Make predictions following the chain order with uncertainty estimation
1152        for (chain_pos, &label_idx) in self.state.order.iter().enumerate() {
1153            let model = &self.state.bayesian_models[chain_pos];
1154
1155            // Get uncertainty estimates
1156            let (means, variances) = predict_bayesian_uncertainty(&current_features.view(), model)?;
1157
1158            // Store uncertainties
1159            for i in 0..n_samples {
1160                uncertainties[[i, label_idx]] = variances[i];
1161            }
1162
1163            // For chaining, use mean predictions as features
1164            if chain_pos < self.state.order.len() - 1 {
1165                let mut new_features =
1166                    Array2::<Float>::zeros((n_samples, current_features.ncols() + 1));
1167
1168                // Copy existing features
1169                new_features
1170                    .slice_mut(s![.., ..current_features.ncols()])
1171                    .assign(&current_features);
1172
1173                // Add mean predictions as feature
1174                for i in 0..n_samples {
1175                    new_features[[i, current_features.ncols()]] = means[i];
1176                }
1177
1178                current_features = new_features;
1179            }
1180        }
1181
1182        Ok(uncertainties)
1183    }
1184
1185    /// Predict probabilities with Bayesian averaging
1186    #[allow(non_snake_case)]
1187    pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
1188        let (n_samples, n_features) = X.dim();
1189        if n_features != self.state.feature_means.len() {
1190            return Err(SklearsError::InvalidInput(
1191                "X has different number of features than training data".to_string(),
1192            ));
1193        }
1194
1195        // Standardize features
1196        let X_standardized =
1197            standardize_features_simple(X, &self.state.feature_means, &self.state.feature_stds);
1198
1199        let mut probabilities = Array2::<Float>::zeros((n_samples, self.state.n_labels));
1200        let mut current_features = X_standardized;
1201
1202        // Make probability predictions following the chain order
1203        for (chain_pos, &label_idx) in self.state.order.iter().enumerate() {
1204            let model = &self.state.bayesian_models[chain_pos];
1205
1206            // Get probability predictions
1207            let label_probabilities = predict_bayesian_binary(&current_features.view(), model);
1208
1209            // Store probabilities
1210            for i in 0..n_samples {
1211                probabilities[[i, label_idx]] = label_probabilities[i];
1212            }
1213
1214            // Add mean predictions as features for next model (if not last)
1215            if chain_pos < self.state.order.len() - 1 {
1216                let mut new_features =
1217                    Array2::<Float>::zeros((n_samples, current_features.ncols() + 1));
1218
1219                // Copy existing features
1220                new_features
1221                    .slice_mut(s![.., ..current_features.ncols()])
1222                    .assign(&current_features);
1223
1224                // Add mean predictions as feature
1225                for i in 0..n_samples {
1226                    new_features[[i, current_features.ncols()]] = label_probabilities[i];
1227                }
1228
1229                current_features = new_features;
1230            }
1231        }
1232
1233        Ok(probabilities)
1234    }
1235
1236    /// Get the chain order used during training
1237    pub fn chain_order(&self) -> &[usize] {
1238        &self.state.order
1239    }
1240
1241    /// Get number of Bayesian models in the chain
1242    pub fn n_models(&self) -> usize {
1243        self.state.bayesian_models.len()
1244    }
1245
1246    /// Get posterior statistics for a specific model in the chain
1247    pub fn model_posterior_stats(
1248        &self,
1249        model_idx: usize,
1250    ) -> Option<(&Array1<Float>, &Array2<Float>)> {
1251        self.state
1252            .bayesian_models
1253            .get(model_idx)
1254            .map(|model| (&model.weight_mean, &model.weight_cov))
1255    }
1256
1257    /// Get the chain order used during training
1258    pub fn order(&self) -> &[usize] {
1259        &self.state.order
1260    }
1261}
1262
1263// Chain-specific utility functions
1264
1265/// Helper function to predict binary classification
1266fn predict_binary_classifier(X: &ArrayView2<Float>, model: &SimpleBinaryModel) -> Array1<i32> {
1267    let raw_scores = X.dot(&model.weights) + model.bias;
1268    raw_scores.mapv(|x| if x > 0.0 { 1 } else { 0 })
1269}