Skip to main content

scry_learn/linear/
logistic.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Logistic regression via L-BFGS (default) or gradient descent.
3//!
4//! Supports configurable [`Penalty`] regularization: `None`, `L1`, `L2` (default),
5//! and `ElasticNet(l1_ratio)`. L1 and ElasticNet use proximal gradient descent
6//! (soft-thresholding); L-BFGS only supports `L2` and `None`.
7
8use rayon::prelude::*;
9
10use crate::dataset::Dataset;
11use crate::error::{Result, ScryLearnError};
12use crate::partial_fit::PartialFit;
13use crate::sparse::{CscMatrix, CsrMatrix};
14use crate::weights::{compute_sample_weights, ClassWeight};
15
16use super::lbfgs;
17
18/// Regularization penalty for logistic regression.
19///
20/// Controls the type of regularization applied during training:
21/// - `None` — no regularization
22/// - `L1` — Lasso penalty (promotes sparsity via proximal gradient descent)
23/// - `L2` — Ridge penalty (default, shrinks coefficients)
24/// - `ElasticNet(l1_ratio)` — Mix of L1 and L2; `l1_ratio` ∈ \[0, 1\]
25///   where 1.0 = pure L1, 0.0 = pure L2
26///
27/// # Solver compatibility
28///
29/// | Penalty | GradientDescent | L-BFGS |
30/// |---------|:-:|:-:|
31/// | `None` | ✓ | ✓ |
32/// | `L1` | ✓ | ✗ (error) |
33/// | `L2` | ✓ | ✓ |
34/// | `ElasticNet` | ✓ | ✗ (error) |
35#[derive(Debug, Clone, PartialEq, Default)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37#[non_exhaustive]
38pub enum Penalty {
39    /// No regularization.
40    None,
41    /// L1 (Lasso) penalty — promotes sparse coefficients.
42    L1,
43    /// L2 (Ridge) penalty — shrinks all coefficients (default).
44    #[default]
45    L2,
46    /// Elastic Net — mix of L1 and L2. The `f64` is the L1 ratio ∈ \[0, 1\].
47    ElasticNet(f64),
48}
49
50/// Solver algorithm for logistic regression.
51///
52/// L-BFGS is the default and recommended solver — it converges in ~10-20
53/// iterations vs 200+ for gradient descent, matching scikit-learn's default.
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
55#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
56#[non_exhaustive]
57pub enum Solver {
58    /// L-BFGS quasi-Newton optimizer (default). Fast, recommended.
59    #[default]
60    Lbfgs,
61    /// Vanilla batch gradient descent. Slower, kept for backward compatibility.
62    GradientDescent,
63}
64
65/// Logistic regression for binary/multiclass classification.
66///
67/// Uses L-BFGS (default) or gradient descent with configurable learning rate,
68/// iterations, and L2 regularization.
69#[derive(Clone)]
70#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
71#[non_exhaustive]
72pub struct LogisticRegression {
73    learning_rate: f64,
74    max_iter: usize,
75    alpha: f64, // regularization strength
76    tolerance: f64,
77    class_weight: ClassWeight,
78    #[cfg_attr(feature = "serde", serde(default))]
79    solver: Solver,
80    #[cfg_attr(feature = "serde", serde(default))]
81    penalty: Penalty,
82    weights: Vec<Vec<f64>>, // [n_classes][n_features + 1] (includes bias)
83    n_classes: usize,
84    fitted: bool,
85    #[cfg_attr(feature = "serde", serde(default))]
86    _schema_version: u32,
87}
88
89impl LogisticRegression {
90    /// Create a new logistic regression model.
91    pub fn new() -> Self {
92        Self {
93            learning_rate: 0.01,
94            max_iter: 1000,
95            alpha: 1.0,
96            tolerance: crate::constants::STRICT_TOL,
97            class_weight: ClassWeight::Uniform,
98            solver: Solver::default(),
99            penalty: Penalty::default(),
100            weights: Vec::new(),
101            n_classes: 0,
102            fitted: false,
103            _schema_version: crate::version::SCHEMA_VERSION,
104        }
105    }
106
107    /// Set the learning rate (used by `GradientDescent` solver only).
108    pub fn learning_rate(mut self, lr: f64) -> Self {
109        self.learning_rate = lr;
110        self
111    }
112
113    /// Set maximum iterations.
114    pub fn max_iter(mut self, n: usize) -> Self {
115        self.max_iter = n;
116        self
117    }
118
119    /// Set regularization strength (equivalent to `1/C` in scikit-learn).
120    ///
121    /// The meaning depends on the [`Penalty`]:
122    /// - `L2` / `L1` — multiplier on the penalty term
123    /// - `ElasticNet` — total regularization strength (split by l1_ratio)
124    /// - `None` — ignored
125    ///
126    /// To match scikit-learn's `LogisticRegression(C=x)`, use `alpha(1.0 / x)`.
127    /// The default `alpha = 1.0` corresponds to `C = 1.0`.
128    pub fn alpha(mut self, a: f64) -> Self {
129        self.alpha = a;
130        self
131    }
132
133    /// Set the regularization penalty.
134    ///
135    /// Default is [`Penalty::L2`]. Use [`Penalty::L1`] for sparse feature selection.
136    ///
137    /// # Errors
138    ///
139    /// `L1` and `ElasticNet` are **not** supported with the `Lbfgs` solver — calling
140    /// `fit()` will return `Err(InvalidParameter)`. Switch to `GradientDescent`.
141    pub fn penalty(mut self, p: Penalty) -> Self {
142        self.penalty = p;
143        self
144    }
145
146    /// Set convergence tolerance.
147    pub fn tolerance(mut self, t: f64) -> Self {
148        self.tolerance = t;
149        self
150    }
151
152    /// Alias for [`tolerance`](Self::tolerance) (sklearn convention).
153    pub fn tol(self, t: f64) -> Self {
154        self.tolerance(t)
155    }
156
157    /// Set class weighting strategy for imbalanced datasets.
158    pub fn class_weight(mut self, cw: ClassWeight) -> Self {
159        self.class_weight = cw;
160        self
161    }
162
163    /// Set the solver algorithm.
164    ///
165    /// Defaults to `Solver::Lbfgs` which is ~10-20× faster than gradient descent.
166    pub fn solver(mut self, s: Solver) -> Self {
167        self.solver = s;
168        self
169    }
170
171    /// Train the model using the configured solver.
172    ///
173    /// Uses consistent softmax for both training and inference (not one-vs-rest sigmoid).
174    ///
175    /// # Errors
176    ///
177    /// Returns `InvalidParameter` if `Penalty::L1` or `Penalty::ElasticNet` is used
178    /// with the `Lbfgs` solver (L-BFGS requires a differentiable objective).
179    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
180        data.validate_finite()?;
181        if let Some(csc) = data.sparse_csc() {
182            return self.fit_sparse(csc, &data.target);
183        }
184        // Classification requires at least 2 distinct classes.
185        if data.n_classes() < 2 {
186            return Err(ScryLearnError::InvalidParameter(
187                "LogisticRegression requires at least 2 distinct classes in the target.".into(),
188            ));
189        }
190        // Validate solver/penalty compatibility.
191        if matches!(self.solver, Solver::Lbfgs)
192            && matches!(self.penalty, Penalty::L1 | Penalty::ElasticNet(_))
193        {
194            return Err(ScryLearnError::InvalidParameter(
195                "L-BFGS solver does not support L1 or ElasticNet penalties \
196                 (non-differentiable). Use Solver::GradientDescent instead."
197                    .into(),
198            ));
199        }
200        match self.solver {
201            Solver::Lbfgs => self.fit_lbfgs(data),
202            Solver::GradientDescent => self.fit_gd(data),
203        }
204    }
205
206    /// L-BFGS solver: flatten weights, optimize, unflatten.
207    ///
208    /// Uses vectorized (batch) gradient computation for cache-friendly
209    /// column-major access patterns.
210    #[allow(clippy::needless_range_loop)]
211    fn fit_lbfgs(&mut self, data: &Dataset) -> Result<()> {
212        let n = data.n_samples();
213        let m = data.n_features();
214        if n == 0 {
215            return Err(ScryLearnError::EmptyDataset);
216        }
217
218        self.n_classes = data.n_classes();
219        let k = self.n_classes;
220
221        // Binary fast path: use sigmoid instead of softmax (halves parameter count).
222        if k == 2 {
223            return self.fit_lbfgs_binary(data);
224        }
225
226        let dim = m + 1; // features + bias
227
228        // Compute per-sample weights for class imbalance (skip for uniform).
229        let uniform = matches!(self.class_weight, ClassWeight::Uniform);
230        let sample_weights = if uniform {
231            Vec::new()
232        } else {
233            compute_sample_weights(&data.target, &self.class_weight)
234        };
235
236        // Pre-convert targets to usize.
237        let target_class: Vec<usize> = data.target.iter().map(|&t| t as usize).collect();
238
239        let alpha = self.alpha;
240        let inv_n = 1.0 / n as f64;
241
242        // Flatten initial weights: [class0_bias, class0_w1, ..., class1_bias, ...]
243        let total_params = k * dim;
244        let mut params = vec![0.0; total_params];
245
246        let config = lbfgs::LbfgsConfig {
247            max_iter: self.max_iter,
248            tolerance: self.tolerance,
249            history_size: 10,
250            wolfe: false,
251        };
252
253        // Pre-allocate batch buffers — reused every closure call.
254        let mut logits = vec![0.0; n * k]; // row-major: logits[i * k + c]
255        let mut max_logit = vec![0.0; n];
256        let mut sum_exp = vec![0.0; n];
257        let use_par = n * m >= crate::constants::LOGREG_PAR_THRESHOLD;
258        let mut feature_grad_buf = if use_par {
259            vec![0.0; m * k]
260        } else {
261            Vec::new()
262        };
263
264        lbfgs::minimize(
265            &mut params,
266            |x, grad| {
267                // ── 1. Batch compute logits: logits[i,c] = bias_c + Σ_j w_{c,j} * X_{j,i}
268                // Initialize with bias terms.
269                for i in 0..n {
270                    for c in 0..k {
271                        logits[i * k + c] = x[c * dim]; // bias
272                    }
273                }
274                // Accumulate feature contributions column-by-column (cache-friendly).
275                for j in 0..m {
276                    let feat_col = &data.features[j];
277                    for c in 0..k {
278                        let w = x[c * dim + j + 1];
279                        for i in 0..n {
280                            logits[i * k + c] += w * feat_col[i];
281                        }
282                    }
283                }
284
285                // ── 2. Batch softmax + loss computation.
286                let mut loss = 0.0;
287
288                // Find max logit per sample (for numerical stability).
289                for i in 0..n {
290                    let row = &logits[i * k..(i + 1) * k];
291                    max_logit[i] = row.iter().copied().fold(f64::NEG_INFINITY, f64::max);
292                }
293
294                // Exponentiate and sum.
295                for i in 0..n {
296                    let mut se = 0.0;
297                    for c in 0..k {
298                        let val = (logits[i * k + c] - max_logit[i]).exp();
299                        logits[i * k + c] = val; // now stores exp(logit - max)
300                        se += val;
301                    }
302                    sum_exp[i] = se;
303                }
304
305                // Cross-entropy loss: sw * (log_sum_exp - logit_tc).
306                // logit_tc was overwritten, so reconstruct from max + log(exp_val).
307                for i in 0..n {
308                    let tc = target_class[i];
309                    let log_sum = max_logit[i] + sum_exp[i].ln();
310                    let logit_tc_val = max_logit[i] + logits[i * k + tc].ln();
311                    let sw = if uniform { 1.0 } else { sample_weights[i] };
312                    loss += sw * (log_sum - logit_tc_val);
313                }
314
315                // Normalize to probabilities.
316                for i in 0..n {
317                    let se = sum_exp[i];
318                    for c in 0..k {
319                        logits[i * k + c] /= se;
320                    }
321                }
322
323                // ── 3. Batch gradient: grad_{c,j} = Σ_i sw_i * (prob_{i,c} - 1_{tc==c}) * x_{i,j}
324                // Zero gradient.
325                for g in grad.iter_mut() {
326                    *g = 0.0;
327                }
328
329                // Compute error = prob - one_hot, weighted by sample weight.
330                // Then accumulate bias gradients.
331                // We modify logits in-place to store sw * error.
332                for i in 0..n {
333                    let tc = target_class[i];
334                    let sw = if uniform { 1.0 } else { sample_weights[i] };
335                    for c in 0..k {
336                        let y_i = if tc == c { 1.0 } else { 0.0 };
337                        let error = sw * (logits[i * k + c] - y_i);
338                        logits[i * k + c] = error; // reuse buffer for weighted errors
339                        grad[c * dim] += error; // bias gradient
340                    }
341                }
342
343                // Accumulate feature gradients column-by-column (cache-friendly).
344                if use_par {
345                    let errors: &[f64] = &logits;
346                    feature_grad_buf
347                        .par_chunks_mut(k)
348                        .zip(data.features.par_iter())
349                        .for_each(|(chunk, feat_col)| {
350                            for c in 0..k {
351                                let mut acc = 0.0;
352                                for i in 0..n {
353                                    acc += errors[i * k + c] * feat_col[i];
354                                }
355                                chunk[c] = acc;
356                            }
357                        });
358                    for j in 0..m {
359                        for c in 0..k {
360                            grad[c * dim + j + 1] += feature_grad_buf[j * k + c];
361                        }
362                    }
363                } else {
364                    for j in 0..m {
365                        let feat_col = &data.features[j];
366                        for c in 0..k {
367                            let grad_idx = c * dim + j + 1;
368                            let mut acc = 0.0;
369                            for i in 0..n {
370                                acc += logits[i * k + c] * feat_col[i];
371                            }
372                            grad[grad_idx] += acc;
373                        }
374                    }
375                }
376
377                // ── 4. Average over samples + L2 regularization.
378                //      sklearn formula: min_w  C * mean(log_loss) + 0.5 * ||w||²
379                //      Equivalently:    min_w  mean(log_loss) + 0.5 * (1/C) * ||w||²
380                //      Our `alpha` = 1/C, so we scale both loss and penalty by inv_n
381                //      to get: mean(log_loss) + 0.5 * alpha * inv_n * ||w||²
382                //      This ensures regularization strength scales with dataset size
383                //      (matching sklearn's behavior).
384                loss *= inv_n;
385                for g in grad.iter_mut() {
386                    *g *= inv_n;
387                }
388
389                if alpha > 0.0 {
390                    let reg_scale = alpha * inv_n;
391                    for c in 0..k {
392                        let base = c * dim;
393                        for j in 1..dim {
394                            let w = x[base + j];
395                            loss += 0.5 * reg_scale * w * w;
396                            grad[base + j] += reg_scale * w;
397                        }
398                    }
399                }
400
401                loss
402            },
403            &config,
404        );
405
406        // Unflatten back to [n_classes][dim].
407        self.weights = (0..k)
408            .map(|c| params[c * dim..(c + 1) * dim].to_vec())
409            .collect();
410
411        self.fitted = true;
412        Ok(())
413    }
414
415    /// Binary L-BFGS fast path: single weight vector with sigmoid.
416    ///
417    /// For 2-class problems, uses sigmoid(z) instead of softmax over 2 classes,
418    /// halving the parameter count and gradient work.
419    #[allow(clippy::needless_range_loop)]
420    fn fit_lbfgs_binary(&mut self, data: &Dataset) -> Result<()> {
421        let n = data.n_samples();
422        let m = data.n_features();
423        let dim = m + 1; // features + bias
424
425        let uniform = matches!(self.class_weight, ClassWeight::Uniform);
426        let sample_weights = if uniform {
427            Vec::new()
428        } else {
429            compute_sample_weights(&data.target, &self.class_weight)
430        };
431        let target_bin: Vec<f64> = data
432            .target
433            .iter()
434            .map(|&t| if t as usize == 1 { 1.0 } else { 0.0 })
435            .collect();
436
437        let alpha = self.alpha;
438        let inv_n = 1.0 / n as f64;
439
440        let mut params = vec![0.0; dim];
441
442        let config = lbfgs::LbfgsConfig {
443            max_iter: self.max_iter,
444            tolerance: self.tolerance,
445            history_size: 10,
446            wolfe: false,
447        };
448
449        // Pre-allocate buffers reused every closure call.
450        let mut prob = vec![0.0; n];
451        let use_par = n * m >= crate::constants::LOGREG_PAR_THRESHOLD;
452
453        lbfgs::minimize(
454            &mut params,
455            |x, grad| {
456                // ── 1. Compute z_i = bias + Σ_j w_j * X_{j,i}, then sigmoid.
457                for i in 0..n {
458                    prob[i] = x[0]; // bias
459                }
460                for j in 0..m {
461                    let w = x[j + 1];
462                    let col = &data.features[j];
463                    for i in 0..n {
464                        prob[i] += w * col[i];
465                    }
466                }
467
468                // Sigmoid + loss.
469                let mut loss = 0.0;
470                for i in 0..n {
471                    let z = prob[i];
472                    // Numerically stable sigmoid and log-loss.
473                    let p = if z >= 0.0 {
474                        1.0 / (1.0 + (-z).exp())
475                    } else {
476                        let ez = z.exp();
477                        ez / (1.0 + ez)
478                    };
479                    prob[i] = p;
480
481                    // Binary cross-entropy: -[y*log(p) + (1-y)*log(1-p)]
482                    let y = target_bin[i];
483                    let log_loss = if z >= 0.0 {
484                        (1.0 - y) * z + (-z).exp().ln_1p()
485                    } else {
486                        -y * z + z.exp().ln_1p()
487                    };
488                    let sw = if uniform { 1.0 } else { sample_weights[i] };
489                    loss += sw * log_loss;
490                }
491
492                // ── 2. Gradient: (1/n) * Σ sw_i * (p_i - y_i) * x_i
493                for g in grad.iter_mut() {
494                    *g = 0.0;
495                }
496
497                // Bias gradient.
498                let mut bias_grad = 0.0;
499                for i in 0..n {
500                    let sw = if uniform { 1.0 } else { sample_weights[i] };
501                    let err = sw * (prob[i] - target_bin[i]);
502                    prob[i] = err; // reuse buffer for weighted errors
503                    bias_grad += err;
504                }
505                grad[0] = bias_grad;
506
507                // Feature gradients column-by-column.
508                let errors: &[f64] = &prob;
509                if use_par {
510                    data.features
511                        .par_iter()
512                        .zip(grad[1..=m].par_iter_mut())
513                        .for_each(|(col, g)| {
514                            let mut acc = 0.0;
515                            for i in 0..n {
516                                acc += errors[i] * col[i];
517                            }
518                            *g = acc;
519                        });
520                } else {
521                    for j in 0..m {
522                        let col = &data.features[j];
523                        let mut acc = 0.0;
524                        for i in 0..n {
525                            acc += errors[i] * col[i];
526                        }
527                        grad[j + 1] = acc;
528                    }
529                }
530
531                // ── 3. Average + L2 regularization.
532                loss *= inv_n;
533                for g in grad.iter_mut() {
534                    *g *= inv_n;
535                }
536
537                if alpha > 0.0 {
538                    let reg_scale = alpha * inv_n;
539                    for j in 1..dim {
540                        let w = x[j];
541                        loss += 0.5 * reg_scale * w * w;
542                        grad[j] += reg_scale * w;
543                    }
544                }
545
546                loss
547            },
548            &config,
549        );
550
551        // Unflatten: class 0 = zero weights (reference), class 1 = learned weights.
552        // softmax([0, z]) produces the same probabilities as sigmoid(z).
553        self.weights = vec![vec![0.0; dim], params];
554        self.fitted = true;
555        Ok(())
556    }
557
558    /// Gradient descent solver (legacy).
559    #[allow(clippy::needless_range_loop)]
560    fn fit_gd(&mut self, data: &Dataset) -> Result<()> {
561        let n = data.n_samples();
562        let m = data.n_features();
563        if n == 0 {
564            return Err(ScryLearnError::EmptyDataset);
565        }
566
567        self.n_classes = data.n_classes();
568        let dim = m + 1; // features + bias
569
570        // Initialize weights to zero.
571        self.weights = vec![vec![0.0; dim]; self.n_classes];
572
573        // Compute per-sample weights for class imbalance.
574        let sample_weights = compute_sample_weights(&data.target, &self.class_weight);
575
576        // softmax gradient descent.
577        let mut probs = vec![0.0; self.n_classes];
578
579        for _epoch in 0..self.max_iter {
580            let mut max_grad = 0.0_f64;
581            let mut gradient = vec![vec![0.0; dim]; self.n_classes];
582
583            for (i, (&sw, &target_val)) in sample_weights.iter().zip(data.target.iter()).enumerate()
584            {
585                let target_class = target_val as usize;
586
587                // Compute logits for all classes.
588                for (c, prob) in probs.iter_mut().enumerate().take(self.n_classes) {
589                    let mut z = self.weights[c][0]; // bias
590                    for j in 0..m {
591                        z += self.weights[c][j + 1] * data.features[j][i];
592                    }
593                    *prob = z;
594                }
595
596                // Softmax.
597                let max_s = probs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
598                let mut sum = 0.0;
599                for p in &mut probs[..self.n_classes] {
600                    *p = (*p - max_s).exp();
601                    sum += *p;
602                }
603                for p in &mut probs[..self.n_classes] {
604                    *p /= sum;
605                }
606
607                // Gradient: weight_i * (softmax_prob - one_hot_target).
608                for (c, (&pc, gc)) in probs
609                    .iter()
610                    .zip(gradient.iter_mut())
611                    .enumerate()
612                    .take(self.n_classes)
613                {
614                    let y_i = if target_class == c { 1.0 } else { 0.0 };
615                    let error = sw * (pc - y_i);
616
617                    gc[0] += error; // bias
618                    for j in 0..m {
619                        gc[j + 1] += error * data.features[j][i];
620                    }
621                }
622            }
623
624            // Compute L2 ratio for the penalty.
625            let (l1_ratio, l2_ratio) = match &self.penalty {
626                Penalty::None => (0.0, 0.0),
627                Penalty::L1 => (1.0, 0.0),
628                Penalty::L2 => (0.0, 1.0),
629                Penalty::ElasticNet(r) => (*r, 1.0 - *r),
630            };
631
632            let inv_n = 1.0 / n as f64;
633
634            // Update weights: gradient step + L2 regularization in gradient.
635            for (c_grad, c_w) in gradient
636                .iter_mut()
637                .zip(self.weights.iter_mut())
638                .take(self.n_classes)
639            {
640                for (j, (g, w)) in c_grad.iter_mut().zip(c_w.iter_mut()).enumerate().take(dim) {
641                    *g *= inv_n;
642                    if j > 0 {
643                        // L2 component goes into the gradient (scaled by inv_n like sklearn).
644                        *g += self.alpha * inv_n * l2_ratio * *w;
645                    }
646                    max_grad = max_grad.max(g.abs());
647                    *w -= self.learning_rate * *g;
648                }
649            }
650
651            // Proximal step for L1 component (soft-thresholding).
652            // Applied after the gradient update, only to feature weights (skip bias j=0).
653            if l1_ratio > 0.0 {
654                let threshold = self.learning_rate * self.alpha * inv_n * l1_ratio;
655                for c_w in self.weights.iter_mut().take(self.n_classes) {
656                    for w in c_w.iter_mut().skip(1) {
657                        let sign = w.signum();
658                        *w = sign * (*w * sign - threshold).max(0.0);
659                    }
660                }
661            }
662
663            if max_grad < self.tolerance {
664                break;
665            }
666        }
667
668        self.fitted = true;
669        Ok(())
670    }
671
672    /// Predict class labels.
673    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
674        crate::version::check_schema_version(self._schema_version)?;
675        if !self.fitted {
676            return Err(ScryLearnError::NotFitted);
677        }
678        let probas = self.predict_proba(features)?;
679        Ok(probas
680            .iter()
681            .map(|probs| {
682                probs
683                    .iter()
684                    .enumerate()
685                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
686                    .map_or(0.0, |(idx, _)| idx as f64)
687            })
688            .collect())
689    }
690
691    /// Predict class probabilities.
692    pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
693        if !self.fitted {
694            return Err(ScryLearnError::NotFitted);
695        }
696
697        Ok(features
698            .iter()
699            .map(|row| {
700                let mut scores: Vec<f64> = self
701                    .weights
702                    .iter()
703                    .map(|w| {
704                        let mut z = w[0]; // bias
705                        for (j, &x) in row.iter().enumerate() {
706                            if j + 1 < w.len() {
707                                z += w[j + 1] * x;
708                            }
709                        }
710                        z
711                    })
712                    .collect();
713
714                // Softmax.
715                let max_s = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
716                let mut sum = 0.0;
717                for s in &mut scores {
718                    *s = (*s - max_s).exp();
719                    sum += *s;
720                }
721                for s in &mut scores {
722                    *s /= sum;
723                }
724                scores
725            })
726            .collect())
727    }
728
729    /// Fit on sparse features using gradient descent.
730    ///
731    /// Accepts `CscMatrix` (column-oriented) for efficient gradient computation.
732    /// Only supports L2 penalty (or None). Uses gradient descent (not L-BFGS).
733    #[allow(clippy::needless_range_loop)]
734    pub fn fit_sparse(&mut self, features: &CscMatrix, target: &[f64]) -> Result<()> {
735        let n = features.n_rows();
736        let m = features.n_cols();
737        if n == 0 {
738            return Err(ScryLearnError::EmptyDataset);
739        }
740        if target.len() != n {
741            return Err(ScryLearnError::InvalidParameter(format!(
742                "target length {} != n_rows {}",
743                target.len(),
744                n
745            )));
746        }
747
748        // Determine n_classes from target.
749        let max_class = target.iter().map(|&t| t as usize).max().unwrap_or(0);
750        self.n_classes = max_class + 1;
751        if self.n_classes < 2 {
752            return Err(ScryLearnError::InvalidParameter(
753                "LogisticRegression requires at least 2 distinct classes.".into(),
754            ));
755        }
756
757        let k = self.n_classes;
758        let dim = m + 1;
759        let sample_weights = compute_sample_weights(target, &self.class_weight);
760        let target_class: Vec<usize> = target.iter().map(|&t| t as usize).collect();
761
762        self.weights = vec![vec![0.0; dim]; k];
763
764        let mut probs = vec![0.0; k];
765        let inv_n = 1.0 / n as f64;
766
767        for _epoch in 0..self.max_iter {
768            let mut max_grad = 0.0_f64;
769            let mut gradient = vec![vec![0.0; dim]; k];
770
771            for i in 0..n {
772                let tc = target_class[i];
773                let sw = sample_weights[i];
774
775                // Compute logits: bias + sparse dot.
776                for c in 0..k {
777                    probs[c] = self.weights[c][0]; // bias
778                }
779                // Accumulate feature contributions from sparse row.
780                // We need row access, so iterate all columns and check if row i has an entry.
781                // More efficient: convert to CSR, but for fit we iterate columns.
782                // Actually, build logits by iterating columns of CSC.
783                // But per-sample approach requires iterating all columns for each sample.
784                // Better: precompute logits for all samples using column iteration.
785                // For simplicity in the per-sample loop, use CSC get which is log(nnz_col).
786                for j in 0..m {
787                    let xij = features.get(i, j);
788                    if xij != 0.0 {
789                        for c in 0..k {
790                            probs[c] += self.weights[c][j + 1] * xij;
791                        }
792                    }
793                }
794
795                // Softmax.
796                let max_s = probs[..k].iter().copied().fold(f64::NEG_INFINITY, f64::max);
797                let mut sum = 0.0;
798                for p in &mut probs[..k] {
799                    *p = (*p - max_s).exp();
800                    sum += *p;
801                }
802                for p in &mut probs[..k] {
803                    *p /= sum;
804                }
805
806                // Gradient.
807                for c in 0..k {
808                    let y_i = if tc == c { 1.0 } else { 0.0 };
809                    let error = sw * (probs[c] - y_i);
810                    gradient[c][0] += error; // bias
811                    for j in 0..m {
812                        let xij = features.get(i, j);
813                        if xij != 0.0 {
814                            gradient[c][j + 1] += error * xij;
815                        }
816                    }
817                }
818            }
819
820            // Update weights.
821            for c in 0..k {
822                for j in 0..dim {
823                    gradient[c][j] *= inv_n;
824                    if j > 0 && self.alpha > 0.0 {
825                        gradient[c][j] += self.alpha * inv_n * self.weights[c][j];
826                    }
827                    max_grad = max_grad.max(gradient[c][j].abs());
828                    self.weights[c][j] -= self.learning_rate * gradient[c][j];
829                }
830            }
831
832            if max_grad < self.tolerance {
833                break;
834            }
835        }
836
837        self.fitted = true;
838        Ok(())
839    }
840
841    /// Predict class labels from sparse features (CSR format).
842    pub fn predict_sparse(&self, features: &CsrMatrix) -> Result<Vec<f64>> {
843        if !self.fitted {
844            return Err(ScryLearnError::NotFitted);
845        }
846        let probas = self.predict_proba_sparse(features)?;
847        Ok(probas
848            .iter()
849            .map(|probs| {
850                probs
851                    .iter()
852                    .enumerate()
853                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
854                    .map_or(0.0, |(idx, _)| idx as f64)
855            })
856            .collect())
857    }
858
859    /// Predict class probabilities from sparse features (CSR format).
860    pub fn predict_proba_sparse(&self, features: &CsrMatrix) -> Result<Vec<Vec<f64>>> {
861        if !self.fitted {
862            return Err(ScryLearnError::NotFitted);
863        }
864        Ok((0..features.n_rows())
865            .map(|i| {
866                let row = features.row(i);
867                let mut scores: Vec<f64> = self
868                    .weights
869                    .iter()
870                    .map(|w| {
871                        let mut z = w[0]; // bias
872                        for (col, val) in row.iter() {
873                            if col + 1 < w.len() {
874                                z += w[col + 1] * val;
875                            }
876                        }
877                        z
878                    })
879                    .collect();
880
881                // Softmax.
882                let max_s = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
883                let mut sum = 0.0;
884                for s in &mut scores {
885                    *s = (*s - max_s).exp();
886                    sum += *s;
887                }
888                for s in &mut scores {
889                    *s /= sum;
890                }
891                scores
892            })
893            .collect())
894    }
895
896    /// Get learned weights (coefficients + bias) for each class.
897    pub fn weights(&self) -> &[Vec<f64>] {
898        &self.weights
899    }
900}
901
902impl Default for LogisticRegression {
903    fn default() -> Self {
904        Self::new()
905    }
906}
907
908impl PartialFit for LogisticRegression {
909    /// Run one pass of gradient descent on the given batch.
910    ///
911    /// On the first call, initializes weights from the data dimensions and
912    /// class count. Subsequent calls preserve weights and continue updating.
913    #[allow(clippy::needless_range_loop)]
914    fn partial_fit(&mut self, data: &Dataset) -> Result<()> {
915        let n = data.n_samples();
916        let m = data.n_features();
917        if n == 0 {
918            if self.is_initialized() {
919                return Ok(());
920            }
921            return Err(ScryLearnError::EmptyDataset);
922        }
923
924        if !self.is_initialized() {
925            if data.n_classes() < 2 {
926                return Err(ScryLearnError::InvalidParameter(
927                    "LogisticRegression requires at least 2 distinct classes.".into(),
928                ));
929            }
930            self.n_classes = data.n_classes();
931            let dim = m + 1;
932            self.weights = vec![vec![0.0; dim]; self.n_classes];
933        }
934
935        let dim = m + 1;
936        let sample_weights = compute_sample_weights(&data.target, &self.class_weight);
937
938        // Pre-scan for new classes and grow weights if needed.
939        let max_class = data.target.iter().map(|&t| t as usize).max().unwrap_or(0);
940        if max_class >= self.n_classes {
941            let new_n = max_class + 1;
942            self.weights.resize(new_n, vec![0.0; dim]);
943            self.n_classes = new_n;
944        }
945
946        let mut probs = vec![0.0; self.n_classes];
947        let mut gradient = vec![vec![0.0; dim]; self.n_classes];
948
949        for (i, (&sw, &target_val)) in sample_weights.iter().zip(data.target.iter()).enumerate() {
950            let target_class = target_val as usize;
951
952            // Compute logits for all classes.
953            for (c, prob) in probs.iter_mut().enumerate().take(self.n_classes) {
954                let mut z = self.weights[c][0]; // bias
955                for j in 0..m {
956                    z += self.weights[c][j + 1] * data.features[j][i];
957                }
958                *prob = z;
959            }
960
961            // Softmax.
962            let max_s = probs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
963            let mut sum = 0.0;
964            for p in &mut probs[..self.n_classes] {
965                *p = (*p - max_s).exp();
966                sum += *p;
967            }
968            for p in &mut probs[..self.n_classes] {
969                *p /= sum;
970            }
971
972            // Accumulate gradient.
973            for (c, (&pc, gc)) in probs
974                .iter()
975                .zip(gradient.iter_mut())
976                .enumerate()
977                .take(self.n_classes)
978            {
979                let y_i = if target_class == c { 1.0 } else { 0.0 };
980                let error = sw * (pc - y_i);
981                gc[0] += error;
982                for j in 0..m {
983                    gc[j + 1] += error * data.features[j][i];
984                }
985            }
986        }
987
988        // Penalty ratios.
989        let (l1_ratio, l2_ratio) = match &self.penalty {
990            Penalty::None => (0.0, 0.0),
991            Penalty::L1 => (1.0, 0.0),
992            Penalty::L2 => (0.0, 1.0),
993            Penalty::ElasticNet(r) => (*r, 1.0 - *r),
994        };
995
996        let inv_n = 1.0 / n as f64;
997
998        // Update weights with L2 gradient and learning rate.
999        for (c_grad, c_w) in gradient
1000            .iter_mut()
1001            .zip(self.weights.iter_mut())
1002            .take(self.n_classes)
1003        {
1004            for (j, (g, w)) in c_grad.iter_mut().zip(c_w.iter_mut()).enumerate().take(dim) {
1005                *g *= inv_n;
1006                if j > 0 {
1007                    *g += self.alpha * inv_n * l2_ratio * *w;
1008                }
1009                *w -= self.learning_rate * *g;
1010            }
1011        }
1012
1013        // Proximal step for L1 component (soft-thresholding).
1014        if l1_ratio > 0.0 {
1015            let threshold = self.learning_rate * self.alpha * inv_n * l1_ratio;
1016            for c_w in self.weights.iter_mut().take(self.n_classes) {
1017                for w in c_w.iter_mut().skip(1) {
1018                    let sign = w.signum();
1019                    *w = sign * (*w * sign - threshold).max(0.0);
1020                }
1021            }
1022        }
1023
1024        self.fitted = true;
1025        Ok(())
1026    }
1027
1028    fn is_initialized(&self) -> bool {
1029        !self.weights.is_empty()
1030    }
1031}
1032
1033#[cfg(test)]
1034mod tests {
1035    use super::*;
1036
1037    #[test]
1038    fn test_logistic_linearly_separable() {
1039        // Class 0: x < 5, Class 1: x >= 5
1040        let features = vec![(0..20).map(|i| i as f64).collect()];
1041        let target: Vec<f64> = (0..20).map(|i| if i < 10 { 0.0 } else { 1.0 }).collect();
1042        let data = Dataset::new(features, target, vec!["x".into()], "class");
1043
1044        let mut lr = LogisticRegression::new().alpha(0.0).max_iter(200);
1045        lr.fit(&data).unwrap();
1046
1047        let matrix = data.feature_matrix();
1048        let preds = lr.predict(&matrix).unwrap();
1049        let acc = preds
1050            .iter()
1051            .zip(data.target.iter())
1052            .filter(|(p, t)| (*p - *t).abs() < 1e-6)
1053            .count() as f64
1054            / data.n_samples() as f64;
1055
1056        assert!(
1057            acc >= 0.85,
1058            "expected ≥85% accuracy, got {:.1}%",
1059            acc * 100.0
1060        );
1061    }
1062
1063    #[test]
1064    fn test_predict_proba_sums_to_one() {
1065        let features = vec![vec![1.0, 2.0, 3.0]];
1066        let target = vec![0.0, 1.0, 0.0];
1067        let data = Dataset::new(features, target, vec!["x".into()], "class");
1068
1069        let mut lr = LogisticRegression::new().max_iter(100);
1070        lr.fit(&data).unwrap();
1071
1072        let probas = lr.predict_proba(&[vec![2.0]]).unwrap();
1073        let sum: f64 = probas[0].iter().sum();
1074        assert!((sum - 1.0).abs() < 1e-6);
1075    }
1076
1077    #[test]
1078    fn test_gd_solver_still_works() {
1079        let features = vec![(0..20).map(|i| i as f64).collect()];
1080        let target: Vec<f64> = (0..20).map(|i| if i < 10 { 0.0 } else { 1.0 }).collect();
1081        let data = Dataset::new(features, target, vec!["x".into()], "class");
1082
1083        let mut lr = LogisticRegression::new()
1084            .solver(Solver::GradientDescent)
1085            .learning_rate(0.1)
1086            .max_iter(1000);
1087        lr.fit(&data).unwrap();
1088
1089        let matrix = data.feature_matrix();
1090        let preds = lr.predict(&matrix).unwrap();
1091        let acc = preds
1092            .iter()
1093            .zip(data.target.iter())
1094            .filter(|(p, t)| (*p - *t).abs() < 1e-6)
1095            .count() as f64
1096            / data.n_samples() as f64;
1097
1098        assert!(
1099            acc >= 0.85,
1100            "GD solver: expected ≥85% accuracy, got {:.1}%",
1101            acc * 100.0
1102        );
1103    }
1104
1105    #[test]
1106    fn test_lbfgs_is_default() {
1107        let lr = LogisticRegression::new();
1108        assert_eq!(lr.solver, Solver::Lbfgs);
1109    }
1110
1111    #[test]
1112    fn test_l1_sparsity() {
1113        // 4 features: x0 = signal, x1 = signal, x2 = noise, x3 = noise.
1114        // Signal features have clear class separation; noise features are random.
1115        let n = 200;
1116        let mut f0 = Vec::with_capacity(n);
1117        let mut f1 = Vec::with_capacity(n);
1118        let mut f2 = Vec::with_capacity(n);
1119        let mut f3 = Vec::with_capacity(n);
1120        let mut target = Vec::with_capacity(n);
1121
1122        for i in 0..n {
1123            // Strong signal: class 0 centered at -3, class 1 centered at +3
1124            let class = i32::from(i >= n / 2);
1125            let offset = if class == 0 { -3.0 } else { 3.0 };
1126            f0.push(offset + (i % 7) as f64 * 0.1);
1127            f1.push(offset * 0.5 + (i % 5) as f64 * 0.05);
1128            // Noise: same distribution regardless of class
1129            f2.push((i % 3) as f64 * 0.01);
1130            f3.push((i % 5) as f64 * 0.01);
1131            target.push(class as f64);
1132        }
1133
1134        let data = Dataset::new(
1135            vec![f0, f1, f2, f3],
1136            target,
1137            vec![
1138                "sig0".into(),
1139                "sig1".into(),
1140                "noise0".into(),
1141                "noise1".into(),
1142            ],
1143            "class",
1144        );
1145
1146        let mut lr = LogisticRegression::new()
1147            .solver(Solver::GradientDescent)
1148            .penalty(Penalty::L1)
1149            .alpha(0.1)
1150            .learning_rate(0.1)
1151            .max_iter(3000);
1152        lr.fit(&data).unwrap();
1153
1154        // Noise coefficients should be driven toward zero.
1155        let w = &lr.weights()[0];
1156        let noise_mag = w[3].abs() + w[4].abs(); // indices 3,4 = features 2,3 (skip bias)
1157        let signal_mag = w[1].abs() + w[2].abs();
1158        assert!(
1159            signal_mag > 0.01,
1160            "L1: signal coefficients should be nonzero, got {signal_mag:.6}"
1161        );
1162        assert!(
1163            noise_mag < signal_mag * 0.3,
1164            "L1: noise coefficients ({noise_mag:.4}) should be much smaller than signal ({signal_mag:.4})"
1165        );
1166    }
1167
1168    #[test]
1169    fn test_l2_no_sparsity() {
1170        let n = 200;
1171        let mut f0 = Vec::with_capacity(n);
1172        let mut f1 = Vec::with_capacity(n);
1173        let mut f2 = Vec::with_capacity(n);
1174        let mut f3 = Vec::with_capacity(n);
1175        let mut target = Vec::with_capacity(n);
1176
1177        for i in 0..n {
1178            let x = i as f64 / n as f64;
1179            f0.push(x);
1180            f1.push(x * 2.0);
1181            f2.push(0.5 + (i % 3) as f64 * 0.01);
1182            f3.push(0.5 - (i % 5) as f64 * 0.01);
1183            target.push(if x < 0.5 { 0.0 } else { 1.0 });
1184        }
1185
1186        let data = Dataset::new(
1187            vec![f0, f1, f2, f3],
1188            target,
1189            vec![
1190                "sig0".into(),
1191                "sig1".into(),
1192                "noise0".into(),
1193                "noise1".into(),
1194            ],
1195            "class",
1196        );
1197
1198        let mut lr = LogisticRegression::new()
1199            .solver(Solver::GradientDescent)
1200            .penalty(Penalty::L2)
1201            .alpha(0.01)
1202            .learning_rate(0.5)
1203            .max_iter(2000);
1204        lr.fit(&data).unwrap();
1205
1206        // L2 should keep ALL coefficients nonzero (no sparsity).
1207        let w = &lr.weights()[0];
1208        for (j, &wj) in w.iter().enumerate().skip(1) {
1209            assert!(
1210                wj.abs() > 1e-6,
1211                "L2: coefficient w[{j}] = {wj:.6} should be nonzero"
1212            );
1213        }
1214    }
1215
1216    #[test]
1217    fn test_elasticnet_middle_ground() {
1218        let n = 200;
1219        let mut f0 = Vec::with_capacity(n);
1220        let mut f1 = Vec::with_capacity(n);
1221        let mut f2 = Vec::with_capacity(n);
1222        let mut f3 = Vec::with_capacity(n);
1223        let mut target = Vec::with_capacity(n);
1224
1225        for i in 0..n {
1226            let class = i32::from(i >= n / 2);
1227            let offset = if class == 0 { -3.0 } else { 3.0 };
1228            f0.push(offset + (i % 7) as f64 * 0.1);
1229            f1.push(offset * 0.5 + (i % 5) as f64 * 0.05);
1230            f2.push((i % 3) as f64 * 0.01);
1231            f3.push((i % 5) as f64 * 0.01);
1232            target.push(class as f64);
1233        }
1234
1235        let data = Dataset::new(
1236            vec![f0, f1, f2, f3],
1237            target,
1238            vec![
1239                "sig0".into(),
1240                "sig1".into(),
1241                "noise0".into(),
1242                "noise1".into(),
1243            ],
1244            "class",
1245        );
1246
1247        let mut lr = LogisticRegression::new()
1248            .solver(Solver::GradientDescent)
1249            .penalty(Penalty::ElasticNet(0.5))
1250            .alpha(0.1)
1251            .learning_rate(0.1)
1252            .max_iter(3000);
1253        lr.fit(&data).unwrap();
1254
1255        // ElasticNet: signal coefficients should remain present.
1256        let w = &lr.weights()[0];
1257        let signal_mag = w[1].abs() + w[2].abs();
1258        assert!(
1259            signal_mag > 0.01,
1260            "ElasticNet: signal coefficients should remain nonzero, got {signal_mag:.6}"
1261        );
1262    }
1263
1264    #[test]
1265    fn test_lbfgs_rejects_l1() {
1266        let features = vec![vec![1.0, 2.0, 3.0]];
1267        let target = vec![0.0, 1.0, 0.0];
1268        let data = Dataset::new(features, target, vec!["x".into()], "class");
1269
1270        let mut lr = LogisticRegression::new()
1271            .solver(Solver::Lbfgs)
1272            .penalty(Penalty::L1)
1273            .alpha(0.1);
1274        let result = lr.fit(&data);
1275        assert!(result.is_err(), "L-BFGS should reject L1 penalty");
1276
1277        // Also reject ElasticNet.
1278        let mut lr2 = LogisticRegression::new()
1279            .solver(Solver::Lbfgs)
1280            .penalty(Penalty::ElasticNet(0.5))
1281            .alpha(0.1);
1282        let result2 = lr2.fit(&data);
1283        assert!(result2.is_err(), "L-BFGS should reject ElasticNet penalty");
1284    }
1285
1286    #[test]
1287    fn test_partial_fit_is_initialized() {
1288        let mut lr = LogisticRegression::new()
1289            .solver(Solver::GradientDescent)
1290            .learning_rate(0.1);
1291        assert!(!lr.is_initialized());
1292
1293        let features = vec![(0..20).map(|i| i as f64).collect()];
1294        let target: Vec<f64> = (0..20).map(|i| if i < 10 { 0.0 } else { 1.0 }).collect();
1295        let data = Dataset::new(features, target, vec!["x".into()], "class");
1296        lr.partial_fit(&data).unwrap();
1297        assert!(lr.is_initialized());
1298    }
1299
1300    #[test]
1301    fn test_partial_fit_convergence_10_batches() {
1302        // Linearly separable: class 0 = low x, class 1 = high x.
1303        // 10 batches of 100 samples each.
1304        let mut lr = LogisticRegression::new()
1305            .solver(Solver::GradientDescent)
1306            .learning_rate(0.1)
1307            .alpha(0.0);
1308
1309        let mut rng = fastrand::Rng::with_seed(42);
1310        for _ in 0..10 {
1311            let mut feats = Vec::with_capacity(100);
1312            let mut tgt = Vec::with_capacity(100);
1313            for _ in 0..50 {
1314                feats.push(rng.f64() * 3.0); // class 0: [0, 3)
1315                tgt.push(0.0);
1316            }
1317            for _ in 0..50 {
1318                feats.push(7.0 + rng.f64() * 3.0); // class 1: [7, 10)
1319                tgt.push(1.0);
1320            }
1321            let batch = Dataset::new(vec![feats], tgt, vec!["x".into()], "class");
1322            lr.partial_fit(&batch).unwrap();
1323        }
1324
1325        // Test on held-out points.
1326        let preds = lr.predict(&[vec![1.0], vec![9.0]]).unwrap();
1327        assert!(
1328            (preds[0] - 0.0).abs() < f64::EPSILON,
1329            "expected class 0 for x=1"
1330        );
1331        assert!(
1332            (preds[1] - 1.0).abs() < f64::EPSILON,
1333            "expected class 1 for x=9"
1334        );
1335    }
1336
1337    #[test]
1338    fn test_partial_fit_single_batch_approximates_fit() {
1339        // Normalized features to avoid large gradient magnitudes.
1340        let features = vec![(0..40).map(|i| i as f64 / 40.0).collect()];
1341        let target: Vec<f64> = (0..40).map(|i| if i < 20 { 0.0 } else { 1.0 }).collect();
1342        let data = Dataset::new(features, target, vec!["x".into()], "class");
1343
1344        // partial_fit many passes on same data
1345        let mut lr_partial = LogisticRegression::new()
1346            .solver(Solver::GradientDescent)
1347            .learning_rate(1.0)
1348            .alpha(0.0);
1349        for _ in 0..500 {
1350            lr_partial.partial_fit(&data).unwrap();
1351        }
1352
1353        // Full fit with same settings
1354        let mut lr_full = LogisticRegression::new()
1355            .solver(Solver::GradientDescent)
1356            .learning_rate(1.0)
1357            .alpha(0.0)
1358            .max_iter(500);
1359        lr_full.fit(&data).unwrap();
1360
1361        // Both should classify correctly
1362        let matrix = data.feature_matrix();
1363        let preds_partial = lr_partial.predict(&matrix).unwrap();
1364        let preds_full = lr_full.predict(&matrix).unwrap();
1365
1366        let acc_partial = preds_partial
1367            .iter()
1368            .zip(data.target.iter())
1369            .filter(|(p, t)| (*p - *t).abs() < 1e-6)
1370            .count() as f64
1371            / 40.0;
1372        let acc_full = preds_full
1373            .iter()
1374            .zip(data.target.iter())
1375            .filter(|(p, t)| (*p - *t).abs() < 1e-6)
1376            .count() as f64
1377            / 40.0;
1378
1379        assert!(
1380            acc_partial >= 0.85,
1381            "partial_fit accuracy {:.1}% too low",
1382            acc_partial * 100.0
1383        );
1384        assert!(
1385            acc_full >= 0.85,
1386            "full fit accuracy {:.1}% too low",
1387            acc_full * 100.0
1388        );
1389    }
1390
1391    #[test]
1392    fn test_sparse_fit_predict_matches_dense() {
1393        let features = vec![(0..20).map(|i| i as f64).collect()];
1394        let target: Vec<f64> = (0..20).map(|i| if i < 10 { 0.0 } else { 1.0 }).collect();
1395        let data = Dataset::new(features.clone(), target.clone(), vec!["x".into()], "class");
1396
1397        let mut lr_dense = LogisticRegression::new()
1398            .solver(Solver::GradientDescent)
1399            .alpha(0.0)
1400            .learning_rate(0.1)
1401            .max_iter(500);
1402        lr_dense.fit(&data).unwrap();
1403
1404        let csc = CscMatrix::from_dense(&features);
1405        let mut lr_sparse = LogisticRegression::new()
1406            .alpha(0.0)
1407            .learning_rate(0.1)
1408            .max_iter(500);
1409        lr_sparse.fit_sparse(&csc, &target).unwrap();
1410
1411        let matrix = data.feature_matrix();
1412        let preds_dense = lr_dense.predict(&matrix).unwrap();
1413        let csr = CsrMatrix::from_dense(&matrix);
1414        let preds_sparse = lr_sparse.predict_sparse(&csr).unwrap();
1415
1416        let acc_dense: usize = preds_dense
1417            .iter()
1418            .zip(target.iter())
1419            .filter(|(p, t)| (*p - *t).abs() < 1e-6)
1420            .count();
1421        let acc_sparse: usize = preds_sparse
1422            .iter()
1423            .zip(target.iter())
1424            .filter(|(p, t)| (*p - *t).abs() < 1e-6)
1425            .count();
1426
1427        assert!(acc_dense >= 17, "Dense accuracy too low: {acc_dense}/20");
1428        assert!(acc_sparse >= 17, "Sparse accuracy too low: {acc_sparse}/20");
1429    }
1430
1431    #[test]
1432    fn test_binary_sigmoid_matches_predictions() {
1433        // Verify binary sigmoid fast path produces correct classifications.
1434        let features = vec![(0..40).map(|i| i as f64).collect()];
1435        let target: Vec<f64> = (0..40).map(|i| if i < 20 { 0.0 } else { 1.0 }).collect();
1436        let data = Dataset::new(features, target.clone(), vec!["x".into()], "class");
1437
1438        let mut lr = LogisticRegression::new().alpha(0.01).max_iter(200);
1439        lr.fit(&data).unwrap();
1440
1441        // Verify weights structure: class 0 should be all zeros (reference class).
1442        assert_eq!(lr.weights().len(), 2, "should have 2 weight vectors");
1443        assert!(
1444            lr.weights()[0].iter().all(|&w| w == 0.0),
1445            "class 0 weights should all be zero (reference class)"
1446        );
1447        assert!(
1448            lr.weights()[1].iter().any(|&w| w != 0.0),
1449            "class 1 weights should be non-zero"
1450        );
1451
1452        // Verify predictions are correct.
1453        let matrix = data.feature_matrix();
1454        let preds = lr.predict(&matrix).unwrap();
1455        let acc = preds
1456            .iter()
1457            .zip(target.iter())
1458            .filter(|(p, t)| (*p - *t).abs() < 1e-6)
1459            .count() as f64
1460            / 40.0;
1461        assert!(
1462            acc >= 0.90,
1463            "binary sigmoid: expected ≥90% accuracy, got {:.1}%",
1464            acc * 100.0
1465        );
1466
1467        // Verify probabilities sum to 1.
1468        let probas = lr.predict_proba(&[vec![5.0], vec![35.0]]).unwrap();
1469        for (idx, probs) in probas.iter().enumerate() {
1470            let sum: f64 = probs.iter().sum();
1471            assert!(
1472                (sum - 1.0).abs() < 1e-6,
1473                "probabilities for sample {idx} should sum to 1, got {sum}"
1474            );
1475        }
1476        // Low x should predict class 0, high x should predict class 1.
1477        assert!(
1478            probas[0][0] > probas[0][1],
1479            "x=5 should have higher prob for class 0"
1480        );
1481        assert!(
1482            probas[1][1] > probas[1][0],
1483            "x=35 should have higher prob for class 1"
1484        );
1485    }
1486}