scirs2_metrics/optimization/
parallel.rs

1//! Parallel computation utilities for metrics
2//!
3//! This module provides tools for computing metrics in parallel using core parallel operations.
4
5use parking_lot;
6use scirs2_core::ndarray::{ArrayBase, Data, Dimension};
7use scirs2_core::parallel_ops::*;
8use std::sync::Arc;
9
10use crate::error::Result;
11
12/// Type alias for a metric function that can be executed in parallel
13pub type ParallelMetricFn<S1, S2, D1, D2> =
14    dyn Fn(&ArrayBase<S1, D1>, &ArrayBase<S2, D2>) -> Result<f64> + Send + Sync;
15
16/// Configuration for parallel metrics computation
17///
18/// This struct provides options for controlling parallel execution
19/// of metrics calculations using Rayon.
20#[derive(Debug, Clone)]
21pub struct ParallelConfig {
22    /// Minimum chunk size for parallel processing
23    pub min_chunk_size: usize,
24    /// Whether to use parallel processing
25    pub parallel_enabled: bool,
26    /// Number of threads to use (None = use Rayon's default thread pool)
27    pub num_threads: Option<usize>,
28}
29
30impl Default for ParallelConfig {
31    fn default() -> Self {
32        ParallelConfig {
33            min_chunk_size: 1000,
34            parallel_enabled: true,
35            num_threads: None,
36        }
37    }
38}
39
40impl ParallelConfig {
41    /// Create a new ParallelConfig with default values
42    pub fn new() -> Self {
43        Default::default()
44    }
45
46    /// Set the minimum chunk size for parallel processing
47    pub fn with_min_chunk_size(mut self, size: usize) -> Self {
48        self.min_chunk_size = size;
49        self
50    }
51
52    /// Enable or disable parallel processing
53    pub fn with_parallel_enabled(mut self, enabled: bool) -> Self {
54        self.parallel_enabled = enabled;
55        self
56    }
57
58    /// Set the number of threads to use
59    pub fn with_num_threads(mut self, threads: Option<usize>) -> Self {
60        self.num_threads = threads;
61        self
62    }
63}
64
65/// Trait for metrics that can be computed in parallel
66pub trait ParallelMetric<T, D>
67where
68    T: Send + Sync,
69    D: Dimension,
70{
71    /// Compute the metric in parallel
72    fn compute_parallel(
73        &self,
74        x: &ArrayBase<impl Data<Elem = T>, D>,
75        config: &ParallelConfig,
76    ) -> Result<f64>;
77}
78
79/// Compute multiple metrics in parallel
80///
81/// This function computes multiple metrics in parallel using Rayon.
82///
83/// # Arguments
84///
85/// * `y_true` - True values
86/// * `y_pred` - Predicted values
87/// * `metric_fns` - Vector of metric functions
88/// * `config` - Parallel configuration
89///
90/// # Returns
91///
92/// * Vector of metric values
93///
94/// # Examples
95///
96/// ```
97/// use scirs2_core::ndarray::Array1;
98/// use scirs2_metrics::optimization::parallel::{compute_metrics_batch, ParallelConfig};
99/// use scirs2_metrics::error::Result;
100/// use scirs2_metrics::classification::{accuracy_score, precision_score};
101///
102/// // Create sample data
103/// let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2]);
104/// let y_pred = Array1::from_vec(vec![0, 2, 1, 0, 0, 2]);
105///
106/// // Define metric functions
107/// let metric_fns: Vec<Box<dyn Fn(&Array1<i32>, &Array1<i32>) -> Result<f64> + Send + Sync>> = vec![
108///     Box::new(|a, b| accuracy_score(a, b)),
109///     Box::new(|a, b| precision_score(a, b, 1)),
110/// ];
111///
112/// // Compute metrics in parallel
113/// let config = ParallelConfig::default();
114/// let results = compute_metrics_batch(&y_true, &y_pred, &metric_fns, &config).unwrap();
115///
116/// // Check results
117/// assert_eq!(results.len(), 2);
118/// ```
119#[allow(dead_code)]
120pub fn compute_metrics_batch<T, S1, S2, D1, D2>(
121    y_true: &ArrayBase<S1, D1>,
122    y_pred: &ArrayBase<S2, D2>,
123    metric_fns: &[Box<ParallelMetricFn<S1, S2, D1, D2>>],
124    config: &ParallelConfig,
125) -> Result<Vec<f64>>
126where
127    T: Clone + Send + Sync,
128    S1: Data<Elem = T> + Sync,
129    S2: Data<Elem = T> + Sync,
130    D1: Dimension + Sync,
131    D2: Dimension + Sync,
132{
133    if !config.parallel_enabled || metric_fns.len() < 2 {
134        // Sequential computation if parallel is disabled or only one metric
135        let mut results = Vec::with_capacity(metric_fns.len());
136        for metric_fn in metric_fns {
137            let value = metric_fn(y_true, y_pred)?;
138            results.push(value);
139        }
140        return Ok(results);
141    }
142
143    // Parallel computation of metrics
144    let results: Result<Vec<f64>> = metric_fns
145        .par_iter()
146        .map(|metric_fn| metric_fn(y_true, y_pred))
147        .collect();
148
149    results
150}
151
152/// Process a large array in chunks with parallel execution
153///
154/// This function splits a large array into chunks and processes each chunk in parallel.
155///
156/// # Arguments
157///
158/// * `data` - Input data
159/// * `chunk_size` - Size of each chunk
160/// * `chunk_op` - Operation to perform on each chunk
161/// * `reducer` - Function to combine results from all chunks
162///
163/// # Returns
164///
165/// * Combined result
166///
167/// # Examples
168///
169/// ```
170/// use scirs2_metrics::optimization::parallel::{chunked_parallel_compute, ParallelConfig};
171/// use scirs2_metrics::error::Result;
172///
173/// // Create sample data
174/// let data: Vec<f64> = (0..1000).map(|x| x as f64).collect();
175///
176/// // Define chunk operation (sum of squares)
177/// let chunk_op = |chunk: &[f64]| -> Result<f64> {
178///     Ok(chunk.iter().map(|x| x * x).sum())
179/// };
180///
181/// // Define reducer (sum of partial results)
182/// let reducer = |results: Vec<f64>| -> Result<f64> {
183///     Ok(results.iter().sum())
184/// };
185///
186/// // Process data in chunks
187/// let result = chunked_parallel_compute(&data, 100, chunk_op, reducer).unwrap();
188///
189/// // Verify result
190/// let expected: f64 = (0..1000).map(|x| (x * x) as f64).sum();
191/// assert!((result - expected).abs() < 1e-10);
192/// ```
193#[allow(dead_code)]
194pub fn chunked_parallel_compute<T, R>(
195    data: &[T],
196    chunk_size: usize,
197    chunk_op: impl Fn(&[T]) -> Result<R> + Send + Sync,
198    reducer: impl Fn(Vec<R>) -> Result<R>,
199) -> Result<R>
200where
201    T: Clone + Send + Sync,
202    R: Send + Sync,
203{
204    if data.len() <= chunk_size {
205        // If data fits in a single chunk, just process it directly
206        return chunk_op(data);
207    }
208
209    // Split data into chunks
210    let chunks: Vec<&[T]> = data.chunks(chunk_size).collect();
211
212    // Process chunks in parallel
213    let results: Result<Vec<R>> = chunks.par_iter().map(|chunk| chunk_op(chunk)).collect();
214
215    // Combine results
216    reducer(results?)
217}
218
219/// Trait for defining chunked metric operations
220pub trait ChunkedMetric<T> {
221    /// Type for intermediate state
222    type State: Send + Sync;
223
224    /// Initialize state
225    fn init_state(&self) -> Self::State;
226
227    /// Process a chunk and update state
228    fn process_chunk(&self, state: &mut Self::State, chunk: &[T]) -> Result<()>;
229
230    /// Finalize computation from state
231    fn finalize(&self, state: &Self::State) -> Result<f64>;
232}
233
234/// Process a large array using chunked metric computation
235///
236/// # Arguments
237///
238/// * `data` - Input data
239/// * `metric` - Chunked metric implementation
240/// * `chunk_size` - Size of each chunk
241/// * `config` - Parallel configuration
242///
243/// # Returns
244///
245/// * Computed metric value
246#[allow(dead_code)]
247pub fn compute_chunked_metric<T, M>(
248    data: &[T],
249    metric: &M,
250    chunk_size: usize,
251    config: &ParallelConfig,
252) -> Result<f64>
253where
254    T: Clone + Send + Sync,
255    M: ChunkedMetric<T> + Send + Sync,
256{
257    if data.len() <= chunk_size || !config.parallel_enabled {
258        // If data fits in a single chunk or parallel is disabled
259        let mut state = metric.init_state();
260        metric.process_chunk(&mut state, data)?;
261        return metric.finalize(&state);
262    }
263
264    // Create shared state
265    let state = Arc::new(parking_lot::Mutex::new(metric.init_state()));
266    let metric = Arc::new(metric);
267
268    // Split data into chunks
269    let chunks: Vec<&[T]> = data.chunks(chunk_size).collect();
270
271    // Process chunks in parallel
272    let result: Result<()> = chunks.par_iter().try_for_each(|chunk| {
273        let mut local_state = metric.init_state();
274        metric.process_chunk(&mut local_state, chunk)?;
275
276        // Update global state with mutex
277        let mut global_state = state.lock();
278        metric.process_chunk(&mut *global_state, chunk)?;
279        Ok(())
280    });
281
282    // Check for errors during processing
283    result?;
284
285    // Finalize computation
286    let state_lock = state.lock();
287    let result = metric.finalize(&*state_lock);
288    drop(state_lock); // Explicitly drop the lock before the end of the function
289    result
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::error::MetricsError;
296    use scirs2_core::ndarray::Array1;
297
298    #[test]
299    fn test_parallel_config() {
300        let config = ParallelConfig::new()
301            .with_min_chunk_size(500)
302            .with_parallel_enabled(true)
303            .with_num_threads(Some(4));
304
305        assert_eq!(config.min_chunk_size, 500);
306        assert!(config.parallel_enabled);
307        assert_eq!(config.num_threads, Some(4));
308    }
309
310    #[test]
311    fn test_compute_metrics_batch() {
312        // Create sample data
313        let y_true = Array1::from_vec(vec![0, 1, 2, 0, 1, 2]);
314        let y_pred = Array1::from_vec(vec![0, 2, 1, 0, 0, 2]);
315
316        // Define metric functions
317        type MetricFn = Box<dyn Fn(&Array1<i32>, &Array1<i32>) -> Result<f64> + Send + Sync>;
318        let metric_fns: Vec<MetricFn> = vec![
319            Box::new(|a, b| {
320                if a.len() != b.len() {
321                    return Err(MetricsError::InvalidInput("Lengths must match".to_string()));
322                }
323                // Simple accuracy calculation for test
324                let correct = a.iter().zip(b.iter()).filter(|&(a, b)| a == b).count();
325                Ok(correct as f64 / a.len() as f64)
326            }),
327            Box::new(|a, _b| {
328                // Another dummy metric
329                Ok(a.len() as f64)
330            }),
331        ];
332
333        // Compute metrics with parallel disabled
334        let config = ParallelConfig::new().with_parallel_enabled(false);
335        let results = compute_metrics_batch(&y_true, &y_pred, &metric_fns, &config).unwrap();
336
337        assert_eq!(results.len(), 2);
338        assert!((results[0] - 0.5).abs() < 1e-10); // 3/6 correct
339        assert!((results[1] - 6.0).abs() < 1e-10); // Length is 6
340
341        // Compute metrics with parallel enabled
342        let config = ParallelConfig::new().with_parallel_enabled(true);
343        let results = compute_metrics_batch(&y_true, &y_pred, &metric_fns, &config).unwrap();
344
345        assert_eq!(results.len(), 2);
346        assert!((results[0] - 0.5).abs() < 1e-10);
347        assert!((results[1] - 6.0).abs() < 1e-10);
348    }
349
350    #[test]
351    fn test_chunked_parallel_compute() {
352        // Create sample data
353        let data: Vec<f64> = (0..1000).map(|x| x as f64).collect();
354
355        // Define chunk operation (sum of squares)
356        let chunk_op = |chunk: &[f64]| -> Result<f64> { Ok(chunk.iter().map(|x| x * x).sum()) };
357
358        // Define reducer (sum of partial results)
359        let reducer = |results: Vec<f64>| -> Result<f64> { Ok(results.iter().sum()) };
360
361        // Process data in chunks
362        let result = chunked_parallel_compute(&data, 100, chunk_op, reducer).unwrap();
363
364        // Verify result against direct calculation
365        let expected: f64 = (0..1000).map(|x| (x * x) as f64).sum();
366        assert!((result - expected).abs() < 1e-10);
367    }
368
369    // Example implementation of ChunkedMetric for testing
370    struct MeanChunkedMetric;
371
372    impl ChunkedMetric<f64> for MeanChunkedMetric {
373        type State = (f64, usize); // (sum, count)
374
375        fn init_state(&self) -> Self::State {
376            (0.0, 0)
377        }
378
379        fn process_chunk(&self, state: &mut Self::State, chunk: &[f64]) -> Result<()> {
380            for &value in chunk {
381                state.0 += value;
382                state.1 += 1;
383            }
384            Ok(())
385        }
386
387        fn finalize(&self, state: &Self::State) -> Result<f64> {
388            if state.1 == 0 {
389                return Err(MetricsError::DivisionByZero);
390            }
391            Ok(state.0 / state.1 as f64)
392        }
393    }
394
395    #[test]
396    fn test_compute_chunked_metric() {
397        // Create sample data
398        let data: Vec<f64> = (0..1000).map(|x| x as f64).collect();
399
400        // Create metric
401        let metric = MeanChunkedMetric;
402
403        // Compute with chunking
404        let config = ParallelConfig::default();
405        let result = compute_chunked_metric(&data, &metric, 100, &config).unwrap();
406
407        // Verify result against direct calculation
408        let expected: f64 = data.iter().sum::<f64>() / data.len() as f64;
409        assert!((result - expected).abs() < 1e-10);
410    }
411}