1use crate::error::{SpecialError, SpecialResult};
8use scirs2_core::ndarray::{Array, ArrayView, ArrayViewMut, Ix1};
9use scirs2_core::numeric::Float;
10use std::marker::PhantomData;
11
12#[derive(Debug, Clone)]
14pub struct ChunkedConfig {
15 pub max_chunk_bytes: usize,
17 pub parallel_chunks: bool,
19 pub min_arraysize: usize,
21 pub prefetch: bool,
23}
24
25impl Default for ChunkedConfig {
26 fn default() -> Self {
27 Self {
28 max_chunk_bytes: 64 * 1024 * 1024,
30 parallel_chunks: true,
31 min_arraysize: 100_000,
32 prefetch: true,
33 }
34 }
35}
36
37pub trait ChunkableFunction<T> {
39 fn apply_chunk(
41 &self,
42 input: &ArrayView<T, Ix1>,
43 output: &mut ArrayViewMut<T, Ix1>,
44 ) -> SpecialResult<()>;
45
46 fn name(&self) -> &str;
48}
49
50pub struct ChunkedProcessor<T, F> {
52 config: ChunkedConfig,
53 function: F,
54 _phantom: PhantomData<T>,
55}
56
57impl<T, F> ChunkedProcessor<T, F>
58where
59 T: Float + Send + Sync,
60 F: ChunkableFunction<T> + Send + Sync,
61{
62 pub fn new(config: ChunkedConfig, function: F) -> Self {
64 Self {
65 config,
66 function,
67 _phantom: PhantomData,
68 }
69 }
70
71 fn calculate_chunksize(&self, totalelements: usize) -> usize {
73 let elementsize = std::mem::size_of::<T>();
74 let max_elements = self.config.max_chunk_bytes / elementsize;
75
76 if totalelements < self.config.min_arraysize {
78 return totalelements;
79 }
80
81 let ideal_chunk = max_elements.min(totalelements);
83
84 for divisor in 1..=100 {
86 let chunksize = totalelements / divisor;
87 if chunksize <= ideal_chunk && totalelements.is_multiple_of(divisor) {
88 return chunksize;
89 }
90 }
91
92 ideal_chunk
93 }
94
95 pub fn process_1d(
97 &self,
98 input: &Array<T, Ix1>,
99 output: &mut Array<T, Ix1>,
100 ) -> SpecialResult<()> {
101 if input.len() != output.len() {
102 return Err(SpecialError::ValueError(
103 "Input and output arrays must have the same length".to_string(),
104 ));
105 }
106
107 let totalelements = input.len();
108 let chunksize = self.calculate_chunksize(totalelements);
109
110 if chunksize == totalelements {
111 self.function
113 .apply_chunk(&input.view(), &mut output.view_mut())?;
114 return Ok(());
115 }
116
117 if self.config.parallel_chunks {
119 self.process_chunks_parallel(input, output, chunksize)
120 } else {
121 self.process_chunks_sequential(input, output, chunksize)
122 }
123 }
124
125 fn process_chunks_sequential(
127 &self,
128 input: &Array<T, Ix1>,
129 output: &mut Array<T, Ix1>,
130 chunksize: usize,
131 ) -> SpecialResult<()> {
132 let totalelements = input.len();
133 let mut offset = 0;
134
135 while offset < totalelements {
136 let end = (offset + chunksize).min(totalelements);
137 let input_chunk = input.slice(scirs2_core::ndarray::s![offset..end]);
138 let mut output_chunk = output.slice_mut(scirs2_core::ndarray::s![offset..end]);
139
140 self.function.apply_chunk(&input_chunk, &mut output_chunk)?;
141
142 offset = end;
143 }
144
145 Ok(())
146 }
147
148 #[cfg(feature = "parallel")]
150 fn process_chunks_parallel(
151 &self,
152 input: &Array<T, Ix1>,
153 output: &mut Array<T, Ix1>,
154 chunksize: usize,
155 ) -> SpecialResult<()> {
156 use scirs2_core::parallel_ops::*;
157
158 let totalelements = input.len();
159 let num_chunks = (totalelements + chunksize - 1) / chunksize;
160
161 let chunks: Vec<(usize, usize)> = (0..num_chunks)
163 .map(|i| {
164 let start = i * chunksize;
165 let end = ((i + 1) * chunksize).min(totalelements);
166 (start, end)
167 })
168 .collect();
169
170 use scirs2_core::parallel_ops::IndexedParallelIterator;
173
174 let results: Vec<_> = chunks
175 .par_iter()
176 .enumerate()
177 .map(|(idx, (start, end))| {
178 let input_chunk = input.slice(scirs2_core::ndarray::s![*start..*end]);
179 let mut temp_output = Array::zeros(end - start);
180 let mut temp_view = temp_output.view_mut();
181
182 match self.function.apply_chunk(&input_chunk, &mut temp_view) {
183 Ok(_) => Ok((idx, temp_output)),
184 Err(e) => Err(e),
185 }
186 })
187 .collect();
188
189 for result in results {
191 match result {
192 Ok((idx, temp_output)) => {
193 let (start, end) = chunks[idx];
194 output
195 .slice_mut(scirs2_core::ndarray::s![start..end])
196 .assign(&temp_output);
197 }
198 Err(e) => return Err(e),
199 }
200 }
201
202 Ok(())
203 }
204
205 #[cfg(not(feature = "parallel"))]
206 fn process_chunks_parallel(
207 &self,
208 input: &Array<T, Ix1>,
209 output: &mut Array<T, Ix1>,
210 chunksize: usize,
211 ) -> SpecialResult<()> {
212 self.process_chunks_sequential(input, output, chunksize)
214 }
215}
216
217pub struct ChunkedGamma;
219
220impl Default for ChunkedGamma {
221 fn default() -> Self {
222 Self::new()
223 }
224}
225
226impl ChunkedGamma {
227 pub fn new() -> Self {
228 Self
229 }
230}
231
232impl<T> ChunkableFunction<T> for ChunkedGamma
233where
234 T: Float + scirs2_core::numeric::FromPrimitive + std::fmt::Debug + std::ops::AddAssign,
235{
236 fn apply_chunk(
237 &self,
238 input: &ArrayView<T, Ix1>,
239 output: &mut ArrayViewMut<T, Ix1>,
240 ) -> SpecialResult<()> {
241 use crate::gamma::gamma;
242
243 for (inp, out) in input.iter().zip(output.iter_mut()) {
244 *out = gamma(*inp);
245 }
246 Ok(())
247 }
248
249 fn name(&self) -> &str {
250 "gamma"
251 }
252}
253
254pub struct ChunkedBesselJ0;
256
257impl Default for ChunkedBesselJ0 {
258 fn default() -> Self {
259 Self::new()
260 }
261}
262
263impl ChunkedBesselJ0 {
264 pub fn new() -> Self {
265 Self
266 }
267}
268
269impl<T> ChunkableFunction<T> for ChunkedBesselJ0
270where
271 T: Float + scirs2_core::numeric::FromPrimitive + std::fmt::Debug,
272{
273 fn apply_chunk(
274 &self,
275 input: &ArrayView<T, Ix1>,
276 output: &mut ArrayViewMut<T, Ix1>,
277 ) -> SpecialResult<()> {
278 use crate::bessel::j0;
279
280 for (inp, out) in input.iter().zip(output.iter_mut()) {
281 *out = j0(*inp);
282 }
283 Ok(())
284 }
285
286 fn name(&self) -> &str {
287 "bessel_j0"
288 }
289}
290
291pub struct ChunkedErf;
293
294impl Default for ChunkedErf {
295 fn default() -> Self {
296 Self::new()
297 }
298}
299
300impl ChunkedErf {
301 pub fn new() -> Self {
302 Self
303 }
304}
305
306impl<T> ChunkableFunction<T> for ChunkedErf
307where
308 T: Float + scirs2_core::numeric::FromPrimitive,
309{
310 fn apply_chunk(
311 &self,
312 input: &ArrayView<T, Ix1>,
313 output: &mut ArrayViewMut<T, Ix1>,
314 ) -> SpecialResult<()> {
315 use crate::erf::erf;
316
317 for (inp, out) in input.iter().zip(output.iter_mut()) {
318 *out = erf(*inp);
319 }
320 Ok(())
321 }
322
323 fn name(&self) -> &str {
324 "erf"
325 }
326}
327
328#[allow(dead_code)]
331pub fn gamma_chunked<T>(
332 input: &Array<T, Ix1>,
333 config: Option<ChunkedConfig>,
334) -> SpecialResult<Array<T, Ix1>>
335where
336 T: Float
337 + scirs2_core::numeric::FromPrimitive
338 + std::fmt::Debug
339 + std::ops::AddAssign
340 + Send
341 + Sync,
342{
343 let config = config.unwrap_or_default();
344 let processor = ChunkedProcessor::new(config, ChunkedGamma::new());
345 let mut output = Array::zeros(input.raw_dim());
346 processor.process_1d(input, &mut output)?;
347 Ok(output)
348}
349
350#[allow(dead_code)]
352pub fn j0_chunked<T>(
353 input: &Array<T, Ix1>,
354 config: Option<ChunkedConfig>,
355) -> SpecialResult<Array<T, Ix1>>
356where
357 T: Float + scirs2_core::numeric::FromPrimitive + std::fmt::Debug + Send + Sync,
358{
359 let config = config.unwrap_or_default();
360 let processor = ChunkedProcessor::new(config, ChunkedBesselJ0::new());
361 let mut output = Array::zeros(input.raw_dim());
362 processor.process_1d(input, &mut output)?;
363 Ok(output)
364}
365
366#[allow(dead_code)]
368pub fn erf_chunked<T>(
369 input: &Array<T, Ix1>,
370 config: Option<ChunkedConfig>,
371) -> SpecialResult<Array<T, Ix1>>
372where
373 T: Float + scirs2_core::numeric::FromPrimitive + Send + Sync,
374{
375 let config = config.unwrap_or_default();
376 let processor = ChunkedProcessor::new(config, ChunkedErf::new());
377 let mut output = Array::zeros(input.raw_dim());
378 processor.process_1d(input, &mut output)?;
379 Ok(output)
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use scirs2_core::ndarray::Array1;
386
387 #[test]
388 fn test_chunksize_calculation() {
389 let config = ChunkedConfig::default();
390 let processor: ChunkedProcessor<f64, ChunkedGamma> =
391 ChunkedProcessor::new(config, ChunkedGamma::new());
392
393 assert_eq!(processor.calculate_chunksize(1000), 1000);
395
396 let chunksize = processor.calculate_chunksize(10_000_000);
398 assert!(chunksize < 10_000_000);
399 assert!(chunksize > 0);
400 }
401
402 #[test]
403 fn test_gamma_chunked() {
404 let input = Array1::linspace(0.1, 5.0, 1000);
405 let result = gamma_chunked(&input, None).expect("Operation failed");
406
407 use crate::gamma::gamma;
409 for i in 0..1000 {
410 assert!((result[i] - gamma(input[i])).abs() < 1e-10);
411 }
412 }
413
414 #[test]
415 fn test_chunked_with_custom_config() {
416 let config = ChunkedConfig {
417 max_chunk_bytes: 1024, parallel_chunks: false,
419 min_arraysize: 10,
420 prefetch: false,
421 };
422
423 let input = Array1::linspace(0.1, 5.0, 100);
424 let result = gamma_chunked(&input, Some(config)).expect("Operation failed");
425
426 use crate::gamma::gamma;
428 for i in 0..100 {
429 assert!((result[i] - gamma(input[i])).abs() < 1e-10);
430 }
431 }
432}