Skip to main content

voirs_conversion/
transforms.rs

1//! Voice transformation algorithms
2
3use crate::{Error, Result};
4use scirs2_core::Complex;
5use scirs2_fft::RealFftPlanner;
6use serde::{Deserialize, Serialize};
7use std::f32::consts::PI;
8
9/// Generic transform trait for audio transformation operations
10pub trait Transform {
11    /// Apply transform to audio
12    fn apply(&self, input: &[f32]) -> Result<Vec<f32>>;
13
14    /// Get transform parameters
15    fn get_parameters(&self) -> std::collections::HashMap<String, f32>;
16}
17
18/// Pitch transformation
19#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
20pub struct PitchTransform {
21    /// Pitch shift factor (1.0 = no change, 2.0 = one octave up)
22    pub pitch_factor: f32,
23    /// Preserve formants
24    pub preserve_formants: bool,
25}
26
27impl PitchTransform {
28    /// Create new pitch transform
29    pub fn new(pitch_factor: f32) -> Self {
30        Self {
31            pitch_factor,
32            preserve_formants: true,
33        }
34    }
35}
36
37impl Transform for PitchTransform {
38    fn apply(&self, input: &[f32]) -> Result<Vec<f32>> {
39        if input.is_empty() {
40            return Ok(input.to_vec());
41        }
42
43        if (self.pitch_factor - 1.0).abs() < f32::EPSILON {
44            return Ok(input.to_vec());
45        }
46
47        // Use phase vocoder for high-quality pitch shifting
48        if self.preserve_formants {
49            self.apply_phase_vocoder_pitch_shift(input)
50        } else {
51            self.apply_simple_pitch_shift(input)
52        }
53    }
54
55    fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
56        let mut params = std::collections::HashMap::new();
57        params.insert("pitch_factor".to_string(), self.pitch_factor);
58        params.insert(
59            "preserve_formants".to_string(),
60            if self.preserve_formants { 1.0 } else { 0.0 },
61        );
62        params
63    }
64}
65
66/// Speed transformation
67#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
68pub struct SpeedTransform {
69    /// Speed factor (1.0 = no change, 2.0 = double speed)
70    pub speed_factor: f32,
71    /// Preserve pitch
72    pub preserve_pitch: bool,
73}
74
75impl SpeedTransform {
76    /// Create new speed transform
77    pub fn new(speed_factor: f32) -> Self {
78        Self {
79            speed_factor,
80            preserve_pitch: true,
81        }
82    }
83}
84
85impl Transform for SpeedTransform {
86    fn apply(&self, input: &[f32]) -> Result<Vec<f32>> {
87        if input.is_empty() {
88            return Ok(input.to_vec());
89        }
90
91        if (self.speed_factor - 1.0).abs() < f32::EPSILON {
92            return Ok(input.to_vec());
93        }
94
95        if self.preserve_pitch {
96            // Use PSOLA (Pitch Synchronous Overlap and Add) for pitch preservation
97            self.apply_psola_time_stretch(input)
98        } else {
99            // Simple resampling without pitch preservation
100            self.apply_linear_interpolation(input)
101        }
102    }
103
104    fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
105        let mut params = std::collections::HashMap::new();
106        params.insert("speed_factor".to_string(), self.speed_factor);
107        params.insert(
108            "preserve_pitch".to_string(),
109            if self.preserve_pitch { 1.0 } else { 0.0 },
110        );
111        params
112    }
113}
114
115/// Age transformation
116#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
117pub struct AgeTransform {
118    /// Target age (in years)
119    pub target_age: f32,
120    /// Source age (in years)
121    pub source_age: f32,
122}
123
124impl AgeTransform {
125    /// Create new age transform
126    pub fn new(source_age: f32, target_age: f32) -> Self {
127        Self {
128            target_age,
129            source_age,
130        }
131    }
132}
133
134impl Transform for AgeTransform {
135    fn apply(&self, input: &[f32]) -> Result<Vec<f32>> {
136        if input.is_empty() {
137            return Ok(input.to_vec());
138        }
139
140        if (self.target_age - self.source_age).abs() < 1.0 {
141            return Ok(input.to_vec());
142        }
143
144        // Apply age-related vocal tract modifications
145        self.apply_age_related_modifications(input)
146    }
147
148    fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
149        let mut params = std::collections::HashMap::new();
150        params.insert("target_age".to_string(), self.target_age);
151        params.insert("source_age".to_string(), self.source_age);
152        params
153    }
154}
155
156/// Gender transformation
157#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
158pub struct GenderTransform {
159    /// Target gender (-1.0 = male, 0.0 = neutral, 1.0 = female)
160    pub target_gender: f32,
161    /// Formant shift strength
162    pub formant_shift_strength: f32,
163}
164
165impl GenderTransform {
166    /// Create new gender transform
167    pub fn new(target_gender: f32) -> Self {
168        Self {
169            target_gender: target_gender.clamp(-1.0, 1.0),
170            formant_shift_strength: 0.5,
171        }
172    }
173}
174
175impl Transform for GenderTransform {
176    fn apply(&self, input: &[f32]) -> Result<Vec<f32>> {
177        if input.is_empty() {
178            return Ok(input.to_vec());
179        }
180
181        if self.target_gender.abs() < f32::EPSILON {
182            return Ok(input.to_vec());
183        }
184
185        // Apply gender-specific formant and pitch modifications
186        self.apply_gender_modifications(input)
187    }
188
189    fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
190        let mut params = std::collections::HashMap::new();
191        params.insert("target_gender".to_string(), self.target_gender);
192        params.insert(
193            "formant_shift_strength".to_string(),
194            self.formant_shift_strength,
195        );
196        params
197    }
198}
199
200/// Voice morpher for blending multiple voices with various interpolation methods
201#[derive(Debug, Clone)]
202pub struct VoiceMorpher {
203    /// Voice blend weights
204    pub blend_weights: Vec<f32>,
205    /// Source voices
206    pub source_voices: Vec<String>,
207    /// Morphing method
208    pub method: MorphingMethod,
209    /// Spectral interpolation strength
210    pub spectral_strength: f32,
211}
212
213/// Methods for voice morphing between multiple sources
214#[derive(Debug, Clone, Copy, PartialEq, Eq)]
215pub enum MorphingMethod {
216    /// Linear blending in time domain
217    LinearBlend,
218    /// Spectral interpolation
219    SpectralInterpolation,
220    /// Cross-fade morphing
221    CrossFade,
222    /// Feature-based morphing
223    FeatureBased,
224}
225
226impl VoiceMorpher {
227    /// Create new voice morpher
228    pub fn new(source_voices: Vec<String>, blend_weights: Vec<f32>) -> Self {
229        Self {
230            blend_weights,
231            source_voices,
232            method: MorphingMethod::LinearBlend,
233            spectral_strength: 0.5,
234        }
235    }
236
237    /// Create morpher with specific method
238    pub fn with_method(mut self, method: MorphingMethod) -> Self {
239        self.method = method;
240        self
241    }
242
243    /// Set spectral interpolation strength
244    pub fn with_spectral_strength(mut self, strength: f32) -> Self {
245        self.spectral_strength = strength.clamp(0.0, 1.0);
246        self
247    }
248
249    /// Morph between voices
250    pub fn morph(&self, inputs: &[Vec<f32>]) -> Result<Vec<f32>> {
251        if inputs.is_empty() {
252            return Err(Error::transform("No input voices for morphing".to_string()));
253        }
254
255        if inputs.len() == 1 {
256            return Ok(inputs[0].clone());
257        }
258
259        match self.method {
260            MorphingMethod::LinearBlend => self.linear_blend(inputs),
261            MorphingMethod::SpectralInterpolation => self.spectral_interpolation(inputs),
262            MorphingMethod::CrossFade => self.cross_fade(inputs),
263            MorphingMethod::FeatureBased => self.feature_based_morph(inputs),
264        }
265    }
266
267    fn linear_blend(&self, inputs: &[Vec<f32>]) -> Result<Vec<f32>> {
268        let output_len = inputs.iter().map(|v| v.len()).max().unwrap_or(0);
269        let mut output = vec![0.0; output_len];
270
271        // Normalize weights
272        let total_weight: f32 = self.blend_weights.iter().sum();
273        let normalized_weights: Vec<f32> = if total_weight > 0.0 {
274            self.blend_weights
275                .iter()
276                .map(|w| w / total_weight)
277                .collect()
278        } else {
279            vec![1.0 / inputs.len() as f32; inputs.len()]
280        };
281
282        for (i, input) in inputs.iter().enumerate() {
283            let weight = normalized_weights.get(i).copied().unwrap_or(0.0);
284            for (j, &sample) in input.iter().enumerate() {
285                if j < output_len {
286                    output[j] += sample * weight;
287                }
288            }
289        }
290
291        Ok(output)
292    }
293
294    fn spectral_interpolation(&self, inputs: &[Vec<f32>]) -> Result<Vec<f32>> {
295        if inputs.len() != 2 {
296            // Fall back to linear blend for more than 2 inputs
297            return self.linear_blend(inputs);
298        }
299
300        let input1 = &inputs[0];
301        let input2 = &inputs[1];
302        let blend_factor = self.blend_weights.get(1).copied().unwrap_or(0.5);
303
304        // Perform spectral interpolation using FFT
305        self.spectral_blend(input1, input2, blend_factor)
306    }
307
308    fn spectral_blend(
309        &self,
310        input1: &[f32],
311        input2: &[f32],
312        blend_factor: f32,
313    ) -> Result<Vec<f32>> {
314        let window_size = 1024;
315        let min_len = input1.len().min(input2.len());
316
317        if min_len < window_size {
318            // For short audio, use time-domain blending
319            let mut output = vec![0.0; min_len];
320            for i in 0..min_len {
321                output[i] = input1[i] * (1.0 - blend_factor) + input2[i] * blend_factor;
322            }
323            return Ok(output);
324        }
325
326        let mut planner = RealFftPlanner::<f32>::new();
327        let fft = planner.plan_fft_forward(window_size);
328        let ifft = planner.plan_fft_inverse(window_size);
329
330        let mut output = Vec::new();
331        let hop_size = window_size / 4;
332
333        for window_start in (0..min_len.saturating_sub(window_size)).step_by(hop_size) {
334            let window_end = (window_start + window_size).min(min_len);
335
336            // Extract windows
337            let mut window1 = vec![0.0; window_size];
338            let mut window2 = vec![0.0; window_size];
339
340            for (i, (&s1, &s2)) in input1[window_start..window_end]
341                .iter()
342                .zip(input2[window_start..window_end].iter())
343                .enumerate()
344            {
345                let hann = 0.5 - 0.5 * (2.0 * PI * i as f32 / (window_size - 1) as f32).cos();
346                window1[i] = s1 * hann;
347                window2[i] = s2 * hann;
348            }
349
350            // FFT
351            let mut spectrum1 = vec![Complex::new(0.0, 0.0); window_size / 2 + 1];
352            let mut spectrum2 = vec![Complex::new(0.0, 0.0); window_size / 2 + 1];
353
354            fft.process(&window1, &mut spectrum1);
355            fft.process(&window2, &mut spectrum2);
356
357            // Interpolate in frequency domain
358            let mut blended_spectrum = vec![Complex::new(0.0, 0.0); window_size / 2 + 1];
359            for (i, (s1, s2)) in spectrum1.iter().zip(spectrum2.iter()).enumerate() {
360                let mag1 = s1.norm();
361                let mag2 = s2.norm();
362                let phase1 = s1.arg();
363                let phase2 = s2.arg();
364
365                // Interpolate magnitude and phase
366                let blended_mag = mag1 * (1.0 - blend_factor) + mag2 * blend_factor;
367                let blended_phase = phase1 * (1.0 - blend_factor) + phase2 * blend_factor;
368
369                // Ensure DC and Nyquist components are real-valued
370                if i == 0 || i == blended_spectrum.len() - 1 {
371                    // DC and Nyquist components must be purely real
372                    blended_spectrum[i] = Complex::new(blended_mag, 0.0);
373                } else {
374                    blended_spectrum[i] = Complex::new(
375                        blended_mag * blended_phase.cos(),
376                        blended_mag * blended_phase.sin(),
377                    );
378                }
379            }
380
381            // IFFT
382            let mut time_domain = vec![0.0; window_size];
383            ifft.process(&blended_spectrum, &mut time_domain);
384
385            // Overlap-add
386            for (i, &sample) in time_domain.iter().enumerate() {
387                let output_idx = window_start + i;
388                if output_idx >= output.len() {
389                    output.resize(output_idx + 1, 0.0);
390                }
391                output[output_idx] += sample;
392            }
393        }
394
395        Ok(output)
396    }
397
398    fn cross_fade(&self, inputs: &[Vec<f32>]) -> Result<Vec<f32>> {
399        if inputs.len() != 2 {
400            return self.linear_blend(inputs);
401        }
402
403        let input1 = &inputs[0];
404        let input2 = &inputs[1];
405        let min_len = input1.len().min(input2.len());
406        let mut output = vec![0.0; min_len];
407
408        for i in 0..min_len {
409            let fade_factor = i as f32 / min_len as f32;
410            output[i] = input1[i] * (1.0 - fade_factor) + input2[i] * fade_factor;
411        }
412
413        Ok(output)
414    }
415
416    fn feature_based_morph(&self, inputs: &[Vec<f32>]) -> Result<Vec<f32>> {
417        // Simplified feature-based morphing
418        // In a full implementation, this would extract and morph features like formants, pitch, etc.
419        self.linear_blend(inputs)
420    }
421}
422
423impl Transform for VoiceMorpher {
424    fn apply(&self, input: &[f32]) -> Result<Vec<f32>> {
425        // For single input, just return weighted result
426        let weight = self.blend_weights.first().copied().unwrap_or(1.0);
427        Ok(input.iter().map(|x| x * weight).collect())
428    }
429
430    fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
431        let mut params = std::collections::HashMap::new();
432        for (i, &weight) in self.blend_weights.iter().enumerate() {
433            params.insert(format!("weight_{i}"), weight);
434        }
435        params.insert("spectral_strength".to_string(), self.spectral_strength);
436        params.insert("method".to_string(), self.method as u8 as f32);
437        params
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    #[test]
446    fn test_pitch_transform() {
447        let transform = PitchTransform::new(2.0);
448        let input = vec![0.1, 0.2, 0.3];
449        let output = transform.apply(&input).unwrap();
450
451        assert_eq!(output.len(), input.len());
452        assert_eq!(output[0], 0.2);
453        assert_eq!(output[1], 0.4);
454    }
455
456    #[test]
457    fn test_speed_transform() {
458        let transform = SpeedTransform::new(2.0);
459        let input = vec![0.1, 0.2, 0.3, 0.4];
460        let output = transform.apply(&input).unwrap();
461
462        assert_eq!(output.len(), 2); // Half length due to 2x speed
463    }
464
465    #[test]
466    fn test_age_transform() {
467        let transform = AgeTransform::new(30.0, 60.0);
468        let input = vec![0.1, 0.2, 0.3];
469        let output = transform.apply(&input).unwrap();
470
471        assert_eq!(output.len(), input.len());
472    }
473
474    #[test]
475    fn test_gender_transform() {
476        let transform = GenderTransform::new(1.0); // Female
477        let input = vec![0.1, 0.2, 0.3];
478        let output = transform.apply(&input).unwrap();
479
480        assert_eq!(output.len(), input.len());
481        assert!(output[0] > input[0]); // Should be scaled up
482    }
483
484    #[test]
485    fn test_voice_morpher() {
486        let morpher = VoiceMorpher::new(
487            vec!["voice1".to_string(), "voice2".to_string()],
488            vec![0.5, 0.5],
489        );
490
491        let inputs = vec![vec![0.1, 0.2], vec![0.3, 0.4]];
492
493        let output = morpher.morph(&inputs).unwrap();
494        assert_eq!(output.len(), 2);
495        assert_eq!(output[0], 0.2); // (0.1 * 0.5) + (0.3 * 0.5)
496        assert_eq!(output[1], 0.3); // (0.2 * 0.5) + (0.4 * 0.5)
497    }
498}
499
500// Implementation methods for transforms
501
502impl PitchTransform {
503    /// Apply phase vocoder pitch shifting with formant preservation
504    fn apply_phase_vocoder_pitch_shift(&self, input: &[f32]) -> Result<Vec<f32>> {
505        let window_size = 1024;
506        let hop_size = window_size / 4;
507        let overlap = window_size - hop_size;
508
509        if input.len() < window_size {
510            // For short audio, use simple pitch scaling
511            return self.apply_simple_pitch_shift(input);
512        }
513
514        let mut planner = RealFftPlanner::<f32>::new();
515        let fft = planner.plan_fft_forward(window_size);
516        let ifft = planner.plan_fft_inverse(window_size);
517
518        let mut output = Vec::new();
519        let mut phase_accum = vec![0.0; window_size / 2 + 1];
520        let mut last_phase = vec![0.0; window_size / 2 + 1];
521
522        // Process overlapping windows
523        for window_start in (0..input.len().saturating_sub(window_size)).step_by(hop_size) {
524            let window_end = (window_start + window_size).min(input.len());
525            let mut window = vec![0.0; window_size];
526
527            // Copy window with Hann windowing
528            for (i, &sample) in input[window_start..window_end].iter().enumerate() {
529                let hann = 0.5 - 0.5 * (2.0 * PI * i as f32 / (window_size - 1) as f32).cos();
530                window[i] = sample * hann;
531            }
532
533            // Forward FFT
534            let mut spectrum = vec![Complex::new(0.0, 0.0); window_size / 2 + 1];
535            fft.process(&window, &mut spectrum);
536
537            // Phase vocoder processing
538            let mut modified_spectrum = vec![Complex::new(0.0, 0.0); window_size / 2 + 1];
539
540            for (k, &bin) in spectrum.iter().enumerate() {
541                let magnitude = bin.norm();
542                let phase = bin.arg();
543
544                // Calculate expected phase advance
545                let expected_phase_advance =
546                    2.0 * PI * k as f32 * hop_size as f32 / window_size as f32;
547                let phase_diff = phase - last_phase[k] - expected_phase_advance;
548
549                // Wrap phase difference to [-π, π]
550                let wrapped_phase_diff = ((phase_diff + PI) % (2.0 * PI)) - PI;
551
552                // Calculate instantaneous frequency
553                let inst_freq = (k as f32 + wrapped_phase_diff / (2.0 * PI)) * self.pitch_factor;
554
555                // Update phase accumulator
556                phase_accum[k] += inst_freq * 2.0 * PI * hop_size as f32 / window_size as f32;
557
558                // Create modified spectrum
559                let new_k = (inst_freq.round() as usize).min(spectrum.len() - 1);
560                if new_k < modified_spectrum.len() {
561                    // Ensure DC and Nyquist components are real-valued
562                    if new_k == 0 || new_k == modified_spectrum.len() - 1 {
563                        // DC and Nyquist components must be purely real
564                        modified_spectrum[new_k] = Complex::new(magnitude, 0.0);
565                    } else {
566                        modified_spectrum[new_k] = Complex::new(
567                            magnitude * phase_accum[k].cos(),
568                            magnitude * phase_accum[k].sin(),
569                        );
570                    }
571                }
572
573                last_phase[k] = phase;
574            }
575
576            // Inverse FFT
577            let mut time_domain = vec![0.0; window_size];
578            ifft.process(&modified_spectrum, &mut time_domain);
579
580            // Apply window and overlap-add
581            for (i, &sample) in time_domain.iter().enumerate() {
582                let hann = 0.5 - 0.5 * (2.0 * PI * i as f32 / (window_size - 1) as f32).cos();
583                let windowed_sample = sample * hann;
584
585                let output_idx = window_start + i;
586                if output_idx >= output.len() {
587                    output.resize(output_idx + 1, 0.0);
588                }
589                output[output_idx] += windowed_sample;
590            }
591        }
592
593        Ok(output)
594    }
595
596    /// Apply simple pitch shifting using time-domain scaling
597    fn apply_simple_pitch_shift(&self, input: &[f32]) -> Result<Vec<f32>> {
598        if self.pitch_factor == 1.0 {
599            return Ok(input.to_vec());
600        }
601
602        // For simple pitch shift, maintain same length but apply frequency scaling
603        // This is a simplified version - real pitch shifting would use PSOLA or phase vocoder
604        let mut output = Vec::with_capacity(input.len());
605
606        for &sample in input {
607            // Simple approach: scale amplitude based on pitch factor for testing
608            let scaled_sample = sample * self.pitch_factor;
609            output.push(scaled_sample);
610        }
611
612        Ok(output)
613    }
614}
615
616impl SpeedTransform {
617    /// Apply PSOLA-based time stretching with pitch preservation
618    fn apply_psola_time_stretch(&self, input: &[f32]) -> Result<Vec<f32>> {
619        // Simplified PSOLA implementation
620        // In a full implementation, this would involve pitch period detection
621        // and pitch-synchronous windowing
622
623        let pitch_period = 100; // Estimated pitch period in samples
624        let output_len = (input.len() as f32 / self.speed_factor) as usize;
625        let mut output = vec![0.0; output_len];
626
627        let mut input_pos = 0;
628        let mut output_pos = 0;
629
630        while input_pos + pitch_period < input.len() && output_pos + pitch_period < output.len() {
631            // Extract pitch period
632            let period_start = input_pos;
633            let period_end = (input_pos + pitch_period).min(input.len());
634
635            // Apply Hann window to the period
636            for i in 0..(period_end - period_start) {
637                let hann = 0.5 - 0.5 * (2.0 * PI * i as f32 / pitch_period as f32).cos();
638                let sample = input[period_start + i] * hann;
639
640                if output_pos + i < output.len() {
641                    output[output_pos + i] += sample;
642                }
643            }
644
645            // Advance positions
646            input_pos += (pitch_period as f32 * self.speed_factor) as usize;
647            output_pos += pitch_period;
648        }
649
650        Ok(output)
651    }
652
653    /// Apply linear interpolation for speed change
654    fn apply_linear_interpolation(&self, input: &[f32]) -> Result<Vec<f32>> {
655        let output_len = (input.len() as f32 / self.speed_factor) as usize;
656        let mut output = Vec::with_capacity(output_len);
657
658        for i in 0..output_len {
659            let src_idx = i as f32 * self.speed_factor;
660            let idx = src_idx as usize;
661
662            if idx + 1 < input.len() {
663                let frac = src_idx - idx as f32;
664                let sample = input[idx] * (1.0 - frac) + input[idx + 1] * frac;
665                output.push(sample);
666            } else if idx < input.len() {
667                output.push(input[idx]);
668            } else {
669                output.push(0.0);
670            }
671        }
672
673        Ok(output)
674    }
675}
676
677impl AgeTransform {
678    /// Apply age-related vocal tract modifications
679    fn apply_age_related_modifications(&self, input: &[f32]) -> Result<Vec<f32>> {
680        let mut output = input.to_vec();
681
682        // Age affects vocal tract length and thus formant frequencies
683        let age_ratio = self.target_age / self.source_age.max(1.0);
684
685        // Children have shorter vocal tracts (higher formants)
686        // Adults have longer vocal tracts (lower formants)
687        let formant_shift = if self.target_age < 18.0 {
688            // Child voice: higher formants, brighter sound
689            1.0 + (18.0 - self.target_age) * 0.02
690        } else if self.target_age > 60.0 {
691            // Elderly voice: slightly lower formants, reduced breath support
692            1.0 - (self.target_age - 60.0) * 0.01
693        } else {
694            // Adult voice
695            age_ratio.sqrt()
696        };
697
698        // Apply spectral modifications
699        output = self.apply_spectral_scaling(&output, formant_shift)?;
700
701        // Apply age-related amplitude modifications
702        if self.target_age > 60.0 {
703            // Elderly voice: add slight tremor and reduced amplitude
704            output = self.apply_age_tremor(&output)?;
705        } else if self.target_age < 12.0 {
706            // Child voice: higher pitch variability
707            output = self.apply_child_characteristics(&output)?;
708        }
709
710        Ok(output)
711    }
712
713    fn apply_spectral_scaling(&self, input: &[f32], scale_factor: f32) -> Result<Vec<f32>> {
714        // Simplified spectral scaling
715        Ok(input.iter().map(|&x| x * scale_factor).collect())
716    }
717
718    fn apply_age_tremor(&self, input: &[f32]) -> Result<Vec<f32>> {
719        // Add slight tremor characteristic of elderly voices
720        let tremor_freq = 6.0; // Hz
721        let tremor_depth = 0.05;
722
723        Ok(input
724            .iter()
725            .enumerate()
726            .map(|(i, &x)| {
727                let tremor =
728                    1.0 + tremor_depth * (2.0 * PI * tremor_freq * i as f32 / 22050.0).sin();
729                x * tremor
730            })
731            .collect())
732    }
733
734    fn apply_child_characteristics(&self, input: &[f32]) -> Result<Vec<f32>> {
735        // Apply characteristics of child voices
736        Ok(input.iter().map(|&x| x * 1.1).collect()) // Slightly brighter
737    }
738}
739
740impl GenderTransform {
741    /// Apply gender-specific modifications
742    fn apply_gender_modifications(&self, input: &[f32]) -> Result<Vec<f32>> {
743        let mut output = input.to_vec();
744
745        if self.target_gender > 0.0 {
746            // Feminize: raise formants, adjust pitch contour
747            output = self.apply_feminization(&output)?;
748        } else if self.target_gender < 0.0 {
749            // Masculinize: lower formants, deepen resonance
750            output = self.apply_masculinization(&output)?;
751        }
752
753        Ok(output)
754    }
755
756    fn apply_feminization(&self, input: &[f32]) -> Result<Vec<f32>> {
757        // Raise formant frequencies (shorter vocal tract simulation)
758        let formant_shift = 1.0 + (self.target_gender * self.formant_shift_strength * 0.15);
759
760        // Apply formant shifting and brightness enhancement
761        let mut output = input
762            .iter()
763            .map(|&x| x * formant_shift)
764            .collect::<Vec<f32>>();
765
766        // Add slight breathiness
767        for (i, sample) in output.iter_mut().enumerate() {
768            let breathiness = 0.02 * (i as f32 * 0.01).sin();
769            *sample += breathiness * self.target_gender;
770        }
771
772        Ok(output)
773    }
774
775    fn apply_masculinization(&self, input: &[f32]) -> Result<Vec<f32>> {
776        // Lower formant frequencies (longer vocal tract simulation)
777        let formant_shift = 1.0 + (self.target_gender * self.formant_shift_strength * 0.15);
778
779        // Apply formant shifting and add depth
780        let output = input
781            .iter()
782            .map(|&x| x * formant_shift)
783            .collect::<Vec<f32>>();
784
785        Ok(output)
786    }
787}
788
789/// Multi-channel audio data structure with per-channel samples
790#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
791pub struct MultiChannelAudio {
792    /// Audio samples organized as \[channel\]\[sample\]
793    pub channels: Vec<Vec<f32>>,
794    /// Sample rate
795    pub sample_rate: u32,
796}
797
798impl MultiChannelAudio {
799    /// Create new multi-channel audio
800    pub fn new(channels: Vec<Vec<f32>>, sample_rate: u32) -> Result<Self> {
801        if channels.is_empty() {
802            return Err(Error::transform("No channels provided".to_string()));
803        }
804
805        // Validate all channels have the same length
806        let first_len = channels[0].len();
807        if !channels.iter().all(|ch| ch.len() == first_len) {
808            return Err(Error::transform(
809                "All channels must have the same length".to_string(),
810            ));
811        }
812
813        Ok(Self {
814            channels,
815            sample_rate,
816        })
817    }
818
819    /// Create from interleaved samples
820    pub fn from_interleaved(data: &[f32], num_channels: usize, sample_rate: u32) -> Result<Self> {
821        if num_channels == 0 {
822            return Err(Error::Transform {
823                transform_type: "channel_validation".to_string(),
824                message: "Number of channels must be greater than 0".to_string(),
825                context: None,
826                recovery_suggestions: Box::new(vec![
827                    "Ensure num_channels parameter is greater than 0".to_string(),
828                    "Check audio format specification".to_string(),
829                ]),
830            });
831        }
832
833        if !data.len().is_multiple_of(num_channels) {
834            return Err(Error::Transform {
835                transform_type: "interleaved_validation".to_string(),
836                message: "Data length must be divisible by number of channels".to_string(),
837                context: None,
838                recovery_suggestions: Box::new(vec![
839                    "Ensure audio data length matches channel count".to_string(),
840                    "Verify audio format is properly structured".to_string(),
841                ]),
842            });
843        }
844
845        let samples_per_channel = data.len() / num_channels;
846        let mut channels = vec![Vec::with_capacity(samples_per_channel); num_channels];
847
848        for (i, &sample) in data.iter().enumerate() {
849            channels[i % num_channels].push(sample);
850        }
851
852        Ok(Self {
853            channels,
854            sample_rate,
855        })
856    }
857
858    /// Convert to interleaved samples
859    pub fn to_interleaved(&self) -> Vec<f32> {
860        let num_channels = self.channels.len();
861        let samples_per_channel = self.channels[0].len();
862        let mut interleaved = Vec::with_capacity(num_channels * samples_per_channel);
863
864        for sample_idx in 0..samples_per_channel {
865            for channel in &self.channels {
866                interleaved.push(channel[sample_idx]);
867            }
868        }
869
870        interleaved
871    }
872
873    /// Get number of channels
874    pub fn num_channels(&self) -> usize {
875        self.channels.len()
876    }
877
878    /// Get number of samples per channel
879    pub fn num_samples(&self) -> usize {
880        self.channels.first().map(|ch| ch.len()).unwrap_or(0)
881    }
882
883    /// Convert to mono by averaging channels
884    pub fn to_mono(&self) -> Vec<f32> {
885        let samples_per_channel = self.num_samples();
886        let num_channels = self.num_channels() as f32;
887
888        let mut mono = Vec::with_capacity(samples_per_channel);
889
890        for sample_idx in 0..samples_per_channel {
891            let sum: f32 = self.channels.iter().map(|ch| ch[sample_idx]).sum();
892            mono.push(sum / num_channels);
893        }
894
895        mono
896    }
897}
898
899/// Multi-channel transform trait for processing multi-channel audio
900pub trait MultiChannelTransform {
901    /// Apply transform to multi-channel audio
902    fn apply_multichannel(&self, input: &MultiChannelAudio) -> Result<MultiChannelAudio>;
903
904    /// Get transform parameters
905    fn get_parameters(&self) -> std::collections::HashMap<String, f32>;
906}
907
908/// Channel processing strategy for multi-channel transformations
909#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
910pub enum ChannelStrategy {
911    /// Process each channel independently
912    Independent,
913    /// Process channels with cross-channel correlation
914    Correlated,
915    /// Convert to mono, process, then expand to multi-channel
916    MonoExpanded,
917    /// Use mid/side processing for stereo
918    MidSide,
919}
920
921/// Multi-channel configuration defining processing parameters
922#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
923pub struct MultiChannelConfig {
924    /// Processing strategy
925    pub strategy: ChannelStrategy,
926    /// Channel gains (for output balancing)
927    pub channel_gains: Vec<f32>,
928    /// Enable channel crosstalk simulation
929    pub enable_crosstalk: bool,
930    /// Crosstalk amount (0.0-1.0)
931    pub crosstalk_amount: f32,
932}
933
934impl Default for MultiChannelConfig {
935    fn default() -> Self {
936        Self {
937            strategy: ChannelStrategy::Independent,
938            channel_gains: vec![1.0, 1.0], // Default stereo
939            enable_crosstalk: false,
940            crosstalk_amount: 0.05,
941        }
942    }
943}
944
945/// Multi-channel pitch transform with per-channel control
946#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
947pub struct MultiChannelPitchTransform {
948    /// Base pitch transform
949    pub base_transform: PitchTransform,
950    /// Multi-channel configuration
951    pub config: MultiChannelConfig,
952    /// Per-channel pitch adjustments
953    pub channel_pitch_factors: Vec<f32>,
954}
955
956impl MultiChannelPitchTransform {
957    /// Create new multi-channel pitch transform
958    pub fn new(pitch_factor: f32, num_channels: usize) -> Self {
959        Self {
960            base_transform: PitchTransform::new(pitch_factor),
961            config: MultiChannelConfig {
962                channel_gains: vec![1.0; num_channels],
963                ..Default::default()
964            },
965            channel_pitch_factors: vec![pitch_factor; num_channels],
966        }
967    }
968
969    /// Create stereo pitch transform with independent channel factors
970    pub fn stereo(left_pitch: f32, right_pitch: f32) -> Self {
971        Self {
972            base_transform: PitchTransform::new((left_pitch + right_pitch) / 2.0),
973            config: MultiChannelConfig::default(),
974            channel_pitch_factors: vec![left_pitch, right_pitch],
975        }
976    }
977
978    /// Set channel-specific pitch factors
979    pub fn set_channel_pitch_factors(&mut self, factors: Vec<f32>) {
980        self.channel_pitch_factors = factors;
981    }
982}
983
984impl MultiChannelTransform for MultiChannelPitchTransform {
985    fn apply_multichannel(&self, input: &MultiChannelAudio) -> Result<MultiChannelAudio> {
986        match self.config.strategy {
987            ChannelStrategy::Independent => {
988                let mut output_channels = Vec::new();
989
990                for (ch_idx, channel) in input.channels.iter().enumerate() {
991                    let pitch_factor = self
992                        .channel_pitch_factors
993                        .get(ch_idx)
994                        .copied()
995                        .unwrap_or(self.base_transform.pitch_factor);
996
997                    let mut channel_transform = self.base_transform.clone();
998                    channel_transform.pitch_factor = pitch_factor;
999
1000                    let processed_channel = channel_transform.apply(channel)?;
1001                    output_channels.push(processed_channel);
1002                }
1003
1004                self.apply_channel_processing(output_channels, input.sample_rate)
1005            }
1006
1007            ChannelStrategy::MidSide => {
1008                if input.num_channels() != 2 {
1009                    return Err(Error::Transform {
1010                        transform_type: "mid_side_validation".to_string(),
1011                        message: "Mid/Side processing requires exactly 2 channels".to_string(),
1012                        context: None,
1013                        recovery_suggestions: Box::new(vec![
1014                            "Convert audio to stereo format before Mid/Side processing".to_string(),
1015                            "Use a different transform for non-stereo audio".to_string(),
1016                        ]),
1017                    });
1018                }
1019
1020                let left = &input.channels[0];
1021                let right = &input.channels[1];
1022
1023                // Convert to Mid/Side
1024                let mid: Vec<f32> = left
1025                    .iter()
1026                    .zip(right.iter())
1027                    .map(|(&l, &r)| (l + r) / 2.0)
1028                    .collect();
1029
1030                let side: Vec<f32> = left
1031                    .iter()
1032                    .zip(right.iter())
1033                    .map(|(&l, &r)| (l - r) / 2.0)
1034                    .collect();
1035
1036                // Process Mid and Side independently
1037                let mid_factor = self
1038                    .channel_pitch_factors
1039                    .first()
1040                    .copied()
1041                    .unwrap_or(self.base_transform.pitch_factor);
1042                let side_factor = self
1043                    .channel_pitch_factors
1044                    .get(1)
1045                    .copied()
1046                    .unwrap_or(self.base_transform.pitch_factor);
1047
1048                let mut mid_transform = self.base_transform.clone();
1049                mid_transform.pitch_factor = mid_factor;
1050                let processed_mid = mid_transform.apply(&mid)?;
1051
1052                let mut side_transform = self.base_transform.clone();
1053                side_transform.pitch_factor = side_factor;
1054                let processed_side = side_transform.apply(&side)?;
1055
1056                // Convert back to Left/Right
1057                let processed_left: Vec<f32> = processed_mid
1058                    .iter()
1059                    .zip(processed_side.iter())
1060                    .map(|(&m, &s)| m + s)
1061                    .collect();
1062
1063                let processed_right: Vec<f32> = processed_mid
1064                    .iter()
1065                    .zip(processed_side.iter())
1066                    .map(|(&m, &s)| m - s)
1067                    .collect();
1068
1069                self.apply_channel_processing(
1070                    vec![processed_left, processed_right],
1071                    input.sample_rate,
1072                )
1073            }
1074
1075            ChannelStrategy::MonoExpanded => {
1076                let mono = input.to_mono();
1077                let processed_mono = self.base_transform.apply(&mono)?;
1078
1079                // Expand mono to all channels
1080                let mut output_channels = Vec::new();
1081                for ch_idx in 0..input.num_channels() {
1082                    let gain = self
1083                        .config
1084                        .channel_gains
1085                        .get(ch_idx)
1086                        .copied()
1087                        .unwrap_or(1.0);
1088                    let channel = processed_mono.iter().map(|&s| s * gain).collect();
1089                    output_channels.push(channel);
1090                }
1091
1092                self.apply_channel_processing(output_channels, input.sample_rate)
1093            }
1094
1095            ChannelStrategy::Correlated => {
1096                // Process channels with correlation awareness
1097                let mut output_channels = Vec::new();
1098                let correlation_matrix = self.calculate_channel_correlation(input);
1099
1100                for (ch_idx, channel) in input.channels.iter().enumerate() {
1101                    let mut correlated_channel = channel.clone();
1102
1103                    // Apply correlation-based adjustments
1104                    for (other_idx, other_channel) in input.channels.iter().enumerate() {
1105                        if ch_idx != other_idx {
1106                            let correlation = correlation_matrix[ch_idx][other_idx];
1107                            let influence = correlation * 0.1; // Limit influence
1108
1109                            for (i, &other_sample) in other_channel.iter().enumerate() {
1110                                if i < correlated_channel.len() {
1111                                    correlated_channel[i] += other_sample * influence;
1112                                }
1113                            }
1114                        }
1115                    }
1116
1117                    let pitch_factor = self
1118                        .channel_pitch_factors
1119                        .get(ch_idx)
1120                        .copied()
1121                        .unwrap_or(self.base_transform.pitch_factor);
1122
1123                    let mut channel_transform = self.base_transform.clone();
1124                    channel_transform.pitch_factor = pitch_factor;
1125
1126                    let processed_channel = channel_transform.apply(&correlated_channel)?;
1127                    output_channels.push(processed_channel);
1128                }
1129
1130                self.apply_channel_processing(output_channels, input.sample_rate)
1131            }
1132        }
1133    }
1134
1135    fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
1136        let mut params = Transform::get_parameters(&self.base_transform);
1137        params.insert(
1138            "num_channels".to_string(),
1139            self.config.channel_gains.len() as f32,
1140        );
1141        params.insert("crosstalk_amount".to_string(), self.config.crosstalk_amount);
1142
1143        for (i, &factor) in self.channel_pitch_factors.iter().enumerate() {
1144            params.insert(format!("channel_{i}_pitch"), factor);
1145        }
1146
1147        params
1148    }
1149}
1150
1151impl MultiChannelPitchTransform {
1152    /// Apply channel-specific processing (gains, crosstalk)
1153    fn apply_channel_processing(
1154        &self,
1155        mut channels: Vec<Vec<f32>>,
1156        sample_rate: u32,
1157    ) -> Result<MultiChannelAudio> {
1158        // Apply channel gains
1159        for (ch_idx, channel) in channels.iter_mut().enumerate() {
1160            let gain = self
1161                .config
1162                .channel_gains
1163                .get(ch_idx)
1164                .copied()
1165                .unwrap_or(1.0);
1166            for sample in channel.iter_mut() {
1167                *sample *= gain;
1168            }
1169        }
1170
1171        // Apply crosstalk if enabled
1172        if self.config.enable_crosstalk && channels.len() > 1 {
1173            self.apply_crosstalk(&mut channels);
1174        }
1175
1176        MultiChannelAudio::new(channels, sample_rate)
1177    }
1178
1179    /// Apply crosstalk between channels
1180    fn apply_crosstalk(&self, channels: &mut [Vec<f32>]) {
1181        let num_channels = channels.len();
1182        let crosstalk = self.config.crosstalk_amount;
1183
1184        // Create a copy for crosstalk calculation
1185        let original_channels: Vec<Vec<f32>> = channels.to_vec();
1186
1187        for (ch_idx, channel) in channels.iter_mut().enumerate() {
1188            for (sample_idx, sample) in channel.iter_mut().enumerate() {
1189                let mut crosstalk_sum = 0.0;
1190                let mut count = 0;
1191
1192                // Add crosstalk from other channels
1193                for (other_idx, other_channel) in original_channels.iter().enumerate() {
1194                    if ch_idx != other_idx && sample_idx < other_channel.len() {
1195                        crosstalk_sum += other_channel[sample_idx];
1196                        count += 1;
1197                    }
1198                }
1199
1200                if count > 0 {
1201                    let avg_crosstalk = crosstalk_sum / count as f32;
1202                    *sample = *sample * (1.0 - crosstalk) + avg_crosstalk * crosstalk;
1203                }
1204            }
1205        }
1206    }
1207
1208    /// Calculate correlation matrix between channels
1209    fn calculate_channel_correlation(&self, input: &MultiChannelAudio) -> Vec<Vec<f32>> {
1210        let num_channels = input.num_channels();
1211
1212        (0..num_channels)
1213            .map(|i| {
1214                (0..num_channels)
1215                    .map(|j| {
1216                        if i == j {
1217                            1.0
1218                        } else {
1219                            self.calculate_correlation(&input.channels[i], &input.channels[j])
1220                        }
1221                    })
1222                    .collect()
1223            })
1224            .collect()
1225    }
1226
1227    /// Calculate correlation between two channels
1228    fn calculate_correlation(&self, ch1: &[f32], ch2: &[f32]) -> f32 {
1229        if ch1.len() != ch2.len() || ch1.is_empty() {
1230            return 0.0;
1231        }
1232
1233        let mean1 = ch1.iter().sum::<f32>() / ch1.len() as f32;
1234        let mean2 = ch2.iter().sum::<f32>() / ch2.len() as f32;
1235
1236        let mut numerator = 0.0;
1237        let mut sum_sq1 = 0.0;
1238        let mut sum_sq2 = 0.0;
1239
1240        for (s1, s2) in ch1.iter().zip(ch2.iter()) {
1241            let diff1 = s1 - mean1;
1242            let diff2 = s2 - mean2;
1243
1244            numerator += diff1 * diff2;
1245            sum_sq1 += diff1 * diff1;
1246            sum_sq2 += diff2 * diff2;
1247        }
1248
1249        let denominator = (sum_sq1 * sum_sq2).sqrt();
1250        if denominator == 0.0 {
1251            0.0
1252        } else {
1253            numerator / denominator
1254        }
1255    }
1256}
1257
1258// Implement MultiChannelTransform for existing transforms by wrapping them
1259
1260impl MultiChannelTransform for PitchTransform {
1261    fn apply_multichannel(&self, input: &MultiChannelAudio) -> Result<MultiChannelAudio> {
1262        let multichannel_transform = MultiChannelPitchTransform {
1263            base_transform: self.clone(),
1264            config: MultiChannelConfig {
1265                channel_gains: vec![1.0; input.num_channels()],
1266                ..Default::default()
1267            },
1268            channel_pitch_factors: vec![self.pitch_factor; input.num_channels()],
1269        };
1270
1271        multichannel_transform.apply_multichannel(input)
1272    }
1273
1274    fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
1275        Transform::get_parameters(self)
1276    }
1277}
1278
1279impl MultiChannelTransform for SpeedTransform {
1280    fn apply_multichannel(&self, input: &MultiChannelAudio) -> Result<MultiChannelAudio> {
1281        let mut output_channels = Vec::new();
1282
1283        for channel in &input.channels {
1284            let processed_channel = self.apply(channel)?;
1285            output_channels.push(processed_channel);
1286        }
1287
1288        MultiChannelAudio::new(output_channels, input.sample_rate)
1289    }
1290
1291    fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
1292        Transform::get_parameters(self)
1293    }
1294}
1295
1296impl MultiChannelTransform for AgeTransform {
1297    fn apply_multichannel(&self, input: &MultiChannelAudio) -> Result<MultiChannelAudio> {
1298        let mut output_channels = Vec::new();
1299
1300        for channel in &input.channels {
1301            let processed_channel = self.apply(channel)?;
1302            output_channels.push(processed_channel);
1303        }
1304
1305        MultiChannelAudio::new(output_channels, input.sample_rate)
1306    }
1307
1308    fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
1309        Transform::get_parameters(self)
1310    }
1311}
1312
1313impl MultiChannelTransform for GenderTransform {
1314    fn apply_multichannel(&self, input: &MultiChannelAudio) -> Result<MultiChannelAudio> {
1315        let mut output_channels = Vec::new();
1316
1317        for channel in &input.channels {
1318            let processed_channel = self.apply(channel)?;
1319            output_channels.push(processed_channel);
1320        }
1321
1322        MultiChannelAudio::new(output_channels, input.sample_rate)
1323    }
1324
1325    fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
1326        Transform::get_parameters(self)
1327    }
1328}
1329
1330#[cfg(test)]
1331mod multichannel_tests {
1332    use super::*;
1333
1334    #[test]
1335    fn test_multichannel_audio_creation() {
1336        let channels = vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
1337        let audio = MultiChannelAudio::new(channels.clone(), 44100).unwrap();
1338
1339        assert_eq!(audio.num_channels(), 2);
1340        assert_eq!(audio.num_samples(), 3);
1341        assert_eq!(audio.channels, channels);
1342    }
1343
1344    #[test]
1345    fn test_interleaved_conversion() {
1346        let interleaved = vec![0.1, 0.4, 0.2, 0.5, 0.3, 0.6];
1347        let audio = MultiChannelAudio::from_interleaved(&interleaved, 2, 44100).unwrap();
1348
1349        assert_eq!(audio.num_channels(), 2);
1350        assert_eq!(audio.num_samples(), 3);
1351
1352        let back_to_interleaved = audio.to_interleaved();
1353        assert_eq!(back_to_interleaved, interleaved);
1354    }
1355
1356    #[test]
1357    fn test_mono_conversion() {
1358        let channels = vec![vec![0.2, 0.4, 0.6], vec![0.8, 1.0, 1.2]];
1359        let audio = MultiChannelAudio::new(channels, 44100).unwrap();
1360        let mono = audio.to_mono();
1361
1362        assert_eq!(mono.len(), 3);
1363        assert!((mono[0] - 0.5).abs() < f32::EPSILON); // (0.2 + 0.8) / 2
1364        assert!((mono[1] - 0.7).abs() < f32::EPSILON); // (0.4 + 1.0) / 2
1365        assert!((mono[2] - 0.9).abs() < f32::EPSILON); // (0.6 + 1.2) / 2
1366    }
1367
1368    #[test]
1369    fn test_multichannel_pitch_transform_independent() {
1370        let channels = vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
1371        let audio = MultiChannelAudio::new(channels, 44100).unwrap();
1372
1373        let transform = MultiChannelPitchTransform::new(2.0, 2);
1374        let result = transform.apply_multichannel(&audio).unwrap();
1375
1376        assert_eq!(result.num_channels(), 2);
1377        assert_eq!(result.num_samples(), 3);
1378    }
1379
1380    #[test]
1381    fn test_multichannel_pitch_transform_stereo() {
1382        let channels = vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
1383        let audio = MultiChannelAudio::new(channels, 44100).unwrap();
1384
1385        let transform = MultiChannelPitchTransform::stereo(1.5, 2.5);
1386        let result = transform.apply_multichannel(&audio).unwrap();
1387
1388        assert_eq!(result.num_channels(), 2);
1389        assert_eq!(transform.channel_pitch_factors, vec![1.5, 2.5]);
1390    }
1391
1392    #[test]
1393    fn test_multichannel_mid_side_processing() {
1394        let channels = vec![
1395            vec![0.8, 0.6, 0.4], // Left
1396            vec![0.2, 0.4, 0.6], // Right
1397        ];
1398        let audio = MultiChannelAudio::new(channels, 44100).unwrap();
1399
1400        let mut transform = MultiChannelPitchTransform::stereo(2.0, 2.0);
1401        transform.config.strategy = ChannelStrategy::MidSide;
1402
1403        let result = transform.apply_multichannel(&audio).unwrap();
1404        assert_eq!(result.num_channels(), 2);
1405    }
1406
1407    #[test]
1408    fn test_multichannel_transform_with_crosstalk() {
1409        let channels = vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, 0.5]];
1410        let audio = MultiChannelAudio::new(channels, 44100).unwrap();
1411
1412        let mut transform = MultiChannelPitchTransform::new(1.0, 2);
1413        transform.config.enable_crosstalk = true;
1414        transform.config.crosstalk_amount = 0.1;
1415
1416        let result = transform.apply_multichannel(&audio).unwrap();
1417
1418        // With crosstalk, channels should influence each other
1419        assert_eq!(result.num_channels(), 2);
1420        assert_ne!(result.channels[0], vec![1.0, 0.0, 0.5]);
1421        assert_ne!(result.channels[1], vec![0.0, 1.0, 0.5]);
1422    }
1423
1424    #[test]
1425    fn test_channel_correlation_calculation() {
1426        let channels = vec![
1427            vec![1.0, 2.0, 3.0],
1428            vec![1.0, 2.0, 3.0], // Perfect correlation
1429        ];
1430        let audio = MultiChannelAudio::new(channels, 44100).unwrap();
1431
1432        let transform = MultiChannelPitchTransform::new(1.0, 2);
1433        let correlation_matrix = transform.calculate_channel_correlation(&audio);
1434
1435        assert_eq!(correlation_matrix.len(), 2);
1436        assert_eq!(correlation_matrix[0].len(), 2);
1437        assert_eq!(correlation_matrix[0][0], 1.0); // Self-correlation
1438        assert!((correlation_matrix[0][1] - 1.0).abs() < f32::EPSILON); // Perfect correlation
1439    }
1440}