sklears_kernel_approximation/
structured_random_features.rs

1//! Structured orthogonal random features for efficient kernel approximation
2//!
3//! This module implements structured random feature methods that use structured
4//! random matrices (like Hadamard matrices) to reduce computational complexity
5//! while maintaining approximation quality. Also includes quasi-random and
6//! low-discrepancy sequence methods for improved feature distribution.
7
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
10use scirs2_core::random::rngs::StdRng as RealStdRng;
11use scirs2_core::random::Distribution;
12use scirs2_core::random::Rng;
13use scirs2_core::random::{thread_rng, SeedableRng};
14use scirs2_core::StandardNormal;
15use sklears_core::{
16    error::{Result, SklearsError},
17    prelude::{Fit, Transform},
18    traits::{Estimator, Trained, Untrained},
19    types::Float,
20};
21use std::f64::consts::PI;
22use std::marker::PhantomData;
23
24/// Type of structured matrix to use
25#[derive(Debug, Clone)]
26/// StructuredMatrix
27pub enum StructuredMatrix {
28    Hadamard,
29    DCT,
30    Circulant,
31    Toeplitz,
32}
33
34/// Type of quasi-random sequence to use for feature generation
35#[derive(Debug, Clone)]
36/// QuasiRandomSequence
37pub enum QuasiRandomSequence {
38    VanDerCorput,
39    Halton,
40    Sobol,
41    PseudoRandom,
42}
43
44/// Low-discrepancy sequence generators for quasi-random features
45pub struct LowDiscrepancySequences;
46
47impl LowDiscrepancySequences {
48    /// Generate Van der Corput sequence in base 2
49    /// This provides a 1D low-discrepancy sequence
50    pub fn van_der_corput(n: usize) -> Vec<Float> {
51        (0..n)
52            .map(|i| {
53                let mut value = 0.0;
54                let mut base_inv = 0.5;
55                let mut n = i + 1;
56
57                while n > 0 {
58                    if n % 2 == 1 {
59                        value += base_inv;
60                    }
61                    base_inv *= 0.5;
62                    n /= 2;
63                }
64                value
65            })
66            .collect()
67    }
68
69    /// Generate Halton sequence for multi-dimensional low-discrepancy
70    /// Uses prime bases for each dimension
71    pub fn halton(n: usize, dimensions: usize) -> Array2<Float> {
72        let primes = [
73            2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
74        ];
75
76        if dimensions > primes.len() {
77            panic!("Maximum supported dimensions: {}", primes.len());
78        }
79
80        let mut sequence = Array2::zeros((n, dimensions));
81
82        for dim in 0..dimensions {
83            let base = primes[dim] as Float;
84            for i in 0..n {
85                let mut value = 0.0;
86                let mut base_inv = 1.0 / base;
87                let mut n = i + 1;
88
89                while n > 0 {
90                    value += (n % primes[dim]) as Float * base_inv;
91                    base_inv /= base;
92                    n /= primes[dim];
93                }
94                sequence[[i, dim]] = value;
95            }
96        }
97
98        sequence
99    }
100
101    /// Generate simplified Sobol sequence for higher dimensions
102    /// This is a basic implementation focusing on the key properties
103    pub fn sobol(n: usize, dimensions: usize) -> Array2<Float> {
104        let mut sequence = Array2::zeros((n, dimensions));
105
106        // For the first dimension, use Van der Corput base 2
107        let first_dim = Self::van_der_corput(n);
108        for i in 0..n {
109            sequence[[i, 0]] = first_dim[i];
110        }
111
112        // For additional dimensions, use Van der Corput with different bases
113        let bases = [
114            3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73,
115        ];
116
117        for dim in 1..dimensions {
118            let base = if dim - 1 < bases.len() {
119                bases[dim - 1]
120            } else {
121                2 + dim
122            };
123            for i in 0..n {
124                let mut value = 0.0;
125                let mut base_inv = 1.0 / (base as Float);
126                let mut n = i + 1;
127
128                while n > 0 {
129                    value += (n % base) as Float * base_inv;
130                    base_inv /= base as Float;
131                    n /= base;
132                }
133                sequence[[i, dim]] = value;
134            }
135        }
136
137        sequence
138    }
139
140    /// Transform uniform low-discrepancy sequence to Gaussian using inverse normal CDF
141    /// Box-Muller-like transformation for quasi-random Gaussian variables
142    pub fn uniform_to_gaussian(uniform_sequence: &Array2<Float>) -> Array2<Float> {
143        let (n, dim) = uniform_sequence.dim();
144        let mut gaussian_sequence = Array2::zeros((n, dim));
145
146        for i in 0..n {
147            for j in 0..dim {
148                let u = uniform_sequence[[i, j]];
149                // Prevent extreme values
150                let u_clamped = u.max(1e-10).min(1.0 - 1e-10);
151
152                // Approximate inverse normal CDF using Beasley-Springer-Moro algorithm
153                let x = Self::inverse_normal_cdf(u_clamped);
154                gaussian_sequence[[i, j]] = x;
155            }
156        }
157
158        gaussian_sequence
159    }
160
161    /// Approximate inverse normal CDF using rational approximation
162    fn inverse_normal_cdf(u: Float) -> Float {
163        if u <= 0.0 || u >= 1.0 {
164            return if u <= 0.0 {
165                Float::NEG_INFINITY
166            } else {
167                Float::INFINITY
168            };
169        }
170
171        let u = if u > 0.5 { 1.0 - u } else { u };
172        let sign = if u == 1.0 - u { 1.0 } else { -1.0 };
173
174        let t = (-2.0 * u.ln()).sqrt();
175
176        // Coefficients for rational approximation
177        let c0 = 2.515517;
178        let c1 = 0.802853;
179        let c2 = 0.010328;
180        let d1 = 1.432788;
181        let d2 = 0.189269;
182        let d3 = 0.001308;
183
184        let numerator = c0 + c1 * t + c2 * t * t;
185        let denominator = 1.0 + d1 * t + d2 * t * t + d3 * t * t * t;
186
187        sign * (t - numerator / denominator)
188    }
189}
190
191/// Structured orthogonal random features
192///
193/// Uses structured random matrices to approximate RBF kernels more efficiently
194/// than standard random Fourier features. The structured approach reduces
195/// computational complexity from O(d*D) to O(d*log(D)) for certain operations.
196///
197/// # Parameters
198///
199/// * `n_components` - Number of random features to generate
200/// * `gamma` - RBF kernel parameter (default: 1.0)
201/// * `structured_matrix` - Type of structured matrix to use
202/// * `random_state` - Random seed for reproducibility
203///
204/// # Examples
205///
206/// ```rust,ignore
207/// use sklears_kernel_approximation::structured_random_features::{
208#[derive(Debug, Clone)]
209pub struct StructuredRandomFeatures<State = Untrained> {
210    pub n_components: usize,
211    pub gamma: Float,
212    pub structured_matrix: StructuredMatrix,
213    pub random_state: Option<u64>,
214
215    // Fitted parameters
216    random_weights_: Option<Array2<Float>>,
217    random_offset_: Option<Array1<Float>>,
218    structured_transform_: Option<Array2<Float>>,
219
220    _state: PhantomData<State>,
221}
222
223impl StructuredRandomFeatures<Untrained> {
224    /// Create a new structured random features transformer
225    pub fn new(n_components: usize) -> Self {
226        Self {
227            n_components,
228            gamma: 1.0,
229            structured_matrix: StructuredMatrix::Hadamard,
230            random_state: None,
231            random_weights_: None,
232            random_offset_: None,
233            structured_transform_: None,
234            _state: PhantomData,
235        }
236    }
237
238    /// Set the gamma parameter for RBF kernel
239    pub fn gamma(mut self, gamma: Float) -> Self {
240        self.gamma = gamma;
241        self
242    }
243
244    /// Set the structured matrix type
245    pub fn structured_matrix(mut self, matrix_type: StructuredMatrix) -> Self {
246        self.structured_matrix = matrix_type;
247        self
248    }
249
250    /// Set random state for reproducibility
251    pub fn random_state(mut self, seed: u64) -> Self {
252        self.random_state = Some(seed);
253        self
254    }
255}
256
257impl Estimator for StructuredRandomFeatures<Untrained> {
258    type Config = ();
259    type Error = SklearsError;
260    type Float = Float;
261
262    fn config(&self) -> &Self::Config {
263        &()
264    }
265}
266
267impl Fit<Array2<Float>, ()> for StructuredRandomFeatures<Untrained> {
268    type Fitted = StructuredRandomFeatures<Trained>;
269
270    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
271        let (_, n_features) = x.dim();
272
273        let mut rng = match self.random_state {
274            Some(seed) => RealStdRng::seed_from_u64(seed),
275            None => RealStdRng::from_seed(thread_rng().gen()),
276        };
277
278        // Generate structured transform matrix
279        let structured_transform = self.generate_structured_matrix(n_features, &mut rng)?;
280
281        // Generate random weights for mixing
282        let normal = RandNormal::new(0.0, (2.0 * self.gamma).sqrt()).unwrap();
283        let random_weights =
284            Array2::from_shape_fn((n_features, self.n_components), |_| rng.sample(normal));
285
286        // Generate random offset for phase
287        let uniform = RandUniform::new(0.0, 2.0 * PI).unwrap();
288        let random_offset = Array1::from_shape_fn(self.n_components, |_| rng.sample(uniform));
289
290        Ok(StructuredRandomFeatures {
291            n_components: self.n_components,
292            gamma: self.gamma,
293            structured_matrix: self.structured_matrix,
294            random_state: self.random_state,
295            random_weights_: Some(random_weights),
296            random_offset_: Some(random_offset),
297            structured_transform_: Some(structured_transform),
298            _state: PhantomData,
299        })
300    }
301}
302
303impl StructuredRandomFeatures<Untrained> {
304    /// Generate structured matrix based on the specified type
305    fn generate_structured_matrix(
306        &self,
307        n_features: usize,
308        rng: &mut RealStdRng,
309    ) -> Result<Array2<Float>> {
310        match &self.structured_matrix {
311            StructuredMatrix::Hadamard => self.generate_hadamard_matrix(n_features, rng),
312            StructuredMatrix::DCT => self.generate_dct_matrix(n_features),
313            StructuredMatrix::Circulant => self.generate_circulant_matrix(n_features, rng),
314            StructuredMatrix::Toeplitz => self.generate_toeplitz_matrix(n_features, rng),
315        }
316    }
317
318    /// Generate (approximate) Hadamard matrix
319    fn generate_hadamard_matrix(
320        &self,
321        n_features: usize,
322        rng: &mut RealStdRng,
323    ) -> Result<Array2<Float>> {
324        // For simplicity, generate a randomized orthogonal-like matrix
325        // True Hadamard matrices exist only for specific sizes
326        let mut matrix = Array2::zeros((n_features, n_features));
327
328        // Generate random signs for each entry
329        for i in 0..n_features {
330            for j in 0..n_features {
331                matrix[[i, j]] = if rng.gen::<bool>() { 1.0 } else { -1.0 };
332            }
333        }
334
335        // Normalize to make approximately orthogonal
336        for mut row in matrix.rows_mut() {
337            let norm = (row.dot(&row) as Float).sqrt();
338            if norm > 1e-10 {
339                row /= norm;
340            }
341        }
342
343        Ok(matrix)
344    }
345
346    /// Generate Discrete Cosine Transform matrix
347    fn generate_dct_matrix(&self, n_features: usize) -> Result<Array2<Float>> {
348        let mut matrix = Array2::zeros((n_features, n_features));
349
350        for i in 0..n_features {
351            for j in 0..n_features {
352                let coeff = if i == 0 {
353                    (1.0 / (n_features as Float)).sqrt()
354                } else {
355                    (2.0 / (n_features as Float)).sqrt()
356                };
357
358                matrix[[i, j]] = coeff
359                    * ((PI * (i as Float) * (2.0 * (j as Float) + 1.0))
360                        / (2.0 * (n_features as Float)))
361                        .cos();
362            }
363        }
364
365        Ok(matrix)
366    }
367
368    /// Generate Circulant matrix
369    fn generate_circulant_matrix(
370        &self,
371        n_features: usize,
372        rng: &mut RealStdRng,
373    ) -> Result<Array2<Float>> {
374        let mut matrix = Array2::zeros((n_features, n_features));
375
376        // Generate first row randomly
377        let first_row: Vec<Float> = (0..n_features)
378            .map(|_| StandardNormal.sample(rng))
379            .collect();
380
381        // Fill circulant matrix
382        for i in 0..n_features {
383            for j in 0..n_features {
384                matrix[[i, j]] = first_row[(j + n_features - i) % n_features];
385            }
386        }
387
388        // Normalize rows
389        for mut row in matrix.rows_mut() {
390            let norm = (row.dot(&row) as Float).sqrt();
391            if norm > 1e-10 {
392                row /= norm;
393            }
394        }
395
396        Ok(matrix)
397    }
398
399    /// Generate Toeplitz matrix
400    fn generate_toeplitz_matrix(
401        &self,
402        n_features: usize,
403        rng: &mut RealStdRng,
404    ) -> Result<Array2<Float>> {
405        let mut matrix = Array2::zeros((n_features, n_features));
406
407        // Generate values for first row and first column
408        let first_row: Vec<Float> = (0..n_features)
409            .map(|_| StandardNormal.sample(rng))
410            .collect();
411        let first_col: Vec<Float> = (1..n_features)
412            .map(|_| StandardNormal.sample(rng))
413            .collect();
414
415        // Fill Toeplitz matrix
416        for i in 0..n_features {
417            for j in 0..n_features {
418                if i <= j {
419                    matrix[[i, j]] = first_row[j - i];
420                } else {
421                    matrix[[i, j]] = first_col[i - j - 1];
422                }
423            }
424        }
425
426        // Normalize rows
427        for mut row in matrix.rows_mut() {
428            let norm = (row.dot(&row) as Float).sqrt();
429            if norm > 1e-10 {
430                row /= norm;
431            }
432        }
433
434        Ok(matrix)
435    }
436}
437
438impl Transform<Array2<Float>> for StructuredRandomFeatures<Trained> {
439    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
440        let random_weights =
441            self.random_weights_
442                .as_ref()
443                .ok_or_else(|| SklearsError::NotFitted {
444                    operation: "transform".to_string(),
445                })?;
446
447        let random_offset =
448            self.random_offset_
449                .as_ref()
450                .ok_or_else(|| SklearsError::NotFitted {
451                    operation: "transform".to_string(),
452                })?;
453
454        let structured_transform =
455            self.structured_transform_
456                .as_ref()
457                .ok_or_else(|| SklearsError::NotFitted {
458                    operation: "transform".to_string(),
459                })?;
460
461        let (n_samples, _) = x.dim();
462
463        // Apply structured transformation first
464        let structured_x = x.dot(structured_transform);
465
466        // Apply random projections
467        let projected = structured_x.dot(random_weights);
468
469        // Add phase offsets and compute cosine features
470        let mut features = Array2::zeros((n_samples, self.n_components));
471
472        for i in 0..n_samples {
473            for j in 0..self.n_components {
474                let phase = projected[[i, j]] + random_offset[j];
475                features[[i, j]] = (2.0 / (self.n_components as Float)).sqrt() * phase.cos();
476            }
477        }
478
479        Ok(features)
480    }
481}
482
483/// Fast Walsh-Hadamard Transform for efficient structured features
484///
485/// This implements the Fast Walsh-Hadamard Transform (FWHT) which can be used
486/// to accelerate computations with Hadamard-structured random features.
487pub struct FastWalshHadamardTransform;
488
489impl FastWalshHadamardTransform {
490    /// Apply Fast Walsh-Hadamard Transform to input vector
491    ///
492    /// Time complexity: O(n log n) where n is the length of the input
493    /// Input length must be a power of 2
494    pub fn transform(mut data: Array1<Float>) -> Result<Array1<Float>> {
495        let n = data.len();
496
497        // Check if n is a power of 2
498        if n & (n - 1) != 0 {
499            return Err(SklearsError::InvalidInput(
500                "Input length must be a power of 2 for FWHT".to_string(),
501            ));
502        }
503
504        // Perform FWHT
505        let mut h = 1;
506        while h < n {
507            for i in (0..n).step_by(h * 2) {
508                for j in i..i + h {
509                    let u = data[j];
510                    let v = data[j + h];
511                    data[j] = u + v;
512                    data[j + h] = u - v;
513                }
514            }
515            h *= 2;
516        }
517
518        // Normalize
519        data /= (n as Float).sqrt();
520
521        Ok(data)
522    }
523
524    /// Apply FWHT to each row of a 2D array
525    pub fn transform_rows(mut data: Array2<Float>) -> Result<Array2<Float>> {
526        let (n_rows, n_cols) = data.dim();
527
528        // Check if n_cols is a power of 2
529        if n_cols & (n_cols - 1) != 0 {
530            return Err(SklearsError::InvalidInput(
531                "Number of columns must be a power of 2 for FWHT".to_string(),
532            ));
533        }
534
535        for i in 0..n_rows {
536            let row = data.row(i).to_owned();
537            let transformed_row = Self::transform(row)?;
538            data.row_mut(i).assign(&transformed_row);
539        }
540
541        Ok(data)
542    }
543}
544
545/// Structured Random Features using Fast Walsh-Hadamard Transform
546///
547/// This is an optimized version that uses FWHT for better computational efficiency.
548/// Input dimension must be a power of 2 for optimal performance.
549#[derive(Debug, Clone)]
550/// StructuredRFFHadamard
551pub struct StructuredRFFHadamard<State = Untrained> {
552    /// Number of random features
553    pub n_components: usize,
554    /// RBF kernel gamma parameter
555    pub gamma: Float,
556    /// Random seed
557    pub random_state: Option<u64>,
558
559    // Fitted parameters
560    random_signs_: Option<Array2<Float>>,
561    random_offset_: Option<Array1<Float>>,
562    gaussian_weights_: Option<Array1<Float>>,
563
564    _state: PhantomData<State>,
565}
566
567impl StructuredRFFHadamard<Untrained> {
568    /// Create a new structured RFF with Hadamard transforms
569    pub fn new(n_components: usize) -> Self {
570        Self {
571            n_components,
572            gamma: 1.0,
573            random_state: None,
574            random_signs_: None,
575            random_offset_: None,
576            gaussian_weights_: None,
577            _state: PhantomData,
578        }
579    }
580
581    /// Set gamma parameter
582    pub fn gamma(mut self, gamma: Float) -> Self {
583        self.gamma = gamma;
584        self
585    }
586
587    /// Set random state
588    pub fn random_state(mut self, seed: u64) -> Self {
589        self.random_state = Some(seed);
590        self
591    }
592}
593
594impl Estimator for StructuredRFFHadamard<Untrained> {
595    type Config = ();
596    type Error = SklearsError;
597    type Float = Float;
598
599    fn config(&self) -> &Self::Config {
600        &()
601    }
602}
603
604impl Fit<Array2<Float>, ()> for StructuredRFFHadamard<Untrained> {
605    type Fitted = StructuredRFFHadamard<Trained>;
606
607    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
608        let (_, n_features) = x.dim();
609
610        // Check if n_features is a power of 2
611        if n_features & (n_features - 1) != 0 {
612            return Err(SklearsError::InvalidInput(
613                "Number of features must be a power of 2 for structured Hadamard RFF".to_string(),
614            ));
615        }
616
617        let mut rng = match self.random_state {
618            Some(seed) => RealStdRng::seed_from_u64(seed),
619            None => RealStdRng::from_seed(thread_rng().gen()),
620        };
621
622        // Generate random signs for Hadamard transforms
623        let mut random_signs = Array2::zeros((self.n_components, n_features));
624        for i in 0..self.n_components {
625            for j in 0..n_features {
626                random_signs[[i, j]] = if rng.gen::<bool>() { 1.0 } else { -1.0 };
627            }
628        }
629
630        // Generate Gaussian weights
631        let normal = RandNormal::new(0.0, 1.0).unwrap();
632        let gaussian_weights = Array1::from_shape_fn(n_features, |_| rng.sample(normal));
633
634        // Generate random offsets
635        let uniform = RandUniform::new(0.0, 2.0 * PI).unwrap();
636        let random_offset = Array1::from_shape_fn(self.n_components, |_| rng.sample(uniform));
637
638        Ok(StructuredRFFHadamard {
639            n_components: self.n_components,
640            gamma: self.gamma,
641            random_state: self.random_state,
642            random_signs_: Some(random_signs),
643            random_offset_: Some(random_offset),
644            gaussian_weights_: Some(gaussian_weights),
645            _state: PhantomData,
646        })
647    }
648}
649
650impl Transform<Array2<Float>> for StructuredRFFHadamard<Trained> {
651    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
652        let random_signs = self
653            .random_signs_
654            .as_ref()
655            .ok_or_else(|| SklearsError::NotFitted {
656                operation: "transform".to_string(),
657            })?;
658
659        let random_offset =
660            self.random_offset_
661                .as_ref()
662                .ok_or_else(|| SklearsError::NotFitted {
663                    operation: "transform".to_string(),
664                })?;
665
666        let gaussian_weights =
667            self.gaussian_weights_
668                .as_ref()
669                .ok_or_else(|| SklearsError::NotFitted {
670                    operation: "transform".to_string(),
671                })?;
672
673        let (n_samples, n_features) = x.dim();
674        let mut features = Array2::zeros((n_samples, self.n_components));
675
676        // For each component, apply structured transformation
677        for comp in 0..self.n_components {
678            for sample in 0..n_samples {
679                // Element-wise multiplication with random signs
680                let mut signed_input = Array1::zeros(n_features);
681                for feat in 0..n_features {
682                    signed_input[feat] = x[[sample, feat]] * random_signs[[comp, feat]];
683                }
684
685                // Apply Fast Walsh-Hadamard Transform
686                let transformed = FastWalshHadamardTransform::transform(signed_input)?;
687
688                // Scale by Gaussian weights and gamma
689                let mut projected = 0.0;
690                for feat in 0..n_features {
691                    projected += transformed[feat] * gaussian_weights[feat];
692                }
693                projected *= (2.0 * self.gamma).sqrt();
694
695                // Add offset and compute cosine
696                let phase = projected + random_offset[comp];
697                features[[sample, comp]] =
698                    (2.0 / (self.n_components as Float)).sqrt() * phase.cos();
699            }
700        }
701
702        Ok(features)
703    }
704}
705
706#[allow(non_snake_case)]
707#[cfg(test)]
708mod tests {
709    use super::*;
710    use scirs2_core::ndarray::array;
711
712    #[test]
713    fn test_structured_random_features_basic() {
714        let x = array![
715            [1.0, 2.0, 3.0, 4.0],
716            [2.0, 3.0, 4.0, 5.0],
717            [3.0, 4.0, 5.0, 6.0]
718        ];
719
720        let srf = StructuredRandomFeatures::new(8).gamma(0.5);
721        let fitted = srf.fit(&x, &()).unwrap();
722        let transformed = fitted.transform(&x).unwrap();
723
724        assert_eq!(transformed.shape(), &[3, 8]);
725    }
726
727    #[test]
728    fn test_structured_matrices() {
729        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
730
731        // Test Hadamard
732        let hadamard_srf =
733            StructuredRandomFeatures::new(4).structured_matrix(StructuredMatrix::Hadamard);
734        let hadamard_fitted = hadamard_srf.fit(&x, &()).unwrap();
735        let hadamard_result = hadamard_fitted.transform(&x).unwrap();
736        assert_eq!(hadamard_result.shape(), &[2, 4]);
737
738        // Test DCT
739        let dct_srf = StructuredRandomFeatures::new(4).structured_matrix(StructuredMatrix::DCT);
740        let dct_fitted = dct_srf.fit(&x, &()).unwrap();
741        let dct_result = dct_fitted.transform(&x).unwrap();
742        assert_eq!(dct_result.shape(), &[2, 4]);
743
744        // Test Circulant
745        let circulant_srf =
746            StructuredRandomFeatures::new(4).structured_matrix(StructuredMatrix::Circulant);
747        let circulant_fitted = circulant_srf.fit(&x, &()).unwrap();
748        let circulant_result = circulant_fitted.transform(&x).unwrap();
749        assert_eq!(circulant_result.shape(), &[2, 4]);
750    }
751
752    #[test]
753    fn test_fast_walsh_hadamard_transform() {
754        let data = array![1.0, 2.0, 3.0, 4.0];
755        let result = FastWalshHadamardTransform::transform(data).unwrap();
756        assert_eq!(result.len(), 4);
757
758        // Test with 2D array
759        let data_2d = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
760        let result_2d = FastWalshHadamardTransform::transform_rows(data_2d).unwrap();
761        assert_eq!(result_2d.shape(), &[2, 4]);
762    }
763
764    #[test]
765    fn test_structured_rff_hadamard() {
766        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
767
768        let srf_h = StructuredRFFHadamard::new(6).gamma(0.5);
769        let fitted = srf_h.fit(&x, &()).unwrap();
770        let transformed = fitted.transform(&x).unwrap();
771
772        assert_eq!(transformed.shape(), &[2, 6]);
773    }
774
775    #[test]
776    fn test_fwht_invalid_size() {
777        let data = array![1.0, 2.0, 3.0]; // Not a power of 2
778        let result = FastWalshHadamardTransform::transform(data);
779        assert!(result.is_err());
780    }
781
782    #[test]
783    fn test_reproducibility() {
784        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
785
786        let srf1 = StructuredRandomFeatures::new(8).random_state(42);
787        let fitted1 = srf1.fit(&x, &()).unwrap();
788        let result1 = fitted1.transform(&x).unwrap();
789
790        let srf2 = StructuredRandomFeatures::new(8).random_state(42);
791        let fitted2 = srf2.fit(&x, &()).unwrap();
792        let result2 = fitted2.transform(&x).unwrap();
793
794        assert_eq!(result1.shape(), result2.shape());
795        for i in 0..result1.len() {
796            assert!(
797                (result1.as_slice().unwrap()[i] - result2.as_slice().unwrap()[i]).abs() < 1e-10
798            );
799        }
800    }
801
802    #[test]
803    fn test_different_gamma_values() {
804        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
805
806        let srf_low = StructuredRandomFeatures::new(4).gamma(0.1);
807        let fitted_low = srf_low.fit(&x, &()).unwrap();
808        let result_low = fitted_low.transform(&x).unwrap();
809
810        let srf_high = StructuredRandomFeatures::new(4).gamma(10.0);
811        let fitted_high = srf_high.fit(&x, &()).unwrap();
812        let result_high = fitted_high.transform(&x).unwrap();
813
814        assert_eq!(result_low.shape(), result_high.shape());
815        // Results should be different with different gamma values
816        let diff_sum: Float = result_low
817            .iter()
818            .zip(result_high.iter())
819            .map(|(a, b)| (a - b).abs())
820            .sum();
821        assert!(diff_sum > 1e-6);
822    }
823}