1use crate::{next_fast_len, FFTResult};
8use scirs2_core::ndarray::{s, Array1, ArrayBase, ArrayD, Data, Dimension};
9use scirs2_core::numeric::Complex;
10use scirs2_core::numeric::Zero;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum PaddingMode {
15 None,
17 Zero,
19 Constant(f64),
21 Edge,
23 Reflect,
25 Symmetric,
27 Wrap,
29 LinearRamp,
31}
32
33#[derive(Debug, Clone)]
35pub struct AutoPadConfig {
36 pub mode: PaddingMode,
38 pub min_pad: usize,
40 pub max_pad: Option<usize>,
42 pub power_of_2: bool,
44 pub center: bool,
46}
47
48impl Default for AutoPadConfig {
49 fn default() -> Self {
50 Self {
51 mode: PaddingMode::Zero,
52 min_pad: 0,
53 max_pad: None,
54 power_of_2: false,
55 center: false,
56 }
57 }
58}
59
60impl AutoPadConfig {
61 pub fn new(mode: PaddingMode) -> Self {
63 Self {
64 mode,
65 ..Default::default()
66 }
67 }
68
69 pub fn with_min_pad(mut self, minpad: usize) -> Self {
71 self.min_pad = minpad;
72 self
73 }
74
75 pub fn with_max_pad(mut self, maxpad: usize) -> Self {
77 self.max_pad = Some(maxpad);
78 self
79 }
80
81 pub fn with_power_of_2(mut self) -> Self {
83 self.power_of_2 = true;
84 self
85 }
86
87 pub fn with_center(mut self) -> Self {
89 self.center = true;
90 self
91 }
92}
93
94#[allow(dead_code)]
96pub fn auto_pad_1d<T>(x: &Array1<T>, config: &AutoPadConfig) -> FFTResult<Array1<T>>
97where
98 T: Clone + Zero,
99{
100 let n = x.len();
101
102 let target_size = if config.power_of_2 {
104 let min_size = n + config.min_pad;
106 let mut size = 1;
107 while size < min_size {
108 size *= 2;
109 }
110 size
111 } else {
112 next_fast_len(n + config.min_pad, false)
114 };
115
116 let padded_size = if let Some(max_pad) = config.max_pad {
118 target_size.min(n + max_pad)
119 } else {
120 target_size
121 };
122
123 if padded_size == n {
125 return Ok(x.clone());
126 }
127
128 let mut padded = Array1::zeros(padded_size);
130
131 let start_idx = if config.center {
133 (padded_size - n) / 2
134 } else {
135 0
136 };
137
138 padded.slice_mut(s![start_idx..start_idx + n]).assign(x);
140
141 match config.mode {
143 PaddingMode::None | PaddingMode::Zero => {
144 }
146 PaddingMode::Constant(_value) => {
147 let const_val = T::zero(); if start_idx > 0 {
149 padded.slice_mut(s![..start_idx]).fill(const_val.clone());
150 }
151 if start_idx + n < padded_size {
152 padded.slice_mut(s![start_idx + n..]).fill(const_val);
153 }
154 }
155 PaddingMode::Edge => {
156 if start_idx > 0 {
158 let left_val = x[0].clone();
159 padded.slice_mut(s![..start_idx]).fill(left_val);
160 }
161 if start_idx + n < padded_size {
162 let right_val = x[n - 1].clone();
163 padded.slice_mut(s![start_idx + n..]).fill(right_val);
164 }
165 }
166 PaddingMode::Reflect => {
167 for i in 0..start_idx {
169 let offset = start_idx - i - 1;
170 let cycle = 2 * (n - 1);
171 let src_idx = offset % cycle;
172 let src_idx = if src_idx >= n {
173 cycle - src_idx
174 } else {
175 src_idx
176 };
177 padded[i] = x[src_idx].clone();
178 }
179 for i in (start_idx + n)..padded_size {
180 let offset = i - (start_idx + n);
181 let cycle = 2 * (n - 1);
182 let src_idx = (n - 1 - (offset % cycle)).max(0);
183 padded[i] = x[src_idx].clone();
184 }
185 }
186 PaddingMode::Symmetric => {
187 for i in 0..start_idx {
189 let offset = start_idx - i;
190 let cycle = 2 * n;
191 let src_idx = (offset - 1) % cycle;
192 let src_idx = if src_idx >= n {
193 cycle - 1 - src_idx
194 } else {
195 src_idx
196 };
197 padded[i] = x[src_idx].clone();
198 }
199 for i in (start_idx + n)..padded_size {
200 let offset = i - (start_idx + n);
201 let cycle = 2 * n;
202 let src_idx = (n - 1 - (offset % cycle)).max(0);
203 padded[i] = x[src_idx].clone();
204 }
205 }
206 PaddingMode::Wrap => {
207 for i in 0..start_idx {
209 let src_idx = (n - (start_idx - i) % n) % n;
210 padded[i] = x[src_idx].clone();
211 }
212 for i in (start_idx + n)..padded_size {
213 let src_idx = (i - start_idx) % n;
214 padded[i] = x[src_idx].clone();
215 }
216 }
217 PaddingMode::LinearRamp => {
218 if start_idx > 0 {
220 for i in 0..start_idx {
221 padded[i] = T::zero();
223 }
224 }
225 if start_idx + n < padded_size {
226 for i in (start_idx + n)..padded_size {
227 padded[i] = T::zero();
229 }
230 }
231 }
232 }
233
234 Ok(padded)
235}
236
237#[allow(dead_code)]
239pub fn auto_pad_complex(
240 x: &Array1<Complex<f64>>,
241 config: &AutoPadConfig,
242) -> FFTResult<Array1<Complex<f64>>> {
243 let n = x.len();
244
245 let target_size = if config.power_of_2 {
247 let min_size = n + config.min_pad;
248 let mut size = 1;
249 while size < min_size {
250 size *= 2;
251 }
252 size
253 } else {
254 next_fast_len(n + config.min_pad, false)
255 };
256
257 let padded_size = if let Some(max_pad) = config.max_pad {
259 target_size.min(n + max_pad)
260 } else {
261 target_size
262 };
263
264 if padded_size == n {
265 return Ok(x.clone());
266 }
267
268 let mut padded = Array1::zeros(padded_size);
269 let start_idx = if config.center {
270 (padded_size - n) / 2
271 } else {
272 0
273 };
274
275 padded.slice_mut(s![start_idx..start_idx + n]).assign(x);
276
277 match config.mode {
279 PaddingMode::None | PaddingMode::Zero => {}
280 PaddingMode::Constant(value) => {
281 let const_val = Complex::new(value, 0.0);
282 if start_idx > 0 {
283 padded.slice_mut(s![..start_idx]).fill(const_val);
284 }
285 if start_idx + n < padded_size {
286 padded.slice_mut(s![start_idx + n..]).fill(const_val);
287 }
288 }
289 PaddingMode::Edge => {
290 if start_idx > 0 {
291 let left_val = x[0];
292 padded.slice_mut(s![..start_idx]).fill(left_val);
293 }
294 if start_idx + n < padded_size {
295 let right_val = x[n - 1];
296 padded.slice_mut(s![start_idx + n..]).fill(right_val);
297 }
298 }
299 PaddingMode::LinearRamp => {
300 if start_idx > 0 {
302 let edge_val = x[0];
303 for i in 0..start_idx {
304 let t = i as f64 / start_idx as f64;
305 padded[start_idx - 1 - i] = edge_val * t;
306 }
307 }
308 if start_idx + n < padded_size {
309 let edge_val = x[n - 1];
310 let pad_len = padded_size - (start_idx + n);
311 for i in 0..pad_len {
312 let t = 1.0 - (i as f64 / pad_len as f64);
313 padded[start_idx + n + i] = edge_val * t;
314 }
315 }
316 }
317 _ => {
318 return auto_pad_1d(x, config);
320 }
321 }
322
323 Ok(padded)
324}
325
326#[allow(dead_code)]
328pub fn remove_padding_1d<T>(
329 padded: &Array1<T>,
330 original_size: usize,
331 config: &AutoPadConfig,
332) -> Array1<T>
333where
334 T: Clone,
335{
336 let padded_size = padded.len();
337
338 if padded_size == original_size {
339 return padded.clone();
340 }
341
342 let start_idx = if config.center {
343 (padded_size - original_size) / 2
344 } else {
345 0
346 };
347
348 padded
349 .slice(s![start_idx..start_idx + original_size])
350 .to_owned()
351}
352
353#[allow(dead_code)]
355pub fn auto_pad_nd<S, D>(
356 x: &ArrayBase<S, D>,
357 config: &AutoPadConfig,
358 axes: Option<&[usize]>,
359) -> FFTResult<ArrayD<Complex<f64>>>
360where
361 S: Data<Elem = Complex<f64>>,
362 D: Dimension,
363{
364 let shape = x.shape();
365 let default_axes = (0..shape.len()).collect::<Vec<_>>();
366 let axes = axes.unwrap_or(&default_axes[..]);
367
368 let mut paddedshape = shape.to_vec();
369
370 for &axis in axes {
372 let n = shape[axis];
373 let target_size = if config.power_of_2 {
374 let min_size = n + config.min_pad;
375 let mut size = 1;
376 while size < min_size {
377 size *= 2;
378 }
379 size
380 } else {
381 next_fast_len(n + config.min_pad, false)
382 };
383
384 paddedshape[axis] = if let Some(max_pad) = config.max_pad {
385 target_size.min(n + max_pad)
386 } else {
387 target_size
388 };
389 }
390
391 let mut padded = ArrayD::zeros(paddedshape.clone());
393
394 let x_dyn = x
396 .to_owned()
397 .into_shape_with_order(x.shape().to_vec())
398 .unwrap()
399 .into_dyn();
400
401 match x_dyn.ndim() {
403 1 => {
404 let start = if config.center {
405 (paddedshape[0] - shape[0]) / 2
406 } else {
407 0
408 };
409 padded.slice_mut(s![start..start + shape[0]]).assign(&x_dyn);
410 }
411 2 => {
412 let start0 = if config.center && axes.contains(&0) {
413 (paddedshape[0] - shape[0]) / 2
414 } else {
415 0
416 };
417 let start1 = if config.center && axes.contains(&1) {
418 (paddedshape[1] - shape[1]) / 2
419 } else {
420 0
421 };
422 padded
423 .slice_mut(s![start0..start0 + shape[0], start1..start1 + shape[1]])
424 .assign(
425 &x_dyn
426 .view()
427 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
428 .unwrap(),
429 );
430 }
431 _ => {
432 return Err(crate::FFTError::ValueError(
435 "auto_pad_nd currently only supports 1D and 2D arrays".to_string(),
436 ));
437 }
438 }
439
440 match config.mode {
442 PaddingMode::None | PaddingMode::Zero => {
443 }
445 PaddingMode::Constant(value) => {
446 let _const_val = Complex::new(value, 0.0);
448 }
450 _ => {
451 }
453 }
454
455 Ok(padded)
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use approx::assert_abs_diff_eq;
462
463 #[test]
464 fn test_auto_pad_zero() {
465 let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
466 let config = AutoPadConfig::new(PaddingMode::Zero);
467
468 let padded = auto_pad_complex(&x.mapv(|v| Complex::new(v, 0.0)), &config).unwrap();
469
470 assert!(padded.len() >= x.len());
472
473 for i in 0..x.len() {
475 assert_abs_diff_eq!(padded[i].re, x[i], epsilon = 1e-10);
476 }
477 }
478
479 #[test]
480 fn test_auto_pad_power_of_2() {
481 let x = Array1::from_vec(vec![1.0; 5]);
482 let config = AutoPadConfig::new(PaddingMode::Zero).with_power_of_2();
483
484 let padded = auto_pad_complex(&x.mapv(|v| Complex::new(v, 0.0)), &config).unwrap();
485
486 assert_eq!(padded.len(), 8);
488 }
489
490 #[test]
491 fn test_remove_padding() {
492 let padded = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 0.0, 0.0]);
493 let config = AutoPadConfig::new(PaddingMode::Zero);
494
495 let unpadded = remove_padding_1d(&padded, 4, &config);
496 assert_eq!(unpadded.len(), 4);
497 assert_eq!(unpadded.as_slice().unwrap(), &[0.0, 1.0, 2.0, 3.0]);
498 }
499
500 #[test]
501 fn test_auto_pad_center() {
502 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
503 let config = AutoPadConfig::new(PaddingMode::Zero)
504 .with_center()
505 .with_min_pad(3);
506
507 let padded = auto_pad_complex(&x.mapv(|v| Complex::new(v, 0.0)), &config).unwrap();
508
509 assert!(padded.len() >= 6);
511 let start = (padded.len() - 3) / 2;
512 assert_abs_diff_eq!(padded[start].re, 1.0, epsilon = 1e-10);
513 assert_abs_diff_eq!(padded[start + 1].re, 2.0, epsilon = 1e-10);
514 assert_abs_diff_eq!(padded[start + 2].re, 3.0, epsilon = 1e-10);
515 }
516}