scirs2_stats/
survival.rs

1//! Survival Analysis
2//!
3//! This module provides survival analysis functions including Kaplan-Meier estimator,
4//! Cox proportional hazards model, and related statistical tests.
5
6use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::validation::*;
9
10/// Kaplan-Meier survival estimator
11///
12/// Computes the Kaplan-Meier survival function from time-to-event data.
13#[derive(Debug, Clone)]
14pub struct KaplanMeierEstimator {
15    /// Unique event times
16    pub event_times: Array1<f64>,
17    /// Survival probabilities at each event time
18    pub survival_function: Array1<f64>,
19    /// Confidence intervals (lower bound, upper bound)
20    pub confidence_intervals: Option<(Array1<f64>, Array1<f64>)>,
21    /// Number at risk at each time point
22    pub at_risk: Array1<usize>,
23    /// Number of events at each time point
24    pub events: Array1<usize>,
25    /// Median survival time
26    pub median_survival_time: Option<f64>,
27}
28
29impl KaplanMeierEstimator {
30    /// Fit Kaplan-Meier estimator
31    ///
32    /// # Arguments
33    /// * `durations` - Time to event or censoring
34    /// * `event_observed` - Whether event was observed (true) or censored (false)
35    /// * `confidence_level` - Confidence level for intervals (0.0 to 1.0)
36    ///
37    /// # Returns
38    /// * Kaplan-Meier estimator with survival function and statistics
39    pub fn fit(
40        durations: ArrayView1<f64>,
41        event_observed: ArrayView1<bool>,
42        confidence_level: Option<f64>,
43    ) -> Result<Self> {
44        checkarray_finite(&durations, "durations")?;
45
46        if durations.len() != event_observed.len() {
47            return Err(StatsError::DimensionMismatch(format!(
48                "durations length ({durations_len}) must match event_observed length ({events_len})",
49                durations_len = durations.len(),
50                events_len = event_observed.len()
51            )));
52        }
53
54        if durations.is_empty() {
55            return Err(StatsError::InvalidArgument(
56                "Input arrays cannot be empty".to_string(),
57            ));
58        }
59
60        if let Some(conf) = confidence_level {
61            check_probability(conf, "confidence_level")?;
62        }
63
64        // Create time-event pairs and sort by time
65        let mut time_event_pairs: Vec<(f64, bool)> = durations
66            .iter()
67            .zip(event_observed.iter())
68            .map(|(&t, &e)| (t, e))
69            .collect();
70        time_event_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
71
72        // Calculate survival function using Kaplan-Meier product-limit estimator
73        let mut unique_times = Vec::new();
74        let mut at_risk_counts = Vec::new();
75        let mut event_counts = Vec::new();
76        let mut survival_probs = Vec::new();
77
78        let n = time_event_pairs.len();
79        let mut current_survival = 1.0;
80        let mut current_at_risk = n;
81
82        let mut i = 0;
83        while i < time_event_pairs.len() {
84            let current_time = time_event_pairs[i].0;
85            let mut events_at_time = 0;
86            let mut censored_at_time = 0;
87
88            // Count events and censored observations at current time
89            while i < time_event_pairs.len() && time_event_pairs[i].0 == current_time {
90                if time_event_pairs[i].1 {
91                    events_at_time += 1;
92                } else {
93                    censored_at_time += 1;
94                }
95                i += 1;
96            }
97
98            if events_at_time > 0 {
99                // Update survival probability only if there were events
100                let survival_this_time = 1.0 - (events_at_time as f64) / (current_at_risk as f64);
101                current_survival *= survival_this_time;
102
103                unique_times.push(current_time);
104                at_risk_counts.push(current_at_risk);
105                event_counts.push(events_at_time);
106                survival_probs.push(current_survival);
107            }
108
109            // Update at-risk count
110            current_at_risk -= events_at_time + censored_at_time;
111        }
112
113        let event_times = Array1::from_vec(unique_times);
114        let survival_function = Array1::from_vec(survival_probs);
115        let at_risk = Array1::from_vec(at_risk_counts);
116        let events = Array1::from_vec(event_counts);
117
118        // Calculate confidence intervals if requested
119        let confidence_intervals = if let Some(conf_level) = confidence_level {
120            Some(Self::calculate_confidence_intervals(
121                &survival_function,
122                &at_risk,
123                &events,
124                conf_level,
125            )?)
126        } else {
127            None
128        };
129
130        // Calculate median survival time
131        let median_survival_time =
132            Self::calculate_median_survival(&event_times, &survival_function);
133
134        Ok(Self {
135            event_times,
136            survival_function,
137            confidence_intervals,
138            at_risk,
139            events,
140            median_survival_time,
141        })
142    }
143
144    /// Calculate confidence intervals using Greenwood's formula
145    fn calculate_confidence_intervals(
146        survival_function: &Array1<f64>,
147        at_risk: &Array1<usize>,
148        events: &Array1<usize>,
149        confidence_level: f64,
150    ) -> Result<(Array1<f64>, Array1<f64>)> {
151        let _alpha = 1.0 - confidence_level;
152        let z_score = 1.96; // Approximate 95% CI, should use proper inverse normal for other levels
153
154        let mut lower_bounds = Array1::zeros(survival_function.len());
155        let mut upper_bounds = Array1::zeros(survival_function.len());
156
157        // Cumulative Greenwood variance
158        let mut cumulative_variance = 0.0;
159
160        for i in 0..survival_function.len() {
161            // Greenwood's formula for variance
162            let n_i = at_risk[i] as f64;
163            let d_i = events[i] as f64;
164
165            if n_i > d_i && n_i > 0.0 {
166                cumulative_variance += d_i / (n_i * (n_i - d_i));
167            }
168
169            let s_t = survival_function[i];
170            if s_t > 0.0 {
171                let se = s_t * cumulative_variance.sqrt();
172
173                // Using log-log transformation for better CI properties
174                let log_log_s = (-(s_t.ln())).ln();
175                let se_log_log = se / (s_t * s_t.ln().abs());
176
177                let lower_log_log = log_log_s - z_score * se_log_log;
178                let upper_log_log = log_log_s + z_score * se_log_log;
179
180                lower_bounds[i] = (-(-lower_log_log.exp()).exp()).max(0.0);
181                upper_bounds[i] = (-(-upper_log_log.exp()).exp()).min(1.0);
182            } else {
183                lower_bounds[i] = 0.0;
184                upper_bounds[i] = 0.0;
185            }
186        }
187
188        Ok((lower_bounds, upper_bounds))
189    }
190
191    /// Calculate median survival time
192    fn calculate_median_survival(
193        event_times: &Array1<f64>,
194        survival_function: &Array1<f64>,
195    ) -> Option<f64> {
196        // Find the time where survival probability first drops to or below 0.5
197        for i in 0..survival_function.len() {
198            if survival_function[i] <= 0.5 {
199                return Some(event_times[i]);
200            }
201        }
202        None // Median not reached
203    }
204
205    /// Predict survival probability at given times
206    pub fn predict(&self, times: ArrayView1<f64>) -> Result<Array1<f64>> {
207        checkarray_finite(&times, "times")?;
208
209        let mut predictions = Array1::zeros(times.len());
210
211        for (i, &t) in times.iter().enumerate() {
212            if t < 0.0 {
213                return Err(StatsError::InvalidArgument(
214                    "Times must be non-negative".to_string(),
215                ));
216            }
217
218            // Find the survival probability at time t
219            let mut survival_prob = 1.0; // Start with 100% survival at time 0
220
221            for j in 0..self.event_times.len() {
222                if self.event_times[j] <= t {
223                    survival_prob = self.survival_function[j];
224                } else {
225                    break;
226                }
227            }
228
229            predictions[i] = survival_prob;
230        }
231
232        Ok(predictions)
233    }
234}
235
236/// Log-rank test for comparing survival curves
237///
238/// Tests the null hypothesis that two or more survival curves are identical.
239pub struct LogRankTest;
240
241impl LogRankTest {
242    /// Perform log-rank test comparing two survival curves
243    ///
244    /// # Arguments
245    /// * `durations1` - Time to event or censoring for group 1
246    /// * `events1` - Whether event was observed for group 1
247    /// * `durations2` - Time to event or censoring for group 2
248    /// * `events2` - Whether event was observed for group 2
249    ///
250    /// # Returns
251    /// * Tuple of (test statistic, p-value)
252    pub fn compare_two_groups(
253        durations1: ArrayView1<f64>,
254        events1: ArrayView1<bool>,
255        durations2: ArrayView1<f64>,
256        events2: ArrayView1<bool>,
257    ) -> Result<(f64, f64)> {
258        checkarray_finite(&durations1, "durations1")?;
259        checkarray_finite(&durations2, "durations2")?;
260
261        if durations1.len() != events1.len() || durations2.len() != events2.len() {
262            return Err(StatsError::DimensionMismatch(
263                "Durations and events arrays must have same length".to_string(),
264            ));
265        }
266
267        // Combine all observations with group labels
268        let mut combineddata = Vec::new();
269
270        for i in 0..durations1.len() {
271            combineddata.push((durations1[i], events1[i], 0)); // Group 0
272        }
273        for i in 0..durations2.len() {
274            combineddata.push((durations2[i], events2[i], 1)); // Group 1
275        }
276
277        // Sort by time
278        combineddata.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
279
280        // Calculate log-rank statistic
281        let mut observed_group1 = 0.0;
282        let mut expected_group1 = 0.0;
283        let mut variance = 0.0;
284
285        let n1 = durations1.len() as f64;
286        let n2 = durations2.len() as f64;
287        let mut at_risk1 = n1;
288        let mut at_risk2 = n2;
289
290        let mut i = 0;
291        while i < combineddata.len() {
292            let current_time = combineddata[i].0;
293            let mut events_group1 = 0.0;
294            let mut events_group2 = 0.0;
295            let mut censored_group1 = 0.0;
296            let mut censored_group2 = 0.0;
297
298            // Count events and censoring at current time
299            while i < combineddata.len() && combineddata[i].0 == current_time {
300                let (_, is_event, group) = combineddata[i];
301                if group == 0 {
302                    if is_event {
303                        events_group1 += 1.0;
304                    } else {
305                        censored_group1 += 1.0;
306                    }
307                } else if is_event {
308                    events_group2 += 1.0;
309                } else {
310                    censored_group2 += 1.0;
311                }
312                i += 1;
313            }
314
315            let total_events = events_group1 + events_group2;
316            let total_at_risk = at_risk1 + at_risk2;
317
318            if total_events > 0.0 && total_at_risk > 0.0 {
319                // Expected events in group 1
320                let expected_events1 = (at_risk1 / total_at_risk) * total_events;
321
322                // Variance contribution
323                let var_contrib =
324                    (at_risk1 * at_risk2 * total_events * (total_at_risk - total_events))
325                        / (total_at_risk.powi(2) * (total_at_risk - 1.0).max(1.0));
326
327                observed_group1 += events_group1;
328                expected_group1 += expected_events1;
329                variance += var_contrib;
330            }
331
332            // Update at-risk counts
333            at_risk1 -= events_group1 + censored_group1;
334            at_risk2 -= events_group2 + censored_group2;
335        }
336
337        // Calculate test statistic
338        if variance <= 0.0 {
339            return Ok((0.0, 1.0)); // No variance, no difference
340        }
341
342        let test_statistic = (observed_group1 - expected_group1).powi(2) / variance;
343
344        // Calculate p-value using chi-square distribution with 1 df
345        let p_value = Self::chi_square_survival(test_statistic, 1.0);
346
347        Ok((test_statistic, p_value))
348    }
349
350    /// Approximate survival function for chi-square distribution
351    fn chi_square_survival(x: f64, df: f64) -> f64 {
352        if x <= 0.0 {
353            return 1.0;
354        }
355
356        // Simple approximation - in practice would use proper chi-square CDF
357        let mean = df;
358        let var = 2.0 * df;
359        let std = var.sqrt();
360
361        // Normal approximation for large df
362        if df > 30.0 {
363            let z = (x - mean) / std;
364            return 0.5 * (1.0 - Self::erf(z / std::f64::consts::SQRT_2));
365        }
366
367        // Simple exponential approximation for small df
368        (-x / mean).exp()
369    }
370
371    /// Error function approximation
372    fn erf(x: f64) -> f64 {
373        // Abramowitz and Stegun approximation
374        let a1 = 0.254829592;
375        let a2 = -0.284496736;
376        let a3 = 1.421413741;
377        let a4 = -1.453152027;
378        let a5 = 1.061405429;
379        let p = 0.3275911;
380
381        let sign = if x >= 0.0 { 1.0 } else { -1.0 };
382        let x = x.abs();
383
384        let t = 1.0 / (1.0 + p * x);
385        let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
386
387        sign * y
388    }
389}
390
391/// Cox Proportional Hazards Model
392///
393/// Implements Cox regression for survival analysis with covariates.
394#[derive(Debug, Clone)]
395pub struct CoxPHModel {
396    /// Regression coefficients
397    pub coefficients: Array1<f64>,
398    /// Covariance matrix of coefficients
399    pub covariance_matrix: Array2<f64>,
400    /// Log-likelihood of the fitted model
401    pub log_likelihood: f64,
402    /// Baseline cumulative hazard
403    pub baseline_cumulative_hazard: Array1<f64>,
404    /// Time points for baseline hazard
405    pub baseline_times: Array1<f64>,
406    /// Number of iterations until convergence
407    pub n_iter: usize,
408}
409
410impl CoxPHModel {
411    /// Fit Cox proportional hazards model
412    ///
413    /// # Arguments
414    /// * `durations` - Time to event or censoring
415    /// * `events` - Whether event was observed
416    /// * `covariates` - Covariate matrix (n_samples_ x n_features)
417    /// * `max_iter` - Maximum number of iterations
418    /// * `tol` - Convergence tolerance
419    ///
420    /// # Returns
421    /// * Fitted Cox model
422    pub fn fit(
423        durations: ArrayView1<f64>,
424        events: ArrayView1<bool>,
425        covariates: ArrayView2<f64>,
426        max_iter: Option<usize>,
427        tol: Option<f64>,
428    ) -> Result<Self> {
429        checkarray_finite(&durations, "durations")?;
430        checkarray_finite(&covariates, "covariates")?;
431
432        let (n_samples_, n_features) = covariates.dim();
433        let max_iter = max_iter.unwrap_or(100);
434        let tol = tol.unwrap_or(1e-6);
435
436        if durations.len() != n_samples_ || events.len() != n_samples_ {
437            return Err(StatsError::DimensionMismatch(
438                "All input arrays must have the same number of samples".to_string(),
439            ));
440        }
441
442        // Initialize coefficients
443        let mut beta = Array1::zeros(n_features);
444        let mut prev_log_likelihood = f64::NEG_INFINITY;
445
446        for iteration in 0..max_iter {
447            // Calculate partial likelihood and its derivatives
448            let (log_likelihood, gradient, hessian) =
449                Self::partial_likelihood_derivatives(&durations, &events, &covariates, &beta)?;
450
451            // Check convergence
452            if (log_likelihood - prev_log_likelihood).abs() < tol {
453                let covariance_matrix = Self::invert_hessian(&hessian)?;
454                let (baseline_times, baseline_cumulative_hazard) =
455                    Self::calculatebaseline_hazard(&durations, &events, &covariates, &beta)?;
456
457                return Ok(Self {
458                    coefficients: beta,
459                    covariance_matrix,
460                    log_likelihood,
461                    baseline_cumulative_hazard,
462                    baseline_times,
463                    n_iter: iteration + 1,
464                });
465            }
466
467            // Newton-Raphson update
468            let hessian_inv = Self::invert_hessian(&hessian)?;
469            let delta = hessian_inv.dot(&gradient);
470            beta = &beta + &delta;
471
472            prev_log_likelihood = log_likelihood;
473        }
474
475        Err(StatsError::ConvergenceError(format!(
476            "Cox model failed to converge after {max_iter} iterations"
477        )))
478    }
479
480    /// Calculate partial likelihood and its derivatives
481    fn partial_likelihood_derivatives(
482        durations: &ArrayView1<f64>,
483        events: &ArrayView1<bool>,
484        covariates: &ArrayView2<f64>,
485        beta: &Array1<f64>,
486    ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
487        let n_samples_ = durations.len();
488        let n_features = beta.len();
489
490        // Sort by event time
491        let mut indices: Vec<usize> = (0..n_samples_).collect();
492        indices.sort_by(|&i, &j| durations[i].partial_cmp(&durations[j]).unwrap());
493
494        let mut log_likelihood = 0.0;
495        let mut gradient = Array1::zeros(n_features);
496        let mut hessian = Array2::zeros((n_features, n_features));
497
498        for &i in &indices {
499            if !events[i] {
500                continue; // Skip censored observations for likelihood
501            }
502
503            let t_i = durations[i];
504            let x_i = covariates.row(i);
505
506            // Calculate risk sets (all subjects at risk at time t_i)
507            let mut risk_set = Vec::new();
508            for &j in &indices {
509                if durations[j] >= t_i {
510                    risk_set.push(j);
511                }
512            }
513
514            if risk_set.is_empty() {
515                continue;
516            }
517
518            // Calculate exp(beta' * x) for risk set
519            let mut exp_beta_x = Array1::zeros(risk_set.len());
520            for (k, &j) in risk_set.iter().enumerate() {
521                let x_j = covariates.row(j);
522                exp_beta_x[k] = x_j.dot(beta).exp();
523            }
524
525            let sum_exp = exp_beta_x.sum();
526            if sum_exp <= 0.0 {
527                continue;
528            }
529
530            // Update log-likelihood
531            log_likelihood += x_i.dot(beta) - sum_exp.ln();
532
533            // Update gradient
534            let mut weighted_x = Array1::<f64>::zeros(n_features);
535            for (k, &j) in risk_set.iter().enumerate() {
536                let x_j = covariates.row(j);
537                let weight = exp_beta_x[k] / sum_exp;
538                weighted_x = &weighted_x + &(weight * &x_j.to_owned());
539            }
540            gradient = &gradient + &(&x_i.to_owned() - &weighted_x);
541
542            // Update Hessian (simplified - should include second-order terms)
543            for p in 0..n_features {
544                for q in 0..n_features {
545                    let mut weighted_sum = 0.0;
546                    for (k, &j) in risk_set.iter().enumerate() {
547                        let x_j = covariates.row(j);
548                        let weight = exp_beta_x[k] / sum_exp;
549                        weighted_sum += weight * x_j[p] * x_j[q];
550                    }
551                    hessian[[p, q]] -= weighted_sum - (weighted_x[p] * weighted_x[q]);
552                }
553            }
554        }
555
556        Ok((log_likelihood, gradient, hessian))
557    }
558
559    /// Invert Hessian matrix (negative for Newton-Raphson)
560    fn invert_hessian(hessian: &Array2<f64>) -> Result<Array2<f64>> {
561        let neg_hessian = -hessian;
562        scirs2_linalg::inv(&neg_hessian.view(), None)
563            .map_err(|e| StatsError::ComputationError(format!("Failed to invert Hessian: {e}")))
564    }
565
566    /// Calculate baseline hazard function
567    fn calculatebaseline_hazard(
568        durations: &ArrayView1<f64>,
569        events: &ArrayView1<bool>,
570        covariates: &ArrayView2<f64>,
571        beta: &Array1<f64>,
572    ) -> Result<(Array1<f64>, Array1<f64>)> {
573        let n_samples_ = durations.len();
574
575        // Sort by event time
576        let mut indices: Vec<usize> = (0..n_samples_).collect();
577        indices.sort_by(|&i, &j| durations[i].partial_cmp(&durations[j]).unwrap());
578
579        let mut times = Vec::new();
580        let mut cumulative_hazard = Vec::new();
581        let mut current_cumhaz = 0.0;
582
583        for &i in &indices {
584            if !events[i] {
585                continue;
586            }
587
588            let t_i = durations[i];
589
590            // Calculate risk set
591            let mut risk_sum = 0.0;
592            for &j in &indices {
593                if durations[j] >= t_i {
594                    let x_j = covariates.row(j);
595                    risk_sum += x_j.dot(beta).exp();
596                }
597            }
598
599            if risk_sum > 0.0 {
600                current_cumhaz += 1.0 / risk_sum; // Breslow estimator
601                times.push(t_i);
602                cumulative_hazard.push(current_cumhaz);
603            }
604        }
605
606        Ok((Array1::from_vec(times), Array1::from_vec(cumulative_hazard)))
607    }
608
609    /// Predict hazard ratios for new data
610    pub fn predict_hazard_ratio(&self, covariates: ArrayView2<f64>) -> Result<Array1<f64>> {
611        checkarray_finite(&covariates, "covariates")?;
612
613        if covariates.ncols() != self.coefficients.len() {
614            return Err(StatsError::DimensionMismatch(format!(
615                "covariates has {features} features, expected {expected}",
616                features = covariates.ncols(),
617                expected = self.coefficients.len()
618            )));
619        }
620
621        let mut hazard_ratios = Array1::zeros(covariates.nrows());
622
623        for i in 0..covariates.nrows() {
624            let x_i = covariates.row(i);
625            hazard_ratios[i] = x_i.dot(&self.coefficients).exp();
626        }
627
628        Ok(hazard_ratios)
629    }
630}
631
632/// Accelerated Failure Time (AFT) model
633///
634/// Models the logarithm of survival time as a linear function of covariates.
635#[derive(Debug, Clone)]
636pub struct AFTModel {
637    /// Regression coefficients
638    pub coefficients: Array1<f64>,
639    /// Scale parameter
640    pub scale: f64,
641    /// Distribution type
642    pub distribution: AFTDistribution,
643}
644
645/// Distribution types for AFT models
646#[derive(Debug, Clone, Copy)]
647pub enum AFTDistribution {
648    /// Weibull distribution
649    Weibull,
650    /// Lognormal distribution
651    Lognormal,
652    /// Exponential distribution (special case of Weibull)
653    Exponential,
654}
655
656impl AFTModel {
657    /// Fit AFT model (simplified implementation)
658    pub fn fit(
659        durations: ArrayView1<f64>,
660        events: ArrayView1<bool>,
661        covariates: ArrayView2<f64>,
662        distribution: AFTDistribution,
663    ) -> Result<Self> {
664        checkarray_finite(&durations, "durations")?;
665        checkarray_finite(&covariates, "covariates")?;
666
667        let (n_samples_, n_features) = covariates.dim();
668
669        if durations.len() != n_samples_ || events.len() != n_samples_ {
670            return Err(StatsError::DimensionMismatch(
671                "All input arrays must have the same number of samples".to_string(),
672            ));
673        }
674
675        // Simplified implementation: use log-linear regression on observed times
676        // In practice, this would use maximum likelihood estimation
677
678        let mut y = Array1::zeros(n_samples_);
679        let mut weights = Array1::zeros(n_samples_);
680
681        for i in 0..n_samples_ {
682            y[i] = durations[i].ln();
683            weights[i] = if events[i] { 1.0 } else { 0.5 }; // Downweight censored observations
684        }
685
686        // Weighted least squares (simplified)
687        let mut xtx = Array2::zeros((n_features, n_features));
688        let mut xty = Array1::zeros(n_features);
689
690        for i in 0..n_samples_ {
691            let x_i = covariates.row(i);
692            let w = weights[i];
693
694            for j in 0..n_features {
695                xty[j] += w * x_i[j] * y[i];
696                for k in 0..n_features {
697                    xtx[[j, k]] += w * x_i[j] * x_i[k];
698                }
699            }
700        }
701
702        let coefficients = scirs2_linalg::solve(&xtx.view(), &xty.view(), None).map_err(|e| {
703            StatsError::ComputationError(format!("Failed to solve regression: {e}"))
704        })?;
705
706        // Estimate scale parameter
707        let mut residual_sum = 0.0;
708        let mut count = 0;
709
710        for i in 0..n_samples_ {
711            if events[i] {
712                let x_i = covariates.row(i);
713                let predicted = x_i.dot(&coefficients);
714                let residual = y[i] - predicted;
715                residual_sum += residual * residual;
716                count += 1;
717            }
718        }
719
720        let scale = if count > 0 {
721            (residual_sum / count as f64).sqrt()
722        } else {
723            1.0
724        };
725
726        Ok(Self {
727            coefficients,
728            scale,
729            distribution,
730        })
731    }
732
733    /// Predict survival times
734    pub fn predict(&self, covariates: ArrayView2<f64>) -> Result<Array1<f64>> {
735        checkarray_finite(&covariates, "covariates")?;
736
737        if covariates.ncols() != self.coefficients.len() {
738            return Err(StatsError::DimensionMismatch(format!(
739                "covariates has {features} features, expected {expected}",
740                features = covariates.ncols(),
741                expected = self.coefficients.len()
742            )));
743        }
744
745        let mut predictions = Array1::zeros(covariates.nrows());
746
747        for i in 0..covariates.nrows() {
748            let x_i = covariates.row(i);
749            let log_time = x_i.dot(&self.coefficients);
750            predictions[i] = log_time.exp();
751        }
752
753        Ok(predictions)
754    }
755}
756
757/// Extended Cox model with time-dependent covariates and stratification
758///
759/// Supports time-varying covariates and stratified analysis for heterogeneous populations
760#[derive(Debug, Clone)]
761pub struct ExtendedCoxModel {
762    /// Regression coefficients
763    pub coefficients: Array1<f64>,
764    /// Covariance matrix of coefficients
765    pub covariance_matrix: Array2<f64>,
766    /// Log-likelihood of the fitted model
767    pub log_likelihood: f64,
768    /// Baseline cumulative hazard for each stratum
769    pub stratumbaseline_hazards: Vec<(Array1<f64>, Array1<f64>)>, // (times, cumulative hazards)
770    /// Stratum labels
771    pub strata: Option<Array1<usize>>,
772    /// Number of strata
773    pub n_strata: usize,
774    /// Time-dependent covariate indices
775    pub time_varying_indices: Vec<usize>,
776    /// Number of iterations until convergence
777    pub n_iter: usize,
778}
779
780impl ExtendedCoxModel {
781    /// Fit extended Cox model with optional stratification and time-dependent covariates
782    ///
783    /// # Arguments
784    /// * `durations` - Time to event or censoring
785    /// * `events` - Whether event was observed
786    /// * `covariates` - Covariate matrix (n_samples_ x n_features)
787    /// * `strata` - Optional stratification variable
788    /// * `time_varying_indices` - Indices of time-varying covariates
789    /// * `time_points` - Time points for time-varying covariates (if any)
790    /// * `time_varying_values` - Values of time-varying covariates at each time point
791    pub fn fit_stratified(
792        durations: ArrayView1<f64>,
793        events: ArrayView1<bool>,
794        covariates: ArrayView2<f64>,
795        strata: Option<ArrayView1<usize>>,
796        time_varying_indices: Option<Vec<usize>>,
797        max_iter: Option<usize>,
798        tol: Option<f64>,
799    ) -> Result<Self> {
800        checkarray_finite(&durations, "durations")?;
801        checkarray_finite(&covariates, "covariates")?;
802
803        let (n_samples_, n_features) = covariates.dim();
804        let max_iter = max_iter.unwrap_or(100);
805        let tol = tol.unwrap_or(1e-6);
806
807        if durations.len() != n_samples_ || events.len() != n_samples_ {
808            return Err(StatsError::DimensionMismatch(
809                "All input arrays must have the same number of samples".to_string(),
810            ));
811        }
812
813        // Handle stratification
814        let (strata_array, n_strata) = if let Some(strata_input) = strata {
815            if strata_input.len() != n_samples_ {
816                return Err(StatsError::DimensionMismatch(
817                    "Strata length must match number of samples".to_string(),
818                ));
819            }
820            let max_stratum = strata_input.iter().cloned().max().unwrap_or(0);
821            (Some(strata_input.to_owned()), max_stratum + 1)
822        } else {
823            (None, 1)
824        };
825
826        let time_varying_indices = time_varying_indices.unwrap_or_default();
827
828        // Initialize coefficients
829        let mut beta = Array1::zeros(n_features);
830        let mut prev_log_likelihood = f64::NEG_INFINITY;
831
832        for iteration in 0..max_iter {
833            // Calculate stratified partial likelihood and derivatives
834            let (log_likelihood, gradient, hessian) =
835                Self::stratified_partial_likelihood_derivatives(
836                    &durations,
837                    &events,
838                    &covariates,
839                    &beta,
840                    &strata_array,
841                    n_strata,
842                )?;
843
844            // Check convergence
845            if (log_likelihood - prev_log_likelihood).abs() < tol {
846                let covariance_matrix = Self::invert_hessian(&hessian)?;
847                let baseline_hazards = Self::calculate_stratifiedbaseline_hazards(
848                    &durations,
849                    &events,
850                    &covariates,
851                    &beta,
852                    &strata_array,
853                    n_strata,
854                )?;
855
856                return Ok(Self {
857                    coefficients: beta,
858                    covariance_matrix,
859                    log_likelihood,
860                    stratumbaseline_hazards: baseline_hazards,
861                    strata: strata_array,
862                    n_strata,
863                    time_varying_indices,
864                    n_iter: iteration + 1,
865                });
866            }
867
868            // Newton-Raphson update
869            let hessian_inv = Self::invert_hessian(&hessian)?;
870            let delta = hessian_inv.dot(&gradient);
871            beta = &beta + &delta;
872
873            prev_log_likelihood = log_likelihood;
874        }
875
876        Err(StatsError::ConvergenceError(format!(
877            "Extended Cox model failed to converge after {max_iter} iterations"
878        )))
879    }
880
881    /// Calculate stratified partial likelihood and derivatives
882    fn stratified_partial_likelihood_derivatives(
883        durations: &ArrayView1<f64>,
884        events: &ArrayView1<bool>,
885        covariates: &ArrayView2<f64>,
886        beta: &Array1<f64>,
887        strata: &Option<Array1<usize>>,
888        n_strata: usize,
889    ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
890        let n_samples_ = durations.len();
891        let n_features = beta.len();
892
893        let mut total_log_likelihood = 0.0;
894        let mut total_gradient = Array1::zeros(n_features);
895        let mut total_hessian = Array2::zeros((n_features, n_features));
896
897        // Process each stratum separately
898        for stratum in 0..n_strata {
899            // Get indices for this stratum
900            let stratum_indices: Vec<usize> = if let Some(ref strata_array) = strata {
901                (0..n_samples_)
902                    .filter(|&i| strata_array[i] == stratum)
903                    .collect()
904            } else {
905                (0..n_samples_).collect()
906            };
907
908            if stratum_indices.is_empty() {
909                continue;
910            }
911
912            // Sort by event time within stratum
913            let mut sorted_indices = stratum_indices.clone();
914            sorted_indices.sort_by(|&i, &j| durations[i].partial_cmp(&durations[j]).unwrap());
915
916            // Calculate partial likelihood for this stratum
917            let (stratum_ll, stratum_grad, stratum_hess) = Self::stratum_partial_likelihood(
918                durations,
919                events,
920                covariates,
921                beta,
922                &sorted_indices,
923            )?;
924
925            total_log_likelihood += stratum_ll;
926            total_gradient = &total_gradient + &stratum_grad;
927            total_hessian = &total_hessian + &stratum_hess;
928        }
929
930        Ok((total_log_likelihood, total_gradient, total_hessian))
931    }
932
933    /// Calculate partial likelihood for a single stratum
934    fn stratum_partial_likelihood(
935        durations: &ArrayView1<f64>,
936        events: &ArrayView1<bool>,
937        covariates: &ArrayView2<f64>,
938        beta: &Array1<f64>,
939        sorted_indices: &[usize],
940    ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
941        let n_features = beta.len();
942
943        let mut log_likelihood = 0.0;
944        let mut gradient = Array1::zeros(n_features);
945        let mut hessian = Array2::zeros((n_features, n_features));
946
947        for &i in sorted_indices {
948            if !events[i] {
949                continue; // Skip censored observations
950            }
951
952            let t_i = durations[i];
953            let x_i = covariates.row(i);
954
955            // Find risk set for this event time within stratum
956            let mut risk_set = Vec::new();
957            for &j in sorted_indices {
958                if durations[j] >= t_i {
959                    risk_set.push(j);
960                }
961            }
962
963            if risk_set.is_empty() {
964                continue;
965            }
966
967            // Calculate exp(beta' * x) for risk set
968            let mut exp_beta_x = Array1::zeros(risk_set.len());
969            for (k, &j) in risk_set.iter().enumerate() {
970                let x_j = covariates.row(j);
971                exp_beta_x[k] = x_j.dot(beta).exp();
972            }
973
974            let sum_exp = exp_beta_x.sum();
975            if sum_exp <= 0.0 {
976                continue;
977            }
978
979            // Update log-likelihood
980            log_likelihood += x_i.dot(beta) - sum_exp.ln();
981
982            // Update gradient
983            let mut weighted_x = Array1::<f64>::zeros(n_features);
984            for (k, &j) in risk_set.iter().enumerate() {
985                let x_j = covariates.row(j);
986                let weight = exp_beta_x[k] / sum_exp;
987                weighted_x = &weighted_x + &(weight * &x_j.to_owned());
988            }
989            gradient = &gradient + &(&x_i.to_owned() - &weighted_x);
990
991            // Update Hessian
992            for p in 0..n_features {
993                for q in 0..n_features {
994                    let mut weighted_sum = 0.0;
995                    for (k, &j) in risk_set.iter().enumerate() {
996                        let x_j = covariates.row(j);
997                        let weight = exp_beta_x[k] / sum_exp;
998                        weighted_sum += weight * x_j[p] * x_j[q];
999                    }
1000                    hessian[[p, q]] -= weighted_sum - (weighted_x[p] * weighted_x[q]);
1001                }
1002            }
1003        }
1004
1005        Ok((log_likelihood, gradient, hessian))
1006    }
1007
1008    /// Calculate baseline hazards for each stratum
1009    fn calculate_stratifiedbaseline_hazards(
1010        durations: &ArrayView1<f64>,
1011        events: &ArrayView1<bool>,
1012        covariates: &ArrayView2<f64>,
1013        beta: &Array1<f64>,
1014        strata: &Option<Array1<usize>>,
1015        n_strata: usize,
1016    ) -> Result<Vec<(Array1<f64>, Array1<f64>)>> {
1017        let n_samples_ = durations.len();
1018        let mut baseline_hazards = Vec::new();
1019
1020        for stratum in 0..n_strata {
1021            // Get indices for this stratum
1022            let stratum_indices: Vec<usize> = if let Some(ref strata_array) = strata {
1023                (0..n_samples_)
1024                    .filter(|&i| strata_array[i] == stratum)
1025                    .collect()
1026            } else {
1027                (0..n_samples_).collect()
1028            };
1029
1030            if stratum_indices.is_empty() {
1031                baseline_hazards.push((Array1::zeros(0), Array1::zeros(0)));
1032                continue;
1033            }
1034
1035            // Sort by event time
1036            let mut sorted_indices = stratum_indices.clone();
1037            sorted_indices.sort_by(|&i, &j| durations[i].partial_cmp(&durations[j]).unwrap());
1038
1039            let mut times = Vec::new();
1040            let mut cumulative_hazard = Vec::new();
1041            let mut current_cumhaz = 0.0;
1042
1043            for &i in &sorted_indices {
1044                if !events[i] {
1045                    continue;
1046                }
1047
1048                let t_i = durations[i];
1049
1050                // Calculate risk sum for this stratum
1051                let mut risk_sum = 0.0;
1052                for &j in &sorted_indices {
1053                    if durations[j] >= t_i {
1054                        let x_j = covariates.row(j);
1055                        risk_sum += x_j.dot(beta).exp();
1056                    }
1057                }
1058
1059                if risk_sum > 0.0 {
1060                    current_cumhaz += 1.0 / risk_sum; // Breslow estimator
1061                    times.push(t_i);
1062                    cumulative_hazard.push(current_cumhaz);
1063                }
1064            }
1065
1066            baseline_hazards.push((Array1::from_vec(times), Array1::from_vec(cumulative_hazard)));
1067        }
1068
1069        Ok(baseline_hazards)
1070    }
1071
1072    /// Invert Hessian matrix
1073    fn invert_hessian(hessian: &Array2<f64>) -> Result<Array2<f64>> {
1074        let neg_hessian = -hessian;
1075        scirs2_linalg::inv(&neg_hessian.view(), None)
1076            .map_err(|e| StatsError::ComputationError(format!("Failed to invert Hessian: {e}")))
1077    }
1078
1079    /// Predict hazard ratios with optional stratum
1080    pub fn predict_hazard_ratio_stratified(
1081        &self,
1082        covariates: ArrayView2<f64>,
1083        strata: Option<ArrayView1<usize>>,
1084    ) -> Result<Array1<f64>> {
1085        checkarray_finite(&covariates, "covariates")?;
1086
1087        if covariates.ncols() != self.coefficients.len() {
1088            return Err(StatsError::DimensionMismatch(format!(
1089                "covariates has {features} features, expected {expected}",
1090                features = covariates.ncols(),
1091                expected = self.coefficients.len()
1092            )));
1093        }
1094
1095        if let Some(ref strata_input) = strata {
1096            if strata_input.len() != covariates.nrows() {
1097                return Err(StatsError::DimensionMismatch(
1098                    "Strata length must match number of predictions".to_string(),
1099                ));
1100            }
1101        }
1102
1103        let mut hazard_ratios = Array1::zeros(covariates.nrows());
1104
1105        for i in 0..covariates.nrows() {
1106            let x_i = covariates.row(i);
1107            hazard_ratios[i] = x_i.dot(&self.coefficients).exp();
1108        }
1109
1110        Ok(hazard_ratios)
1111    }
1112
1113    /// Compute confidence intervals for coefficients
1114    pub fn coefficient_confidence_intervals(&self, confidencelevel: f64) -> Result<Array2<f64>> {
1115        check_probability(confidencelevel, "confidence_level")?;
1116
1117        let n_features = self.coefficients.len();
1118        let mut intervals = Array2::zeros((n_features, 2));
1119        let _alpha = (1.0 - confidencelevel) / 2.0;
1120        let z_critical = 1.96; // Approximate 95% CI
1121
1122        for i in 0..n_features {
1123            let coeff = self.coefficients[i];
1124            let se = self.covariance_matrix[[i, i]].sqrt();
1125
1126            intervals[[i, 0]] = coeff - z_critical * se; // Lower bound
1127            intervals[[i, 1]] = coeff + z_critical * se; // Upper bound
1128        }
1129
1130        Ok(intervals)
1131    }
1132}
1133
1134/// Competing risks analysis using subdistribution hazards (Fine-Gray model)
1135///
1136/// Handles multiple competing events where the occurrence of one event
1137/// prevents the observation of others.
1138#[derive(Debug, Clone)]
1139pub struct CompetingRisksModel {
1140    /// Coefficients for each competing risk
1141    pub coefficients: Vec<Array1<f64>>,
1142    /// Covariance matrices for each competing risk
1143    pub covariance_matrices: Vec<Array2<f64>>,
1144    /// Baseline cumulative incidence functions
1145    pub baseline_cifs: Vec<(Array1<f64>, Array1<f64>)>, // (times, cumulative incidence)
1146    /// Number of competing risks
1147    pub n_risks: usize,
1148    /// Log-likelihood of the fitted model
1149    pub log_likelihood: f64,
1150}
1151
1152impl CompetingRisksModel {
1153    /// Fit competing risks model using Fine-Gray subdistribution hazards
1154    ///
1155    /// # Arguments
1156    /// * `durations` - Time to event or censoring
1157    /// * `events` - Event type (0 = censored, 1 = risk 1, 2 = risk 2, etc.)
1158    /// * `covariates` - Covariate matrix
1159    /// * `n_risks` - Number of competing risks
1160    /// * `target_risk` - Risk of interest for modeling
1161    pub fn fit(
1162        durations: ArrayView1<f64>,
1163        events: ArrayView1<usize>,
1164        covariates: ArrayView2<f64>,
1165        n_risks: usize,
1166        target_risk: usize,
1167        max_iter: Option<usize>,
1168        tol: Option<f64>,
1169    ) -> Result<Self> {
1170        checkarray_finite(&durations, "durations")?;
1171        checkarray_finite(&covariates, "covariates")?;
1172        check_positive(n_risks, "n_risks")?;
1173
1174        let (n_samples_, n_features) = covariates.dim();
1175        let max_iter = max_iter.unwrap_or(100);
1176        let tol = tol.unwrap_or(1e-6);
1177
1178        if durations.len() != n_samples_ || events.len() != n_samples_ {
1179            return Err(StatsError::DimensionMismatch(
1180                "All input arrays must have the same number of samples".to_string(),
1181            ));
1182        }
1183
1184        if target_risk == 0 || target_risk > n_risks {
1185            return Err(StatsError::InvalidArgument(
1186                "target_risk must be between 1 and n_risks".to_string(),
1187            ));
1188        }
1189
1190        // For now, implement single _risk at a time
1191        // In full implementation, would fit all _risks simultaneously
1192        let mut coefficients = vec![Array1::zeros(n_features); n_risks];
1193        let mut covariance_matrices = vec![Array2::zeros((n_features, n_features)); n_risks];
1194        let mut baseline_cifs = vec![(Array1::zeros(0), Array1::zeros(0)); n_risks];
1195
1196        // Fit model for the target _risk using modified data
1197        let (modified_durations, modified_events, modified_weights) =
1198            Self::prepare_fine_gray_data(&durations, &events, target_risk)?;
1199
1200        // Initialize coefficients for target _risk
1201        let mut beta = Array1::zeros(n_features);
1202        let mut prev_log_likelihood = f64::NEG_INFINITY;
1203
1204        for _iteration in 0..max_iter {
1205            // Calculate weighted partial likelihood for subdistribution hazard
1206            let (log_likelihood, gradient, hessian) = Self::subdistribution_partial_likelihood(
1207                &modified_durations,
1208                &modified_events,
1209                &covariates,
1210                &modified_weights,
1211                &beta,
1212            )?;
1213
1214            // Check convergence
1215            if (log_likelihood - prev_log_likelihood).abs() < tol {
1216                coefficients[target_risk - 1] = beta.clone();
1217                covariance_matrices[target_risk - 1] = Self::invert_hessian(&hessian)?;
1218
1219                // Calculate baseline cumulative incidence
1220                let (times, cif) = Self::calculatebaseline_cif(
1221                    &modified_durations,
1222                    &modified_events,
1223                    &covariates,
1224                    &modified_weights,
1225                    &beta,
1226                )?;
1227                baseline_cifs[target_risk - 1] = (times, cif);
1228
1229                return Ok(Self {
1230                    coefficients,
1231                    covariance_matrices,
1232                    baseline_cifs,
1233                    n_risks,
1234                    log_likelihood,
1235                });
1236            }
1237
1238            // Newton-Raphson update
1239            let hessian_inv = Self::invert_hessian(&hessian)?;
1240            let delta = hessian_inv.dot(&gradient);
1241            beta = &beta + &delta;
1242
1243            prev_log_likelihood = log_likelihood;
1244        }
1245
1246        Err(StatsError::ConvergenceError(format!(
1247            "Competing _risks model failed to converge after {max_iter} iterations"
1248        )))
1249    }
1250
1251    /// Prepare data for Fine-Gray model with artificial censoring
1252    fn prepare_fine_gray_data(
1253        durations: &ArrayView1<f64>,
1254        events: &ArrayView1<usize>,
1255        target_risk: usize,
1256    ) -> Result<(Array1<f64>, Array1<bool>, Array1<f64>)> {
1257        let n_samples_ = durations.len();
1258        let modified_durations = durations.to_owned();
1259        let mut modified_events = Array1::from_elem(n_samples_, false);
1260        let mut weights = Array1::ones(n_samples_);
1261
1262        // Calculate Kaplan-Meier for censoring distribution
1263        let censoring_km = Self::kaplan_meier_censoring(durations, events)?;
1264
1265        for i in 0..n_samples_ {
1266            if events[i] == target_risk {
1267                // Event of interest occurred
1268                modified_events[i] = true;
1269                weights[i] = 1.0;
1270            } else if events[i] == 0 {
1271                // Censored observation
1272                modified_events[i] = false;
1273                weights[i] = 1.0;
1274            } else {
1275                // Competing event occurred - use artificial censoring
1276                modified_events[i] = false;
1277
1278                // Weight by inverse probability of censoring
1279                let km_prob = Self::interpolate_km_probability(
1280                    &censoring_km.0,
1281                    &censoring_km.1,
1282                    durations[i],
1283                );
1284                weights[i] = if km_prob > 0.0 { 1.0 / km_prob } else { 0.0 };
1285            }
1286        }
1287
1288        Ok((modified_durations, modified_events, weights))
1289    }
1290
1291    /// Calculate Kaplan-Meier estimator for censoring distribution
1292    fn kaplan_meier_censoring(
1293        durations: &ArrayView1<f64>,
1294        events: &ArrayView1<usize>,
1295    ) -> Result<(Array1<f64>, Array1<f64>)> {
1296        // Treat any event as "censoring" for the censoring distribution
1297        let censoring_events: Array1<bool> = events.mapv(|e| e == 0);
1298
1299        // Create time-event pairs and sort
1300        let mut time_event_pairs: Vec<(f64, bool)> = durations
1301            .iter()
1302            .zip(censoring_events.iter())
1303            .map(|(&t, &e)| (t, e))
1304            .collect();
1305        time_event_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
1306
1307        let mut times = Vec::new();
1308        let mut survival_probs = Vec::new();
1309        let mut current_survival = 1.0;
1310        let mut current_at_risk = time_event_pairs.len();
1311
1312        let mut i = 0;
1313        while i < time_event_pairs.len() {
1314            let current_time = time_event_pairs[i].0;
1315            let mut events_at_time = 0;
1316            let mut total_at_time = 0;
1317
1318            while i < time_event_pairs.len() && time_event_pairs[i].0 == current_time {
1319                if time_event_pairs[i].1 {
1320                    events_at_time += 1;
1321                }
1322                total_at_time += 1;
1323                i += 1;
1324            }
1325
1326            if events_at_time > 0 {
1327                let survival_this_time = 1.0 - (events_at_time as f64) / (current_at_risk as f64);
1328                current_survival *= survival_this_time;
1329
1330                times.push(current_time);
1331                survival_probs.push(current_survival);
1332            }
1333
1334            current_at_risk -= total_at_time;
1335        }
1336
1337        Ok((Array1::from_vec(times), Array1::from_vec(survival_probs)))
1338    }
1339
1340    /// Interpolate Kaplan-Meier probability at given time
1341    fn interpolate_km_probability(times: &Array1<f64>, probs: &Array1<f64>, t: f64) -> f64 {
1342        if times.is_empty() {
1343            return 1.0;
1344        }
1345
1346        if t <= times[0] {
1347            return 1.0;
1348        }
1349
1350        for i in 0..times.len() {
1351            if times[i] >= t {
1352                return probs[i];
1353            }
1354        }
1355
1356        // If t is beyond last time point, return last probability
1357        probs[probs.len() - 1]
1358    }
1359
1360    /// Calculate subdistribution partial likelihood
1361    fn subdistribution_partial_likelihood(
1362        durations: &Array1<f64>,
1363        events: &Array1<bool>,
1364        covariates: &ArrayView2<f64>,
1365        weights: &Array1<f64>,
1366        beta: &Array1<f64>,
1367    ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
1368        let n_samples_ = durations.len();
1369        let n_features = beta.len();
1370
1371        // Sort by event time
1372        let mut indices: Vec<usize> = (0..n_samples_).collect();
1373        indices.sort_by(|&i, &j| durations[i].partial_cmp(&durations[j]).unwrap());
1374
1375        let mut log_likelihood = 0.0;
1376        let mut gradient = Array1::zeros(n_features);
1377        let mut hessian = Array2::zeros((n_features, n_features));
1378
1379        for &i in &indices {
1380            if !events[i] {
1381                continue; // Skip non-events
1382            }
1383
1384            let t_i = durations[i];
1385            let x_i = covariates.row(i);
1386            let w_i = weights[i];
1387
1388            // Calculate weighted risk set
1389            let mut weighted_exp_sum = 0.0;
1390            let mut weighted_x_sum = Array1::zeros(n_features);
1391            let mut weighted_xx_sum = Array2::zeros((n_features, n_features));
1392
1393            for &j in &indices {
1394                if durations[j] >= t_i {
1395                    let x_j = covariates.row(j);
1396                    let w_j = weights[j];
1397                    let exp_beta_x = x_j.dot(beta).exp();
1398                    let weighted_exp = w_j * exp_beta_x;
1399
1400                    weighted_exp_sum += weighted_exp;
1401                    weighted_x_sum = &weighted_x_sum + &(weighted_exp * &x_j.to_owned());
1402
1403                    for p in 0..n_features {
1404                        for q in 0..n_features {
1405                            weighted_xx_sum[[p, q]] += weighted_exp * x_j[p] * x_j[q];
1406                        }
1407                    }
1408                }
1409            }
1410
1411            if weighted_exp_sum <= 0.0 {
1412                continue;
1413            }
1414
1415            // Update likelihood components
1416            let weighted_mean_x = &weighted_x_sum / weighted_exp_sum;
1417
1418            log_likelihood += w_i * (x_i.dot(beta) - weighted_exp_sum.ln());
1419            gradient = &gradient + &(w_i * (&x_i.to_owned() - &weighted_mean_x));
1420
1421            // Update Hessian
1422            let weighted_mean_xx = &weighted_xx_sum / weighted_exp_sum;
1423            let outer_product = outer_product_array(&weighted_mean_x);
1424            hessian = &hessian - &(w_i * (&weighted_mean_xx - &outer_product));
1425        }
1426
1427        Ok((log_likelihood, gradient, hessian))
1428    }
1429
1430    /// Calculate baseline cumulative incidence function
1431    fn calculatebaseline_cif(
1432        durations: &Array1<f64>,
1433        events: &Array1<bool>,
1434        covariates: &ArrayView2<f64>,
1435        weights: &Array1<f64>,
1436        beta: &Array1<f64>,
1437    ) -> Result<(Array1<f64>, Array1<f64>)> {
1438        let n_samples_ = durations.len();
1439
1440        // Sort by event time
1441        let mut indices: Vec<usize> = (0..n_samples_).collect();
1442        indices.sort_by(|&i, &j| durations[i].partial_cmp(&durations[j]).unwrap());
1443
1444        let mut times = Vec::new();
1445        let mut cumulative_incidence = Vec::new();
1446        let mut current_cif = 0.0;
1447
1448        for &i in &indices {
1449            if !events[i] {
1450                continue;
1451            }
1452
1453            let t_i = durations[i];
1454            let w_i = weights[i];
1455
1456            // Calculate weighted risk sum
1457            let mut weighted_risk_sum = 0.0;
1458            for &j in &indices {
1459                if durations[j] >= t_i {
1460                    let x_j = covariates.row(j);
1461                    let w_j = weights[j];
1462                    weighted_risk_sum += w_j * x_j.dot(beta).exp();
1463                }
1464            }
1465
1466            if weighted_risk_sum > 0.0 {
1467                current_cif += w_i / weighted_risk_sum;
1468                times.push(t_i);
1469                cumulative_incidence.push(current_cif);
1470            }
1471        }
1472
1473        Ok((
1474            Array1::from_vec(times),
1475            Array1::from_vec(cumulative_incidence),
1476        ))
1477    }
1478
1479    /// Invert Hessian matrix
1480    fn invert_hessian(hessian: &Array2<f64>) -> Result<Array2<f64>> {
1481        let neg_hessian = -hessian;
1482        scirs2_linalg::inv(&neg_hessian.view(), None)
1483            .map_err(|e| StatsError::ComputationError(format!("Failed to invert Hessian: {e}")))
1484    }
1485
1486    /// Predict cumulative incidence for target risk
1487    pub fn predict_cumulative_incidence(
1488        &self,
1489        covariates: ArrayView2<f64>,
1490        target_risk: usize,
1491        times: ArrayView1<f64>,
1492    ) -> Result<Array2<f64>> {
1493        checkarray_finite(&covariates, "covariates")?;
1494        checkarray_finite(&times, "times")?;
1495
1496        if target_risk == 0 || target_risk > self.n_risks {
1497            return Err(StatsError::InvalidArgument(
1498                "target_risk must be between 1 and n_risks".to_string(),
1499            ));
1500        }
1501
1502        let risk_idx = target_risk - 1;
1503        let n_samples_ = covariates.nrows();
1504        let n_times = times.len();
1505        let mut predictions = Array2::zeros((n_samples_, n_times));
1506
1507        let beta = &self.coefficients[risk_idx];
1508        let (baseline_times, baseline_cif) = &self.baseline_cifs[risk_idx];
1509
1510        for i in 0..n_samples_ {
1511            let x_i = covariates.row(i);
1512            let hazard_ratio = x_i.dot(beta).exp();
1513
1514            for (j, &t) in times.iter().enumerate() {
1515                // Interpolate baseline CIF at time t
1516                let baseline_value = Self::interpolatebaseline_cif(baseline_times, baseline_cif, t);
1517
1518                // Transform baseline CIF using subdistribution hazard ratio
1519                // This is a simplified transformation - full implementation would be more complex
1520                predictions[[i, j]] = 1.0 - (1.0 - baseline_value).powf(hazard_ratio);
1521            }
1522        }
1523
1524        Ok(predictions)
1525    }
1526
1527    /// Interpolate baseline cumulative incidence at given time
1528    fn interpolatebaseline_cif(times: &Array1<f64>, cif: &Array1<f64>, t: f64) -> f64 {
1529        if times.is_empty() {
1530            return 0.0;
1531        }
1532
1533        if t <= times[0] {
1534            return 0.0;
1535        }
1536
1537        for i in 0..times.len() {
1538            if times[i] >= t {
1539                return cif[i];
1540            }
1541        }
1542
1543        // If t is beyond last time point, return last CIF value
1544        cif[cif.len() - 1]
1545    }
1546}
1547
1548/// Helper function to compute outer product of array
1549#[allow(dead_code)]
1550fn outer_product_array(v: &Array1<f64>) -> Array2<f64> {
1551    let n = v.len();
1552    let mut result = Array2::zeros((n, n));
1553    for i in 0..n {
1554        for j in 0..n {
1555            result[[i, j]] = v[i] * v[j];
1556        }
1557    }
1558    result
1559}