Skip to main content

survival/ml/
ensemble_surv.rs

1#![allow(clippy::too_many_arguments)]
2#![allow(dead_code)]
3
4use pyo3::prelude::*;
5use rayon::prelude::*;
6
7#[derive(Debug, Clone)]
8#[pyclass]
9pub struct SuperLearnerConfig {
10    #[pyo3(get, set)]
11    pub n_folds: usize,
12    #[pyo3(get, set)]
13    pub meta_learner: String,
14    #[pyo3(get, set)]
15    pub include_original_features: bool,
16    #[pyo3(get, set)]
17    pub optimize_weights: bool,
18    #[pyo3(get, set)]
19    pub seed: Option<u64>,
20}
21
22#[pymethods]
23impl SuperLearnerConfig {
24    #[new]
25    #[pyo3(signature = (
26        n_folds=5,
27        meta_learner="nnls",
28        include_original_features=false,
29        optimize_weights=true,
30        seed=None
31    ))]
32    pub fn new(
33        n_folds: usize,
34        meta_learner: &str,
35        include_original_features: bool,
36        optimize_weights: bool,
37        seed: Option<u64>,
38    ) -> PyResult<Self> {
39        if n_folds < 2 {
40            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
41                "n_folds must be at least 2",
42            ));
43        }
44        Ok(Self {
45            n_folds,
46            meta_learner: meta_learner.to_string(),
47            include_original_features,
48            optimize_weights,
49            seed,
50        })
51    }
52}
53
54fn create_cv_folds(n: usize, n_folds: usize, seed: u64) -> Vec<Vec<usize>> {
55    let mut indices: Vec<usize> = (0..n).collect();
56    let mut rng_state = seed;
57    for i in (1..n).rev() {
58        rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
59        let j = (rng_state as usize) % (i + 1);
60        indices.swap(i, j);
61    }
62
63    let fold_size = n / n_folds;
64    let mut folds = Vec::with_capacity(n_folds);
65
66    for i in 0..n_folds {
67        let start = i * fold_size;
68        let end = if i == n_folds - 1 {
69            n
70        } else {
71            (i + 1) * fold_size
72        };
73        folds.push(indices[start..end].to_vec());
74    }
75
76    folds
77}
78
79fn fit_base_cox(
80    time: &[f64],
81    event: &[i32],
82    covariates: &[Vec<f64>],
83    train_indices: &[usize],
84    learning_rate: f64,
85    n_iter: usize,
86) -> Vec<f64> {
87    let n_features = if covariates.is_empty() {
88        0
89    } else {
90        covariates[0].len()
91    };
92
93    let mut coefficients = vec![0.0; n_features];
94
95    let mut sorted_indices: Vec<usize> = train_indices.to_vec();
96    sorted_indices.sort_by(|&a, &b| time[b].partial_cmp(&time[a]).unwrap());
97
98    for _ in 0..n_iter {
99        let linear_pred: Vec<f64> = sorted_indices
100            .iter()
101            .map(|&i| {
102                covariates[i]
103                    .iter()
104                    .zip(coefficients.iter())
105                    .map(|(&x, &b)| x * b)
106                    .sum()
107            })
108            .collect();
109
110        let exp_lp: Vec<f64> = linear_pred.iter().map(|&lp| lp.exp()).collect();
111
112        let mut gradient = vec![0.0; n_features];
113        let mut risk_sum = 0.0;
114        let mut weighted_sum = vec![0.0; n_features];
115
116        for (idx, &i) in sorted_indices.iter().enumerate() {
117            risk_sum += exp_lp[idx];
118            for (j, &xij) in covariates[i].iter().enumerate() {
119                weighted_sum[j] += xij * exp_lp[idx];
120            }
121
122            if event[i] == 1 {
123                for (j, g) in gradient.iter_mut().enumerate() {
124                    *g += covariates[i][j] - weighted_sum[j] / risk_sum;
125                }
126            }
127        }
128
129        for (b, g) in coefficients.iter_mut().zip(gradient.iter()) {
130            *b += learning_rate * g / train_indices.len() as f64;
131        }
132    }
133
134    coefficients
135}
136
137fn nnls_weights(predictions: &[Vec<f64>], outcomes: &[f64], n_models: usize) -> Vec<f64> {
138    let n = outcomes.len();
139    let mut weights = vec![1.0 / n_models as f64; n_models];
140
141    for _ in 0..100 {
142        let mut gradient = vec![0.0; n_models];
143
144        for i in 0..n {
145            let pred: f64 = (0..n_models).map(|m| weights[m] * predictions[m][i]).sum();
146            let error = pred - outcomes[i];
147
148            for m in 0..n_models {
149                gradient[m] += 2.0 * error * predictions[m][i] / n as f64;
150            }
151        }
152
153        for (w, g) in weights.iter_mut().zip(gradient.iter()) {
154            *w = (*w - 0.01 * g).max(0.0);
155        }
156
157        let sum: f64 = weights.iter().sum();
158        if sum > 0.0 {
159            for w in &mut weights {
160                *w /= sum;
161            }
162        }
163    }
164
165    weights
166}
167
168#[derive(Debug, Clone)]
169#[pyclass]
170pub struct SuperLearnerResult {
171    #[pyo3(get)]
172    pub weights: Vec<f64>,
173    #[pyo3(get)]
174    pub cv_risks: Vec<f64>,
175    #[pyo3(get)]
176    pub model_names: Vec<String>,
177    #[pyo3(get)]
178    pub ensemble_c_index: f64,
179    #[pyo3(get)]
180    pub individual_c_indices: Vec<f64>,
181}
182
183#[pymethods]
184impl SuperLearnerResult {
185    fn __repr__(&self) -> String {
186        format!(
187            "SuperLearnerResult(n_models={}, C-index={:.4})",
188            self.weights.len(),
189            self.ensemble_c_index
190        )
191    }
192
193    fn best_model(&self) -> (String, f64) {
194        let (idx, &max_c) = self
195            .individual_c_indices
196            .iter()
197            .enumerate()
198            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
199            .unwrap_or((0, &0.0));
200        (self.model_names[idx].clone(), max_c)
201    }
202}
203
204fn compute_c_index(time: &[f64], event: &[i32], risk: &[f64]) -> f64 {
205    let n = time.len();
206    let mut concordant = 0.0;
207    let mut discordant = 0.0;
208
209    for i in 0..n {
210        if event[i] == 1 {
211            for j in 0..n {
212                if time[j] > time[i] {
213                    if risk[i] > risk[j] {
214                        concordant += 1.0;
215                    } else if risk[i] < risk[j] {
216                        discordant += 1.0;
217                    } else {
218                        concordant += 0.5;
219                        discordant += 0.5;
220                    }
221                }
222            }
223        }
224    }
225
226    if concordant + discordant > 0.0 {
227        concordant / (concordant + discordant)
228    } else {
229        0.5
230    }
231}
232
233#[pyfunction]
234#[pyo3(signature = (
235    time,
236    event,
237    covariates,
238    base_learner_predictions,
239    model_names,
240    config
241))]
242pub fn super_learner_survival(
243    time: Vec<f64>,
244    event: Vec<i32>,
245    covariates: Vec<Vec<f64>>,
246    base_learner_predictions: Vec<Vec<f64>>,
247    model_names: Vec<String>,
248    config: SuperLearnerConfig,
249) -> PyResult<SuperLearnerResult> {
250    let n = time.len();
251    let n_models = base_learner_predictions.len();
252
253    if n == 0 || event.len() != n || covariates.len() != n {
254        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
255            "Input arrays must have the same non-zero length",
256        ));
257    }
258    if n_models == 0 || model_names.len() != n_models {
259        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
260            "Must provide predictions from at least one model",
261        ));
262    }
263
264    let seed = config.seed.unwrap_or(42);
265    let folds = create_cv_folds(n, config.n_folds, seed);
266
267    let mut cv_predictions: Vec<Vec<f64>> = vec![vec![0.0; n]; n_models];
268
269    for test_indices in folds.iter() {
270        let train_indices: Vec<usize> = (0..n).filter(|i| !test_indices.contains(i)).collect();
271
272        for m in 0..n_models {
273            let train_preds: Vec<f64> = train_indices
274                .iter()
275                .map(|&i| base_learner_predictions[m][i])
276                .collect();
277            let test_preds: Vec<f64> = test_indices
278                .iter()
279                .map(|&i| base_learner_predictions[m][i])
280                .collect();
281
282            let scale = if !train_preds.is_empty() {
283                train_preds.iter().sum::<f64>() / train_preds.len() as f64
284            } else {
285                1.0
286            };
287
288            for (idx, &test_i) in test_indices.iter().enumerate() {
289                cv_predictions[m][test_i] = test_preds[idx] / scale.max(1e-10);
290            }
291        }
292    }
293
294    let outcomes: Vec<f64> = event.iter().map(|&e| e as f64).collect();
295    let weights = if config.optimize_weights {
296        nnls_weights(&cv_predictions, &outcomes, n_models)
297    } else {
298        vec![1.0 / n_models as f64; n_models]
299    };
300
301    let ensemble_risk: Vec<f64> = (0..n)
302        .map(|i| {
303            (0..n_models)
304                .map(|m| weights[m] * base_learner_predictions[m][i])
305                .sum()
306        })
307        .collect();
308
309    let ensemble_c_index = compute_c_index(&time, &event, &ensemble_risk);
310
311    let individual_c_indices: Vec<f64> = base_learner_predictions
312        .iter()
313        .map(|preds| compute_c_index(&time, &event, preds))
314        .collect();
315
316    let cv_risks: Vec<f64> = (0..n_models)
317        .map(|m| {
318            let mse: f64 = cv_predictions[m]
319                .iter()
320                .zip(outcomes.iter())
321                .map(|(&p, &o)| (p - o).powi(2))
322                .sum::<f64>()
323                / n as f64;
324            mse
325        })
326        .collect();
327
328    Ok(SuperLearnerResult {
329        weights,
330        cv_risks,
331        model_names,
332        ensemble_c_index,
333        individual_c_indices,
334    })
335}
336
337#[derive(Debug, Clone)]
338#[pyclass]
339pub struct StackingConfig {
340    #[pyo3(get, set)]
341    pub n_folds: usize,
342    #[pyo3(get, set)]
343    pub meta_model: String,
344    #[pyo3(get, set)]
345    pub use_probabilities: bool,
346    #[pyo3(get, set)]
347    pub seed: Option<u64>,
348}
349
350#[pymethods]
351impl StackingConfig {
352    #[new]
353    #[pyo3(signature = (
354        n_folds=5,
355        meta_model="cox",
356        use_probabilities=true,
357        seed=None
358    ))]
359    pub fn new(
360        n_folds: usize,
361        meta_model: &str,
362        use_probabilities: bool,
363        seed: Option<u64>,
364    ) -> PyResult<Self> {
365        if n_folds < 2 {
366            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
367                "n_folds must be at least 2",
368            ));
369        }
370        Ok(Self {
371            n_folds,
372            meta_model: meta_model.to_string(),
373            use_probabilities,
374            seed,
375        })
376    }
377}
378
379#[derive(Debug, Clone)]
380#[pyclass]
381pub struct StackingResult {
382    #[pyo3(get)]
383    pub meta_coefficients: Vec<f64>,
384    #[pyo3(get)]
385    pub stacked_predictions: Vec<f64>,
386    #[pyo3(get)]
387    pub c_index: f64,
388    #[pyo3(get)]
389    pub base_model_importance: Vec<f64>,
390}
391
392#[pymethods]
393impl StackingResult {
394    fn __repr__(&self) -> String {
395        format!(
396            "StackingResult(n_base_models={}, C-index={:.4})",
397            self.meta_coefficients.len(),
398            self.c_index
399        )
400    }
401}
402
403#[pyfunction]
404#[pyo3(signature = (
405    time,
406    event,
407    base_predictions,
408    config
409))]
410pub fn stacking_survival(
411    time: Vec<f64>,
412    event: Vec<i32>,
413    base_predictions: Vec<Vec<f64>>,
414    config: StackingConfig,
415) -> PyResult<StackingResult> {
416    let n = time.len();
417    let n_models = base_predictions.len();
418
419    if n == 0 || event.len() != n {
420        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
421            "time and event must have the same non-zero length",
422        ));
423    }
424    if n_models == 0 {
425        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
426            "Must provide at least one base model",
427        ));
428    }
429
430    let seed = config.seed.unwrap_or(42);
431    let folds = create_cv_folds(n, config.n_folds, seed);
432
433    let mut oof_predictions: Vec<Vec<f64>> = vec![vec![0.0; n]; n_models];
434
435    for test_indices in &folds {
436        let train_indices: Vec<usize> = (0..n).filter(|i| !test_indices.contains(i)).collect();
437
438        for m in 0..n_models {
439            let train_mean: f64 = train_indices
440                .iter()
441                .map(|&i| base_predictions[m][i])
442                .sum::<f64>()
443                / train_indices.len() as f64;
444
445            for &test_i in test_indices {
446                oof_predictions[m][test_i] = base_predictions[m][test_i] / train_mean.max(1e-10);
447            }
448        }
449    }
450
451    let meta_features: Vec<Vec<f64>> = (0..n)
452        .map(|i| (0..n_models).map(|m| oof_predictions[m][i]).collect())
453        .collect();
454
455    let train_indices: Vec<usize> = (0..n).collect();
456    let meta_coefficients = fit_base_cox(&time, &event, &meta_features, &train_indices, 0.01, 100);
457
458    let stacked_predictions: Vec<f64> = meta_features
459        .iter()
460        .map(|x| {
461            x.iter()
462                .zip(meta_coefficients.iter())
463                .map(|(&xi, &bi)| xi * bi)
464                .sum::<f64>()
465                .exp()
466        })
467        .collect();
468
469    let c_index = compute_c_index(&time, &event, &stacked_predictions);
470
471    let total_abs: f64 = meta_coefficients.iter().map(|&c| c.abs()).sum();
472    let base_model_importance: Vec<f64> = if total_abs > 0.0 {
473        meta_coefficients
474            .iter()
475            .map(|&c| c.abs() / total_abs)
476            .collect()
477    } else {
478        vec![1.0 / n_models as f64; n_models]
479    };
480
481    Ok(StackingResult {
482        meta_coefficients,
483        stacked_predictions,
484        c_index,
485        base_model_importance,
486    })
487}
488
489#[derive(Debug, Clone)]
490#[pyclass]
491pub struct ComponentwiseBoostingConfig {
492    #[pyo3(get, set)]
493    pub n_iterations: usize,
494    #[pyo3(get, set)]
495    pub learning_rate: f64,
496    #[pyo3(get, set)]
497    pub early_stopping_rounds: Option<usize>,
498    #[pyo3(get, set)]
499    pub subsample_ratio: f64,
500    #[pyo3(get, set)]
501    pub seed: Option<u64>,
502}
503
504#[pymethods]
505impl ComponentwiseBoostingConfig {
506    #[new]
507    #[pyo3(signature = (
508        n_iterations=100,
509        learning_rate=0.1,
510        early_stopping_rounds=None,
511        subsample_ratio=1.0,
512        seed=None
513    ))]
514    pub fn new(
515        n_iterations: usize,
516        learning_rate: f64,
517        early_stopping_rounds: Option<usize>,
518        subsample_ratio: f64,
519        seed: Option<u64>,
520    ) -> PyResult<Self> {
521        if learning_rate <= 0.0 || learning_rate > 1.0 {
522            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
523                "learning_rate must be in (0, 1]",
524            ));
525        }
526        if subsample_ratio <= 0.0 || subsample_ratio > 1.0 {
527            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
528                "subsample_ratio must be in (0, 1]",
529            ));
530        }
531        Ok(Self {
532            n_iterations,
533            learning_rate,
534            early_stopping_rounds,
535            subsample_ratio,
536            seed,
537        })
538    }
539}
540
541#[derive(Debug, Clone)]
542#[pyclass]
543pub struct ComponentwiseBoostingResult {
544    #[pyo3(get)]
545    pub coefficients: Vec<f64>,
546    #[pyo3(get)]
547    pub selected_features: Vec<usize>,
548    #[pyo3(get)]
549    pub iteration_log_likelihood: Vec<f64>,
550    #[pyo3(get)]
551    pub feature_importance: Vec<f64>,
552    #[pyo3(get)]
553    pub optimal_iterations: usize,
554}
555
556#[pymethods]
557impl ComponentwiseBoostingResult {
558    fn __repr__(&self) -> String {
559        format!(
560            "ComponentwiseBoostingResult(n_selected={}, iterations={})",
561            self.selected_features
562                .iter()
563                .collect::<std::collections::HashSet<_>>()
564                .len(),
565            self.optimal_iterations
566        )
567    }
568
569    fn predict_risk(&self, covariates: Vec<Vec<f64>>) -> Vec<f64> {
570        covariates
571            .par_iter()
572            .map(|x| {
573                x.iter()
574                    .zip(self.coefficients.iter())
575                    .map(|(&xi, &bi)| xi * bi)
576                    .sum::<f64>()
577                    .exp()
578            })
579            .collect()
580    }
581}
582
583fn compute_partial_log_likelihood(time: &[f64], event: &[i32], linear_pred: &[f64]) -> f64 {
584    let n = time.len();
585    let mut indices: Vec<usize> = (0..n).collect();
586    indices.sort_by(|&a, &b| time[b].partial_cmp(&time[a]).unwrap());
587
588    let exp_lp: Vec<f64> = linear_pred.iter().map(|&lp| lp.exp()).collect();
589
590    let mut ll = 0.0;
591    let mut risk_sum = 0.0;
592
593    for &i in &indices {
594        risk_sum += exp_lp[i];
595        if event[i] == 1 {
596            ll += linear_pred[i] - risk_sum.ln();
597        }
598    }
599
600    ll
601}
602
603#[pyfunction]
604#[pyo3(signature = (
605    time,
606    event,
607    covariates,
608    config
609))]
610pub fn componentwise_boosting(
611    time: Vec<f64>,
612    event: Vec<i32>,
613    covariates: Vec<Vec<f64>>,
614    config: ComponentwiseBoostingConfig,
615) -> PyResult<ComponentwiseBoostingResult> {
616    let n = time.len();
617    if n == 0 || event.len() != n || covariates.len() != n {
618        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
619            "Input arrays must have the same non-zero length",
620        ));
621    }
622
623    let n_features = if covariates.is_empty() {
624        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
625            "Covariates cannot be empty",
626        ));
627    } else {
628        covariates[0].len()
629    };
630
631    let seed = config.seed.unwrap_or(42);
632    let mut rng_state = seed;
633
634    let mut coefficients: Vec<f64> = vec![0.0; n_features];
635    let mut linear_pred: Vec<f64> = vec![0.0; n];
636    let mut selected_features = Vec::new();
637    let mut iteration_log_likelihood = Vec::new();
638    let mut feature_selection_count = vec![0usize; n_features];
639
640    let mut best_ll = f64::NEG_INFINITY;
641    let mut rounds_without_improvement = 0;
642    let mut optimal_iterations = 0;
643
644    for iter in 0..config.n_iterations {
645        let sample_indices: Vec<usize> = if config.subsample_ratio < 1.0 {
646            let sample_size = (n as f64 * config.subsample_ratio).ceil() as usize;
647            let mut indices: Vec<usize> = (0..n).collect();
648            for i in (1..n).rev() {
649                rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
650                let j = (rng_state as usize) % (i + 1);
651                indices.swap(i, j);
652            }
653            indices.truncate(sample_size);
654            indices
655        } else {
656            (0..n).collect()
657        };
658
659        let exp_lp: Vec<f64> = linear_pred.iter().map(|&lp| lp.exp()).collect();
660
661        let mut sorted_indices: Vec<usize> = sample_indices.clone();
662        sorted_indices.sort_by(|&a, &b| time[b].partial_cmp(&time[a]).unwrap());
663
664        let mut best_feature = 0;
665        let mut best_score = f64::NEG_INFINITY;
666        let mut best_update = 0.0;
667
668        #[allow(clippy::needless_range_loop)]
669        for j in 0..n_features {
670            let mut gradient = 0.0;
671            let mut hessian = 0.0;
672            let mut risk_sum = 0.0;
673            let mut weighted_sum = 0.0;
674            let mut weighted_sq_sum = 0.0;
675
676            for &i in &sorted_indices {
677                risk_sum += exp_lp[i];
678                weighted_sum += covariates[i][j] * exp_lp[i];
679                weighted_sq_sum += covariates[i][j].powi(2) * exp_lp[i];
680
681                if event[i] == 1 {
682                    let mean = weighted_sum / risk_sum;
683                    gradient += covariates[i][j] - mean;
684                    hessian += weighted_sq_sum / risk_sum - mean.powi(2);
685                }
686            }
687
688            if hessian.abs() > 1e-10 {
689                let update = gradient / hessian;
690                let score = gradient.abs();
691
692                if score > best_score {
693                    best_score = score;
694                    best_feature = j;
695                    best_update = update;
696                }
697            }
698        }
699
700        coefficients[best_feature] += config.learning_rate * best_update;
701        selected_features.push(best_feature);
702        feature_selection_count[best_feature] += 1;
703
704        for i in 0..n {
705            linear_pred[i] = coefficients
706                .iter()
707                .zip(covariates[i].iter())
708                .map(|(&b, &x)| b * x)
709                .sum();
710        }
711
712        let ll = compute_partial_log_likelihood(&time, &event, &linear_pred);
713        iteration_log_likelihood.push(ll);
714
715        if ll > best_ll {
716            best_ll = ll;
717            optimal_iterations = iter + 1;
718            rounds_without_improvement = 0;
719        } else {
720            rounds_without_improvement += 1;
721        }
722
723        if let Some(patience) = config.early_stopping_rounds
724            && rounds_without_improvement >= patience
725        {
726            break;
727        }
728    }
729
730    let total_selections: f64 = feature_selection_count.iter().sum::<usize>() as f64;
731    let feature_importance: Vec<f64> = if total_selections > 0.0 {
732        feature_selection_count
733            .iter()
734            .map(|&c| c as f64 / total_selections)
735            .collect()
736    } else {
737        vec![0.0; n_features]
738    };
739
740    Ok(ComponentwiseBoostingResult {
741        coefficients,
742        selected_features,
743        iteration_log_likelihood,
744        feature_importance,
745        optimal_iterations,
746    })
747}
748
749#[derive(Debug, Clone)]
750#[pyclass]
751pub struct BlendingResult {
752    #[pyo3(get)]
753    pub blend_weights: Vec<f64>,
754    #[pyo3(get)]
755    pub blended_predictions: Vec<f64>,
756    #[pyo3(get)]
757    pub validation_c_index: f64,
758}
759
760#[pymethods]
761impl BlendingResult {
762    fn __repr__(&self) -> String {
763        format!(
764            "BlendingResult(n_models={}, val_C={:.4})",
765            self.blend_weights.len(),
766            self.validation_c_index
767        )
768    }
769}
770
771#[pyfunction]
772#[pyo3(signature = (
773    val_time,
774    val_event,
775    val_predictions,
776    test_predictions
777))]
778pub fn blending_survival(
779    val_time: Vec<f64>,
780    val_event: Vec<i32>,
781    val_predictions: Vec<Vec<f64>>,
782    test_predictions: Vec<Vec<f64>>,
783) -> PyResult<BlendingResult> {
784    let n_val = val_time.len();
785    let n_models = val_predictions.len();
786
787    if n_val == 0 || val_event.len() != n_val {
788        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
789            "Validation arrays must have the same non-zero length",
790        ));
791    }
792    if n_models == 0 || test_predictions.len() != n_models {
793        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
794            "Must have same number of models for validation and test",
795        ));
796    }
797
798    let outcomes: Vec<f64> = val_event.iter().map(|&e| e as f64).collect();
799    let blend_weights = nnls_weights(&val_predictions, &outcomes, n_models);
800
801    let n_test = test_predictions[0].len();
802    let blended_predictions: Vec<f64> = (0..n_test)
803        .map(|i| {
804            (0..n_models)
805                .map(|m| blend_weights[m] * test_predictions[m][i])
806                .sum()
807        })
808        .collect();
809
810    let val_blended: Vec<f64> = (0..n_val)
811        .map(|i| {
812            (0..n_models)
813                .map(|m| blend_weights[m] * val_predictions[m][i])
814                .sum()
815        })
816        .collect();
817
818    let validation_c_index = compute_c_index(&val_time, &val_event, &val_blended);
819
820    Ok(BlendingResult {
821        blend_weights,
822        blended_predictions,
823        validation_c_index,
824    })
825}
826
827#[cfg(test)]
828mod tests {
829    use super::*;
830
831    #[test]
832    fn test_super_learner() {
833        let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
834        let event = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 0];
835        let covariates: Vec<Vec<f64>> = (0..10).map(|i| vec![i as f64 * 0.1]).collect();
836        let pred1: Vec<f64> = (0..10).map(|i| 0.1 + i as f64 * 0.05).collect();
837        let pred2: Vec<f64> = (0..10).map(|i| 0.2 + i as f64 * 0.03).collect();
838
839        let config = SuperLearnerConfig::new(3, "nnls", false, true, Some(42)).unwrap();
840        let result = super_learner_survival(
841            time,
842            event,
843            covariates,
844            vec![pred1, pred2],
845            vec!["model1".to_string(), "model2".to_string()],
846            config,
847        )
848        .unwrap();
849
850        assert_eq!(result.weights.len(), 2);
851        assert!((result.weights.iter().sum::<f64>() - 1.0).abs() < 0.01);
852    }
853
854    #[test]
855    fn test_stacking() {
856        let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
857        let event = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 0];
858        let pred1: Vec<f64> = (0..10).map(|i| 0.1 + i as f64 * 0.05).collect();
859        let pred2: Vec<f64> = (0..10).map(|i| 0.2 + i as f64 * 0.03).collect();
860
861        let config = StackingConfig::new(3, "cox", true, Some(42)).unwrap();
862        let result = stacking_survival(time, event, vec![pred1, pred2], config).unwrap();
863
864        assert_eq!(result.meta_coefficients.len(), 2);
865    }
866
867    #[test]
868    fn test_componentwise_boosting() {
869        let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
870        let event = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 0];
871        let covariates: Vec<Vec<f64>> = (0..10)
872            .map(|i| vec![i as f64 * 0.1, (10 - i) as f64 * 0.1])
873            .collect();
874
875        let config = ComponentwiseBoostingConfig::new(50, 0.1, Some(10), 1.0, Some(42)).unwrap();
876        let result = componentwise_boosting(time, event, covariates, config).unwrap();
877
878        assert_eq!(result.coefficients.len(), 2);
879        assert!(!result.selected_features.is_empty());
880    }
881
882    #[test]
883    fn test_blending() {
884        let val_time = vec![1.0, 2.0, 3.0, 4.0, 5.0];
885        let val_event = vec![1, 0, 1, 0, 1];
886        let val_pred1 = vec![0.1, 0.2, 0.3, 0.4, 0.5];
887        let val_pred2 = vec![0.15, 0.25, 0.35, 0.45, 0.55];
888        let test_pred1 = vec![0.2, 0.3, 0.4];
889        let test_pred2 = vec![0.25, 0.35, 0.45];
890
891        let result = blending_survival(
892            val_time,
893            val_event,
894            vec![val_pred1, val_pred2],
895            vec![test_pred1, test_pred2],
896        )
897        .unwrap();
898
899        assert_eq!(result.blend_weights.len(), 2);
900        assert_eq!(result.blended_predictions.len(), 3);
901    }
902
903    #[test]
904    fn test_super_learner_config_validation() {
905        let result = SuperLearnerConfig::new(1, "nnls", false, true, None);
906        assert!(result.is_err());
907    }
908
909    #[test]
910    fn test_stacking_config_validation() {
911        let result = StackingConfig::new(1, "cox", true, None);
912        assert!(result.is_err());
913    }
914
915    #[test]
916    fn test_componentwise_boosting_config_validation() {
917        let result = ComponentwiseBoostingConfig::new(100, 0.0, None, 1.0, None);
918        assert!(result.is_err());
919
920        let result = ComponentwiseBoostingConfig::new(100, 1.5, None, 1.0, None);
921        assert!(result.is_err());
922
923        let result = ComponentwiseBoostingConfig::new(100, 0.1, None, 0.0, None);
924        assert!(result.is_err());
925    }
926
927    #[test]
928    fn test_super_learner_empty_input() {
929        let config = SuperLearnerConfig::new(3, "nnls", false, true, Some(42)).unwrap();
930        let result = super_learner_survival(vec![], vec![], vec![], vec![], vec![], config);
931        assert!(result.is_err());
932    }
933
934    #[test]
935    fn test_super_learner_uniform_weights() {
936        let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
937        let event = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 0];
938        let covariates: Vec<Vec<f64>> = (0..10).map(|i| vec![i as f64 * 0.1]).collect();
939        let pred1: Vec<f64> = (0..10).map(|i| 0.1 + i as f64 * 0.05).collect();
940        let pred2: Vec<f64> = (0..10).map(|i| 0.2 + i as f64 * 0.03).collect();
941
942        let config = SuperLearnerConfig::new(3, "nnls", false, false, Some(42)).unwrap();
943        let result = super_learner_survival(
944            time,
945            event,
946            covariates,
947            vec![pred1, pred2],
948            vec!["m1".to_string(), "m2".to_string()],
949            config,
950        )
951        .unwrap();
952
953        assert!((result.weights[0] - 0.5).abs() < 1e-6);
954        assert!((result.weights[1] - 0.5).abs() < 1e-6);
955    }
956
957    #[test]
958    fn test_componentwise_boosting_predict_risk() {
959        let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
960        let event = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 0];
961        let covariates: Vec<Vec<f64>> = (0..10)
962            .map(|i| vec![i as f64 * 0.1, (10 - i) as f64 * 0.1])
963            .collect();
964
965        let config = ComponentwiseBoostingConfig::new(50, 0.1, Some(10), 1.0, Some(42)).unwrap();
966        let result = componentwise_boosting(time, event, covariates.clone(), config).unwrap();
967
968        let risks = result.predict_risk(covariates);
969        assert_eq!(risks.len(), 10);
970        assert!(risks.iter().all(|&r| r > 0.0));
971    }
972
973    #[test]
974    fn test_stacking_empty_input() {
975        let config = StackingConfig::new(3, "cox", true, Some(42)).unwrap();
976        let result = stacking_survival(vec![], vec![], vec![], config);
977        assert!(result.is_err());
978    }
979
980    #[test]
981    fn test_blending_empty_input() {
982        let result = blending_survival(vec![], vec![], vec![], vec![]);
983        assert!(result.is_err());
984    }
985
986    #[test]
987    fn test_blending_mismatched_models() {
988        let val_time = vec![1.0, 2.0, 3.0];
989        let val_event = vec![1, 0, 1];
990        let val_preds = vec![vec![0.1, 0.2, 0.3]];
991        let test_preds = vec![vec![0.1, 0.2], vec![0.3, 0.4]];
992        let result = blending_survival(val_time, val_event, val_preds, test_preds);
993        assert!(result.is_err());
994    }
995
996    #[test]
997    fn test_componentwise_boosting_feature_importance() {
998        let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
999        let event = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 0];
1000        let covariates: Vec<Vec<f64>> = (0..10)
1001            .map(|i| vec![i as f64 * 0.1, (10 - i) as f64 * 0.1, 0.5])
1002            .collect();
1003
1004        let config = ComponentwiseBoostingConfig::new(50, 0.1, None, 1.0, Some(42)).unwrap();
1005        let result = componentwise_boosting(time, event, covariates, config).unwrap();
1006
1007        assert_eq!(result.feature_importance.len(), 3);
1008        let total: f64 = result.feature_importance.iter().sum();
1009        assert!((total - 1.0).abs() < 1e-6);
1010    }
1011}