sklears_gaussian_process/
features.rs1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
14use sklears_core::{
15 error::{Result as SklResult, SklearsError},
16 traits::{Estimator, Fit, Predict, Untrained},
17};
18
19use crate::{classification::GpcConfig, utils};
20
21#[derive(Debug, Clone)]
28pub struct RandomFourierFeatures {
29 n_components: usize,
30 gamma: f64,
31 random_state: Option<u64>,
32 omega: Option<Array2<f64>>, phi: Option<Array1<f64>>, fitted: bool,
35}
36
37impl RandomFourierFeatures {
38 pub fn new(n_components: usize, gamma: f64, random_state: Option<u64>) -> Self {
45 Self {
46 n_components,
47 gamma,
48 random_state,
49 omega: None,
50 phi: None,
51 fitted: false,
52 }
53 }
54
55 pub fn fit(&mut self, X: &ArrayView2<f64>) -> SklResult<()> {
57 let (_, n_features) = X.dim();
58
59 let mut rng_state = self.random_state.unwrap_or(42);
61
62 let mut omega = Array2::<f64>::zeros((self.n_components, n_features));
65 let std_dev = (2.0 * self.gamma).sqrt();
66
67 for i in 0..self.n_components {
68 for j in 0..n_features {
69 let (u1, u2) = self.uniform_pair(&mut rng_state);
71 let normal_sample =
72 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
73 omega[[i, j]] = std_dev * normal_sample;
74 }
75 }
76
77 let mut phi = Array1::<f64>::zeros(self.n_components);
79 for i in 0..self.n_components {
80 let u = self.uniform(&mut rng_state);
81 phi[i] = 2.0 * std::f64::consts::PI * u;
82 }
83
84 self.omega = Some(omega);
85 self.phi = Some(phi);
86 self.fitted = true;
87
88 Ok(())
89 }
90
91 pub fn transform(&self, X: &ArrayView2<f64>) -> SklResult<Array2<f64>> {
93 if !self.fitted {
94 return Err(SklearsError::InvalidInput(
95 "RandomFourierFeatures must be fitted before transform".to_string(),
96 ));
97 }
98
99 let omega = self.omega.as_ref().unwrap();
100 let phi = self.phi.as_ref().unwrap();
101
102 let (n_samples, _) = X.dim();
103 let mut features = Array2::<f64>::zeros((n_samples, self.n_components));
104
105 let normalization = (2.0 / self.n_components as f64).sqrt();
107
108 for i in 0..n_samples {
109 for j in 0..self.n_components {
110 let dot_product: f64 = X
111 .row(i)
112 .iter()
113 .zip(omega.row(j).iter())
114 .map(|(x, w)| x * w)
115 .sum();
116 features[[i, j]] = normalization * (dot_product + phi[j]).cos();
117 }
118 }
119
120 Ok(features)
121 }
122
123 pub fn fit_transform(&mut self, X: &ArrayView2<f64>) -> SklResult<Array2<f64>> {
125 self.fit(X)?;
126 self.transform(X)
127 }
128
129 pub fn n_components(&self) -> usize {
131 self.n_components
132 }
133
134 pub fn gamma(&self) -> f64 {
136 self.gamma
137 }
138
139 pub fn is_fitted(&self) -> bool {
141 self.fitted
142 }
143
144 fn uniform(&self, rng_state: &mut u64) -> f64 {
146 *rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
147 (*rng_state & 0x7fffffff) as f64 / 0x7fffffff as f64
148 }
149
150 fn uniform_pair(&self, rng_state: &mut u64) -> (f64, f64) {
152 let u1 = self.uniform(rng_state);
153 let u2 = self.uniform(rng_state);
154 (u1, u2)
155 }
156}
157
158impl Default for RandomFourierFeatures {
159 fn default() -> Self {
160 Self::new(100, 1.0, Some(42))
161 }
162}
163
164#[derive(Debug, Clone)]
190pub struct RandomFourierFeaturesGPR<S = Untrained> {
191 state: S,
192 n_components: usize,
193 gamma: f64,
194 alpha: f64,
195 random_state: Option<u64>,
196 config: GpcConfig,
197}
198
199#[derive(Debug, Clone)]
201pub struct RffGprTrained {
202 pub rff: RandomFourierFeatures, pub weights: Array1<f64>, pub weights_cov: Array2<f64>, pub alpha: f64, pub y_mean: f64, pub log_marginal_likelihood_value: f64, }
215
216impl RandomFourierFeaturesGPR<Untrained> {
217 pub fn new() -> Self {
219 Self {
220 state: Untrained,
221 n_components: 100,
222 gamma: 1.0,
223 alpha: 1e-3,
224 random_state: Some(42),
225 config: GpcConfig::default(),
226 }
227 }
228
229 pub fn n_components(mut self, n_components: usize) -> Self {
231 self.n_components = n_components;
232 self
233 }
234
235 pub fn gamma(mut self, gamma: f64) -> Self {
237 self.gamma = gamma;
238 self
239 }
240
241 pub fn alpha(mut self, alpha: f64) -> Self {
243 self.alpha = alpha;
244 self
245 }
246
247 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
249 self.random_state = random_state;
250 self
251 }
252}
253
254impl Estimator for RandomFourierFeaturesGPR<Untrained> {
255 type Config = GpcConfig;
256 type Error = SklearsError;
257 type Float = f64;
258
259 fn config(&self) -> &Self::Config {
260 &self.config
261 }
262}
263
264impl Estimator for RandomFourierFeaturesGPR<RffGprTrained> {
265 type Config = GpcConfig;
266 type Error = SklearsError;
267 type Float = f64;
268
269 fn config(&self) -> &Self::Config {
270 &self.config
271 }
272}
273
274impl Fit<ArrayView2<'_, f64>, ArrayView1<'_, f64>> for RandomFourierFeaturesGPR<Untrained> {
275 type Fitted = RandomFourierFeaturesGPR<RffGprTrained>;
276
277 #[allow(non_snake_case)]
278 fn fit(self, X: &ArrayView2<f64>, y: &ArrayView1<f64>) -> SklResult<Self::Fitted> {
279 if X.nrows() != y.len() {
280 return Err(SklearsError::InvalidInput(
281 "X and y must have the same number of samples".to_string(),
282 ));
283 }
284
285 let mut rff = RandomFourierFeatures::new(self.n_components, self.gamma, self.random_state);
287 let features = rff.fit_transform(X)?;
288
289 let y_mean = y.mean().unwrap_or(0.0);
291 let y_centered = y.mapv(|yi| yi - y_mean);
292
293 let n_features = features.ncols();
299 let sigma_n_sq = 1.0 / self.alpha;
300 let sigma_p_sq = 1.0; let mut gram_matrix = features.t().dot(&features) / sigma_n_sq;
304 for i in 0..n_features {
305 gram_matrix[[i, i]] += 1.0 / sigma_p_sq;
306 }
307
308 let L = utils::robust_cholesky(&gram_matrix)?;
310
311 let phi_t_y = features.t().dot(&y_centered) / sigma_n_sq;
313 let weights = utils::triangular_solve(&L, &phi_t_y)?;
314
315 let I = Array2::<f64>::eye(n_features);
317 let weights_cov = utils::triangular_solve_matrix(&L, &I)?;
318
319 let predictions = features.dot(&weights);
321 let residuals = &y_centered - &predictions;
322 let data_fit = -0.5 * residuals.mapv(|r| r * r).sum() * self.alpha;
323 let complexity_penalty = -0.5 * weights.mapv(|w| w * w).sum() / sigma_p_sq;
324 let normalization = -0.5 * y.len() as f64 * (2.0 * std::f64::consts::PI / self.alpha).ln();
325 let log_marginal_likelihood_value = data_fit + complexity_penalty + normalization;
326
327 Ok(RandomFourierFeaturesGPR {
328 state: RffGprTrained {
329 rff,
330 weights,
331 weights_cov,
332 alpha: self.alpha,
333 y_mean,
334 log_marginal_likelihood_value,
335 },
336 n_components: self.n_components,
337 gamma: self.gamma,
338 alpha: self.alpha,
339 random_state: self.random_state,
340 config: self.config.clone(),
341 })
342 }
343}
344
345impl Predict<ArrayView2<'_, f64>, Array1<f64>> for RandomFourierFeaturesGPR<RffGprTrained> {
346 fn predict(&self, X: &ArrayView2<f64>) -> SklResult<Array1<f64>> {
347 let (mean, _) = self.predict_with_std(X)?;
348 Ok(mean)
349 }
350}
351
352impl RandomFourierFeaturesGPR<RffGprTrained> {
353 pub fn predict_with_std(&self, X: &ArrayView2<f64>) -> SklResult<(Array1<f64>, Array1<f64>)> {
355 let features = self.state.rff.transform(X)?;
357
358 let mean_centered = features.dot(&self.state.weights);
360 let mean = mean_centered.mapv(|m| m + self.state.y_mean);
361
362 let feature_uncertainty = features.dot(&self.state.weights_cov);
364 let mut variance = Array1::<f64>::zeros(X.nrows());
365
366 for i in 0..X.nrows() {
367 let phi_i = features.row(i);
368 let var_from_weights = phi_i.dot(&feature_uncertainty.row(i));
369 let noise_var = 1.0 / self.state.alpha;
370 variance[i] = var_from_weights + noise_var;
371 }
372
373 let std = variance.mapv(|v| v.sqrt().max(0.0));
374
375 Ok((mean, std))
376 }
377
378 pub fn log_marginal_likelihood(&self) -> f64 {
380 self.state.log_marginal_likelihood_value
381 }
382
383 pub fn weights(&self) -> &Array1<f64> {
385 &self.state.weights
386 }
387
388 pub fn weights_covariance(&self) -> &Array2<f64> {
390 &self.state.weights_cov
391 }
392
393 pub fn rff_transformer(&self) -> &RandomFourierFeatures {
395 &self.state.rff
396 }
397}
398
399impl Default for RandomFourierFeaturesGPR<Untrained> {
400 fn default() -> Self {
401 Self::new()
402 }
403}