Skip to main content

sklears_kernel_approximation/
structured_random_features.rs

1//! Structured orthogonal random features for efficient kernel approximation
2//!
3//! This module implements structured random feature methods that use structured
4//! random matrices (like Hadamard matrices) to reduce computational complexity
5//! while maintaining approximation quality. Also includes quasi-random and
6//! low-discrepancy sequence methods for improved feature distribution.
7
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
10use scirs2_core::random::rngs::StdRng as RealStdRng;
11use scirs2_core::random::Distribution;
12use scirs2_core::random::RngExt;
13use scirs2_core::random::{thread_rng, SeedableRng};
14use scirs2_core::StandardNormal;
15use sklears_core::{
16    error::{Result, SklearsError},
17    prelude::{Fit, Transform},
18    traits::{Estimator, Trained, Untrained},
19    types::Float,
20};
21use std::f64::consts::PI;
22use std::marker::PhantomData;
23
24/// Type of structured matrix to use
25#[derive(Debug, Clone)]
26/// StructuredMatrix
27pub enum StructuredMatrix {
28    Hadamard,
29    DCT,
30    Circulant,
31    Toeplitz,
32}
33
34/// Type of quasi-random sequence to use for feature generation
35#[derive(Debug, Clone)]
36/// QuasiRandomSequence
37pub enum QuasiRandomSequence {
38    VanDerCorput,
39    Halton,
40    Sobol,
41    PseudoRandom,
42}
43
44/// Low-discrepancy sequence generators for quasi-random features
45pub struct LowDiscrepancySequences;
46
47impl LowDiscrepancySequences {
48    /// Generate Van der Corput sequence in base 2
49    /// This provides a 1D low-discrepancy sequence
50    pub fn van_der_corput(n: usize) -> Vec<Float> {
51        (0..n)
52            .map(|i| {
53                let mut value = 0.0;
54                let mut base_inv = 0.5;
55                let mut n = i + 1;
56
57                while n > 0 {
58                    if n % 2 == 1 {
59                        value += base_inv;
60                    }
61                    base_inv *= 0.5;
62                    n /= 2;
63                }
64                value
65            })
66            .collect()
67    }
68
69    /// Generate Halton sequence for multi-dimensional low-discrepancy
70    /// Uses prime bases for each dimension
71    pub fn halton(n: usize, dimensions: usize) -> Array2<Float> {
72        let primes = [
73            2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
74        ];
75
76        if dimensions > primes.len() {
77            panic!("Maximum supported dimensions: {}", primes.len());
78        }
79
80        let mut sequence = Array2::zeros((n, dimensions));
81
82        for dim in 0..dimensions {
83            let base = primes[dim] as Float;
84            for i in 0..n {
85                let mut value = 0.0;
86                let mut base_inv = 1.0 / base;
87                let mut n = i + 1;
88
89                while n > 0 {
90                    value += (n % primes[dim]) as Float * base_inv;
91                    base_inv /= base;
92                    n /= primes[dim];
93                }
94                sequence[[i, dim]] = value;
95            }
96        }
97
98        sequence
99    }
100
101    /// Generate simplified Sobol sequence for higher dimensions
102    /// This is a basic implementation focusing on the key properties
103    pub fn sobol(n: usize, dimensions: usize) -> Array2<Float> {
104        let mut sequence = Array2::zeros((n, dimensions));
105
106        // For the first dimension, use Van der Corput base 2
107        let first_dim = Self::van_der_corput(n);
108        for i in 0..n {
109            sequence[[i, 0]] = first_dim[i];
110        }
111
112        // For additional dimensions, use Van der Corput with different bases
113        let bases = [
114            3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73,
115        ];
116
117        for dim in 1..dimensions {
118            let base = if dim - 1 < bases.len() {
119                bases[dim - 1]
120            } else {
121                2 + dim
122            };
123            for i in 0..n {
124                let mut value = 0.0;
125                let mut base_inv = 1.0 / (base as Float);
126                let mut n = i + 1;
127
128                while n > 0 {
129                    value += (n % base) as Float * base_inv;
130                    base_inv /= base as Float;
131                    n /= base;
132                }
133                sequence[[i, dim]] = value;
134            }
135        }
136
137        sequence
138    }
139
140    /// Transform uniform low-discrepancy sequence to Gaussian using inverse normal CDF
141    /// Box-Muller-like transformation for quasi-random Gaussian variables
142    pub fn uniform_to_gaussian(uniform_sequence: &Array2<Float>) -> Array2<Float> {
143        let (n, dim) = uniform_sequence.dim();
144        let mut gaussian_sequence = Array2::zeros((n, dim));
145
146        for i in 0..n {
147            for j in 0..dim {
148                let u = uniform_sequence[[i, j]];
149                // Prevent extreme values
150                let u_clamped = u.max(1e-10).min(1.0 - 1e-10);
151
152                // Approximate inverse normal CDF using Beasley-Springer-Moro algorithm
153                let x = Self::inverse_normal_cdf(u_clamped);
154                gaussian_sequence[[i, j]] = x;
155            }
156        }
157
158        gaussian_sequence
159    }
160
161    /// Approximate inverse normal CDF using rational approximation
162    fn inverse_normal_cdf(u: Float) -> Float {
163        if u <= 0.0 || u >= 1.0 {
164            return if u <= 0.0 {
165                Float::NEG_INFINITY
166            } else {
167                Float::INFINITY
168            };
169        }
170
171        let u = if u > 0.5 { 1.0 - u } else { u };
172        let sign = if u == 1.0 - u { 1.0 } else { -1.0 };
173
174        let t = (-2.0 * u.ln()).sqrt();
175
176        // Coefficients for rational approximation
177        let c0 = 2.515517;
178        let c1 = 0.802853;
179        let c2 = 0.010328;
180        let d1 = 1.432788;
181        let d2 = 0.189269;
182        let d3 = 0.001308;
183
184        let numerator = c0 + c1 * t + c2 * t * t;
185        let denominator = 1.0 + d1 * t + d2 * t * t + d3 * t * t * t;
186
187        sign * (t - numerator / denominator)
188    }
189}
190
191/// Structured orthogonal random features
192///
193/// Uses structured random matrices to approximate RBF kernels more efficiently
194/// than standard random Fourier features. The structured approach reduces
195/// computational complexity from O(d*D) to O(d*log(D)) for certain operations.
196///
197/// # Parameters
198///
199/// * `n_components` - Number of random features to generate
200/// * `gamma` - RBF kernel parameter (default: 1.0)
201/// * `structured_matrix` - Type of structured matrix to use
202/// * `random_state` - Random seed for reproducibility
203///
204/// # Examples
205///
206/// ```text
207/// use sklears_kernel_approximation::structured_random_features::{
208///     StructuredRandomFeatures, StructuredMatrix,
209/// };
210///
211/// let srf = StructuredRandomFeatures::new(100)
212///     .with_gamma(0.5)
213///     .with_structured_matrix(StructuredMatrix::Hadamard);
214/// ```
215#[derive(Debug, Clone)]
216pub struct StructuredRandomFeatures<State = Untrained> {
217    pub n_components: usize,
218    pub gamma: Float,
219    pub structured_matrix: StructuredMatrix,
220    pub random_state: Option<u64>,
221
222    // Fitted parameters
223    random_weights_: Option<Array2<Float>>,
224    random_offset_: Option<Array1<Float>>,
225    structured_transform_: Option<Array2<Float>>,
226
227    _state: PhantomData<State>,
228}
229
230impl StructuredRandomFeatures<Untrained> {
231    /// Create a new structured random features transformer
232    pub fn new(n_components: usize) -> Self {
233        Self {
234            n_components,
235            gamma: 1.0,
236            structured_matrix: StructuredMatrix::Hadamard,
237            random_state: None,
238            random_weights_: None,
239            random_offset_: None,
240            structured_transform_: None,
241            _state: PhantomData,
242        }
243    }
244
245    /// Set the gamma parameter for RBF kernel
246    pub fn gamma(mut self, gamma: Float) -> Self {
247        self.gamma = gamma;
248        self
249    }
250
251    /// Set the structured matrix type
252    pub fn structured_matrix(mut self, matrix_type: StructuredMatrix) -> Self {
253        self.structured_matrix = matrix_type;
254        self
255    }
256
257    /// Set random state for reproducibility
258    pub fn random_state(mut self, seed: u64) -> Self {
259        self.random_state = Some(seed);
260        self
261    }
262}
263
264impl Estimator for StructuredRandomFeatures<Untrained> {
265    type Config = ();
266    type Error = SklearsError;
267    type Float = Float;
268
269    fn config(&self) -> &Self::Config {
270        &()
271    }
272}
273
274impl Fit<Array2<Float>, ()> for StructuredRandomFeatures<Untrained> {
275    type Fitted = StructuredRandomFeatures<Trained>;
276
277    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
278        let (_, n_features) = x.dim();
279
280        let mut rng = match self.random_state {
281            Some(seed) => RealStdRng::seed_from_u64(seed),
282            None => RealStdRng::from_seed(thread_rng().random()),
283        };
284
285        // Generate structured transform matrix
286        let structured_transform = self.generate_structured_matrix(n_features, &mut rng)?;
287
288        // Generate random weights for mixing
289        let normal =
290            RandNormal::new(0.0, (2.0 * self.gamma).sqrt()).expect("operation should succeed");
291        let random_weights =
292            Array2::from_shape_fn((n_features, self.n_components), |_| rng.sample(normal));
293
294        // Generate random offset for phase
295        let uniform = RandUniform::new(0.0, 2.0 * PI).expect("operation should succeed");
296        let random_offset = Array1::from_shape_fn(self.n_components, |_| rng.sample(uniform));
297
298        Ok(StructuredRandomFeatures {
299            n_components: self.n_components,
300            gamma: self.gamma,
301            structured_matrix: self.structured_matrix,
302            random_state: self.random_state,
303            random_weights_: Some(random_weights),
304            random_offset_: Some(random_offset),
305            structured_transform_: Some(structured_transform),
306            _state: PhantomData,
307        })
308    }
309}
310
311impl StructuredRandomFeatures<Untrained> {
312    /// Generate structured matrix based on the specified type
313    fn generate_structured_matrix(
314        &self,
315        n_features: usize,
316        rng: &mut RealStdRng,
317    ) -> Result<Array2<Float>> {
318        match &self.structured_matrix {
319            StructuredMatrix::Hadamard => self.generate_hadamard_matrix(n_features, rng),
320            StructuredMatrix::DCT => self.generate_dct_matrix(n_features),
321            StructuredMatrix::Circulant => self.generate_circulant_matrix(n_features, rng),
322            StructuredMatrix::Toeplitz => self.generate_toeplitz_matrix(n_features, rng),
323        }
324    }
325
326    /// Generate (approximate) Hadamard matrix
327    fn generate_hadamard_matrix(
328        &self,
329        n_features: usize,
330        rng: &mut RealStdRng,
331    ) -> Result<Array2<Float>> {
332        // For simplicity, generate a randomized orthogonal-like matrix
333        // True Hadamard matrices exist only for specific sizes
334        let mut matrix = Array2::zeros((n_features, n_features));
335
336        // Generate random signs for each entry
337        for i in 0..n_features {
338            for j in 0..n_features {
339                matrix[[i, j]] = if rng.random::<bool>() { 1.0 } else { -1.0 };
340            }
341        }
342
343        // Normalize to make approximately orthogonal
344        for mut row in matrix.rows_mut() {
345            let norm = (row.dot(&row) as Float).sqrt();
346            if norm > 1e-10 {
347                row /= norm;
348            }
349        }
350
351        Ok(matrix)
352    }
353
354    /// Generate Discrete Cosine Transform matrix
355    fn generate_dct_matrix(&self, n_features: usize) -> Result<Array2<Float>> {
356        let mut matrix = Array2::zeros((n_features, n_features));
357
358        for i in 0..n_features {
359            for j in 0..n_features {
360                let coeff = if i == 0 {
361                    (1.0 / (n_features as Float)).sqrt()
362                } else {
363                    (2.0 / (n_features as Float)).sqrt()
364                };
365
366                matrix[[i, j]] = coeff
367                    * ((PI * (i as Float) * (2.0 * (j as Float) + 1.0))
368                        / (2.0 * (n_features as Float)))
369                        .cos();
370            }
371        }
372
373        Ok(matrix)
374    }
375
376    /// Generate Circulant matrix
377    fn generate_circulant_matrix(
378        &self,
379        n_features: usize,
380        rng: &mut RealStdRng,
381    ) -> Result<Array2<Float>> {
382        let mut matrix = Array2::zeros((n_features, n_features));
383
384        // Generate first row randomly
385        let first_row: Vec<Float> = (0..n_features)
386            .map(|_| StandardNormal.sample(rng))
387            .collect();
388
389        // Fill circulant matrix
390        for i in 0..n_features {
391            for j in 0..n_features {
392                matrix[[i, j]] = first_row[(j + n_features - i) % n_features];
393            }
394        }
395
396        // Normalize rows
397        for mut row in matrix.rows_mut() {
398            let norm = (row.dot(&row) as Float).sqrt();
399            if norm > 1e-10 {
400                row /= norm;
401            }
402        }
403
404        Ok(matrix)
405    }
406
407    /// Generate Toeplitz matrix
408    fn generate_toeplitz_matrix(
409        &self,
410        n_features: usize,
411        rng: &mut RealStdRng,
412    ) -> Result<Array2<Float>> {
413        let mut matrix = Array2::zeros((n_features, n_features));
414
415        // Generate values for first row and first column
416        let first_row: Vec<Float> = (0..n_features)
417            .map(|_| StandardNormal.sample(rng))
418            .collect();
419        let first_col: Vec<Float> = (1..n_features)
420            .map(|_| StandardNormal.sample(rng))
421            .collect();
422
423        // Fill Toeplitz matrix
424        for i in 0..n_features {
425            for j in 0..n_features {
426                if i <= j {
427                    matrix[[i, j]] = first_row[j - i];
428                } else {
429                    matrix[[i, j]] = first_col[i - j - 1];
430                }
431            }
432        }
433
434        // Normalize rows
435        for mut row in matrix.rows_mut() {
436            let norm = (row.dot(&row) as Float).sqrt();
437            if norm > 1e-10 {
438                row /= norm;
439            }
440        }
441
442        Ok(matrix)
443    }
444}
445
446impl Transform<Array2<Float>> for StructuredRandomFeatures<Trained> {
447    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
448        let random_weights =
449            self.random_weights_
450                .as_ref()
451                .ok_or_else(|| SklearsError::NotFitted {
452                    operation: "transform".to_string(),
453                })?;
454
455        let random_offset =
456            self.random_offset_
457                .as_ref()
458                .ok_or_else(|| SklearsError::NotFitted {
459                    operation: "transform".to_string(),
460                })?;
461
462        let structured_transform =
463            self.structured_transform_
464                .as_ref()
465                .ok_or_else(|| SklearsError::NotFitted {
466                    operation: "transform".to_string(),
467                })?;
468
469        let (n_samples, _) = x.dim();
470
471        // Apply structured transformation first
472        let structured_x = x.dot(structured_transform);
473
474        // Apply random projections
475        let projected = structured_x.dot(random_weights);
476
477        // Add phase offsets and compute cosine features
478        let mut features = Array2::zeros((n_samples, self.n_components));
479
480        for i in 0..n_samples {
481            for j in 0..self.n_components {
482                let phase = projected[[i, j]] + random_offset[j];
483                features[[i, j]] = (2.0 / (self.n_components as Float)).sqrt() * phase.cos();
484            }
485        }
486
487        Ok(features)
488    }
489}
490
491/// Fast Walsh-Hadamard Transform for efficient structured features
492///
493/// This implements the Fast Walsh-Hadamard Transform (FWHT) which can be used
494/// to accelerate computations with Hadamard-structured random features.
495pub struct FastWalshHadamardTransform;
496
497impl FastWalshHadamardTransform {
498    /// Apply Fast Walsh-Hadamard Transform to input vector
499    ///
500    /// Time complexity: O(n log n) where n is the length of the input
501    /// Input length must be a power of 2
502    pub fn transform(mut data: Array1<Float>) -> Result<Array1<Float>> {
503        let n = data.len();
504
505        // Check if n is a power of 2
506        if n & (n - 1) != 0 {
507            return Err(SklearsError::InvalidInput(
508                "Input length must be a power of 2 for FWHT".to_string(),
509            ));
510        }
511
512        // Perform FWHT
513        let mut h = 1;
514        while h < n {
515            for i in (0..n).step_by(h * 2) {
516                for j in i..i + h {
517                    let u = data[j];
518                    let v = data[j + h];
519                    data[j] = u + v;
520                    data[j + h] = u - v;
521                }
522            }
523            h *= 2;
524        }
525
526        // Normalize
527        data /= (n as Float).sqrt();
528
529        Ok(data)
530    }
531
532    /// Apply FWHT to each row of a 2D array
533    pub fn transform_rows(mut data: Array2<Float>) -> Result<Array2<Float>> {
534        let (n_rows, n_cols) = data.dim();
535
536        // Check if n_cols is a power of 2
537        if n_cols & (n_cols - 1) != 0 {
538            return Err(SklearsError::InvalidInput(
539                "Number of columns must be a power of 2 for FWHT".to_string(),
540            ));
541        }
542
543        for i in 0..n_rows {
544            let row = data.row(i).to_owned();
545            let transformed_row = Self::transform(row)?;
546            data.row_mut(i).assign(&transformed_row);
547        }
548
549        Ok(data)
550    }
551}
552
553/// Structured Random Features using Fast Walsh-Hadamard Transform
554///
555/// This is an optimized version that uses FWHT for better computational efficiency.
556/// Input dimension must be a power of 2 for optimal performance.
557#[derive(Debug, Clone)]
558/// StructuredRFFHadamard
559pub struct StructuredRFFHadamard<State = Untrained> {
560    /// Number of random features
561    pub n_components: usize,
562    /// RBF kernel gamma parameter
563    pub gamma: Float,
564    /// Random seed
565    pub random_state: Option<u64>,
566
567    // Fitted parameters
568    random_signs_: Option<Array2<Float>>,
569    random_offset_: Option<Array1<Float>>,
570    gaussian_weights_: Option<Array1<Float>>,
571
572    _state: PhantomData<State>,
573}
574
575impl StructuredRFFHadamard<Untrained> {
576    /// Create a new structured RFF with Hadamard transforms
577    pub fn new(n_components: usize) -> Self {
578        Self {
579            n_components,
580            gamma: 1.0,
581            random_state: None,
582            random_signs_: None,
583            random_offset_: None,
584            gaussian_weights_: None,
585            _state: PhantomData,
586        }
587    }
588
589    /// Set gamma parameter
590    pub fn gamma(mut self, gamma: Float) -> Self {
591        self.gamma = gamma;
592        self
593    }
594
595    /// Set random state
596    pub fn random_state(mut self, seed: u64) -> Self {
597        self.random_state = Some(seed);
598        self
599    }
600}
601
602impl Estimator for StructuredRFFHadamard<Untrained> {
603    type Config = ();
604    type Error = SklearsError;
605    type Float = Float;
606
607    fn config(&self) -> &Self::Config {
608        &()
609    }
610}
611
612impl Fit<Array2<Float>, ()> for StructuredRFFHadamard<Untrained> {
613    type Fitted = StructuredRFFHadamard<Trained>;
614
615    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
616        let (_, n_features) = x.dim();
617
618        // Check if n_features is a power of 2
619        if n_features & (n_features - 1) != 0 {
620            return Err(SklearsError::InvalidInput(
621                "Number of features must be a power of 2 for structured Hadamard RFF".to_string(),
622            ));
623        }
624
625        let mut rng = match self.random_state {
626            Some(seed) => RealStdRng::seed_from_u64(seed),
627            None => RealStdRng::from_seed(thread_rng().random()),
628        };
629
630        // Generate random signs for Hadamard transforms
631        let mut random_signs = Array2::zeros((self.n_components, n_features));
632        for i in 0..self.n_components {
633            for j in 0..n_features {
634                random_signs[[i, j]] = if rng.random::<bool>() { 1.0 } else { -1.0 };
635            }
636        }
637
638        // Generate Gaussian weights
639        let normal = RandNormal::new(0.0, 1.0).expect("operation should succeed");
640        let gaussian_weights = Array1::from_shape_fn(n_features, |_| rng.sample(normal));
641
642        // Generate random offsets
643        let uniform = RandUniform::new(0.0, 2.0 * PI).expect("operation should succeed");
644        let random_offset = Array1::from_shape_fn(self.n_components, |_| rng.sample(uniform));
645
646        Ok(StructuredRFFHadamard {
647            n_components: self.n_components,
648            gamma: self.gamma,
649            random_state: self.random_state,
650            random_signs_: Some(random_signs),
651            random_offset_: Some(random_offset),
652            gaussian_weights_: Some(gaussian_weights),
653            _state: PhantomData,
654        })
655    }
656}
657
658impl Transform<Array2<Float>> for StructuredRFFHadamard<Trained> {
659    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
660        let random_signs = self
661            .random_signs_
662            .as_ref()
663            .ok_or_else(|| SklearsError::NotFitted {
664                operation: "transform".to_string(),
665            })?;
666
667        let random_offset =
668            self.random_offset_
669                .as_ref()
670                .ok_or_else(|| SklearsError::NotFitted {
671                    operation: "transform".to_string(),
672                })?;
673
674        let gaussian_weights =
675            self.gaussian_weights_
676                .as_ref()
677                .ok_or_else(|| SklearsError::NotFitted {
678                    operation: "transform".to_string(),
679                })?;
680
681        let (n_samples, n_features) = x.dim();
682        let mut features = Array2::zeros((n_samples, self.n_components));
683
684        // For each component, apply structured transformation
685        for comp in 0..self.n_components {
686            for sample in 0..n_samples {
687                // Element-wise multiplication with random signs
688                let mut signed_input = Array1::zeros(n_features);
689                for feat in 0..n_features {
690                    signed_input[feat] = x[[sample, feat]] * random_signs[[comp, feat]];
691                }
692
693                // Apply Fast Walsh-Hadamard Transform
694                let transformed = FastWalshHadamardTransform::transform(signed_input)?;
695
696                // Scale by Gaussian weights and gamma
697                let mut projected = 0.0;
698                for feat in 0..n_features {
699                    projected += transformed[feat] * gaussian_weights[feat];
700                }
701                projected *= (2.0 * self.gamma).sqrt();
702
703                // Add offset and compute cosine
704                let phase = projected + random_offset[comp];
705                features[[sample, comp]] =
706                    (2.0 / (self.n_components as Float)).sqrt() * phase.cos();
707            }
708        }
709
710        Ok(features)
711    }
712}
713
714#[allow(non_snake_case)]
715#[cfg(test)]
716mod tests {
717    use super::*;
718    use scirs2_core::ndarray::array;
719
720    #[test]
721    fn test_structured_random_features_basic() {
722        let x = array![
723            [1.0, 2.0, 3.0, 4.0],
724            [2.0, 3.0, 4.0, 5.0],
725            [3.0, 4.0, 5.0, 6.0]
726        ];
727
728        let srf = StructuredRandomFeatures::new(8).gamma(0.5);
729        let fitted = srf.fit(&x, &()).expect("operation should succeed");
730        let transformed = fitted.transform(&x).expect("operation should succeed");
731
732        assert_eq!(transformed.shape(), &[3, 8]);
733    }
734
735    #[test]
736    fn test_structured_matrices() {
737        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
738
739        // Test Hadamard
740        let hadamard_srf =
741            StructuredRandomFeatures::new(4).structured_matrix(StructuredMatrix::Hadamard);
742        let hadamard_fitted = hadamard_srf.fit(&x, &()).expect("operation should succeed");
743        let hadamard_result = hadamard_fitted
744            .transform(&x)
745            .expect("operation should succeed");
746        assert_eq!(hadamard_result.shape(), &[2, 4]);
747
748        // Test DCT
749        let dct_srf = StructuredRandomFeatures::new(4).structured_matrix(StructuredMatrix::DCT);
750        let dct_fitted = dct_srf.fit(&x, &()).expect("operation should succeed");
751        let dct_result = dct_fitted.transform(&x).expect("operation should succeed");
752        assert_eq!(dct_result.shape(), &[2, 4]);
753
754        // Test Circulant
755        let circulant_srf =
756            StructuredRandomFeatures::new(4).structured_matrix(StructuredMatrix::Circulant);
757        let circulant_fitted = circulant_srf
758            .fit(&x, &())
759            .expect("operation should succeed");
760        let circulant_result = circulant_fitted
761            .transform(&x)
762            .expect("operation should succeed");
763        assert_eq!(circulant_result.shape(), &[2, 4]);
764    }
765
766    #[test]
767    fn test_fast_walsh_hadamard_transform() {
768        let data = array![1.0, 2.0, 3.0, 4.0];
769        let result = FastWalshHadamardTransform::transform(data).expect("operation should succeed");
770        assert_eq!(result.len(), 4);
771
772        // Test with 2D array
773        let data_2d = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
774        let result_2d =
775            FastWalshHadamardTransform::transform_rows(data_2d).expect("operation should succeed");
776        assert_eq!(result_2d.shape(), &[2, 4]);
777    }
778
779    #[test]
780    fn test_structured_rff_hadamard() {
781        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
782
783        let srf_h = StructuredRFFHadamard::new(6).gamma(0.5);
784        let fitted = srf_h.fit(&x, &()).expect("operation should succeed");
785        let transformed = fitted.transform(&x).expect("operation should succeed");
786
787        assert_eq!(transformed.shape(), &[2, 6]);
788    }
789
790    #[test]
791    fn test_fwht_invalid_size() {
792        let data = array![1.0, 2.0, 3.0]; // Not a power of 2
793        let result = FastWalshHadamardTransform::transform(data);
794        assert!(result.is_err());
795    }
796
797    #[test]
798    fn test_reproducibility() {
799        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
800
801        let srf1 = StructuredRandomFeatures::new(8).random_state(42);
802        let fitted1 = srf1.fit(&x, &()).expect("operation should succeed");
803        let result1 = fitted1.transform(&x).expect("operation should succeed");
804
805        let srf2 = StructuredRandomFeatures::new(8).random_state(42);
806        let fitted2 = srf2.fit(&x, &()).expect("operation should succeed");
807        let result2 = fitted2.transform(&x).expect("operation should succeed");
808
809        assert_eq!(result1.shape(), result2.shape());
810        for i in 0..result1.len() {
811            assert!(
812                (result1.as_slice().expect("operation should succeed")[i]
813                    - result2.as_slice().expect("operation should succeed")[i])
814                    .abs()
815                    < 1e-10
816            );
817        }
818    }
819
820    #[test]
821    fn test_different_gamma_values() {
822        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
823
824        let srf_low = StructuredRandomFeatures::new(4).gamma(0.1);
825        let fitted_low = srf_low.fit(&x, &()).expect("operation should succeed");
826        let result_low = fitted_low.transform(&x).expect("operation should succeed");
827
828        let srf_high = StructuredRandomFeatures::new(4).gamma(10.0);
829        let fitted_high = srf_high.fit(&x, &()).expect("operation should succeed");
830        let result_high = fitted_high.transform(&x).expect("operation should succeed");
831
832        assert_eq!(result_low.shape(), result_high.shape());
833        // Results should be different with different gamma values
834        let diff_sum: Float = result_low
835            .iter()
836            .zip(result_high.iter())
837            .map(|(a, b)| (a - b).abs())
838            .sum();
839        assert!(diff_sum > 1e-6);
840    }
841}