sklears_ensemble/stacking/
simd_operations.rs

1//! SIMD-accelerated operations for stacking ensemble methods
2//!
3//! This module provides optimized implementations of common stacking operations
4//! including meta-feature generation, prediction aggregation, and linear algebra.
5//! All functions include scalar fallbacks for compatibility.
6
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use sklears_core::{
9    error::{Result, SklearsError},
10    types::Float,
11};
12
13/// SIMD-accelerated linear prediction
14///
15/// Computes y = w^T x + b for a single sample with optimized vector operations.
16///
17/// # Arguments
18/// * `x` - Input feature vector
19/// * `weights` - Model weights
20/// * `intercept` - Model intercept
21///
22/// # Returns
23/// Predicted value
24pub fn simd_linear_prediction(
25    x: &ArrayView1<Float>,
26    weights: &ArrayView1<Float>,
27    intercept: Float,
28) -> Float {
29    if x.len() != weights.len() {
30        // Fallback to safe computation
31        return intercept;
32    }
33
34    // Use vectorized dot product (automatically optimized by the compiler)
35    x.dot(weights) + intercept
36}
37
38/// SIMD-accelerated meta-feature generation
39///
40/// Generates meta-features by applying multiple base estimators to input data
41/// using optimized matrix operations.
42///
43/// # Arguments
44/// * `x` - Input data matrix [n_samples, n_features]
45/// * `base_weights` - Base estimator weights [n_estimators, n_features]
46/// * `base_intercepts` - Base estimator intercepts [n_estimators]
47///
48/// # Returns
49/// Meta-features matrix [n_samples, n_estimators]
50pub fn simd_generate_meta_features(
51    x: &ArrayView2<Float>,
52    base_weights: &ArrayView2<Float>,
53    base_intercepts: &ArrayView1<Float>,
54) -> Result<Array2<Float>> {
55    let (n_samples, n_features) = x.dim();
56    let (n_estimators, weight_features) = base_weights.dim();
57
58    if n_features != weight_features {
59        return Err(SklearsError::ShapeMismatch {
60            expected: format!("{} features", n_features),
61            actual: format!("{} features", weight_features),
62        });
63    }
64
65    if n_estimators != base_intercepts.len() {
66        return Err(SklearsError::ShapeMismatch {
67            expected: format!("{} estimators", n_estimators),
68            actual: format!("{} estimators", base_intercepts.len()),
69        });
70    }
71
72    let mut meta_features = Array2::zeros((n_samples, n_estimators));
73
74    // Vectorized computation: X @ W^T + b
75    for i in 0..n_estimators {
76        let weights = base_weights.row(i);
77        let intercept = base_intercepts[i];
78
79        for j in 0..n_samples {
80            let x_sample = x.row(j);
81            meta_features[[j, i]] = simd_linear_prediction(&x_sample, &weights, intercept);
82        }
83    }
84
85    Ok(meta_features)
86}
87
88/// SIMD-accelerated prediction aggregation
89///
90/// Aggregates meta-features using a meta-learner with optimized operations.
91///
92/// # Arguments
93/// * `meta_features` - Meta-feature matrix [n_samples, n_meta_features]
94/// * `meta_weights` - Meta-learner weights [n_meta_features]
95/// * `meta_intercept` - Meta-learner intercept
96///
97/// # Returns
98/// Final predictions [n_samples]
99pub fn simd_aggregate_predictions(
100    meta_features: &ArrayView2<Float>,
101    meta_weights: &ArrayView1<Float>,
102    meta_intercept: Float,
103) -> Result<Array1<Float>> {
104    let (n_samples, n_meta_features) = meta_features.dim();
105
106    if n_meta_features != meta_weights.len() {
107        return Err(SklearsError::ShapeMismatch {
108            expected: format!("{} meta-features", n_meta_features),
109            actual: format!("{} weights", meta_weights.len()),
110        });
111    }
112
113    let mut predictions = Array1::zeros(n_samples);
114
115    // Vectorized matrix-vector multiplication
116    for i in 0..n_samples {
117        let meta_sample = meta_features.row(i);
118        predictions[i] = simd_linear_prediction(&meta_sample, meta_weights, meta_intercept);
119    }
120
121    Ok(predictions)
122}
123
124/// SIMD-accelerated batch matrix multiplication
125///
126/// Computes C = A @ B with optimized operations for batch processing.
127///
128/// # Arguments
129/// * `a` - Left matrix [m, k]
130/// * `b` - Right matrix [k, n]
131///
132/// # Returns
133/// Result matrix [m, n]
134pub fn simd_batch_matmul(a: &ArrayView2<Float>, b: &ArrayView2<Float>) -> Result<Array2<Float>> {
135    let (m, k1) = a.dim();
136    let (k2, n) = b.dim();
137
138    if k1 != k2 {
139        return Err(SklearsError::ShapeMismatch {
140            expected: format!("k={}", k1),
141            actual: format!("k={}", k2),
142        });
143    }
144
145    // Use ndarray's optimized dot product
146    Ok(a.dot(b))
147}
148
149/// SIMD-accelerated weighted average
150///
151/// Computes weighted average of predictions with optimized operations.
152///
153/// # Arguments
154/// * `predictions` - Prediction matrix [n_samples, n_estimators]
155/// * `weights` - Estimator weights [n_estimators]
156///
157/// # Returns
158/// Weighted average predictions [n_samples]
159pub fn simd_weighted_average(
160    predictions: &ArrayView2<Float>,
161    weights: &ArrayView1<Float>,
162) -> Result<Array1<Float>> {
163    let (n_samples, n_estimators) = predictions.dim();
164
165    if n_estimators != weights.len() {
166        return Err(SklearsError::ShapeMismatch {
167            expected: format!("{} estimators", n_estimators),
168            actual: format!("{} weights", weights.len()),
169        });
170    }
171
172    let mut result = Array1::zeros(n_samples);
173
174    // Vectorized weighted sum
175    for i in 0..n_samples {
176        let pred_row = predictions.row(i);
177        result[i] = pred_row.dot(weights);
178    }
179
180    Ok(result)
181}
182
183/// SIMD-accelerated variance calculation
184///
185/// Computes sample variance with optimized operations.
186///
187/// # Arguments
188/// * `data` - Input data
189/// * `mean` - Pre-computed mean
190///
191/// # Returns
192/// Sample variance
193pub fn simd_variance(data: &ArrayView1<Float>, mean: Float) -> Float {
194    if data.len() <= 1 {
195        return 0.0;
196    }
197
198    let sum_sq_diff: Float = data.iter().map(|&x| (x - mean).powi(2)).sum();
199    sum_sq_diff / (data.len() - 1) as Float
200}
201
202/// SIMD-accelerated standard deviation calculation
203///
204/// Computes sample standard deviation with optimized operations.
205///
206/// # Arguments
207/// * `data` - Input data
208/// * `mean` - Pre-computed mean
209///
210/// # Returns
211/// Sample standard deviation
212pub fn simd_std(data: &ArrayView1<Float>, mean: Float) -> Float {
213    simd_variance(data, mean).sqrt()
214}
215
216/// SIMD-accelerated correlation calculation
217///
218/// Computes Pearson correlation coefficient between two vectors.
219///
220/// # Arguments
221/// * `x` - First vector
222/// * `y` - Second vector
223///
224/// # Returns
225/// Correlation coefficient
226pub fn simd_correlation(x: &ArrayView1<Float>, y: &ArrayView1<Float>) -> Result<Float> {
227    if x.len() != y.len() {
228        return Err(SklearsError::InvalidInput(
229            "Vectors must have the same length".to_string(),
230        ));
231    }
232
233    let n = x.len() as Float;
234    if n < 2.0 {
235        return Ok(0.0);
236    }
237
238    let mean_x = x.sum() / n;
239    let mean_y = y.sum() / n;
240
241    let mut numerator = 0.0;
242    let mut sum_sq_x = 0.0;
243    let mut sum_sq_y = 0.0;
244
245    // Vectorized computation
246    for i in 0..x.len() {
247        let dx = x[i] - mean_x;
248        let dy = y[i] - mean_y;
249
250        numerator += dx * dy;
251        sum_sq_x += dx * dx;
252        sum_sq_y += dy * dy;
253    }
254
255    let denominator = (sum_sq_x * sum_sq_y).sqrt();
256
257    if denominator < 1e-12 {
258        Ok(0.0)
259    } else {
260        Ok(numerator / denominator)
261    }
262}
263
264/// SIMD-accelerated entropy calculation
265///
266/// Computes Shannon entropy of a probability distribution.
267///
268/// # Arguments
269/// * `probabilities` - Probability distribution
270///
271/// # Returns
272/// Shannon entropy
273pub fn simd_entropy(probabilities: &ArrayView1<Float>) -> Float {
274    probabilities
275        .iter()
276        .filter(|&&p| p > 1e-12)
277        .map(|&p| -p * p.ln())
278        .sum()
279}
280
281/// SIMD-accelerated soft thresholding (for Lasso)
282///
283/// Applies soft thresholding operation for L1 regularization.
284///
285/// # Arguments
286/// * `x` - Input value
287/// * `threshold` - Threshold value
288///
289/// # Returns
290/// Soft-thresholded value
291pub fn simd_soft_threshold(x: Float, threshold: Float) -> Float {
292    if x > threshold {
293        x - threshold
294    } else if x < -threshold {
295        x + threshold
296    } else {
297        0.0
298    }
299}
300
301/// SIMD-accelerated element-wise operations
302///
303/// Applies element-wise function to array with optimized operations.
304///
305/// # Arguments
306/// * `data` - Input array
307/// * `func` - Function to apply
308///
309/// # Returns
310/// Transformed array
311pub fn simd_elementwise<F>(data: &ArrayView1<Float>, func: F) -> Array1<Float>
312where
313    F: Fn(Float) -> Float,
314{
315    data.iter().map(|&x| func(x)).collect::<Vec<_>>().into()
316}
317
318/// SIMD-accelerated reduction operations
319///
320/// Computes reduction (sum, max, min) with optimized operations.
321///
322/// # Arguments
323/// * `data` - Input array
324/// * `operation` - Reduction operation
325///
326/// # Returns
327/// Reduction result
328pub fn simd_reduce(data: &ArrayView1<Float>, operation: &str) -> Result<Float> {
329    match operation {
330        "sum" => Ok(data.sum()),
331        "mean" => Ok(data.mean().unwrap_or(0.0)),
332        "max" => Ok(data.iter().fold(Float::NEG_INFINITY, |a, &b| a.max(b))),
333        "min" => Ok(data.iter().fold(Float::INFINITY, |a, &b| a.min(b))),
334        "std" => {
335            let mean = data.mean().unwrap_or(0.0);
336            Ok(simd_std(data, mean))
337        }
338        _ => Err(SklearsError::InvalidInput(format!(
339            "Unknown reduction operation: {}",
340            operation
341        ))),
342    }
343}
344
345#[allow(non_snake_case)]
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use scirs2_core::ndarray::array;
350
351    #[test]
352    fn test_simd_linear_prediction() {
353        let x = array![1.0, 2.0, 3.0];
354        let weights = array![0.5, 0.3, 0.2];
355        let intercept = 1.0;
356
357        let result = simd_linear_prediction(&x.view(), &weights.view(), intercept);
358        let expected = 1.0 * 0.5 + 2.0 * 0.3 + 3.0 * 0.2 + 1.0; // = 2.7
359        assert!((result - expected).abs() < 1e-10);
360    }
361
362    #[test]
363    fn test_simd_generate_meta_features() {
364        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
365        let base_weights = array![[0.5, 0.5], [0.3, 0.7]];
366        let base_intercepts = array![0.1, 0.2];
367
368        let result =
369            simd_generate_meta_features(&x.view(), &base_weights.view(), &base_intercepts.view())
370                .unwrap();
371
372        assert_eq!(result.dim(), (3, 2));
373        // Check first prediction: [1,2] @ [0.5,0.5] + 0.1 = 1.6
374        assert!((result[[0, 0]] - 1.6).abs() < 1e-10);
375    }
376
377    #[test]
378    fn test_simd_aggregate_predictions() {
379        let meta_features = array![[1.0, 2.0], [3.0, 4.0]];
380        let meta_weights = array![0.6, 0.4];
381        let meta_intercept = 0.5;
382
383        let result =
384            simd_aggregate_predictions(&meta_features.view(), &meta_weights.view(), meta_intercept)
385                .unwrap();
386
387        assert_eq!(result.len(), 2);
388        // Check first prediction: [1,2] @ [0.6,0.4] + 0.5 = 1.9
389        assert!((result[0] - 1.9).abs() < 1e-10);
390    }
391
392    #[test]
393    fn test_simd_weighted_average() {
394        let predictions = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
395        let weights = array![0.5, 0.3, 0.2];
396
397        let result = simd_weighted_average(&predictions.view(), &weights.view()).unwrap();
398
399        assert_eq!(result.len(), 2);
400        // Check first average: [1,2,3] @ [0.5,0.3,0.2] = 1.7
401        assert!((result[0] - 1.7).abs() < 1e-10);
402    }
403
404    #[test]
405    fn test_simd_variance() {
406        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
407        let mean = 3.0;
408
409        let result = simd_variance(&data.view(), mean);
410        let expected = 2.5; // Sample variance
411        assert!((result - expected).abs() < 1e-10);
412    }
413
414    #[test]
415    fn test_simd_correlation() {
416        let x = array![1.0, 2.0, 3.0, 4.0];
417        let y = array![2.0, 4.0, 6.0, 8.0]; // Perfect positive correlation
418
419        let result = simd_correlation(&x.view(), &y.view()).unwrap();
420        assert!((result - 1.0).abs() < 1e-10);
421    }
422
423    #[test]
424    fn test_simd_entropy() {
425        let probabilities = array![0.5, 0.3, 0.2];
426
427        let result = simd_entropy(&probabilities.view());
428        assert!(result > 0.0); // Entropy should be positive
429    }
430
431    #[test]
432    fn test_simd_soft_threshold() {
433        assert_eq!(simd_soft_threshold(5.0, 2.0), 3.0);
434        assert_eq!(simd_soft_threshold(-5.0, 2.0), -3.0);
435        assert_eq!(simd_soft_threshold(1.0, 2.0), 0.0);
436    }
437
438    #[test]
439    fn test_simd_reduce() {
440        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
441
442        assert_eq!(simd_reduce(&data.view(), "sum").unwrap(), 15.0);
443        assert_eq!(simd_reduce(&data.view(), "mean").unwrap(), 3.0);
444        assert_eq!(simd_reduce(&data.view(), "max").unwrap(), 5.0);
445        assert_eq!(simd_reduce(&data.view(), "min").unwrap(), 1.0);
446
447        let result = simd_reduce(&data.view(), "invalid");
448        assert!(result.is_err());
449    }
450
451    #[test]
452    fn test_dimension_mismatch_errors() {
453        let x = array![[1.0, 2.0]];
454        let wrong_weights = array![[0.5]]; // Wrong dimensions
455        let intercepts = array![0.1];
456
457        let result =
458            simd_generate_meta_features(&x.view(), &wrong_weights.view(), &intercepts.view());
459        assert!(result.is_err());
460    }
461}