1use crate::{Error, Result};
4use scirs2_core::Complex;
5use scirs2_fft::RealFftPlanner;
6use serde::{Deserialize, Serialize};
7use std::f32::consts::PI;
8
9pub trait Transform {
11 fn apply(&self, input: &[f32]) -> Result<Vec<f32>>;
13
14 fn get_parameters(&self) -> std::collections::HashMap<String, f32>;
16}
17
18#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
20pub struct PitchTransform {
21 pub pitch_factor: f32,
23 pub preserve_formants: bool,
25}
26
27impl PitchTransform {
28 pub fn new(pitch_factor: f32) -> Self {
30 Self {
31 pitch_factor,
32 preserve_formants: true,
33 }
34 }
35}
36
37impl Transform for PitchTransform {
38 fn apply(&self, input: &[f32]) -> Result<Vec<f32>> {
39 if input.is_empty() {
40 return Ok(input.to_vec());
41 }
42
43 if (self.pitch_factor - 1.0).abs() < f32::EPSILON {
44 return Ok(input.to_vec());
45 }
46
47 if self.preserve_formants {
49 self.apply_phase_vocoder_pitch_shift(input)
50 } else {
51 self.apply_simple_pitch_shift(input)
52 }
53 }
54
55 fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
56 let mut params = std::collections::HashMap::new();
57 params.insert("pitch_factor".to_string(), self.pitch_factor);
58 params.insert(
59 "preserve_formants".to_string(),
60 if self.preserve_formants { 1.0 } else { 0.0 },
61 );
62 params
63 }
64}
65
66#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
68pub struct SpeedTransform {
69 pub speed_factor: f32,
71 pub preserve_pitch: bool,
73}
74
75impl SpeedTransform {
76 pub fn new(speed_factor: f32) -> Self {
78 Self {
79 speed_factor,
80 preserve_pitch: true,
81 }
82 }
83}
84
85impl Transform for SpeedTransform {
86 fn apply(&self, input: &[f32]) -> Result<Vec<f32>> {
87 if input.is_empty() {
88 return Ok(input.to_vec());
89 }
90
91 if (self.speed_factor - 1.0).abs() < f32::EPSILON {
92 return Ok(input.to_vec());
93 }
94
95 if self.preserve_pitch {
96 self.apply_psola_time_stretch(input)
98 } else {
99 self.apply_linear_interpolation(input)
101 }
102 }
103
104 fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
105 let mut params = std::collections::HashMap::new();
106 params.insert("speed_factor".to_string(), self.speed_factor);
107 params.insert(
108 "preserve_pitch".to_string(),
109 if self.preserve_pitch { 1.0 } else { 0.0 },
110 );
111 params
112 }
113}
114
115#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
117pub struct AgeTransform {
118 pub target_age: f32,
120 pub source_age: f32,
122}
123
124impl AgeTransform {
125 pub fn new(source_age: f32, target_age: f32) -> Self {
127 Self {
128 target_age,
129 source_age,
130 }
131 }
132}
133
134impl Transform for AgeTransform {
135 fn apply(&self, input: &[f32]) -> Result<Vec<f32>> {
136 if input.is_empty() {
137 return Ok(input.to_vec());
138 }
139
140 if (self.target_age - self.source_age).abs() < 1.0 {
141 return Ok(input.to_vec());
142 }
143
144 self.apply_age_related_modifications(input)
146 }
147
148 fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
149 let mut params = std::collections::HashMap::new();
150 params.insert("target_age".to_string(), self.target_age);
151 params.insert("source_age".to_string(), self.source_age);
152 params
153 }
154}
155
156#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
158pub struct GenderTransform {
159 pub target_gender: f32,
161 pub formant_shift_strength: f32,
163}
164
165impl GenderTransform {
166 pub fn new(target_gender: f32) -> Self {
168 Self {
169 target_gender: target_gender.clamp(-1.0, 1.0),
170 formant_shift_strength: 0.5,
171 }
172 }
173}
174
175impl Transform for GenderTransform {
176 fn apply(&self, input: &[f32]) -> Result<Vec<f32>> {
177 if input.is_empty() {
178 return Ok(input.to_vec());
179 }
180
181 if self.target_gender.abs() < f32::EPSILON {
182 return Ok(input.to_vec());
183 }
184
185 self.apply_gender_modifications(input)
187 }
188
189 fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
190 let mut params = std::collections::HashMap::new();
191 params.insert("target_gender".to_string(), self.target_gender);
192 params.insert(
193 "formant_shift_strength".to_string(),
194 self.formant_shift_strength,
195 );
196 params
197 }
198}
199
200#[derive(Debug, Clone)]
202pub struct VoiceMorpher {
203 pub blend_weights: Vec<f32>,
205 pub source_voices: Vec<String>,
207 pub method: MorphingMethod,
209 pub spectral_strength: f32,
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
215pub enum MorphingMethod {
216 LinearBlend,
218 SpectralInterpolation,
220 CrossFade,
222 FeatureBased,
224}
225
226impl VoiceMorpher {
227 pub fn new(source_voices: Vec<String>, blend_weights: Vec<f32>) -> Self {
229 Self {
230 blend_weights,
231 source_voices,
232 method: MorphingMethod::LinearBlend,
233 spectral_strength: 0.5,
234 }
235 }
236
237 pub fn with_method(mut self, method: MorphingMethod) -> Self {
239 self.method = method;
240 self
241 }
242
243 pub fn with_spectral_strength(mut self, strength: f32) -> Self {
245 self.spectral_strength = strength.clamp(0.0, 1.0);
246 self
247 }
248
249 pub fn morph(&self, inputs: &[Vec<f32>]) -> Result<Vec<f32>> {
251 if inputs.is_empty() {
252 return Err(Error::transform("No input voices for morphing".to_string()));
253 }
254
255 if inputs.len() == 1 {
256 return Ok(inputs[0].clone());
257 }
258
259 match self.method {
260 MorphingMethod::LinearBlend => self.linear_blend(inputs),
261 MorphingMethod::SpectralInterpolation => self.spectral_interpolation(inputs),
262 MorphingMethod::CrossFade => self.cross_fade(inputs),
263 MorphingMethod::FeatureBased => self.feature_based_morph(inputs),
264 }
265 }
266
267 fn linear_blend(&self, inputs: &[Vec<f32>]) -> Result<Vec<f32>> {
268 let output_len = inputs.iter().map(|v| v.len()).max().unwrap_or(0);
269 let mut output = vec![0.0; output_len];
270
271 let total_weight: f32 = self.blend_weights.iter().sum();
273 let normalized_weights: Vec<f32> = if total_weight > 0.0 {
274 self.blend_weights
275 .iter()
276 .map(|w| w / total_weight)
277 .collect()
278 } else {
279 vec![1.0 / inputs.len() as f32; inputs.len()]
280 };
281
282 for (i, input) in inputs.iter().enumerate() {
283 let weight = normalized_weights.get(i).copied().unwrap_or(0.0);
284 for (j, &sample) in input.iter().enumerate() {
285 if j < output_len {
286 output[j] += sample * weight;
287 }
288 }
289 }
290
291 Ok(output)
292 }
293
294 fn spectral_interpolation(&self, inputs: &[Vec<f32>]) -> Result<Vec<f32>> {
295 if inputs.len() != 2 {
296 return self.linear_blend(inputs);
298 }
299
300 let input1 = &inputs[0];
301 let input2 = &inputs[1];
302 let blend_factor = self.blend_weights.get(1).copied().unwrap_or(0.5);
303
304 self.spectral_blend(input1, input2, blend_factor)
306 }
307
308 fn spectral_blend(
309 &self,
310 input1: &[f32],
311 input2: &[f32],
312 blend_factor: f32,
313 ) -> Result<Vec<f32>> {
314 let window_size = 1024;
315 let min_len = input1.len().min(input2.len());
316
317 if min_len < window_size {
318 let mut output = vec![0.0; min_len];
320 for i in 0..min_len {
321 output[i] = input1[i] * (1.0 - blend_factor) + input2[i] * blend_factor;
322 }
323 return Ok(output);
324 }
325
326 let mut planner = RealFftPlanner::<f32>::new();
327 let fft = planner.plan_fft_forward(window_size);
328 let ifft = planner.plan_fft_inverse(window_size);
329
330 let mut output = Vec::new();
331 let hop_size = window_size / 4;
332
333 for window_start in (0..min_len.saturating_sub(window_size)).step_by(hop_size) {
334 let window_end = (window_start + window_size).min(min_len);
335
336 let mut window1 = vec![0.0; window_size];
338 let mut window2 = vec![0.0; window_size];
339
340 for (i, (&s1, &s2)) in input1[window_start..window_end]
341 .iter()
342 .zip(input2[window_start..window_end].iter())
343 .enumerate()
344 {
345 let hann = 0.5 - 0.5 * (2.0 * PI * i as f32 / (window_size - 1) as f32).cos();
346 window1[i] = s1 * hann;
347 window2[i] = s2 * hann;
348 }
349
350 let mut spectrum1 = vec![Complex::new(0.0, 0.0); window_size / 2 + 1];
352 let mut spectrum2 = vec![Complex::new(0.0, 0.0); window_size / 2 + 1];
353
354 fft.process(&window1, &mut spectrum1);
355 fft.process(&window2, &mut spectrum2);
356
357 let mut blended_spectrum = vec![Complex::new(0.0, 0.0); window_size / 2 + 1];
359 for (i, (s1, s2)) in spectrum1.iter().zip(spectrum2.iter()).enumerate() {
360 let mag1 = s1.norm();
361 let mag2 = s2.norm();
362 let phase1 = s1.arg();
363 let phase2 = s2.arg();
364
365 let blended_mag = mag1 * (1.0 - blend_factor) + mag2 * blend_factor;
367 let blended_phase = phase1 * (1.0 - blend_factor) + phase2 * blend_factor;
368
369 if i == 0 || i == blended_spectrum.len() - 1 {
371 blended_spectrum[i] = Complex::new(blended_mag, 0.0);
373 } else {
374 blended_spectrum[i] = Complex::new(
375 blended_mag * blended_phase.cos(),
376 blended_mag * blended_phase.sin(),
377 );
378 }
379 }
380
381 let mut time_domain = vec![0.0; window_size];
383 ifft.process(&blended_spectrum, &mut time_domain);
384
385 for (i, &sample) in time_domain.iter().enumerate() {
387 let output_idx = window_start + i;
388 if output_idx >= output.len() {
389 output.resize(output_idx + 1, 0.0);
390 }
391 output[output_idx] += sample;
392 }
393 }
394
395 Ok(output)
396 }
397
398 fn cross_fade(&self, inputs: &[Vec<f32>]) -> Result<Vec<f32>> {
399 if inputs.len() != 2 {
400 return self.linear_blend(inputs);
401 }
402
403 let input1 = &inputs[0];
404 let input2 = &inputs[1];
405 let min_len = input1.len().min(input2.len());
406 let mut output = vec![0.0; min_len];
407
408 for i in 0..min_len {
409 let fade_factor = i as f32 / min_len as f32;
410 output[i] = input1[i] * (1.0 - fade_factor) + input2[i] * fade_factor;
411 }
412
413 Ok(output)
414 }
415
416 fn feature_based_morph(&self, inputs: &[Vec<f32>]) -> Result<Vec<f32>> {
417 self.linear_blend(inputs)
420 }
421}
422
423impl Transform for VoiceMorpher {
424 fn apply(&self, input: &[f32]) -> Result<Vec<f32>> {
425 let weight = self.blend_weights.first().copied().unwrap_or(1.0);
427 Ok(input.iter().map(|x| x * weight).collect())
428 }
429
430 fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
431 let mut params = std::collections::HashMap::new();
432 for (i, &weight) in self.blend_weights.iter().enumerate() {
433 params.insert(format!("weight_{i}"), weight);
434 }
435 params.insert("spectral_strength".to_string(), self.spectral_strength);
436 params.insert("method".to_string(), self.method as u8 as f32);
437 params
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444
445 #[test]
446 fn test_pitch_transform() {
447 let transform = PitchTransform::new(2.0);
448 let input = vec![0.1, 0.2, 0.3];
449 let output = transform.apply(&input).unwrap();
450
451 assert_eq!(output.len(), input.len());
452 assert_eq!(output[0], 0.2);
453 assert_eq!(output[1], 0.4);
454 }
455
456 #[test]
457 fn test_speed_transform() {
458 let transform = SpeedTransform::new(2.0);
459 let input = vec![0.1, 0.2, 0.3, 0.4];
460 let output = transform.apply(&input).unwrap();
461
462 assert_eq!(output.len(), 2); }
464
465 #[test]
466 fn test_age_transform() {
467 let transform = AgeTransform::new(30.0, 60.0);
468 let input = vec![0.1, 0.2, 0.3];
469 let output = transform.apply(&input).unwrap();
470
471 assert_eq!(output.len(), input.len());
472 }
473
474 #[test]
475 fn test_gender_transform() {
476 let transform = GenderTransform::new(1.0); let input = vec![0.1, 0.2, 0.3];
478 let output = transform.apply(&input).unwrap();
479
480 assert_eq!(output.len(), input.len());
481 assert!(output[0] > input[0]); }
483
484 #[test]
485 fn test_voice_morpher() {
486 let morpher = VoiceMorpher::new(
487 vec!["voice1".to_string(), "voice2".to_string()],
488 vec![0.5, 0.5],
489 );
490
491 let inputs = vec![vec![0.1, 0.2], vec![0.3, 0.4]];
492
493 let output = morpher.morph(&inputs).unwrap();
494 assert_eq!(output.len(), 2);
495 assert_eq!(output[0], 0.2); assert_eq!(output[1], 0.3); }
498}
499
500impl PitchTransform {
503 fn apply_phase_vocoder_pitch_shift(&self, input: &[f32]) -> Result<Vec<f32>> {
505 let window_size = 1024;
506 let hop_size = window_size / 4;
507 let overlap = window_size - hop_size;
508
509 if input.len() < window_size {
510 return self.apply_simple_pitch_shift(input);
512 }
513
514 let mut planner = RealFftPlanner::<f32>::new();
515 let fft = planner.plan_fft_forward(window_size);
516 let ifft = planner.plan_fft_inverse(window_size);
517
518 let mut output = Vec::new();
519 let mut phase_accum = vec![0.0; window_size / 2 + 1];
520 let mut last_phase = vec![0.0; window_size / 2 + 1];
521
522 for window_start in (0..input.len().saturating_sub(window_size)).step_by(hop_size) {
524 let window_end = (window_start + window_size).min(input.len());
525 let mut window = vec![0.0; window_size];
526
527 for (i, &sample) in input[window_start..window_end].iter().enumerate() {
529 let hann = 0.5 - 0.5 * (2.0 * PI * i as f32 / (window_size - 1) as f32).cos();
530 window[i] = sample * hann;
531 }
532
533 let mut spectrum = vec![Complex::new(0.0, 0.0); window_size / 2 + 1];
535 fft.process(&window, &mut spectrum);
536
537 let mut modified_spectrum = vec![Complex::new(0.0, 0.0); window_size / 2 + 1];
539
540 for (k, &bin) in spectrum.iter().enumerate() {
541 let magnitude = bin.norm();
542 let phase = bin.arg();
543
544 let expected_phase_advance =
546 2.0 * PI * k as f32 * hop_size as f32 / window_size as f32;
547 let phase_diff = phase - last_phase[k] - expected_phase_advance;
548
549 let wrapped_phase_diff = ((phase_diff + PI) % (2.0 * PI)) - PI;
551
552 let inst_freq = (k as f32 + wrapped_phase_diff / (2.0 * PI)) * self.pitch_factor;
554
555 phase_accum[k] += inst_freq * 2.0 * PI * hop_size as f32 / window_size as f32;
557
558 let new_k = (inst_freq.round() as usize).min(spectrum.len() - 1);
560 if new_k < modified_spectrum.len() {
561 if new_k == 0 || new_k == modified_spectrum.len() - 1 {
563 modified_spectrum[new_k] = Complex::new(magnitude, 0.0);
565 } else {
566 modified_spectrum[new_k] = Complex::new(
567 magnitude * phase_accum[k].cos(),
568 magnitude * phase_accum[k].sin(),
569 );
570 }
571 }
572
573 last_phase[k] = phase;
574 }
575
576 let mut time_domain = vec![0.0; window_size];
578 ifft.process(&modified_spectrum, &mut time_domain);
579
580 for (i, &sample) in time_domain.iter().enumerate() {
582 let hann = 0.5 - 0.5 * (2.0 * PI * i as f32 / (window_size - 1) as f32).cos();
583 let windowed_sample = sample * hann;
584
585 let output_idx = window_start + i;
586 if output_idx >= output.len() {
587 output.resize(output_idx + 1, 0.0);
588 }
589 output[output_idx] += windowed_sample;
590 }
591 }
592
593 Ok(output)
594 }
595
596 fn apply_simple_pitch_shift(&self, input: &[f32]) -> Result<Vec<f32>> {
598 if self.pitch_factor == 1.0 {
599 return Ok(input.to_vec());
600 }
601
602 let mut output = Vec::with_capacity(input.len());
605
606 for &sample in input {
607 let scaled_sample = sample * self.pitch_factor;
609 output.push(scaled_sample);
610 }
611
612 Ok(output)
613 }
614}
615
616impl SpeedTransform {
617 fn apply_psola_time_stretch(&self, input: &[f32]) -> Result<Vec<f32>> {
619 let pitch_period = 100; let output_len = (input.len() as f32 / self.speed_factor) as usize;
625 let mut output = vec![0.0; output_len];
626
627 let mut input_pos = 0;
628 let mut output_pos = 0;
629
630 while input_pos + pitch_period < input.len() && output_pos + pitch_period < output.len() {
631 let period_start = input_pos;
633 let period_end = (input_pos + pitch_period).min(input.len());
634
635 for i in 0..(period_end - period_start) {
637 let hann = 0.5 - 0.5 * (2.0 * PI * i as f32 / pitch_period as f32).cos();
638 let sample = input[period_start + i] * hann;
639
640 if output_pos + i < output.len() {
641 output[output_pos + i] += sample;
642 }
643 }
644
645 input_pos += (pitch_period as f32 * self.speed_factor) as usize;
647 output_pos += pitch_period;
648 }
649
650 Ok(output)
651 }
652
653 fn apply_linear_interpolation(&self, input: &[f32]) -> Result<Vec<f32>> {
655 let output_len = (input.len() as f32 / self.speed_factor) as usize;
656 let mut output = Vec::with_capacity(output_len);
657
658 for i in 0..output_len {
659 let src_idx = i as f32 * self.speed_factor;
660 let idx = src_idx as usize;
661
662 if idx + 1 < input.len() {
663 let frac = src_idx - idx as f32;
664 let sample = input[idx] * (1.0 - frac) + input[idx + 1] * frac;
665 output.push(sample);
666 } else if idx < input.len() {
667 output.push(input[idx]);
668 } else {
669 output.push(0.0);
670 }
671 }
672
673 Ok(output)
674 }
675}
676
677impl AgeTransform {
678 fn apply_age_related_modifications(&self, input: &[f32]) -> Result<Vec<f32>> {
680 let mut output = input.to_vec();
681
682 let age_ratio = self.target_age / self.source_age.max(1.0);
684
685 let formant_shift = if self.target_age < 18.0 {
688 1.0 + (18.0 - self.target_age) * 0.02
690 } else if self.target_age > 60.0 {
691 1.0 - (self.target_age - 60.0) * 0.01
693 } else {
694 age_ratio.sqrt()
696 };
697
698 output = self.apply_spectral_scaling(&output, formant_shift)?;
700
701 if self.target_age > 60.0 {
703 output = self.apply_age_tremor(&output)?;
705 } else if self.target_age < 12.0 {
706 output = self.apply_child_characteristics(&output)?;
708 }
709
710 Ok(output)
711 }
712
713 fn apply_spectral_scaling(&self, input: &[f32], scale_factor: f32) -> Result<Vec<f32>> {
714 Ok(input.iter().map(|&x| x * scale_factor).collect())
716 }
717
718 fn apply_age_tremor(&self, input: &[f32]) -> Result<Vec<f32>> {
719 let tremor_freq = 6.0; let tremor_depth = 0.05;
722
723 Ok(input
724 .iter()
725 .enumerate()
726 .map(|(i, &x)| {
727 let tremor =
728 1.0 + tremor_depth * (2.0 * PI * tremor_freq * i as f32 / 22050.0).sin();
729 x * tremor
730 })
731 .collect())
732 }
733
734 fn apply_child_characteristics(&self, input: &[f32]) -> Result<Vec<f32>> {
735 Ok(input.iter().map(|&x| x * 1.1).collect()) }
738}
739
740impl GenderTransform {
741 fn apply_gender_modifications(&self, input: &[f32]) -> Result<Vec<f32>> {
743 let mut output = input.to_vec();
744
745 if self.target_gender > 0.0 {
746 output = self.apply_feminization(&output)?;
748 } else if self.target_gender < 0.0 {
749 output = self.apply_masculinization(&output)?;
751 }
752
753 Ok(output)
754 }
755
756 fn apply_feminization(&self, input: &[f32]) -> Result<Vec<f32>> {
757 let formant_shift = 1.0 + (self.target_gender * self.formant_shift_strength * 0.15);
759
760 let mut output = input
762 .iter()
763 .map(|&x| x * formant_shift)
764 .collect::<Vec<f32>>();
765
766 for (i, sample) in output.iter_mut().enumerate() {
768 let breathiness = 0.02 * (i as f32 * 0.01).sin();
769 *sample += breathiness * self.target_gender;
770 }
771
772 Ok(output)
773 }
774
775 fn apply_masculinization(&self, input: &[f32]) -> Result<Vec<f32>> {
776 let formant_shift = 1.0 + (self.target_gender * self.formant_shift_strength * 0.15);
778
779 let output = input
781 .iter()
782 .map(|&x| x * formant_shift)
783 .collect::<Vec<f32>>();
784
785 Ok(output)
786 }
787}
788
789#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
791pub struct MultiChannelAudio {
792 pub channels: Vec<Vec<f32>>,
794 pub sample_rate: u32,
796}
797
798impl MultiChannelAudio {
799 pub fn new(channels: Vec<Vec<f32>>, sample_rate: u32) -> Result<Self> {
801 if channels.is_empty() {
802 return Err(Error::transform("No channels provided".to_string()));
803 }
804
805 let first_len = channels[0].len();
807 if !channels.iter().all(|ch| ch.len() == first_len) {
808 return Err(Error::transform(
809 "All channels must have the same length".to_string(),
810 ));
811 }
812
813 Ok(Self {
814 channels,
815 sample_rate,
816 })
817 }
818
819 pub fn from_interleaved(data: &[f32], num_channels: usize, sample_rate: u32) -> Result<Self> {
821 if num_channels == 0 {
822 return Err(Error::Transform {
823 transform_type: "channel_validation".to_string(),
824 message: "Number of channels must be greater than 0".to_string(),
825 context: None,
826 recovery_suggestions: Box::new(vec![
827 "Ensure num_channels parameter is greater than 0".to_string(),
828 "Check audio format specification".to_string(),
829 ]),
830 });
831 }
832
833 if !data.len().is_multiple_of(num_channels) {
834 return Err(Error::Transform {
835 transform_type: "interleaved_validation".to_string(),
836 message: "Data length must be divisible by number of channels".to_string(),
837 context: None,
838 recovery_suggestions: Box::new(vec![
839 "Ensure audio data length matches channel count".to_string(),
840 "Verify audio format is properly structured".to_string(),
841 ]),
842 });
843 }
844
845 let samples_per_channel = data.len() / num_channels;
846 let mut channels = vec![Vec::with_capacity(samples_per_channel); num_channels];
847
848 for (i, &sample) in data.iter().enumerate() {
849 channels[i % num_channels].push(sample);
850 }
851
852 Ok(Self {
853 channels,
854 sample_rate,
855 })
856 }
857
858 pub fn to_interleaved(&self) -> Vec<f32> {
860 let num_channels = self.channels.len();
861 let samples_per_channel = self.channels[0].len();
862 let mut interleaved = Vec::with_capacity(num_channels * samples_per_channel);
863
864 for sample_idx in 0..samples_per_channel {
865 for channel in &self.channels {
866 interleaved.push(channel[sample_idx]);
867 }
868 }
869
870 interleaved
871 }
872
873 pub fn num_channels(&self) -> usize {
875 self.channels.len()
876 }
877
878 pub fn num_samples(&self) -> usize {
880 self.channels.first().map(|ch| ch.len()).unwrap_or(0)
881 }
882
883 pub fn to_mono(&self) -> Vec<f32> {
885 let samples_per_channel = self.num_samples();
886 let num_channels = self.num_channels() as f32;
887
888 let mut mono = Vec::with_capacity(samples_per_channel);
889
890 for sample_idx in 0..samples_per_channel {
891 let sum: f32 = self.channels.iter().map(|ch| ch[sample_idx]).sum();
892 mono.push(sum / num_channels);
893 }
894
895 mono
896 }
897}
898
899pub trait MultiChannelTransform {
901 fn apply_multichannel(&self, input: &MultiChannelAudio) -> Result<MultiChannelAudio>;
903
904 fn get_parameters(&self) -> std::collections::HashMap<String, f32>;
906}
907
908#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
910pub enum ChannelStrategy {
911 Independent,
913 Correlated,
915 MonoExpanded,
917 MidSide,
919}
920
921#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
923pub struct MultiChannelConfig {
924 pub strategy: ChannelStrategy,
926 pub channel_gains: Vec<f32>,
928 pub enable_crosstalk: bool,
930 pub crosstalk_amount: f32,
932}
933
934impl Default for MultiChannelConfig {
935 fn default() -> Self {
936 Self {
937 strategy: ChannelStrategy::Independent,
938 channel_gains: vec![1.0, 1.0], enable_crosstalk: false,
940 crosstalk_amount: 0.05,
941 }
942 }
943}
944
945#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
947pub struct MultiChannelPitchTransform {
948 pub base_transform: PitchTransform,
950 pub config: MultiChannelConfig,
952 pub channel_pitch_factors: Vec<f32>,
954}
955
956impl MultiChannelPitchTransform {
957 pub fn new(pitch_factor: f32, num_channels: usize) -> Self {
959 Self {
960 base_transform: PitchTransform::new(pitch_factor),
961 config: MultiChannelConfig {
962 channel_gains: vec![1.0; num_channels],
963 ..Default::default()
964 },
965 channel_pitch_factors: vec![pitch_factor; num_channels],
966 }
967 }
968
969 pub fn stereo(left_pitch: f32, right_pitch: f32) -> Self {
971 Self {
972 base_transform: PitchTransform::new((left_pitch + right_pitch) / 2.0),
973 config: MultiChannelConfig::default(),
974 channel_pitch_factors: vec![left_pitch, right_pitch],
975 }
976 }
977
978 pub fn set_channel_pitch_factors(&mut self, factors: Vec<f32>) {
980 self.channel_pitch_factors = factors;
981 }
982}
983
984impl MultiChannelTransform for MultiChannelPitchTransform {
985 fn apply_multichannel(&self, input: &MultiChannelAudio) -> Result<MultiChannelAudio> {
986 match self.config.strategy {
987 ChannelStrategy::Independent => {
988 let mut output_channels = Vec::new();
989
990 for (ch_idx, channel) in input.channels.iter().enumerate() {
991 let pitch_factor = self
992 .channel_pitch_factors
993 .get(ch_idx)
994 .copied()
995 .unwrap_or(self.base_transform.pitch_factor);
996
997 let mut channel_transform = self.base_transform.clone();
998 channel_transform.pitch_factor = pitch_factor;
999
1000 let processed_channel = channel_transform.apply(channel)?;
1001 output_channels.push(processed_channel);
1002 }
1003
1004 self.apply_channel_processing(output_channels, input.sample_rate)
1005 }
1006
1007 ChannelStrategy::MidSide => {
1008 if input.num_channels() != 2 {
1009 return Err(Error::Transform {
1010 transform_type: "mid_side_validation".to_string(),
1011 message: "Mid/Side processing requires exactly 2 channels".to_string(),
1012 context: None,
1013 recovery_suggestions: Box::new(vec![
1014 "Convert audio to stereo format before Mid/Side processing".to_string(),
1015 "Use a different transform for non-stereo audio".to_string(),
1016 ]),
1017 });
1018 }
1019
1020 let left = &input.channels[0];
1021 let right = &input.channels[1];
1022
1023 let mid: Vec<f32> = left
1025 .iter()
1026 .zip(right.iter())
1027 .map(|(&l, &r)| (l + r) / 2.0)
1028 .collect();
1029
1030 let side: Vec<f32> = left
1031 .iter()
1032 .zip(right.iter())
1033 .map(|(&l, &r)| (l - r) / 2.0)
1034 .collect();
1035
1036 let mid_factor = self
1038 .channel_pitch_factors
1039 .first()
1040 .copied()
1041 .unwrap_or(self.base_transform.pitch_factor);
1042 let side_factor = self
1043 .channel_pitch_factors
1044 .get(1)
1045 .copied()
1046 .unwrap_or(self.base_transform.pitch_factor);
1047
1048 let mut mid_transform = self.base_transform.clone();
1049 mid_transform.pitch_factor = mid_factor;
1050 let processed_mid = mid_transform.apply(&mid)?;
1051
1052 let mut side_transform = self.base_transform.clone();
1053 side_transform.pitch_factor = side_factor;
1054 let processed_side = side_transform.apply(&side)?;
1055
1056 let processed_left: Vec<f32> = processed_mid
1058 .iter()
1059 .zip(processed_side.iter())
1060 .map(|(&m, &s)| m + s)
1061 .collect();
1062
1063 let processed_right: Vec<f32> = processed_mid
1064 .iter()
1065 .zip(processed_side.iter())
1066 .map(|(&m, &s)| m - s)
1067 .collect();
1068
1069 self.apply_channel_processing(
1070 vec![processed_left, processed_right],
1071 input.sample_rate,
1072 )
1073 }
1074
1075 ChannelStrategy::MonoExpanded => {
1076 let mono = input.to_mono();
1077 let processed_mono = self.base_transform.apply(&mono)?;
1078
1079 let mut output_channels = Vec::new();
1081 for ch_idx in 0..input.num_channels() {
1082 let gain = self
1083 .config
1084 .channel_gains
1085 .get(ch_idx)
1086 .copied()
1087 .unwrap_or(1.0);
1088 let channel = processed_mono.iter().map(|&s| s * gain).collect();
1089 output_channels.push(channel);
1090 }
1091
1092 self.apply_channel_processing(output_channels, input.sample_rate)
1093 }
1094
1095 ChannelStrategy::Correlated => {
1096 let mut output_channels = Vec::new();
1098 let correlation_matrix = self.calculate_channel_correlation(input);
1099
1100 for (ch_idx, channel) in input.channels.iter().enumerate() {
1101 let mut correlated_channel = channel.clone();
1102
1103 for (other_idx, other_channel) in input.channels.iter().enumerate() {
1105 if ch_idx != other_idx {
1106 let correlation = correlation_matrix[ch_idx][other_idx];
1107 let influence = correlation * 0.1; for (i, &other_sample) in other_channel.iter().enumerate() {
1110 if i < correlated_channel.len() {
1111 correlated_channel[i] += other_sample * influence;
1112 }
1113 }
1114 }
1115 }
1116
1117 let pitch_factor = self
1118 .channel_pitch_factors
1119 .get(ch_idx)
1120 .copied()
1121 .unwrap_or(self.base_transform.pitch_factor);
1122
1123 let mut channel_transform = self.base_transform.clone();
1124 channel_transform.pitch_factor = pitch_factor;
1125
1126 let processed_channel = channel_transform.apply(&correlated_channel)?;
1127 output_channels.push(processed_channel);
1128 }
1129
1130 self.apply_channel_processing(output_channels, input.sample_rate)
1131 }
1132 }
1133 }
1134
1135 fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
1136 let mut params = Transform::get_parameters(&self.base_transform);
1137 params.insert(
1138 "num_channels".to_string(),
1139 self.config.channel_gains.len() as f32,
1140 );
1141 params.insert("crosstalk_amount".to_string(), self.config.crosstalk_amount);
1142
1143 for (i, &factor) in self.channel_pitch_factors.iter().enumerate() {
1144 params.insert(format!("channel_{i}_pitch"), factor);
1145 }
1146
1147 params
1148 }
1149}
1150
1151impl MultiChannelPitchTransform {
1152 fn apply_channel_processing(
1154 &self,
1155 mut channels: Vec<Vec<f32>>,
1156 sample_rate: u32,
1157 ) -> Result<MultiChannelAudio> {
1158 for (ch_idx, channel) in channels.iter_mut().enumerate() {
1160 let gain = self
1161 .config
1162 .channel_gains
1163 .get(ch_idx)
1164 .copied()
1165 .unwrap_or(1.0);
1166 for sample in channel.iter_mut() {
1167 *sample *= gain;
1168 }
1169 }
1170
1171 if self.config.enable_crosstalk && channels.len() > 1 {
1173 self.apply_crosstalk(&mut channels);
1174 }
1175
1176 MultiChannelAudio::new(channels, sample_rate)
1177 }
1178
1179 fn apply_crosstalk(&self, channels: &mut [Vec<f32>]) {
1181 let num_channels = channels.len();
1182 let crosstalk = self.config.crosstalk_amount;
1183
1184 let original_channels: Vec<Vec<f32>> = channels.to_vec();
1186
1187 for (ch_idx, channel) in channels.iter_mut().enumerate() {
1188 for (sample_idx, sample) in channel.iter_mut().enumerate() {
1189 let mut crosstalk_sum = 0.0;
1190 let mut count = 0;
1191
1192 for (other_idx, other_channel) in original_channels.iter().enumerate() {
1194 if ch_idx != other_idx && sample_idx < other_channel.len() {
1195 crosstalk_sum += other_channel[sample_idx];
1196 count += 1;
1197 }
1198 }
1199
1200 if count > 0 {
1201 let avg_crosstalk = crosstalk_sum / count as f32;
1202 *sample = *sample * (1.0 - crosstalk) + avg_crosstalk * crosstalk;
1203 }
1204 }
1205 }
1206 }
1207
1208 fn calculate_channel_correlation(&self, input: &MultiChannelAudio) -> Vec<Vec<f32>> {
1210 let num_channels = input.num_channels();
1211
1212 (0..num_channels)
1213 .map(|i| {
1214 (0..num_channels)
1215 .map(|j| {
1216 if i == j {
1217 1.0
1218 } else {
1219 self.calculate_correlation(&input.channels[i], &input.channels[j])
1220 }
1221 })
1222 .collect()
1223 })
1224 .collect()
1225 }
1226
1227 fn calculate_correlation(&self, ch1: &[f32], ch2: &[f32]) -> f32 {
1229 if ch1.len() != ch2.len() || ch1.is_empty() {
1230 return 0.0;
1231 }
1232
1233 let mean1 = ch1.iter().sum::<f32>() / ch1.len() as f32;
1234 let mean2 = ch2.iter().sum::<f32>() / ch2.len() as f32;
1235
1236 let mut numerator = 0.0;
1237 let mut sum_sq1 = 0.0;
1238 let mut sum_sq2 = 0.0;
1239
1240 for (s1, s2) in ch1.iter().zip(ch2.iter()) {
1241 let diff1 = s1 - mean1;
1242 let diff2 = s2 - mean2;
1243
1244 numerator += diff1 * diff2;
1245 sum_sq1 += diff1 * diff1;
1246 sum_sq2 += diff2 * diff2;
1247 }
1248
1249 let denominator = (sum_sq1 * sum_sq2).sqrt();
1250 if denominator == 0.0 {
1251 0.0
1252 } else {
1253 numerator / denominator
1254 }
1255 }
1256}
1257
1258impl MultiChannelTransform for PitchTransform {
1261 fn apply_multichannel(&self, input: &MultiChannelAudio) -> Result<MultiChannelAudio> {
1262 let multichannel_transform = MultiChannelPitchTransform {
1263 base_transform: self.clone(),
1264 config: MultiChannelConfig {
1265 channel_gains: vec![1.0; input.num_channels()],
1266 ..Default::default()
1267 },
1268 channel_pitch_factors: vec![self.pitch_factor; input.num_channels()],
1269 };
1270
1271 multichannel_transform.apply_multichannel(input)
1272 }
1273
1274 fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
1275 Transform::get_parameters(self)
1276 }
1277}
1278
1279impl MultiChannelTransform for SpeedTransform {
1280 fn apply_multichannel(&self, input: &MultiChannelAudio) -> Result<MultiChannelAudio> {
1281 let mut output_channels = Vec::new();
1282
1283 for channel in &input.channels {
1284 let processed_channel = self.apply(channel)?;
1285 output_channels.push(processed_channel);
1286 }
1287
1288 MultiChannelAudio::new(output_channels, input.sample_rate)
1289 }
1290
1291 fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
1292 Transform::get_parameters(self)
1293 }
1294}
1295
1296impl MultiChannelTransform for AgeTransform {
1297 fn apply_multichannel(&self, input: &MultiChannelAudio) -> Result<MultiChannelAudio> {
1298 let mut output_channels = Vec::new();
1299
1300 for channel in &input.channels {
1301 let processed_channel = self.apply(channel)?;
1302 output_channels.push(processed_channel);
1303 }
1304
1305 MultiChannelAudio::new(output_channels, input.sample_rate)
1306 }
1307
1308 fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
1309 Transform::get_parameters(self)
1310 }
1311}
1312
1313impl MultiChannelTransform for GenderTransform {
1314 fn apply_multichannel(&self, input: &MultiChannelAudio) -> Result<MultiChannelAudio> {
1315 let mut output_channels = Vec::new();
1316
1317 for channel in &input.channels {
1318 let processed_channel = self.apply(channel)?;
1319 output_channels.push(processed_channel);
1320 }
1321
1322 MultiChannelAudio::new(output_channels, input.sample_rate)
1323 }
1324
1325 fn get_parameters(&self) -> std::collections::HashMap<String, f32> {
1326 Transform::get_parameters(self)
1327 }
1328}
1329
1330#[cfg(test)]
1331mod multichannel_tests {
1332 use super::*;
1333
1334 #[test]
1335 fn test_multichannel_audio_creation() {
1336 let channels = vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
1337 let audio = MultiChannelAudio::new(channels.clone(), 44100).unwrap();
1338
1339 assert_eq!(audio.num_channels(), 2);
1340 assert_eq!(audio.num_samples(), 3);
1341 assert_eq!(audio.channels, channels);
1342 }
1343
1344 #[test]
1345 fn test_interleaved_conversion() {
1346 let interleaved = vec![0.1, 0.4, 0.2, 0.5, 0.3, 0.6];
1347 let audio = MultiChannelAudio::from_interleaved(&interleaved, 2, 44100).unwrap();
1348
1349 assert_eq!(audio.num_channels(), 2);
1350 assert_eq!(audio.num_samples(), 3);
1351
1352 let back_to_interleaved = audio.to_interleaved();
1353 assert_eq!(back_to_interleaved, interleaved);
1354 }
1355
1356 #[test]
1357 fn test_mono_conversion() {
1358 let channels = vec![vec![0.2, 0.4, 0.6], vec![0.8, 1.0, 1.2]];
1359 let audio = MultiChannelAudio::new(channels, 44100).unwrap();
1360 let mono = audio.to_mono();
1361
1362 assert_eq!(mono.len(), 3);
1363 assert!((mono[0] - 0.5).abs() < f32::EPSILON); assert!((mono[1] - 0.7).abs() < f32::EPSILON); assert!((mono[2] - 0.9).abs() < f32::EPSILON); }
1367
1368 #[test]
1369 fn test_multichannel_pitch_transform_independent() {
1370 let channels = vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
1371 let audio = MultiChannelAudio::new(channels, 44100).unwrap();
1372
1373 let transform = MultiChannelPitchTransform::new(2.0, 2);
1374 let result = transform.apply_multichannel(&audio).unwrap();
1375
1376 assert_eq!(result.num_channels(), 2);
1377 assert_eq!(result.num_samples(), 3);
1378 }
1379
1380 #[test]
1381 fn test_multichannel_pitch_transform_stereo() {
1382 let channels = vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
1383 let audio = MultiChannelAudio::new(channels, 44100).unwrap();
1384
1385 let transform = MultiChannelPitchTransform::stereo(1.5, 2.5);
1386 let result = transform.apply_multichannel(&audio).unwrap();
1387
1388 assert_eq!(result.num_channels(), 2);
1389 assert_eq!(transform.channel_pitch_factors, vec![1.5, 2.5]);
1390 }
1391
1392 #[test]
1393 fn test_multichannel_mid_side_processing() {
1394 let channels = vec![
1395 vec![0.8, 0.6, 0.4], vec![0.2, 0.4, 0.6], ];
1398 let audio = MultiChannelAudio::new(channels, 44100).unwrap();
1399
1400 let mut transform = MultiChannelPitchTransform::stereo(2.0, 2.0);
1401 transform.config.strategy = ChannelStrategy::MidSide;
1402
1403 let result = transform.apply_multichannel(&audio).unwrap();
1404 assert_eq!(result.num_channels(), 2);
1405 }
1406
1407 #[test]
1408 fn test_multichannel_transform_with_crosstalk() {
1409 let channels = vec![vec![1.0, 0.0, 0.5], vec![0.0, 1.0, 0.5]];
1410 let audio = MultiChannelAudio::new(channels, 44100).unwrap();
1411
1412 let mut transform = MultiChannelPitchTransform::new(1.0, 2);
1413 transform.config.enable_crosstalk = true;
1414 transform.config.crosstalk_amount = 0.1;
1415
1416 let result = transform.apply_multichannel(&audio).unwrap();
1417
1418 assert_eq!(result.num_channels(), 2);
1420 assert_ne!(result.channels[0], vec![1.0, 0.0, 0.5]);
1421 assert_ne!(result.channels[1], vec![0.0, 1.0, 0.5]);
1422 }
1423
1424 #[test]
1425 fn test_channel_correlation_calculation() {
1426 let channels = vec![
1427 vec![1.0, 2.0, 3.0],
1428 vec![1.0, 2.0, 3.0], ];
1430 let audio = MultiChannelAudio::new(channels, 44100).unwrap();
1431
1432 let transform = MultiChannelPitchTransform::new(1.0, 2);
1433 let correlation_matrix = transform.calculate_channel_correlation(&audio);
1434
1435 assert_eq!(correlation_matrix.len(), 2);
1436 assert_eq!(correlation_matrix[0].len(), 2);
1437 assert_eq!(correlation_matrix[0][0], 1.0); assert!((correlation_matrix[0][1] - 1.0).abs() < f32::EPSILON); }
1440}