1#![allow(clippy::too_many_arguments)]
2#![allow(dead_code)]
3
4use pyo3::prelude::*;
5use rayon::prelude::*;
6
7#[derive(Debug, Clone)]
8#[pyclass]
9pub struct SuperLearnerConfig {
10 #[pyo3(get, set)]
11 pub n_folds: usize,
12 #[pyo3(get, set)]
13 pub meta_learner: String,
14 #[pyo3(get, set)]
15 pub include_original_features: bool,
16 #[pyo3(get, set)]
17 pub optimize_weights: bool,
18 #[pyo3(get, set)]
19 pub seed: Option<u64>,
20}
21
22#[pymethods]
23impl SuperLearnerConfig {
24 #[new]
25 #[pyo3(signature = (
26 n_folds=5,
27 meta_learner="nnls",
28 include_original_features=false,
29 optimize_weights=true,
30 seed=None
31 ))]
32 pub fn new(
33 n_folds: usize,
34 meta_learner: &str,
35 include_original_features: bool,
36 optimize_weights: bool,
37 seed: Option<u64>,
38 ) -> PyResult<Self> {
39 if n_folds < 2 {
40 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
41 "n_folds must be at least 2",
42 ));
43 }
44 Ok(Self {
45 n_folds,
46 meta_learner: meta_learner.to_string(),
47 include_original_features,
48 optimize_weights,
49 seed,
50 })
51 }
52}
53
54fn create_cv_folds(n: usize, n_folds: usize, seed: u64) -> Vec<Vec<usize>> {
55 let mut indices: Vec<usize> = (0..n).collect();
56 let mut rng_state = seed;
57 for i in (1..n).rev() {
58 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
59 let j = (rng_state as usize) % (i + 1);
60 indices.swap(i, j);
61 }
62
63 let fold_size = n / n_folds;
64 let mut folds = Vec::with_capacity(n_folds);
65
66 for i in 0..n_folds {
67 let start = i * fold_size;
68 let end = if i == n_folds - 1 {
69 n
70 } else {
71 (i + 1) * fold_size
72 };
73 folds.push(indices[start..end].to_vec());
74 }
75
76 folds
77}
78
79fn fit_base_cox(
80 time: &[f64],
81 event: &[i32],
82 covariates: &[Vec<f64>],
83 train_indices: &[usize],
84 learning_rate: f64,
85 n_iter: usize,
86) -> Vec<f64> {
87 let n_features = if covariates.is_empty() {
88 0
89 } else {
90 covariates[0].len()
91 };
92
93 let mut coefficients = vec![0.0; n_features];
94
95 let mut sorted_indices: Vec<usize> = train_indices.to_vec();
96 sorted_indices.sort_by(|&a, &b| time[b].partial_cmp(&time[a]).unwrap());
97
98 for _ in 0..n_iter {
99 let linear_pred: Vec<f64> = sorted_indices
100 .iter()
101 .map(|&i| {
102 covariates[i]
103 .iter()
104 .zip(coefficients.iter())
105 .map(|(&x, &b)| x * b)
106 .sum()
107 })
108 .collect();
109
110 let exp_lp: Vec<f64> = linear_pred.iter().map(|&lp| lp.exp()).collect();
111
112 let mut gradient = vec![0.0; n_features];
113 let mut risk_sum = 0.0;
114 let mut weighted_sum = vec![0.0; n_features];
115
116 for (idx, &i) in sorted_indices.iter().enumerate() {
117 risk_sum += exp_lp[idx];
118 for (j, &xij) in covariates[i].iter().enumerate() {
119 weighted_sum[j] += xij * exp_lp[idx];
120 }
121
122 if event[i] == 1 {
123 for (j, g) in gradient.iter_mut().enumerate() {
124 *g += covariates[i][j] - weighted_sum[j] / risk_sum;
125 }
126 }
127 }
128
129 for (b, g) in coefficients.iter_mut().zip(gradient.iter()) {
130 *b += learning_rate * g / train_indices.len() as f64;
131 }
132 }
133
134 coefficients
135}
136
137fn nnls_weights(predictions: &[Vec<f64>], outcomes: &[f64], n_models: usize) -> Vec<f64> {
138 let n = outcomes.len();
139 let mut weights = vec![1.0 / n_models as f64; n_models];
140
141 for _ in 0..100 {
142 let mut gradient = vec![0.0; n_models];
143
144 for i in 0..n {
145 let pred: f64 = (0..n_models).map(|m| weights[m] * predictions[m][i]).sum();
146 let error = pred - outcomes[i];
147
148 for m in 0..n_models {
149 gradient[m] += 2.0 * error * predictions[m][i] / n as f64;
150 }
151 }
152
153 for (w, g) in weights.iter_mut().zip(gradient.iter()) {
154 *w = (*w - 0.01 * g).max(0.0);
155 }
156
157 let sum: f64 = weights.iter().sum();
158 if sum > 0.0 {
159 for w in &mut weights {
160 *w /= sum;
161 }
162 }
163 }
164
165 weights
166}
167
168#[derive(Debug, Clone)]
169#[pyclass]
170pub struct SuperLearnerResult {
171 #[pyo3(get)]
172 pub weights: Vec<f64>,
173 #[pyo3(get)]
174 pub cv_risks: Vec<f64>,
175 #[pyo3(get)]
176 pub model_names: Vec<String>,
177 #[pyo3(get)]
178 pub ensemble_c_index: f64,
179 #[pyo3(get)]
180 pub individual_c_indices: Vec<f64>,
181}
182
183#[pymethods]
184impl SuperLearnerResult {
185 fn __repr__(&self) -> String {
186 format!(
187 "SuperLearnerResult(n_models={}, C-index={:.4})",
188 self.weights.len(),
189 self.ensemble_c_index
190 )
191 }
192
193 fn best_model(&self) -> (String, f64) {
194 let (idx, &max_c) = self
195 .individual_c_indices
196 .iter()
197 .enumerate()
198 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
199 .unwrap_or((0, &0.0));
200 (self.model_names[idx].clone(), max_c)
201 }
202}
203
204fn compute_c_index(time: &[f64], event: &[i32], risk: &[f64]) -> f64 {
205 let n = time.len();
206 let mut concordant = 0.0;
207 let mut discordant = 0.0;
208
209 for i in 0..n {
210 if event[i] == 1 {
211 for j in 0..n {
212 if time[j] > time[i] {
213 if risk[i] > risk[j] {
214 concordant += 1.0;
215 } else if risk[i] < risk[j] {
216 discordant += 1.0;
217 } else {
218 concordant += 0.5;
219 discordant += 0.5;
220 }
221 }
222 }
223 }
224 }
225
226 if concordant + discordant > 0.0 {
227 concordant / (concordant + discordant)
228 } else {
229 0.5
230 }
231}
232
233#[pyfunction]
234#[pyo3(signature = (
235 time,
236 event,
237 covariates,
238 base_learner_predictions,
239 model_names,
240 config
241))]
242pub fn super_learner_survival(
243 time: Vec<f64>,
244 event: Vec<i32>,
245 covariates: Vec<Vec<f64>>,
246 base_learner_predictions: Vec<Vec<f64>>,
247 model_names: Vec<String>,
248 config: SuperLearnerConfig,
249) -> PyResult<SuperLearnerResult> {
250 let n = time.len();
251 let n_models = base_learner_predictions.len();
252
253 if n == 0 || event.len() != n || covariates.len() != n {
254 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
255 "Input arrays must have the same non-zero length",
256 ));
257 }
258 if n_models == 0 || model_names.len() != n_models {
259 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
260 "Must provide predictions from at least one model",
261 ));
262 }
263
264 let seed = config.seed.unwrap_or(42);
265 let folds = create_cv_folds(n, config.n_folds, seed);
266
267 let mut cv_predictions: Vec<Vec<f64>> = vec![vec![0.0; n]; n_models];
268
269 for test_indices in folds.iter() {
270 let train_indices: Vec<usize> = (0..n).filter(|i| !test_indices.contains(i)).collect();
271
272 for m in 0..n_models {
273 let train_preds: Vec<f64> = train_indices
274 .iter()
275 .map(|&i| base_learner_predictions[m][i])
276 .collect();
277 let test_preds: Vec<f64> = test_indices
278 .iter()
279 .map(|&i| base_learner_predictions[m][i])
280 .collect();
281
282 let scale = if !train_preds.is_empty() {
283 train_preds.iter().sum::<f64>() / train_preds.len() as f64
284 } else {
285 1.0
286 };
287
288 for (idx, &test_i) in test_indices.iter().enumerate() {
289 cv_predictions[m][test_i] = test_preds[idx] / scale.max(1e-10);
290 }
291 }
292 }
293
294 let outcomes: Vec<f64> = event.iter().map(|&e| e as f64).collect();
295 let weights = if config.optimize_weights {
296 nnls_weights(&cv_predictions, &outcomes, n_models)
297 } else {
298 vec![1.0 / n_models as f64; n_models]
299 };
300
301 let ensemble_risk: Vec<f64> = (0..n)
302 .map(|i| {
303 (0..n_models)
304 .map(|m| weights[m] * base_learner_predictions[m][i])
305 .sum()
306 })
307 .collect();
308
309 let ensemble_c_index = compute_c_index(&time, &event, &ensemble_risk);
310
311 let individual_c_indices: Vec<f64> = base_learner_predictions
312 .iter()
313 .map(|preds| compute_c_index(&time, &event, preds))
314 .collect();
315
316 let cv_risks: Vec<f64> = (0..n_models)
317 .map(|m| {
318 let mse: f64 = cv_predictions[m]
319 .iter()
320 .zip(outcomes.iter())
321 .map(|(&p, &o)| (p - o).powi(2))
322 .sum::<f64>()
323 / n as f64;
324 mse
325 })
326 .collect();
327
328 Ok(SuperLearnerResult {
329 weights,
330 cv_risks,
331 model_names,
332 ensemble_c_index,
333 individual_c_indices,
334 })
335}
336
337#[derive(Debug, Clone)]
338#[pyclass]
339pub struct StackingConfig {
340 #[pyo3(get, set)]
341 pub n_folds: usize,
342 #[pyo3(get, set)]
343 pub meta_model: String,
344 #[pyo3(get, set)]
345 pub use_probabilities: bool,
346 #[pyo3(get, set)]
347 pub seed: Option<u64>,
348}
349
350#[pymethods]
351impl StackingConfig {
352 #[new]
353 #[pyo3(signature = (
354 n_folds=5,
355 meta_model="cox",
356 use_probabilities=true,
357 seed=None
358 ))]
359 pub fn new(
360 n_folds: usize,
361 meta_model: &str,
362 use_probabilities: bool,
363 seed: Option<u64>,
364 ) -> PyResult<Self> {
365 if n_folds < 2 {
366 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
367 "n_folds must be at least 2",
368 ));
369 }
370 Ok(Self {
371 n_folds,
372 meta_model: meta_model.to_string(),
373 use_probabilities,
374 seed,
375 })
376 }
377}
378
379#[derive(Debug, Clone)]
380#[pyclass]
381pub struct StackingResult {
382 #[pyo3(get)]
383 pub meta_coefficients: Vec<f64>,
384 #[pyo3(get)]
385 pub stacked_predictions: Vec<f64>,
386 #[pyo3(get)]
387 pub c_index: f64,
388 #[pyo3(get)]
389 pub base_model_importance: Vec<f64>,
390}
391
392#[pymethods]
393impl StackingResult {
394 fn __repr__(&self) -> String {
395 format!(
396 "StackingResult(n_base_models={}, C-index={:.4})",
397 self.meta_coefficients.len(),
398 self.c_index
399 )
400 }
401}
402
403#[pyfunction]
404#[pyo3(signature = (
405 time,
406 event,
407 base_predictions,
408 config
409))]
410pub fn stacking_survival(
411 time: Vec<f64>,
412 event: Vec<i32>,
413 base_predictions: Vec<Vec<f64>>,
414 config: StackingConfig,
415) -> PyResult<StackingResult> {
416 let n = time.len();
417 let n_models = base_predictions.len();
418
419 if n == 0 || event.len() != n {
420 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
421 "time and event must have the same non-zero length",
422 ));
423 }
424 if n_models == 0 {
425 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
426 "Must provide at least one base model",
427 ));
428 }
429
430 let seed = config.seed.unwrap_or(42);
431 let folds = create_cv_folds(n, config.n_folds, seed);
432
433 let mut oof_predictions: Vec<Vec<f64>> = vec![vec![0.0; n]; n_models];
434
435 for test_indices in &folds {
436 let train_indices: Vec<usize> = (0..n).filter(|i| !test_indices.contains(i)).collect();
437
438 for m in 0..n_models {
439 let train_mean: f64 = train_indices
440 .iter()
441 .map(|&i| base_predictions[m][i])
442 .sum::<f64>()
443 / train_indices.len() as f64;
444
445 for &test_i in test_indices {
446 oof_predictions[m][test_i] = base_predictions[m][test_i] / train_mean.max(1e-10);
447 }
448 }
449 }
450
451 let meta_features: Vec<Vec<f64>> = (0..n)
452 .map(|i| (0..n_models).map(|m| oof_predictions[m][i]).collect())
453 .collect();
454
455 let train_indices: Vec<usize> = (0..n).collect();
456 let meta_coefficients = fit_base_cox(&time, &event, &meta_features, &train_indices, 0.01, 100);
457
458 let stacked_predictions: Vec<f64> = meta_features
459 .iter()
460 .map(|x| {
461 x.iter()
462 .zip(meta_coefficients.iter())
463 .map(|(&xi, &bi)| xi * bi)
464 .sum::<f64>()
465 .exp()
466 })
467 .collect();
468
469 let c_index = compute_c_index(&time, &event, &stacked_predictions);
470
471 let total_abs: f64 = meta_coefficients.iter().map(|&c| c.abs()).sum();
472 let base_model_importance: Vec<f64> = if total_abs > 0.0 {
473 meta_coefficients
474 .iter()
475 .map(|&c| c.abs() / total_abs)
476 .collect()
477 } else {
478 vec![1.0 / n_models as f64; n_models]
479 };
480
481 Ok(StackingResult {
482 meta_coefficients,
483 stacked_predictions,
484 c_index,
485 base_model_importance,
486 })
487}
488
489#[derive(Debug, Clone)]
490#[pyclass]
491pub struct ComponentwiseBoostingConfig {
492 #[pyo3(get, set)]
493 pub n_iterations: usize,
494 #[pyo3(get, set)]
495 pub learning_rate: f64,
496 #[pyo3(get, set)]
497 pub early_stopping_rounds: Option<usize>,
498 #[pyo3(get, set)]
499 pub subsample_ratio: f64,
500 #[pyo3(get, set)]
501 pub seed: Option<u64>,
502}
503
504#[pymethods]
505impl ComponentwiseBoostingConfig {
506 #[new]
507 #[pyo3(signature = (
508 n_iterations=100,
509 learning_rate=0.1,
510 early_stopping_rounds=None,
511 subsample_ratio=1.0,
512 seed=None
513 ))]
514 pub fn new(
515 n_iterations: usize,
516 learning_rate: f64,
517 early_stopping_rounds: Option<usize>,
518 subsample_ratio: f64,
519 seed: Option<u64>,
520 ) -> PyResult<Self> {
521 if learning_rate <= 0.0 || learning_rate > 1.0 {
522 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
523 "learning_rate must be in (0, 1]",
524 ));
525 }
526 if subsample_ratio <= 0.0 || subsample_ratio > 1.0 {
527 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
528 "subsample_ratio must be in (0, 1]",
529 ));
530 }
531 Ok(Self {
532 n_iterations,
533 learning_rate,
534 early_stopping_rounds,
535 subsample_ratio,
536 seed,
537 })
538 }
539}
540
541#[derive(Debug, Clone)]
542#[pyclass]
543pub struct ComponentwiseBoostingResult {
544 #[pyo3(get)]
545 pub coefficients: Vec<f64>,
546 #[pyo3(get)]
547 pub selected_features: Vec<usize>,
548 #[pyo3(get)]
549 pub iteration_log_likelihood: Vec<f64>,
550 #[pyo3(get)]
551 pub feature_importance: Vec<f64>,
552 #[pyo3(get)]
553 pub optimal_iterations: usize,
554}
555
556#[pymethods]
557impl ComponentwiseBoostingResult {
558 fn __repr__(&self) -> String {
559 format!(
560 "ComponentwiseBoostingResult(n_selected={}, iterations={})",
561 self.selected_features
562 .iter()
563 .collect::<std::collections::HashSet<_>>()
564 .len(),
565 self.optimal_iterations
566 )
567 }
568
569 fn predict_risk(&self, covariates: Vec<Vec<f64>>) -> Vec<f64> {
570 covariates
571 .par_iter()
572 .map(|x| {
573 x.iter()
574 .zip(self.coefficients.iter())
575 .map(|(&xi, &bi)| xi * bi)
576 .sum::<f64>()
577 .exp()
578 })
579 .collect()
580 }
581}
582
583fn compute_partial_log_likelihood(time: &[f64], event: &[i32], linear_pred: &[f64]) -> f64 {
584 let n = time.len();
585 let mut indices: Vec<usize> = (0..n).collect();
586 indices.sort_by(|&a, &b| time[b].partial_cmp(&time[a]).unwrap());
587
588 let exp_lp: Vec<f64> = linear_pred.iter().map(|&lp| lp.exp()).collect();
589
590 let mut ll = 0.0;
591 let mut risk_sum = 0.0;
592
593 for &i in &indices {
594 risk_sum += exp_lp[i];
595 if event[i] == 1 {
596 ll += linear_pred[i] - risk_sum.ln();
597 }
598 }
599
600 ll
601}
602
603#[pyfunction]
604#[pyo3(signature = (
605 time,
606 event,
607 covariates,
608 config
609))]
610pub fn componentwise_boosting(
611 time: Vec<f64>,
612 event: Vec<i32>,
613 covariates: Vec<Vec<f64>>,
614 config: ComponentwiseBoostingConfig,
615) -> PyResult<ComponentwiseBoostingResult> {
616 let n = time.len();
617 if n == 0 || event.len() != n || covariates.len() != n {
618 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
619 "Input arrays must have the same non-zero length",
620 ));
621 }
622
623 let n_features = if covariates.is_empty() {
624 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
625 "Covariates cannot be empty",
626 ));
627 } else {
628 covariates[0].len()
629 };
630
631 let seed = config.seed.unwrap_or(42);
632 let mut rng_state = seed;
633
634 let mut coefficients: Vec<f64> = vec![0.0; n_features];
635 let mut linear_pred: Vec<f64> = vec![0.0; n];
636 let mut selected_features = Vec::new();
637 let mut iteration_log_likelihood = Vec::new();
638 let mut feature_selection_count = vec![0usize; n_features];
639
640 let mut best_ll = f64::NEG_INFINITY;
641 let mut rounds_without_improvement = 0;
642 let mut optimal_iterations = 0;
643
644 for iter in 0..config.n_iterations {
645 let sample_indices: Vec<usize> = if config.subsample_ratio < 1.0 {
646 let sample_size = (n as f64 * config.subsample_ratio).ceil() as usize;
647 let mut indices: Vec<usize> = (0..n).collect();
648 for i in (1..n).rev() {
649 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
650 let j = (rng_state as usize) % (i + 1);
651 indices.swap(i, j);
652 }
653 indices.truncate(sample_size);
654 indices
655 } else {
656 (0..n).collect()
657 };
658
659 let exp_lp: Vec<f64> = linear_pred.iter().map(|&lp| lp.exp()).collect();
660
661 let mut sorted_indices: Vec<usize> = sample_indices.clone();
662 sorted_indices.sort_by(|&a, &b| time[b].partial_cmp(&time[a]).unwrap());
663
664 let mut best_feature = 0;
665 let mut best_score = f64::NEG_INFINITY;
666 let mut best_update = 0.0;
667
668 #[allow(clippy::needless_range_loop)]
669 for j in 0..n_features {
670 let mut gradient = 0.0;
671 let mut hessian = 0.0;
672 let mut risk_sum = 0.0;
673 let mut weighted_sum = 0.0;
674 let mut weighted_sq_sum = 0.0;
675
676 for &i in &sorted_indices {
677 risk_sum += exp_lp[i];
678 weighted_sum += covariates[i][j] * exp_lp[i];
679 weighted_sq_sum += covariates[i][j].powi(2) * exp_lp[i];
680
681 if event[i] == 1 {
682 let mean = weighted_sum / risk_sum;
683 gradient += covariates[i][j] - mean;
684 hessian += weighted_sq_sum / risk_sum - mean.powi(2);
685 }
686 }
687
688 if hessian.abs() > 1e-10 {
689 let update = gradient / hessian;
690 let score = gradient.abs();
691
692 if score > best_score {
693 best_score = score;
694 best_feature = j;
695 best_update = update;
696 }
697 }
698 }
699
700 coefficients[best_feature] += config.learning_rate * best_update;
701 selected_features.push(best_feature);
702 feature_selection_count[best_feature] += 1;
703
704 for i in 0..n {
705 linear_pred[i] = coefficients
706 .iter()
707 .zip(covariates[i].iter())
708 .map(|(&b, &x)| b * x)
709 .sum();
710 }
711
712 let ll = compute_partial_log_likelihood(&time, &event, &linear_pred);
713 iteration_log_likelihood.push(ll);
714
715 if ll > best_ll {
716 best_ll = ll;
717 optimal_iterations = iter + 1;
718 rounds_without_improvement = 0;
719 } else {
720 rounds_without_improvement += 1;
721 }
722
723 if let Some(patience) = config.early_stopping_rounds
724 && rounds_without_improvement >= patience
725 {
726 break;
727 }
728 }
729
730 let total_selections: f64 = feature_selection_count.iter().sum::<usize>() as f64;
731 let feature_importance: Vec<f64> = if total_selections > 0.0 {
732 feature_selection_count
733 .iter()
734 .map(|&c| c as f64 / total_selections)
735 .collect()
736 } else {
737 vec![0.0; n_features]
738 };
739
740 Ok(ComponentwiseBoostingResult {
741 coefficients,
742 selected_features,
743 iteration_log_likelihood,
744 feature_importance,
745 optimal_iterations,
746 })
747}
748
749#[derive(Debug, Clone)]
750#[pyclass]
751pub struct BlendingResult {
752 #[pyo3(get)]
753 pub blend_weights: Vec<f64>,
754 #[pyo3(get)]
755 pub blended_predictions: Vec<f64>,
756 #[pyo3(get)]
757 pub validation_c_index: f64,
758}
759
760#[pymethods]
761impl BlendingResult {
762 fn __repr__(&self) -> String {
763 format!(
764 "BlendingResult(n_models={}, val_C={:.4})",
765 self.blend_weights.len(),
766 self.validation_c_index
767 )
768 }
769}
770
771#[pyfunction]
772#[pyo3(signature = (
773 val_time,
774 val_event,
775 val_predictions,
776 test_predictions
777))]
778pub fn blending_survival(
779 val_time: Vec<f64>,
780 val_event: Vec<i32>,
781 val_predictions: Vec<Vec<f64>>,
782 test_predictions: Vec<Vec<f64>>,
783) -> PyResult<BlendingResult> {
784 let n_val = val_time.len();
785 let n_models = val_predictions.len();
786
787 if n_val == 0 || val_event.len() != n_val {
788 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
789 "Validation arrays must have the same non-zero length",
790 ));
791 }
792 if n_models == 0 || test_predictions.len() != n_models {
793 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
794 "Must have same number of models for validation and test",
795 ));
796 }
797
798 let outcomes: Vec<f64> = val_event.iter().map(|&e| e as f64).collect();
799 let blend_weights = nnls_weights(&val_predictions, &outcomes, n_models);
800
801 let n_test = test_predictions[0].len();
802 let blended_predictions: Vec<f64> = (0..n_test)
803 .map(|i| {
804 (0..n_models)
805 .map(|m| blend_weights[m] * test_predictions[m][i])
806 .sum()
807 })
808 .collect();
809
810 let val_blended: Vec<f64> = (0..n_val)
811 .map(|i| {
812 (0..n_models)
813 .map(|m| blend_weights[m] * val_predictions[m][i])
814 .sum()
815 })
816 .collect();
817
818 let validation_c_index = compute_c_index(&val_time, &val_event, &val_blended);
819
820 Ok(BlendingResult {
821 blend_weights,
822 blended_predictions,
823 validation_c_index,
824 })
825}
826
827#[cfg(test)]
828mod tests {
829 use super::*;
830
831 #[test]
832 fn test_super_learner() {
833 let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
834 let event = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 0];
835 let covariates: Vec<Vec<f64>> = (0..10).map(|i| vec![i as f64 * 0.1]).collect();
836 let pred1: Vec<f64> = (0..10).map(|i| 0.1 + i as f64 * 0.05).collect();
837 let pred2: Vec<f64> = (0..10).map(|i| 0.2 + i as f64 * 0.03).collect();
838
839 let config = SuperLearnerConfig::new(3, "nnls", false, true, Some(42)).unwrap();
840 let result = super_learner_survival(
841 time,
842 event,
843 covariates,
844 vec![pred1, pred2],
845 vec!["model1".to_string(), "model2".to_string()],
846 config,
847 )
848 .unwrap();
849
850 assert_eq!(result.weights.len(), 2);
851 assert!((result.weights.iter().sum::<f64>() - 1.0).abs() < 0.01);
852 }
853
854 #[test]
855 fn test_stacking() {
856 let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
857 let event = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 0];
858 let pred1: Vec<f64> = (0..10).map(|i| 0.1 + i as f64 * 0.05).collect();
859 let pred2: Vec<f64> = (0..10).map(|i| 0.2 + i as f64 * 0.03).collect();
860
861 let config = StackingConfig::new(3, "cox", true, Some(42)).unwrap();
862 let result = stacking_survival(time, event, vec![pred1, pred2], config).unwrap();
863
864 assert_eq!(result.meta_coefficients.len(), 2);
865 }
866
867 #[test]
868 fn test_componentwise_boosting() {
869 let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
870 let event = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 0];
871 let covariates: Vec<Vec<f64>> = (0..10)
872 .map(|i| vec![i as f64 * 0.1, (10 - i) as f64 * 0.1])
873 .collect();
874
875 let config = ComponentwiseBoostingConfig::new(50, 0.1, Some(10), 1.0, Some(42)).unwrap();
876 let result = componentwise_boosting(time, event, covariates, config).unwrap();
877
878 assert_eq!(result.coefficients.len(), 2);
879 assert!(!result.selected_features.is_empty());
880 }
881
882 #[test]
883 fn test_blending() {
884 let val_time = vec![1.0, 2.0, 3.0, 4.0, 5.0];
885 let val_event = vec![1, 0, 1, 0, 1];
886 let val_pred1 = vec![0.1, 0.2, 0.3, 0.4, 0.5];
887 let val_pred2 = vec![0.15, 0.25, 0.35, 0.45, 0.55];
888 let test_pred1 = vec![0.2, 0.3, 0.4];
889 let test_pred2 = vec![0.25, 0.35, 0.45];
890
891 let result = blending_survival(
892 val_time,
893 val_event,
894 vec![val_pred1, val_pred2],
895 vec![test_pred1, test_pred2],
896 )
897 .unwrap();
898
899 assert_eq!(result.blend_weights.len(), 2);
900 assert_eq!(result.blended_predictions.len(), 3);
901 }
902
903 #[test]
904 fn test_super_learner_config_validation() {
905 let result = SuperLearnerConfig::new(1, "nnls", false, true, None);
906 assert!(result.is_err());
907 }
908
909 #[test]
910 fn test_stacking_config_validation() {
911 let result = StackingConfig::new(1, "cox", true, None);
912 assert!(result.is_err());
913 }
914
915 #[test]
916 fn test_componentwise_boosting_config_validation() {
917 let result = ComponentwiseBoostingConfig::new(100, 0.0, None, 1.0, None);
918 assert!(result.is_err());
919
920 let result = ComponentwiseBoostingConfig::new(100, 1.5, None, 1.0, None);
921 assert!(result.is_err());
922
923 let result = ComponentwiseBoostingConfig::new(100, 0.1, None, 0.0, None);
924 assert!(result.is_err());
925 }
926
927 #[test]
928 fn test_super_learner_empty_input() {
929 let config = SuperLearnerConfig::new(3, "nnls", false, true, Some(42)).unwrap();
930 let result = super_learner_survival(vec![], vec![], vec![], vec![], vec![], config);
931 assert!(result.is_err());
932 }
933
934 #[test]
935 fn test_super_learner_uniform_weights() {
936 let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
937 let event = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 0];
938 let covariates: Vec<Vec<f64>> = (0..10).map(|i| vec![i as f64 * 0.1]).collect();
939 let pred1: Vec<f64> = (0..10).map(|i| 0.1 + i as f64 * 0.05).collect();
940 let pred2: Vec<f64> = (0..10).map(|i| 0.2 + i as f64 * 0.03).collect();
941
942 let config = SuperLearnerConfig::new(3, "nnls", false, false, Some(42)).unwrap();
943 let result = super_learner_survival(
944 time,
945 event,
946 covariates,
947 vec![pred1, pred2],
948 vec!["m1".to_string(), "m2".to_string()],
949 config,
950 )
951 .unwrap();
952
953 assert!((result.weights[0] - 0.5).abs() < 1e-6);
954 assert!((result.weights[1] - 0.5).abs() < 1e-6);
955 }
956
957 #[test]
958 fn test_componentwise_boosting_predict_risk() {
959 let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
960 let event = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 0];
961 let covariates: Vec<Vec<f64>> = (0..10)
962 .map(|i| vec![i as f64 * 0.1, (10 - i) as f64 * 0.1])
963 .collect();
964
965 let config = ComponentwiseBoostingConfig::new(50, 0.1, Some(10), 1.0, Some(42)).unwrap();
966 let result = componentwise_boosting(time, event, covariates.clone(), config).unwrap();
967
968 let risks = result.predict_risk(covariates);
969 assert_eq!(risks.len(), 10);
970 assert!(risks.iter().all(|&r| r > 0.0));
971 }
972
973 #[test]
974 fn test_stacking_empty_input() {
975 let config = StackingConfig::new(3, "cox", true, Some(42)).unwrap();
976 let result = stacking_survival(vec![], vec![], vec![], config);
977 assert!(result.is_err());
978 }
979
980 #[test]
981 fn test_blending_empty_input() {
982 let result = blending_survival(vec![], vec![], vec![], vec![]);
983 assert!(result.is_err());
984 }
985
986 #[test]
987 fn test_blending_mismatched_models() {
988 let val_time = vec![1.0, 2.0, 3.0];
989 let val_event = vec![1, 0, 1];
990 let val_preds = vec![vec![0.1, 0.2, 0.3]];
991 let test_preds = vec![vec![0.1, 0.2], vec![0.3, 0.4]];
992 let result = blending_survival(val_time, val_event, val_preds, test_preds);
993 assert!(result.is_err());
994 }
995
996 #[test]
997 fn test_componentwise_boosting_feature_importance() {
998 let time = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
999 let event = vec![1, 0, 1, 0, 1, 0, 1, 0, 1, 0];
1000 let covariates: Vec<Vec<f64>> = (0..10)
1001 .map(|i| vec![i as f64 * 0.1, (10 - i) as f64 * 0.1, 0.5])
1002 .collect();
1003
1004 let config = ComponentwiseBoostingConfig::new(50, 0.1, None, 1.0, Some(42)).unwrap();
1005 let result = componentwise_boosting(time, event, covariates, config).unwrap();
1006
1007 assert_eq!(result.feature_importance.len(), 3);
1008 let total: f64 = result.feature_importance.iter().sum();
1009 assert!((total - 1.0).abs() < 1e-6);
1010 }
1011}