1use std::f32::consts::PI;
7use torsh_core::{Result as TorshResult, TorshError};
8use torsh_tensor::Tensor;
9
10use crate::spectral::{fft, ifft, rfft};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum WindowFunction {
15 Rectangular,
17 Hann,
19 Hamming,
21 Blackman,
23 Bartlett,
25 Kaiser(i32), }
28
29pub fn generate_window(window_type: WindowFunction, size: usize) -> TorshResult<Vec<f32>> {
58 if size == 0 {
59 return Err(TorshError::InvalidArgument(
60 "Window size must be positive".to_string(),
61 ));
62 }
63
64 let mut window = vec![0.0; size];
65
66 match window_type {
67 WindowFunction::Rectangular => {
68 window.fill(1.0);
69 }
70 WindowFunction::Hann => {
71 for (i, w) in window.iter_mut().enumerate() {
72 let n = i as f32;
73 let n_size = size as f32;
74 *w = 0.5 - 0.5 * (2.0 * PI * n / (n_size - 1.0)).cos();
75 }
76 }
77 WindowFunction::Hamming => {
78 for (i, w) in window.iter_mut().enumerate() {
79 let n = i as f32;
80 let n_size = size as f32;
81 *w = 0.54 - 0.46 * (2.0 * PI * n / (n_size - 1.0)).cos();
82 }
83 }
84 WindowFunction::Blackman => {
85 for (i, w) in window.iter_mut().enumerate() {
86 let n = i as f32;
87 let n_size = size as f32;
88 let factor = 2.0 * PI * n / (n_size - 1.0);
89 *w = 0.42 - 0.5 * factor.cos() + 0.08 * (2.0 * factor).cos();
90 }
91 }
92 WindowFunction::Bartlett => {
93 for (i, w) in window.iter_mut().enumerate() {
94 let n = i as f32;
95 let n_size = size as f32;
96 *w = 1.0 - (2.0 * n / (n_size - 1.0) - 1.0).abs();
97 }
98 }
99 WindowFunction::Kaiser(beta) => {
100 let beta_f = beta as f32;
102 let i0_beta = bessel_i0(beta_f);
103
104 for (i, w) in window.iter_mut().enumerate() {
105 let n = i as f32;
106 let n_size = size as f32;
107 let x = beta_f * (1.0 - (2.0 * n / (n_size - 1.0) - 1.0).powi(2)).sqrt();
108 *w = bessel_i0(x) / i0_beta;
109 }
110 }
111 }
112
113 Ok(window)
114}
115
116fn bessel_i0(x: f32) -> f32 {
118 let ax = x.abs();
119 if ax < 3.75 {
120 let y = (x / 3.75).powi(2);
121 1.0 + y
122 * (3.5156229
123 + y * (3.0899424
124 + y * (1.2067492 + y * (0.2659732 + y * (0.360768e-1 + y * 0.45813e-2)))))
125 } else {
126 let y = 3.75 / ax;
127 (ax.exp() / ax.sqrt())
128 * (0.39894228
129 + y * (0.1328592e-1
130 + y * (0.225319e-2
131 + y * (-0.157565e-2
132 + y * (0.916281e-2
133 + y * (-0.2057706e-1
134 + y * (0.2635537e-1
135 + y * (-0.1647633e-1 + y * 0.392377e-2))))))))
136 }
137}
138
139pub fn stft_complete(
188 input: &Tensor,
189 n_fft: usize,
190 hop_length: Option<usize>,
191 win_length: Option<usize>,
192 window: WindowFunction,
193 center: bool,
194 normalized: bool,
195 onesided: bool,
196) -> TorshResult<Tensor> {
197 let input_shape = input.shape();
198 let ndim = input_shape.ndim();
199
200 if ndim == 0 || ndim > 2 {
201 return Err(TorshError::InvalidArgument(
202 "STFT input must be 1D or 2D".to_string(),
203 ));
204 }
205
206 let hop_len = hop_length.unwrap_or(n_fft / 4);
207 let win_len = win_length.unwrap_or(n_fft);
208
209 if win_len > n_fft {
210 return Err(TorshError::InvalidArgument(
211 "Window length cannot exceed FFT size".to_string(),
212 ));
213 }
214
215 if hop_len == 0 {
216 return Err(TorshError::InvalidArgument(
217 "Hop length must be positive".to_string(),
218 ));
219 }
220
221 let window_data = generate_window(window, win_len)?;
223
224 let window_data = if normalized {
226 let energy: f32 = window_data.iter().map(|w| w * w).sum();
227 let scale = (energy / win_len as f32).sqrt();
228 window_data.iter().map(|w| w / scale).collect()
229 } else {
230 window_data
231 };
232
233 let signal_data = input.data()?;
235 let dims = input_shape.dims();
236
237 let (batch_size, signal_len) = if ndim == 1 {
238 (1, dims[0])
239 } else {
240 (dims[0], dims[1])
241 };
242
243 let (padded_signal, padded_len) = if center {
245 let pad_amount = n_fft / 2;
246 let new_len = signal_len + 2 * pad_amount;
247 let mut padded = vec![0.0; batch_size * new_len];
248
249 for b in 0..batch_size {
250 let src_start = b * signal_len;
251 let dst_start = b * new_len + pad_amount;
252
253 for i in 0..signal_len {
255 padded[dst_start + i] = signal_data[src_start + i];
256 }
257
258 for i in 0..pad_amount {
260 if i < signal_len {
261 padded[b * new_len + i] = signal_data[src_start + pad_amount - i];
262 }
263 }
264 for i in 0..pad_amount {
265 if signal_len > i + 1 {
266 padded[dst_start + signal_len + i] =
267 signal_data[src_start + signal_len - 2 - i];
268 }
269 }
270 }
271
272 (padded, new_len)
273 } else {
274 (signal_data.to_vec(), signal_len)
275 };
276
277 let n_frames = if padded_len >= n_fft {
279 (padded_len - n_fft) / hop_len + 1
280 } else {
281 0
282 };
283
284 if n_frames == 0 {
285 return Err(TorshError::InvalidArgument(
286 "Signal too short for STFT".to_string(),
287 ));
288 }
289
290 let freq_bins = if onesided { n_fft / 2 + 1 } else { n_fft };
292
293 let mut stft_data = Vec::with_capacity(batch_size * freq_bins * n_frames * 2);
295
296 for b in 0..batch_size {
297 let signal_start = b * padded_len;
298
299 for frame_idx in 0..n_frames {
300 let frame_start = signal_start + frame_idx * hop_len;
301
302 let mut frame = vec![0.0; n_fft];
304 for i in 0..win_len.min(n_fft) {
305 if frame_start + i < signal_start + padded_len {
306 frame[i] = padded_signal[frame_start + i] * window_data[i];
307 }
308 }
309
310 let frame_tensor = Tensor::from_data(frame, vec![n_fft], input.device())?;
312
313 let fft_result = if onesided {
315 rfft(&frame_tensor, Some(n_fft), None, None)?
316 } else {
317 use torsh_core::dtype::Complex32;
319 let complex_frame: Vec<Complex32> = frame_tensor
320 .data()?
321 .iter()
322 .map(|&x| Complex32::new(x, 0.0))
323 .collect();
324 let complex_tensor = Tensor::from_data(complex_frame, vec![n_fft], input.device())?;
325 fft(&complex_tensor, Some(n_fft), None, None)?
326 };
327
328 let fft_data = fft_result.data()?;
330 for val in fft_data.iter() {
331 stft_data.push(val.re);
332 stft_data.push(val.im);
333 }
334 }
335 }
336
337 let output_shape = if ndim == 1 {
339 vec![freq_bins, n_frames, 2]
340 } else {
341 vec![batch_size, freq_bins, n_frames, 2]
342 };
343
344 Tensor::from_data(stft_data, output_shape, input.device())
345}
346
347pub fn istft_complete(
380 stft: &Tensor,
381 n_fft: usize,
382 hop_length: Option<usize>,
383 win_length: Option<usize>,
384 window: WindowFunction,
385 center: bool,
386 normalized: bool,
387 onesided: bool,
388 length: Option<usize>,
389) -> TorshResult<Tensor> {
390 let stft_shape = stft.shape();
391 let ndim = stft_shape.ndim();
392
393 if ndim < 3 || ndim > 4 {
394 return Err(TorshError::InvalidArgument(
395 "ISTFT input must be 3D [freq, time, 2] or 4D [batch, freq, time, 2]".to_string(),
396 ));
397 }
398
399 let dims = stft_shape.dims();
400 if dims[ndim - 1] != 2 {
401 return Err(TorshError::InvalidArgument(
402 "Last dimension must be 2 (real, imag)".to_string(),
403 ));
404 }
405
406 let hop_len = hop_length.unwrap_or(n_fft / 4);
407 let win_len = win_length.unwrap_or(n_fft);
408
409 let window_data = generate_window(window, win_len)?;
411 let window_data = if normalized {
412 let energy: f32 = window_data.iter().map(|w| w * w).sum();
413 let scale = (energy / win_len as f32).sqrt();
414 window_data.iter().map(|w| w / scale).collect()
415 } else {
416 window_data
417 };
418
419 let (batch_size, freq_bins, n_frames) = if ndim == 3 {
420 (1, dims[0], dims[1])
421 } else {
422 (dims[0], dims[1], dims[2])
423 };
424
425 let expected_bins = if onesided { n_fft / 2 + 1 } else { n_fft };
427 if freq_bins != expected_bins {
428 return Err(TorshError::InvalidArgument(format!(
429 "Frequency bins mismatch: expected {}, got {}",
430 expected_bins, freq_bins
431 )));
432 }
433
434 let output_len = length.unwrap_or((n_frames - 1) * hop_len + n_fft);
436
437 let mut output_data = vec![0.0; batch_size * output_len];
439 let mut window_sum = vec![0.0; output_len];
440
441 let stft_data = stft.data()?;
443
444 for b in 0..batch_size {
446 let batch_offset = if ndim == 3 {
447 0
448 } else {
449 b * freq_bins * n_frames * 2
450 };
451
452 for frame_idx in 0..n_frames {
453 use torsh_core::dtype::Complex32;
455 let mut frame_fft = Vec::with_capacity(freq_bins);
456
457 for f in 0..freq_bins {
458 let idx = batch_offset + (f * n_frames + frame_idx) * 2;
459 if idx + 1 < stft_data.len() {
460 frame_fft.push(Complex32::new(stft_data[idx], stft_data[idx + 1]));
461 } else {
462 frame_fft.push(Complex32::new(0.0, 0.0));
463 }
464 }
465
466 let fft_tensor = Tensor::from_data(frame_fft, vec![freq_bins], stft.device())?;
468
469 let frame_signal = if onesided {
470 super::spectral_advanced::irfft(&fft_tensor, Some(n_fft), None, None)?
471 } else {
472 let ifft_result = ifft(&fft_tensor, Some(n_fft), None, None)?;
473 let ifft_data = ifft_result.data()?;
474 let real_data: Vec<f32> = ifft_data.iter().map(|c| c.re).collect();
475 Tensor::from_data(real_data, vec![n_fft], ifft_result.device())?
476 };
477
478 let frame_data = frame_signal.data()?;
479
480 let frame_start = frame_idx * hop_len;
482 for i in 0..win_len.min(n_fft) {
483 let output_idx = b * output_len + frame_start + i;
484 if output_idx < (b + 1) * output_len && i < frame_data.len() {
485 output_data[output_idx] += frame_data[i] * window_data[i];
486 if b == 0 {
487 window_sum[frame_start + i] += window_data[i] * window_data[i];
489 }
490 }
491 }
492 }
493 }
494
495 for b in 0..batch_size {
497 for i in 0..output_len {
498 let idx = b * output_len + i;
499 if window_sum[i] > 1e-8 {
500 output_data[idx] /= window_sum[i];
501 }
502 }
503 }
504
505 let final_data = if center {
507 let pad_amount = n_fft / 2;
508 let unpadded_len = output_len.saturating_sub(2 * pad_amount);
509 let mut unpadded = Vec::with_capacity(batch_size * unpadded_len);
510
511 for b in 0..batch_size {
512 let src_start = b * output_len + pad_amount;
513 for i in 0..unpadded_len {
514 if src_start + i < output_data.len() {
515 unpadded.push(output_data[src_start + i]);
516 } else {
517 unpadded.push(0.0);
518 }
519 }
520 }
521
522 unpadded
523 } else {
524 output_data
525 };
526
527 let final_len = if center {
529 output_len.saturating_sub(n_fft)
530 } else {
531 output_len
532 };
533
534 let output_shape = if batch_size == 1 {
535 vec![final_len]
536 } else {
537 vec![batch_size, final_len]
538 };
539
540 Tensor::from_data(final_data, output_shape, stft.device())
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546 use crate::random_ops::randn;
547
548 #[test]
549 fn test_window_generation() -> TorshResult<()> {
550 let hann = generate_window(WindowFunction::Hann, 256)?;
552 assert_eq!(hann.len(), 256);
553 assert!(hann[0] < 0.01); assert!(hann[128] > 0.99); let hamming = generate_window(WindowFunction::Hamming, 256)?;
558 assert_eq!(hamming.len(), 256);
559
560 let rect = generate_window(WindowFunction::Rectangular, 256)?;
562 assert!(rect.iter().all(|&x| (x - 1.0).abs() < 1e-6));
563
564 Ok(())
565 }
566
567 #[test]
568 fn test_stft_basic() -> TorshResult<()> {
569 let signal = randn(&[1024], None, None, None)?;
570
571 let stft_result = stft_complete(
572 &signal,
573 256,
574 Some(128),
575 None,
576 WindowFunction::Hann,
577 true,
578 false,
579 true,
580 )?;
581
582 let stft_result_shape = stft_result.shape();
584 let shape = stft_result_shape.dims();
585 assert_eq!(shape.len(), 3);
586 assert_eq!(shape[0], 129); assert_eq!(shape[2], 2); Ok(())
590 }
591
592 #[test]
593 fn test_stft_istft_roundtrip() -> TorshResult<()> {
594 let signal_len = 2048;
595 let signal = randn(&[signal_len], None, None, None)?;
596
597 let n_fft = 256;
598 let hop_length = 64;
599
600 let stft_result = stft_complete(
602 &signal,
603 n_fft,
604 Some(hop_length),
605 None,
606 WindowFunction::Hann,
607 true,
608 false,
609 true,
610 )?;
611
612 let reconstructed = istft_complete(
614 &stft_result,
615 n_fft,
616 Some(hop_length),
617 None,
618 WindowFunction::Hann,
619 true,
620 false,
621 true,
622 Some(signal_len),
623 )?;
624
625 let signal_data = signal.data()?;
627 let recon_data = reconstructed.data()?;
628
629 let mut max_error = 0.0f32;
630 for i in 0..signal_len.min(recon_data.len()) {
631 let error = (signal_data[i] - recon_data[i]).abs();
632 max_error = max_error.max(error);
633 }
634
635 assert!(max_error < 5.0, "Max reconstruction error: {}", max_error);
637
638 Ok(())
639 }
640
641 #[test]
642 fn test_stft_different_windows() -> TorshResult<()> {
643 let signal = randn(&[1024], None, None, None)?;
644
645 for window in &[
646 WindowFunction::Hann,
647 WindowFunction::Hamming,
648 WindowFunction::Blackman,
649 WindowFunction::Bartlett,
650 ] {
651 let stft_result =
652 stft_complete(&signal, 256, Some(128), None, *window, false, false, true)?;
653
654 assert_eq!(stft_result.shape().ndim(), 3);
655 }
656
657 Ok(())
658 }
659
660 #[test]
661 fn test_stft_batch_processing() -> TorshResult<()> {
662 let batch_size = 4;
663 let signal_len = 1024;
664 let batch_signal = randn(&[batch_size, signal_len], None, None, None)?;
665
666 let stft_result = stft_complete(
667 &batch_signal,
668 256,
669 Some(128),
670 None,
671 WindowFunction::Hann,
672 true,
673 false,
674 true,
675 )?;
676
677 let stft_result_shape = stft_result.shape();
679 let shape = stft_result_shape.dims();
680 assert_eq!(shape.len(), 4);
681 assert_eq!(shape[0], batch_size);
682 assert_eq!(shape[1], 129); assert_eq!(shape[3], 2);
684
685 Ok(())
686 }
687
688 #[test]
689 fn test_error_handling() {
690 let signal = randn(&[64], None, None, None).unwrap();
691
692 let result = stft_complete(
694 &signal,
695 256,
696 Some(0),
697 None,
698 WindowFunction::Hann,
699 false,
700 false,
701 true,
702 );
703 assert!(result.is_err());
704
705 let result = stft_complete(
707 &signal,
708 128,
709 Some(64),
710 Some(256),
711 WindowFunction::Hann,
712 false,
713 false,
714 true,
715 );
716 assert!(result.is_err());
717
718 let tiny_signal = randn(&[32], None, None, None).unwrap();
720 let result = stft_complete(
721 &tiny_signal,
722 256,
723 Some(128),
724 None,
725 WindowFunction::Hann,
726 false,
727 false,
728 true,
729 );
730 assert!(result.is_err());
731 }
732
733 #[test]
734 fn test_window_properties() -> TorshResult<()> {
735 let size = 256;
737
738 let hann = generate_window(WindowFunction::Hann, size)?;
740 let hann_sum: f32 = hann.iter().sum();
741 assert!((hann_sum - (size as f32 / 2.0)).abs() < 10.0);
742
743 let hamming = generate_window(WindowFunction::Hamming, size)?;
745 assert!(hamming.iter().all(|&x| x > 0.08)); let blackman = generate_window(WindowFunction::Blackman, size)?;
749 assert!(blackman[0] < 0.01); assert!(blackman[size - 1] < 0.01);
751
752 Ok(())
753 }
754
755 #[test]
756 fn test_stft_energy_conservation() -> TorshResult<()> {
757 let signal_len = 2048;
759 let signal = randn(&[signal_len], None, None, None)?;
760 let signal_data = signal.data()?;
761
762 let signal_energy: f32 = signal_data.iter().map(|&x| x * x).sum();
764
765 let stft_result = stft_complete(
767 &signal,
768 256,
769 Some(64),
770 None,
771 WindowFunction::Hann,
772 false,
773 false,
774 true,
775 )?;
776
777 let stft_data = stft_result.data()?;
779 let mut stft_energy = 0.0f32;
780 for chunk in stft_data.chunks_exact(2) {
781 stft_energy += chunk[0] * chunk[0] + chunk[1] * chunk[1];
782 }
783
784 let ratio = stft_energy / signal_energy;
786 assert!(ratio > 0.01 && ratio < 200.0, "Energy ratio: {}", ratio);
787
788 Ok(())
789 }
790
791 #[test]
792 fn test_stft_time_shift_property() -> TorshResult<()> {
793 let signal_len = 1024;
795 let mut signal1 = vec![0.0; signal_len];
796 let mut signal2 = vec![0.0; signal_len];
797
798 signal1[256] = 1.0;
800 signal2[512] = 1.0;
801
802 let tensor1 = Tensor::from_data(
803 signal1,
804 vec![signal_len],
805 torsh_core::device::DeviceType::Cpu,
806 )?;
807 let tensor2 = Tensor::from_data(
808 signal2,
809 vec![signal_len],
810 torsh_core::device::DeviceType::Cpu,
811 )?;
812
813 let stft1 = stft_complete(
814 &tensor1,
815 256,
816 Some(128),
817 None,
818 WindowFunction::Hann,
819 false,
820 false,
821 true,
822 )?;
823 let stft2 = stft_complete(
824 &tensor2,
825 256,
826 Some(128),
827 None,
828 WindowFunction::Hann,
829 false,
830 false,
831 true,
832 )?;
833
834 assert_eq!(stft1.shape().dims(), stft2.shape().dims());
836
837 Ok(())
838 }
839
840 #[test]
841 fn test_istft_perfect_reconstruction_conditions() -> TorshResult<()> {
842 let signal_len = 2048;
844 let signal = randn(&[signal_len], None, None, None)?;
845
846 let n_fft = 256;
848 let hop_length = 64; let stft_result = stft_complete(
851 &signal,
852 n_fft,
853 Some(hop_length),
854 None,
855 WindowFunction::Hann,
856 true,
857 false,
858 true,
859 )?;
860
861 let reconstructed = istft_complete(
862 &stft_result,
863 n_fft,
864 Some(hop_length),
865 None,
866 WindowFunction::Hann,
867 true,
868 false,
869 true,
870 Some(signal_len),
871 )?;
872
873 let recon_len = reconstructed.shape().dims()[0];
875 assert!(
876 recon_len >= signal_len - n_fft && recon_len <= signal_len + n_fft,
877 "Reconstructed length {} not close to signal length {}",
878 recon_len,
879 signal_len
880 );
881
882 Ok(())
883 }
884
885 #[test]
886 fn test_stft_onesided_vs_twosided() -> TorshResult<()> {
887 let signal = randn(&[1024], None, None, None)?;
889 let n_fft = 256;
890
891 let onesided = stft_complete(
892 &signal,
893 n_fft,
894 Some(128),
895 None,
896 WindowFunction::Hann,
897 false,
898 false,
899 true,
900 )?;
901 let onesided_shape = onesided.shape();
902 let onesided_freqs = onesided_shape.dims()[0];
903
904 assert_eq!(onesided_freqs, n_fft / 2 + 1);
906
907 Ok(())
908 }
909
910 #[test]
911 fn test_stft_with_all_window_types() -> TorshResult<()> {
912 let signal = randn(&[1024], None, None, None)?;
914
915 let windows = vec![
916 WindowFunction::Rectangular,
917 WindowFunction::Hann,
918 WindowFunction::Hamming,
919 WindowFunction::Blackman,
920 WindowFunction::Bartlett,
921 WindowFunction::Kaiser(5),
922 ];
923
924 for window in windows {
925 let result = stft_complete(&signal, 256, Some(128), None, window, false, false, true)?;
926 assert_eq!(result.shape().ndim(), 3);
927 }
928
929 Ok(())
930 }
931
932 #[test]
933 fn test_stft_normalized_vs_unnormalized() -> TorshResult<()> {
934 let signal = randn(&[1024], None, None, None)?;
936
937 let normalized = stft_complete(
938 &signal,
939 256,
940 Some(128),
941 None,
942 WindowFunction::Hann,
943 false,
944 true,
945 true,
946 )?;
947 let unnormalized = stft_complete(
948 &signal,
949 256,
950 Some(128),
951 None,
952 WindowFunction::Hann,
953 false,
954 false,
955 true,
956 )?;
957
958 assert_eq!(normalized.shape().dims(), unnormalized.shape().dims());
960
961 let norm_data = normalized.data()?;
963 let unnorm_data = unnormalized.data()?;
964
965 let mut diff_count = 0;
967 for i in 0..norm_data.len().min(unnorm_data.len()) {
968 if (norm_data[i] - unnorm_data[i]).abs() > 1e-6 {
969 diff_count += 1;
970 }
971 }
972 assert!(diff_count > 0);
973
974 Ok(())
975 }
976
977 #[test]
978 fn test_kaiser_window_beta_parameter() -> TorshResult<()> {
979 let size = 256;
981
982 let kaiser_low = generate_window(WindowFunction::Kaiser(0), size)?;
983 let kaiser_high = generate_window(WindowFunction::Kaiser(10), size)?;
984
985 assert!(kaiser_high[0] < kaiser_low[0]);
987 assert!(kaiser_high[size - 1] < kaiser_low[size - 1]);
988
989 Ok(())
990 }
991}