Skip to main content

voirs_sdk/audio/
processing.rs

1//! Audio processing functions for manipulation and enhancement.
2
3use super::buffer::AudioBuffer;
4use crate::{error::Result, VoirsError};
5
6// Import SciRS2 SIMD operations for performance optimization
7use scirs2_core::ndarray::Array1;
8use scirs2_core::simd_ops::SimdUnifiedOps;
9
10impl AudioBuffer {
11    /// Convert to different sample rate
12    pub fn resample(&self, target_rate: u32) -> Result<AudioBuffer> {
13        if target_rate == self.sample_rate {
14            return Ok(self.clone());
15        }
16
17        // Simple linear interpolation resampling
18        // Improved: Linear interpolation resampling (upgraded from nearest neighbor)
19        let ratio = target_rate as f32 / self.sample_rate as f32;
20        let new_length = (self.samples.len() as f32 * ratio) as usize;
21        let mut resampled = Vec::with_capacity(new_length);
22
23        for i in 0..new_length {
24            let src_index = i as f32 / ratio;
25            let idx = src_index as usize;
26
27            if idx < self.samples.len() {
28                // Simple nearest neighbor for now
29                resampled.push(self.samples[idx]);
30            }
31        }
32
33        Ok(AudioBuffer::new(resampled, target_rate, self.channels))
34    }
35
36    /// Apply gain to audio (in dB)
37    ///
38    /// Uses SIMD acceleration for improved performance on large buffers.
39    pub fn apply_gain(&mut self, gain_db: f32) -> Result<()> {
40        let gain_linear = 10.0_f32.powf(gain_db / 20.0);
41
42        // Use SIMD optimization for buffers larger than 64 samples
43        if self.samples.len() > 64 && f32::simd_available() {
44            // Convert to Array1 for SIMD operations
45            let samples_array = Array1::from_vec(self.samples.clone());
46
47            // SIMD scalar multiplication
48            let gained = f32::simd_scalar_mul(&samples_array.view(), gain_linear);
49
50            // Clamp values to prevent clipping
51            self.samples = gained.iter().map(|&s| s.clamp(-1.0, 1.0)).collect();
52        } else {
53            // Fallback to scalar implementation for small buffers
54            for sample in &mut self.samples {
55                *sample *= gain_linear;
56                // Prevent clipping
57                *sample = sample.clamp(-1.0, 1.0);
58            }
59        }
60
61        // Update metadata
62        self.update_metadata();
63        Ok(())
64    }
65
66    /// Normalize audio to peak amplitude
67    ///
68    /// Uses SIMD acceleration for improved performance on large buffers.
69    pub fn normalize(&mut self, target_peak: f32) -> Result<()> {
70        // Use SIMD optimization for buffers larger than 64 samples
71        let current_peak = if self.samples.len() > 64 && f32::simd_available() {
72            let samples_array = Array1::from_vec(self.samples.clone());
73            let abs_samples = f32::simd_abs(&samples_array.view());
74            f32::simd_max_element(&abs_samples.view())
75        } else {
76            self.samples.iter().map(|&s| s.abs()).fold(0.0, f32::max)
77        };
78
79        if current_peak > 0.0 {
80            let gain = target_peak / current_peak;
81
82            // Use SIMD for gain application if buffer is large enough
83            if self.samples.len() > 64 && f32::simd_available() {
84                let samples_array = Array1::from_vec(self.samples.clone());
85                let normalized = f32::simd_scalar_mul(&samples_array.view(), gain);
86                self.samples = normalized.to_vec();
87            } else {
88                for sample in &mut self.samples {
89                    *sample *= gain;
90                }
91            }
92            self.update_metadata();
93        }
94
95        Ok(())
96    }
97
98    /// Mix with another audio buffer
99    ///
100    /// Uses SIMD acceleration (FMA - fused multiply-add) for improved performance on large buffers.
101    pub fn mix(&mut self, other: &AudioBuffer, gain: f32) -> Result<()> {
102        if self.sample_rate != other.sample_rate {
103            return Err(VoirsError::audio_error(
104                "Sample rates must match for mixing",
105            ));
106        }
107
108        let mix_length = self.samples.len().min(other.samples.len());
109
110        // Use SIMD optimization for buffers larger than 64 samples
111        if mix_length > 64 && f32::simd_available() {
112            // Extract the portions to mix
113            let self_portion = Array1::from_vec(self.samples[..mix_length].to_vec());
114            let other_portion = Array1::from_vec(other.samples[..mix_length].to_vec());
115
116            // Create gain vector for SIMD multiplication
117            let gain_vec = Array1::from_elem(mix_length, gain);
118
119            // Use SIMD FMA: self + other * gain
120            let mixed = f32::simd_fma(
121                &other_portion.view(),
122                &gain_vec.view(),
123                &self_portion.view(),
124            );
125
126            // Clamp and update
127            for (i, &sample) in mixed.iter().enumerate() {
128                self.samples[i] = sample.clamp(-1.0, 1.0);
129            }
130        } else {
131            // Fallback to scalar implementation for small buffers
132            for i in 0..mix_length {
133                self.samples[i] += other.samples[i] * gain;
134                // Prevent clipping
135                self.samples[i] = self.samples[i].clamp(-1.0, 1.0);
136            }
137        }
138
139        self.update_metadata();
140        Ok(())
141    }
142
143    /// Append another audio buffer
144    pub fn append(&mut self, other: &AudioBuffer) -> Result<()> {
145        if self.sample_rate != other.sample_rate || self.channels != other.channels {
146            return Err(VoirsError::audio_error(
147                "Sample rate and channels must match for appending",
148            ));
149        }
150
151        self.samples.extend_from_slice(&other.samples);
152        self.update_metadata();
153        Ok(())
154    }
155
156    /// Split audio buffer at given time (in seconds)
157    pub fn split(&self, time_seconds: f32) -> Result<(AudioBuffer, AudioBuffer)> {
158        let split_sample = (time_seconds * self.sample_rate as f32 * self.channels as f32) as usize;
159
160        if split_sample >= self.samples.len() {
161            return Err(VoirsError::audio_error("Split time exceeds audio duration"));
162        }
163
164        let first_part = AudioBuffer::new(
165            self.samples[..split_sample].to_vec(),
166            self.sample_rate,
167            self.channels,
168        );
169
170        let second_part = AudioBuffer::new(
171            self.samples[split_sample..].to_vec(),
172            self.sample_rate,
173            self.channels,
174        );
175
176        Ok((first_part, second_part))
177    }
178
179    /// Fade in over specified duration
180    pub fn fade_in(&mut self, duration_seconds: f32) -> Result<()> {
181        let fade_samples =
182            (duration_seconds * self.sample_rate as f32 * self.channels as f32) as usize;
183        let fade_samples = fade_samples.min(self.samples.len());
184
185        for i in 0..fade_samples {
186            let fade_factor = i as f32 / fade_samples as f32;
187            self.samples[i] *= fade_factor;
188        }
189
190        self.update_metadata();
191        Ok(())
192    }
193
194    /// Fade out over specified duration
195    pub fn fade_out(&mut self, duration_seconds: f32) -> Result<()> {
196        let fade_samples =
197            (duration_seconds * self.sample_rate as f32 * self.channels as f32) as usize;
198        let fade_samples = fade_samples.min(self.samples.len());
199        let start_index = self.samples.len().saturating_sub(fade_samples);
200
201        for i in 0..fade_samples {
202            let fade_factor = 1.0 - (i as f32 / fade_samples as f32);
203            self.samples[start_index + i] *= fade_factor;
204        }
205
206        self.update_metadata();
207        Ok(())
208    }
209
210    /// Apply cross-fade between two buffers
211    pub fn crossfade(&mut self, other: &AudioBuffer, crossfade_duration: f32) -> Result<()> {
212        if self.sample_rate != other.sample_rate || self.channels != other.channels {
213            return Err(VoirsError::audio_error(
214                "Sample rate and channels must match for crossfading",
215            ));
216        }
217
218        let crossfade_samples =
219            (crossfade_duration * self.sample_rate as f32 * self.channels as f32) as usize;
220        let crossfade_samples = crossfade_samples
221            .min(self.samples.len())
222            .min(other.samples.len());
223
224        // Fade out the end of this buffer
225        let fade_start = self.samples.len().saturating_sub(crossfade_samples);
226        for i in 0..crossfade_samples {
227            let fade_factor = 1.0 - (i as f32 / crossfade_samples as f32);
228            self.samples[fade_start + i] *= fade_factor;
229        }
230
231        // Mix in the beginning of the other buffer with fade in
232        for i in 0..crossfade_samples {
233            let fade_factor = i as f32 / crossfade_samples as f32;
234            self.samples[fade_start + i] += other.samples[i] * fade_factor;
235            // Prevent clipping
236            self.samples[fade_start + i] = self.samples[fade_start + i].clamp(-1.0, 1.0);
237        }
238
239        // Append the rest of the other buffer
240        if crossfade_samples < other.samples.len() {
241            self.samples
242                .extend_from_slice(&other.samples[crossfade_samples..]);
243        }
244
245        self.update_metadata();
246        Ok(())
247    }
248
249    /// Apply a simple lowpass filter
250    pub fn lowpass_filter(&mut self, cutoff_frequency: f32) -> Result<()> {
251        // Simple single-pole lowpass filter
252        let dt = 1.0 / self.sample_rate as f32;
253        let rc = 1.0 / (2.0 * std::f32::consts::PI * cutoff_frequency);
254        let alpha = dt / (rc + dt);
255
256        let mut previous_output = 0.0;
257        for sample in &mut self.samples {
258            let output = alpha * (*sample) + (1.0 - alpha) * previous_output;
259            *sample = output;
260            previous_output = output;
261        }
262
263        self.update_metadata();
264        Ok(())
265    }
266
267    /// Apply a simple highpass filter
268    pub fn highpass_filter(&mut self, cutoff_frequency: f32) -> Result<()> {
269        // Simple single-pole highpass filter
270        let dt = 1.0 / self.sample_rate as f32;
271        let rc = 1.0 / (2.0 * std::f32::consts::PI * cutoff_frequency);
272        let alpha = rc / (rc + dt);
273
274        let mut previous_input = 0.0;
275        let mut previous_output = 0.0;
276
277        for sample in &mut self.samples {
278            let output = alpha * (previous_output + *sample - previous_input);
279            previous_input = *sample;
280            *sample = output;
281            previous_output = output;
282        }
283
284        self.update_metadata();
285        Ok(())
286    }
287
288    /// Apply time stretching (simple pitch-preserving speed change)
289    pub fn time_stretch(&self, stretch_factor: f32) -> Result<AudioBuffer> {
290        if stretch_factor <= 0.0 {
291            return Err(VoirsError::audio_error("Stretch factor must be positive"));
292        }
293
294        // Simple time-domain stretching (not high quality)
295        let new_length = (self.samples.len() as f32 / stretch_factor) as usize;
296        let mut stretched = Vec::with_capacity(new_length);
297
298        for i in 0..new_length {
299            let src_index = (i as f32 * stretch_factor) as usize;
300            if src_index < self.samples.len() {
301                stretched.push(self.samples[src_index]);
302            }
303        }
304
305        Ok(AudioBuffer::new(stretched, self.sample_rate, self.channels))
306    }
307
308    /// Apply pitch shifting using phase vocoder algorithm
309    pub fn pitch_shift(&self, semitones: f32) -> Result<AudioBuffer> {
310        use scirs2_core::Complex;
311        use std::f32::consts::PI;
312
313        if semitones == 0.0 {
314            return Ok(self.clone());
315        }
316
317        let pitch_factor = 2.0_f32.powf(semitones / 12.0);
318
319        // Phase vocoder parameters
320        let frame_size = 1024; // FFT frame size
321        let hop_size = frame_size / 4; // 75% overlap
322        let _overlap_factor = frame_size / hop_size;
323
324        // Prepare input with zero padding
325        let mut input_samples = self.samples.clone();
326        let padding = frame_size * 2;
327        input_samples.resize(input_samples.len() + padding, 0.0);
328
329        // Calculate output length (pitch shifting doesn't change duration)
330        let output_length = self.samples.len();
331        let mut output_samples = vec![0.0; output_length + padding];
332
333        // Phase vocoder state
334        let mut previous_phase = vec![0.0; frame_size / 2 + 1];
335        let mut synthesis_phase = vec![0.0; frame_size / 2 + 1];
336
337        // Hanning window for windowing
338        let window: Vec<f32> = (0..frame_size)
339            .map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / (frame_size - 1) as f32).cos()))
340            .collect();
341
342        // Process in overlapping frames
343        let mut input_pos = 0;
344        let mut output_pos = 0;
345
346        while input_pos + frame_size <= input_samples.len() {
347            // Extract and window the input frame
348            let frame_real: Vec<f32> = (0..frame_size)
349                .map(|i| input_samples[input_pos + i] * window[i])
350                .collect();
351
352            // Forward FFT (f32 -> f64 for scirs2_fft)
353            let frame_real_f64: Vec<f64> = frame_real.iter().map(|&x| x as f64).collect();
354            let frame_complex_f64 =
355                scirs2_fft::rfft(&frame_real_f64, None).map_err(|e| VoirsError::AudioError {
356                    message: format!("FFT failed: {}", e),
357                    buffer_info: None,
358                })?;
359
360            // Convert back to f32 Complex
361            let frame_complex: Vec<Complex<f32>> = frame_complex_f64
362                .iter()
363                .map(|c| Complex::new(c.re as f32, c.im as f32))
364                .collect();
365
366            // Phase vocoder processing
367            let mut modified_frame = vec![Complex::new(0.0f64, 0.0f64); frame_size / 2 + 1];
368
369            for k in 0..frame_complex.len() {
370                let magnitude = frame_complex[k].norm();
371                let phase = frame_complex[k].arg();
372
373                // Calculate phase difference
374                let phase_diff = phase - previous_phase[k];
375                previous_phase[k] = phase;
376
377                // Unwrap phase difference
378                let unwrapped_phase_diff = phase_diff;
379                let expected_phase_diff = 2.0 * PI * k as f32 * hop_size as f32 / frame_size as f32;
380                let phase_deviation = unwrapped_phase_diff - expected_phase_diff;
381
382                // Wrap to [-π, π]
383                let wrapped_deviation = ((phase_deviation + PI) % (2.0 * PI)) - PI;
384                let true_freq =
385                    2.0 * PI * k as f32 / frame_size as f32 + wrapped_deviation / hop_size as f32;
386
387                // Apply pitch shift to frequency
388                let shifted_freq = true_freq * pitch_factor;
389                let shifted_bin = shifted_freq * frame_size as f32 / (2.0 * PI);
390
391                if shifted_bin >= 0.0 && shifted_bin < (frame_size / 2) as f32 {
392                    let target_bin = shifted_bin.round() as usize;
393                    if target_bin < frame_size / 2 + 1 {
394                        // Update synthesis phase
395                        synthesis_phase[target_bin] += shifted_freq * hop_size as f32;
396
397                        // Set the shifted frequency component
398                        let new_complex = Complex::new(
399                            (magnitude * synthesis_phase[target_bin].cos()) as f64,
400                            (magnitude * synthesis_phase[target_bin].sin()) as f64,
401                        );
402                        modified_frame[target_bin] = new_complex;
403                    }
404                }
405            }
406
407            // Inverse FFT
408            let frame_output_f64 =
409                scirs2_fft::irfft(&modified_frame, Some(frame_size)).map_err(|e| {
410                    VoirsError::AudioError {
411                        message: format!("IFFT failed: {}", e),
412                        buffer_info: None,
413                    }
414                })?;
415            let frame_output: Vec<f32> = frame_output_f64.iter().map(|&x| x as f32).collect();
416
417            // Overlap-add synthesis with windowing
418            // For 75% overlap with Hanning window, normalization factor is 2/3
419            let norm_factor = 2.0 / 3.0;
420            for i in 0..frame_size {
421                if output_pos + i < output_samples.len() {
422                    let windowed_sample = frame_output[i] * window[i] * norm_factor;
423                    output_samples[output_pos + i] += windowed_sample;
424                }
425            }
426
427            // Advance positions
428            input_pos += hop_size;
429            output_pos += hop_size;
430        }
431
432        // Normalize and trim output
433        output_samples.truncate(output_length);
434
435        // Normalize the output to prevent clipping
436        let max_amplitude = output_samples
437            .iter()
438            .map(|&s: &f32| s.abs())
439            .fold(0.0f32, f32::max);
440        if max_amplitude > 1.0 {
441            let normalization_factor = 0.95 / max_amplitude;
442            for sample in &mut output_samples {
443                *sample *= normalization_factor;
444            }
445        }
446
447        Ok(AudioBuffer::new(
448            output_samples,
449            self.sample_rate,
450            self.channels,
451        ))
452    }
453
454    /// Apply pitch shifting using PSOLA (Pitch Synchronous Overlap and Add) algorithm
455    /// This method is more suitable for speech and preserves formants better
456    pub fn pitch_shift_psola(&self, semitones: f32) -> Result<AudioBuffer> {
457        if semitones == 0.0 {
458            return Ok(self.clone());
459        }
460
461        let pitch_factor = 2.0_f32.powf(semitones / 12.0);
462
463        // PSOLA parameters
464        let min_period = (self.sample_rate as f32 / 800.0) as usize; // ~800 Hz max
465        let max_period = (self.sample_rate as f32 / 50.0) as usize; // ~50 Hz min
466
467        // Simple pitch detection using autocorrelation
468        let pitch_periods = self.detect_pitch_periods(min_period, max_period)?;
469
470        if pitch_periods.is_empty() {
471            // Fallback to phase vocoder for non-pitched signals
472            return self.pitch_shift(semitones);
473        }
474
475        // Calculate output length
476        let output_length = self.samples.len();
477        let mut output_samples = vec![0.0; output_length];
478
479        // PSOLA synthesis
480        let mut output_pos = 0.0;
481        let mut input_idx = 0;
482
483        while input_idx < pitch_periods.len() - 1 && (output_pos as usize) < output_length {
484            let current_period = pitch_periods[input_idx];
485            let next_period = pitch_periods[input_idx + 1];
486            let period_length = next_period - current_period;
487
488            // Create windowed grain
489            let grain_size = period_length * 2; // Use 2 periods for good overlap
490            let grain_start = current_period.saturating_sub(period_length / 2);
491            let grain_end = (grain_start + grain_size).min(self.samples.len());
492
493            if grain_end > grain_start {
494                // Extract grain with Hanning window
495                let grain_length = grain_end - grain_start;
496                let window: Vec<f32> = (0..grain_length)
497                    .map(|i| {
498                        0.5 * (1.0
499                            - (2.0 * std::f32::consts::PI * i as f32 / (grain_length - 1) as f32)
500                                .cos())
501                    })
502                    .collect();
503
504                // Apply grain to output with overlap-add
505                for (i, &sample) in self.samples[grain_start..grain_end].iter().enumerate() {
506                    let windowed_sample = sample * window[i];
507                    let output_index = (output_pos as usize) + i;
508                    if output_index < output_samples.len() {
509                        output_samples[output_index] += windowed_sample;
510                    }
511                }
512            }
513
514            // Advance positions
515            output_pos += period_length as f32 / pitch_factor;
516            input_idx += 1;
517        }
518
519        // Normalize output
520        let max_amplitude = output_samples.iter().map(|&s| s.abs()).fold(0.0, f32::max);
521        if max_amplitude > 1.0 {
522            let normalization_factor = 0.95 / max_amplitude;
523            for sample in &mut output_samples {
524                *sample *= normalization_factor;
525            }
526        }
527
528        Ok(AudioBuffer::new(
529            output_samples,
530            self.sample_rate,
531            self.channels,
532        ))
533    }
534
535    /// Detect pitch periods in the audio signal using autocorrelation
536    fn detect_pitch_periods(&self, min_period: usize, max_period: usize) -> Result<Vec<usize>> {
537        let mut periods = Vec::new();
538        let analysis_window = max_period * 4;
539
540        let mut pos = 0;
541        while pos + analysis_window < self.samples.len() {
542            // Extract analysis window
543            let window = &self.samples[pos..pos + analysis_window];
544
545            // Compute autocorrelation
546            let mut max_correlation = 0.0;
547            let mut best_period = min_period;
548
549            for period in min_period..=max_period.min(analysis_window / 2) {
550                let mut correlation = 0.0;
551                let mut energy = 0.0;
552
553                for i in 0..(analysis_window - period) {
554                    correlation += window[i] * window[i + period];
555                    energy += window[i] * window[i];
556                }
557
558                if energy > 0.0 {
559                    let normalized_correlation = correlation / energy;
560                    if normalized_correlation > max_correlation {
561                        max_correlation = normalized_correlation;
562                        best_period = period;
563                    }
564                }
565            }
566
567            // Only accept periods with sufficient correlation
568            if max_correlation > 0.3 {
569                periods.push(pos + best_period);
570                pos += best_period;
571            } else {
572                pos += min_period; // Move forward by minimum period
573            }
574        }
575
576        Ok(periods)
577    }
578
579    /// Apply dynamic range compression
580    pub fn compress(
581        &mut self,
582        threshold: f32,
583        ratio: f32,
584        attack_ms: f32,
585        release_ms: f32,
586    ) -> Result<()> {
587        let attack_coeff = (-1.0 / (attack_ms * 0.001 * self.sample_rate as f32)).exp();
588        let release_coeff = (-1.0 / (release_ms * 0.001 * self.sample_rate as f32)).exp();
589
590        let mut envelope = 0.0;
591
592        for sample in &mut self.samples {
593            let input_level = sample.abs();
594
595            // Update envelope
596            if input_level > envelope {
597                envelope = input_level + (envelope - input_level) * attack_coeff;
598            } else {
599                envelope = input_level + (envelope - input_level) * release_coeff;
600            }
601
602            // Apply compression
603            if envelope > threshold {
604                let excess = envelope - threshold;
605                let compressed_excess = excess / ratio;
606                let gain_reduction = (threshold + compressed_excess) / envelope;
607                *sample *= gain_reduction;
608            }
609        }
610
611        self.update_metadata();
612        Ok(())
613    }
614
615    /// Apply reverb effect (simple delay-based reverb)
616    pub fn reverb(&mut self, room_size: f32, damping: f32, wet_level: f32) -> Result<()> {
617        let delay_samples = (room_size * self.sample_rate as f32 * 0.1) as usize; // Max 100ms delay
618        let delay_samples = delay_samples.max(1);
619
620        let mut delay_buffer = vec![0.0; delay_samples];
621        let mut delay_index = 0;
622
623        for sample in &mut self.samples {
624            // Read from delay buffer
625            let delayed_sample = delay_buffer[delay_index];
626
627            // Apply damping (lowpass filter)
628            let damped_sample = delayed_sample * (1.0 - damping) + *sample * damping;
629
630            // Write to delay buffer
631            delay_buffer[delay_index] = damped_sample;
632            delay_index = (delay_index + 1) % delay_samples;
633
634            // Mix wet and dry signals
635            *sample = *sample * (1.0 - wet_level) + delayed_sample * wet_level;
636        }
637
638        self.update_metadata();
639        Ok(())
640    }
641
642    /// Extract a portion of the audio buffer
643    pub fn extract(&self, start_seconds: f32, duration_seconds: f32) -> Result<AudioBuffer> {
644        let start_sample =
645            (start_seconds * self.sample_rate as f32 * self.channels as f32) as usize;
646        let duration_samples =
647            (duration_seconds * self.sample_rate as f32 * self.channels as f32) as usize;
648        let end_sample = (start_sample + duration_samples).min(self.samples.len());
649
650        if start_sample >= self.samples.len() {
651            return Err(VoirsError::audio_error("Start time exceeds audio duration"));
652        }
653
654        let extracted_samples = self.samples[start_sample..end_sample].to_vec();
655        Ok(AudioBuffer::new(
656            extracted_samples,
657            self.sample_rate,
658            self.channels,
659        ))
660    }
661
662    /// Calculate RMS (Root Mean Square) value for loudness
663    ///
664    /// Uses SIMD acceleration for improved performance on large buffers.
665    pub fn rms(&self) -> f32 {
666        if self.samples.is_empty() {
667            return 0.0;
668        }
669
670        // Use SIMD optimization for buffers larger than 64 samples
671        let sum_squares = if self.samples.len() > 64 && f32::simd_available() {
672            let samples_array = Array1::from_vec(self.samples.clone());
673            f32::simd_sum_squares(&samples_array.view())
674        } else {
675            self.samples.iter().map(|&s| s * s).sum()
676        };
677
678        (sum_squares / self.samples.len() as f32).sqrt()
679    }
680
681    /// Calculate peak amplitude
682    ///
683    /// Uses SIMD acceleration for improved performance on large buffers.
684    pub fn peak(&self) -> f32 {
685        // Use SIMD optimization for buffers larger than 64 samples
686        if self.samples.len() > 64 && f32::simd_available() {
687            let samples_array = Array1::from_vec(self.samples.clone());
688            let abs_samples = f32::simd_abs(&samples_array.view());
689            f32::simd_max_element(&abs_samples.view())
690        } else {
691            self.samples.iter().map(|&s| s.abs()).fold(0.0, f32::max)
692        }
693    }
694
695    /// Check if audio contains clipping
696    pub fn is_clipped(&self, threshold: f32) -> bool {
697        self.samples.iter().any(|&s| s.abs() >= threshold)
698    }
699
700    /// Apply soft clipping to prevent harsh distortion
701    pub fn soft_clip(&mut self, threshold: f32) -> Result<()> {
702        for sample in &mut self.samples {
703            if sample.abs() > threshold {
704                let sign = if *sample >= 0.0 { 1.0 } else { -1.0 };
705                *sample = sign * threshold * (1.0 - (-(*sample).abs() / threshold).exp());
706            }
707        }
708
709        self.update_metadata();
710        Ok(())
711    }
712}
713
714#[cfg(test)]
715mod tests {
716
717    use crate::audio::buffer::AudioBuffer;
718
719    #[test]
720    fn test_gain_application() {
721        let mut buffer = AudioBuffer::sine_wave(440.0, 0.1, 44100, 0.5);
722        let original_peak = buffer.metadata().peak_amplitude;
723
724        buffer.apply_gain(6.0).unwrap(); // +6dB gain
725
726        let new_peak = buffer.metadata().peak_amplitude;
727        assert!(new_peak > original_peak);
728    }
729
730    #[test]
731    fn test_normalization() {
732        let mut buffer = AudioBuffer::sine_wave(440.0, 0.1, 44100, 0.3);
733
734        buffer.normalize(0.8).unwrap();
735
736        let peak = buffer.metadata().peak_amplitude;
737        assert!((peak - 0.8).abs() < 0.01);
738    }
739
740    #[test]
741    fn test_mixing() {
742        let mut buffer1 = AudioBuffer::sine_wave(440.0, 0.1, 44100, 0.5);
743        let buffer2 = AudioBuffer::sine_wave(880.0, 0.1, 44100, 0.3);
744
745        let original_peak = buffer1.metadata().peak_amplitude;
746        buffer1.mix(&buffer2, 0.5).unwrap();
747
748        // Peak should be different after mixing
749        assert!(buffer1.metadata().peak_amplitude != original_peak);
750    }
751
752    #[test]
753    fn test_split() {
754        let buffer = AudioBuffer::sine_wave(440.0, 2.0, 44100, 0.5);
755
756        let (first, second) = buffer.split(1.0).unwrap();
757
758        assert!((first.duration() - 1.0).abs() < 0.01);
759        assert!((second.duration() - 1.0).abs() < 0.01);
760    }
761
762    #[test]
763    fn test_append() {
764        let mut buffer1 = AudioBuffer::sine_wave(440.0, 1.0, 44100, 0.5);
765        let buffer2 = AudioBuffer::sine_wave(880.0, 1.0, 44100, 0.3);
766
767        let original_duration = buffer1.duration();
768        buffer1.append(&buffer2).unwrap();
769
770        assert!((buffer1.duration() - 2.0 * original_duration).abs() < 0.01);
771    }
772
773    #[test]
774    fn test_fade_in_out() {
775        let mut buffer = AudioBuffer::sine_wave(440.0, 1.0, 44100, 0.5);
776
777        buffer.fade_in(0.1).unwrap();
778        buffer.fade_out(0.1).unwrap();
779
780        // First and last samples should be attenuated
781        assert!(buffer.samples()[0].abs() < 0.1);
782        assert!(buffer.samples()[buffer.len() - 1].abs() < 0.1);
783    }
784
785    #[test]
786    fn test_time_stretch() {
787        let buffer = AudioBuffer::sine_wave(440.0, 1.0, 44100, 0.5);
788
789        let stretched = buffer.time_stretch(2.0).unwrap();
790
791        // Duration should be halved
792        assert!((stretched.duration() - 0.5).abs() < 0.01);
793    }
794
795    #[test]
796    fn test_extract() {
797        let buffer = AudioBuffer::sine_wave(440.0, 2.0, 44100, 0.5);
798
799        let extracted = buffer.extract(0.5, 1.0).unwrap();
800
801        assert!((extracted.duration() - 1.0).abs() < 0.01);
802    }
803
804    #[test]
805    fn test_rms_calculation() {
806        let buffer = AudioBuffer::sine_wave(440.0, 1.0, 44100, 0.5);
807
808        let rms = buffer.rms();
809
810        // For a sine wave, RMS should be amplitude / sqrt(2)
811        assert!((rms - 0.5 / 2.0_f32.sqrt()).abs() < 0.01);
812    }
813
814    #[test]
815    fn test_clipping_detection() {
816        let mut buffer = AudioBuffer::sine_wave(440.0, 0.1, 44100, 1.5);
817
818        // Should detect clipping at 1.0 threshold
819        assert!(buffer.is_clipped(1.0));
820
821        // Apply soft clipping
822        buffer.soft_clip(0.95).unwrap();
823
824        // Should no longer clip at 0.95 threshold
825        assert!(!buffer.is_clipped(0.95));
826    }
827
828    #[test]
829    fn test_resampling() {
830        let buffer = AudioBuffer::sine_wave(440.0, 1.0, 44100, 0.5);
831
832        let resampled = buffer.resample(22050).unwrap();
833
834        assert_eq!(resampled.sample_rate(), 22050);
835        assert_eq!(resampled.len(), 22050); // Half the samples
836    }
837
838    #[test]
839    fn test_pitch_shift_phase_vocoder() {
840        let buffer = AudioBuffer::sine_wave(440.0, 1.0, 44100, 0.5);
841        let original_duration = buffer.duration();
842
843        // Test pitch shift up
844        let shifted_up = buffer.pitch_shift(12.0).unwrap(); // One octave up
845        let shifted_duration = shifted_up.duration();
846
847        // Duration should remain approximately the same for pitch shifting
848        let duration_diff = (shifted_duration - original_duration).abs();
849        assert!(
850            duration_diff < 0.01,
851            "Original duration: {original_duration}, Shifted duration: {shifted_duration}, Difference: {duration_diff}"
852        );
853
854        // Should have same length
855        assert_eq!(shifted_up.len(), buffer.len());
856
857        // Test pitch shift down
858        let shifted_down = buffer.pitch_shift(-12.0).unwrap(); // One octave down
859        assert_eq!(shifted_down.len(), buffer.len());
860
861        // Test no change
862        let no_shift = buffer.pitch_shift(0.0).unwrap();
863        assert_eq!(no_shift.len(), buffer.len());
864
865        // Test small shift
866        let small_shift = buffer.pitch_shift(2.0).unwrap(); // Two semitones up
867        assert_eq!(small_shift.len(), buffer.len());
868    }
869
870    #[test]
871    fn test_pitch_shift_psola() {
872        let buffer = AudioBuffer::sine_wave(440.0, 1.0, 44100, 0.5);
873        let original_duration = buffer.duration();
874
875        // Test PSOLA pitch shift up
876        let shifted_up = buffer.pitch_shift_psola(7.0).unwrap(); // Perfect fifth up
877        let shifted_duration = shifted_up.duration();
878
879        // Duration should remain approximately the same
880        let duration_diff = (shifted_duration - original_duration).abs();
881        assert!(
882            duration_diff < 0.01,
883            "Original duration: {original_duration}, Shifted duration: {shifted_duration}, Difference: {duration_diff}"
884        );
885
886        // Should have same length
887        assert_eq!(shifted_up.len(), buffer.len());
888
889        // Test PSOLA pitch shift down
890        let shifted_down = buffer.pitch_shift_psola(-7.0).unwrap(); // Perfect fifth down
891        assert_eq!(shifted_down.len(), buffer.len());
892
893        // Test no change
894        let no_shift = buffer.pitch_shift_psola(0.0).unwrap();
895        assert_eq!(no_shift.len(), buffer.len());
896    }
897
898    #[test]
899    fn test_pitch_detection() {
900        let buffer = AudioBuffer::sine_wave(440.0, 1.0, 44100, 0.5);
901
902        // Test pitch period detection
903        let min_period = (buffer.sample_rate as f32 / 800.0) as usize;
904        let max_period = (buffer.sample_rate as f32 / 50.0) as usize;
905        let periods = buffer.detect_pitch_periods(min_period, max_period).unwrap();
906
907        // Should detect some periods for a sine wave
908        assert!(!periods.is_empty(), "Should detect periods in a sine wave");
909
910        // All periods should be within valid range
911        for &period in &periods {
912            assert!(
913                period < buffer.samples.len(),
914                "Period {} exceeds buffer length {}",
915                period,
916                buffer.samples.len()
917            );
918        }
919    }
920
921    #[test]
922    fn test_pitch_shift_algorithms_comparison() {
923        let buffer = AudioBuffer::sine_wave(440.0, 1.0, 44100, 0.5);
924
925        // Test both algorithms with same pitch shift
926        let semitones = 5.0; // Perfect fourth up
927        let phase_vocoder_result = buffer.pitch_shift(semitones).unwrap();
928        let psola_result = buffer.pitch_shift_psola(semitones).unwrap();
929
930        // Both should produce same-length outputs
931        assert_eq!(phase_vocoder_result.len(), buffer.len());
932        assert_eq!(psola_result.len(), buffer.len());
933
934        // Both should have same sample rate and channels
935        assert_eq!(phase_vocoder_result.sample_rate, buffer.sample_rate);
936        assert_eq!(psola_result.sample_rate, buffer.sample_rate);
937        assert_eq!(phase_vocoder_result.channels, buffer.channels);
938        assert_eq!(psola_result.channels, buffer.channels);
939
940        // Both should have similar durations
941        let pv_duration = phase_vocoder_result.duration();
942        let psola_duration = psola_result.duration();
943        let duration_diff = (pv_duration - psola_duration).abs();
944        assert!(
945            duration_diff < 0.01,
946            "Phase vocoder duration: {pv_duration}, PSOLA duration: {psola_duration}"
947        );
948    }
949
950    #[test]
951    fn test_pitch_shift_edge_cases() {
952        let buffer = AudioBuffer::sine_wave(440.0, 1.0, 44100, 0.5);
953
954        // Test extreme pitch shifts
955        let extreme_up = buffer.pitch_shift(24.0).unwrap(); // Two octaves up
956        assert_eq!(extreme_up.len(), buffer.len());
957
958        let extreme_down = buffer.pitch_shift(-24.0).unwrap(); // Two octaves down
959        assert_eq!(extreme_down.len(), buffer.len());
960
961        // Test fractional semitones
962        let fractional = buffer.pitch_shift(1.5).unwrap(); // 1.5 semitones up
963        assert_eq!(fractional.len(), buffer.len());
964
965        // Test negative fractional semitones
966        let neg_fractional = buffer.pitch_shift(-2.5).unwrap(); // 2.5 semitones down
967        assert_eq!(neg_fractional.len(), buffer.len());
968    }
969
970    #[test]
971    fn test_pitch_shift_quality() {
972        let buffer = AudioBuffer::sine_wave(440.0, 1.0, 44100, 0.5);
973
974        // Test that pitch shifting doesn't introduce excessive artifacts
975        let shifted = buffer.pitch_shift(12.0).unwrap();
976
977        // Check that output is not silent
978        let max_amplitude = shifted.samples.iter().map(|&s| s.abs()).fold(0.0, f32::max);
979        assert!(max_amplitude > 0.0, "Output should not be silent");
980
981        // Check that output doesn't clip
982        assert!(
983            max_amplitude <= 1.0,
984            "Output should not exceed [-1, 1] range"
985        );
986
987        // Check that there's some variation in the output (not all zeros)
988        let mut has_variation = false;
989        let first_sample = shifted.samples[0];
990        for &sample in &shifted.samples {
991            if (sample - first_sample).abs() > 0.001 {
992                has_variation = true;
993                break;
994            }
995        }
996        assert!(has_variation, "Output should have some variation");
997    }
998}