sklears_kernel_approximation/
chi2_samplers.rs

1//! Chi-squared kernel approximation methods
2use scirs2_core::ndarray::{Array1, Array2, Axis};
3use scirs2_core::random::essentials::Uniform as RandUniform;
4use scirs2_core::random::rngs::StdRng as RealStdRng;
5use scirs2_core::random::Rng;
6use sklears_core::{
7    error::{Result, SklearsError},
8    traits::{Estimator, Fit, Trained, Transform, Untrained},
9    types::Float,
10};
11use std::marker::PhantomData;
12
13use scirs2_core::random::{thread_rng, SeedableRng};
14/// Additive Chi-Squared Kernel Approximation
15///
16/// Approximates the additive chi-squared kernel: K(x,y) = Σᵢ (2xᵢyᵢ)/(xᵢ+yᵢ)
17/// Used with histogram data in computer vision. This is a stateless transformer.
18///
19/// # Parameters
20///
21/// * `sample_steps` - Number of sampling points (default: 2)
22/// * `sample_interval` - Sampling interval (auto-computed if None)
23///
24/// # Examples
25///
26/// ```rust,ignore
27/// use sklears_kernel_approximation::AdditiveChi2Sampler;
28/// use sklears_core::traits::Transform;
29/// use scirs2_core::ndarray::array;
30///
31/// let X = array![[1.0, 2.0], [3.0, 4.0]];
32///
33/// let chi2 = AdditiveChi2Sampler::new(2);
34/// let X_transformed = chi2.transform(&X).unwrap();
35/// assert_eq!(X_transformed.shape(), &[2, 6]); // 2 features * 3 = 6
36/// ```
37#[derive(Debug, Clone)]
38/// AdditiveChi2Sampler
39pub struct AdditiveChi2Sampler {
40    /// Number of sampling points
41    pub sample_steps: usize,
42    /// Sampling interval
43    pub sample_interval: Float,
44}
45
46impl AdditiveChi2Sampler {
47    /// Create a new Additive Chi2 sampler
48    pub fn new(sample_steps: usize) -> Self {
49        let sample_interval = match sample_steps {
50            1 => 0.8,
51            2 => 0.5,
52            3 => 0.4,
53            _ => 0.5, // Default fallback
54        };
55
56        Self {
57            sample_steps,
58            sample_interval,
59        }
60    }
61
62    /// Set the sample interval
63    pub fn sample_interval(mut self, interval: Float) -> Self {
64        self.sample_interval = interval;
65        self
66    }
67}
68
69impl Transform<Array2<Float>, Array2<Float>> for AdditiveChi2Sampler {
70    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
71        let (n_samples, n_features) = x.dim();
72
73        // Check for non-negative values
74        for val in x.iter() {
75            if *val < 0.0 {
76                return Err(SklearsError::InvalidInput(
77                    "Additive chi2 kernel requires non-negative features".to_string(),
78                ));
79            }
80        }
81
82        let n_output_features = n_features * (2 * self.sample_steps - 1);
83        let mut result = Array2::zeros((n_samples, n_output_features));
84
85        for i in 0..n_samples {
86            let mut feature_idx = 0;
87
88            for j in 0..n_features {
89                let x_val = x[[i, j]];
90
91                // First component: sqrt(X * sample_interval)
92                result[[i, feature_idx]] = (x_val * self.sample_interval).sqrt();
93                feature_idx += 1;
94
95                // Additional components: factor * cos/sin(k * log(X) * sample_interval)
96                if x_val > 0.0 {
97                    let log_x = x_val.ln();
98
99                    for k in 1..self.sample_steps {
100                        let k_float = k as Float;
101                        let arg = k_float * log_x * self.sample_interval;
102                        let factor = (2.0 * x_val * self.sample_interval
103                            / (std::f64::consts::PI * k_float * self.sample_interval).cosh())
104                        .sqrt();
105
106                        // Cosine component
107                        result[[i, feature_idx]] = factor * arg.cos();
108                        feature_idx += 1;
109
110                        // Sine component
111                        result[[i, feature_idx]] = factor * arg.sin();
112                        feature_idx += 1;
113                    }
114                } else {
115                    // For x_val == 0, set remaining components to 0
116                    feature_idx += 2 * (self.sample_steps - 1);
117                }
118            }
119        }
120
121        Ok(result)
122    }
123}
124
125/// Skewed Chi-Squared Kernel Approximation
126///
127/// Approximates the skewed chi-squared kernel using Monte Carlo sampling.
128/// K(x,y) = ∏ᵢ (2√(xᵢ+c)√(yᵢ+c))/(xᵢ+yᵢ+2c)
129///
130/// # Parameters
131///
132/// * `skewedness` - The "c" parameter (default: 1.0)
133/// * `n_components` - Number of Monte Carlo samples (default: 100)
134/// * `random_state` - Random seed for reproducibility
135///
136/// # Examples
137///
138/// ```rust,ignore
139/// use sklears_kernel_approximation::SkewedChi2Sampler;
140/// use sklears_core::traits::{Transform, Fit, Untrained}
141/// use scirs2_core::ndarray::array;
142///
143/// let X = array![[1.0, 2.0], [3.0, 4.0]];
144///
145/// let skewed_chi2 = SkewedChi2Sampler::new(50);
146/// let fitted_chi2 = skewed_chi2.fit(&X, &()).unwrap();
147/// let X_transformed = fitted_chi2.transform(&X).unwrap();
148/// assert_eq!(X_transformed.shape(), &[2, 50]);
149/// ```
150#[derive(Debug, Clone)]
151/// SkewedChi2Sampler
152pub struct SkewedChi2Sampler<State = Untrained> {
153    /// Skewedness parameter
154    pub skewedness: Float,
155    /// Number of Monte Carlo samples
156    pub n_components: usize,
157    /// Random seed
158    pub random_state: Option<u64>,
159
160    // Fitted attributes
161    random_weights_: Option<Array2<Float>>,
162    random_offset_: Option<Array1<Float>>,
163
164    _state: PhantomData<State>,
165}
166
167impl SkewedChi2Sampler<Untrained> {
168    /// Create a new Skewed Chi2 sampler
169    pub fn new(n_components: usize) -> Self {
170        Self {
171            skewedness: 1.0,
172            n_components,
173            random_state: None,
174            random_weights_: None,
175            random_offset_: None,
176            _state: PhantomData,
177        }
178    }
179
180    /// Set the skewedness parameter
181    pub fn skewedness(mut self, skewedness: Float) -> Self {
182        self.skewedness = skewedness;
183        self
184    }
185
186    /// Set random state for reproducibility
187    pub fn random_state(mut self, seed: u64) -> Self {
188        self.random_state = Some(seed);
189        self
190    }
191}
192
193impl Estimator for SkewedChi2Sampler<Untrained> {
194    type Config = ();
195    type Error = SklearsError;
196    type Float = Float;
197
198    fn config(&self) -> &Self::Config {
199        &()
200    }
201}
202
203impl Fit<Array2<Float>, ()> for SkewedChi2Sampler<Untrained> {
204    type Fitted = SkewedChi2Sampler<Trained>;
205
206    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
207        let (_, n_features) = x.dim();
208
209        if self.skewedness <= 0.0 {
210            return Err(SklearsError::InvalidInput(
211                "skewedness must be positive".to_string(),
212            ));
213        }
214
215        // Check that all values > -skewedness
216        for val in x.iter() {
217            if *val <= -self.skewedness {
218                return Err(SklearsError::InvalidInput(format!(
219                    "All values must be > -skewedness ({})",
220                    -self.skewedness
221                )));
222            }
223        }
224
225        let mut rng = if let Some(seed) = self.random_state {
226            RealStdRng::seed_from_u64(seed)
227        } else {
228            RealStdRng::from_seed(thread_rng().gen())
229        };
230
231        // Sample random weights from inverse CDF of sech distribution
232        let uniform = RandUniform::new(0.0, 1.0).unwrap();
233        let mut weights = Array2::zeros((n_features, self.n_components));
234
235        for mut col in weights.columns_mut() {
236            for weight in col.iter_mut() {
237                let u = rng.sample(uniform);
238                // Inverse CDF of sech: (1/π) * log(tan(π/2 * u))
239                *weight =
240                    (1.0 / std::f64::consts::PI) * ((std::f64::consts::PI / 2.0 * u).tan()).ln();
241            }
242        }
243
244        // Sample random offsets from Uniform(0, 2π)
245        let offset_uniform = RandUniform::new(0.0, 2.0 * std::f64::consts::PI).unwrap();
246        let mut random_offset = Array1::zeros(self.n_components);
247        for val in random_offset.iter_mut() {
248            *val = rng.sample(offset_uniform);
249        }
250
251        Ok(SkewedChi2Sampler {
252            skewedness: self.skewedness,
253            n_components: self.n_components,
254            random_state: self.random_state,
255            random_weights_: Some(weights),
256            random_offset_: Some(random_offset),
257            _state: PhantomData,
258        })
259    }
260}
261
262impl Transform<Array2<Float>, Array2<Float>> for SkewedChi2Sampler<Trained> {
263    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
264        let (_n_samples, n_features) = x.dim();
265        let weights = self.random_weights_.as_ref().unwrap();
266        let offset = self.random_offset_.as_ref().unwrap();
267
268        if n_features != weights.nrows() {
269            return Err(SklearsError::InvalidInput(format!(
270                "X has {} features, but SkewedChi2Sampler was fitted with {} features",
271                n_features,
272                weights.nrows()
273            )));
274        }
275
276        // Check input validity
277        for val in x.iter() {
278            if *val <= -self.skewedness {
279                return Err(SklearsError::InvalidInput(format!(
280                    "All values must be > -skewedness ({})",
281                    -self.skewedness
282                )));
283            }
284        }
285
286        // Transform: log(X + skewedness)
287        let x_shifted = x.mapv(|v| (v + self.skewedness).ln());
288
289        // Compute projection and apply cosine
290        let projection = x_shifted.dot(weights) + offset.view().insert_axis(Axis(0));
291        let normalization = (2.0 / self.n_components as Float).sqrt();
292        let result = projection.mapv(|v| normalization * v.cos());
293
294        Ok(result)
295    }
296}
297
298#[allow(non_snake_case)]
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use scirs2_core::ndarray::array;
303
304    #[test]
305    fn test_additive_chi2_sampler_basic() {
306        let x = array![[1.0, 2.0], [3.0, 4.0],];
307
308        let chi2 = AdditiveChi2Sampler::new(2);
309        let x_transformed = chi2.transform(&x).unwrap();
310
311        // 2 features * (2*2-1) = 6 output features
312        assert_eq!(x_transformed.shape(), &[2, 6]);
313
314        // Check non-negativity of first components (sqrt values)
315        assert!(x_transformed[[0, 0]] >= 0.0);
316        assert!(x_transformed[[0, 3]] >= 0.0);
317    }
318
319    #[test]
320    fn test_additive_chi2_sampler_negative_input() {
321        let x = array![
322            [1.0, -2.0], // Negative value
323        ];
324
325        let chi2 = AdditiveChi2Sampler::new(2);
326        let result = chi2.transform(&x);
327        assert!(result.is_err());
328    }
329
330    #[test]
331    fn test_skewed_chi2_sampler_basic() {
332        let x = array![[1.0, 2.0], [3.0, 4.0],];
333
334        let skewed_chi2 = SkewedChi2Sampler::new(50);
335        let fitted = skewed_chi2.fit(&x, &()).unwrap();
336        let x_transformed = fitted.transform(&x).unwrap();
337
338        assert_eq!(x_transformed.shape(), &[2, 50]);
339
340        // Check that values are in reasonable range
341        for val in x_transformed.iter() {
342            assert!(val.abs() <= 2.0);
343        }
344    }
345
346    #[test]
347    fn test_skewed_chi2_sampler_invalid_skewedness() {
348        let x = array![[1.0, 2.0]];
349        let skewed_chi2 = SkewedChi2Sampler::new(10).skewedness(-1.0);
350        let result = skewed_chi2.fit(&x, &());
351        assert!(result.is_err());
352    }
353
354    #[test]
355    fn test_skewed_chi2_sampler_input_validation() {
356        let x_train = array![[1.0, 2.0]];
357        let x_test = array![[-1.5, 2.0]]; // < -skewedness when skewedness=1.0
358
359        let skewed_chi2 = SkewedChi2Sampler::new(10);
360        let fitted = skewed_chi2.fit(&x_train, &()).unwrap();
361        let result = fitted.transform(&x_test);
362        assert!(result.is_err());
363    }
364}