1use crate::types::Position3D;
8use crate::{Error, Result};
9use scirs2_core::ndarray::{s, Array1, Array2, Axis};
10use scirs2_core::Complex32;
11use scirs2_fft::{irfft, rfft, FftPlanner, RealFftPlanner, RealToComplex};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::f32::consts::PI;
15use std::sync::Arc;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct WfsConfig {
20 pub sample_rate: f32,
22 pub speaker_count: usize,
24 pub array_geometry: ArrayGeometry,
26 pub speaker_positions: Vec<Position3D>,
28 pub max_distance: f32,
30 pub frequency_range: (f32, f32),
32 pub reference_distance: f32,
34 pub pre_emphasis: PreEmphasisConfig,
36 pub aliasing_compensation: bool,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
42pub enum ArrayGeometry {
43 Linear,
45 Circular,
47 Rectangular,
49 Custom,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct PreEmphasisConfig {
56 pub enabled: bool,
58 pub cutoff_frequency: f32,
60 pub filter_order: usize,
62}
63
64#[derive(Debug, Clone)]
66pub struct WfsSource {
67 pub id: String,
69 pub position: Position3D,
71 pub audio_data: Array1<f32>,
73 pub source_type: WfsSourceType,
75 pub gain: f32,
77 pub distance: f32,
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub enum WfsSourceType {
84 Point,
86 PlaneWave,
88 Extended,
90}
91
92#[derive(Debug, Clone)]
94pub struct WfsDrivingFunction {
95 pub speaker_index: usize,
97 pub frequency_response: Array1<Complex32>,
99 pub delay_samples: f32,
101 pub amplitude: f32,
103}
104
105pub struct WfsProcessor {
107 config: WfsConfig,
109 fft_planner: Arc<RealFftPlanner<f32>>,
111 forward_fft: Arc<dyn RealToComplex<f32>>,
113 inverse_fft: Arc<dyn scirs2_fft::ComplexToReal<f32>>,
115 driving_functions_cache: HashMap<String, Vec<WfsDrivingFunction>>,
117 frequency_buffer: Array2<Complex32>,
119 time_buffer: Array2<f32>,
120 speed_of_sound: f32,
122}
123
124impl Default for WfsConfig {
125 fn default() -> Self {
126 let speaker_count = 16;
128 let speaker_spacing = 0.2; let speaker_positions: Vec<Position3D> = (0..speaker_count)
130 .map(|i| Position3D {
131 x: (i as f32 - speaker_count as f32 / 2.0) * speaker_spacing,
132 y: 0.0,
133 z: 0.0,
134 })
135 .collect();
136
137 Self {
138 sample_rate: 48000.0,
139 speaker_count,
140 array_geometry: ArrayGeometry::Linear,
141 speaker_positions,
142 max_distance: 10.0,
143 frequency_range: (20.0, 20000.0),
144 reference_distance: 1.0,
145 pre_emphasis: PreEmphasisConfig {
146 enabled: true,
147 cutoff_frequency: 100.0,
148 filter_order: 2,
149 },
150 aliasing_compensation: true,
151 }
152 }
153}
154
155impl WfsProcessor {
156 pub fn new(config: WfsConfig) -> Result<Self> {
158 if config.speaker_count == 0 {
159 return Err(Error::LegacyConfig(
160 "Speaker count must be greater than 0".to_string(),
161 ));
162 }
163
164 if config.speaker_positions.len() != config.speaker_count {
165 return Err(Error::LegacyConfig(
166 "Number of speaker positions must match speaker count".to_string(),
167 ));
168 }
169
170 let mut planner = RealFftPlanner::<f32>::new();
171 let buffer_size = 1024; let forward_fft = planner.plan_fft_forward(buffer_size);
174 let inverse_fft = planner.plan_fft_inverse(buffer_size);
175
176 let frequency_buffer = Array2::zeros((config.speaker_count, buffer_size / 2 + 1));
177 let time_buffer = Array2::zeros((config.speaker_count, buffer_size));
178
179 Ok(Self {
180 config,
181 fft_planner: Arc::new(planner),
182 forward_fft,
183 inverse_fft,
184 driving_functions_cache: HashMap::new(),
185 frequency_buffer,
186 time_buffer,
187 speed_of_sound: 343.0, })
189 }
190
191 pub fn process_source(&mut self, source: &WfsSource) -> Result<Array2<f32>> {
193 let driving_functions = self.compute_driving_functions(source)?;
194 self.apply_driving_functions(&driving_functions, &source.audio_data)
195 }
196
197 fn compute_driving_functions(&mut self, source: &WfsSource) -> Result<Vec<WfsDrivingFunction>> {
199 if let Some(cached) = self.driving_functions_cache.get(&source.id) {
201 return Ok(cached.clone());
202 }
203
204 let mut driving_functions = Vec::with_capacity(self.config.speaker_count);
205
206 for (speaker_idx, speaker_pos) in self.config.speaker_positions.iter().enumerate() {
207 let driving_function = match source.source_type {
208 WfsSourceType::Point => {
209 self.compute_point_source_driving_function(source, speaker_pos, speaker_idx)?
210 }
211 WfsSourceType::PlaneWave => {
212 self.compute_plane_wave_driving_function(source, speaker_pos, speaker_idx)?
213 }
214 WfsSourceType::Extended => {
215 self.compute_extended_source_driving_function(source, speaker_pos, speaker_idx)?
216 }
217 };
218 driving_functions.push(driving_function);
219 }
220
221 self.driving_functions_cache
223 .insert(source.id.clone(), driving_functions.clone());
224
225 Ok(driving_functions)
226 }
227
228 fn compute_point_source_driving_function(
230 &self,
231 source: &WfsSource,
232 speaker_pos: &Position3D,
233 speaker_idx: usize,
234 ) -> Result<WfsDrivingFunction> {
235 let distance = source.position.distance_to(speaker_pos);
237
238 let delay_time = distance / self.speed_of_sound;
240 let delay_samples = delay_time * self.config.sample_rate;
241
242 let amplitude = source.gain * (self.config.reference_distance / distance).sqrt();
244
245 let buffer_size = self.frequency_buffer.ncols();
247 let mut frequency_response = Array1::zeros(buffer_size);
248
249 for (freq_idx, response) in frequency_response.iter_mut().enumerate() {
251 let frequency = freq_idx as f32 * self.config.sample_rate / (2.0 * buffer_size as f32);
252
253 if frequency >= self.config.frequency_range.0
254 && frequency <= self.config.frequency_range.1
255 {
256 let omega = 2.0 * PI * frequency;
258 let wave_number = omega / self.speed_of_sound;
259
260 let phase = -wave_number * distance;
262 *response = Complex32::from_polar(amplitude, phase);
263
264 if self.config.pre_emphasis.enabled {
266 let pre_emphasis_gain = self.compute_pre_emphasis_gain(frequency);
267 *response *= pre_emphasis_gain;
268 }
269 }
270 }
271
272 Ok(WfsDrivingFunction {
273 speaker_index: speaker_idx,
274 frequency_response,
275 delay_samples,
276 amplitude,
277 })
278 }
279
280 fn compute_plane_wave_driving_function(
282 &self,
283 source: &WfsSource,
284 speaker_pos: &Position3D,
285 speaker_idx: usize,
286 ) -> Result<WfsDrivingFunction> {
287 let wave_direction = source.position.normalized();
290 let projection = speaker_pos.dot(&wave_direction);
291
292 let delay_time = projection / self.speed_of_sound;
293 let delay_samples = delay_time * self.config.sample_rate;
294
295 let amplitude = source.gain;
297
298 let buffer_size = self.frequency_buffer.ncols();
300 let mut frequency_response = Array1::zeros(buffer_size);
301
302 for (freq_idx, response) in frequency_response.iter_mut().enumerate() {
303 let frequency = freq_idx as f32 * self.config.sample_rate / (2.0 * buffer_size as f32);
304
305 if frequency >= self.config.frequency_range.0
306 && frequency <= self.config.frequency_range.1
307 {
308 let omega = 2.0 * PI * frequency;
309 let wave_number = omega / self.speed_of_sound;
310 let phase = -wave_number * projection;
311
312 *response = Complex32::from_polar(amplitude, phase);
313 }
314 }
315
316 Ok(WfsDrivingFunction {
317 speaker_index: speaker_idx,
318 frequency_response,
319 delay_samples,
320 amplitude,
321 })
322 }
323
324 fn compute_extended_source_driving_function(
326 &self,
327 source: &WfsSource,
328 speaker_pos: &Position3D,
329 speaker_idx: usize,
330 ) -> Result<WfsDrivingFunction> {
331 self.compute_point_source_driving_function(source, speaker_pos, speaker_idx)
334 }
335
336 fn compute_pre_emphasis_gain(&self, frequency: f32) -> f32 {
338 if !self.config.pre_emphasis.enabled
339 || frequency < self.config.pre_emphasis.cutoff_frequency
340 {
341 return 1.0;
342 }
343
344 let normalized_freq = frequency / self.config.pre_emphasis.cutoff_frequency;
346 normalized_freq.sqrt() }
348
349 fn apply_driving_functions(
351 &mut self,
352 driving_functions: &[WfsDrivingFunction],
353 audio_data: &Array1<f32>,
354 ) -> Result<Array2<f32>> {
355 let output_length = audio_data.len();
356 let mut output = Array2::zeros((self.config.speaker_count, output_length));
357
358 for (speaker_idx, driving_function) in driving_functions.iter().enumerate() {
360 let delayed_signal = self.apply_delay_and_amplitude(
361 audio_data,
362 driving_function.delay_samples,
363 driving_function.amplitude,
364 )?;
365
366 let processed_signal = if self.should_apply_frequency_processing(driving_function) {
368 self.apply_frequency_response(
369 &delayed_signal,
370 &driving_function.frequency_response,
371 )?
372 } else {
373 delayed_signal
374 };
375
376 let output_length = output_length.min(processed_signal.len());
378 output
379 .row_mut(speaker_idx)
380 .slice_mut(s![..output_length])
381 .assign(&processed_signal.slice(s![..output_length]));
382 }
383
384 Ok(output)
385 }
386
387 fn apply_delay_and_amplitude(
389 &self,
390 signal: &Array1<f32>,
391 delay_samples: f32,
392 amplitude: f32,
393 ) -> Result<Array1<f32>> {
394 let signal_length = signal.len();
395 let delay_int = delay_samples.floor() as isize;
396 let delay_frac = delay_samples - delay_int as f32;
397
398 let mut output = Array1::zeros(signal_length);
399
400 if delay_int >= 0 {
402 let start_idx = delay_int as usize;
403 if start_idx < signal_length {
404 let copy_length = signal_length - start_idx;
405 output
406 .slice_mut(s![start_idx..])
407 .assign(&signal.slice(s![..copy_length]));
408 }
409 }
410
411 if delay_frac > 0.001 {
413 for i in 1..signal_length {
414 output[i] = output[i] * (1.0 - delay_frac) + output[i - 1] * delay_frac;
415 }
416 }
417
418 output *= amplitude;
420
421 Ok(output)
422 }
423
424 fn should_apply_frequency_processing(&self, driving_function: &WfsDrivingFunction) -> bool {
426 driving_function
428 .frequency_response
429 .iter()
430 .any(|&response| (response.norm() - 1.0).abs() > 0.1 || response.arg().abs() > 0.1)
431 }
432
433 fn apply_frequency_response(
435 &mut self,
436 signal: &Array1<f32>,
437 frequency_response: &Array1<Complex32>,
438 ) -> Result<Array1<f32>> {
439 let buffer_size = self.frequency_buffer.ncols() * 2 - 2;
440 let mut padded_signal = Array1::zeros(buffer_size);
441
442 let copy_length = signal.len().min(buffer_size);
444 padded_signal
445 .slice_mut(s![..copy_length])
446 .assign(&signal.slice(s![..copy_length]));
447
448 let mut spectrum = Array1::zeros(frequency_response.len());
450 self.forward_fft.process(
451 padded_signal.as_slice().expect("contiguous array"),
452 spectrum.as_slice_mut().expect("contiguous array"),
453 );
454
455 for (spectrum_bin, &response) in spectrum.iter_mut().zip(frequency_response.iter()) {
457 *spectrum_bin *= response;
458 }
459
460 let mut result = Array1::zeros(buffer_size);
462 self.inverse_fft.process(
463 spectrum.as_slice().expect("contiguous array"),
464 result.as_slice_mut().expect("contiguous array"),
465 );
466
467 Ok(result.slice(s![..signal.len()]).to_owned())
469 }
470
471 pub fn update_source_position(&mut self, source_id: &str, new_position: Position3D) {
473 self.driving_functions_cache.remove(source_id);
474 }
475
476 pub fn clear_cache(&mut self) {
478 self.driving_functions_cache.clear();
479 }
480
481 pub fn config(&self) -> &WfsConfig {
483 &self.config
484 }
485
486 pub fn set_speed_of_sound(&mut self, speed: f32) {
488 if speed > 0.0 {
489 self.speed_of_sound = speed;
490 self.clear_cache(); }
492 }
493}
494
495pub struct WfsArrayBuilder {
497 geometry: ArrayGeometry,
498 speaker_count: usize,
499 dimensions: (f32, f32, f32), }
501
502impl WfsArrayBuilder {
503 pub fn new(geometry: ArrayGeometry) -> Self {
505 Self {
506 geometry,
507 speaker_count: 16,
508 dimensions: (3.0, 0.0, 0.0), }
510 }
511
512 pub fn speaker_count(mut self, count: usize) -> Self {
514 self.speaker_count = count;
515 self
516 }
517
518 pub fn dimensions(mut self, width: f32, height: f32, depth: f32) -> Self {
520 self.dimensions = (width, height, depth);
521 self
522 }
523
524 pub fn build_positions(self) -> Vec<Position3D> {
526 match self.geometry {
527 ArrayGeometry::Linear => self.build_linear_array(),
528 ArrayGeometry::Circular => self.build_circular_array(),
529 ArrayGeometry::Rectangular => self.build_rectangular_array(),
530 ArrayGeometry::Custom => vec![], }
532 }
533
534 fn build_linear_array(&self) -> Vec<Position3D> {
535 let spacing = self.dimensions.0 / (self.speaker_count - 1) as f32;
536 let start_x = -self.dimensions.0 / 2.0;
537
538 (0..self.speaker_count)
539 .map(|i| Position3D {
540 x: start_x + i as f32 * spacing,
541 y: 0.0,
542 z: 0.0,
543 })
544 .collect()
545 }
546
547 fn build_circular_array(&self) -> Vec<Position3D> {
548 let radius = self.dimensions.0 / 2.0;
549 let angle_step = 2.0 * PI / self.speaker_count as f32;
550
551 (0..self.speaker_count)
552 .map(|i| {
553 let angle = i as f32 * angle_step;
554 Position3D {
555 x: radius * angle.cos(),
556 y: radius * angle.sin(),
557 z: 0.0,
558 }
559 })
560 .collect()
561 }
562
563 fn build_rectangular_array(&self) -> Vec<Position3D> {
564 let cols = (self.speaker_count as f32).sqrt().ceil() as usize;
566 let rows = self.speaker_count.div_ceil(cols);
567
568 let x_spacing = self.dimensions.0 / (cols - 1) as f32;
569 let y_spacing = self.dimensions.1 / (rows - 1) as f32;
570
571 let start_x = -self.dimensions.0 / 2.0;
572 let start_y = -self.dimensions.1 / 2.0;
573
574 let mut positions = Vec::new();
575 for row in 0..rows {
576 for col in 0..cols {
577 if positions.len() < self.speaker_count {
578 positions.push(Position3D {
579 x: start_x + col as f32 * x_spacing,
580 y: start_y + row as f32 * y_spacing,
581 z: 0.0,
582 });
583 }
584 }
585 }
586 positions
587 }
588}
589
590#[cfg(test)]
591mod tests {
592 use super::*;
593
594 #[test]
595 fn test_wfs_config_default() {
596 let config = WfsConfig::default();
597 assert_eq!(config.speaker_count, 16);
598 assert_eq!(config.array_geometry, ArrayGeometry::Linear);
599 assert_eq!(config.speaker_positions.len(), 16);
600 }
601
602 #[test]
603 fn test_wfs_processor_creation() {
604 let config = WfsConfig::default();
605 let processor = WfsProcessor::new(config);
606 assert!(processor.is_ok());
607 }
608
609 #[test]
610 fn test_array_builder_linear() {
611 let positions = WfsArrayBuilder::new(ArrayGeometry::Linear)
612 .speaker_count(8)
613 .dimensions(2.0, 0.0, 0.0)
614 .build_positions();
615
616 assert_eq!(positions.len(), 8);
617 assert_eq!(positions[0].x, -1.0);
618 assert_eq!(positions[7].x, 1.0);
619 }
620
621 #[test]
622 fn test_array_builder_circular() {
623 let positions = WfsArrayBuilder::new(ArrayGeometry::Circular)
624 .speaker_count(4)
625 .dimensions(2.0, 0.0, 0.0) .build_positions();
627
628 assert_eq!(positions.len(), 4);
629 assert!((positions[0].x - 1.0).abs() < 0.001);
631 assert!(positions[0].y.abs() < 0.001);
632 }
633
634 #[test]
635 fn test_wfs_source_creation() {
636 let source = WfsSource {
637 id: "test_source".to_string(),
638 position: Position3D {
639 x: 1.0,
640 y: 0.0,
641 z: 0.0,
642 },
643 audio_data: Array1::zeros(1024),
644 source_type: WfsSourceType::Point,
645 gain: 1.0,
646 distance: 1.0,
647 };
648
649 assert_eq!(source.id, "test_source");
650 assert_eq!(source.source_type, WfsSourceType::Point);
651 }
652
653 #[test]
654 fn test_processor_source_processing() {
655 let config = WfsConfig::default();
656 let mut processor = WfsProcessor::new(config).unwrap();
657
658 let source = WfsSource {
659 id: "test".to_string(),
660 position: Position3D {
661 x: 2.0,
662 y: 0.0,
663 z: 0.0,
664 },
665 audio_data: Array1::ones(512),
666 source_type: WfsSourceType::Point,
667 gain: 1.0,
668 distance: 2.0,
669 };
670
671 let result = processor.process_source(&source);
672 assert!(result.is_ok());
673
674 let output = result.unwrap();
675 assert_eq!(output.nrows(), 16); assert_eq!(output.ncols(), 512); }
678}