sklears_kernel_approximation/
structured_random_features.rs

1//! Structured orthogonal random features for efficient kernel approximation
2//!
3//! This module implements structured random feature methods that use structured
4//! random matrices (like Hadamard matrices) to reduce computational complexity
5//! while maintaining approximation quality. Also includes quasi-random and
6//! low-discrepancy sequence methods for improved feature distribution.
7
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
10use scirs2_core::random::rngs::StdRng as RealStdRng;
11use scirs2_core::random::Distribution;
12use scirs2_core::random::{thread_rng, Rng, SeedableRng};
13use scirs2_core::StandardNormal;
14use sklears_core::{
15    error::{Result, SklearsError},
16    prelude::{Fit, Transform},
17    traits::{Estimator, Trained, Untrained},
18    types::Float,
19};
20use std::f64::consts::PI;
21use std::marker::PhantomData;
22
23/// Type of structured matrix to use
24#[derive(Debug, Clone)]
25/// StructuredMatrix
26pub enum StructuredMatrix {
27    Hadamard,
28    DCT,
29    Circulant,
30    Toeplitz,
31}
32
33/// Type of quasi-random sequence to use for feature generation
34#[derive(Debug, Clone)]
35/// QuasiRandomSequence
36pub enum QuasiRandomSequence {
37    VanDerCorput,
38    Halton,
39    Sobol,
40    PseudoRandom,
41}
42
43/// Low-discrepancy sequence generators for quasi-random features
44pub struct LowDiscrepancySequences;
45
46impl LowDiscrepancySequences {
47    /// Generate Van der Corput sequence in base 2
48    /// This provides a 1D low-discrepancy sequence
49    pub fn van_der_corput(n: usize) -> Vec<Float> {
50        (0..n)
51            .map(|i| {
52                let mut value = 0.0;
53                let mut base_inv = 0.5;
54                let mut n = i + 1;
55
56                while n > 0 {
57                    if n % 2 == 1 {
58                        value += base_inv;
59                    }
60                    base_inv *= 0.5;
61                    n /= 2;
62                }
63                value
64            })
65            .collect()
66    }
67
68    /// Generate Halton sequence for multi-dimensional low-discrepancy
69    /// Uses prime bases for each dimension
70    pub fn halton(n: usize, dimensions: usize) -> Array2<Float> {
71        let primes = [
72            2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
73        ];
74
75        if dimensions > primes.len() {
76            panic!("Maximum supported dimensions: {}", primes.len());
77        }
78
79        let mut sequence = Array2::zeros((n, dimensions));
80
81        for dim in 0..dimensions {
82            let base = primes[dim] as Float;
83            for i in 0..n {
84                let mut value = 0.0;
85                let mut base_inv = 1.0 / base;
86                let mut n = i + 1;
87
88                while n > 0 {
89                    value += (n % primes[dim]) as Float * base_inv;
90                    base_inv /= base;
91                    n /= primes[dim];
92                }
93                sequence[[i, dim]] = value;
94            }
95        }
96
97        sequence
98    }
99
100    /// Generate simplified Sobol sequence for higher dimensions
101    /// This is a basic implementation focusing on the key properties
102    pub fn sobol(n: usize, dimensions: usize) -> Array2<Float> {
103        let mut sequence = Array2::zeros((n, dimensions));
104
105        // For the first dimension, use Van der Corput base 2
106        let first_dim = Self::van_der_corput(n);
107        for i in 0..n {
108            sequence[[i, 0]] = first_dim[i];
109        }
110
111        // For additional dimensions, use Van der Corput with different bases
112        let bases = [
113            3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73,
114        ];
115
116        for dim in 1..dimensions {
117            let base = if dim - 1 < bases.len() {
118                bases[dim - 1]
119            } else {
120                2 + dim
121            };
122            for i in 0..n {
123                let mut value = 0.0;
124                let mut base_inv = 1.0 / (base as Float);
125                let mut n = i + 1;
126
127                while n > 0 {
128                    value += (n % base) as Float * base_inv;
129                    base_inv /= base as Float;
130                    n /= base;
131                }
132                sequence[[i, dim]] = value;
133            }
134        }
135
136        sequence
137    }
138
139    /// Transform uniform low-discrepancy sequence to Gaussian using inverse normal CDF
140    /// Box-Muller-like transformation for quasi-random Gaussian variables
141    pub fn uniform_to_gaussian(uniform_sequence: &Array2<Float>) -> Array2<Float> {
142        let (n, dim) = uniform_sequence.dim();
143        let mut gaussian_sequence = Array2::zeros((n, dim));
144
145        for i in 0..n {
146            for j in 0..dim {
147                let u = uniform_sequence[[i, j]];
148                // Prevent extreme values
149                let u_clamped = u.max(1e-10).min(1.0 - 1e-10);
150
151                // Approximate inverse normal CDF using Beasley-Springer-Moro algorithm
152                let x = Self::inverse_normal_cdf(u_clamped);
153                gaussian_sequence[[i, j]] = x;
154            }
155        }
156
157        gaussian_sequence
158    }
159
160    /// Approximate inverse normal CDF using rational approximation
161    fn inverse_normal_cdf(u: Float) -> Float {
162        if u <= 0.0 || u >= 1.0 {
163            return if u <= 0.0 {
164                Float::NEG_INFINITY
165            } else {
166                Float::INFINITY
167            };
168        }
169
170        let u = if u > 0.5 { 1.0 - u } else { u };
171        let sign = if u == 1.0 - u { 1.0 } else { -1.0 };
172
173        let t = (-2.0 * u.ln()).sqrt();
174
175        // Coefficients for rational approximation
176        let c0 = 2.515517;
177        let c1 = 0.802853;
178        let c2 = 0.010328;
179        let d1 = 1.432788;
180        let d2 = 0.189269;
181        let d3 = 0.001308;
182
183        let numerator = c0 + c1 * t + c2 * t * t;
184        let denominator = 1.0 + d1 * t + d2 * t * t + d3 * t * t * t;
185
186        sign * (t - numerator / denominator)
187    }
188}
189
190/// Structured orthogonal random features
191///
192/// Uses structured random matrices to approximate RBF kernels more efficiently
193/// than standard random Fourier features. The structured approach reduces
194/// computational complexity from O(d*D) to O(d*log(D)) for certain operations.
195///
196/// # Parameters
197///
198/// * `n_components` - Number of random features to generate
199/// * `gamma` - RBF kernel parameter (default: 1.0)
200/// * `structured_matrix` - Type of structured matrix to use
201/// * `random_state` - Random seed for reproducibility
202///
203/// # Examples
204///
205/// ```rust,ignore
206/// use sklears_kernel_approximation::structured_random_features::{
207#[derive(Debug, Clone)]
208pub struct StructuredRandomFeatures<State = Untrained> {
209    pub n_components: usize,
210    pub gamma: Float,
211    pub structured_matrix: StructuredMatrix,
212    pub random_state: Option<u64>,
213
214    // Fitted parameters
215    random_weights_: Option<Array2<Float>>,
216    random_offset_: Option<Array1<Float>>,
217    structured_transform_: Option<Array2<Float>>,
218
219    _state: PhantomData<State>,
220}
221
222impl StructuredRandomFeatures<Untrained> {
223    /// Create a new structured random features transformer
224    pub fn new(n_components: usize) -> Self {
225        Self {
226            n_components,
227            gamma: 1.0,
228            structured_matrix: StructuredMatrix::Hadamard,
229            random_state: None,
230            random_weights_: None,
231            random_offset_: None,
232            structured_transform_: None,
233            _state: PhantomData,
234        }
235    }
236
237    /// Set the gamma parameter for RBF kernel
238    pub fn gamma(mut self, gamma: Float) -> Self {
239        self.gamma = gamma;
240        self
241    }
242
243    /// Set the structured matrix type
244    pub fn structured_matrix(mut self, matrix_type: StructuredMatrix) -> Self {
245        self.structured_matrix = matrix_type;
246        self
247    }
248
249    /// Set random state for reproducibility
250    pub fn random_state(mut self, seed: u64) -> Self {
251        self.random_state = Some(seed);
252        self
253    }
254}
255
256impl Estimator for StructuredRandomFeatures<Untrained> {
257    type Config = ();
258    type Error = SklearsError;
259    type Float = Float;
260
261    fn config(&self) -> &Self::Config {
262        &()
263    }
264}
265
266impl Fit<Array2<Float>, ()> for StructuredRandomFeatures<Untrained> {
267    type Fitted = StructuredRandomFeatures<Trained>;
268
269    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
270        let (_, n_features) = x.dim();
271
272        let mut rng = match self.random_state {
273            Some(seed) => RealStdRng::seed_from_u64(seed),
274            None => RealStdRng::from_seed(thread_rng().gen()),
275        };
276
277        // Generate structured transform matrix
278        let structured_transform = self.generate_structured_matrix(n_features, &mut rng)?;
279
280        // Generate random weights for mixing
281        let normal = RandNormal::new(0.0, (2.0 * self.gamma).sqrt()).unwrap();
282        let random_weights =
283            Array2::from_shape_fn((n_features, self.n_components), |_| rng.sample(normal));
284
285        // Generate random offset for phase
286        let uniform = RandUniform::new(0.0, 2.0 * PI).unwrap();
287        let random_offset = Array1::from_shape_fn(self.n_components, |_| rng.sample(uniform));
288
289        Ok(StructuredRandomFeatures {
290            n_components: self.n_components,
291            gamma: self.gamma,
292            structured_matrix: self.structured_matrix,
293            random_state: self.random_state,
294            random_weights_: Some(random_weights),
295            random_offset_: Some(random_offset),
296            structured_transform_: Some(structured_transform),
297            _state: PhantomData,
298        })
299    }
300}
301
302impl StructuredRandomFeatures<Untrained> {
303    /// Generate structured matrix based on the specified type
304    fn generate_structured_matrix(
305        &self,
306        n_features: usize,
307        rng: &mut RealStdRng,
308    ) -> Result<Array2<Float>> {
309        match &self.structured_matrix {
310            StructuredMatrix::Hadamard => self.generate_hadamard_matrix(n_features, rng),
311            StructuredMatrix::DCT => self.generate_dct_matrix(n_features),
312            StructuredMatrix::Circulant => self.generate_circulant_matrix(n_features, rng),
313            StructuredMatrix::Toeplitz => self.generate_toeplitz_matrix(n_features, rng),
314        }
315    }
316
317    /// Generate (approximate) Hadamard matrix
318    fn generate_hadamard_matrix(
319        &self,
320        n_features: usize,
321        rng: &mut RealStdRng,
322    ) -> Result<Array2<Float>> {
323        // For simplicity, generate a randomized orthogonal-like matrix
324        // True Hadamard matrices exist only for specific sizes
325        let mut matrix = Array2::zeros((n_features, n_features));
326
327        // Generate random signs for each entry
328        for i in 0..n_features {
329            for j in 0..n_features {
330                matrix[[i, j]] = if rng.gen::<bool>() { 1.0 } else { -1.0 };
331            }
332        }
333
334        // Normalize to make approximately orthogonal
335        for mut row in matrix.rows_mut() {
336            let norm = (row.dot(&row) as Float).sqrt();
337            if norm > 1e-10 {
338                row /= norm;
339            }
340        }
341
342        Ok(matrix)
343    }
344
345    /// Generate Discrete Cosine Transform matrix
346    fn generate_dct_matrix(&self, n_features: usize) -> Result<Array2<Float>> {
347        let mut matrix = Array2::zeros((n_features, n_features));
348
349        for i in 0..n_features {
350            for j in 0..n_features {
351                let coeff = if i == 0 {
352                    (1.0 / (n_features as Float)).sqrt()
353                } else {
354                    (2.0 / (n_features as Float)).sqrt()
355                };
356
357                matrix[[i, j]] = coeff
358                    * ((PI * (i as Float) * (2.0 * (j as Float) + 1.0))
359                        / (2.0 * (n_features as Float)))
360                        .cos();
361            }
362        }
363
364        Ok(matrix)
365    }
366
367    /// Generate Circulant matrix
368    fn generate_circulant_matrix(
369        &self,
370        n_features: usize,
371        rng: &mut RealStdRng,
372    ) -> Result<Array2<Float>> {
373        let mut matrix = Array2::zeros((n_features, n_features));
374
375        // Generate first row randomly
376        let first_row: Vec<Float> = (0..n_features)
377            .map(|_| StandardNormal.sample(rng))
378            .collect();
379
380        // Fill circulant matrix
381        for i in 0..n_features {
382            for j in 0..n_features {
383                matrix[[i, j]] = first_row[(j + n_features - i) % n_features];
384            }
385        }
386
387        // Normalize rows
388        for mut row in matrix.rows_mut() {
389            let norm = (row.dot(&row) as Float).sqrt();
390            if norm > 1e-10 {
391                row /= norm;
392            }
393        }
394
395        Ok(matrix)
396    }
397
398    /// Generate Toeplitz matrix
399    fn generate_toeplitz_matrix(
400        &self,
401        n_features: usize,
402        rng: &mut RealStdRng,
403    ) -> Result<Array2<Float>> {
404        let mut matrix = Array2::zeros((n_features, n_features));
405
406        // Generate values for first row and first column
407        let first_row: Vec<Float> = (0..n_features)
408            .map(|_| StandardNormal.sample(rng))
409            .collect();
410        let first_col: Vec<Float> = (1..n_features)
411            .map(|_| StandardNormal.sample(rng))
412            .collect();
413
414        // Fill Toeplitz matrix
415        for i in 0..n_features {
416            for j in 0..n_features {
417                if i <= j {
418                    matrix[[i, j]] = first_row[j - i];
419                } else {
420                    matrix[[i, j]] = first_col[i - j - 1];
421                }
422            }
423        }
424
425        // Normalize rows
426        for mut row in matrix.rows_mut() {
427            let norm = (row.dot(&row) as Float).sqrt();
428            if norm > 1e-10 {
429                row /= norm;
430            }
431        }
432
433        Ok(matrix)
434    }
435}
436
437impl Transform<Array2<Float>> for StructuredRandomFeatures<Trained> {
438    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
439        let random_weights =
440            self.random_weights_
441                .as_ref()
442                .ok_or_else(|| SklearsError::NotFitted {
443                    operation: "transform".to_string(),
444                })?;
445
446        let random_offset =
447            self.random_offset_
448                .as_ref()
449                .ok_or_else(|| SklearsError::NotFitted {
450                    operation: "transform".to_string(),
451                })?;
452
453        let structured_transform =
454            self.structured_transform_
455                .as_ref()
456                .ok_or_else(|| SklearsError::NotFitted {
457                    operation: "transform".to_string(),
458                })?;
459
460        let (n_samples, _) = x.dim();
461
462        // Apply structured transformation first
463        let structured_x = x.dot(structured_transform);
464
465        // Apply random projections
466        let projected = structured_x.dot(random_weights);
467
468        // Add phase offsets and compute cosine features
469        let mut features = Array2::zeros((n_samples, self.n_components));
470
471        for i in 0..n_samples {
472            for j in 0..self.n_components {
473                let phase = projected[[i, j]] + random_offset[j];
474                features[[i, j]] = (2.0 / (self.n_components as Float)).sqrt() * phase.cos();
475            }
476        }
477
478        Ok(features)
479    }
480}
481
482/// Fast Walsh-Hadamard Transform for efficient structured features
483///
484/// This implements the Fast Walsh-Hadamard Transform (FWHT) which can be used
485/// to accelerate computations with Hadamard-structured random features.
486pub struct FastWalshHadamardTransform;
487
488impl FastWalshHadamardTransform {
489    /// Apply Fast Walsh-Hadamard Transform to input vector
490    ///
491    /// Time complexity: O(n log n) where n is the length of the input
492    /// Input length must be a power of 2
493    pub fn transform(mut data: Array1<Float>) -> Result<Array1<Float>> {
494        let n = data.len();
495
496        // Check if n is a power of 2
497        if n & (n - 1) != 0 {
498            return Err(SklearsError::InvalidInput(
499                "Input length must be a power of 2 for FWHT".to_string(),
500            ));
501        }
502
503        // Perform FWHT
504        let mut h = 1;
505        while h < n {
506            for i in (0..n).step_by(h * 2) {
507                for j in i..i + h {
508                    let u = data[j];
509                    let v = data[j + h];
510                    data[j] = u + v;
511                    data[j + h] = u - v;
512                }
513            }
514            h *= 2;
515        }
516
517        // Normalize
518        data /= (n as Float).sqrt();
519
520        Ok(data)
521    }
522
523    /// Apply FWHT to each row of a 2D array
524    pub fn transform_rows(mut data: Array2<Float>) -> Result<Array2<Float>> {
525        let (n_rows, n_cols) = data.dim();
526
527        // Check if n_cols is a power of 2
528        if n_cols & (n_cols - 1) != 0 {
529            return Err(SklearsError::InvalidInput(
530                "Number of columns must be a power of 2 for FWHT".to_string(),
531            ));
532        }
533
534        for i in 0..n_rows {
535            let row = data.row(i).to_owned();
536            let transformed_row = Self::transform(row)?;
537            data.row_mut(i).assign(&transformed_row);
538        }
539
540        Ok(data)
541    }
542}
543
544/// Structured Random Features using Fast Walsh-Hadamard Transform
545///
546/// This is an optimized version that uses FWHT for better computational efficiency.
547/// Input dimension must be a power of 2 for optimal performance.
548#[derive(Debug, Clone)]
549/// StructuredRFFHadamard
550pub struct StructuredRFFHadamard<State = Untrained> {
551    /// Number of random features
552    pub n_components: usize,
553    /// RBF kernel gamma parameter
554    pub gamma: Float,
555    /// Random seed
556    pub random_state: Option<u64>,
557
558    // Fitted parameters
559    random_signs_: Option<Array2<Float>>,
560    random_offset_: Option<Array1<Float>>,
561    gaussian_weights_: Option<Array1<Float>>,
562
563    _state: PhantomData<State>,
564}
565
566impl StructuredRFFHadamard<Untrained> {
567    /// Create a new structured RFF with Hadamard transforms
568    pub fn new(n_components: usize) -> Self {
569        Self {
570            n_components,
571            gamma: 1.0,
572            random_state: None,
573            random_signs_: None,
574            random_offset_: None,
575            gaussian_weights_: None,
576            _state: PhantomData,
577        }
578    }
579
580    /// Set gamma parameter
581    pub fn gamma(mut self, gamma: Float) -> Self {
582        self.gamma = gamma;
583        self
584    }
585
586    /// Set random state
587    pub fn random_state(mut self, seed: u64) -> Self {
588        self.random_state = Some(seed);
589        self
590    }
591}
592
593impl Estimator for StructuredRFFHadamard<Untrained> {
594    type Config = ();
595    type Error = SklearsError;
596    type Float = Float;
597
598    fn config(&self) -> &Self::Config {
599        &()
600    }
601}
602
603impl Fit<Array2<Float>, ()> for StructuredRFFHadamard<Untrained> {
604    type Fitted = StructuredRFFHadamard<Trained>;
605
606    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
607        let (_, n_features) = x.dim();
608
609        // Check if n_features is a power of 2
610        if n_features & (n_features - 1) != 0 {
611            return Err(SklearsError::InvalidInput(
612                "Number of features must be a power of 2 for structured Hadamard RFF".to_string(),
613            ));
614        }
615
616        let mut rng = match self.random_state {
617            Some(seed) => RealStdRng::seed_from_u64(seed),
618            None => RealStdRng::from_seed(thread_rng().gen()),
619        };
620
621        // Generate random signs for Hadamard transforms
622        let mut random_signs = Array2::zeros((self.n_components, n_features));
623        for i in 0..self.n_components {
624            for j in 0..n_features {
625                random_signs[[i, j]] = if rng.gen::<bool>() { 1.0 } else { -1.0 };
626            }
627        }
628
629        // Generate Gaussian weights
630        let normal = RandNormal::new(0.0, 1.0).unwrap();
631        let gaussian_weights = Array1::from_shape_fn(n_features, |_| rng.sample(normal));
632
633        // Generate random offsets
634        let uniform = RandUniform::new(0.0, 2.0 * PI).unwrap();
635        let random_offset = Array1::from_shape_fn(self.n_components, |_| rng.sample(uniform));
636
637        Ok(StructuredRFFHadamard {
638            n_components: self.n_components,
639            gamma: self.gamma,
640            random_state: self.random_state,
641            random_signs_: Some(random_signs),
642            random_offset_: Some(random_offset),
643            gaussian_weights_: Some(gaussian_weights),
644            _state: PhantomData,
645        })
646    }
647}
648
649impl Transform<Array2<Float>> for StructuredRFFHadamard<Trained> {
650    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
651        let random_signs = self
652            .random_signs_
653            .as_ref()
654            .ok_or_else(|| SklearsError::NotFitted {
655                operation: "transform".to_string(),
656            })?;
657
658        let random_offset =
659            self.random_offset_
660                .as_ref()
661                .ok_or_else(|| SklearsError::NotFitted {
662                    operation: "transform".to_string(),
663                })?;
664
665        let gaussian_weights =
666            self.gaussian_weights_
667                .as_ref()
668                .ok_or_else(|| SklearsError::NotFitted {
669                    operation: "transform".to_string(),
670                })?;
671
672        let (n_samples, n_features) = x.dim();
673        let mut features = Array2::zeros((n_samples, self.n_components));
674
675        // For each component, apply structured transformation
676        for comp in 0..self.n_components {
677            for sample in 0..n_samples {
678                // Element-wise multiplication with random signs
679                let mut signed_input = Array1::zeros(n_features);
680                for feat in 0..n_features {
681                    signed_input[feat] = x[[sample, feat]] * random_signs[[comp, feat]];
682                }
683
684                // Apply Fast Walsh-Hadamard Transform
685                let transformed = FastWalshHadamardTransform::transform(signed_input)?;
686
687                // Scale by Gaussian weights and gamma
688                let mut projected = 0.0;
689                for feat in 0..n_features {
690                    projected += transformed[feat] * gaussian_weights[feat];
691                }
692                projected *= (2.0 * self.gamma).sqrt();
693
694                // Add offset and compute cosine
695                let phase = projected + random_offset[comp];
696                features[[sample, comp]] =
697                    (2.0 / (self.n_components as Float)).sqrt() * phase.cos();
698            }
699        }
700
701        Ok(features)
702    }
703}
704
705#[allow(non_snake_case)]
706#[cfg(test)]
707mod tests {
708    use super::*;
709    use scirs2_core::ndarray::array;
710
711    #[test]
712    fn test_structured_random_features_basic() {
713        let x = array![
714            [1.0, 2.0, 3.0, 4.0],
715            [2.0, 3.0, 4.0, 5.0],
716            [3.0, 4.0, 5.0, 6.0]
717        ];
718
719        let srf = StructuredRandomFeatures::new(8).gamma(0.5);
720        let fitted = srf.fit(&x, &()).unwrap();
721        let transformed = fitted.transform(&x).unwrap();
722
723        assert_eq!(transformed.shape(), &[3, 8]);
724    }
725
726    #[test]
727    fn test_structured_matrices() {
728        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
729
730        // Test Hadamard
731        let hadamard_srf =
732            StructuredRandomFeatures::new(4).structured_matrix(StructuredMatrix::Hadamard);
733        let hadamard_fitted = hadamard_srf.fit(&x, &()).unwrap();
734        let hadamard_result = hadamard_fitted.transform(&x).unwrap();
735        assert_eq!(hadamard_result.shape(), &[2, 4]);
736
737        // Test DCT
738        let dct_srf = StructuredRandomFeatures::new(4).structured_matrix(StructuredMatrix::DCT);
739        let dct_fitted = dct_srf.fit(&x, &()).unwrap();
740        let dct_result = dct_fitted.transform(&x).unwrap();
741        assert_eq!(dct_result.shape(), &[2, 4]);
742
743        // Test Circulant
744        let circulant_srf =
745            StructuredRandomFeatures::new(4).structured_matrix(StructuredMatrix::Circulant);
746        let circulant_fitted = circulant_srf.fit(&x, &()).unwrap();
747        let circulant_result = circulant_fitted.transform(&x).unwrap();
748        assert_eq!(circulant_result.shape(), &[2, 4]);
749    }
750
751    #[test]
752    fn test_fast_walsh_hadamard_transform() {
753        let data = array![1.0, 2.0, 3.0, 4.0];
754        let result = FastWalshHadamardTransform::transform(data).unwrap();
755        assert_eq!(result.len(), 4);
756
757        // Test with 2D array
758        let data_2d = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
759        let result_2d = FastWalshHadamardTransform::transform_rows(data_2d).unwrap();
760        assert_eq!(result_2d.shape(), &[2, 4]);
761    }
762
763    #[test]
764    fn test_structured_rff_hadamard() {
765        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
766
767        let srf_h = StructuredRFFHadamard::new(6).gamma(0.5);
768        let fitted = srf_h.fit(&x, &()).unwrap();
769        let transformed = fitted.transform(&x).unwrap();
770
771        assert_eq!(transformed.shape(), &[2, 6]);
772    }
773
774    #[test]
775    fn test_fwht_invalid_size() {
776        let data = array![1.0, 2.0, 3.0]; // Not a power of 2
777        let result = FastWalshHadamardTransform::transform(data);
778        assert!(result.is_err());
779    }
780
781    #[test]
782    fn test_reproducibility() {
783        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
784
785        let srf1 = StructuredRandomFeatures::new(8).random_state(42);
786        let fitted1 = srf1.fit(&x, &()).unwrap();
787        let result1 = fitted1.transform(&x).unwrap();
788
789        let srf2 = StructuredRandomFeatures::new(8).random_state(42);
790        let fitted2 = srf2.fit(&x, &()).unwrap();
791        let result2 = fitted2.transform(&x).unwrap();
792
793        assert_eq!(result1.shape(), result2.shape());
794        for i in 0..result1.len() {
795            assert!(
796                (result1.as_slice().unwrap()[i] - result2.as_slice().unwrap()[i]).abs() < 1e-10
797            );
798        }
799    }
800
801    #[test]
802    fn test_different_gamma_values() {
803        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
804
805        let srf_low = StructuredRandomFeatures::new(4).gamma(0.1);
806        let fitted_low = srf_low.fit(&x, &()).unwrap();
807        let result_low = fitted_low.transform(&x).unwrap();
808
809        let srf_high = StructuredRandomFeatures::new(4).gamma(10.0);
810        let fitted_high = srf_high.fit(&x, &()).unwrap();
811        let result_high = fitted_high.transform(&x).unwrap();
812
813        assert_eq!(result_low.shape(), result_high.shape());
814        // Results should be different with different gamma values
815        let diff_sum: Float = result_low
816            .iter()
817            .zip(result_high.iter())
818            .map(|(a, b)| (a - b).abs())
819            .sum();
820        assert!(diff_sum > 1e-6);
821    }
822}