1use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::random::essentials::Uniform as RandUniform;
9use scirs2_core::random::rngs::StdRng as RealStdRng;
10use scirs2_core::random::seq::SliceRandom;
11use scirs2_core::random::Rng;
12use scirs2_core::random::{thread_rng, SeedableRng};
13use sklears_core::{
14 error::{Result, SklearsError},
15 prelude::{Fit, Transform},
16 traits::{Estimator, Trained, Untrained},
17 types::Float,
18};
19use std::f64::consts::PI;
20use std::marker::PhantomData;
21
22use crate::structured_random_features::FastWalshHadamardTransform;
23
24#[derive(Debug, Clone)]
58pub struct FastfoodTransform<State = Untrained> {
60 pub n_components: usize,
62 pub gamma: Float,
64 pub random_state: Option<u64>,
66
67 scaling_b_: Option<Array1<Float>>, permutation_: Option<Array1<usize>>, scaling_g_: Option<Array1<Float>>, random_offset_: Option<Array1<Float>>, padded_dim_: Option<usize>, n_blocks_: Option<usize>, _state: PhantomData<State>,
77}
78
79impl FastfoodTransform<Untrained> {
80 pub fn new(n_components: usize) -> Self {
82 Self {
83 n_components,
84 gamma: 1.0,
85 random_state: None,
86 scaling_b_: None,
87 permutation_: None,
88 scaling_g_: None,
89 random_offset_: None,
90 padded_dim_: None,
91 n_blocks_: None,
92 _state: PhantomData,
93 }
94 }
95
96 pub fn gamma(mut self, gamma: Float) -> Self {
98 self.gamma = gamma;
99 self
100 }
101
102 pub fn random_state(mut self, seed: u64) -> Self {
104 self.random_state = Some(seed);
105 self
106 }
107}
108
109impl Estimator for FastfoodTransform<Untrained> {
110 type Config = ();
111 type Error = SklearsError;
112 type Float = Float;
113
114 fn config(&self) -> &Self::Config {
115 &()
116 }
117}
118
119impl Fit<Array2<Float>, ()> for FastfoodTransform<Untrained> {
120 type Fitted = FastfoodTransform<Trained>;
121
122 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
123 let (_, n_features) = x.dim();
124
125 let mut rng = match self.random_state {
126 Some(seed) => RealStdRng::seed_from_u64(seed),
127 None => RealStdRng::from_seed(thread_rng().gen()),
128 };
129
130 let padded_dim = next_power_of_2(n_features);
132
133 let n_blocks = (self.n_components + padded_dim - 1) / padded_dim;
135
136 let scaling_b = self.generate_random_scaling(padded_dim * n_blocks, &mut rng);
138 let scaling_g = self.generate_random_scaling(padded_dim * n_blocks, &mut rng);
139
140 let permutation = self.generate_random_permutation(padded_dim * n_blocks, &mut rng);
142
143 let uniform = RandUniform::new(0.0, 2.0 * PI).unwrap();
145 let random_offset = Array1::from_shape_fn(self.n_components, |_| rng.sample(uniform));
146
147 Ok(FastfoodTransform {
148 n_components: self.n_components,
149 gamma: self.gamma,
150 random_state: self.random_state,
151 scaling_b_: Some(scaling_b),
152 permutation_: Some(permutation),
153 scaling_g_: Some(scaling_g),
154 random_offset_: Some(random_offset),
155 padded_dim_: Some(padded_dim),
156 n_blocks_: Some(n_blocks),
157 _state: PhantomData,
158 })
159 }
160}
161
162impl FastfoodTransform<Untrained> {
163 fn generate_random_scaling(&self, size: usize, rng: &mut RealStdRng) -> Array1<Float> {
165 let mut scaling = Array1::zeros(size);
166 for i in 0..size {
167 scaling[i] = if rng.gen::<bool>() { 1.0 } else { -1.0 };
168 }
169 scaling
170 }
171
172 fn generate_random_permutation(&self, size: usize, rng: &mut RealStdRng) -> Array1<usize> {
174 let mut permutation: Vec<usize> = (0..size).collect();
175 permutation.shuffle(rng);
176 Array1::from_vec(permutation)
177 }
178}
179
180impl Transform<Array2<Float>> for FastfoodTransform<Trained> {
181 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
182 let scaling_b = self
183 .scaling_b_
184 .as_ref()
185 .ok_or_else(|| SklearsError::NotFitted {
186 operation: "transform".to_string(),
187 })?;
188
189 let permutation = self
190 .permutation_
191 .as_ref()
192 .ok_or_else(|| SklearsError::NotFitted {
193 operation: "transform".to_string(),
194 })?;
195
196 let scaling_g = self
197 .scaling_g_
198 .as_ref()
199 .ok_or_else(|| SklearsError::NotFitted {
200 operation: "transform".to_string(),
201 })?;
202
203 let random_offset =
204 self.random_offset_
205 .as_ref()
206 .ok_or_else(|| SklearsError::NotFitted {
207 operation: "transform".to_string(),
208 })?;
209
210 let padded_dim = *self
211 .padded_dim_
212 .as_ref()
213 .ok_or_else(|| SklearsError::NotFitted {
214 operation: "transform".to_string(),
215 })?;
216
217 let n_blocks = *self
218 .n_blocks_
219 .as_ref()
220 .ok_or_else(|| SklearsError::NotFitted {
221 operation: "transform".to_string(),
222 })?;
223
224 let (n_samples, n_features) = x.dim();
225 let mut features = Array2::zeros((n_samples, self.n_components));
226
227 for sample_idx in 0..n_samples {
229 let sample = x.row(sample_idx);
230
231 let transformed_sample = self.apply_fastfood_transform(
233 &sample,
234 scaling_b,
235 permutation,
236 scaling_g,
237 padded_dim,
238 n_blocks,
239 n_features,
240 )?;
241
242 for j in 0..(self.n_components.min(transformed_sample.len())) {
244 let phase = transformed_sample[j] * (2.0 * self.gamma).sqrt() + random_offset[j];
245 features[[sample_idx, j]] =
246 (2.0 / (self.n_components as Float)).sqrt() * phase.cos();
247 }
248 }
249
250 Ok(features)
251 }
252}
253
254impl FastfoodTransform<Trained> {
255 fn apply_fastfood_transform(
257 &self,
258 x: &scirs2_core::ndarray::ArrayBase<
259 scirs2_core::ndarray::ViewRepr<&Float>,
260 scirs2_core::ndarray::Dim<[usize; 1]>,
261 >,
262 scaling_b: &Array1<Float>,
263 permutation: &Array1<usize>,
264 scaling_g: &Array1<Float>,
265 padded_dim: usize,
266 n_blocks: usize,
267 n_features: usize,
268 ) -> Result<Array1<Float>> {
269 let mut result = Array1::zeros(padded_dim * n_blocks);
270
271 for block in 0..n_blocks {
273 let block_start = block * padded_dim;
274 let _block_end = block_start + padded_dim;
275
276 let mut padded_input = Array1::zeros(padded_dim);
278 for i in 0..n_features.min(padded_dim) {
279 padded_input[i] = x[i];
280 }
281
282 let mut transformed = FastWalshHadamardTransform::transform(padded_input)?;
284
285 for i in 0..padded_dim {
287 transformed[i] *= scaling_b[block_start + i];
288 }
289
290 let mut permuted = Array1::zeros(padded_dim);
292 for i in 0..padded_dim {
293 let perm_idx = permutation[block_start + i] % padded_dim;
294 permuted[i] = transformed[perm_idx];
295 }
296
297 transformed = FastWalshHadamardTransform::transform(permuted)?;
299
300 for i in 0..padded_dim {
302 transformed[i] *= scaling_g[block_start + i];
303 }
304
305 for i in 0..padded_dim {
307 result[block_start + i] = transformed[i];
308 }
309 }
310
311 Ok(result)
312 }
313}
314
315fn next_power_of_2(n: usize) -> usize {
317 if n <= 1 {
318 return 1;
319 }
320 let mut power = 1;
321 while power < n {
322 power *= 2;
323 }
324 power
325}
326
327#[derive(Debug, Clone)]
332pub struct FastfoodKernel<State = Untrained> {
334 pub n_components: usize,
336 pub kernel_params: FastfoodKernelParams,
338 pub random_state: Option<u64>,
340
341 fastfood_transforms_: Option<Vec<FastfoodTransform<Trained>>>,
343
344 _state: PhantomData<State>,
345}
346
347#[derive(Debug, Clone)]
349pub enum FastfoodKernelParams {
351 Rbf { gamma: Float },
353 Matern { nu: Float, length_scale: Float },
355 RationalQuadratic { alpha: Float, length_scale: Float },
357}
358
359impl FastfoodKernel<Untrained> {
360 pub fn new(n_components: usize, kernel_params: FastfoodKernelParams) -> Self {
362 Self {
363 n_components,
364 kernel_params,
365 random_state: None,
366 fastfood_transforms_: None,
367 _state: PhantomData,
368 }
369 }
370
371 pub fn random_state(mut self, seed: u64) -> Self {
373 self.random_state = Some(seed);
374 self
375 }
376}
377
378impl Estimator for FastfoodKernel<Untrained> {
379 type Config = ();
380 type Error = SklearsError;
381 type Float = Float;
382
383 fn config(&self) -> &Self::Config {
384 &()
385 }
386}
387
388impl Fit<Array2<Float>, ()> for FastfoodKernel<Untrained> {
389 type Fitted = FastfoodKernel<Trained>;
390
391 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
392 let gamma = match &self.kernel_params {
394 FastfoodKernelParams::Rbf { gamma } => *gamma,
395 _ => {
396 return Err(SklearsError::InvalidInput(
397 "Only RBF kernel is currently supported for FastfoodKernel".to_string(),
398 ))
399 }
400 };
401
402 let fastfood = FastfoodTransform::new(self.n_components).gamma(gamma);
403 let fastfood = match self.random_state {
404 Some(seed) => fastfood.random_state(seed),
405 None => fastfood,
406 };
407
408 let fitted_fastfood = fastfood.fit(x, &())?;
409 let transforms = vec![fitted_fastfood];
410
411 Ok(FastfoodKernel {
412 n_components: self.n_components,
413 kernel_params: self.kernel_params,
414 random_state: self.random_state,
415 fastfood_transforms_: Some(transforms),
416 _state: PhantomData,
417 })
418 }
419}
420
421impl Transform<Array2<Float>> for FastfoodKernel<Trained> {
422 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
423 let transforms =
424 self.fastfood_transforms_
425 .as_ref()
426 .ok_or_else(|| SklearsError::NotFitted {
427 operation: "transform".to_string(),
428 })?;
429
430 transforms[0].transform(x)
432 }
433}
434
435#[allow(non_snake_case)]
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use scirs2_core::ndarray::array;
440
441 #[test]
442 fn test_next_power_of_2() {
443 assert_eq!(next_power_of_2(1), 1);
444 assert_eq!(next_power_of_2(2), 2);
445 assert_eq!(next_power_of_2(3), 4);
446 assert_eq!(next_power_of_2(7), 8);
447 assert_eq!(next_power_of_2(8), 8);
448 assert_eq!(next_power_of_2(15), 16);
449 assert_eq!(next_power_of_2(16), 16);
450 }
451
452 #[test]
453 fn test_fastfood_transform_basic() {
454 let x = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]];
455
456 let fastfood = FastfoodTransform::new(8).gamma(0.5);
457 let fitted = fastfood.fit(&x, &()).unwrap();
458 let transformed = fitted.transform(&x).unwrap();
459
460 assert_eq!(transformed.shape(), &[3, 8]);
461 }
462
463 #[test]
464 fn test_fastfood_transform_power_of_2() {
465 let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
466
467 let fastfood = FastfoodTransform::new(4).gamma(1.0);
468 let fitted = fastfood.fit(&x, &()).unwrap();
469 let transformed = fitted.transform(&x).unwrap();
470
471 assert_eq!(transformed.shape(), &[2, 4]);
472 }
473
474 #[test]
475 fn test_fastfood_kernel_rbf() {
476 let x = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
477
478 let kernel_params = FastfoodKernelParams::Rbf { gamma: 0.5 };
479 let fastfood_kernel = FastfoodKernel::new(6, kernel_params);
480 let fitted = fastfood_kernel.fit(&x, &()).unwrap();
481 let transformed = fitted.transform(&x).unwrap();
482
483 assert_eq!(transformed.shape(), &[2, 6]);
484 }
485
486 #[test]
487 fn test_fastfood_reproducibility() {
488 let x = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
489
490 let fastfood1 = FastfoodTransform::new(8).random_state(42);
491 let fitted1 = fastfood1.fit(&x, &()).unwrap();
492 let result1 = fitted1.transform(&x).unwrap();
493
494 let fastfood2 = FastfoodTransform::new(8).random_state(42);
495 let fitted2 = fastfood2.fit(&x, &()).unwrap();
496 let result2 = fitted2.transform(&x).unwrap();
497
498 assert_eq!(result1.shape(), result2.shape());
499 for (a, b) in result1.iter().zip(result2.iter()) {
500 assert!((a - b).abs() < 1e-10);
501 }
502 }
503
504 #[test]
505 fn test_fastfood_different_gamma() {
506 let x = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
507
508 let fastfood_low = FastfoodTransform::new(4).gamma(0.1);
509 let fitted_low = fastfood_low.fit(&x, &()).unwrap();
510 let result_low = fitted_low.transform(&x).unwrap();
511
512 let fastfood_high = FastfoodTransform::new(4).gamma(10.0);
513 let fitted_high = fastfood_high.fit(&x, &()).unwrap();
514 let result_high = fitted_high.transform(&x).unwrap();
515
516 assert_eq!(result_low.shape(), result_high.shape());
517 let diff_sum: Float = result_low
519 .iter()
520 .zip(result_high.iter())
521 .map(|(a, b)| (a - b).abs())
522 .sum();
523 assert!(diff_sum > 1e-6);
524 }
525
526 #[test]
527 fn test_fastfood_large_dimensions() {
528 let x = array![
529 [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
530 [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0]
531 ];
532
533 let fastfood = FastfoodTransform::new(16).gamma(0.1);
534 let fitted = fastfood.fit(&x, &()).unwrap();
535 let transformed = fitted.transform(&x).unwrap();
536
537 assert_eq!(transformed.shape(), &[2, 16]);
538 }
539
540 #[test]
541 fn test_fastfood_single_sample() {
542 let x = array![[1.0, 2.0, 3.0, 4.0]];
543
544 let fastfood = FastfoodTransform::new(8).gamma(1.0);
545 let fitted = fastfood.fit(&x, &()).unwrap();
546 let transformed = fitted.transform(&x).unwrap();
547
548 assert_eq!(transformed.shape(), &[1, 8]);
549 }
550
551 #[test]
552 fn test_fastfood_edge_cases() {
553 let x = array![[1.0], [2.0]];
555
556 let fastfood = FastfoodTransform::new(2).gamma(1.0);
557 let fitted = fastfood.fit(&x, &()).unwrap();
558 let transformed = fitted.transform(&x).unwrap();
559
560 assert_eq!(transformed.shape(), &[2, 2]);
561
562 let x2 = array![[1.0, 2.0], [3.0, 4.0]];
564 let fastfood2 = FastfoodTransform::new(32).gamma(0.5);
565 let fitted2 = fastfood2.fit(&x2, &()).unwrap();
566 let transformed2 = fitted2.transform(&x2).unwrap();
567
568 assert_eq!(transformed2.shape(), &[2, 32]);
569 }
570}