Skip to main content

scirs2_cluster/
density_ratio.rs

1//! Density ratio estimation algorithms.
2//!
3//! Density ratio estimation estimates the ratio w(x) = p(x) / q(x) between
4//! two probability densities from finite samples. These estimates are used in
5//! covariate shift adaptation, importance weighting, two-sample testing, and
6//! change-point detection.
7//!
8//! # Algorithms
9//!
10//! * [`RuLSIF`] - Relative Unconstrained Least-Squares Importance Fitting
11//! * [`KLIEP`] - KL Importance Estimation Procedure
12//! * [`KullbackLeibler`] - Plug-in KL divergence ratio estimator
13//!
14//! # Example
15//!
16//! ```rust
17//! use scirs2_cluster::density_ratio::{DensityRatioEstimator, KLIEP, importance_weights, ImportanceMethod};
18//! use scirs2_core::ndarray::Array2;
19//!
20//! // Source and target samples
21//! let source = Array2::from_shape_vec((6, 1), vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5]).expect("operation should succeed");
22//! let target = Array2::from_shape_vec((6, 1), vec![0.5, 0.6, 0.7, 0.8, 0.9, 1.0]).expect("operation should succeed");
23//!
24//! let weights = importance_weights(source.view(), target.view(), ImportanceMethod::Kliep).expect("operation should succeed");
25//! assert_eq!(weights.len(), 6);
26//! ```
27
28use std::f64::consts::PI;
29
30use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
31
32use crate::error::{ClusteringError, Result};
33
34// ─── Trait ──────────────────────────────────────────────────────────────────
35
36/// Trait for density ratio estimators: w(x) ≈ p_numerator(x) / p_denominator(x).
37pub trait DensityRatioEstimator {
38    /// Fit the estimator using numerator (`p`) and denominator (`q`) samples.
39    ///
40    /// Both arrays have shape `(n_samples, n_features)`.
41    fn fit(&mut self, numerator: ArrayView2<f64>, denominator: ArrayView2<f64>) -> Result<()>;
42
43    /// Predict importance weights w(x) ≈ p(x) / q(x) for new `query` points.
44    ///
45    /// Returns a vector of length `query.shape()[0]`.
46    fn predict(&self, query: ArrayView2<f64>) -> Result<Vec<f64>>;
47}
48
49// ─── Kernel helpers ─────────────────────────────────────────────────────────
50
51/// Compute an RBF (Gaussian) kernel matrix between rows of `a` and rows of `b`.
52///
53/// `K[i, j] = exp(-||a_i - b_j||^2 / (2 * sigma^2))`
54fn rbf_kernel_matrix(a: ArrayView2<f64>, b: ArrayView2<f64>, sigma: f64) -> Array2<f64> {
55    let na = a.shape()[0];
56    let nb = b.shape()[0];
57    let d = a.shape()[1];
58    let two_s2 = 2.0 * sigma * sigma;
59
60    let mut k = Array2::<f64>::zeros((na, nb));
61    for i in 0..na {
62        for j in 0..nb {
63            let mut sq = 0.0_f64;
64            for f in 0..d {
65                let diff = a[[i, f]] - b[[j, f]];
66                sq += diff * diff;
67            }
68            k[[i, j]] = (-sq / two_s2).exp();
69        }
70    }
71    k
72}
73
74/// Compute squared Euclidean distance between two row slices.
75fn sq_dist(a: ArrayView1<f64>, b: ArrayView1<f64>) -> f64 {
76    a.iter()
77        .zip(b.iter())
78        .map(|(&ai, &bi)| (ai - bi) * (ai - bi))
79        .sum()
80}
81
82/// Median heuristic for the RBF bandwidth: sigma = median(pairwise distances) / sqrt(2).
83fn median_bandwidth(data: ArrayView2<f64>) -> f64 {
84    let n = data.shape()[0];
85    if n <= 1 {
86        return 1.0;
87    }
88    let mut dists = Vec::with_capacity(n * (n - 1) / 2);
89    for i in 0..n {
90        for j in (i + 1)..n {
91            dists.push(sq_dist(data.row(i), data.row(j)).sqrt());
92        }
93    }
94    dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
95    let med = dists[dists.len() / 2];
96    if med < 1e-10 {
97        1.0
98    } else {
99        med / (2.0_f64).sqrt()
100    }
101}
102
103/// Simple Cholesky-based solver for small positive-definite systems.
104/// Solves `A x = b` where `A` is `(m×m)` symmetric positive-definite.
105fn solve_pd_system(a: &Array2<f64>, b: &Array1<f64>) -> Result<Array1<f64>> {
106    let m = a.shape()[0];
107    if m == 0 {
108        return Err(ClusteringError::ComputationError(
109            "Empty system".to_string(),
110        ));
111    }
112    // Cholesky: L s.t. A = L L^T
113    let mut l = Array2::<f64>::zeros((m, m));
114    for i in 0..m {
115        for j in 0..=i {
116            let mut s = a[[i, j]];
117            for k in 0..j {
118                s -= l[[i, k]] * l[[j, k]];
119            }
120            if i == j {
121                if s < 0.0 {
122                    s = 1e-12; // numerical fallback
123                }
124                l[[i, j]] = s.sqrt();
125            } else if l[[j, j]].abs() < 1e-15 {
126                l[[i, j]] = 0.0;
127            } else {
128                l[[i, j]] = s / l[[j, j]];
129            }
130        }
131    }
132    // Forward solve: L y = b
133    let mut y = Array1::<f64>::zeros(m);
134    for i in 0..m {
135        let mut s = b[i];
136        for k in 0..i {
137            s -= l[[i, k]] * y[k];
138        }
139        if l[[i, i]].abs() < 1e-15 {
140            y[i] = 0.0;
141        } else {
142            y[i] = s / l[[i, i]];
143        }
144    }
145    // Backward solve: L^T x = y
146    let mut x = Array1::<f64>::zeros(m);
147    for i in (0..m).rev() {
148        let mut s = y[i];
149        for k in (i + 1)..m {
150            s -= l[[k, i]] * x[k];
151        }
152        if l[[i, i]].abs() < 1e-15 {
153            x[i] = 0.0;
154        } else {
155            x[i] = s / l[[i, i]];
156        }
157    }
158    Ok(x)
159}
160
161// ─── KullbackLeibler ─────────────────────────────────────────────────────────
162
163/// Plug-in KL divergence density ratio estimator.
164///
165/// Estimates w(x) = p(x)/q(x) as a ratio of kernel density estimates.
166/// Uses Gaussian kernels with a common bandwidth selected by the median heuristic.
167///
168/// This estimator is simple and intuitive but can suffer from the curse of
169/// dimensionality in high dimensions; prefer [`RuLSIF`] or [`KLIEP`] in
170/// production usage.
171#[derive(Debug, Clone)]
172pub struct KullbackLeibler {
173    /// Bandwidth for the kernel density estimates.
174    pub sigma: Option<f64>,
175    /// Numerator samples stored after `fit`.
176    numerator_samples: Option<Array2<f64>>,
177    /// Denominator samples stored after `fit`.
178    denominator_samples: Option<Array2<f64>>,
179    /// Resolved bandwidth.
180    resolved_sigma: f64,
181}
182
183impl KullbackLeibler {
184    /// Create a new estimator.  Pass `None` to use the median heuristic.
185    pub fn new(sigma: Option<f64>) -> Self {
186        Self {
187            sigma,
188            numerator_samples: None,
189            denominator_samples: None,
190            resolved_sigma: 1.0,
191        }
192    }
193}
194
195impl DensityRatioEstimator for KullbackLeibler {
196    fn fit(&mut self, numerator: ArrayView2<f64>, denominator: ArrayView2<f64>) -> Result<()> {
197        if numerator.shape()[1] != denominator.shape()[1] {
198            return Err(ClusteringError::InvalidInput(
199                "numerator and denominator must have the same number of features".to_string(),
200            ));
201        }
202        self.resolved_sigma = self.sigma.unwrap_or_else(|| {
203            let s1 = median_bandwidth(numerator);
204            let s2 = median_bandwidth(denominator);
205            (s1 + s2) / 2.0
206        });
207        self.numerator_samples = Some(numerator.to_owned());
208        self.denominator_samples = Some(denominator.to_owned());
209        Ok(())
210    }
211
212    fn predict(&self, query: ArrayView2<f64>) -> Result<Vec<f64>> {
213        let num_samples = self.numerator_samples.as_ref().ok_or_else(|| {
214            ClusteringError::InvalidState("Call fit() before predict()".to_string())
215        })?;
216        let den_samples = self.denominator_samples.as_ref().ok_or_else(|| {
217            ClusteringError::InvalidState("Call fit() before predict()".to_string())
218        })?;
219
220        let sigma = self.resolved_sigma;
221        let n_num = num_samples.shape()[0] as f64;
222        let n_den = den_samples.shape()[0] as f64;
223        let n_query = query.shape()[0];
224
225        // Two-pi normalization factor (same for both, cancels in ratio)
226        let d = query.shape()[1] as f64;
227        let norm = (2.0 * PI * sigma * sigma).powf(d / 2.0);
228
229        let mut weights = Vec::with_capacity(n_query);
230        for qi in 0..n_query {
231            let xq = query.row(qi);
232
233            // KDE for numerator
234            let p_num: f64 = num_samples
235                .rows()
236                .into_iter()
237                .map(|row| (-sq_dist(xq, row) / (2.0 * sigma * sigma)).exp())
238                .sum::<f64>()
239                / (n_num * norm);
240
241            // KDE for denominator
242            let p_den: f64 = den_samples
243                .rows()
244                .into_iter()
245                .map(|row| (-sq_dist(xq, row) / (2.0 * sigma * sigma)).exp())
246                .sum::<f64>()
247                / (n_den * norm);
248
249            let ratio = if p_den < 1e-300 { 0.0 } else { p_num / p_den };
250            weights.push(ratio.max(0.0));
251        }
252        Ok(weights)
253    }
254}
255
256// ─── KLIEP ───────────────────────────────────────────────────────────────────
257
258/// KL Importance Estimation Procedure (KLIEP).
259///
260/// Estimates w(x) = p(x)/q(x) by directly minimising the KL divergence from
261/// p to the model distribution ĝ = w·q.  The model is a non-negative linear
262/// combination of RBF basis functions centred on numerator samples (or a
263/// randomly selected subset).
264///
265/// Reference: Sugiyama et al. (2008), "Direct importance estimation with model
266/// selection and its application to covariate shift adaptation", NIPS.
267#[derive(Debug, Clone)]
268pub struct KLIEP {
269    /// RBF bandwidth. `None` selects via median heuristic.
270    pub sigma: Option<f64>,
271    /// L2 regularisation on the coefficient vector.
272    pub lambda: f64,
273    /// Maximum gradient ascent iterations.
274    pub max_iter: usize,
275    /// Learning rate for gradient ascent.
276    pub learning_rate: f64,
277    /// Number of basis centres (subsampled from numerator).  `None` = all.
278    pub n_centers: Option<usize>,
279
280    // fitted state
281    alpha: Option<Array1<f64>>,
282    centers: Option<Array2<f64>>,
283    resolved_sigma: f64,
284}
285
286impl KLIEP {
287    /// Create a new KLIEP estimator with default hyper-parameters.
288    pub fn new() -> Self {
289        Self {
290            sigma: None,
291            lambda: 1e-3,
292            max_iter: 1000,
293            learning_rate: 1e-3,
294            n_centers: None,
295            alpha: None,
296            centers: None,
297            resolved_sigma: 1.0,
298        }
299    }
300
301    /// Builder: set bandwidth.
302    pub fn with_sigma(mut self, sigma: f64) -> Self {
303        self.sigma = Some(sigma);
304        self
305    }
306
307    /// Builder: set regularisation.
308    pub fn with_lambda(mut self, lambda: f64) -> Self {
309        self.lambda = lambda;
310        self
311    }
312
313    /// Builder: set number of RBF centres.
314    pub fn with_n_centers(mut self, n: usize) -> Self {
315        self.n_centers = Some(n);
316        self
317    }
318}
319
320impl Default for KLIEP {
321    fn default() -> Self {
322        Self::new()
323    }
324}
325
326impl DensityRatioEstimator for KLIEP {
327    fn fit(&mut self, numerator: ArrayView2<f64>, denominator: ArrayView2<f64>) -> Result<()> {
328        let n_num = numerator.shape()[0];
329        let n_den = denominator.shape()[0];
330        let n_feat = numerator.shape()[1];
331
332        if n_feat != denominator.shape()[1] {
333            return Err(ClusteringError::InvalidInput(
334                "Feature dimension mismatch".to_string(),
335            ));
336        }
337        if n_num == 0 || n_den == 0 {
338            return Err(ClusteringError::InvalidInput(
339                "Empty sample arrays".to_string(),
340            ));
341        }
342
343        // Resolve bandwidth
344        let sigma = self.sigma.unwrap_or_else(|| {
345            let s1 = median_bandwidth(numerator);
346            let s2 = median_bandwidth(denominator);
347            (s1 + s2) / 2.0
348        });
349        self.resolved_sigma = sigma;
350
351        // Select basis centres (subset of numerator)
352        let n_centers = self.n_centers.unwrap_or(n_num).min(n_num);
353        let centers: Array2<f64> = if n_centers == n_num {
354            numerator.to_owned()
355        } else {
356            // Evenly-spaced subsample (deterministic, no randomness needed here)
357            let step = n_num / n_centers;
358            let mut c = Array2::<f64>::zeros((n_centers, n_feat));
359            for (ci, ni) in (0..n_num).step_by(step.max(1)).take(n_centers).enumerate() {
360                c.row_mut(ci).assign(&numerator.row(ni));
361            }
362            c
363        };
364
365        // Kernel matrices: phi_num (n_num × n_centers), phi_den (n_den × n_centers)
366        let phi_num = rbf_kernel_matrix(numerator, centers.view(), sigma);
367        let phi_den = rbf_kernel_matrix(denominator, centers.view(), sigma);
368
369        // Initialise alpha uniformly
370        let mut alpha = Array1::<f64>::from_elem(n_centers, 1.0 / n_centers as f64);
371
372        let lr = self.learning_rate;
373        let lambda = self.lambda;
374
375        // Gradient ascent on KL objective:
376        //   J(alpha) = E_p[log w(x)] - log E_q[w(x)]
377        // with constraint that w >= 0 and is normalised on q.
378        for _iter in 0..self.max_iter {
379            // w_num[i] = phi_num[i,:] . alpha
380            let w_num: Vec<f64> = (0..n_num)
381                .map(|i| {
382                    (0..n_centers)
383                        .map(|j| phi_num[[i, j]] * alpha[j])
384                        .sum::<f64>()
385                        .max(1e-15)
386                })
387                .collect();
388
389            // w_den[i] = phi_den[i,:] . alpha
390            let w_den: Vec<f64> = (0..n_den)
391                .map(|i| {
392                    (0..n_centers)
393                        .map(|j| phi_den[[i, j]] * alpha[j])
394                        .sum::<f64>()
395                        .max(1e-15)
396                })
397                .collect();
398
399            let z_den: f64 = w_den.iter().sum::<f64>() / n_den as f64;
400            if z_den < 1e-300 {
401                break;
402            }
403
404            // Gradient w.r.t. alpha
405            let mut grad = Array1::<f64>::zeros(n_centers);
406            // Positive part: E_p[phi / w]
407            for i in 0..n_num {
408                for j in 0..n_centers {
409                    grad[j] += phi_num[[i, j]] / w_num[i];
410                }
411            }
412            for j in 0..n_centers {
413                grad[j] /= n_num as f64;
414            }
415            // Negative part: E_q[phi] / z_den
416            for i in 0..n_den {
417                for j in 0..n_centers {
418                    grad[j] -= phi_den[[i, j]] / (n_den as f64 * z_den);
419                }
420            }
421            // L2 regularisation gradient
422            for j in 0..n_centers {
423                grad[j] -= lambda * alpha[j];
424            }
425
426            // Update + project non-negative
427            for j in 0..n_centers {
428                alpha[j] = (alpha[j] + lr * grad[j]).max(0.0);
429            }
430
431            // Re-normalise so that mean w on denominator = 1
432            let z: f64 = (0..n_den)
433                .map(|i| {
434                    (0..n_centers)
435                        .map(|j| phi_den[[i, j]] * alpha[j])
436                        .sum::<f64>()
437                        .max(0.0)
438                })
439                .sum::<f64>()
440                / n_den as f64;
441            if z > 1e-15 {
442                for j in 0..n_centers {
443                    alpha[j] /= z;
444                }
445            }
446        }
447
448        self.alpha = Some(alpha);
449        self.centers = Some(centers);
450        Ok(())
451    }
452
453    fn predict(&self, query: ArrayView2<f64>) -> Result<Vec<f64>> {
454        let alpha = self.alpha.as_ref().ok_or_else(|| {
455            ClusteringError::InvalidState("Call fit() before predict()".to_string())
456        })?;
457        let centers = self.centers.as_ref().ok_or_else(|| {
458            ClusteringError::InvalidState("Call fit() before predict()".to_string())
459        })?;
460
461        let phi = rbf_kernel_matrix(query, centers.view(), self.resolved_sigma);
462        let n_query = query.shape()[0];
463        let n_centers = centers.shape()[0];
464
465        let mut weights = Vec::with_capacity(n_query);
466        for i in 0..n_query {
467            let w: f64 = (0..n_centers)
468                .map(|j| phi[[i, j]] * alpha[j])
469                .sum::<f64>()
470                .max(0.0);
471            weights.push(w);
472        }
473        Ok(weights)
474    }
475}
476
477// ─── RuLSIF ──────────────────────────────────────────────────────────────────
478
479/// Relative Unconstrained Least-Squares Importance Fitting (RuLSIF).
480///
481/// Minimises a least-squares criterion with L2 regularisation to estimate
482/// the *relative* density ratio:
483///
484///   r_alpha(x) = p(x) / (alpha * p(x) + (1 - alpha) * q(x))
485///
486/// Setting `alpha = 0` recovers plain uLSIF.
487///
488/// Reference: Yamada et al. (2013), "Relative Density-Ratio Estimation for
489/// Robust Distribution Comparison", Neural Computation.
490#[derive(Debug, Clone)]
491pub struct RuLSIF {
492    /// Relative mixture parameter alpha ∈ [0, 1).  0 = uLSIF.
493    pub alpha: f64,
494    /// RBF bandwidth.  `None` = median heuristic.
495    pub sigma: Option<f64>,
496    /// L2 regularisation.
497    pub lambda: f64,
498    /// Number of RBF centres.  `None` = use all denominator samples.
499    pub n_centers: Option<usize>,
500
501    // fitted state
502    theta: Option<Array1<f64>>,
503    centers: Option<Array2<f64>>,
504    resolved_sigma: f64,
505}
506
507impl RuLSIF {
508    /// Create a new RuLSIF estimator.
509    pub fn new(alpha: f64, lambda: f64) -> Self {
510        Self {
511            alpha: alpha.clamp(0.0, 1.0 - 1e-6),
512            sigma: None,
513            lambda,
514            n_centers: None,
515            theta: None,
516            centers: None,
517            resolved_sigma: 1.0,
518        }
519    }
520
521    /// Builder: set bandwidth.
522    pub fn with_sigma(mut self, sigma: f64) -> Self {
523        self.sigma = Some(sigma);
524        self
525    }
526
527    /// Builder: set number of RBF centres.
528    pub fn with_n_centers(mut self, n: usize) -> Self {
529        self.n_centers = Some(n);
530        self
531    }
532}
533
534impl Default for RuLSIF {
535    fn default() -> Self {
536        Self::new(0.1, 1e-3)
537    }
538}
539
540impl DensityRatioEstimator for RuLSIF {
541    fn fit(&mut self, numerator: ArrayView2<f64>, denominator: ArrayView2<f64>) -> Result<()> {
542        let n_num = numerator.shape()[0];
543        let n_den = denominator.shape()[0];
544        let n_feat = numerator.shape()[1];
545
546        if n_feat != denominator.shape()[1] {
547            return Err(ClusteringError::InvalidInput(
548                "Feature dimension mismatch".to_string(),
549            ));
550        }
551        if n_num == 0 || n_den == 0 {
552            return Err(ClusteringError::InvalidInput(
553                "Empty sample arrays".to_string(),
554            ));
555        }
556
557        // Resolve bandwidth
558        let sigma = self.sigma.unwrap_or_else(|| {
559            let s1 = median_bandwidth(numerator);
560            let s2 = median_bandwidth(denominator);
561            (s1 + s2) / 2.0
562        });
563        self.resolved_sigma = sigma;
564
565        // Select centres from denominator (or subsample)
566        let n_centers = self.n_centers.unwrap_or(n_den).min(n_den);
567        let centers: Array2<f64> = if n_centers == n_den {
568            denominator.to_owned()
569        } else {
570            let step = (n_den / n_centers).max(1);
571            let mut c = Array2::<f64>::zeros((n_centers, n_feat));
572            for (ci, di) in (0..n_den).step_by(step).take(n_centers).enumerate() {
573                c.row_mut(ci).assign(&denominator.row(di));
574            }
575            c
576        };
577
578        // Kernel matrices
579        let phi_num = rbf_kernel_matrix(numerator, centers.view(), sigma); // n_num × n_centers
580        let phi_den = rbf_kernel_matrix(denominator, centers.view(), sigma); // n_den × n_centers
581
582        let m = n_centers;
583
584        // hat_H = (alpha/n_num) * Phi_num^T Phi_num + ((1-alpha)/n_den) * Phi_den^T Phi_den + lambda * I
585        let mut h2 = Array2::<f64>::zeros((m, m));
586        for i in 0..n_num {
587            for j in 0..m {
588                for k in j..m {
589                    let v = phi_num[[i, j]] * phi_num[[i, k]];
590                    h2[[j, k]] += v;
591                    if k != j {
592                        h2[[k, j]] += v;
593                    }
594                }
595            }
596        }
597        let mut h3 = Array2::<f64>::zeros((m, m));
598        for i in 0..n_den {
599            for j in 0..m {
600                for k in j..m {
601                    let v = phi_den[[i, j]] * phi_den[[i, k]];
602                    h3[[j, k]] += v;
603                    if k != j {
604                        h3[[k, j]] += v;
605                    }
606                }
607            }
608        }
609        let mut hat_h = Array2::<f64>::zeros((m, m));
610        for j in 0..m {
611            for k in 0..m {
612                hat_h[[j, k]] = self.alpha * h2[[j, k]] / n_num as f64
613                    + (1.0 - self.alpha) * h3[[j, k]] / n_den as f64;
614            }
615            hat_h[[j, j]] += self.lambda;
616        }
617
618        // h_vec = (1/n_num) * sum_i phi_num[i,:]
619        let mut h_vec = Array1::<f64>::zeros(m);
620        for i in 0..n_num {
621            for j in 0..m {
622                h_vec[j] += phi_num[[i, j]];
623            }
624        }
625        for j in 0..m {
626            h_vec[j] /= n_num as f64;
627        }
628
629        // Solve hat_h * theta = h_vec
630        let theta = solve_pd_system(&hat_h, &h_vec)?;
631        // Project non-negative
632        let theta = theta.mapv(|v: f64| v.max(0.0));
633
634        self.theta = Some(theta);
635        self.centers = Some(centers);
636        Ok(())
637    }
638
639    fn predict(&self, query: ArrayView2<f64>) -> Result<Vec<f64>> {
640        let theta = self.theta.as_ref().ok_or_else(|| {
641            ClusteringError::InvalidState("Call fit() before predict()".to_string())
642        })?;
643        let centers = self.centers.as_ref().ok_or_else(|| {
644            ClusteringError::InvalidState("Call fit() before predict()".to_string())
645        })?;
646
647        let phi = rbf_kernel_matrix(query, centers.view(), self.resolved_sigma);
648        let n_query = query.shape()[0];
649        let n_centers = centers.shape()[0];
650
651        let mut weights = Vec::with_capacity(n_query);
652        for i in 0..n_query {
653            let w: f64 = (0..n_centers)
654                .map(|j| phi[[i, j]] * theta[j])
655                .sum::<f64>()
656                .max(0.0);
657            weights.push(w);
658        }
659        Ok(weights)
660    }
661}
662
663// ─── Convenience API ─────────────────────────────────────────────────────────
664
665/// Method selector for [`importance_weights`].
666#[derive(Debug, Clone, Copy, PartialEq, Eq)]
667pub enum ImportanceMethod {
668    /// KL Importance Estimation Procedure.
669    Kliep,
670    /// Relative Unconstrained Least-Squares Importance Fitting.
671    RuLSIF,
672    /// Plug-in KDE ratio.
673    KullbackLeibler,
674}
675
676/// Compute importance weights w(x_i) ≈ p_target(x_i) / p_source(x_i) for
677/// each source sample x_i.
678///
679/// # Arguments
680///
681/// * `source_samples` - Samples from the source (denominator) distribution.
682/// * `target_samples` - Samples from the target (numerator) distribution.
683/// * `method`         - Which estimator to use.
684///
685/// # Returns
686///
687/// A vector of non-negative importance weights, one per source sample.
688pub fn importance_weights(
689    source_samples: ArrayView2<f64>,
690    target_samples: ArrayView2<f64>,
691    method: ImportanceMethod,
692) -> Result<Vec<f64>> {
693    match method {
694        ImportanceMethod::Kliep => {
695            let mut est = KLIEP::new();
696            est.fit(target_samples, source_samples)?;
697            est.predict(source_samples)
698        }
699        ImportanceMethod::RuLSIF => {
700            let mut est = RuLSIF::default();
701            est.fit(target_samples, source_samples)?;
702            est.predict(source_samples)
703        }
704        ImportanceMethod::KullbackLeibler => {
705            let mut est = KullbackLeibler::new(None);
706            est.fit(target_samples, source_samples)?;
707            est.predict(source_samples)
708        }
709    }
710}
711
712/// Direct KLIEP density ratio estimation.
713///
714/// # Arguments
715///
716/// * `numerator_samples`   - Samples from p.
717/// * `denominator_samples` - Samples from q.
718/// * `kernel_centers`      - RBF centres (if `None`, uses numerator samples).
719/// * `lambda`              - L2 regularisation.
720///
721/// # Returns
722///
723/// Importance weight w(x_i) for each numerator sample.
724pub fn density_ratio_kliep(
725    numerator_samples: ArrayView2<f64>,
726    denominator_samples: ArrayView2<f64>,
727    kernel_centers: Option<ArrayView2<f64>>,
728    lambda: f64,
729) -> Result<Vec<f64>> {
730    let n_centers = kernel_centers.map(|c: ArrayView2<f64>| c.shape()[0]);
731    let mut est = KLIEP {
732        sigma: None,
733        lambda,
734        max_iter: 1000,
735        learning_rate: 1e-3,
736        n_centers,
737        alpha: None,
738        centers: None,
739        resolved_sigma: 1.0,
740    };
741    est.fit(numerator_samples, denominator_samples)?;
742    est.predict(numerator_samples)
743}
744
745// ─── Two-sample testing / covariate shift ────────────────────────────────────
746
747/// Compute the covariate shift score between source and target distributions.
748///
749/// The score is defined as the mean log-importance-weight estimated by KLIEP:
750///
751///   score = (1/n) Σ_i log(w(x_i) + ε)
752///
753/// A score near zero indicates negligible shift; larger positive values indicate
754/// that the source and target distributions differ substantially.
755///
756/// # Arguments
757///
758/// * `source` - Source domain samples.
759/// * `target` - Target domain samples.
760///
761/// # Returns
762///
763/// A non-negative scalar; larger values indicate greater covariate shift.
764pub fn covariate_shift_score(source: ArrayView2<f64>, target: ArrayView2<f64>) -> Result<f64> {
765    let weights = importance_weights(source, target, ImportanceMethod::Kliep)?;
766    if weights.is_empty() {
767        return Ok(0.0);
768    }
769    let eps = 1e-15_f64;
770    let mean_log_w =
771        weights.iter().map(|&w: &f64| (w + eps).ln()).sum::<f64>() / weights.len() as f64;
772    Ok(mean_log_w.abs())
773}
774
775/// Two-sample test statistic based on the RuLSIF density ratio.
776///
777/// Returns the Pearson divergence PE(p ‖ α p + (1−α) q) estimated from finite
778/// samples, which is zero iff p = q and positive otherwise.
779pub fn two_sample_test_statistic(
780    samples_p: ArrayView2<f64>,
781    samples_q: ArrayView2<f64>,
782    alpha: f64,
783    lambda: f64,
784) -> Result<f64> {
785    let mut est = RuLSIF::new(alpha, lambda);
786    est.fit(samples_p, samples_q)?;
787    let w_p = est.predict(samples_p)?;
788    let w_q = est.predict(samples_q)?;
789
790    // PE divergence = 0.5 * E_q[w^2] - E_p[w] + 0.5
791    let mean_w2_q = w_q.iter().map(|&w| w * w).sum::<f64>() / w_q.len() as f64;
792    let mean_w_p = w_p.iter().sum::<f64>() / w_p.len() as f64;
793    let pe = 0.5 * mean_w2_q - mean_w_p + 0.5_f64;
794    Ok(pe.max(0.0_f64))
795}
796
797// ─── Tests ───────────────────────────────────────────────────────────────────
798
799#[cfg(test)]
800mod tests {
801    use super::*;
802    use scirs2_core::ndarray::Array2;
803
804    fn make_gaussian_samples(mean: f64, std: f64, n: usize, seed: u64) -> Array2<f64> {
805        // Deterministic Box-Muller to avoid rand dependency
806        let mut out = vec![0.0_f64; n];
807        let mut state = seed;
808        let lcg_next = |s: u64| {
809            s.wrapping_mul(6364136223846793005)
810                .wrapping_add(1442695040888963407)
811        };
812        for i in 0..n {
813            state = lcg_next(state);
814            let u1 = (state as f64) / u64::MAX as f64 * (1.0 - 1e-10) + 1e-10;
815            state = lcg_next(state);
816            let u2 = (state as f64) / u64::MAX as f64;
817            let z = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
818            out[i] = mean + std * z;
819        }
820        let data: Vec<f64> = out;
821        Array2::from_shape_vec((n, 1), data).expect("shape")
822    }
823
824    #[test]
825    fn test_kliep_same_distribution() {
826        let src = make_gaussian_samples(0.0, 1.0, 40, 1);
827        let tgt = make_gaussian_samples(0.0, 1.0, 40, 2);
828        let w = importance_weights(src.view(), tgt.view(), ImportanceMethod::Kliep)
829            .expect("kliep failed");
830        assert_eq!(w.len(), 40);
831        // All weights should be non-negative
832        assert!(w.iter().all(|&v| v >= 0.0));
833    }
834
835    #[test]
836    fn test_kliep_shifted_distribution() {
837        let src = make_gaussian_samples(0.0, 1.0, 40, 3);
838        let tgt = make_gaussian_samples(3.0, 1.0, 40, 4);
839        let w = importance_weights(src.view(), tgt.view(), ImportanceMethod::Kliep)
840            .expect("kliep shifted");
841        assert!(w.iter().all(|&v| v >= 0.0));
842    }
843
844    #[test]
845    fn test_rulsif_basic() {
846        let src = make_gaussian_samples(0.0, 1.0, 30, 5);
847        let tgt = make_gaussian_samples(1.0, 1.0, 30, 6);
848        let w =
849            importance_weights(src.view(), tgt.view(), ImportanceMethod::RuLSIF).expect("rulsif");
850        assert_eq!(w.len(), 30);
851        assert!(w.iter().all(|&v| v >= 0.0));
852    }
853
854    #[test]
855    fn test_kde_ratio_basic() {
856        let src = make_gaussian_samples(0.0, 1.0, 30, 7);
857        let tgt = make_gaussian_samples(0.0, 1.0, 30, 8);
858        let w = importance_weights(src.view(), tgt.view(), ImportanceMethod::KullbackLeibler)
859            .expect("kde ratio");
860        assert_eq!(w.len(), 30);
861        assert!(w.iter().all(|&v| v >= 0.0));
862    }
863
864    #[test]
865    fn test_covariate_shift_score_zero_shift() {
866        let src = make_gaussian_samples(0.0, 1.0, 30, 9);
867        let tgt = make_gaussian_samples(0.0, 1.0, 30, 10);
868        let score = covariate_shift_score(src.view(), tgt.view()).expect("score");
869        // Should be a finite non-negative number
870        assert!(score.is_finite());
871        assert!(score >= 0.0);
872    }
873
874    #[test]
875    fn test_two_sample_test_nonneg() {
876        let p = make_gaussian_samples(0.0, 1.0, 30, 11);
877        let q = make_gaussian_samples(2.0, 1.0, 30, 12);
878        let stat = two_sample_test_statistic(p.view(), q.view(), 0.1, 1e-3).expect("test stat");
879        assert!(stat >= 0.0);
880        assert!(stat.is_finite());
881    }
882
883    #[test]
884    fn test_density_ratio_kliep_fn() {
885        let num = make_gaussian_samples(1.0, 0.5, 20, 13);
886        let den = make_gaussian_samples(0.0, 1.0, 20, 14);
887        let w = density_ratio_kliep(num.view(), den.view(), None, 1e-3).expect("kliep fn");
888        assert_eq!(w.len(), 20);
889    }
890
891    #[test]
892    fn test_feature_mismatch_error() {
893        let a = Array2::from_shape_vec((5, 2), vec![0.0; 10]).expect("a");
894        let b = Array2::from_shape_vec((5, 3), vec![0.0; 15]).expect("b");
895        let mut est = KLIEP::new();
896        let result = est.fit(a.view(), b.view());
897        assert!(result.is_err());
898    }
899}