1use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write};
35use std::path::PathBuf;
36
37use crate::error::{FFTError, FFTResult};
38use crate::fft::{fft, ifft};
39use scirs2_core::numeric::Complex64;
40
41#[derive(Debug, Clone)]
47pub struct OutOfCoreConfig {
48 pub rows: usize,
50 pub cols: usize,
52 pub chunk_rows: usize,
55 pub temp_dir: PathBuf,
57}
58
59pub struct OutOfCoreFft2D {
65 config: OutOfCoreConfig,
66}
67
68impl OutOfCoreFft2D {
73 pub fn new(rows: usize, cols: usize) -> Self {
78 Self {
79 config: OutOfCoreConfig {
80 rows,
81 cols,
82 chunk_rows: rows,
83 temp_dir: std::env::temp_dir(),
84 },
85 }
86 }
87
88 pub fn with_config(config: OutOfCoreConfig) -> Self {
90 Self { config }
91 }
92
93 pub fn fft2d(&self, data: &[f64]) -> FFTResult<Vec<(f64, f64)>> {
104 let rows = self.config.rows;
105 let cols = self.config.cols;
106
107 if data.len() != rows * cols {
108 return Err(FFTError::DimensionError(format!(
109 "outofcore fft2d: data length {} != rows*cols {}",
110 data.len(),
111 rows * cols
112 )));
113 }
114
115 let chunk_rows = self.config.chunk_rows.max(1).min(rows);
116
117 let row_fft_results = self.fft_all_rows(data, rows, cols, chunk_rows)?;
120
121 let col_fft_results = self.fft_all_cols(&row_fft_results, rows, cols, chunk_rows)?;
123
124 Ok(col_fft_results)
125 }
126
127 pub fn ifft2d(&self, spectrum: &[(f64, f64)]) -> FFTResult<Vec<f64>> {
138 let rows = self.config.rows;
139 let cols = self.config.cols;
140
141 if spectrum.len() != rows * cols {
142 return Err(FFTError::DimensionError(format!(
143 "outofcore ifft2d: spectrum length {} != rows*cols {}",
144 spectrum.len(),
145 rows * cols
146 )));
147 }
148
149 let chunk_rows = self.config.chunk_rows.max(1).min(rows);
150
151 let row_ifft = self.ifft_all_rows(spectrum, rows, cols, chunk_rows)?;
153
154 let col_ifft = self.ifft_all_cols(&row_ifft, rows, cols, chunk_rows)?;
156
157 Ok(col_ifft.iter().map(|&(re, _im)| re).collect())
162 }
163
164 fn fft_all_rows(
170 &self,
171 data: &[f64],
172 rows: usize,
173 cols: usize,
174 chunk_rows: usize,
175 ) -> FFTResult<Vec<(f64, f64)>> {
176 let mut row_spectra: Vec<(f64, f64)> = Vec::with_capacity(rows * cols);
177
178 let mut row = 0_usize;
179 while row < rows {
180 let end = (row + chunk_rows).min(rows);
181 for r in row..end {
182 let start = r * cols;
183 let row_data = &data[start..start + cols];
184 let spectrum = fft(row_data, Some(cols))?;
185 for c in spectrum.iter().take(cols) {
186 row_spectra.push((c.re, c.im));
187 }
188 }
189 row = end;
190 }
191
192 Ok(row_spectra)
193 }
194
195 fn fft_all_cols(
201 &self,
202 data: &[(f64, f64)],
203 rows: usize,
204 cols: usize,
205 _chunk_rows: usize,
206 ) -> FFTResult<Vec<(f64, f64)>> {
207 let total = rows * cols;
208 let mut result: Vec<(f64, f64)> = vec![(0.0, 0.0); total];
209
210 for c in 0..cols {
211 let col_data: Vec<Complex64> = (0..rows)
213 .map(|r| {
214 let (re, im) = data[r * cols + c];
215 Complex64::new(re, im)
216 })
217 .collect();
218
219 let col_spectrum = fft(&col_data, Some(rows))?;
222
223 for (r, val) in col_spectrum.iter().take(rows).enumerate() {
225 result[r * cols + c] = (val.re, val.im);
226 }
227 }
228
229 Ok(result)
230 }
231
232 fn ifft_all_rows(
234 &self,
235 spectrum: &[(f64, f64)],
236 rows: usize,
237 cols: usize,
238 chunk_rows: usize,
239 ) -> FFTResult<Vec<(f64, f64)>> {
240 let mut row_results: Vec<(f64, f64)> = Vec::with_capacity(rows * cols);
241
242 let mut row = 0_usize;
243 while row < rows {
244 let end = (row + chunk_rows).min(rows);
245 for r in row..end {
246 let start = r * cols;
247 let row_data: Vec<Complex64> = spectrum[start..start + cols]
248 .iter()
249 .map(|&(re, im)| Complex64::new(re, im))
250 .collect();
251 let time_row = ifft(&row_data, Some(cols))?;
252 for v in time_row.iter().take(cols) {
253 row_results.push((v.re, v.im));
254 }
255 }
256 row = end;
257 }
258
259 Ok(row_results)
260 }
261
262 fn ifft_all_cols(
264 &self,
265 data: &[(f64, f64)],
266 rows: usize,
267 cols: usize,
268 _chunk_rows: usize,
269 ) -> FFTResult<Vec<(f64, f64)>> {
270 let total = rows * cols;
271 let mut result: Vec<(f64, f64)> = vec![(0.0, 0.0); total];
272
273 for c in 0..cols {
274 let col_data: Vec<Complex64> = (0..rows)
275 .map(|r| {
276 let (re, im) = data[r * cols + c];
277 Complex64::new(re, im)
278 })
279 .collect();
280
281 let col_result = ifft(&col_data, Some(rows))?;
282
283 for (r, val) in col_result.iter().take(rows).enumerate() {
284 result[r * cols + c] = (val.re, val.im);
285 }
286 }
287
288 Ok(result)
289 }
290}
291
292fn write_complex_pairs<W: Write>(writer: &mut W, data: &[(f64, f64)]) -> FFTResult<()> {
298 for &(re, im) in data {
299 writer
300 .write_all(&re.to_le_bytes())
301 .map_err(|e| FFTError::IOError(format!("write re: {}", e)))?;
302 writer
303 .write_all(&im.to_le_bytes())
304 .map_err(|e| FFTError::IOError(format!("write im: {}", e)))?;
305 }
306 Ok(())
307}
308
309fn read_complex_pair<R: Read + Seek>(
312 reader: &mut BufReader<R>,
313 byte_offset: u64,
314) -> FFTResult<(f64, f64)> {
315 reader
316 .seek(SeekFrom::Start(byte_offset))
317 .map_err(|e| FFTError::IOError(format!("seek: {}", e)))?;
318
319 let mut buf = [0u8; 8];
320 reader
321 .read_exact(&mut buf)
322 .map_err(|e| FFTError::IOError(format!("read re: {}", e)))?;
323 let re = f64::from_le_bytes(buf);
324
325 reader
326 .read_exact(&mut buf)
327 .map_err(|e| FFTError::IOError(format!("read im: {}", e)))?;
328 let im = f64::from_le_bytes(buf);
329
330 Ok((re, im))
331}
332
333pub fn disk_based_fft2d(
347 data: &[f64],
348 rows: usize,
349 cols: usize,
350 temp_dir: &PathBuf,
351) -> FFTResult<Vec<(f64, f64)>> {
352 if data.len() != rows * cols {
353 return Err(FFTError::DimensionError(format!(
354 "disk_based_fft2d: data {} != {}*{}",
355 data.len(),
356 rows,
357 cols
358 )));
359 }
360
361 let tmp_row = tempfile::NamedTempFile::new_in(temp_dir)
363 .map_err(|e| FFTError::IOError(format!("create temp row file: {}", e)))?;
364
365 {
366 let mut writer = BufWriter::new(tmp_row.as_file());
367 for r in 0..rows {
368 let row_data = &data[r * cols..(r + 1) * cols];
369 let spectrum = fft(row_data, Some(cols))?;
370 write_complex_pairs(
371 &mut writer,
372 &spectrum
373 .iter()
374 .take(cols)
375 .map(|c| (c.re, c.im))
376 .collect::<Vec<_>>(),
377 )?;
378 }
379 writer
380 .flush()
381 .map_err(|e| FFTError::IOError(format!("flush row temp: {}", e)))?;
382 }
383
384 const COMPLEX_BYTES: u64 = 16;
387 let row_stride = cols as u64 * COMPLEX_BYTES;
388
389 let mut result: Vec<(f64, f64)> = vec![(0.0, 0.0); rows * cols];
390
391 let row_file = std::fs::File::open(tmp_row.path())
392 .map_err(|e| FFTError::IOError(format!("open row temp: {}", e)))?;
393 let mut reader = BufReader::new(row_file);
394
395 for c in 0..cols {
396 let col_data: Result<Vec<Complex64>, FFTError> = (0..rows)
398 .map(|r| {
399 let offset = r as u64 * row_stride + c as u64 * COMPLEX_BYTES;
400 let (re, im) = read_complex_pair(&mut reader, offset)?;
401 Ok(Complex64::new(re, im))
402 })
403 .collect();
404
405 let col_vec = col_data?;
406 let col_spectrum = fft(&col_vec, Some(rows))?;
407
408 for (r, val) in col_spectrum.iter().take(rows).enumerate() {
409 result[r * cols + c] = (val.re, val.im);
410 }
411 }
412
413 Ok(result)
414}
415
416pub fn small_fft2d(data: &[f64], rows: usize, cols: usize) -> Vec<(f64, f64)> {
444 if data.len() != rows * cols || rows == 0 || cols == 0 {
446 return Vec::new();
447 }
448
449 let proc = OutOfCoreFft2D::new(rows, cols);
450 proc.fft2d(data).unwrap_or_default()
451}
452
453#[cfg(test)]
458mod tests {
459 use super::*;
460
461 #[test]
464 fn test_small_fft2d_impulse_dc() {
465 let rows = 8_usize;
466 let cols = 8_usize;
467 let mut data = vec![0.0_f64; rows * cols];
468 data[0] = 1.0;
469
470 let spectrum = small_fft2d(&data, rows, cols);
471 assert_eq!(spectrum.len(), rows * cols);
472
473 for (i, &(re, im)) in spectrum.iter().enumerate() {
475 assert!((re - 1.0).abs() < 1e-10, "bin {i}: re={re} expected 1.0");
476 assert!(im.abs() < 1e-10, "bin {i}: im={im} expected 0.0");
477 }
478 }
479
480 #[test]
483 fn test_fft2d_ifft2d_roundtrip_8x8() {
484 let rows = 8_usize;
485 let cols = 8_usize;
486 let data: Vec<f64> = (0..rows * cols)
488 .map(|k| {
489 let r = k / cols;
490 let c = k % cols;
491 (r as f64 * 0.3).sin() + (c as f64 * 0.7).cos()
492 })
493 .collect();
494
495 let proc = OutOfCoreFft2D::new(rows, cols);
496 let spectrum = proc.fft2d(&data).expect("fft2d failed");
497 let recovered = proc.ifft2d(&spectrum).expect("ifft2d failed");
498
499 assert_eq!(recovered.len(), data.len());
500 for (i, (&orig, &rec)) in data.iter().zip(recovered.iter()).enumerate() {
501 assert!(
502 (orig - rec).abs() < 1e-10,
503 "index {i}: original={orig} recovered={rec} diff={}",
504 (orig - rec).abs()
505 );
506 }
507 }
508
509 #[test]
512 fn test_fft2d_matches_small_fft2d_16x16() {
513 let rows = 16_usize;
514 let cols = 16_usize;
515 let data: Vec<f64> = (0..rows * cols)
516 .map(|k| (k as f64 * std::f64::consts::PI / 16.0).sin())
517 .collect();
518
519 let proc = OutOfCoreFft2D::new(rows, cols);
520 let result1 = proc.fft2d(&data).expect("fft2d failed");
521 let result2 = small_fft2d(&data, rows, cols);
522
523 assert_eq!(result1.len(), result2.len());
524 for (i, (&(re1, im1), &(re2, im2))) in result1.iter().zip(result2.iter()).enumerate() {
525 assert!((re1 - re2).abs() < 1e-10, "bin {i}: re1={re1} re2={re2}");
526 assert!((im1 - im2).abs() < 1e-10, "bin {i}: im1={im1} im2={im2}");
527 }
528 }
529
530 #[test]
535 fn test_parseval_theorem() {
536 let rows = 8_usize;
537 let cols = 8_usize;
538 let n = (rows * cols) as f64;
539 let data: Vec<f64> = (0..rows * cols)
540 .map(|k| ((k as f64) * 0.4).sin() * 2.0)
541 .collect();
542
543 let proc = OutOfCoreFft2D::new(rows, cols);
544 let spectrum = proc.fft2d(&data).expect("fft2d failed");
545
546 let spatial_energy: f64 = data.iter().map(|&x| x * x).sum();
547 let spectral_energy: f64 = spectrum
548 .iter()
549 .map(|&(re, im)| re * re + im * im)
550 .sum::<f64>()
551 / n;
552
553 assert!(
554 (spatial_energy - spectral_energy).abs() < 1e-9,
555 "Parseval: spatial={spatial_energy} spectral/N={spectral_energy}"
556 );
557 }
558
559 #[test]
562 fn test_chunk_rows_1_vs_full() {
563 let rows = 8_usize;
564 let cols = 8_usize;
565 let data: Vec<f64> = (0..rows * cols).map(|k| (k as f64 * 0.2).cos()).collect();
566
567 let proc_full = OutOfCoreFft2D::with_config(OutOfCoreConfig {
568 rows,
569 cols,
570 chunk_rows: rows,
571 temp_dir: std::env::temp_dir(),
572 });
573 let result_full = proc_full.fft2d(&data).expect("full fft2d failed");
574
575 let proc_one = OutOfCoreFft2D::with_config(OutOfCoreConfig {
576 rows,
577 cols,
578 chunk_rows: 1,
579 temp_dir: std::env::temp_dir(),
580 });
581 let result_one = proc_one.fft2d(&data).expect("chunk-1 fft2d failed");
582
583 assert_eq!(result_full.len(), result_one.len());
584 for (i, (&(re_f, im_f), &(re_o, im_o))) in
585 result_full.iter().zip(result_one.iter()).enumerate()
586 {
587 assert!(
588 (re_f - re_o).abs() < 1e-10,
589 "bin {i}: re_full={re_f} re_one={re_o}"
590 );
591 assert!(
592 (im_f - im_o).abs() < 1e-10,
593 "bin {i}: im_full={im_f} im_one={im_o}"
594 );
595 }
596 }
597
598 #[test]
601 fn test_disk_based_matches_in_memory() {
602 let rows = 8_usize;
603 let cols = 8_usize;
604 let data: Vec<f64> = (0..rows * cols).map(|k| (k as f64 * 0.5).sin()).collect();
605
606 let in_memory = small_fft2d(&data, rows, cols);
607 let on_disk =
608 disk_based_fft2d(&data, rows, cols, &std::env::temp_dir()).expect("disk fft2d failed");
609
610 assert_eq!(in_memory.len(), on_disk.len());
611 for (i, (&(re_m, im_m), &(re_d, im_d))) in in_memory.iter().zip(on_disk.iter()).enumerate()
612 {
613 assert!(
614 (re_m - re_d).abs() < 1e-10,
615 "bin {i}: re_mem={re_m} re_disk={re_d}"
616 );
617 assert!(
618 (im_m - im_d).abs() < 1e-10,
619 "bin {i}: im_mem={im_m} im_disk={im_d}"
620 );
621 }
622 }
623
624 #[test]
627 fn test_dimension_mismatch_error() {
628 let proc = OutOfCoreFft2D::new(8, 8);
629 let result = proc.fft2d(&[1.0, 2.0, 3.0]); assert!(result.is_err(), "expected error for length mismatch");
631 }
632
633 #[test]
636 fn test_pure_tone_bin_location() {
637 let rows = 8_usize;
639 let cols = 8_usize;
640 let f_r = 1_usize; let f_c = 2_usize; let data: Vec<f64> = (0..rows * cols)
644 .map(|k| {
645 let r = k / cols;
646 let c = k % cols;
647 (2.0 * std::f64::consts::PI * f_r as f64 * r as f64 / rows as f64).cos()
648 * (2.0 * std::f64::consts::PI * f_c as f64 * c as f64 / cols as f64).cos()
649 })
650 .collect();
651
652 let spectrum = small_fft2d(&data, rows, cols);
653 let n = (rows * cols) as f64;
654
655 let expected_magnitude = n / 4.0;
658
659 let bin_at = |r: usize, c: usize| -> f64 {
660 let (re, im) = spectrum[r * cols + c];
661 (re * re + im * im).sqrt()
662 };
663
664 assert!(
665 (bin_at(f_r, f_c) - expected_magnitude).abs() < 1e-9,
666 "expected magnitude {} at ({f_r},{f_c}), got {}",
667 expected_magnitude,
668 bin_at(f_r, f_c)
669 );
670 }
671}