sklears_kernel_approximation/
polynomial_count_sketch.rs

1//! Polynomial kernel approximation via Tensor Sketch
2use rustfft::num_complex::Complex;
3use rustfft::FftPlanner;
4use scirs2_core::ndarray::{Array1, Array2, Array3};
5use scirs2_core::random::essentials::Uniform as RandUniform;
6use scirs2_core::random::rngs::StdRng as RealStdRng;
7use scirs2_core::random::Rng;
8use scirs2_core::random::{thread_rng, SeedableRng};
9use sklears_core::{
10    error::{Result, SklearsError},
11    traits::{Estimator, Fit, Trained, Transform, Untrained},
12    types::Float,
13};
14use std::marker::PhantomData;
15
16/// Polynomial kernel approximation via Tensor Sketch
17///
18/// Implements Tensor Sketch for polynomial kernel approximation:
19/// K(x,y) = (gamma * <x,y> + coef0)^degree
20///
21/// # Parameters
22///
23/// * `gamma` - Polynomial kernel parameter (default: 1.0)
24/// * `degree` - Polynomial degree (default: 2)
25/// * `coef0` - Constant term (default: 0.0)
26/// * `n_components` - Output dimensionality (default: 100)
27/// * `random_state` - Random seed for reproducibility
28///
29/// # Examples
30///
31/// ```rust,ignore
32/// use sklears_kernel_approximation::polynomial_count_sketch::PolynomialCountSketch;
33/// use sklears_core::traits::{Transform, Fit, Untrained}
34/// use scirs2_core::ndarray::array;
35///
36/// let X = array![[1.0, 2.0], [3.0, 4.0]];
37///
38/// let poly_sketch = PolynomialCountSketch::new(50);
39/// let fitted_sketch = poly_sketch.fit(&X, &()).unwrap();
40/// let X_transformed = fitted_sketch.transform(&X).unwrap();
41/// assert_eq!(X_transformed.shape(), &[2, 50]);
42/// ```
43#[derive(Debug, Clone)]
44/// PolynomialCountSketch
45pub struct PolynomialCountSketch<State = Untrained> {
46    /// Polynomial kernel gamma parameter
47    pub gamma: Float,
48    /// Polynomial degree
49    pub degree: u32,
50    /// Constant term
51    pub coef0: Float,
52    /// Output dimensionality
53    pub n_components: usize,
54    /// Random seed
55    pub random_state: Option<u64>,
56
57    // Fitted attributes
58    index_hash_: Option<Array3<usize>>, // (degree, n_features, 1)
59    bit_hash_: Option<Array3<Float>>,   // (degree, n_features, 1)
60
61    _state: PhantomData<State>,
62}
63
64impl PolynomialCountSketch<Untrained> {
65    /// Create a new Polynomial Count Sketch
66    pub fn new(n_components: usize) -> Self {
67        Self {
68            gamma: 1.0,
69            degree: 2,
70            coef0: 0.0,
71            n_components,
72            random_state: None,
73            index_hash_: None,
74            bit_hash_: None,
75            _state: PhantomData,
76        }
77    }
78
79    /// Set the gamma parameter
80    pub fn gamma(mut self, gamma: Float) -> Self {
81        self.gamma = gamma;
82        self
83    }
84
85    /// Set the polynomial degree
86    pub fn degree(mut self, degree: u32) -> Self {
87        self.degree = degree;
88        self
89    }
90
91    /// Set the constant term
92    pub fn coef0(mut self, coef0: Float) -> Self {
93        self.coef0 = coef0;
94        self
95    }
96
97    /// Set random state for reproducibility
98    pub fn random_state(mut self, seed: u64) -> Self {
99        self.random_state = Some(seed);
100        self
101    }
102}
103
104impl Estimator for PolynomialCountSketch<Untrained> {
105    type Config = ();
106    type Error = SklearsError;
107    type Float = Float;
108
109    fn config(&self) -> &Self::Config {
110        &()
111    }
112}
113
114impl Fit<Array2<Float>, ()> for PolynomialCountSketch<Untrained> {
115    type Fitted = PolynomialCountSketch<Trained>;
116
117    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
118        let (_, n_features) = x.dim();
119
120        if self.degree == 0 {
121            return Err(SklearsError::InvalidInput(
122                "degree must be positive".to_string(),
123            ));
124        }
125
126        if self.n_components == 0 {
127            return Err(SklearsError::InvalidInput(
128                "n_components must be positive".to_string(),
129            ));
130        }
131
132        let mut rng = if let Some(seed) = self.random_state {
133            RealStdRng::seed_from_u64(seed)
134        } else {
135            RealStdRng::from_seed(thread_rng().gen())
136        };
137
138        // Generate hash functions for Count Sketch
139        let mut index_hash = Array3::zeros((self.degree as usize, n_features, 1));
140        let mut bit_hash = Array3::zeros((self.degree as usize, n_features, 1));
141
142        let index_uniform = RandUniform::new(0, self.n_components).unwrap();
143
144        for d in 0..self.degree as usize {
145            for j in 0..n_features {
146                // Index hash: uniform random in [0, n_components)
147                index_hash[[d, j, 0]] = rng.sample(index_uniform);
148
149                // Bit hash: random ±1
150                bit_hash[[d, j, 0]] = if rng.gen::<bool>() { 1.0 } else { -1.0 };
151            }
152        }
153
154        Ok(PolynomialCountSketch {
155            gamma: self.gamma,
156            degree: self.degree,
157            coef0: self.coef0,
158            n_components: self.n_components,
159            random_state: self.random_state,
160            index_hash_: Some(index_hash),
161            bit_hash_: Some(bit_hash),
162            _state: PhantomData,
163        })
164    }
165}
166
167impl Transform<Array2<Float>, Array2<Float>> for PolynomialCountSketch<Trained> {
168    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
169        let (n_samples, n_features) = x.dim();
170        let index_hash = self.index_hash_.as_ref().unwrap();
171        let bit_hash = self.bit_hash_.as_ref().unwrap();
172
173        if n_features != index_hash.shape()[1] {
174            return Err(SklearsError::InvalidInput(format!(
175                "X has {} features, but PolynomialCountSketch was fitted with {} features",
176                n_features,
177                index_hash.shape()[1]
178            )));
179        }
180
181        let mut result = Array2::zeros((n_samples, self.n_components));
182
183        // FFT planner
184        let mut planner = FftPlanner::new();
185        let fft = planner.plan_fft_forward(self.n_components);
186        let ifft = planner.plan_fft_inverse(self.n_components);
187
188        for i in 0..n_samples {
189            let sample = x.row(i);
190
191            // Add constant term if coef0 != 0
192            let extended_sample = if self.coef0 != 0.0 {
193                let mut vec = sample.to_vec();
194                vec.push(self.coef0.sqrt());
195                Array1::from(vec)
196            } else {
197                sample.to_owned()
198            };
199
200            // Compute Count Sketch for each degree
201            let mut sketches = Vec::new();
202
203            for d in 0..self.degree as usize {
204                let mut sketch = vec![Complex::new(0.0, 0.0); self.n_components];
205
206                for (j, &feature_val) in extended_sample.iter().enumerate() {
207                    if j < n_features || self.coef0 != 0.0 {
208                        // Include constant term
209                        let scaled_val = if j < n_features {
210                            self.gamma.sqrt() * feature_val
211                        } else {
212                            feature_val // Already sqrt(coef0)
213                        };
214
215                        let hash_idx = if j < n_features {
216                            index_hash[[d, j, 0]]
217                        } else {
218                            0 // Use index 0 for constant term
219                        };
220
221                        let hash_sign = if j < n_features {
222                            bit_hash[[d, j, 0]]
223                        } else {
224                            1.0 // Always positive for constant term
225                        };
226
227                        sketch[hash_idx] += Complex::new(hash_sign * scaled_val, 0.0);
228                    }
229                }
230
231                sketches.push(sketch);
232            }
233
234            // Compute FFT of each sketch
235            let mut fft_sketches = Vec::new();
236            for mut sketch in sketches {
237                fft.process(&mut sketch);
238                fft_sketches.push(sketch);
239            }
240
241            // Compute element-wise product of FFTs (convolution in time domain)
242            let mut product = vec![Complex::new(1.0, 0.0); self.n_components];
243            for fft_sketch in fft_sketches {
244                for (k, val) in fft_sketch.into_iter().enumerate() {
245                    product[k] *= val;
246                }
247            }
248
249            // Compute inverse FFT
250            ifft.process(&mut product);
251
252            // Extract real part and normalize
253            for (k, val) in product.into_iter().enumerate() {
254                result[[i, k]] = val.re / self.n_components as Float;
255            }
256        }
257
258        Ok(result)
259    }
260}
261
262impl PolynomialCountSketch<Trained> {
263    /// Get the index hash table
264    pub fn index_hash(&self) -> &Array3<usize> {
265        self.index_hash_.as_ref().unwrap()
266    }
267
268    /// Get the bit hash table
269    pub fn bit_hash(&self) -> &Array3<Float> {
270        self.bit_hash_.as_ref().unwrap()
271    }
272}
273
274#[allow(non_snake_case)]
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use scirs2_core::ndarray::array;
279
280    #[test]
281    fn test_polynomial_count_sketch_basic() {
282        let x = array![[1.0, 2.0], [3.0, 4.0],];
283
284        let poly_sketch = PolynomialCountSketch::new(32); // Power of 2 for FFT
285        let fitted = poly_sketch.fit(&x, &()).unwrap();
286        let x_transformed = fitted.transform(&x).unwrap();
287
288        assert_eq!(x_transformed.shape(), &[2, 32]);
289    }
290
291    #[test]
292    fn test_polynomial_count_sketch_with_coef0() {
293        let x = array![[1.0, 2.0], [3.0, 4.0],];
294
295        let poly_sketch = PolynomialCountSketch::new(16).coef0(1.0).degree(3);
296        let fitted = poly_sketch.fit(&x, &()).unwrap();
297        let x_transformed = fitted.transform(&x).unwrap();
298
299        assert_eq!(x_transformed.shape(), &[2, 16]);
300    }
301
302    #[test]
303    fn test_polynomial_count_sketch_reproducibility() {
304        let x = array![[1.0, 2.0], [3.0, 4.0],];
305
306        let poly1 = PolynomialCountSketch::new(16).random_state(42);
307        let fitted1 = poly1.fit(&x, &()).unwrap();
308        let result1 = fitted1.transform(&x).unwrap();
309
310        let poly2 = PolynomialCountSketch::new(16).random_state(42);
311        let fitted2 = poly2.fit(&x, &()).unwrap();
312        let result2 = fitted2.transform(&x).unwrap();
313
314        // Results should be identical with same random state
315        for (a, b) in result1.iter().zip(result2.iter()) {
316            assert!((a - b).abs() < 1e-10);
317        }
318    }
319
320    #[test]
321    fn test_polynomial_count_sketch_invalid_degree() {
322        let x = array![[1.0, 2.0]];
323        let poly_sketch = PolynomialCountSketch::new(16).degree(0);
324        let result = poly_sketch.fit(&x, &());
325        assert!(result.is_err());
326    }
327
328    #[test]
329    fn test_polynomial_count_sketch_zero_components() {
330        let x = array![[1.0, 2.0]];
331        let poly_sketch = PolynomialCountSketch::new(0);
332        let result = poly_sketch.fit(&x, &());
333        assert!(result.is_err());
334    }
335}