sklears_kernel_approximation/
chi2_samplers.rs1use scirs2_core::ndarray::{Array1, Array2, Axis};
3use scirs2_core::random::essentials::Uniform as RandUniform;
4use scirs2_core::random::rngs::StdRng as RealStdRng;
5use sklears_core::{
6 error::{Result, SklearsError},
7 traits::{Estimator, Fit, Trained, Transform, Untrained},
8 types::Float,
9};
10use std::marker::PhantomData;
11
12use scirs2_core::random::RngExt;
13use scirs2_core::random::{thread_rng, SeedableRng};
14#[derive(Debug, Clone)]
38pub struct AdditiveChi2Sampler {
40 pub sample_steps: usize,
42 pub sample_interval: Float,
44}
45
46impl AdditiveChi2Sampler {
47 pub fn new(sample_steps: usize) -> Self {
49 let sample_interval = match sample_steps {
50 1 => 0.8,
51 2 => 0.5,
52 3 => 0.4,
53 _ => 0.5, };
55
56 Self {
57 sample_steps,
58 sample_interval,
59 }
60 }
61
62 pub fn sample_interval(mut self, interval: Float) -> Self {
64 self.sample_interval = interval;
65 self
66 }
67}
68
69impl Transform<Array2<Float>, Array2<Float>> for AdditiveChi2Sampler {
70 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
71 let (n_samples, n_features) = x.dim();
72
73 for val in x.iter() {
75 if *val < 0.0 {
76 return Err(SklearsError::InvalidInput(
77 "Additive chi2 kernel requires non-negative features".to_string(),
78 ));
79 }
80 }
81
82 let n_output_features = n_features * (2 * self.sample_steps - 1);
83 let mut result = Array2::zeros((n_samples, n_output_features));
84
85 for i in 0..n_samples {
86 let mut feature_idx = 0;
87
88 for j in 0..n_features {
89 let x_val = x[[i, j]];
90
91 result[[i, feature_idx]] = (x_val * self.sample_interval).sqrt();
93 feature_idx += 1;
94
95 if x_val > 0.0 {
97 let log_x = x_val.ln();
98
99 for k in 1..self.sample_steps {
100 let k_float = k as Float;
101 let arg = k_float * log_x * self.sample_interval;
102 let factor = (2.0 * x_val * self.sample_interval
103 / (std::f64::consts::PI * k_float * self.sample_interval).cosh())
104 .sqrt();
105
106 result[[i, feature_idx]] = factor * arg.cos();
108 feature_idx += 1;
109
110 result[[i, feature_idx]] = factor * arg.sin();
112 feature_idx += 1;
113 }
114 } else {
115 feature_idx += 2 * (self.sample_steps - 1);
117 }
118 }
119 }
120
121 Ok(result)
122 }
123}
124
125#[derive(Debug, Clone)]
151pub struct SkewedChi2Sampler<State = Untrained> {
153 pub skewedness: Float,
155 pub n_components: usize,
157 pub random_state: Option<u64>,
159
160 random_weights_: Option<Array2<Float>>,
162 random_offset_: Option<Array1<Float>>,
163
164 _state: PhantomData<State>,
165}
166
167impl SkewedChi2Sampler<Untrained> {
168 pub fn new(n_components: usize) -> Self {
170 Self {
171 skewedness: 1.0,
172 n_components,
173 random_state: None,
174 random_weights_: None,
175 random_offset_: None,
176 _state: PhantomData,
177 }
178 }
179
180 pub fn skewedness(mut self, skewedness: Float) -> Self {
182 self.skewedness = skewedness;
183 self
184 }
185
186 pub fn random_state(mut self, seed: u64) -> Self {
188 self.random_state = Some(seed);
189 self
190 }
191}
192
193impl Estimator for SkewedChi2Sampler<Untrained> {
194 type Config = ();
195 type Error = SklearsError;
196 type Float = Float;
197
198 fn config(&self) -> &Self::Config {
199 &()
200 }
201}
202
203impl Fit<Array2<Float>, ()> for SkewedChi2Sampler<Untrained> {
204 type Fitted = SkewedChi2Sampler<Trained>;
205
206 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
207 let (_, n_features) = x.dim();
208
209 if self.skewedness <= 0.0 {
210 return Err(SklearsError::InvalidInput(
211 "skewedness must be positive".to_string(),
212 ));
213 }
214
215 for val in x.iter() {
217 if *val <= -self.skewedness {
218 return Err(SklearsError::InvalidInput(format!(
219 "All values must be > -skewedness ({})",
220 -self.skewedness
221 )));
222 }
223 }
224
225 let mut rng = if let Some(seed) = self.random_state {
226 RealStdRng::seed_from_u64(seed)
227 } else {
228 RealStdRng::from_seed(thread_rng().random())
229 };
230
231 let uniform = RandUniform::new(0.0, 1.0).expect("operation should succeed");
233 let mut weights = Array2::zeros((n_features, self.n_components));
234
235 for mut col in weights.columns_mut() {
236 for weight in col.iter_mut() {
237 let u = rng.sample(uniform);
238 *weight =
240 (1.0 / std::f64::consts::PI) * ((std::f64::consts::PI / 2.0 * u).tan()).ln();
241 }
242 }
243
244 let offset_uniform =
246 RandUniform::new(0.0, 2.0 * std::f64::consts::PI).expect("operation should succeed");
247 let mut random_offset = Array1::zeros(self.n_components);
248 for val in random_offset.iter_mut() {
249 *val = rng.sample(offset_uniform);
250 }
251
252 Ok(SkewedChi2Sampler {
253 skewedness: self.skewedness,
254 n_components: self.n_components,
255 random_state: self.random_state,
256 random_weights_: Some(weights),
257 random_offset_: Some(random_offset),
258 _state: PhantomData,
259 })
260 }
261}
262
263impl Transform<Array2<Float>, Array2<Float>> for SkewedChi2Sampler<Trained> {
264 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
265 let (_n_samples, n_features) = x.dim();
266 let weights = self
267 .random_weights_
268 .as_ref()
269 .expect("operation should succeed");
270 let offset = self
271 .random_offset_
272 .as_ref()
273 .expect("operation should succeed");
274
275 if n_features != weights.nrows() {
276 return Err(SklearsError::InvalidInput(format!(
277 "X has {} features, but SkewedChi2Sampler was fitted with {} features",
278 n_features,
279 weights.nrows()
280 )));
281 }
282
283 for val in x.iter() {
285 if *val <= -self.skewedness {
286 return Err(SklearsError::InvalidInput(format!(
287 "All values must be > -skewedness ({})",
288 -self.skewedness
289 )));
290 }
291 }
292
293 let x_shifted = x.mapv(|v| (v + self.skewedness).ln());
295
296 let projection = x_shifted.dot(weights) + offset.view().insert_axis(Axis(0));
298 let normalization = (2.0 / self.n_components as Float).sqrt();
299 let result = projection.mapv(|v| normalization * v.cos());
300
301 Ok(result)
302 }
303}
304
305#[allow(non_snake_case)]
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use scirs2_core::ndarray::array;
310
311 #[test]
312 fn test_additive_chi2_sampler_basic() {
313 let x = array![[1.0, 2.0], [3.0, 4.0],];
314
315 let chi2 = AdditiveChi2Sampler::new(2);
316 let x_transformed = chi2.transform(&x).expect("operation should succeed");
317
318 assert_eq!(x_transformed.shape(), &[2, 6]);
320
321 assert!(x_transformed[[0, 0]] >= 0.0);
323 assert!(x_transformed[[0, 3]] >= 0.0);
324 }
325
326 #[test]
327 fn test_additive_chi2_sampler_negative_input() {
328 let x = array![
329 [1.0, -2.0], ];
331
332 let chi2 = AdditiveChi2Sampler::new(2);
333 let result = chi2.transform(&x);
334 assert!(result.is_err());
335 }
336
337 #[test]
338 fn test_skewed_chi2_sampler_basic() {
339 let x = array![[1.0, 2.0], [3.0, 4.0],];
340
341 let skewed_chi2 = SkewedChi2Sampler::new(50);
342 let fitted = skewed_chi2.fit(&x, &()).expect("operation should succeed");
343 let x_transformed = fitted.transform(&x).expect("operation should succeed");
344
345 assert_eq!(x_transformed.shape(), &[2, 50]);
346
347 for val in x_transformed.iter() {
349 assert!(val.abs() <= 2.0);
350 }
351 }
352
353 #[test]
354 fn test_skewed_chi2_sampler_invalid_skewedness() {
355 let x = array![[1.0, 2.0]];
356 let skewed_chi2 = SkewedChi2Sampler::new(10).skewedness(-1.0);
357 let result = skewed_chi2.fit(&x, &());
358 assert!(result.is_err());
359 }
360
361 #[test]
362 fn test_skewed_chi2_sampler_input_validation() {
363 let x_train = array![[1.0, 2.0]];
364 let x_test = array![[-1.5, 2.0]]; let skewed_chi2 = SkewedChi2Sampler::new(10);
367 let fitted = skewed_chi2
368 .fit(&x_train, &())
369 .expect("operation should succeed");
370 let result = fitted.transform(&x_test);
371 assert!(result.is_err());
372 }
373}