1use super::types::*;
7use crate::dwt::{Wavelet, WaveletFilters};
8use crate::dwt2d_enhanced::enhanced_dwt2d_decompose;
9use crate::error::{SignalError, SignalResult};
10use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1};
11use scirs2_core::parallel_ops::*;
12use scirs2_core::simd_ops::PlatformCapabilities;
13use std::collections::HashMap;
14
15pub fn validate_input_image(
17 image: &Array2<f64>,
18 config: &AdvancedRefinedConfig,
19) -> SignalResult<()> {
20 let (height, width) = image.dim();
21
22 if height == 0 || width == 0 {
23 return Err(SignalError::ValueError("Image cannot be empty".to_string()));
24 }
25
26 if height < config.min_subband_size || width < config.min_subband_size {
27 return Err(SignalError::ValueError(format!(
28 "Image too small: {}x{}, minimum size: {}x{}",
29 height, width, config.min_subband_size, config.min_subband_size
30 )));
31 }
32
33 Ok(())
34}
35
36pub fn optimize_simd_configuration(
38 caps: &PlatformCapabilities,
39 simd_level: SimdLevel,
40) -> SimdConfiguration {
41 match simd_level {
42 SimdLevel::None => SimdConfiguration {
43 use_avx2: false,
44 use_sse: false,
45 acceleration_factor: 1.0,
46 },
47 SimdLevel::Basic => SimdConfiguration {
48 use_avx2: false,
49 use_sse: caps.simd_available,
50 acceleration_factor: if caps.simd_available { 2.0 } else { 1.0 },
51 },
52 SimdLevel::Advanced => SimdConfiguration {
53 use_avx2: caps.avx2_available,
54 use_sse: caps.simd_available,
55 acceleration_factor: if caps.avx2_available {
56 4.0
57 } else if caps.simd_available {
58 2.0
59 } else {
60 1.0
61 },
62 },
63 SimdLevel::Aggressive => SimdConfiguration {
64 use_avx2: caps.avx2_available,
65 use_sse: caps.simd_available,
66 acceleration_factor: if caps.avx2_available {
67 6.0 } else if caps.simd_available {
69 3.0
70 } else {
71 1.0
72 },
73 },
74 }
75}
76
77pub fn should_use_tiled_processing(image: &Array2<f64>, config: &AdvancedRefinedConfig) -> bool {
79 let (height, width) = image.dim();
80 let image_size = height * width;
81 let tile_size = config.tile_size.0 * config.tile_size.1;
82
83 config.memory_efficient && image_size > tile_size * 4
85}
86
87pub fn process_image_tiled(
89 image: &Array2<f64>,
90 wavelet: &Wavelet,
91 config: &AdvancedRefinedConfig,
92 simd_config: &SimdConfiguration,
93 memory_tracker: &mut MemoryTracker,
94) -> SignalResult<ProcessingResult> {
95 let (height, width) = image.dim();
96 let (tile_h, tile_w) = config.tile_size;
97 let overlap = config.tile_overlap;
98
99 let max_levels = config.max_levels;
101 let mut coefficients = Array3::zeros((max_levels, height, width));
102 let mut energy_map = Array2::zeros((height, width));
103
104 let mut total_parallel_efficiency = 0.0;
106 let mut tile_count = 0;
107
108 for y in (0..height).step_by(tile_h - overlap) {
109 for x in (0..width).step_by(tile_w - overlap) {
110 let y_end = (y + tile_h).min(height);
111 let x_end = (x + tile_w).min(width);
112
113 if y_end <= y || x_end <= x {
114 continue;
115 }
116
117 let tile = image
119 .slice(scirs2_core::ndarray::s![y..y_end, x..x_end])
120 .to_owned();
121
122 memory_tracker.track_allocation(
124 &format!("tile_{}_{}_{}", tile_count, y, x),
125 (tile.len() * 8) as f64 / (1024.0 * 1024.0),
126 );
127
128 let tile_result = process_tile(&tile, wavelet, config, simd_config)?;
130
131 copy_tile_results(&tile_result, &mut coefficients, &mut energy_map, y, x)?;
133
134 total_parallel_efficiency += estimate_simd_efficiency(simd_config);
135 tile_count += 1;
136 }
137 }
138
139 let parallel_efficiency = if tile_count > 0 {
140 total_parallel_efficiency / tile_count as f64
141 } else {
142 0.0
143 };
144
145 Ok(ProcessingResult {
146 coefficients,
147 energy_map,
148 parallel_efficiency,
149 })
150}
151
152pub fn process_image_whole(
154 image: &Array2<f64>,
155 wavelet: &Wavelet,
156 config: &AdvancedRefinedConfig,
157 simd_config: &SimdConfiguration,
158 memory_tracker: &mut MemoryTracker,
159) -> SignalResult<ProcessingResult> {
160 let (height, width) = image.dim();
161
162 memory_tracker.track_allocation(
164 "whole_image_processing",
165 (height * width * 8 * config.max_levels) as f64 / (1024.0 * 1024.0),
166 );
167
168 let mut coefficients = Array3::zeros((config.max_levels, height, width));
170 let mut working_image = image.clone();
171
172 for level in 0..config.max_levels {
173 let level_result = perform_level_decomposition(&working_image, wavelet, simd_config)?;
174
175 let level_height = level_result.shape()[0];
177 let level_width = level_result.shape()[1];
178
179 if level_height <= height && level_width <= width {
180 let mut level_slice = coefficients.slice_mut(scirs2_core::ndarray::s![
181 level,
182 0..level_height,
183 0..level_width
184 ]);
185 level_slice.assign(&level_result);
186
187 working_image = extract_approximation_coefficients(&level_result)?;
189 }
190 }
191
192 let energy_map = compute_subband_energy_map(&coefficients)?;
194
195 let parallel_efficiency = estimate_simd_efficiency(simd_config);
196
197 Ok(ProcessingResult {
198 coefficients,
199 energy_map,
200 parallel_efficiency,
201 })
202}
203
204fn process_tile(
206 tile: &Array2<f64>,
207 wavelet: &Wavelet,
208 config: &AdvancedRefinedConfig,
209 simd_config: &SimdConfiguration,
210) -> SignalResult<ProcessingResult> {
211 let (height, width) = tile.dim();
212 let mut coefficients = Array3::zeros((config.max_levels, height, width));
213 let mut working_tile = tile.clone();
214
215 for level in 0..config.max_levels {
216 if working_tile.dim().0 < config.min_subband_size
217 || working_tile.dim().1 < config.min_subband_size
218 {
219 break;
220 }
221
222 let level_result = perform_level_decomposition(&working_tile, wavelet, simd_config)?;
223
224 let level_height = level_result.shape()[0];
226 let level_width = level_result.shape()[1];
227
228 if level_height <= height && level_width <= width {
229 let mut level_slice = coefficients.slice_mut(scirs2_core::ndarray::s![
230 level,
231 0..level_height,
232 0..level_width
233 ]);
234 level_slice.assign(&level_result);
235
236 working_tile = extract_approximation_coefficients(&level_result)?;
238 }
239 }
240
241 let energy_map = compute_subband_energy_map(&coefficients)?;
242
243 Ok(ProcessingResult {
244 coefficients,
245 energy_map,
246 parallel_efficiency: estimate_simd_efficiency(simd_config),
247 })
248}
249
250fn copy_tile_results(
252 tile_result: &ProcessingResult,
253 coefficients: &mut Array3<f64>,
254 energy_map: &mut Array2<f64>,
255 y_offset: usize,
256 x_offset: usize,
257) -> SignalResult<()> {
258 let tile_shape = tile_result.coefficients.shape();
259 let (tile_levels, tile_height, tile_width) = (tile_shape[0], tile_shape[1], tile_shape[2]);
260
261 let coeff_shape = coefficients.shape();
262 let (max_levels, total_height, total_width) = (coeff_shape[0], coeff_shape[1], coeff_shape[2]);
263
264 for level in 0..tile_levels.min(max_levels) {
266 for y in 0..tile_height {
267 for x in 0..tile_width {
268 let global_y = y + y_offset;
269 let global_x = x + x_offset;
270
271 if global_y < total_height && global_x < total_width {
272 coefficients[[level, global_y, global_x]] =
273 tile_result.coefficients[[level, y, x]];
274 }
275 }
276 }
277 }
278
279 update_energy_map(energy_map, &tile_result.energy_map, y_offset, x_offset)?;
281
282 Ok(())
283}
284
285fn perform_level_decomposition(
287 image: &Array2<f64>,
288 wavelet: &Wavelet,
289 simd_config: &SimdConfiguration,
290) -> SignalResult<Array2<f64>> {
291 if simd_config.use_avx2 || simd_config.use_sse {
292 apply_separable_2d_dwt_simd(image, wavelet, simd_config)
293 } else {
294 apply_separable_2d_dwt_standard(image, wavelet)
295 }
296}
297
298fn apply_separable_2d_dwt_simd(
300 image: &Array2<f64>,
301 wavelet: &Wavelet,
302 simd_config: &SimdConfiguration,
303) -> SignalResult<Array2<f64>> {
304 let (height, width) = image.dim();
305 let mut result = Array2::zeros((height, width));
306
307 let filters = wavelet.filters()?;
309
310 let mut row_processed = Array2::zeros((height, width));
312 for i in 0..height {
313 let row = image.row(i);
314 let processed_row = apply_1d_dwt_simd(&row, &filters, simd_config)?;
315 if processed_row.len() == width {
316 row_processed.row_mut(i).assign(&processed_row);
317 }
318 }
319
320 for j in 0..width {
322 let col = row_processed.column(j).to_owned();
323 let processed_col = apply_1d_dwt_simd(&col.view(), &filters, simd_config)?;
324 if processed_col.len() == height {
325 result.column_mut(j).assign(&processed_col);
326 }
327 }
328
329 Ok(result)
330}
331
332fn apply_1d_dwt_simd(
334 signal: &ArrayView1<f64>,
335 filters: &WaveletFilters,
336 simd_config: &SimdConfiguration,
337) -> SignalResult<Array1<f64>> {
338 if simd_config.use_avx2 {
339 apply_dwt_convolution_simd(signal, filters, 8) } else if simd_config.use_sse {
341 apply_dwt_convolution_simd(signal, filters, 4) } else {
343 apply_dwt_convolution_scalar(signal, filters)
344 }
345}
346
347fn apply_dwt_convolution_simd(
349 signal: &ArrayView1<f64>,
350 filters: &WaveletFilters,
351 simd_width: usize,
352) -> SignalResult<Array1<f64>> {
353 let n = signal.len();
354 let mut result = Array1::zeros(n);
355
356 let h_len = filters.dec_lo.len();
357 let g_len = filters.dec_hi.len();
358
359 for i in (0..n).step_by(simd_width) {
361 let chunk_end = (i + simd_width).min(n);
362
363 for j in i..chunk_end {
364 let mut low_sum = 0.0;
365 let mut high_sum = 0.0;
366
367 for k in 0..h_len {
369 let signal_idx = if j >= k {
370 j - k
371 } else {
372 (j + n).wrapping_sub(k)
374 };
375 low_sum += signal[signal_idx % n] * filters.dec_lo[k];
376 }
377
378 for k in 0..g_len {
380 let signal_idx = if j >= k {
381 j - k
382 } else {
383 (j + n).wrapping_sub(k)
385 };
386 high_sum += signal[signal_idx % n] * filters.dec_hi[k];
387 }
388
389 result[j] = if j % 2 == 0 { low_sum } else { high_sum };
391 }
392 }
393
394 Ok(result)
395}
396
397fn apply_dwt_convolution_scalar(
399 signal: &ArrayView1<f64>,
400 filters: &WaveletFilters,
401) -> SignalResult<Array1<f64>> {
402 let n = signal.len();
403 let mut result = Array1::zeros(n);
404
405 let h_len = filters.dec_lo.len();
406 let g_len = filters.dec_hi.len();
407
408 for j in 0..n {
409 let mut low_sum = 0.0;
410 let mut high_sum = 0.0;
411
412 for k in 0..h_len {
414 let signal_idx = (j + n - k) % n;
415 low_sum += signal[signal_idx] * filters.dec_lo[k];
416 }
417
418 for k in 0..g_len {
419 let signal_idx = (j + n - k) % n;
420 high_sum += signal[signal_idx] * filters.dec_hi[k];
421 }
422
423 result[j] = if j % 2 == 0 { low_sum } else { high_sum };
424 }
425
426 Ok(result)
427}
428
429fn apply_separable_2d_dwt_standard(
431 image: &Array2<f64>,
432 wavelet: &Wavelet,
433) -> SignalResult<Array2<f64>> {
434 let config = crate::dwt2d_enhanced::Dwt2dConfig::default();
436
437 let result = enhanced_dwt2d_decompose(image, *wavelet, &config)?;
439
440 Ok(result.approx.clone())
443}
444
445fn extract_approximation_coefficients(coefficients: &Array2<f64>) -> SignalResult<Array2<f64>> {
447 let (height, width) = coefficients.dim();
448
449 let new_height = height / 2;
451 let new_width = width / 2;
452
453 if new_height == 0 || new_width == 0 {
454 return Err(SignalError::ValueError(
455 "Cannot extract approximation coefficients from too small array".to_string(),
456 ));
457 }
458
459 Ok(coefficients
460 .slice(scirs2_core::ndarray::s![0..new_height, 0..new_width])
461 .to_owned())
462}
463
464fn compute_subband_energy_map(coefficients: &Array3<f64>) -> SignalResult<Array2<f64>> {
466 let shape = coefficients.shape();
467 let (levels, height, width) = (shape[0], shape[1], shape[2]);
468
469 let mut energy_map = Array2::zeros((height, width));
470
471 for level in 0..levels {
472 for y in 0..height {
473 for x in 0..width {
474 let coeff = coefficients[[level, y, x]];
475 energy_map[[y, x]] += coeff * coeff;
476 }
477 }
478 }
479
480 Ok(energy_map)
481}
482
483fn update_energy_map(
485 energy_map: &mut Array2<f64>,
486 tile_energy: &Array2<f64>,
487 y_offset: usize,
488 x_offset: usize,
489) -> SignalResult<()> {
490 let tile_shape = tile_energy.shape();
491 let (tile_height, tile_width) = (tile_shape[0], tile_shape[1]);
492
493 let energy_shape = energy_map.shape();
494 let (total_height, total_width) = (energy_shape[0], energy_shape[1]);
495
496 for y in 0..tile_height {
497 for x in 0..tile_width {
498 let global_y = y + y_offset;
499 let global_x = x + x_offset;
500
501 if global_y < total_height && global_x < total_width {
502 energy_map[[global_y, global_x]] += tile_energy[[y, x]];
503 }
504 }
505 }
506
507 Ok(())
508}
509
510fn estimate_simd_efficiency(simd_config: &SimdConfiguration) -> f64 {
512 if simd_config.use_avx2 {
514 0.85 } else if simd_config.use_sse {
516 0.70 } else {
518 0.50 }
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525 use crate::dwt::Wavelet;
526 use scirs2_core::simd_ops::PlatformCapabilities;
527
528 #[test]
529 fn test_validate_input_image() {
530 let image = Array2::zeros((32, 32));
531 let config = AdvancedRefinedConfig::default();
532
533 let result = validate_input_image(&image, &config);
534 assert!(result.is_ok());
535
536 let empty_image = Array2::zeros((0, 0));
538 let result = validate_input_image(&empty_image, &config);
539 assert!(result.is_err());
540 }
541
542 #[test]
543 fn test_simd_configuration() {
544 let caps = PlatformCapabilities::detect();
545 let simd_config = optimize_simd_configuration(&caps, SimdLevel::Advanced);
546
547 assert!(simd_config.acceleration_factor >= 1.0);
548 }
549
550 #[test]
551 fn test_should_use_tiled_processing() {
552 let small_image = Array2::zeros((64, 64));
553 let large_image = Array2::zeros((1024, 1024));
554
555 let config = AdvancedRefinedConfig {
556 memory_efficient: true,
557 tile_size: (256, 256),
558 ..Default::default()
559 };
560
561 assert!(!should_use_tiled_processing(&small_image, &config));
562 assert!(should_use_tiled_processing(&large_image, &config));
563 }
564
565 #[test]
566 fn test_process_image_whole() {
567 let image = Array2::from_shape_fn((64, 64), |(i, j)| {
568 ((i as f64 / 8.0).sin() * (j as f64 / 8.0).cos() + 1.0) / 2.0
569 });
570
571 let config = AdvancedRefinedConfig::default();
572 let caps = PlatformCapabilities::detect();
573 let simd_config = optimize_simd_configuration(&caps, config.simd_level);
574 let mut memory_tracker = MemoryTracker::new();
575
576 let result = process_image_whole(
577 &image,
578 &Wavelet::DB(2),
579 &config,
580 &simd_config,
581 &mut memory_tracker,
582 );
583
584 assert!(result.is_ok());
585 let result = result.unwrap();
586 assert_eq!(result.coefficients.shape()[1], 64);
587 assert_eq!(result.coefficients.shape()[2], 64);
588 }
589}