sklears_neighbors/
batch_processing.rs

1use crate::distance::Distance;
2use crate::{NeighborsError, NeighborsResult};
3#[cfg(feature = "parallel")]
4use rayon::prelude::*;
5use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
6use sklears_core::types::Float;
7use std::collections::VecDeque;
8
9#[derive(Debug, Clone)]
10pub struct BatchConfiguration {
11    pub batch_size: usize,
12    pub max_memory_mb: usize,
13    pub parallel_processing: bool,
14    pub chunk_overlap: usize,
15    pub prefetch_batches: usize,
16}
17
18impl Default for BatchConfiguration {
19    fn default() -> Self {
20        Self {
21            batch_size: 1000,
22            max_memory_mb: 512,
23            parallel_processing: true,
24            chunk_overlap: 0,
25            prefetch_batches: 2,
26        }
27    }
28}
29
30pub struct BatchProcessor {
31    config: BatchConfiguration,
32    memory_monitor: MemoryMonitor,
33}
34
35#[derive(Debug, Clone)]
36pub struct MemoryMonitor {
37    peak_memory_mb: usize,
38    current_memory_mb: usize,
39    memory_threshold_mb: usize,
40}
41
42#[derive(Debug, Clone)]
43pub struct BatchResult<T> {
44    pub results: Vec<T>,
45    pub batch_stats: BatchStatistics,
46}
47
48#[derive(Debug, Clone)]
49pub struct BatchStatistics {
50    pub total_batches: usize,
51    pub processed_samples: usize,
52    pub processing_time_ms: u128,
53    pub peak_memory_mb: usize,
54    pub average_batch_size: usize,
55    pub memory_efficiency: Float,
56}
57
58pub trait BatchProcessable<T> {
59    fn process_batch(&self, batch_data: ArrayView2<Float>) -> NeighborsResult<Vec<T>>;
60    fn estimate_memory_per_sample(&self) -> usize;
61    fn supports_parallel_processing(&self) -> bool;
62}
63
64impl BatchProcessor {
65    pub fn new(config: BatchConfiguration) -> Self {
66        let memory_monitor = MemoryMonitor {
67            peak_memory_mb: 0,
68            current_memory_mb: 0,
69            memory_threshold_mb: config.max_memory_mb,
70        };
71
72        Self {
73            config,
74            memory_monitor,
75        }
76    }
77
78    pub fn builder() -> BatchProcessorBuilder {
79        BatchProcessorBuilder::new()
80    }
81
82    pub fn process_data<T, P>(
83        &mut self,
84        data: ArrayView2<Float>,
85        processor: &P,
86    ) -> NeighborsResult<BatchResult<T>>
87    where
88        T: Send + Sync + Clone,
89        P: BatchProcessable<T> + Sync,
90    {
91        let start_time = std::time::Instant::now();
92        let num_samples = data.nrows();
93
94        if num_samples == 0 {
95            return Err(NeighborsError::EmptyInput);
96        }
97
98        let optimal_batch_size = self.calculate_optimal_batch_size(num_samples, processor)?;
99
100        let mut all_results = Vec::new();
101        let mut batch_count = 0;
102
103        if self.config.parallel_processing && processor.supports_parallel_processing() {
104            all_results = self.process_data_parallel(data, optimal_batch_size, processor)?;
105            batch_count = (num_samples + optimal_batch_size - 1) / optimal_batch_size;
106        } else {
107            let mut start_idx = 0;
108            while start_idx < num_samples {
109                let end_idx = std::cmp::min(start_idx + optimal_batch_size, num_samples);
110                let batch = data.slice(scirs2_core::ndarray::s![start_idx..end_idx, ..]);
111
112                let batch_results = processor.process_batch(batch)?;
113                all_results.extend(batch_results);
114                batch_count += 1;
115                self.update_memory_usage(batch.nrows() * batch.ncols() * 8)?;
116
117                start_idx = end_idx - self.config.chunk_overlap;
118                if start_idx >= end_idx {
119                    break;
120                }
121            }
122        }
123
124        let processing_time = start_time.elapsed().as_millis();
125        let stats = BatchStatistics {
126            total_batches: batch_count,
127            processed_samples: num_samples,
128            processing_time_ms: processing_time,
129            peak_memory_mb: self.memory_monitor.peak_memory_mb,
130            average_batch_size: num_samples / batch_count.max(1),
131            memory_efficiency: self.calculate_memory_efficiency(),
132        };
133
134        Ok(BatchResult {
135            results: all_results,
136            batch_stats: stats,
137        })
138    }
139
140    pub fn process_streaming_data<T, P>(
141        &mut self,
142        data_stream: impl Iterator<Item = Array1<Float>>,
143        processor: &P,
144    ) -> NeighborsResult<Vec<T>>
145    where
146        T: Send + Sync + Clone,
147        P: BatchProcessable<T> + Sync,
148    {
149        let mut buffer = VecDeque::new();
150        let batch_size = self.config.batch_size;
151        let mut all_results = Vec::new();
152
153        for sample in data_stream {
154            buffer.push_back(sample);
155
156            if buffer.len() >= batch_size {
157                let batch_data: Vec<Array1<Float>> = buffer.drain(..batch_size).collect();
158                let batch_matrix = self.vec_to_array2(batch_data)?;
159                let results = processor.process_batch(batch_matrix.view())?;
160                all_results.extend(results);
161            }
162        }
163
164        // Process remaining samples in buffer
165        if !buffer.is_empty() {
166            let batch_data: Vec<Array1<Float>> = buffer.drain(..).collect();
167            let batch_matrix = self.vec_to_array2(batch_data)?;
168            let results = processor.process_batch(batch_matrix.view())?;
169            all_results.extend(results);
170        }
171
172        Ok(all_results)
173    }
174
175    fn calculate_optimal_batch_size<T, P>(
176        &self,
177        num_samples: usize,
178        processor: &P,
179    ) -> NeighborsResult<usize>
180    where
181        P: BatchProcessable<T>,
182    {
183        let memory_per_sample = processor.estimate_memory_per_sample();
184        let max_samples_per_batch = (self.config.max_memory_mb * 1024 * 1024) / memory_per_sample;
185
186        let optimal_size = std::cmp::min(
187            std::cmp::min(self.config.batch_size, max_samples_per_batch),
188            num_samples,
189        );
190
191        if optimal_size == 0 {
192            return Err(NeighborsError::InvalidInput(
193                "Batch size too small for available memory".to_string(),
194            ));
195        }
196
197        Ok(optimal_size)
198    }
199
200    #[cfg(feature = "parallel")]
201    fn process_data_parallel<T, P>(
202        &self,
203        data: ArrayView2<Float>,
204        batch_size: usize,
205        processor: &P,
206    ) -> NeighborsResult<Vec<T>>
207    where
208        T: Send + Sync + Clone,
209        P: BatchProcessable<T> + Sync,
210    {
211        let num_samples = data.nrows();
212        let chunk_indices: Vec<(usize, usize)> = (0..num_samples)
213            .step_by(batch_size)
214            .map(|start| {
215                let end = std::cmp::min(start + batch_size, num_samples);
216                (start, end)
217            })
218            .collect();
219
220        let results: Result<Vec<Vec<T>>, NeighborsError> = chunk_indices
221            .par_iter()
222            .map(|&(start, end)| {
223                let batch = data.slice(scirs2_core::ndarray::s![start..end, ..]);
224                processor.process_batch(batch)
225            })
226            .collect();
227
228        match results {
229            Ok(batch_results) => Ok(batch_results.into_iter().flatten().collect()),
230            Err(e) => Err(e),
231        }
232    }
233
234    #[cfg(not(feature = "parallel"))]
235    fn process_data_parallel<T, P>(
236        &self,
237        data: ArrayView2<Float>,
238        batch_size: usize,
239        processor: &P,
240    ) -> NeighborsResult<Vec<T>>
241    where
242        T: Send + Sync + Clone,
243        P: BatchProcessable<T> + Sync,
244    {
245        // Fallback to sequential processing if parallel feature is not enabled
246        let mut all_results = Vec::new();
247        let num_samples = data.nrows();
248        let mut start_idx = 0;
249
250        while start_idx < num_samples {
251            let end_idx = std::cmp::min(start_idx + batch_size, num_samples);
252            let batch = data.slice(scirs2_core::ndarray::s![start_idx..end_idx, ..]);
253
254            let batch_results = processor.process_batch(batch)?;
255            all_results.extend(batch_results);
256
257            start_idx = end_idx;
258        }
259
260        Ok(all_results)
261    }
262
263    fn update_memory_usage(&mut self, additional_bytes: usize) -> NeighborsResult<()> {
264        let additional_mb = additional_bytes / (1024 * 1024);
265        self.memory_monitor.current_memory_mb += additional_mb;
266
267        if self.memory_monitor.current_memory_mb > self.memory_monitor.peak_memory_mb {
268            self.memory_monitor.peak_memory_mb = self.memory_monitor.current_memory_mb;
269        }
270
271        if self.memory_monitor.current_memory_mb > self.memory_monitor.memory_threshold_mb {
272            return Err(NeighborsError::InvalidInput(format!(
273                "Memory usage exceeded threshold: {} MB",
274                self.memory_monitor.memory_threshold_mb
275            )));
276        }
277
278        Ok(())
279    }
280
281    fn calculate_memory_efficiency(&self) -> Float {
282        if self.memory_monitor.memory_threshold_mb == 0 {
283            return 1.0;
284        }
285
286        1.0 - (self.memory_monitor.peak_memory_mb as Float
287            / self.memory_monitor.memory_threshold_mb as Float)
288    }
289
290    fn vec_to_array2(&self, vec_data: Vec<Array1<Float>>) -> NeighborsResult<Array2<Float>> {
291        if vec_data.is_empty() {
292            return Err(NeighborsError::EmptyInput);
293        }
294
295        let n_samples = vec_data.len();
296        let n_features = vec_data[0].len();
297
298        let mut result = Array2::zeros((n_samples, n_features));
299        for (i, row) in vec_data.iter().enumerate() {
300            if row.len() != n_features {
301                return Err(NeighborsError::ShapeMismatch {
302                    expected: vec![n_features],
303                    actual: vec![row.len()],
304                });
305            }
306            result.row_mut(i).assign(row);
307        }
308
309        Ok(result)
310    }
311
312    pub fn get_memory_stats(&self) -> &MemoryMonitor {
313        &self.memory_monitor
314    }
315
316    pub fn reset_memory_monitor(&mut self) {
317        self.memory_monitor.current_memory_mb = 0;
318        self.memory_monitor.peak_memory_mb = 0;
319    }
320}
321
322pub struct BatchProcessorBuilder {
323    config: BatchConfiguration,
324}
325
326impl Default for BatchProcessorBuilder {
327    fn default() -> Self {
328        Self::new()
329    }
330}
331
332impl BatchProcessorBuilder {
333    pub fn new() -> Self {
334        Self {
335            config: BatchConfiguration::default(),
336        }
337    }
338
339    pub fn batch_size(mut self, size: usize) -> Self {
340        self.config.batch_size = size;
341        self
342    }
343
344    pub fn max_memory_mb(mut self, memory_mb: usize) -> Self {
345        self.config.max_memory_mb = memory_mb;
346        self
347    }
348
349    pub fn parallel_processing(mut self, enabled: bool) -> Self {
350        self.config.parallel_processing = enabled;
351        self
352    }
353
354    pub fn chunk_overlap(mut self, overlap: usize) -> Self {
355        self.config.chunk_overlap = overlap;
356        self
357    }
358
359    pub fn prefetch_batches(mut self, count: usize) -> Self {
360        self.config.prefetch_batches = count;
361        self
362    }
363
364    pub fn build(self) -> BatchProcessor {
365        BatchProcessor::new(self.config)
366    }
367}
368
369pub struct BatchNeighborSearch {
370    k: usize,
371    distance: Distance,
372    training_data: Array2<Float>,
373}
374
375impl BatchNeighborSearch {
376    pub fn new(k: usize, distance: Distance, training_data: Array2<Float>) -> Self {
377        Self {
378            k,
379            distance,
380            training_data,
381        }
382    }
383}
384
385impl BatchProcessable<(Vec<usize>, Vec<Float>)> for BatchNeighborSearch {
386    fn process_batch(
387        &self,
388        batch_data: ArrayView2<Float>,
389    ) -> NeighborsResult<Vec<(Vec<usize>, Vec<Float>)>> {
390        let mut results = Vec::new();
391
392        for query_row in batch_data.axis_iter(Axis(0)) {
393            let mut distances: Vec<(Float, usize)> = self
394                .training_data
395                .axis_iter(Axis(0))
396                .enumerate()
397                .map(|(idx, train_row)| {
398                    let dist = self.distance.calculate(&query_row, &train_row);
399                    (dist, idx)
400                })
401                .collect();
402
403            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
404            distances.truncate(self.k);
405
406            let indices: Vec<usize> = distances.iter().map(|(_, idx)| *idx).collect();
407            let dists: Vec<Float> = distances.iter().map(|(dist, _)| *dist).collect();
408
409            results.push((indices, dists));
410        }
411
412        Ok(results)
413    }
414
415    fn estimate_memory_per_sample(&self) -> usize {
416        let feature_memory = self.training_data.ncols() * 8; // 8 bytes per Float
417        let distance_memory = self.training_data.nrows() * 16; // distance + index pairs
418        let result_memory = self.k * 16; // k neighbors with distance and index
419
420        feature_memory + distance_memory + result_memory
421    }
422
423    fn supports_parallel_processing(&self) -> bool {
424        true
425    }
426}
427
428#[allow(non_snake_case)]
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn test_batch_processor_creation() {
435        let processor = BatchProcessor::builder()
436            .batch_size(500)
437            .max_memory_mb(256)
438            .parallel_processing(true)
439            .build();
440
441        assert_eq!(processor.config.batch_size, 500);
442        assert_eq!(processor.config.max_memory_mb, 256);
443        assert!(processor.config.parallel_processing);
444    }
445
446    #[test]
447    fn test_memory_efficient_batch_processing() {
448        let training_data =
449            Array2::from_shape_vec((100, 4), (0..400).map(|x| x as Float).collect()).unwrap();
450        let test_data =
451            Array2::from_shape_vec((50, 4), (0..200).map(|x| x as Float).collect()).unwrap();
452
453        let search = BatchNeighborSearch::new(5, Distance::Euclidean, training_data);
454        let mut processor = BatchProcessor::builder()
455            .batch_size(10)
456            .max_memory_mb(64)
457            .build();
458
459        let result = processor.process_data(test_data.view(), &search).unwrap();
460
461        assert_eq!(result.results.len(), 50);
462        assert!(result.batch_stats.total_batches > 0);
463        assert_eq!(result.batch_stats.processed_samples, 50);
464    }
465
466    #[test]
467    fn test_optimal_batch_size_calculation() {
468        let training_data = Array2::zeros((100, 10));
469        let search = BatchNeighborSearch::new(5, Distance::Euclidean, training_data);
470        let processor = BatchProcessor::builder()
471            .batch_size(1000)
472            .max_memory_mb(1)
473            .build();
474
475        let optimal_size = processor.calculate_optimal_batch_size(50, &search).unwrap();
476
477        // Should be constrained by memory limit
478        assert!(optimal_size <= 50);
479        assert!(optimal_size > 0);
480    }
481
482    #[test]
483    fn test_batch_processing_with_overlap() {
484        let training_data =
485            Array2::from_shape_vec((20, 2), (0..40).map(|x| x as Float).collect()).unwrap();
486        let _test_data =
487            Array2::from_shape_vec((10, 2), (0..20).map(|x| x as Float).collect()).unwrap();
488
489        let _search = BatchNeighborSearch::new(3, Distance::Euclidean, training_data);
490        let processor = BatchProcessor::builder()
491            .batch_size(4)
492            .chunk_overlap(2)
493            .build();
494
495        // Just verify that overlap configuration doesn't break processing
496        let config = &processor.config;
497        assert_eq!(config.chunk_overlap, 2);
498        assert_eq!(config.batch_size, 4);
499    }
500
501    #[test]
502    fn test_memory_monitoring() {
503        let mut processor = BatchProcessor::builder().max_memory_mb(1).build();
504
505        // Simulate memory usage - using a larger amount to ensure it's tracked
506        let result = processor.update_memory_usage(1024 * 1024); // 1 MB
507        assert!(result.is_ok());
508
509        let stats = processor.get_memory_stats();
510        assert!(stats.current_memory_mb >= 1); // Should be at least 1 MB
511
512        // Test memory limit exceeded
513        let result = processor.update_memory_usage(2 * 1024 * 1024); // 2 MB more
514        assert!(result.is_err());
515    }
516
517    #[test]
518    fn test_parallel_processing_basic() {
519        let training_data =
520            Array2::from_shape_vec((30, 2), (0..60).map(|x| x as Float).collect()).unwrap();
521        let test_data =
522            Array2::from_shape_vec((10, 2), (0..20).map(|x| x as Float).collect()).unwrap();
523
524        let search = BatchNeighborSearch::new(3, Distance::Euclidean, training_data);
525        let mut processor = BatchProcessor::builder()
526            .batch_size(5)
527            .parallel_processing(true)
528            .build();
529
530        let result = processor.process_data(test_data.view(), &search).unwrap();
531
532        // Should process all 10 test samples
533        assert_eq!(result.results.len(), 10);
534        assert_eq!(result.batch_stats.processed_samples, 10);
535
536        // Each result should have 3 neighbors (k=3)
537        for (indices, distances) in &result.results {
538            assert_eq!(indices.len(), 3);
539            assert_eq!(distances.len(), 3);
540        }
541    }
542
543    #[test]
544    fn test_empty_input_handling() {
545        let training_data = Array2::zeros((10, 2));
546        let empty_data = Array2::zeros((0, 2));
547        let search = BatchNeighborSearch::new(3, Distance::Euclidean, training_data);
548        let mut processor = BatchProcessor::builder().build();
549
550        let result = processor.process_data(empty_data.view(), &search);
551        assert!(result.is_err());
552        assert!(matches!(result.unwrap_err(), NeighborsError::EmptyInput));
553    }
554
555    #[test]
556    fn test_memory_efficiency_calculation() {
557        let mut processor = BatchProcessor::builder().max_memory_mb(100).build();
558
559        processor.memory_monitor.peak_memory_mb = 50;
560        let efficiency = processor.calculate_memory_efficiency();
561
562        assert!((efficiency - 0.5).abs() < 1e-6);
563    }
564}