sklears_kernel_approximation/
chi2_samplers.rs1use scirs2_core::ndarray::{Array1, Array2, Axis};
3use scirs2_core::random::essentials::Uniform as RandUniform;
4use scirs2_core::random::rngs::StdRng as RealStdRng;
5use scirs2_core::random::Rng;
6use sklears_core::{
7 error::{Result, SklearsError},
8 traits::{Estimator, Fit, Trained, Transform, Untrained},
9 types::Float,
10};
11use std::marker::PhantomData;
12
13use scirs2_core::random::{thread_rng, SeedableRng};
14#[derive(Debug, Clone)]
38pub struct AdditiveChi2Sampler {
40 pub sample_steps: usize,
42 pub sample_interval: Float,
44}
45
46impl AdditiveChi2Sampler {
47 pub fn new(sample_steps: usize) -> Self {
49 let sample_interval = match sample_steps {
50 1 => 0.8,
51 2 => 0.5,
52 3 => 0.4,
53 _ => 0.5, };
55
56 Self {
57 sample_steps,
58 sample_interval,
59 }
60 }
61
62 pub fn sample_interval(mut self, interval: Float) -> Self {
64 self.sample_interval = interval;
65 self
66 }
67}
68
69impl Transform<Array2<Float>, Array2<Float>> for AdditiveChi2Sampler {
70 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
71 let (n_samples, n_features) = x.dim();
72
73 for val in x.iter() {
75 if *val < 0.0 {
76 return Err(SklearsError::InvalidInput(
77 "Additive chi2 kernel requires non-negative features".to_string(),
78 ));
79 }
80 }
81
82 let n_output_features = n_features * (2 * self.sample_steps - 1);
83 let mut result = Array2::zeros((n_samples, n_output_features));
84
85 for i in 0..n_samples {
86 let mut feature_idx = 0;
87
88 for j in 0..n_features {
89 let x_val = x[[i, j]];
90
91 result[[i, feature_idx]] = (x_val * self.sample_interval).sqrt();
93 feature_idx += 1;
94
95 if x_val > 0.0 {
97 let log_x = x_val.ln();
98
99 for k in 1..self.sample_steps {
100 let k_float = k as Float;
101 let arg = k_float * log_x * self.sample_interval;
102 let factor = (2.0 * x_val * self.sample_interval
103 / (std::f64::consts::PI * k_float * self.sample_interval).cosh())
104 .sqrt();
105
106 result[[i, feature_idx]] = factor * arg.cos();
108 feature_idx += 1;
109
110 result[[i, feature_idx]] = factor * arg.sin();
112 feature_idx += 1;
113 }
114 } else {
115 feature_idx += 2 * (self.sample_steps - 1);
117 }
118 }
119 }
120
121 Ok(result)
122 }
123}
124
125#[derive(Debug, Clone)]
151pub struct SkewedChi2Sampler<State = Untrained> {
153 pub skewedness: Float,
155 pub n_components: usize,
157 pub random_state: Option<u64>,
159
160 random_weights_: Option<Array2<Float>>,
162 random_offset_: Option<Array1<Float>>,
163
164 _state: PhantomData<State>,
165}
166
167impl SkewedChi2Sampler<Untrained> {
168 pub fn new(n_components: usize) -> Self {
170 Self {
171 skewedness: 1.0,
172 n_components,
173 random_state: None,
174 random_weights_: None,
175 random_offset_: None,
176 _state: PhantomData,
177 }
178 }
179
180 pub fn skewedness(mut self, skewedness: Float) -> Self {
182 self.skewedness = skewedness;
183 self
184 }
185
186 pub fn random_state(mut self, seed: u64) -> Self {
188 self.random_state = Some(seed);
189 self
190 }
191}
192
193impl Estimator for SkewedChi2Sampler<Untrained> {
194 type Config = ();
195 type Error = SklearsError;
196 type Float = Float;
197
198 fn config(&self) -> &Self::Config {
199 &()
200 }
201}
202
203impl Fit<Array2<Float>, ()> for SkewedChi2Sampler<Untrained> {
204 type Fitted = SkewedChi2Sampler<Trained>;
205
206 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
207 let (_, n_features) = x.dim();
208
209 if self.skewedness <= 0.0 {
210 return Err(SklearsError::InvalidInput(
211 "skewedness must be positive".to_string(),
212 ));
213 }
214
215 for val in x.iter() {
217 if *val <= -self.skewedness {
218 return Err(SklearsError::InvalidInput(format!(
219 "All values must be > -skewedness ({})",
220 -self.skewedness
221 )));
222 }
223 }
224
225 let mut rng = if let Some(seed) = self.random_state {
226 RealStdRng::seed_from_u64(seed)
227 } else {
228 RealStdRng::from_seed(thread_rng().gen())
229 };
230
231 let uniform = RandUniform::new(0.0, 1.0).unwrap();
233 let mut weights = Array2::zeros((n_features, self.n_components));
234
235 for mut col in weights.columns_mut() {
236 for weight in col.iter_mut() {
237 let u = rng.sample(uniform);
238 *weight =
240 (1.0 / std::f64::consts::PI) * ((std::f64::consts::PI / 2.0 * u).tan()).ln();
241 }
242 }
243
244 let offset_uniform = RandUniform::new(0.0, 2.0 * std::f64::consts::PI).unwrap();
246 let mut random_offset = Array1::zeros(self.n_components);
247 for val in random_offset.iter_mut() {
248 *val = rng.sample(offset_uniform);
249 }
250
251 Ok(SkewedChi2Sampler {
252 skewedness: self.skewedness,
253 n_components: self.n_components,
254 random_state: self.random_state,
255 random_weights_: Some(weights),
256 random_offset_: Some(random_offset),
257 _state: PhantomData,
258 })
259 }
260}
261
262impl Transform<Array2<Float>, Array2<Float>> for SkewedChi2Sampler<Trained> {
263 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
264 let (_n_samples, n_features) = x.dim();
265 let weights = self.random_weights_.as_ref().unwrap();
266 let offset = self.random_offset_.as_ref().unwrap();
267
268 if n_features != weights.nrows() {
269 return Err(SklearsError::InvalidInput(format!(
270 "X has {} features, but SkewedChi2Sampler was fitted with {} features",
271 n_features,
272 weights.nrows()
273 )));
274 }
275
276 for val in x.iter() {
278 if *val <= -self.skewedness {
279 return Err(SklearsError::InvalidInput(format!(
280 "All values must be > -skewedness ({})",
281 -self.skewedness
282 )));
283 }
284 }
285
286 let x_shifted = x.mapv(|v| (v + self.skewedness).ln());
288
289 let projection = x_shifted.dot(weights) + offset.view().insert_axis(Axis(0));
291 let normalization = (2.0 / self.n_components as Float).sqrt();
292 let result = projection.mapv(|v| normalization * v.cos());
293
294 Ok(result)
295 }
296}
297
298#[allow(non_snake_case)]
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use scirs2_core::ndarray::array;
303
304 #[test]
305 fn test_additive_chi2_sampler_basic() {
306 let x = array![[1.0, 2.0], [3.0, 4.0],];
307
308 let chi2 = AdditiveChi2Sampler::new(2);
309 let x_transformed = chi2.transform(&x).unwrap();
310
311 assert_eq!(x_transformed.shape(), &[2, 6]);
313
314 assert!(x_transformed[[0, 0]] >= 0.0);
316 assert!(x_transformed[[0, 3]] >= 0.0);
317 }
318
319 #[test]
320 fn test_additive_chi2_sampler_negative_input() {
321 let x = array![
322 [1.0, -2.0], ];
324
325 let chi2 = AdditiveChi2Sampler::new(2);
326 let result = chi2.transform(&x);
327 assert!(result.is_err());
328 }
329
330 #[test]
331 fn test_skewed_chi2_sampler_basic() {
332 let x = array![[1.0, 2.0], [3.0, 4.0],];
333
334 let skewed_chi2 = SkewedChi2Sampler::new(50);
335 let fitted = skewed_chi2.fit(&x, &()).unwrap();
336 let x_transformed = fitted.transform(&x).unwrap();
337
338 assert_eq!(x_transformed.shape(), &[2, 50]);
339
340 for val in x_transformed.iter() {
342 assert!(val.abs() <= 2.0);
343 }
344 }
345
346 #[test]
347 fn test_skewed_chi2_sampler_invalid_skewedness() {
348 let x = array![[1.0, 2.0]];
349 let skewed_chi2 = SkewedChi2Sampler::new(10).skewedness(-1.0);
350 let result = skewed_chi2.fit(&x, &());
351 assert!(result.is_err());
352 }
353
354 #[test]
355 fn test_skewed_chi2_sampler_input_validation() {
356 let x_train = array![[1.0, 2.0]];
357 let x_test = array![[-1.5, 2.0]]; let skewed_chi2 = SkewedChi2Sampler::new(10);
360 let fitted = skewed_chi2.fit(&x_train, &()).unwrap();
361 let result = fitted.transform(&x_test);
362 assert!(result.is_err());
363 }
364}