sklears_ensemble/stacking/
simd_operations.rs1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use sklears_core::{
9 error::{Result, SklearsError},
10 types::Float,
11};
12
13pub fn simd_linear_prediction(
25 x: &ArrayView1<Float>,
26 weights: &ArrayView1<Float>,
27 intercept: Float,
28) -> Float {
29 if x.len() != weights.len() {
30 return intercept;
32 }
33
34 x.dot(weights) + intercept
36}
37
38pub fn simd_generate_meta_features(
51 x: &ArrayView2<Float>,
52 base_weights: &ArrayView2<Float>,
53 base_intercepts: &ArrayView1<Float>,
54) -> Result<Array2<Float>> {
55 let (n_samples, n_features) = x.dim();
56 let (n_estimators, weight_features) = base_weights.dim();
57
58 if n_features != weight_features {
59 return Err(SklearsError::ShapeMismatch {
60 expected: format!("{} features", n_features),
61 actual: format!("{} features", weight_features),
62 });
63 }
64
65 if n_estimators != base_intercepts.len() {
66 return Err(SklearsError::ShapeMismatch {
67 expected: format!("{} estimators", n_estimators),
68 actual: format!("{} estimators", base_intercepts.len()),
69 });
70 }
71
72 let mut meta_features = Array2::zeros((n_samples, n_estimators));
73
74 for i in 0..n_estimators {
76 let weights = base_weights.row(i);
77 let intercept = base_intercepts[i];
78
79 for j in 0..n_samples {
80 let x_sample = x.row(j);
81 meta_features[[j, i]] = simd_linear_prediction(&x_sample, &weights, intercept);
82 }
83 }
84
85 Ok(meta_features)
86}
87
88pub fn simd_aggregate_predictions(
100 meta_features: &ArrayView2<Float>,
101 meta_weights: &ArrayView1<Float>,
102 meta_intercept: Float,
103) -> Result<Array1<Float>> {
104 let (n_samples, n_meta_features) = meta_features.dim();
105
106 if n_meta_features != meta_weights.len() {
107 return Err(SklearsError::ShapeMismatch {
108 expected: format!("{} meta-features", n_meta_features),
109 actual: format!("{} weights", meta_weights.len()),
110 });
111 }
112
113 let mut predictions = Array1::zeros(n_samples);
114
115 for i in 0..n_samples {
117 let meta_sample = meta_features.row(i);
118 predictions[i] = simd_linear_prediction(&meta_sample, meta_weights, meta_intercept);
119 }
120
121 Ok(predictions)
122}
123
124pub fn simd_batch_matmul(a: &ArrayView2<Float>, b: &ArrayView2<Float>) -> Result<Array2<Float>> {
135 let (m, k1) = a.dim();
136 let (k2, n) = b.dim();
137
138 if k1 != k2 {
139 return Err(SklearsError::ShapeMismatch {
140 expected: format!("k={}", k1),
141 actual: format!("k={}", k2),
142 });
143 }
144
145 Ok(a.dot(b))
147}
148
149pub fn simd_weighted_average(
160 predictions: &ArrayView2<Float>,
161 weights: &ArrayView1<Float>,
162) -> Result<Array1<Float>> {
163 let (n_samples, n_estimators) = predictions.dim();
164
165 if n_estimators != weights.len() {
166 return Err(SklearsError::ShapeMismatch {
167 expected: format!("{} estimators", n_estimators),
168 actual: format!("{} weights", weights.len()),
169 });
170 }
171
172 let mut result = Array1::zeros(n_samples);
173
174 for i in 0..n_samples {
176 let pred_row = predictions.row(i);
177 result[i] = pred_row.dot(weights);
178 }
179
180 Ok(result)
181}
182
183pub fn simd_variance(data: &ArrayView1<Float>, mean: Float) -> Float {
194 if data.len() <= 1 {
195 return 0.0;
196 }
197
198 let sum_sq_diff: Float = data.iter().map(|&x| (x - mean).powi(2)).sum();
199 sum_sq_diff / (data.len() - 1) as Float
200}
201
202pub fn simd_std(data: &ArrayView1<Float>, mean: Float) -> Float {
213 simd_variance(data, mean).sqrt()
214}
215
216pub fn simd_correlation(x: &ArrayView1<Float>, y: &ArrayView1<Float>) -> Result<Float> {
227 if x.len() != y.len() {
228 return Err(SklearsError::InvalidInput(
229 "Vectors must have the same length".to_string(),
230 ));
231 }
232
233 let n = x.len() as Float;
234 if n < 2.0 {
235 return Ok(0.0);
236 }
237
238 let mean_x = x.sum() / n;
239 let mean_y = y.sum() / n;
240
241 let mut numerator = 0.0;
242 let mut sum_sq_x = 0.0;
243 let mut sum_sq_y = 0.0;
244
245 for i in 0..x.len() {
247 let dx = x[i] - mean_x;
248 let dy = y[i] - mean_y;
249
250 numerator += dx * dy;
251 sum_sq_x += dx * dx;
252 sum_sq_y += dy * dy;
253 }
254
255 let denominator = (sum_sq_x * sum_sq_y).sqrt();
256
257 if denominator < 1e-12 {
258 Ok(0.0)
259 } else {
260 Ok(numerator / denominator)
261 }
262}
263
264pub fn simd_entropy(probabilities: &ArrayView1<Float>) -> Float {
274 probabilities
275 .iter()
276 .filter(|&&p| p > 1e-12)
277 .map(|&p| -p * p.ln())
278 .sum()
279}
280
281pub fn simd_soft_threshold(x: Float, threshold: Float) -> Float {
292 if x > threshold {
293 x - threshold
294 } else if x < -threshold {
295 x + threshold
296 } else {
297 0.0
298 }
299}
300
301pub fn simd_elementwise<F>(data: &ArrayView1<Float>, func: F) -> Array1<Float>
312where
313 F: Fn(Float) -> Float,
314{
315 data.iter().map(|&x| func(x)).collect::<Vec<_>>().into()
316}
317
318pub fn simd_reduce(data: &ArrayView1<Float>, operation: &str) -> Result<Float> {
329 match operation {
330 "sum" => Ok(data.sum()),
331 "mean" => Ok(data.mean().unwrap_or(0.0)),
332 "max" => Ok(data.iter().fold(Float::NEG_INFINITY, |a, &b| a.max(b))),
333 "min" => Ok(data.iter().fold(Float::INFINITY, |a, &b| a.min(b))),
334 "std" => {
335 let mean = data.mean().unwrap_or(0.0);
336 Ok(simd_std(data, mean))
337 }
338 _ => Err(SklearsError::InvalidInput(format!(
339 "Unknown reduction operation: {}",
340 operation
341 ))),
342 }
343}
344
345#[allow(non_snake_case)]
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use scirs2_core::ndarray::array;
350
351 #[test]
352 fn test_simd_linear_prediction() {
353 let x = array![1.0, 2.0, 3.0];
354 let weights = array![0.5, 0.3, 0.2];
355 let intercept = 1.0;
356
357 let result = simd_linear_prediction(&x.view(), &weights.view(), intercept);
358 let expected = 1.0 * 0.5 + 2.0 * 0.3 + 3.0 * 0.2 + 1.0; assert!((result - expected).abs() < 1e-10);
360 }
361
362 #[test]
363 fn test_simd_generate_meta_features() {
364 let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
365 let base_weights = array![[0.5, 0.5], [0.3, 0.7]];
366 let base_intercepts = array![0.1, 0.2];
367
368 let result =
369 simd_generate_meta_features(&x.view(), &base_weights.view(), &base_intercepts.view())
370 .unwrap();
371
372 assert_eq!(result.dim(), (3, 2));
373 assert!((result[[0, 0]] - 1.6).abs() < 1e-10);
375 }
376
377 #[test]
378 fn test_simd_aggregate_predictions() {
379 let meta_features = array![[1.0, 2.0], [3.0, 4.0]];
380 let meta_weights = array![0.6, 0.4];
381 let meta_intercept = 0.5;
382
383 let result =
384 simd_aggregate_predictions(&meta_features.view(), &meta_weights.view(), meta_intercept)
385 .unwrap();
386
387 assert_eq!(result.len(), 2);
388 assert!((result[0] - 1.9).abs() < 1e-10);
390 }
391
392 #[test]
393 fn test_simd_weighted_average() {
394 let predictions = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
395 let weights = array![0.5, 0.3, 0.2];
396
397 let result = simd_weighted_average(&predictions.view(), &weights.view()).unwrap();
398
399 assert_eq!(result.len(), 2);
400 assert!((result[0] - 1.7).abs() < 1e-10);
402 }
403
404 #[test]
405 fn test_simd_variance() {
406 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
407 let mean = 3.0;
408
409 let result = simd_variance(&data.view(), mean);
410 let expected = 2.5; assert!((result - expected).abs() < 1e-10);
412 }
413
414 #[test]
415 fn test_simd_correlation() {
416 let x = array![1.0, 2.0, 3.0, 4.0];
417 let y = array![2.0, 4.0, 6.0, 8.0]; let result = simd_correlation(&x.view(), &y.view()).unwrap();
420 assert!((result - 1.0).abs() < 1e-10);
421 }
422
423 #[test]
424 fn test_simd_entropy() {
425 let probabilities = array![0.5, 0.3, 0.2];
426
427 let result = simd_entropy(&probabilities.view());
428 assert!(result > 0.0); }
430
431 #[test]
432 fn test_simd_soft_threshold() {
433 assert_eq!(simd_soft_threshold(5.0, 2.0), 3.0);
434 assert_eq!(simd_soft_threshold(-5.0, 2.0), -3.0);
435 assert_eq!(simd_soft_threshold(1.0, 2.0), 0.0);
436 }
437
438 #[test]
439 fn test_simd_reduce() {
440 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
441
442 assert_eq!(simd_reduce(&data.view(), "sum").unwrap(), 15.0);
443 assert_eq!(simd_reduce(&data.view(), "mean").unwrap(), 3.0);
444 assert_eq!(simd_reduce(&data.view(), "max").unwrap(), 5.0);
445 assert_eq!(simd_reduce(&data.view(), "min").unwrap(), 1.0);
446
447 let result = simd_reduce(&data.view(), "invalid");
448 assert!(result.is_err());
449 }
450
451 #[test]
452 fn test_dimension_mismatch_errors() {
453 let x = array![[1.0, 2.0]];
454 let wrong_weights = array![[0.5]]; let intercepts = array![0.1];
456
457 let result =
458 simd_generate_meta_features(&x.view(), &wrong_weights.view(), &intercepts.view());
459 assert!(result.is_err());
460 }
461}