Skip to main content

voirs_spatial/
wfs.rs

1//! Wave Field Synthesis (WFS) implementation for advanced spatial audio reproduction
2//!
3//! This module provides Wave Field Synthesis capabilities for creating virtual sound fields
4//! using arrays of loudspeakers. WFS enables reproduction of spatial audio with high
5//! localization accuracy and extended listening area compared to traditional stereo systems.
6
7use crate::types::Position3D;
8use crate::{Error, Result};
9use scirs2_core::ndarray::{s, Array1, Array2, Axis};
10use scirs2_core::Complex32;
11use scirs2_fft::{irfft, rfft, FftPlanner, RealFftPlanner, RealToComplex};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::f32::consts::PI;
15use std::sync::Arc;
16
17/// Wave Field Synthesis processor configuration
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct WfsConfig {
20    /// Sample rate in Hz
21    pub sample_rate: f32,
22    /// Number of speakers in the WFS array
23    pub speaker_count: usize,
24    /// Array geometry type
25    pub array_geometry: ArrayGeometry,
26    /// Speaker positions in 3D space
27    pub speaker_positions: Vec<Position3D>,
28    /// Maximum processing distance (meters)
29    pub max_distance: f32,
30    /// Frequency range for WFS processing (Hz)
31    pub frequency_range: (f32, f32),
32    /// Reference distance for amplitude scaling (meters)
33    pub reference_distance: f32,
34    /// Pre-emphasis filter parameters
35    pub pre_emphasis: PreEmphasisConfig,
36    /// Spatial aliasing compensation
37    pub aliasing_compensation: bool,
38}
39
40/// WFS array geometry types
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
42pub enum ArrayGeometry {
43    /// Linear array of speakers
44    Linear,
45    /// Circular array of speakers
46    Circular,
47    /// Rectangular array of speakers
48    Rectangular,
49    /// Custom arrangement
50    Custom,
51}
52
53/// Pre-emphasis filter configuration
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct PreEmphasisConfig {
56    /// Enable pre-emphasis filtering
57    pub enabled: bool,
58    /// High-pass cutoff frequency (Hz)
59    pub cutoff_frequency: f32,
60    /// Filter order
61    pub filter_order: usize,
62}
63
64/// Virtual sound source for WFS reproduction
65#[derive(Debug, Clone)]
66pub struct WfsSource {
67    /// Unique identifier
68    pub id: String,
69    /// 3D position
70    pub position: Position3D,
71    /// Audio signal
72    pub audio_data: Array1<f32>,
73    /// Source type
74    pub source_type: WfsSourceType,
75    /// Gain factor
76    pub gain: f32,
77    /// Distance from reference point
78    pub distance: f32,
79}
80
81/// WFS source types
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub enum WfsSourceType {
84    /// Point source
85    Point,
86    /// Plane wave
87    PlaneWave,
88    /// Extended source
89    Extended,
90}
91
92/// WFS driving function for a single speaker
93#[derive(Debug, Clone)]
94pub struct WfsDrivingFunction {
95    /// Speaker index
96    pub speaker_index: usize,
97    /// Complex frequency response
98    pub frequency_response: Array1<Complex32>,
99    /// Delay in samples
100    pub delay_samples: f32,
101    /// Amplitude scaling factor
102    pub amplitude: f32,
103}
104
105/// Main Wave Field Synthesis processor
106pub struct WfsProcessor {
107    /// Configuration
108    config: WfsConfig,
109    /// FFT planner for frequency domain processing
110    fft_planner: Arc<RealFftPlanner<f32>>,
111    /// Forward FFT
112    forward_fft: Arc<dyn RealToComplex<f32>>,
113    /// Inverse FFT
114    inverse_fft: Arc<dyn scirs2_fft::ComplexToReal<f32>>,
115    /// Speaker driving functions cache
116    driving_functions_cache: HashMap<String, Vec<WfsDrivingFunction>>,
117    /// Processing buffers
118    frequency_buffer: Array2<Complex32>,
119    time_buffer: Array2<f32>,
120    /// Speed of sound (m/s)
121    speed_of_sound: f32,
122}
123
124impl Default for WfsConfig {
125    fn default() -> Self {
126        // Default linear array configuration
127        let speaker_count = 16;
128        let speaker_spacing = 0.2; // 20cm spacing
129        let speaker_positions: Vec<Position3D> = (0..speaker_count)
130            .map(|i| Position3D {
131                x: (i as f32 - speaker_count as f32 / 2.0) * speaker_spacing,
132                y: 0.0,
133                z: 0.0,
134            })
135            .collect();
136
137        Self {
138            sample_rate: 48000.0,
139            speaker_count,
140            array_geometry: ArrayGeometry::Linear,
141            speaker_positions,
142            max_distance: 10.0,
143            frequency_range: (20.0, 20000.0),
144            reference_distance: 1.0,
145            pre_emphasis: PreEmphasisConfig {
146                enabled: true,
147                cutoff_frequency: 100.0,
148                filter_order: 2,
149            },
150            aliasing_compensation: true,
151        }
152    }
153}
154
155impl WfsProcessor {
156    /// Create a new WFS processor
157    pub fn new(config: WfsConfig) -> Result<Self> {
158        if config.speaker_count == 0 {
159            return Err(Error::LegacyConfig(
160                "Speaker count must be greater than 0".to_string(),
161            ));
162        }
163
164        if config.speaker_positions.len() != config.speaker_count {
165            return Err(Error::LegacyConfig(
166                "Number of speaker positions must match speaker count".to_string(),
167            ));
168        }
169
170        let mut planner = RealFftPlanner::<f32>::new();
171        let buffer_size = 1024; // Default FFT size
172
173        let forward_fft = planner.plan_fft_forward(buffer_size);
174        let inverse_fft = planner.plan_fft_inverse(buffer_size);
175
176        let frequency_buffer = Array2::zeros((config.speaker_count, buffer_size / 2 + 1));
177        let time_buffer = Array2::zeros((config.speaker_count, buffer_size));
178
179        Ok(Self {
180            config,
181            fft_planner: Arc::new(planner),
182            forward_fft,
183            inverse_fft,
184            driving_functions_cache: HashMap::new(),
185            frequency_buffer,
186            time_buffer,
187            speed_of_sound: 343.0, // Standard speed of sound at 20°C
188        })
189    }
190
191    /// Process a virtual source using WFS
192    pub fn process_source(&mut self, source: &WfsSource) -> Result<Array2<f32>> {
193        let driving_functions = self.compute_driving_functions(source)?;
194        self.apply_driving_functions(&driving_functions, &source.audio_data)
195    }
196
197    /// Compute WFS driving functions for a source
198    fn compute_driving_functions(&mut self, source: &WfsSource) -> Result<Vec<WfsDrivingFunction>> {
199        // Check cache first
200        if let Some(cached) = self.driving_functions_cache.get(&source.id) {
201            return Ok(cached.clone());
202        }
203
204        let mut driving_functions = Vec::with_capacity(self.config.speaker_count);
205
206        for (speaker_idx, speaker_pos) in self.config.speaker_positions.iter().enumerate() {
207            let driving_function = match source.source_type {
208                WfsSourceType::Point => {
209                    self.compute_point_source_driving_function(source, speaker_pos, speaker_idx)?
210                }
211                WfsSourceType::PlaneWave => {
212                    self.compute_plane_wave_driving_function(source, speaker_pos, speaker_idx)?
213                }
214                WfsSourceType::Extended => {
215                    self.compute_extended_source_driving_function(source, speaker_pos, speaker_idx)?
216                }
217            };
218            driving_functions.push(driving_function);
219        }
220
221        // Cache the result
222        self.driving_functions_cache
223            .insert(source.id.clone(), driving_functions.clone());
224
225        Ok(driving_functions)
226    }
227
228    /// Compute driving function for point source
229    fn compute_point_source_driving_function(
230        &self,
231        source: &WfsSource,
232        speaker_pos: &Position3D,
233        speaker_idx: usize,
234    ) -> Result<WfsDrivingFunction> {
235        // Distance from source to speaker
236        let distance = source.position.distance_to(speaker_pos);
237
238        // Delay calculation
239        let delay_time = distance / self.speed_of_sound;
240        let delay_samples = delay_time * self.config.sample_rate;
241
242        // Amplitude calculation with distance attenuation
243        let amplitude = source.gain * (self.config.reference_distance / distance).sqrt();
244
245        // Frequency response (simplified - can be extended with directivity patterns)
246        let buffer_size = self.frequency_buffer.ncols();
247        let mut frequency_response = Array1::zeros(buffer_size);
248
249        // Apply frequency-dependent processing
250        for (freq_idx, response) in frequency_response.iter_mut().enumerate() {
251            let frequency = freq_idx as f32 * self.config.sample_rate / (2.0 * buffer_size as f32);
252
253            if frequency >= self.config.frequency_range.0
254                && frequency <= self.config.frequency_range.1
255            {
256                // Basic WFS frequency response
257                let omega = 2.0 * PI * frequency;
258                let wave_number = omega / self.speed_of_sound;
259
260                // Phase shift due to distance
261                let phase = -wave_number * distance;
262                *response = Complex32::from_polar(amplitude, phase);
263
264                // Apply pre-emphasis if enabled
265                if self.config.pre_emphasis.enabled {
266                    let pre_emphasis_gain = self.compute_pre_emphasis_gain(frequency);
267                    *response *= pre_emphasis_gain;
268                }
269            }
270        }
271
272        Ok(WfsDrivingFunction {
273            speaker_index: speaker_idx,
274            frequency_response,
275            delay_samples,
276            amplitude,
277        })
278    }
279
280    /// Compute driving function for plane wave
281    fn compute_plane_wave_driving_function(
282        &self,
283        source: &WfsSource,
284        speaker_pos: &Position3D,
285        speaker_idx: usize,
286    ) -> Result<WfsDrivingFunction> {
287        // For plane waves, the delay is based on the projection of speaker position
288        // onto the wave direction
289        let wave_direction = source.position.normalized();
290        let projection = speaker_pos.dot(&wave_direction);
291
292        let delay_time = projection / self.speed_of_sound;
293        let delay_samples = delay_time * self.config.sample_rate;
294
295        // Constant amplitude for plane waves
296        let amplitude = source.gain;
297
298        // Frequency response for plane wave
299        let buffer_size = self.frequency_buffer.ncols();
300        let mut frequency_response = Array1::zeros(buffer_size);
301
302        for (freq_idx, response) in frequency_response.iter_mut().enumerate() {
303            let frequency = freq_idx as f32 * self.config.sample_rate / (2.0 * buffer_size as f32);
304
305            if frequency >= self.config.frequency_range.0
306                && frequency <= self.config.frequency_range.1
307            {
308                let omega = 2.0 * PI * frequency;
309                let wave_number = omega / self.speed_of_sound;
310                let phase = -wave_number * projection;
311
312                *response = Complex32::from_polar(amplitude, phase);
313            }
314        }
315
316        Ok(WfsDrivingFunction {
317            speaker_index: speaker_idx,
318            frequency_response,
319            delay_samples,
320            amplitude,
321        })
322    }
323
324    /// Compute driving function for extended source (simplified implementation)
325    fn compute_extended_source_driving_function(
326        &self,
327        source: &WfsSource,
328        speaker_pos: &Position3D,
329        speaker_idx: usize,
330    ) -> Result<WfsDrivingFunction> {
331        // For extended sources, use point source approximation
332        // In a full implementation, this would integrate over the source extent
333        self.compute_point_source_driving_function(source, speaker_pos, speaker_idx)
334    }
335
336    /// Compute pre-emphasis filter gain
337    fn compute_pre_emphasis_gain(&self, frequency: f32) -> f32 {
338        if !self.config.pre_emphasis.enabled
339            || frequency < self.config.pre_emphasis.cutoff_frequency
340        {
341            return 1.0;
342        }
343
344        // Simple high-pass filter response
345        let normalized_freq = frequency / self.config.pre_emphasis.cutoff_frequency;
346        normalized_freq.sqrt() // Square root frequency response
347    }
348
349    /// Apply driving functions to generate speaker signals
350    fn apply_driving_functions(
351        &mut self,
352        driving_functions: &[WfsDrivingFunction],
353        audio_data: &Array1<f32>,
354    ) -> Result<Array2<f32>> {
355        let output_length = audio_data.len();
356        let mut output = Array2::zeros((self.config.speaker_count, output_length));
357
358        // Process each speaker channel
359        for (speaker_idx, driving_function) in driving_functions.iter().enumerate() {
360            let delayed_signal = self.apply_delay_and_amplitude(
361                audio_data,
362                driving_function.delay_samples,
363                driving_function.amplitude,
364            )?;
365
366            // Apply frequency domain filtering if needed
367            let processed_signal = if self.should_apply_frequency_processing(driving_function) {
368                self.apply_frequency_response(
369                    &delayed_signal,
370                    &driving_function.frequency_response,
371                )?
372            } else {
373                delayed_signal
374            };
375
376            // Copy to output
377            let output_length = output_length.min(processed_signal.len());
378            output
379                .row_mut(speaker_idx)
380                .slice_mut(s![..output_length])
381                .assign(&processed_signal.slice(s![..output_length]));
382        }
383
384        Ok(output)
385    }
386
387    /// Apply delay and amplitude scaling to signal
388    fn apply_delay_and_amplitude(
389        &self,
390        signal: &Array1<f32>,
391        delay_samples: f32,
392        amplitude: f32,
393    ) -> Result<Array1<f32>> {
394        let signal_length = signal.len();
395        let delay_int = delay_samples.floor() as isize;
396        let delay_frac = delay_samples - delay_int as f32;
397
398        let mut output = Array1::zeros(signal_length);
399
400        // Apply integer delay
401        if delay_int >= 0 {
402            let start_idx = delay_int as usize;
403            if start_idx < signal_length {
404                let copy_length = signal_length - start_idx;
405                output
406                    .slice_mut(s![start_idx..])
407                    .assign(&signal.slice(s![..copy_length]));
408            }
409        }
410
411        // Apply fractional delay using linear interpolation
412        if delay_frac > 0.001 {
413            for i in 1..signal_length {
414                output[i] = output[i] * (1.0 - delay_frac) + output[i - 1] * delay_frac;
415            }
416        }
417
418        // Apply amplitude scaling
419        output *= amplitude;
420
421        Ok(output)
422    }
423
424    /// Check if frequency domain processing should be applied
425    fn should_apply_frequency_processing(&self, driving_function: &WfsDrivingFunction) -> bool {
426        // Apply frequency processing if the response is not flat
427        driving_function
428            .frequency_response
429            .iter()
430            .any(|&response| (response.norm() - 1.0).abs() > 0.1 || response.arg().abs() > 0.1)
431    }
432
433    /// Apply frequency response using FFT
434    fn apply_frequency_response(
435        &mut self,
436        signal: &Array1<f32>,
437        frequency_response: &Array1<Complex32>,
438    ) -> Result<Array1<f32>> {
439        let buffer_size = self.frequency_buffer.ncols() * 2 - 2;
440        let mut padded_signal = Array1::zeros(buffer_size);
441
442        // Copy signal to padded buffer
443        let copy_length = signal.len().min(buffer_size);
444        padded_signal
445            .slice_mut(s![..copy_length])
446            .assign(&signal.slice(s![..copy_length]));
447
448        // Transform to frequency domain
449        let mut spectrum = Array1::zeros(frequency_response.len());
450        self.forward_fft.process(
451            padded_signal.as_slice().expect("contiguous array"),
452            spectrum.as_slice_mut().expect("contiguous array"),
453        );
454
455        // Apply frequency response
456        for (spectrum_bin, &response) in spectrum.iter_mut().zip(frequency_response.iter()) {
457            *spectrum_bin *= response;
458        }
459
460        // Transform back to time domain
461        let mut result = Array1::zeros(buffer_size);
462        self.inverse_fft.process(
463            spectrum.as_slice().expect("contiguous array"),
464            result.as_slice_mut().expect("contiguous array"),
465        );
466
467        // Return original length
468        Ok(result.slice(s![..signal.len()]).to_owned())
469    }
470
471    /// Update source position (invalidates cache for that source)
472    pub fn update_source_position(&mut self, source_id: &str, new_position: Position3D) {
473        self.driving_functions_cache.remove(source_id);
474    }
475
476    /// Clear all cached driving functions
477    pub fn clear_cache(&mut self) {
478        self.driving_functions_cache.clear();
479    }
480
481    /// Get configuration
482    pub fn config(&self) -> &WfsConfig {
483        &self.config
484    }
485
486    /// Set speed of sound (for different environmental conditions)
487    pub fn set_speed_of_sound(&mut self, speed: f32) {
488        if speed > 0.0 {
489            self.speed_of_sound = speed;
490            self.clear_cache(); // Invalidate cache since speed affects calculations
491        }
492    }
493}
494
495/// WFS array builder for different geometries
496pub struct WfsArrayBuilder {
497    geometry: ArrayGeometry,
498    speaker_count: usize,
499    dimensions: (f32, f32, f32), // width, height, depth
500}
501
502impl WfsArrayBuilder {
503    /// Create a new array builder
504    pub fn new(geometry: ArrayGeometry) -> Self {
505        Self {
506            geometry,
507            speaker_count: 16,
508            dimensions: (3.0, 0.0, 0.0), // 3m wide linear array by default
509        }
510    }
511
512    /// Set the number of speakers
513    pub fn speaker_count(mut self, count: usize) -> Self {
514        self.speaker_count = count;
515        self
516    }
517
518    /// Set array dimensions
519    pub fn dimensions(mut self, width: f32, height: f32, depth: f32) -> Self {
520        self.dimensions = (width, height, depth);
521        self
522    }
523
524    /// Build speaker positions based on geometry
525    pub fn build_positions(self) -> Vec<Position3D> {
526        match self.geometry {
527            ArrayGeometry::Linear => self.build_linear_array(),
528            ArrayGeometry::Circular => self.build_circular_array(),
529            ArrayGeometry::Rectangular => self.build_rectangular_array(),
530            ArrayGeometry::Custom => vec![], // User must provide custom positions
531        }
532    }
533
534    fn build_linear_array(&self) -> Vec<Position3D> {
535        let spacing = self.dimensions.0 / (self.speaker_count - 1) as f32;
536        let start_x = -self.dimensions.0 / 2.0;
537
538        (0..self.speaker_count)
539            .map(|i| Position3D {
540                x: start_x + i as f32 * spacing,
541                y: 0.0,
542                z: 0.0,
543            })
544            .collect()
545    }
546
547    fn build_circular_array(&self) -> Vec<Position3D> {
548        let radius = self.dimensions.0 / 2.0;
549        let angle_step = 2.0 * PI / self.speaker_count as f32;
550
551        (0..self.speaker_count)
552            .map(|i| {
553                let angle = i as f32 * angle_step;
554                Position3D {
555                    x: radius * angle.cos(),
556                    y: radius * angle.sin(),
557                    z: 0.0,
558                }
559            })
560            .collect()
561    }
562
563    fn build_rectangular_array(&self) -> Vec<Position3D> {
564        // Simple rectangular grid
565        let cols = (self.speaker_count as f32).sqrt().ceil() as usize;
566        let rows = self.speaker_count.div_ceil(cols);
567
568        let x_spacing = self.dimensions.0 / (cols - 1) as f32;
569        let y_spacing = self.dimensions.1 / (rows - 1) as f32;
570
571        let start_x = -self.dimensions.0 / 2.0;
572        let start_y = -self.dimensions.1 / 2.0;
573
574        let mut positions = Vec::new();
575        for row in 0..rows {
576            for col in 0..cols {
577                if positions.len() < self.speaker_count {
578                    positions.push(Position3D {
579                        x: start_x + col as f32 * x_spacing,
580                        y: start_y + row as f32 * y_spacing,
581                        z: 0.0,
582                    });
583                }
584            }
585        }
586        positions
587    }
588}
589
590#[cfg(test)]
591mod tests {
592    use super::*;
593
594    #[test]
595    fn test_wfs_config_default() {
596        let config = WfsConfig::default();
597        assert_eq!(config.speaker_count, 16);
598        assert_eq!(config.array_geometry, ArrayGeometry::Linear);
599        assert_eq!(config.speaker_positions.len(), 16);
600    }
601
602    #[test]
603    fn test_wfs_processor_creation() {
604        let config = WfsConfig::default();
605        let processor = WfsProcessor::new(config);
606        assert!(processor.is_ok());
607    }
608
609    #[test]
610    fn test_array_builder_linear() {
611        let positions = WfsArrayBuilder::new(ArrayGeometry::Linear)
612            .speaker_count(8)
613            .dimensions(2.0, 0.0, 0.0)
614            .build_positions();
615
616        assert_eq!(positions.len(), 8);
617        assert_eq!(positions[0].x, -1.0);
618        assert_eq!(positions[7].x, 1.0);
619    }
620
621    #[test]
622    fn test_array_builder_circular() {
623        let positions = WfsArrayBuilder::new(ArrayGeometry::Circular)
624            .speaker_count(4)
625            .dimensions(2.0, 0.0, 0.0) // diameter = 2.0, so radius = 1.0
626            .build_positions();
627
628        assert_eq!(positions.len(), 4);
629        // First speaker should be at (1, 0, 0)
630        assert!((positions[0].x - 1.0).abs() < 0.001);
631        assert!(positions[0].y.abs() < 0.001);
632    }
633
634    #[test]
635    fn test_wfs_source_creation() {
636        let source = WfsSource {
637            id: "test_source".to_string(),
638            position: Position3D {
639                x: 1.0,
640                y: 0.0,
641                z: 0.0,
642            },
643            audio_data: Array1::zeros(1024),
644            source_type: WfsSourceType::Point,
645            gain: 1.0,
646            distance: 1.0,
647        };
648
649        assert_eq!(source.id, "test_source");
650        assert_eq!(source.source_type, WfsSourceType::Point);
651    }
652
653    #[test]
654    fn test_processor_source_processing() {
655        let config = WfsConfig::default();
656        let mut processor = WfsProcessor::new(config).unwrap();
657
658        let source = WfsSource {
659            id: "test".to_string(),
660            position: Position3D {
661                x: 2.0,
662                y: 0.0,
663                z: 0.0,
664            },
665            audio_data: Array1::ones(512),
666            source_type: WfsSourceType::Point,
667            gain: 1.0,
668            distance: 2.0,
669        };
670
671        let result = processor.process_source(&source);
672        assert!(result.is_ok());
673
674        let output = result.unwrap();
675        assert_eq!(output.nrows(), 16); // 16 speakers
676        assert_eq!(output.ncols(), 512); // Same length as input
677    }
678}