sklears_kernel_approximation/
polynomial_count_sketch.rs1use rustfft::num_complex::Complex;
3use rustfft::FftPlanner;
4use scirs2_core::ndarray::{Array1, Array2, Array3};
5use scirs2_core::random::essentials::Uniform as RandUniform;
6use scirs2_core::random::rngs::StdRng as RealStdRng;
7use scirs2_core::random::Rng;
8use scirs2_core::random::{thread_rng, SeedableRng};
9use sklears_core::{
10 error::{Result, SklearsError},
11 traits::{Estimator, Fit, Trained, Transform, Untrained},
12 types::Float,
13};
14use std::marker::PhantomData;
15
16#[derive(Debug, Clone)]
44pub struct PolynomialCountSketch<State = Untrained> {
46 pub gamma: Float,
48 pub degree: u32,
50 pub coef0: Float,
52 pub n_components: usize,
54 pub random_state: Option<u64>,
56
57 index_hash_: Option<Array3<usize>>, bit_hash_: Option<Array3<Float>>, _state: PhantomData<State>,
62}
63
64impl PolynomialCountSketch<Untrained> {
65 pub fn new(n_components: usize) -> Self {
67 Self {
68 gamma: 1.0,
69 degree: 2,
70 coef0: 0.0,
71 n_components,
72 random_state: None,
73 index_hash_: None,
74 bit_hash_: None,
75 _state: PhantomData,
76 }
77 }
78
79 pub fn gamma(mut self, gamma: Float) -> Self {
81 self.gamma = gamma;
82 self
83 }
84
85 pub fn degree(mut self, degree: u32) -> Self {
87 self.degree = degree;
88 self
89 }
90
91 pub fn coef0(mut self, coef0: Float) -> Self {
93 self.coef0 = coef0;
94 self
95 }
96
97 pub fn random_state(mut self, seed: u64) -> Self {
99 self.random_state = Some(seed);
100 self
101 }
102}
103
104impl Estimator for PolynomialCountSketch<Untrained> {
105 type Config = ();
106 type Error = SklearsError;
107 type Float = Float;
108
109 fn config(&self) -> &Self::Config {
110 &()
111 }
112}
113
114impl Fit<Array2<Float>, ()> for PolynomialCountSketch<Untrained> {
115 type Fitted = PolynomialCountSketch<Trained>;
116
117 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
118 let (_, n_features) = x.dim();
119
120 if self.degree == 0 {
121 return Err(SklearsError::InvalidInput(
122 "degree must be positive".to_string(),
123 ));
124 }
125
126 if self.n_components == 0 {
127 return Err(SklearsError::InvalidInput(
128 "n_components must be positive".to_string(),
129 ));
130 }
131
132 let mut rng = if let Some(seed) = self.random_state {
133 RealStdRng::seed_from_u64(seed)
134 } else {
135 RealStdRng::from_seed(thread_rng().gen())
136 };
137
138 let mut index_hash = Array3::zeros((self.degree as usize, n_features, 1));
140 let mut bit_hash = Array3::zeros((self.degree as usize, n_features, 1));
141
142 let index_uniform = RandUniform::new(0, self.n_components).unwrap();
143
144 for d in 0..self.degree as usize {
145 for j in 0..n_features {
146 index_hash[[d, j, 0]] = rng.sample(index_uniform);
148
149 bit_hash[[d, j, 0]] = if rng.gen::<bool>() { 1.0 } else { -1.0 };
151 }
152 }
153
154 Ok(PolynomialCountSketch {
155 gamma: self.gamma,
156 degree: self.degree,
157 coef0: self.coef0,
158 n_components: self.n_components,
159 random_state: self.random_state,
160 index_hash_: Some(index_hash),
161 bit_hash_: Some(bit_hash),
162 _state: PhantomData,
163 })
164 }
165}
166
167impl Transform<Array2<Float>, Array2<Float>> for PolynomialCountSketch<Trained> {
168 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
169 let (n_samples, n_features) = x.dim();
170 let index_hash = self.index_hash_.as_ref().unwrap();
171 let bit_hash = self.bit_hash_.as_ref().unwrap();
172
173 if n_features != index_hash.shape()[1] {
174 return Err(SklearsError::InvalidInput(format!(
175 "X has {} features, but PolynomialCountSketch was fitted with {} features",
176 n_features,
177 index_hash.shape()[1]
178 )));
179 }
180
181 let mut result = Array2::zeros((n_samples, self.n_components));
182
183 let mut planner = FftPlanner::new();
185 let fft = planner.plan_fft_forward(self.n_components);
186 let ifft = planner.plan_fft_inverse(self.n_components);
187
188 for i in 0..n_samples {
189 let sample = x.row(i);
190
191 let extended_sample = if self.coef0 != 0.0 {
193 let mut vec = sample.to_vec();
194 vec.push(self.coef0.sqrt());
195 Array1::from(vec)
196 } else {
197 sample.to_owned()
198 };
199
200 let mut sketches = Vec::new();
202
203 for d in 0..self.degree as usize {
204 let mut sketch = vec![Complex::new(0.0, 0.0); self.n_components];
205
206 for (j, &feature_val) in extended_sample.iter().enumerate() {
207 if j < n_features || self.coef0 != 0.0 {
208 let scaled_val = if j < n_features {
210 self.gamma.sqrt() * feature_val
211 } else {
212 feature_val };
214
215 let hash_idx = if j < n_features {
216 index_hash[[d, j, 0]]
217 } else {
218 0 };
220
221 let hash_sign = if j < n_features {
222 bit_hash[[d, j, 0]]
223 } else {
224 1.0 };
226
227 sketch[hash_idx] += Complex::new(hash_sign * scaled_val, 0.0);
228 }
229 }
230
231 sketches.push(sketch);
232 }
233
234 let mut fft_sketches = Vec::new();
236 for mut sketch in sketches {
237 fft.process(&mut sketch);
238 fft_sketches.push(sketch);
239 }
240
241 let mut product = vec![Complex::new(1.0, 0.0); self.n_components];
243 for fft_sketch in fft_sketches {
244 for (k, val) in fft_sketch.into_iter().enumerate() {
245 product[k] *= val;
246 }
247 }
248
249 ifft.process(&mut product);
251
252 for (k, val) in product.into_iter().enumerate() {
254 result[[i, k]] = val.re / self.n_components as Float;
255 }
256 }
257
258 Ok(result)
259 }
260}
261
262impl PolynomialCountSketch<Trained> {
263 pub fn index_hash(&self) -> &Array3<usize> {
265 self.index_hash_.as_ref().unwrap()
266 }
267
268 pub fn bit_hash(&self) -> &Array3<Float> {
270 self.bit_hash_.as_ref().unwrap()
271 }
272}
273
274#[allow(non_snake_case)]
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use scirs2_core::ndarray::array;
279
280 #[test]
281 fn test_polynomial_count_sketch_basic() {
282 let x = array![[1.0, 2.0], [3.0, 4.0],];
283
284 let poly_sketch = PolynomialCountSketch::new(32); let fitted = poly_sketch.fit(&x, &()).unwrap();
286 let x_transformed = fitted.transform(&x).unwrap();
287
288 assert_eq!(x_transformed.shape(), &[2, 32]);
289 }
290
291 #[test]
292 fn test_polynomial_count_sketch_with_coef0() {
293 let x = array![[1.0, 2.0], [3.0, 4.0],];
294
295 let poly_sketch = PolynomialCountSketch::new(16).coef0(1.0).degree(3);
296 let fitted = poly_sketch.fit(&x, &()).unwrap();
297 let x_transformed = fitted.transform(&x).unwrap();
298
299 assert_eq!(x_transformed.shape(), &[2, 16]);
300 }
301
302 #[test]
303 fn test_polynomial_count_sketch_reproducibility() {
304 let x = array![[1.0, 2.0], [3.0, 4.0],];
305
306 let poly1 = PolynomialCountSketch::new(16).random_state(42);
307 let fitted1 = poly1.fit(&x, &()).unwrap();
308 let result1 = fitted1.transform(&x).unwrap();
309
310 let poly2 = PolynomialCountSketch::new(16).random_state(42);
311 let fitted2 = poly2.fit(&x, &()).unwrap();
312 let result2 = fitted2.transform(&x).unwrap();
313
314 for (a, b) in result1.iter().zip(result2.iter()) {
316 assert!((a - b).abs() < 1e-10);
317 }
318 }
319
320 #[test]
321 fn test_polynomial_count_sketch_invalid_degree() {
322 let x = array![[1.0, 2.0]];
323 let poly_sketch = PolynomialCountSketch::new(16).degree(0);
324 let result = poly_sketch.fit(&x, &());
325 assert!(result.is_err());
326 }
327
328 #[test]
329 fn test_polynomial_count_sketch_zero_components() {
330 let x = array![[1.0, 2.0]];
331 let poly_sketch = PolynomialCountSketch::new(0);
332 let result = poly_sketch.fit(&x, &());
333 assert!(result.is_err());
334 }
335}