sklears_kernel_approximation/
fastfood.rs

1//! Fastfood Transform for efficient random feature approximation
2//!
3//! This module implements the Fastfood transform, which approximates
4//! Gaussian random projections using structured matrices, reducing
5//! computational complexity from O(d²) to O(d log d).
6
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::random::essentials::Uniform as RandUniform;
9use scirs2_core::random::rngs::StdRng as RealStdRng;
10use scirs2_core::random::seq::SliceRandom;
11use scirs2_core::random::Rng;
12use scirs2_core::random::{thread_rng, SeedableRng};
13use sklears_core::{
14    error::{Result, SklearsError},
15    prelude::{Fit, Transform},
16    traits::{Estimator, Trained, Untrained},
17    types::Float,
18};
19use std::f64::consts::PI;
20use std::marker::PhantomData;
21
22use crate::structured_random_features::FastWalshHadamardTransform;
23
24/// Fastfood Transform for efficient random Fourier features
25///
26/// The Fastfood transform uses structured matrices to approximate Gaussian
27/// random projections efficiently. It combines three structured transforms:
28/// 1. Random diagonal scaling (B)
29/// 2. Fast Walsh-Hadamard Transform (H)
30/// 3. Random permutation (Π)
31/// 4. Random diagonal scaling (G)
32///
33/// The overall transform is: G * H * Π * B * H
34/// This reduces complexity from O(d²) to O(d log d) while maintaining
35/// approximation quality for RBF kernels.
36///
37/// # Parameters
38///
39/// * `n_components` - Number of random features to generate
40/// * `gamma` - RBF kernel parameter (default: 1.0)
41/// * `random_state` - Random seed for reproducibility
42///
43/// # Examples
44///
45/// ```rust,ignore
46/// use sklears_kernel_approximation::fastfood::FastfoodTransform;
47/// use sklears_core::traits::{Transform, Fit, Untrained}
48/// use scirs2_core::ndarray::array;
49///
50/// let X = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
51///
52/// let fastfood = FastfoodTransform::new(8).gamma(0.5);
53/// let fitted = fastfood.fit(&X, &()).unwrap();
54/// let X_transformed = fitted.transform(&X).unwrap();
55/// assert_eq!(X_transformed.shape(), &[2, 8]);
56/// ```
57#[derive(Debug, Clone)]
58/// FastfoodTransform
59pub struct FastfoodTransform<State = Untrained> {
60    /// Number of random features
61    pub n_components: usize,
62    /// RBF kernel gamma parameter
63    pub gamma: Float,
64    /// Random seed
65    pub random_state: Option<u64>,
66
67    // Fitted parameters
68    // The Fastfood transform consists of: G * H * Π * B * H
69    scaling_b_: Option<Array1<Float>>,     // First diagonal scaling
70    permutation_: Option<Array1<usize>>,   // Random permutation
71    scaling_g_: Option<Array1<Float>>,     // Second diagonal scaling
72    random_offset_: Option<Array1<Float>>, // Phase offsets
73    padded_dim_: Option<usize>,            // Padded dimension (power of 2)
74    n_blocks_: Option<usize>,              // Number of Fastfood blocks
75
76    _state: PhantomData<State>,
77}
78
79impl FastfoodTransform<Untrained> {
80    /// Create a new Fastfood transform
81    pub fn new(n_components: usize) -> Self {
82        Self {
83            n_components,
84            gamma: 1.0,
85            random_state: None,
86            scaling_b_: None,
87            permutation_: None,
88            scaling_g_: None,
89            random_offset_: None,
90            padded_dim_: None,
91            n_blocks_: None,
92            _state: PhantomData,
93        }
94    }
95
96    /// Set the gamma parameter for RBF kernel
97    pub fn gamma(mut self, gamma: Float) -> Self {
98        self.gamma = gamma;
99        self
100    }
101
102    /// Set random state for reproducibility
103    pub fn random_state(mut self, seed: u64) -> Self {
104        self.random_state = Some(seed);
105        self
106    }
107}
108
109impl Estimator for FastfoodTransform<Untrained> {
110    type Config = ();
111    type Error = SklearsError;
112    type Float = Float;
113
114    fn config(&self) -> &Self::Config {
115        &()
116    }
117}
118
119impl Fit<Array2<Float>, ()> for FastfoodTransform<Untrained> {
120    type Fitted = FastfoodTransform<Trained>;
121
122    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
123        let (_, n_features) = x.dim();
124
125        let mut rng = match self.random_state {
126            Some(seed) => RealStdRng::seed_from_u64(seed),
127            None => RealStdRng::from_seed(thread_rng().gen()),
128        };
129
130        // Find the smallest power of 2 that is >= n_features
131        let padded_dim = next_power_of_2(n_features);
132
133        // Number of Fastfood blocks needed
134        let n_blocks = (self.n_components + padded_dim - 1) / padded_dim;
135
136        // Generate random diagonal scaling matrices B and G
137        let scaling_b = self.generate_random_scaling(padded_dim * n_blocks, &mut rng);
138        let scaling_g = self.generate_random_scaling(padded_dim * n_blocks, &mut rng);
139
140        // Generate random permutations
141        let permutation = self.generate_random_permutation(padded_dim * n_blocks, &mut rng);
142
143        // Generate random phase offsets
144        let uniform = RandUniform::new(0.0, 2.0 * PI).unwrap();
145        let random_offset = Array1::from_shape_fn(self.n_components, |_| rng.sample(uniform));
146
147        Ok(FastfoodTransform {
148            n_components: self.n_components,
149            gamma: self.gamma,
150            random_state: self.random_state,
151            scaling_b_: Some(scaling_b),
152            permutation_: Some(permutation),
153            scaling_g_: Some(scaling_g),
154            random_offset_: Some(random_offset),
155            padded_dim_: Some(padded_dim),
156            n_blocks_: Some(n_blocks),
157            _state: PhantomData,
158        })
159    }
160}
161
162impl FastfoodTransform<Untrained> {
163    /// Generate random diagonal scaling with Rademacher distribution
164    fn generate_random_scaling(&self, size: usize, rng: &mut RealStdRng) -> Array1<Float> {
165        let mut scaling = Array1::zeros(size);
166        for i in 0..size {
167            scaling[i] = if rng.gen::<bool>() { 1.0 } else { -1.0 };
168        }
169        scaling
170    }
171
172    /// Generate random permutation
173    fn generate_random_permutation(&self, size: usize, rng: &mut RealStdRng) -> Array1<usize> {
174        let mut permutation: Vec<usize> = (0..size).collect();
175        permutation.shuffle(rng);
176        Array1::from_vec(permutation)
177    }
178}
179
180impl Transform<Array2<Float>> for FastfoodTransform<Trained> {
181    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
182        let scaling_b = self
183            .scaling_b_
184            .as_ref()
185            .ok_or_else(|| SklearsError::NotFitted {
186                operation: "transform".to_string(),
187            })?;
188
189        let permutation = self
190            .permutation_
191            .as_ref()
192            .ok_or_else(|| SklearsError::NotFitted {
193                operation: "transform".to_string(),
194            })?;
195
196        let scaling_g = self
197            .scaling_g_
198            .as_ref()
199            .ok_or_else(|| SklearsError::NotFitted {
200                operation: "transform".to_string(),
201            })?;
202
203        let random_offset =
204            self.random_offset_
205                .as_ref()
206                .ok_or_else(|| SklearsError::NotFitted {
207                    operation: "transform".to_string(),
208                })?;
209
210        let padded_dim = *self
211            .padded_dim_
212            .as_ref()
213            .ok_or_else(|| SklearsError::NotFitted {
214                operation: "transform".to_string(),
215            })?;
216
217        let n_blocks = *self
218            .n_blocks_
219            .as_ref()
220            .ok_or_else(|| SklearsError::NotFitted {
221                operation: "transform".to_string(),
222            })?;
223
224        let (n_samples, n_features) = x.dim();
225        let mut features = Array2::zeros((n_samples, self.n_components));
226
227        // Process each sample
228        for sample_idx in 0..n_samples {
229            let sample = x.row(sample_idx);
230
231            // Apply Fastfood transform: G * H * Π * B * H * x
232            let transformed_sample = self.apply_fastfood_transform(
233                &sample,
234                scaling_b,
235                permutation,
236                scaling_g,
237                padded_dim,
238                n_blocks,
239                n_features,
240            )?;
241
242            // Take first n_components and add phase offsets
243            for j in 0..(self.n_components.min(transformed_sample.len())) {
244                let phase = transformed_sample[j] * (2.0 * self.gamma).sqrt() + random_offset[j];
245                features[[sample_idx, j]] =
246                    (2.0 / (self.n_components as Float)).sqrt() * phase.cos();
247            }
248        }
249
250        Ok(features)
251    }
252}
253
254impl FastfoodTransform<Trained> {
255    /// Apply the full Fastfood transform: G * H * Π * B * H * x
256    fn apply_fastfood_transform(
257        &self,
258        x: &scirs2_core::ndarray::ArrayBase<
259            scirs2_core::ndarray::ViewRepr<&Float>,
260            scirs2_core::ndarray::Dim<[usize; 1]>,
261        >,
262        scaling_b: &Array1<Float>,
263        permutation: &Array1<usize>,
264        scaling_g: &Array1<Float>,
265        padded_dim: usize,
266        n_blocks: usize,
267        n_features: usize,
268    ) -> Result<Array1<Float>> {
269        let mut result = Array1::zeros(padded_dim * n_blocks);
270
271        // Process each Fastfood block
272        for block in 0..n_blocks {
273            let block_start = block * padded_dim;
274            let _block_end = block_start + padded_dim;
275
276            // Step 1: Pad input to power of 2
277            let mut padded_input = Array1::zeros(padded_dim);
278            for i in 0..n_features.min(padded_dim) {
279                padded_input[i] = x[i];
280            }
281
282            // Step 2: First Hadamard transform (H)
283            let mut transformed = FastWalshHadamardTransform::transform(padded_input)?;
284
285            // Step 3: Apply first diagonal scaling (B)
286            for i in 0..padded_dim {
287                transformed[i] *= scaling_b[block_start + i];
288            }
289
290            // Step 4: Apply permutation (Π)
291            let mut permuted = Array1::zeros(padded_dim);
292            for i in 0..padded_dim {
293                let perm_idx = permutation[block_start + i] % padded_dim;
294                permuted[i] = transformed[perm_idx];
295            }
296
297            // Step 5: Second Hadamard transform (H)
298            transformed = FastWalshHadamardTransform::transform(permuted)?;
299
300            // Step 6: Apply second diagonal scaling (G)
301            for i in 0..padded_dim {
302                transformed[i] *= scaling_g[block_start + i];
303            }
304
305            // Store result for this block
306            for i in 0..padded_dim {
307                result[block_start + i] = transformed[i];
308            }
309        }
310
311        Ok(result)
312    }
313}
314
315/// Find the next power of 2 greater than or equal to n
316fn next_power_of_2(n: usize) -> usize {
317    if n <= 1 {
318        return 1;
319    }
320    let mut power = 1;
321    while power < n {
322        power *= 2;
323    }
324    power
325}
326
327/// Fastfood kernel approximation for multiple kernels
328///
329/// This variant allows approximating different kernels by adjusting
330/// the scaling and normalization factors.
331#[derive(Debug, Clone)]
332/// FastfoodKernel
333pub struct FastfoodKernel<State = Untrained> {
334    /// Number of random features
335    pub n_components: usize,
336    /// Kernel parameters
337    pub kernel_params: FastfoodKernelParams,
338    /// Random seed
339    pub random_state: Option<u64>,
340
341    // Fitted parameters
342    fastfood_transforms_: Option<Vec<FastfoodTransform<Trained>>>,
343
344    _state: PhantomData<State>,
345}
346
347/// Kernel parameters for Fastfood approximation
348#[derive(Debug, Clone)]
349/// FastfoodKernelParams
350pub enum FastfoodKernelParams {
351    /// RBF kernel with gamma parameter
352    Rbf { gamma: Float },
353    /// Matern kernel with nu and length_scale parameters
354    Matern { nu: Float, length_scale: Float },
355    /// Rational quadratic kernel with alpha and length_scale
356    RationalQuadratic { alpha: Float, length_scale: Float },
357}
358
359impl FastfoodKernel<Untrained> {
360    /// Create a new Fastfood kernel approximation
361    pub fn new(n_components: usize, kernel_params: FastfoodKernelParams) -> Self {
362        Self {
363            n_components,
364            kernel_params,
365            random_state: None,
366            fastfood_transforms_: None,
367            _state: PhantomData,
368        }
369    }
370
371    /// Set random state
372    pub fn random_state(mut self, seed: u64) -> Self {
373        self.random_state = Some(seed);
374        self
375    }
376}
377
378impl Estimator for FastfoodKernel<Untrained> {
379    type Config = ();
380    type Error = SklearsError;
381    type Float = Float;
382
383    fn config(&self) -> &Self::Config {
384        &()
385    }
386}
387
388impl Fit<Array2<Float>, ()> for FastfoodKernel<Untrained> {
389    type Fitted = FastfoodKernel<Trained>;
390
391    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
392        // For now, implement RBF kernel case
393        let gamma = match &self.kernel_params {
394            FastfoodKernelParams::Rbf { gamma } => *gamma,
395            _ => {
396                return Err(SklearsError::InvalidInput(
397                    "Only RBF kernel is currently supported for FastfoodKernel".to_string(),
398                ))
399            }
400        };
401
402        let fastfood = FastfoodTransform::new(self.n_components).gamma(gamma);
403        let fastfood = match self.random_state {
404            Some(seed) => fastfood.random_state(seed),
405            None => fastfood,
406        };
407
408        let fitted_fastfood = fastfood.fit(x, &())?;
409        let transforms = vec![fitted_fastfood];
410
411        Ok(FastfoodKernel {
412            n_components: self.n_components,
413            kernel_params: self.kernel_params,
414            random_state: self.random_state,
415            fastfood_transforms_: Some(transforms),
416            _state: PhantomData,
417        })
418    }
419}
420
421impl Transform<Array2<Float>> for FastfoodKernel<Trained> {
422    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
423        let transforms =
424            self.fastfood_transforms_
425                .as_ref()
426                .ok_or_else(|| SklearsError::NotFitted {
427                    operation: "transform".to_string(),
428                })?;
429
430        // For now, use the first (and only) transform
431        transforms[0].transform(x)
432    }
433}
434
435#[allow(non_snake_case)]
436#[cfg(test)]
437mod tests {
438    use super::*;
439    use scirs2_core::ndarray::array;
440
441    #[test]
442    fn test_next_power_of_2() {
443        assert_eq!(next_power_of_2(1), 1);
444        assert_eq!(next_power_of_2(2), 2);
445        assert_eq!(next_power_of_2(3), 4);
446        assert_eq!(next_power_of_2(7), 8);
447        assert_eq!(next_power_of_2(8), 8);
448        assert_eq!(next_power_of_2(15), 16);
449        assert_eq!(next_power_of_2(16), 16);
450    }
451
452    #[test]
453    fn test_fastfood_transform_basic() {
454        let x = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]];
455
456        let fastfood = FastfoodTransform::new(8).gamma(0.5);
457        let fitted = fastfood.fit(&x, &()).unwrap();
458        let transformed = fitted.transform(&x).unwrap();
459
460        assert_eq!(transformed.shape(), &[3, 8]);
461    }
462
463    #[test]
464    fn test_fastfood_transform_power_of_2() {
465        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
466
467        let fastfood = FastfoodTransform::new(4).gamma(1.0);
468        let fitted = fastfood.fit(&x, &()).unwrap();
469        let transformed = fitted.transform(&x).unwrap();
470
471        assert_eq!(transformed.shape(), &[2, 4]);
472    }
473
474    #[test]
475    fn test_fastfood_kernel_rbf() {
476        let x = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
477
478        let kernel_params = FastfoodKernelParams::Rbf { gamma: 0.5 };
479        let fastfood_kernel = FastfoodKernel::new(6, kernel_params);
480        let fitted = fastfood_kernel.fit(&x, &()).unwrap();
481        let transformed = fitted.transform(&x).unwrap();
482
483        assert_eq!(transformed.shape(), &[2, 6]);
484    }
485
486    #[test]
487    fn test_fastfood_reproducibility() {
488        let x = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
489
490        let fastfood1 = FastfoodTransform::new(8).random_state(42);
491        let fitted1 = fastfood1.fit(&x, &()).unwrap();
492        let result1 = fitted1.transform(&x).unwrap();
493
494        let fastfood2 = FastfoodTransform::new(8).random_state(42);
495        let fitted2 = fastfood2.fit(&x, &()).unwrap();
496        let result2 = fitted2.transform(&x).unwrap();
497
498        assert_eq!(result1.shape(), result2.shape());
499        for (a, b) in result1.iter().zip(result2.iter()) {
500            assert!((a - b).abs() < 1e-10);
501        }
502    }
503
504    #[test]
505    fn test_fastfood_different_gamma() {
506        let x = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
507
508        let fastfood_low = FastfoodTransform::new(4).gamma(0.1);
509        let fitted_low = fastfood_low.fit(&x, &()).unwrap();
510        let result_low = fitted_low.transform(&x).unwrap();
511
512        let fastfood_high = FastfoodTransform::new(4).gamma(10.0);
513        let fitted_high = fastfood_high.fit(&x, &()).unwrap();
514        let result_high = fitted_high.transform(&x).unwrap();
515
516        assert_eq!(result_low.shape(), result_high.shape());
517        // Results should be different with different gamma values
518        let diff_sum: Float = result_low
519            .iter()
520            .zip(result_high.iter())
521            .map(|(a, b)| (a - b).abs())
522            .sum();
523        assert!(diff_sum > 1e-6);
524    }
525
526    #[test]
527    fn test_fastfood_large_dimensions() {
528        let x = array![
529            [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
530            [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0]
531        ];
532
533        let fastfood = FastfoodTransform::new(16).gamma(0.1);
534        let fitted = fastfood.fit(&x, &()).unwrap();
535        let transformed = fitted.transform(&x).unwrap();
536
537        assert_eq!(transformed.shape(), &[2, 16]);
538    }
539
540    #[test]
541    fn test_fastfood_single_sample() {
542        let x = array![[1.0, 2.0, 3.0, 4.0]];
543
544        let fastfood = FastfoodTransform::new(8).gamma(1.0);
545        let fitted = fastfood.fit(&x, &()).unwrap();
546        let transformed = fitted.transform(&x).unwrap();
547
548        assert_eq!(transformed.shape(), &[1, 8]);
549    }
550
551    #[test]
552    fn test_fastfood_edge_cases() {
553        // Test with minimal dimensions
554        let x = array![[1.0], [2.0]];
555
556        let fastfood = FastfoodTransform::new(2).gamma(1.0);
557        let fitted = fastfood.fit(&x, &()).unwrap();
558        let transformed = fitted.transform(&x).unwrap();
559
560        assert_eq!(transformed.shape(), &[2, 2]);
561
562        // Test with many components
563        let x2 = array![[1.0, 2.0], [3.0, 4.0]];
564        let fastfood2 = FastfoodTransform::new(32).gamma(0.5);
565        let fitted2 = fastfood2.fit(&x2, &()).unwrap();
566        let transformed2 = fitted2.transform(&x2).unwrap();
567
568        assert_eq!(transformed2.shape(), &[2, 32]);
569    }
570}