1use scirs2_core::ndarray::{Array, ArrayView, Dimension, IxDyn};
8use scirs2_core::numeric::{Float, FromPrimitive, NumCast, Zero};
9use std::fmt::Debug;
10use std::sync::Mutex;
11
12use crate::error::{NdimageError, NdimageResult};
13use crate::filters::BorderMode;
14
15#[derive(Debug, Clone)]
17pub struct ChunkConfig {
18 pub chunk_size_bytes: usize,
20 pub overlap: usize,
22 pub min_chunk_size: usize,
24 pub parallel: bool,
26}
27
28impl Default for ChunkConfig {
29 fn default() -> Self {
30 Self {
31 chunk_size_bytes: 64 * 1024 * 1024, overlap: 0,
33 min_chunk_size: 16,
34 parallel: true,
35 }
36 }
37}
38
39#[derive(Debug, Clone)]
41pub struct ChunkPosition {
42 pub start: Vec<usize>,
44 pub end: Vec<usize>,
46}
47
48pub trait ChunkProcessor<T, D>
50where
51 D: Dimension,
52{
53 fn process_chunk(
55 &mut self,
56 chunk: ArrayView<T, D>,
57 position: &ChunkPosition,
58 ) -> NdimageResult<Array<T, D>>;
59
60 fn required_overlap(&self) -> usize;
62
63 fn combine_chunks(
65 &self,
66 results: Vec<(Array<T, D>, ChunkPosition)>,
67 outputshape: &[usize],
68 ) -> NdimageResult<Array<T, D>>;
69}
70
71#[allow(dead_code)]
73pub fn process_chunked<T, D, P>(
74 input: &ArrayView<T, D>,
75 processor: &mut P,
76 config: &ChunkConfig,
77) -> NdimageResult<Array<T, D>>
78where
79 T: Float + FromPrimitive + NumCast + Debug + Clone + Send + Sync,
80 D: Dimension,
81 P: ChunkProcessor<T, D> + Send + Sync,
82{
83 let shape = input.shape();
84 let ndim = input.ndim();
85 let element_size = std::mem::size_of::<T>();
86
87 let total_elements = shape.iter().product::<usize>();
89 let target_elements_per_chunk = config.chunk_size_bytes / element_size;
90
91 if total_elements <= target_elements_per_chunk {
92 let position = ChunkPosition {
94 start: vec![0; ndim],
95 end: shape.to_vec(),
96 };
97 let result = processor.process_chunk(input.clone(), &position)?;
98 return Ok(result);
99 }
100
101 let chunk_sizes =
103 calculate_chunk_sizes(shape, target_elements_per_chunk, config.min_chunk_size);
104 let overlap = processor.required_overlap().max(config.overlap);
105
106 let chunks = generate_chunk_positions(shape, &chunk_sizes, overlap);
108
109 let results = if config.parallel && chunks.len() > 1 {
111 #[cfg(feature = "parallel")]
112 {
113 use scirs2_core::parallel_ops::*;
114
115 let processor_mutex = Mutex::new(&mut *processor);
116 chunks
117 .into_par_iter()
118 .map(|position| {
119 let chunk = extract_chunk(input, &position)?;
120 let result = {
121 let mut proc = processor_mutex.lock().expect("Operation failed");
122 proc.process_chunk(chunk.view(), &position)?
123 };
124 Ok((result, position))
125 })
126 .collect::<Result<Vec<_>, NdimageError>>()?
127 }
128 #[cfg(not(feature = "parallel"))]
129 {
130 chunks
131 .into_iter()
132 .map(|position| {
133 let chunk = extract_chunk(input, &position)?;
134 let result = processor.process_chunk(chunk.view(), &position)?;
135 Ok((result, position))
136 })
137 .collect::<Result<Vec<_>, NdimageError>>()?
138 }
139 } else {
140 chunks
141 .into_iter()
142 .map(|position| {
143 let chunk = extract_chunk(input, &position)?;
144 let result = processor.process_chunk(chunk.view(), &position)?;
145 Ok((result, position))
146 })
147 .collect::<Result<Vec<_>, NdimageError>>()?
148 };
149
150 processor.combine_chunks(results, shape)
152}
153
154#[allow(dead_code)]
156fn calculate_chunk_sizes(
157 shape: &[usize],
158 target_elements: usize,
159 min_chunk_size: usize,
160) -> Vec<usize> {
161 let ndim = shape.len();
162 let mut chunk_sizes = vec![0; ndim];
163
164 let elements_per_dim = (target_elements as f64).powf(1.0 / ndim as f64) as usize;
166
167 for (i, &dim_size) in shape.iter().enumerate() {
168 chunk_sizes[i] = elements_per_dim.min(dim_size).max(min_chunk_size);
169 }
170
171 let mut current_elements: usize = chunk_sizes.iter().product();
173
174 while current_elements > target_elements * 2 {
175 let (max_idx_, _) = chunk_sizes
177 .iter()
178 .enumerate()
179 .filter(|(i, &_size)| _size > min_chunk_size && _size < shape[*i])
180 .max_by_key(|(i, &_size)| _size * 1000 / shape[*i])
181 .unwrap_or((0, &1));
182
183 if chunk_sizes[max_idx_] > min_chunk_size {
184 chunk_sizes[max_idx_] = (chunk_sizes[max_idx_] / 2).max(min_chunk_size);
185 current_elements = chunk_sizes.iter().product();
186 } else {
187 break;
188 }
189 }
190
191 chunk_sizes
192}
193
194#[allow(dead_code)]
196fn generate_chunk_positions(
197 shape: &[usize],
198 chunk_sizes: &[usize],
199 overlap: usize,
200) -> Vec<ChunkPosition> {
201 let ndim = shape.len();
202 let mut positions = Vec::new();
203
204 let mut indices = vec![0; ndim];
206
207 loop {
208 let mut position = ChunkPosition {
209 start: Vec::with_capacity(ndim),
210 end: Vec::with_capacity(ndim),
211 };
212
213 for dim in 0..ndim {
214 let start = if indices[dim] == 0 {
215 0
216 } else {
217 indices[dim] * chunk_sizes[dim] - overlap
218 };
219 let end = (start + chunk_sizes[dim] + overlap).min(shape[dim]);
220
221 position.start.push(start);
222 position.end.push(end);
223 }
224
225 positions.push(position);
226
227 let mut carry = true;
229 for dim in (0..ndim).rev() {
230 if carry {
231 indices[dim] += 1;
232 if (indices[dim] + 1) * chunk_sizes[dim] >= shape[dim] + overlap {
233 if indices[dim] * chunk_sizes[dim] < shape[dim] {
234 carry = false;
235 } else {
236 indices[dim] = 0;
237 }
238 } else {
239 carry = false;
240 }
241 }
242 }
243
244 if carry {
245 break;
246 }
247 }
248
249 positions
250}
251
252#[allow(dead_code)]
254fn extract_chunk<T, D>(
255 array: &ArrayView<T, D>,
256 position: &ChunkPosition,
257) -> NdimageResult<Array<T, D>>
258where
259 T: Clone,
260 D: Dimension,
261{
262 use scirs2_core::ndarray::SliceInfoElem;
263
264 let slice_info: Vec<SliceInfoElem> = position
266 .start
267 .iter()
268 .zip(&position.end)
269 .map(|(&start, &end)| SliceInfoElem::Slice {
270 start: start as isize,
271 end: Some(end as isize),
272 step: 1,
273 })
274 .collect();
275
276 let chunk = array.view().into_dyn().slice_move(slice_info.as_slice());
277 let owned_chunk = chunk.to_owned();
278 Ok(owned_chunk
279 .into_dimensionality::<D>()
280 .map_err(|_| NdimageError::DimensionError("Failed to convert chunk dimension".into()))?)
281}
282
283pub struct GaussianChunkProcessor<T> {
285 sigma: Vec<T>,
286 truncate: Option<T>,
287 bordermode: BorderMode,
288}
289
290impl<T> GaussianChunkProcessor<T>
291where
292 T: Float + FromPrimitive,
293{
294 pub fn new(_sigma: Vec<T>, truncate: Option<T>, bordermode: BorderMode) -> Self {
295 Self {
296 sigma: _sigma,
297 truncate,
298 bordermode,
299 }
300 }
301}
302
303impl<T, D> ChunkProcessor<T, D> for GaussianChunkProcessor<T>
304where
305 T: Float + FromPrimitive + NumCast + Debug + Clone + Send + Sync + Zero,
306 D: Dimension,
307{
308 fn process_chunk(
309 &mut self,
310 chunk: ArrayView<T, D>,
311 _position: &ChunkPosition,
312 ) -> NdimageResult<Array<T, D>> {
313 Ok(chunk.to_owned())
316 }
317
318 fn required_overlap(&self) -> usize {
319 let max_sigma = self
321 .sigma
322 .iter()
323 .map(|&s| NumCast::from(s).unwrap_or(0.0))
324 .fold(0.0f64, |a, b| a.max(b));
325
326 let truncate = self
327 .truncate
328 .map(|t| NumCast::from(t).unwrap_or(4.0))
329 .unwrap_or(4.0);
330
331 ((truncate * max_sigma).ceil() as usize).max(1)
332 }
333
334 fn combine_chunks(
335 &self,
336 results: Vec<(Array<T, D>, ChunkPosition)>,
337 outputshape: &[usize],
338 ) -> NdimageResult<Array<T, D>> {
339 let mut output = Array::<T, IxDyn>::zeros(IxDyn(outputshape));
341 let overlap = <Self as ChunkProcessor<T, D>>::required_overlap(self);
342
343 for (chunk_result, position) in results {
345 use scirs2_core::ndarray::SliceInfoElem;
346
347 let mut copy_start = Vec::new();
349 let mut copy_end = Vec::new();
350 let mut chunk_start = Vec::new();
351 let mut chunk_end = Vec::new();
352
353 for (dim, (&start, &end)) in position.start.iter().zip(&position.end).enumerate() {
354 let out_start = if start > 0 {
356 start + overlap / 2
357 } else {
358 start
359 };
360 let out_end = if end < outputshape[dim] {
361 end - overlap / 2
362 } else {
363 end
364 };
365 copy_start.push(out_start);
366 copy_end.push(out_end);
367
368 let ch_start = if start > 0 { overlap / 2 } else { 0 };
370 let ch_end = chunk_result.shape()[dim]
371 - if end < outputshape[dim] {
372 overlap / 2
373 } else {
374 0
375 };
376 chunk_start.push(ch_start);
377 chunk_end.push(ch_end);
378 }
379
380 let output_slice_info: Vec<SliceInfoElem> = copy_start
382 .iter()
383 .zip(©_end)
384 .map(|(&start, &end)| SliceInfoElem::Slice {
385 start: start as isize,
386 end: Some(end as isize),
387 step: 1,
388 })
389 .collect();
390
391 let chunk_slice_info: Vec<SliceInfoElem> = chunk_start
393 .iter()
394 .zip(&chunk_end)
395 .map(|(&start, &end)| SliceInfoElem::Slice {
396 start: start as isize,
397 end: Some(end as isize),
398 step: 1,
399 })
400 .collect();
401
402 let chunk_dyn = chunk_result.view().into_dyn();
404 let chunk_slice = chunk_dyn.slice(chunk_slice_info.as_slice());
405 let mut output_slice = output.slice_mut(output_slice_info.as_slice());
406 output_slice.assign(&chunk_slice);
407 }
408
409 output
411 .into_dimensionality::<D>()
412 .map_err(|_| NdimageError::DimensionError("Failed to convert output dimension".into()))
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use scirs2_core::ndarray::Array2;
420
421 #[test]
422 fn test_calculate_chunk_sizes() {
423 let shape = vec![1000, 1000];
424 let target_elements = 10000;
425 let min_chunk_size = 10;
426
427 let chunk_sizes = calculate_chunk_sizes(&shape, target_elements, min_chunk_size);
428
429 assert_eq!(chunk_sizes.len(), 2);
430 assert!(chunk_sizes[0] >= min_chunk_size);
431 assert!(chunk_sizes[1] >= min_chunk_size);
432 assert!(chunk_sizes[0] <= shape[0]);
433 assert!(chunk_sizes[1] <= shape[1]);
434
435 let total_elements: usize = chunk_sizes.iter().product();
436 assert!(total_elements <= target_elements * 3); }
438
439 #[test]
440 fn test_generate_chunk_positions() {
441 let shape = vec![100, 100];
442 let chunk_sizes = vec![50, 50];
443 let overlap = 5;
444
445 let positions = generate_chunk_positions(&shape, &chunk_sizes, overlap);
446
447 assert_eq!(positions.len(), 4);
449
450 assert_eq!(positions[0].start, vec![0, 0]);
452 assert_eq!(positions[0].end, vec![55, 55]); }
454
455 struct IdentityProcessor;
457
458 impl<T: Clone + Zero, D: Dimension> ChunkProcessor<T, D> for IdentityProcessor {
459 fn process_chunk(
460 &mut self,
461 chunk: ArrayView<T, D>,
462 _position: &ChunkPosition,
463 ) -> NdimageResult<Array<T, D>> {
464 Ok(chunk.to_owned())
465 }
466
467 fn required_overlap(&self) -> usize {
468 0
469 }
470
471 fn combine_chunks(
472 &self,
473 results: Vec<(Array<T, D>, ChunkPosition)>,
474 outputshape: &[usize],
475 ) -> NdimageResult<Array<T, D>> {
476 use scirs2_core::ndarray::SliceInfoElem;
477
478 let mut output = Array::zeros(IxDyn(outputshape));
479
480 for (chunk, position) in results {
481 let slice_info: Vec<SliceInfoElem> = position
482 .start
483 .iter()
484 .zip(&position.end)
485 .map(|(&start, &end)| SliceInfoElem::Slice {
486 start: start as isize,
487 end: Some(end as isize),
488 step: 1,
489 })
490 .collect();
491
492 let mut output_slice = output.slice_mut(slice_info.as_slice());
493 output_slice.assign(&chunk.view().into_dyn());
494 }
495
496 output
497 .into_dimensionality::<D>()
498 .map_err(|_| NdimageError::DimensionError("Dimension conversion failed".into()))
499 }
500 }
501
502 #[test]
503 fn test_process_chunked_identity() {
504 let input = Array2::<f64>::ones((100, 100));
505 let mut processor = IdentityProcessor;
506 let config = ChunkConfig {
507 chunk_size_bytes: 800, overlap: 0,
509 min_chunk_size: 10,
510 parallel: false,
511 };
512
513 let result =
514 process_chunked(&input.view(), &mut processor, &config).expect("Operation failed");
515
516 assert_eq!(result.shape(), input.shape());
517 assert_eq!(result, input);
518 }
519}