Skip to main content

sklears_kernel_approximation/
structured_random_features.rs

1//! Structured orthogonal random features for efficient kernel approximation
2//!
3//! This module implements structured random feature methods that use structured
4//! random matrices (like Hadamard matrices) to reduce computational complexity
5//! while maintaining approximation quality. Also includes quasi-random and
6//! low-discrepancy sequence methods for improved feature distribution.
7
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::random::essentials::{Normal as RandNormal, Uniform as RandUniform};
10use scirs2_core::random::rngs::StdRng as RealStdRng;
11use scirs2_core::random::Distribution;
12use scirs2_core::random::Rng;
13use scirs2_core::random::{thread_rng, SeedableRng};
14use scirs2_core::StandardNormal;
15use sklears_core::{
16    error::{Result, SklearsError},
17    prelude::{Fit, Transform},
18    traits::{Estimator, Trained, Untrained},
19    types::Float,
20};
21use std::f64::consts::PI;
22use std::marker::PhantomData;
23
24/// Type of structured matrix to use
25#[derive(Debug, Clone)]
26/// StructuredMatrix
27pub enum StructuredMatrix {
28    Hadamard,
29    DCT,
30    Circulant,
31    Toeplitz,
32}
33
34/// Type of quasi-random sequence to use for feature generation
35#[derive(Debug, Clone)]
36/// QuasiRandomSequence
37pub enum QuasiRandomSequence {
38    VanDerCorput,
39    Halton,
40    Sobol,
41    PseudoRandom,
42}
43
44/// Low-discrepancy sequence generators for quasi-random features
45pub struct LowDiscrepancySequences;
46
47impl LowDiscrepancySequences {
48    /// Generate Van der Corput sequence in base 2
49    /// This provides a 1D low-discrepancy sequence
50    pub fn van_der_corput(n: usize) -> Vec<Float> {
51        (0..n)
52            .map(|i| {
53                let mut value = 0.0;
54                let mut base_inv = 0.5;
55                let mut n = i + 1;
56
57                while n > 0 {
58                    if n % 2 == 1 {
59                        value += base_inv;
60                    }
61                    base_inv *= 0.5;
62                    n /= 2;
63                }
64                value
65            })
66            .collect()
67    }
68
69    /// Generate Halton sequence for multi-dimensional low-discrepancy
70    /// Uses prime bases for each dimension
71    pub fn halton(n: usize, dimensions: usize) -> Array2<Float> {
72        let primes = [
73            2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
74        ];
75
76        if dimensions > primes.len() {
77            panic!("Maximum supported dimensions: {}", primes.len());
78        }
79
80        let mut sequence = Array2::zeros((n, dimensions));
81
82        for dim in 0..dimensions {
83            let base = primes[dim] as Float;
84            for i in 0..n {
85                let mut value = 0.0;
86                let mut base_inv = 1.0 / base;
87                let mut n = i + 1;
88
89                while n > 0 {
90                    value += (n % primes[dim]) as Float * base_inv;
91                    base_inv /= base;
92                    n /= primes[dim];
93                }
94                sequence[[i, dim]] = value;
95            }
96        }
97
98        sequence
99    }
100
101    /// Generate simplified Sobol sequence for higher dimensions
102    /// This is a basic implementation focusing on the key properties
103    pub fn sobol(n: usize, dimensions: usize) -> Array2<Float> {
104        let mut sequence = Array2::zeros((n, dimensions));
105
106        // For the first dimension, use Van der Corput base 2
107        let first_dim = Self::van_der_corput(n);
108        for i in 0..n {
109            sequence[[i, 0]] = first_dim[i];
110        }
111
112        // For additional dimensions, use Van der Corput with different bases
113        let bases = [
114            3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73,
115        ];
116
117        for dim in 1..dimensions {
118            let base = if dim - 1 < bases.len() {
119                bases[dim - 1]
120            } else {
121                2 + dim
122            };
123            for i in 0..n {
124                let mut value = 0.0;
125                let mut base_inv = 1.0 / (base as Float);
126                let mut n = i + 1;
127
128                while n > 0 {
129                    value += (n % base) as Float * base_inv;
130                    base_inv /= base as Float;
131                    n /= base;
132                }
133                sequence[[i, dim]] = value;
134            }
135        }
136
137        sequence
138    }
139
140    /// Transform uniform low-discrepancy sequence to Gaussian using inverse normal CDF
141    /// Box-Muller-like transformation for quasi-random Gaussian variables
142    pub fn uniform_to_gaussian(uniform_sequence: &Array2<Float>) -> Array2<Float> {
143        let (n, dim) = uniform_sequence.dim();
144        let mut gaussian_sequence = Array2::zeros((n, dim));
145
146        for i in 0..n {
147            for j in 0..dim {
148                let u = uniform_sequence[[i, j]];
149                // Prevent extreme values
150                let u_clamped = u.max(1e-10).min(1.0 - 1e-10);
151
152                // Approximate inverse normal CDF using Beasley-Springer-Moro algorithm
153                let x = Self::inverse_normal_cdf(u_clamped);
154                gaussian_sequence[[i, j]] = x;
155            }
156        }
157
158        gaussian_sequence
159    }
160
161    /// Approximate inverse normal CDF using rational approximation
162    fn inverse_normal_cdf(u: Float) -> Float {
163        if u <= 0.0 || u >= 1.0 {
164            return if u <= 0.0 {
165                Float::NEG_INFINITY
166            } else {
167                Float::INFINITY
168            };
169        }
170
171        let u = if u > 0.5 { 1.0 - u } else { u };
172        let sign = if u == 1.0 - u { 1.0 } else { -1.0 };
173
174        let t = (-2.0 * u.ln()).sqrt();
175
176        // Coefficients for rational approximation
177        let c0 = 2.515517;
178        let c1 = 0.802853;
179        let c2 = 0.010328;
180        let d1 = 1.432788;
181        let d2 = 0.189269;
182        let d3 = 0.001308;
183
184        let numerator = c0 + c1 * t + c2 * t * t;
185        let denominator = 1.0 + d1 * t + d2 * t * t + d3 * t * t * t;
186
187        sign * (t - numerator / denominator)
188    }
189}
190
191/// Structured orthogonal random features
192///
193/// Uses structured random matrices to approximate RBF kernels more efficiently
194/// than standard random Fourier features. The structured approach reduces
195/// computational complexity from O(d*D) to O(d*log(D)) for certain operations.
196///
197/// # Parameters
198///
199/// * `n_components` - Number of random features to generate
200/// * `gamma` - RBF kernel parameter (default: 1.0)
201/// * `structured_matrix` - Type of structured matrix to use
202/// * `random_state` - Random seed for reproducibility
203///
204/// # Examples
205///
206/// ```text
207/// use sklears_kernel_approximation::structured_random_features::{
208///     StructuredRandomFeatures, StructuredMatrix,
209/// };
210///
211/// let srf = StructuredRandomFeatures::new(100)
212///     .with_gamma(0.5)
213///     .with_structured_matrix(StructuredMatrix::Hadamard);
214/// ```
215#[derive(Debug, Clone)]
216pub struct StructuredRandomFeatures<State = Untrained> {
217    pub n_components: usize,
218    pub gamma: Float,
219    pub structured_matrix: StructuredMatrix,
220    pub random_state: Option<u64>,
221
222    // Fitted parameters
223    random_weights_: Option<Array2<Float>>,
224    random_offset_: Option<Array1<Float>>,
225    structured_transform_: Option<Array2<Float>>,
226
227    _state: PhantomData<State>,
228}
229
230impl StructuredRandomFeatures<Untrained> {
231    /// Create a new structured random features transformer
232    pub fn new(n_components: usize) -> Self {
233        Self {
234            n_components,
235            gamma: 1.0,
236            structured_matrix: StructuredMatrix::Hadamard,
237            random_state: None,
238            random_weights_: None,
239            random_offset_: None,
240            structured_transform_: None,
241            _state: PhantomData,
242        }
243    }
244
245    /// Set the gamma parameter for RBF kernel
246    pub fn gamma(mut self, gamma: Float) -> Self {
247        self.gamma = gamma;
248        self
249    }
250
251    /// Set the structured matrix type
252    pub fn structured_matrix(mut self, matrix_type: StructuredMatrix) -> Self {
253        self.structured_matrix = matrix_type;
254        self
255    }
256
257    /// Set random state for reproducibility
258    pub fn random_state(mut self, seed: u64) -> Self {
259        self.random_state = Some(seed);
260        self
261    }
262}
263
264impl Estimator for StructuredRandomFeatures<Untrained> {
265    type Config = ();
266    type Error = SklearsError;
267    type Float = Float;
268
269    fn config(&self) -> &Self::Config {
270        &()
271    }
272}
273
274impl Fit<Array2<Float>, ()> for StructuredRandomFeatures<Untrained> {
275    type Fitted = StructuredRandomFeatures<Trained>;
276
277    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
278        let (_, n_features) = x.dim();
279
280        let mut rng = match self.random_state {
281            Some(seed) => RealStdRng::seed_from_u64(seed),
282            None => RealStdRng::from_seed(thread_rng().gen()),
283        };
284
285        // Generate structured transform matrix
286        let structured_transform = self.generate_structured_matrix(n_features, &mut rng)?;
287
288        // Generate random weights for mixing
289        let normal = RandNormal::new(0.0, (2.0 * self.gamma).sqrt()).unwrap();
290        let random_weights =
291            Array2::from_shape_fn((n_features, self.n_components), |_| rng.sample(normal));
292
293        // Generate random offset for phase
294        let uniform = RandUniform::new(0.0, 2.0 * PI).unwrap();
295        let random_offset = Array1::from_shape_fn(self.n_components, |_| rng.sample(uniform));
296
297        Ok(StructuredRandomFeatures {
298            n_components: self.n_components,
299            gamma: self.gamma,
300            structured_matrix: self.structured_matrix,
301            random_state: self.random_state,
302            random_weights_: Some(random_weights),
303            random_offset_: Some(random_offset),
304            structured_transform_: Some(structured_transform),
305            _state: PhantomData,
306        })
307    }
308}
309
310impl StructuredRandomFeatures<Untrained> {
311    /// Generate structured matrix based on the specified type
312    fn generate_structured_matrix(
313        &self,
314        n_features: usize,
315        rng: &mut RealStdRng,
316    ) -> Result<Array2<Float>> {
317        match &self.structured_matrix {
318            StructuredMatrix::Hadamard => self.generate_hadamard_matrix(n_features, rng),
319            StructuredMatrix::DCT => self.generate_dct_matrix(n_features),
320            StructuredMatrix::Circulant => self.generate_circulant_matrix(n_features, rng),
321            StructuredMatrix::Toeplitz => self.generate_toeplitz_matrix(n_features, rng),
322        }
323    }
324
325    /// Generate (approximate) Hadamard matrix
326    fn generate_hadamard_matrix(
327        &self,
328        n_features: usize,
329        rng: &mut RealStdRng,
330    ) -> Result<Array2<Float>> {
331        // For simplicity, generate a randomized orthogonal-like matrix
332        // True Hadamard matrices exist only for specific sizes
333        let mut matrix = Array2::zeros((n_features, n_features));
334
335        // Generate random signs for each entry
336        for i in 0..n_features {
337            for j in 0..n_features {
338                matrix[[i, j]] = if rng.gen::<bool>() { 1.0 } else { -1.0 };
339            }
340        }
341
342        // Normalize to make approximately orthogonal
343        for mut row in matrix.rows_mut() {
344            let norm = (row.dot(&row) as Float).sqrt();
345            if norm > 1e-10 {
346                row /= norm;
347            }
348        }
349
350        Ok(matrix)
351    }
352
353    /// Generate Discrete Cosine Transform matrix
354    fn generate_dct_matrix(&self, n_features: usize) -> Result<Array2<Float>> {
355        let mut matrix = Array2::zeros((n_features, n_features));
356
357        for i in 0..n_features {
358            for j in 0..n_features {
359                let coeff = if i == 0 {
360                    (1.0 / (n_features as Float)).sqrt()
361                } else {
362                    (2.0 / (n_features as Float)).sqrt()
363                };
364
365                matrix[[i, j]] = coeff
366                    * ((PI * (i as Float) * (2.0 * (j as Float) + 1.0))
367                        / (2.0 * (n_features as Float)))
368                        .cos();
369            }
370        }
371
372        Ok(matrix)
373    }
374
375    /// Generate Circulant matrix
376    fn generate_circulant_matrix(
377        &self,
378        n_features: usize,
379        rng: &mut RealStdRng,
380    ) -> Result<Array2<Float>> {
381        let mut matrix = Array2::zeros((n_features, n_features));
382
383        // Generate first row randomly
384        let first_row: Vec<Float> = (0..n_features)
385            .map(|_| StandardNormal.sample(rng))
386            .collect();
387
388        // Fill circulant matrix
389        for i in 0..n_features {
390            for j in 0..n_features {
391                matrix[[i, j]] = first_row[(j + n_features - i) % n_features];
392            }
393        }
394
395        // Normalize rows
396        for mut row in matrix.rows_mut() {
397            let norm = (row.dot(&row) as Float).sqrt();
398            if norm > 1e-10 {
399                row /= norm;
400            }
401        }
402
403        Ok(matrix)
404    }
405
406    /// Generate Toeplitz matrix
407    fn generate_toeplitz_matrix(
408        &self,
409        n_features: usize,
410        rng: &mut RealStdRng,
411    ) -> Result<Array2<Float>> {
412        let mut matrix = Array2::zeros((n_features, n_features));
413
414        // Generate values for first row and first column
415        let first_row: Vec<Float> = (0..n_features)
416            .map(|_| StandardNormal.sample(rng))
417            .collect();
418        let first_col: Vec<Float> = (1..n_features)
419            .map(|_| StandardNormal.sample(rng))
420            .collect();
421
422        // Fill Toeplitz matrix
423        for i in 0..n_features {
424            for j in 0..n_features {
425                if i <= j {
426                    matrix[[i, j]] = first_row[j - i];
427                } else {
428                    matrix[[i, j]] = first_col[i - j - 1];
429                }
430            }
431        }
432
433        // Normalize rows
434        for mut row in matrix.rows_mut() {
435            let norm = (row.dot(&row) as Float).sqrt();
436            if norm > 1e-10 {
437                row /= norm;
438            }
439        }
440
441        Ok(matrix)
442    }
443}
444
445impl Transform<Array2<Float>> for StructuredRandomFeatures<Trained> {
446    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
447        let random_weights =
448            self.random_weights_
449                .as_ref()
450                .ok_or_else(|| SklearsError::NotFitted {
451                    operation: "transform".to_string(),
452                })?;
453
454        let random_offset =
455            self.random_offset_
456                .as_ref()
457                .ok_or_else(|| SklearsError::NotFitted {
458                    operation: "transform".to_string(),
459                })?;
460
461        let structured_transform =
462            self.structured_transform_
463                .as_ref()
464                .ok_or_else(|| SklearsError::NotFitted {
465                    operation: "transform".to_string(),
466                })?;
467
468        let (n_samples, _) = x.dim();
469
470        // Apply structured transformation first
471        let structured_x = x.dot(structured_transform);
472
473        // Apply random projections
474        let projected = structured_x.dot(random_weights);
475
476        // Add phase offsets and compute cosine features
477        let mut features = Array2::zeros((n_samples, self.n_components));
478
479        for i in 0..n_samples {
480            for j in 0..self.n_components {
481                let phase = projected[[i, j]] + random_offset[j];
482                features[[i, j]] = (2.0 / (self.n_components as Float)).sqrt() * phase.cos();
483            }
484        }
485
486        Ok(features)
487    }
488}
489
490/// Fast Walsh-Hadamard Transform for efficient structured features
491///
492/// This implements the Fast Walsh-Hadamard Transform (FWHT) which can be used
493/// to accelerate computations with Hadamard-structured random features.
494pub struct FastWalshHadamardTransform;
495
496impl FastWalshHadamardTransform {
497    /// Apply Fast Walsh-Hadamard Transform to input vector
498    ///
499    /// Time complexity: O(n log n) where n is the length of the input
500    /// Input length must be a power of 2
501    pub fn transform(mut data: Array1<Float>) -> Result<Array1<Float>> {
502        let n = data.len();
503
504        // Check if n is a power of 2
505        if n & (n - 1) != 0 {
506            return Err(SklearsError::InvalidInput(
507                "Input length must be a power of 2 for FWHT".to_string(),
508            ));
509        }
510
511        // Perform FWHT
512        let mut h = 1;
513        while h < n {
514            for i in (0..n).step_by(h * 2) {
515                for j in i..i + h {
516                    let u = data[j];
517                    let v = data[j + h];
518                    data[j] = u + v;
519                    data[j + h] = u - v;
520                }
521            }
522            h *= 2;
523        }
524
525        // Normalize
526        data /= (n as Float).sqrt();
527
528        Ok(data)
529    }
530
531    /// Apply FWHT to each row of a 2D array
532    pub fn transform_rows(mut data: Array2<Float>) -> Result<Array2<Float>> {
533        let (n_rows, n_cols) = data.dim();
534
535        // Check if n_cols is a power of 2
536        if n_cols & (n_cols - 1) != 0 {
537            return Err(SklearsError::InvalidInput(
538                "Number of columns must be a power of 2 for FWHT".to_string(),
539            ));
540        }
541
542        for i in 0..n_rows {
543            let row = data.row(i).to_owned();
544            let transformed_row = Self::transform(row)?;
545            data.row_mut(i).assign(&transformed_row);
546        }
547
548        Ok(data)
549    }
550}
551
552/// Structured Random Features using Fast Walsh-Hadamard Transform
553///
554/// This is an optimized version that uses FWHT for better computational efficiency.
555/// Input dimension must be a power of 2 for optimal performance.
556#[derive(Debug, Clone)]
557/// StructuredRFFHadamard
558pub struct StructuredRFFHadamard<State = Untrained> {
559    /// Number of random features
560    pub n_components: usize,
561    /// RBF kernel gamma parameter
562    pub gamma: Float,
563    /// Random seed
564    pub random_state: Option<u64>,
565
566    // Fitted parameters
567    random_signs_: Option<Array2<Float>>,
568    random_offset_: Option<Array1<Float>>,
569    gaussian_weights_: Option<Array1<Float>>,
570
571    _state: PhantomData<State>,
572}
573
574impl StructuredRFFHadamard<Untrained> {
575    /// Create a new structured RFF with Hadamard transforms
576    pub fn new(n_components: usize) -> Self {
577        Self {
578            n_components,
579            gamma: 1.0,
580            random_state: None,
581            random_signs_: None,
582            random_offset_: None,
583            gaussian_weights_: None,
584            _state: PhantomData,
585        }
586    }
587
588    /// Set gamma parameter
589    pub fn gamma(mut self, gamma: Float) -> Self {
590        self.gamma = gamma;
591        self
592    }
593
594    /// Set random state
595    pub fn random_state(mut self, seed: u64) -> Self {
596        self.random_state = Some(seed);
597        self
598    }
599}
600
601impl Estimator for StructuredRFFHadamard<Untrained> {
602    type Config = ();
603    type Error = SklearsError;
604    type Float = Float;
605
606    fn config(&self) -> &Self::Config {
607        &()
608    }
609}
610
611impl Fit<Array2<Float>, ()> for StructuredRFFHadamard<Untrained> {
612    type Fitted = StructuredRFFHadamard<Trained>;
613
614    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
615        let (_, n_features) = x.dim();
616
617        // Check if n_features is a power of 2
618        if n_features & (n_features - 1) != 0 {
619            return Err(SklearsError::InvalidInput(
620                "Number of features must be a power of 2 for structured Hadamard RFF".to_string(),
621            ));
622        }
623
624        let mut rng = match self.random_state {
625            Some(seed) => RealStdRng::seed_from_u64(seed),
626            None => RealStdRng::from_seed(thread_rng().gen()),
627        };
628
629        // Generate random signs for Hadamard transforms
630        let mut random_signs = Array2::zeros((self.n_components, n_features));
631        for i in 0..self.n_components {
632            for j in 0..n_features {
633                random_signs[[i, j]] = if rng.gen::<bool>() { 1.0 } else { -1.0 };
634            }
635        }
636
637        // Generate Gaussian weights
638        let normal = RandNormal::new(0.0, 1.0).unwrap();
639        let gaussian_weights = Array1::from_shape_fn(n_features, |_| rng.sample(normal));
640
641        // Generate random offsets
642        let uniform = RandUniform::new(0.0, 2.0 * PI).unwrap();
643        let random_offset = Array1::from_shape_fn(self.n_components, |_| rng.sample(uniform));
644
645        Ok(StructuredRFFHadamard {
646            n_components: self.n_components,
647            gamma: self.gamma,
648            random_state: self.random_state,
649            random_signs_: Some(random_signs),
650            random_offset_: Some(random_offset),
651            gaussian_weights_: Some(gaussian_weights),
652            _state: PhantomData,
653        })
654    }
655}
656
657impl Transform<Array2<Float>> for StructuredRFFHadamard<Trained> {
658    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
659        let random_signs = self
660            .random_signs_
661            .as_ref()
662            .ok_or_else(|| SklearsError::NotFitted {
663                operation: "transform".to_string(),
664            })?;
665
666        let random_offset =
667            self.random_offset_
668                .as_ref()
669                .ok_or_else(|| SklearsError::NotFitted {
670                    operation: "transform".to_string(),
671                })?;
672
673        let gaussian_weights =
674            self.gaussian_weights_
675                .as_ref()
676                .ok_or_else(|| SklearsError::NotFitted {
677                    operation: "transform".to_string(),
678                })?;
679
680        let (n_samples, n_features) = x.dim();
681        let mut features = Array2::zeros((n_samples, self.n_components));
682
683        // For each component, apply structured transformation
684        for comp in 0..self.n_components {
685            for sample in 0..n_samples {
686                // Element-wise multiplication with random signs
687                let mut signed_input = Array1::zeros(n_features);
688                for feat in 0..n_features {
689                    signed_input[feat] = x[[sample, feat]] * random_signs[[comp, feat]];
690                }
691
692                // Apply Fast Walsh-Hadamard Transform
693                let transformed = FastWalshHadamardTransform::transform(signed_input)?;
694
695                // Scale by Gaussian weights and gamma
696                let mut projected = 0.0;
697                for feat in 0..n_features {
698                    projected += transformed[feat] * gaussian_weights[feat];
699                }
700                projected *= (2.0 * self.gamma).sqrt();
701
702                // Add offset and compute cosine
703                let phase = projected + random_offset[comp];
704                features[[sample, comp]] =
705                    (2.0 / (self.n_components as Float)).sqrt() * phase.cos();
706            }
707        }
708
709        Ok(features)
710    }
711}
712
713#[allow(non_snake_case)]
714#[cfg(test)]
715mod tests {
716    use super::*;
717    use scirs2_core::ndarray::array;
718
719    #[test]
720    fn test_structured_random_features_basic() {
721        let x = array![
722            [1.0, 2.0, 3.0, 4.0],
723            [2.0, 3.0, 4.0, 5.0],
724            [3.0, 4.0, 5.0, 6.0]
725        ];
726
727        let srf = StructuredRandomFeatures::new(8).gamma(0.5);
728        let fitted = srf.fit(&x, &()).unwrap();
729        let transformed = fitted.transform(&x).unwrap();
730
731        assert_eq!(transformed.shape(), &[3, 8]);
732    }
733
734    #[test]
735    fn test_structured_matrices() {
736        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
737
738        // Test Hadamard
739        let hadamard_srf =
740            StructuredRandomFeatures::new(4).structured_matrix(StructuredMatrix::Hadamard);
741        let hadamard_fitted = hadamard_srf.fit(&x, &()).unwrap();
742        let hadamard_result = hadamard_fitted.transform(&x).unwrap();
743        assert_eq!(hadamard_result.shape(), &[2, 4]);
744
745        // Test DCT
746        let dct_srf = StructuredRandomFeatures::new(4).structured_matrix(StructuredMatrix::DCT);
747        let dct_fitted = dct_srf.fit(&x, &()).unwrap();
748        let dct_result = dct_fitted.transform(&x).unwrap();
749        assert_eq!(dct_result.shape(), &[2, 4]);
750
751        // Test Circulant
752        let circulant_srf =
753            StructuredRandomFeatures::new(4).structured_matrix(StructuredMatrix::Circulant);
754        let circulant_fitted = circulant_srf.fit(&x, &()).unwrap();
755        let circulant_result = circulant_fitted.transform(&x).unwrap();
756        assert_eq!(circulant_result.shape(), &[2, 4]);
757    }
758
759    #[test]
760    fn test_fast_walsh_hadamard_transform() {
761        let data = array![1.0, 2.0, 3.0, 4.0];
762        let result = FastWalshHadamardTransform::transform(data).unwrap();
763        assert_eq!(result.len(), 4);
764
765        // Test with 2D array
766        let data_2d = array![[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
767        let result_2d = FastWalshHadamardTransform::transform_rows(data_2d).unwrap();
768        assert_eq!(result_2d.shape(), &[2, 4]);
769    }
770
771    #[test]
772    fn test_structured_rff_hadamard() {
773        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
774
775        let srf_h = StructuredRFFHadamard::new(6).gamma(0.5);
776        let fitted = srf_h.fit(&x, &()).unwrap();
777        let transformed = fitted.transform(&x).unwrap();
778
779        assert_eq!(transformed.shape(), &[2, 6]);
780    }
781
782    #[test]
783    fn test_fwht_invalid_size() {
784        let data = array![1.0, 2.0, 3.0]; // Not a power of 2
785        let result = FastWalshHadamardTransform::transform(data);
786        assert!(result.is_err());
787    }
788
789    #[test]
790    fn test_reproducibility() {
791        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
792
793        let srf1 = StructuredRandomFeatures::new(8).random_state(42);
794        let fitted1 = srf1.fit(&x, &()).unwrap();
795        let result1 = fitted1.transform(&x).unwrap();
796
797        let srf2 = StructuredRandomFeatures::new(8).random_state(42);
798        let fitted2 = srf2.fit(&x, &()).unwrap();
799        let result2 = fitted2.transform(&x).unwrap();
800
801        assert_eq!(result1.shape(), result2.shape());
802        for i in 0..result1.len() {
803            assert!(
804                (result1.as_slice().unwrap()[i] - result2.as_slice().unwrap()[i]).abs() < 1e-10
805            );
806        }
807    }
808
809    #[test]
810    fn test_different_gamma_values() {
811        let x = array![[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]];
812
813        let srf_low = StructuredRandomFeatures::new(4).gamma(0.1);
814        let fitted_low = srf_low.fit(&x, &()).unwrap();
815        let result_low = fitted_low.transform(&x).unwrap();
816
817        let srf_high = StructuredRandomFeatures::new(4).gamma(10.0);
818        let fitted_high = srf_high.fit(&x, &()).unwrap();
819        let result_high = fitted_high.transform(&x).unwrap();
820
821        assert_eq!(result_low.shape(), result_high.shape());
822        // Results should be different with different gamma values
823        let diff_sum: Float = result_low
824            .iter()
825            .zip(result_high.iter())
826            .map(|(a, b)| (a - b).abs())
827            .sum();
828        assert!(diff_sum > 1e-6);
829    }
830}