1use crate::distance::Distance;
2use crate::{NeighborsError, NeighborsResult};
3#[cfg(feature = "parallel")]
4use rayon::prelude::*;
5use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
6use sklears_core::types::Float;
7use std::collections::VecDeque;
8
9#[derive(Debug, Clone)]
10pub struct BatchConfiguration {
11 pub batch_size: usize,
12 pub max_memory_mb: usize,
13 pub parallel_processing: bool,
14 pub chunk_overlap: usize,
15 pub prefetch_batches: usize,
16}
17
18impl Default for BatchConfiguration {
19 fn default() -> Self {
20 Self {
21 batch_size: 1000,
22 max_memory_mb: 512,
23 parallel_processing: true,
24 chunk_overlap: 0,
25 prefetch_batches: 2,
26 }
27 }
28}
29
30pub struct BatchProcessor {
31 config: BatchConfiguration,
32 memory_monitor: MemoryMonitor,
33}
34
35#[derive(Debug, Clone)]
36pub struct MemoryMonitor {
37 peak_memory_mb: usize,
38 current_memory_mb: usize,
39 memory_threshold_mb: usize,
40}
41
42#[derive(Debug, Clone)]
43pub struct BatchResult<T> {
44 pub results: Vec<T>,
45 pub batch_stats: BatchStatistics,
46}
47
48#[derive(Debug, Clone)]
49pub struct BatchStatistics {
50 pub total_batches: usize,
51 pub processed_samples: usize,
52 pub processing_time_ms: u128,
53 pub peak_memory_mb: usize,
54 pub average_batch_size: usize,
55 pub memory_efficiency: Float,
56}
57
58pub trait BatchProcessable<T> {
59 fn process_batch(&self, batch_data: ArrayView2<Float>) -> NeighborsResult<Vec<T>>;
60 fn estimate_memory_per_sample(&self) -> usize;
61 fn supports_parallel_processing(&self) -> bool;
62}
63
64impl BatchProcessor {
65 pub fn new(config: BatchConfiguration) -> Self {
66 let memory_monitor = MemoryMonitor {
67 peak_memory_mb: 0,
68 current_memory_mb: 0,
69 memory_threshold_mb: config.max_memory_mb,
70 };
71
72 Self {
73 config,
74 memory_monitor,
75 }
76 }
77
78 pub fn builder() -> BatchProcessorBuilder {
79 BatchProcessorBuilder::new()
80 }
81
82 pub fn process_data<T, P>(
83 &mut self,
84 data: ArrayView2<Float>,
85 processor: &P,
86 ) -> NeighborsResult<BatchResult<T>>
87 where
88 T: Send + Sync + Clone,
89 P: BatchProcessable<T> + Sync,
90 {
91 let start_time = std::time::Instant::now();
92 let num_samples = data.nrows();
93
94 if num_samples == 0 {
95 return Err(NeighborsError::EmptyInput);
96 }
97
98 let optimal_batch_size = self.calculate_optimal_batch_size(num_samples, processor)?;
99
100 let mut all_results = Vec::new();
101 let mut batch_count = 0;
102
103 if self.config.parallel_processing && processor.supports_parallel_processing() {
104 all_results = self.process_data_parallel(data, optimal_batch_size, processor)?;
105 batch_count = (num_samples + optimal_batch_size - 1) / optimal_batch_size;
106 } else {
107 let mut start_idx = 0;
108 while start_idx < num_samples {
109 let end_idx = std::cmp::min(start_idx + optimal_batch_size, num_samples);
110 let batch = data.slice(scirs2_core::ndarray::s![start_idx..end_idx, ..]);
111
112 let batch_results = processor.process_batch(batch)?;
113 all_results.extend(batch_results);
114 batch_count += 1;
115 self.update_memory_usage(batch.nrows() * batch.ncols() * 8)?;
116
117 start_idx = end_idx - self.config.chunk_overlap;
118 if start_idx >= end_idx {
119 break;
120 }
121 }
122 }
123
124 let processing_time = start_time.elapsed().as_millis();
125 let stats = BatchStatistics {
126 total_batches: batch_count,
127 processed_samples: num_samples,
128 processing_time_ms: processing_time,
129 peak_memory_mb: self.memory_monitor.peak_memory_mb,
130 average_batch_size: num_samples / batch_count.max(1),
131 memory_efficiency: self.calculate_memory_efficiency(),
132 };
133
134 Ok(BatchResult {
135 results: all_results,
136 batch_stats: stats,
137 })
138 }
139
140 pub fn process_streaming_data<T, P>(
141 &mut self,
142 data_stream: impl Iterator<Item = Array1<Float>>,
143 processor: &P,
144 ) -> NeighborsResult<Vec<T>>
145 where
146 T: Send + Sync + Clone,
147 P: BatchProcessable<T> + Sync,
148 {
149 let mut buffer = VecDeque::new();
150 let batch_size = self.config.batch_size;
151 let mut all_results = Vec::new();
152
153 for sample in data_stream {
154 buffer.push_back(sample);
155
156 if buffer.len() >= batch_size {
157 let batch_data: Vec<Array1<Float>> = buffer.drain(..batch_size).collect();
158 let batch_matrix = self.vec_to_array2(batch_data)?;
159 let results = processor.process_batch(batch_matrix.view())?;
160 all_results.extend(results);
161 }
162 }
163
164 if !buffer.is_empty() {
166 let batch_data: Vec<Array1<Float>> = buffer.drain(..).collect();
167 let batch_matrix = self.vec_to_array2(batch_data)?;
168 let results = processor.process_batch(batch_matrix.view())?;
169 all_results.extend(results);
170 }
171
172 Ok(all_results)
173 }
174
175 fn calculate_optimal_batch_size<T, P>(
176 &self,
177 num_samples: usize,
178 processor: &P,
179 ) -> NeighborsResult<usize>
180 where
181 P: BatchProcessable<T>,
182 {
183 let memory_per_sample = processor.estimate_memory_per_sample();
184 let max_samples_per_batch = (self.config.max_memory_mb * 1024 * 1024) / memory_per_sample;
185
186 let optimal_size = std::cmp::min(
187 std::cmp::min(self.config.batch_size, max_samples_per_batch),
188 num_samples,
189 );
190
191 if optimal_size == 0 {
192 return Err(NeighborsError::InvalidInput(
193 "Batch size too small for available memory".to_string(),
194 ));
195 }
196
197 Ok(optimal_size)
198 }
199
200 #[cfg(feature = "parallel")]
201 fn process_data_parallel<T, P>(
202 &self,
203 data: ArrayView2<Float>,
204 batch_size: usize,
205 processor: &P,
206 ) -> NeighborsResult<Vec<T>>
207 where
208 T: Send + Sync + Clone,
209 P: BatchProcessable<T> + Sync,
210 {
211 let num_samples = data.nrows();
212 let chunk_indices: Vec<(usize, usize)> = (0..num_samples)
213 .step_by(batch_size)
214 .map(|start| {
215 let end = std::cmp::min(start + batch_size, num_samples);
216 (start, end)
217 })
218 .collect();
219
220 let results: Result<Vec<Vec<T>>, NeighborsError> = chunk_indices
221 .par_iter()
222 .map(|&(start, end)| {
223 let batch = data.slice(scirs2_core::ndarray::s![start..end, ..]);
224 processor.process_batch(batch)
225 })
226 .collect();
227
228 match results {
229 Ok(batch_results) => Ok(batch_results.into_iter().flatten().collect()),
230 Err(e) => Err(e),
231 }
232 }
233
234 #[cfg(not(feature = "parallel"))]
235 fn process_data_parallel<T, P>(
236 &self,
237 data: ArrayView2<Float>,
238 batch_size: usize,
239 processor: &P,
240 ) -> NeighborsResult<Vec<T>>
241 where
242 T: Send + Sync + Clone,
243 P: BatchProcessable<T> + Sync,
244 {
245 let mut all_results = Vec::new();
247 let num_samples = data.nrows();
248 let mut start_idx = 0;
249
250 while start_idx < num_samples {
251 let end_idx = std::cmp::min(start_idx + batch_size, num_samples);
252 let batch = data.slice(scirs2_core::ndarray::s![start_idx..end_idx, ..]);
253
254 let batch_results = processor.process_batch(batch)?;
255 all_results.extend(batch_results);
256
257 start_idx = end_idx;
258 }
259
260 Ok(all_results)
261 }
262
263 fn update_memory_usage(&mut self, additional_bytes: usize) -> NeighborsResult<()> {
264 let additional_mb = additional_bytes / (1024 * 1024);
265 self.memory_monitor.current_memory_mb += additional_mb;
266
267 if self.memory_monitor.current_memory_mb > self.memory_monitor.peak_memory_mb {
268 self.memory_monitor.peak_memory_mb = self.memory_monitor.current_memory_mb;
269 }
270
271 if self.memory_monitor.current_memory_mb > self.memory_monitor.memory_threshold_mb {
272 return Err(NeighborsError::InvalidInput(format!(
273 "Memory usage exceeded threshold: {} MB",
274 self.memory_monitor.memory_threshold_mb
275 )));
276 }
277
278 Ok(())
279 }
280
281 fn calculate_memory_efficiency(&self) -> Float {
282 if self.memory_monitor.memory_threshold_mb == 0 {
283 return 1.0;
284 }
285
286 1.0 - (self.memory_monitor.peak_memory_mb as Float
287 / self.memory_monitor.memory_threshold_mb as Float)
288 }
289
290 fn vec_to_array2(&self, vec_data: Vec<Array1<Float>>) -> NeighborsResult<Array2<Float>> {
291 if vec_data.is_empty() {
292 return Err(NeighborsError::EmptyInput);
293 }
294
295 let n_samples = vec_data.len();
296 let n_features = vec_data[0].len();
297
298 let mut result = Array2::zeros((n_samples, n_features));
299 for (i, row) in vec_data.iter().enumerate() {
300 if row.len() != n_features {
301 return Err(NeighborsError::ShapeMismatch {
302 expected: vec![n_features],
303 actual: vec![row.len()],
304 });
305 }
306 result.row_mut(i).assign(row);
307 }
308
309 Ok(result)
310 }
311
312 pub fn get_memory_stats(&self) -> &MemoryMonitor {
313 &self.memory_monitor
314 }
315
316 pub fn reset_memory_monitor(&mut self) {
317 self.memory_monitor.current_memory_mb = 0;
318 self.memory_monitor.peak_memory_mb = 0;
319 }
320}
321
322pub struct BatchProcessorBuilder {
323 config: BatchConfiguration,
324}
325
326impl Default for BatchProcessorBuilder {
327 fn default() -> Self {
328 Self::new()
329 }
330}
331
332impl BatchProcessorBuilder {
333 pub fn new() -> Self {
334 Self {
335 config: BatchConfiguration::default(),
336 }
337 }
338
339 pub fn batch_size(mut self, size: usize) -> Self {
340 self.config.batch_size = size;
341 self
342 }
343
344 pub fn max_memory_mb(mut self, memory_mb: usize) -> Self {
345 self.config.max_memory_mb = memory_mb;
346 self
347 }
348
349 pub fn parallel_processing(mut self, enabled: bool) -> Self {
350 self.config.parallel_processing = enabled;
351 self
352 }
353
354 pub fn chunk_overlap(mut self, overlap: usize) -> Self {
355 self.config.chunk_overlap = overlap;
356 self
357 }
358
359 pub fn prefetch_batches(mut self, count: usize) -> Self {
360 self.config.prefetch_batches = count;
361 self
362 }
363
364 pub fn build(self) -> BatchProcessor {
365 BatchProcessor::new(self.config)
366 }
367}
368
369pub struct BatchNeighborSearch {
370 k: usize,
371 distance: Distance,
372 training_data: Array2<Float>,
373}
374
375impl BatchNeighborSearch {
376 pub fn new(k: usize, distance: Distance, training_data: Array2<Float>) -> Self {
377 Self {
378 k,
379 distance,
380 training_data,
381 }
382 }
383}
384
385impl BatchProcessable<(Vec<usize>, Vec<Float>)> for BatchNeighborSearch {
386 fn process_batch(
387 &self,
388 batch_data: ArrayView2<Float>,
389 ) -> NeighborsResult<Vec<(Vec<usize>, Vec<Float>)>> {
390 let mut results = Vec::new();
391
392 for query_row in batch_data.axis_iter(Axis(0)) {
393 let mut distances: Vec<(Float, usize)> = self
394 .training_data
395 .axis_iter(Axis(0))
396 .enumerate()
397 .map(|(idx, train_row)| {
398 let dist = self.distance.calculate(&query_row, &train_row);
399 (dist, idx)
400 })
401 .collect();
402
403 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
404 distances.truncate(self.k);
405
406 let indices: Vec<usize> = distances.iter().map(|(_, idx)| *idx).collect();
407 let dists: Vec<Float> = distances.iter().map(|(dist, _)| *dist).collect();
408
409 results.push((indices, dists));
410 }
411
412 Ok(results)
413 }
414
415 fn estimate_memory_per_sample(&self) -> usize {
416 let feature_memory = self.training_data.ncols() * 8; let distance_memory = self.training_data.nrows() * 16; let result_memory = self.k * 16; feature_memory + distance_memory + result_memory
421 }
422
423 fn supports_parallel_processing(&self) -> bool {
424 true
425 }
426}
427
428#[allow(non_snake_case)]
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
434 fn test_batch_processor_creation() {
435 let processor = BatchProcessor::builder()
436 .batch_size(500)
437 .max_memory_mb(256)
438 .parallel_processing(true)
439 .build();
440
441 assert_eq!(processor.config.batch_size, 500);
442 assert_eq!(processor.config.max_memory_mb, 256);
443 assert!(processor.config.parallel_processing);
444 }
445
446 #[test]
447 fn test_memory_efficient_batch_processing() {
448 let training_data =
449 Array2::from_shape_vec((100, 4), (0..400).map(|x| x as Float).collect()).unwrap();
450 let test_data =
451 Array2::from_shape_vec((50, 4), (0..200).map(|x| x as Float).collect()).unwrap();
452
453 let search = BatchNeighborSearch::new(5, Distance::Euclidean, training_data);
454 let mut processor = BatchProcessor::builder()
455 .batch_size(10)
456 .max_memory_mb(64)
457 .build();
458
459 let result = processor.process_data(test_data.view(), &search).unwrap();
460
461 assert_eq!(result.results.len(), 50);
462 assert!(result.batch_stats.total_batches > 0);
463 assert_eq!(result.batch_stats.processed_samples, 50);
464 }
465
466 #[test]
467 fn test_optimal_batch_size_calculation() {
468 let training_data = Array2::zeros((100, 10));
469 let search = BatchNeighborSearch::new(5, Distance::Euclidean, training_data);
470 let processor = BatchProcessor::builder()
471 .batch_size(1000)
472 .max_memory_mb(1)
473 .build();
474
475 let optimal_size = processor.calculate_optimal_batch_size(50, &search).unwrap();
476
477 assert!(optimal_size <= 50);
479 assert!(optimal_size > 0);
480 }
481
482 #[test]
483 fn test_batch_processing_with_overlap() {
484 let training_data =
485 Array2::from_shape_vec((20, 2), (0..40).map(|x| x as Float).collect()).unwrap();
486 let _test_data =
487 Array2::from_shape_vec((10, 2), (0..20).map(|x| x as Float).collect()).unwrap();
488
489 let _search = BatchNeighborSearch::new(3, Distance::Euclidean, training_data);
490 let processor = BatchProcessor::builder()
491 .batch_size(4)
492 .chunk_overlap(2)
493 .build();
494
495 let config = &processor.config;
497 assert_eq!(config.chunk_overlap, 2);
498 assert_eq!(config.batch_size, 4);
499 }
500
501 #[test]
502 fn test_memory_monitoring() {
503 let mut processor = BatchProcessor::builder().max_memory_mb(1).build();
504
505 let result = processor.update_memory_usage(1024 * 1024); assert!(result.is_ok());
508
509 let stats = processor.get_memory_stats();
510 assert!(stats.current_memory_mb >= 1); let result = processor.update_memory_usage(2 * 1024 * 1024); assert!(result.is_err());
515 }
516
517 #[test]
518 fn test_parallel_processing_basic() {
519 let training_data =
520 Array2::from_shape_vec((30, 2), (0..60).map(|x| x as Float).collect()).unwrap();
521 let test_data =
522 Array2::from_shape_vec((10, 2), (0..20).map(|x| x as Float).collect()).unwrap();
523
524 let search = BatchNeighborSearch::new(3, Distance::Euclidean, training_data);
525 let mut processor = BatchProcessor::builder()
526 .batch_size(5)
527 .parallel_processing(true)
528 .build();
529
530 let result = processor.process_data(test_data.view(), &search).unwrap();
531
532 assert_eq!(result.results.len(), 10);
534 assert_eq!(result.batch_stats.processed_samples, 10);
535
536 for (indices, distances) in &result.results {
538 assert_eq!(indices.len(), 3);
539 assert_eq!(distances.len(), 3);
540 }
541 }
542
543 #[test]
544 fn test_empty_input_handling() {
545 let training_data = Array2::zeros((10, 2));
546 let empty_data = Array2::zeros((0, 2));
547 let search = BatchNeighborSearch::new(3, Distance::Euclidean, training_data);
548 let mut processor = BatchProcessor::builder().build();
549
550 let result = processor.process_data(empty_data.view(), &search);
551 assert!(result.is_err());
552 assert!(matches!(result.unwrap_err(), NeighborsError::EmptyInput));
553 }
554
555 #[test]
556 fn test_memory_efficiency_calculation() {
557 let mut processor = BatchProcessor::builder().max_memory_mb(100).build();
558
559 processor.memory_monitor.peak_memory_mb = 50;
560 let efficiency = processor.calculate_memory_efficiency();
561
562 assert!((efficiency - 0.5).abs() < 1e-6);
563 }
564}