1use crate::types::Position3D;
9use crate::{Error, Result};
10use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
11use scirs2_core::Complex32;
12use scirs2_fft::{irfft, rfft, FftPlanner, RealFftPlanner, RealToComplex};
13use serde::{Deserialize, Serialize};
14use std::f32::consts::PI;
15use std::sync::Arc;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
19pub enum BeamformingAlgorithm {
20 DelayAndSum,
22 Mvdr,
24 Gsc,
26 Music,
28 Capon,
30 Frost,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct BeamformingConfig {
37 pub sample_rate: f32,
39 pub array_size: usize,
41 pub element_positions: Vec<Position3D>,
43 pub algorithm: BeamformingAlgorithm,
45 pub target_direction: (f32, f32),
47 pub array_aperture: f32,
49 pub frequency_range: (f32, f32),
51 pub adaptation: AdaptationConfig,
53 pub spatial_smoothing: SpatialSmoothingConfig,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct AdaptationConfig {
60 pub enabled: bool,
62 pub step_size: f32,
64 pub forgetting_factor: f32,
66 pub regularization: f32,
68 pub snapshots: usize,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct SpatialSmoothingConfig {
75 pub enabled: bool,
77 pub subarrays: usize,
79 pub overlap: usize,
81}
82
83#[derive(Debug, Clone)]
85pub struct BeamformerWeights {
86 pub weights: Array2<Complex32>,
88 pub frequencies: Array1<f32>,
90 pub target_direction: (f32, f32),
92}
93
94#[derive(Debug, Clone)]
96pub struct BeamPattern {
97 pub angles: Array1<f32>,
99 pub response: Array1<f32>,
101 pub main_lobe_width: f32,
103 pub side_lobe_level: f32,
105 pub directivity_index: f32,
107}
108
109#[derive(Debug, Clone)]
111pub struct DoaResult {
112 pub directions: Vec<(f32, f32)>,
114 pub confidence: Vec<f32>,
116 pub spectrum: Array2<f32>, pub threshold: f32,
120}
121
122pub struct BeamformingProcessor {
124 config: BeamformingConfig,
126 weights: BeamformerWeights,
128 fft_planner: Arc<RealFftPlanner<f32>>,
130 forward_fft: Arc<dyn RealToComplex<f32>>,
132 inverse_fft: Arc<dyn scirs2_fft::ComplexToReal<f32>>,
134 covariance_matrix: Array3<Complex32>, input_buffer: Array2<Complex32>, output_buffer: Array1<Complex32>,
140 speed_of_sound: f32,
142 adaptation_state: AdaptationState,
144}
145
146#[derive(Debug)]
148struct AdaptationState {
149 snapshot_count: usize,
151 covariance_inverse: Array3<Complex32>,
153 constraint_vectors: Array2<Complex32>,
155}
156
157impl Default for BeamformingConfig {
158 fn default() -> Self {
159 let array_size = 8;
161 let element_spacing = 0.05; let element_positions: Vec<Position3D> = (0..array_size)
163 .map(|i| Position3D {
164 x: (i as f32 - array_size as f32 / 2.0) * element_spacing,
165 y: 0.0,
166 z: 0.0,
167 })
168 .collect();
169
170 Self {
171 sample_rate: 48000.0,
172 array_size,
173 element_positions,
174 algorithm: BeamformingAlgorithm::DelayAndSum,
175 target_direction: (0.0, 0.0), array_aperture: (array_size - 1) as f32 * element_spacing,
177 frequency_range: (200.0, 8000.0),
178 adaptation: AdaptationConfig {
179 enabled: false,
180 step_size: 0.01,
181 forgetting_factor: 0.99,
182 regularization: 0.01,
183 snapshots: 100,
184 },
185 spatial_smoothing: SpatialSmoothingConfig {
186 enabled: false,
187 subarrays: 4,
188 overlap: 2,
189 },
190 }
191 }
192}
193
194impl BeamformingProcessor {
195 pub fn new(config: BeamformingConfig) -> Result<Self> {
197 if config.array_size == 0 {
198 return Err(Error::LegacyConfig(
199 "Array size must be greater than 0".to_string(),
200 ));
201 }
202
203 if config.element_positions.len() != config.array_size {
204 return Err(Error::LegacyConfig(
205 "Number of element positions must match array size".to_string(),
206 ));
207 }
208
209 let mut planner = RealFftPlanner::<f32>::new();
210 let buffer_size = 1024; let frequency_bins = buffer_size / 2 + 1;
212
213 let forward_fft = planner.plan_fft_forward(buffer_size);
214 let inverse_fft = planner.plan_fft_inverse(buffer_size);
215
216 let weights = Self::compute_initial_weights(&config, frequency_bins)?;
218
219 let adaptation_state = AdaptationState {
221 snapshot_count: 0,
222 covariance_inverse: Array3::zeros((
223 frequency_bins,
224 config.array_size,
225 config.array_size,
226 )),
227 constraint_vectors: Array2::zeros((config.array_size, frequency_bins)),
228 };
229
230 let covariance_matrix =
231 Array3::zeros((frequency_bins, config.array_size, config.array_size));
232 let input_buffer = Array2::zeros((config.array_size, frequency_bins));
233 let output_buffer = Array1::zeros(frequency_bins);
234
235 Ok(Self {
236 config,
237 weights,
238 fft_planner: Arc::new(planner),
239 forward_fft,
240 inverse_fft,
241 covariance_matrix,
242 input_buffer,
243 output_buffer,
244 speed_of_sound: 343.0,
245 adaptation_state,
246 })
247 }
248
249 pub fn process(&mut self, input: &Array2<f32>) -> Result<Array1<f32>> {
251 if input.nrows() != self.config.array_size {
252 return Err(Error::LegacyProcessing(format!(
253 "Input must have {} channels, got {}",
254 self.config.array_size,
255 input.nrows()
256 )));
257 }
258
259 let frame_size = input.ncols();
260 let mut output = Array1::zeros(frame_size);
261
262 let block_size = 512;
264 let overlap = block_size / 2;
265 let hop_size = block_size - overlap;
266
267 for block_start in (0..frame_size).step_by(hop_size) {
268 let block_end = (block_start + block_size).min(frame_size);
269 let current_block_size = block_end - block_start;
270
271 if current_block_size < block_size / 2 {
272 break; }
274
275 let mut block = Array2::zeros((self.config.array_size, block_size));
277 for (ch_idx, mut block_row) in block.rows_mut().into_iter().enumerate() {
278 let input_row = input.row(ch_idx);
279 let input_slice = input_row.slice(s![block_start..block_end]);
280 block_row
281 .slice_mut(s![..current_block_size])
282 .assign(&input_slice);
283 }
284
285 let block_output = self.process_block(&block)?;
287
288 let output_start = block_start;
290 let output_end = (output_start + block_output.len()).min(frame_size);
291 let copy_length = output_end - output_start;
292
293 for (i, &value) in block_output.slice(s![..copy_length]).iter().enumerate() {
294 if output_start + i < output.len() {
295 output[output_start + i] += value;
296 }
297 }
298 }
299
300 Ok(output)
301 }
302
303 fn process_block(&mut self, block: &Array2<f32>) -> Result<Array1<f32>> {
305 self.transform_to_frequency_domain(block)?;
307
308 if self.config.adaptation.enabled {
310 self.update_adaptation()?;
311 }
312
313 self.apply_beamforming_weights()?;
315
316 let output = self.transform_to_time_domain()?;
318
319 Ok(output)
320 }
321
322 fn transform_to_frequency_domain(&mut self, input: &Array2<f32>) -> Result<()> {
324 let frequency_bins = self.input_buffer.ncols();
325
326 for (ch_idx, input_row) in input.rows().into_iter().enumerate() {
327 let mut padded_input = input_row.to_vec();
328 padded_input.resize(frequency_bins * 2 - 2, 0.0);
329
330 let mut spectrum = vec![Complex32::new(0.0, 0.0); frequency_bins];
331 self.forward_fft.process(&padded_input, &mut spectrum);
332
333 for (freq_idx, &spectrum_value) in spectrum.iter().enumerate() {
334 self.input_buffer[[ch_idx, freq_idx]] = spectrum_value;
335 }
336 }
337
338 Ok(())
339 }
340
341 fn apply_beamforming_weights(&mut self) -> Result<()> {
343 let frequency_bins = self.output_buffer.len();
344
345 for freq_idx in 0..frequency_bins {
346 let mut output_value = Complex32::new(0.0, 0.0);
347
348 for ch_idx in 0..self.config.array_size {
349 let input_value = self.input_buffer[[ch_idx, freq_idx]];
350 let weight = self.weights.weights[[freq_idx, ch_idx]];
351 output_value += input_value * weight.conj(); }
353
354 self.output_buffer[freq_idx] = output_value;
355 }
356
357 Ok(())
358 }
359
360 fn transform_to_time_domain(&mut self) -> Result<Array1<f32>> {
362 let buffer_size = (self.output_buffer.len() - 1) * 2;
363 let mut spectrum = self.output_buffer.to_vec();
364 let mut output = vec![0.0; buffer_size];
365
366 self.inverse_fft.process(&spectrum, &mut output);
367
368 Ok(Array1::from_vec(output))
369 }
370
371 fn update_adaptation(&mut self) -> Result<()> {
373 match self.config.algorithm {
374 BeamformingAlgorithm::Mvdr => self.update_mvdr_weights(),
375 BeamformingAlgorithm::Capon => self.update_capon_weights(),
376 BeamformingAlgorithm::Frost => self.update_frost_weights(),
377 _ => Ok(()), }
379 }
380
381 fn update_mvdr_weights(&mut self) -> Result<()> {
383 let frequency_bins = self.input_buffer.ncols();
384 let array_size = self.config.array_size;
385
386 for freq_idx in 0..frequency_bins {
387 let input_vector = self.input_buffer.column(freq_idx);
389
390 let forgetting = self.config.adaptation.forgetting_factor;
392 for i in 0..array_size {
393 for j in 0..array_size {
394 let new_value = forgetting * self.covariance_matrix[[freq_idx, i, j]]
395 + (1.0 - forgetting) * input_vector[i] * input_vector[j].conj();
396 self.covariance_matrix[[freq_idx, i, j]] = new_value;
397 }
398 }
399
400 let frequency =
402 freq_idx as f32 * self.config.sample_rate / (2.0 * frequency_bins as f32);
403 let steering_vector =
404 self.compute_steering_vector(frequency, self.config.target_direction);
405
406 let regularization = Complex32::new(self.config.adaptation.regularization, 0.0);
408
409 for i in 0..array_size {
411 self.covariance_matrix[[freq_idx, i, i]] += regularization;
412 }
413
414 for ch_idx in 0..array_size {
416 let weight = steering_vector[ch_idx] / (array_size as f32);
417 self.weights.weights[[freq_idx, ch_idx]] = weight;
418 }
419 }
420
421 Ok(())
422 }
423
424 fn update_capon_weights(&mut self) -> Result<()> {
426 self.update_mvdr_weights()
428 }
429
430 fn update_frost_weights(&mut self) -> Result<()> {
432 let frequency_bins = self.input_buffer.ncols();
433 let step_size = self.config.adaptation.step_size;
434
435 for freq_idx in 0..frequency_bins {
436 let input_vector = self.input_buffer.column(freq_idx);
437 let current_output = self.output_buffer[freq_idx];
438
439 for ch_idx in 0..self.config.array_size {
441 let gradient = input_vector[ch_idx] * current_output.conj();
442 let weight_update = Complex32::new(step_size, 0.0) * gradient;
443 self.weights.weights[[freq_idx, ch_idx]] -= weight_update;
444 }
445 }
446
447 Ok(())
448 }
449
450 fn compute_steering_vector(&self, frequency: f32, direction: (f32, f32)) -> Array1<Complex32> {
452 let (azimuth, elevation) = direction;
453 let wave_number = 2.0 * PI * frequency / self.speed_of_sound;
454
455 let direction_vector = Position3D {
456 x: azimuth.cos() * elevation.cos(),
457 y: azimuth.sin() * elevation.cos(),
458 z: elevation.sin(),
459 };
460
461 let mut steering_vector = Array1::zeros(self.config.array_size);
462
463 for (idx, element_pos) in self.config.element_positions.iter().enumerate() {
464 let delay = element_pos.dot(&direction_vector) / self.speed_of_sound;
466 let phase = wave_number * delay * self.speed_of_sound / frequency;
467 steering_vector[idx] = Complex32::from_polar(1.0, -phase);
468 }
469
470 steering_vector
471 }
472
473 fn compute_initial_weights(
475 config: &BeamformingConfig,
476 frequency_bins: usize,
477 ) -> Result<BeamformerWeights> {
478 let mut weights = Array2::zeros((frequency_bins, config.array_size));
479 let frequencies = Array1::from_shape_fn(frequency_bins, |i| {
480 i as f32 * config.sample_rate / (2.0 * frequency_bins as f32)
481 });
482
483 match config.algorithm {
484 BeamformingAlgorithm::DelayAndSum => {
485 for (freq_idx, &frequency) in frequencies.iter().enumerate() {
487 let steering_vector = Self::compute_delay_and_sum_weights(config, frequency);
488 weights.row_mut(freq_idx).assign(&steering_vector);
489 }
490 }
491 _ => {
492 for (freq_idx, &frequency) in frequencies.iter().enumerate() {
494 let steering_vector = Self::compute_delay_and_sum_weights(config, frequency);
495 weights.row_mut(freq_idx).assign(&steering_vector);
496 }
497 }
498 }
499
500 Ok(BeamformerWeights {
501 weights,
502 frequencies,
503 target_direction: config.target_direction,
504 })
505 }
506
507 fn compute_delay_and_sum_weights(
509 config: &BeamformingConfig,
510 frequency: f32,
511 ) -> Array1<Complex32> {
512 let (azimuth, elevation) = config.target_direction;
513 let wave_number = 2.0 * PI * frequency / 343.0; let direction_vector = Position3D {
516 x: azimuth.cos() * elevation.cos(),
517 y: azimuth.sin() * elevation.cos(),
518 z: elevation.sin(),
519 };
520
521 let mut weights = Array1::zeros(config.array_size);
522
523 for (idx, element_pos) in config.element_positions.iter().enumerate() {
524 let delay = element_pos.dot(&direction_vector) / 343.0;
525 let phase = wave_number * delay;
526 weights[idx] = Complex32::from_polar(1.0 / config.array_size as f32, -phase);
527 }
528
529 weights
530 }
531
532 pub fn estimate_doa(&mut self, input: &Array2<f32>) -> Result<DoaResult> {
534 self.transform_to_frequency_domain(input)?;
536
537 let azimuth_resolution = 1.0; let elevation_resolution = 5.0; let azimuth_range = (-180.0_f32)..180.0_f32;
541 let elevation_range = (-90.0_f32)..90.0_f32;
542
543 let azimuth_steps =
544 ((azimuth_range.end - azimuth_range.start) / azimuth_resolution) as usize;
545 let elevation_steps =
546 ((elevation_range.end - elevation_range.start) / elevation_resolution) as usize;
547
548 let mut spectrum = Array2::zeros((azimuth_steps, elevation_steps));
549 let mut angles = Array1::zeros(azimuth_steps);
550
551 for (az_idx, azimuth_deg) in (0..azimuth_steps).enumerate() {
553 let azimuth = azimuth_range.start + az_idx as f32 * azimuth_resolution;
554 let azimuth_rad = azimuth.to_radians();
555 angles[az_idx] = azimuth_rad;
556
557 for (el_idx, elevation_deg) in (0..elevation_steps).enumerate() {
558 let elevation = elevation_range.start + el_idx as f32 * elevation_resolution;
559 let elevation_rad = elevation.to_radians();
560
561 let power = self.compute_spatial_spectrum_value((azimuth_rad, elevation_rad))?;
563 spectrum[[az_idx, el_idx]] = power;
564 }
565 }
566
567 let threshold = spectrum.fold(0.0f32, |a, &b| a.max(b)) * 0.7; let mut directions = Vec::new();
570 let mut confidence = Vec::new();
571
572 for (az_idx, el_idx) in scirs2_core::ndarray::indices_of(&spectrum) {
573 if spectrum[[az_idx, el_idx]] > threshold {
574 let azimuth = angles[az_idx];
575 let elevation = elevation_range.start + el_idx as f32 * elevation_resolution;
576 directions.push((azimuth, elevation.to_radians()));
577 confidence.push(spectrum[[az_idx, el_idx]] / threshold);
578 }
579 }
580
581 Ok(DoaResult {
582 directions,
583 confidence,
584 spectrum,
585 threshold,
586 })
587 }
588
589 fn compute_spatial_spectrum_value(&self, direction: (f32, f32)) -> Result<f32> {
591 let frequency_bins = self.input_buffer.ncols();
592 let mut total_power = 0.0;
593
594 for freq_idx in 0..frequency_bins {
595 let frequency =
596 freq_idx as f32 * self.config.sample_rate / (2.0 * frequency_bins as f32);
597
598 if frequency < self.config.frequency_range.0
599 || frequency > self.config.frequency_range.1
600 {
601 continue;
602 }
603
604 let steering_vector = self.compute_steering_vector(frequency, direction);
605 let input_vector = self.input_buffer.column(freq_idx);
606
607 let mut power = Complex32::new(0.0, 0.0);
609 for i in 0..self.config.array_size {
610 power += input_vector[i] * steering_vector[i].conj();
611 }
612
613 total_power += power.norm_sqr();
614 }
615
616 Ok(total_power)
617 }
618
619 pub fn compute_beam_pattern(&self, frequency: f32) -> Result<BeamPattern> {
621 let angle_resolution = 1.0; let angle_range = -180.0..180.0;
623 let angle_steps = ((angle_range.end - angle_range.start) / angle_resolution) as usize;
624
625 let mut angles = Array1::zeros(angle_steps);
626 let mut response = Array1::zeros(angle_steps);
627
628 for (idx, angle_deg) in (0..angle_steps).enumerate() {
629 let angle = angle_range.start + idx as f32 * angle_resolution;
630 let angle_rad = angle.to_radians();
631 angles[idx] = angle_rad;
632
633 let direction = (angle_rad, 0.0); let steering_vector = self.compute_steering_vector(frequency, direction);
636
637 let freq_idx = (frequency * 2.0 * self.weights.frequencies.len() as f32
639 / self.config.sample_rate) as usize;
640 let freq_idx = freq_idx.min(self.weights.frequencies.len() - 1);
641
642 let weights_row = self.weights.weights.row(freq_idx);
643
644 let mut pattern_value = Complex32::new(0.0, 0.0);
645 for i in 0..self.config.array_size {
646 pattern_value += weights_row[i] * steering_vector[i];
647 }
648
649 response[idx] = 20.0 * pattern_value.norm().log10(); }
651
652 let max_response = response.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
654 let main_lobe_width = self.compute_main_lobe_width(&angles, &response, max_response);
655 let side_lobe_level = self.compute_side_lobe_level(&response, max_response);
656 let directivity_index = self.compute_directivity_index(&response);
657
658 Ok(BeamPattern {
659 angles,
660 response,
661 main_lobe_width,
662 side_lobe_level,
663 directivity_index,
664 })
665 }
666
667 fn compute_main_lobe_width(
669 &self,
670 angles: &Array1<f32>,
671 response: &Array1<f32>,
672 max_response: f32,
673 ) -> f32 {
674 let threshold = max_response - 3.0; let max_idx = response
678 .iter()
679 .position(|&x| x == max_response)
680 .unwrap_or(0);
681
682 let mut left_idx = max_idx;
683 while left_idx > 0 && response[left_idx] > threshold {
684 left_idx -= 1;
685 }
686
687 let mut right_idx = max_idx;
688 while right_idx < response.len() - 1 && response[right_idx] > threshold {
689 right_idx += 1;
690 }
691
692 if right_idx > left_idx {
693 angles[right_idx] - angles[left_idx]
694 } else {
695 0.0
696 }
697 }
698
699 fn compute_side_lobe_level(&self, response: &Array1<f32>, max_response: f32) -> f32 {
701 let center = response.len() / 2;
703 let quarter = response.len() / 4;
704
705 let left_side_max = response
706 .slice(s![..center - quarter])
707 .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
708 let right_side_max = response
709 .slice(s![center + quarter..])
710 .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
711
712 left_side_max.max(right_side_max) - max_response
713 }
714
715 fn compute_directivity_index(&self, response: &Array1<f32>) -> f32 {
717 let max_response = response.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
718
719 let average_response = response.sum() / response.len() as f32;
721
722 max_response - average_response
723 }
724
725 pub fn set_target_direction(&mut self, azimuth: f32, elevation: f32) {
727 self.config.target_direction = (azimuth, elevation);
728 if let Ok(new_weights) =
730 Self::compute_initial_weights(&self.config, self.weights.frequencies.len())
731 {
732 self.weights = new_weights;
733 }
734 }
735
736 pub fn config(&self) -> &BeamformingConfig {
738 &self.config
739 }
740
741 pub fn weights(&self) -> &BeamformerWeights {
743 &self.weights
744 }
745}
746
747#[cfg(test)]
748mod tests {
749 use super::*;
750
751 #[test]
752 fn test_beamforming_config_default() {
753 let config = BeamformingConfig::default();
754 assert_eq!(config.array_size, 8);
755 assert_eq!(config.algorithm, BeamformingAlgorithm::DelayAndSum);
756 assert_eq!(config.element_positions.len(), 8);
757 }
758
759 #[test]
760 fn test_beamforming_processor_creation() {
761 let config = BeamformingConfig::default();
762 let processor = BeamformingProcessor::new(config);
763 assert!(processor.is_ok());
764 }
765
766 #[test]
767 fn test_steering_vector_computation() {
768 let config = BeamformingConfig::default();
769 let processor = BeamformingProcessor::new(config).unwrap();
770
771 let steering_vector = processor.compute_steering_vector(1000.0, (0.0, 0.0));
772 assert_eq!(steering_vector.len(), 8);
773
774 let magnitude = steering_vector[0].norm();
776 for &element in steering_vector.iter() {
777 assert!((element.norm() - magnitude).abs() < 0.001);
778 }
779 }
780
781 #[test]
782 fn test_beam_pattern_computation() {
783 let config = BeamformingConfig::default();
784 let processor = BeamformingProcessor::new(config).unwrap();
785
786 let pattern = processor.compute_beam_pattern(1000.0);
787 assert!(pattern.is_ok());
788
789 let pattern = pattern.unwrap();
790 assert!(!pattern.angles.is_empty());
791 assert_eq!(pattern.angles.len(), pattern.response.len());
792 assert!(pattern.main_lobe_width > 0.0);
793 }
794
795 #[test]
796 fn test_delay_and_sum_weights() {
797 let config = BeamformingConfig::default();
798 let weights = BeamformingProcessor::compute_delay_and_sum_weights(&config, 1000.0);
799
800 assert_eq!(weights.len(), 8);
801
802 let expected_magnitude = 1.0 / 8.0;
804 for &weight in weights.iter() {
805 assert!((weight.norm() - expected_magnitude).abs() < 0.001);
806 }
807 }
808
809 #[test]
810 fn test_input_processing() {
811 let config = BeamformingConfig::default();
812 let mut processor = BeamformingProcessor::new(config).unwrap();
813
814 let input = Array2::ones((8, 1024));
816
817 let result = processor.process(&input);
818 assert!(result.is_ok());
819
820 let output = result.unwrap();
821 assert_eq!(output.len(), 1024);
822 }
823
824 #[test]
825 fn test_target_direction_update() {
826 let config = BeamformingConfig::default();
827 let mut processor = BeamformingProcessor::new(config).unwrap();
828
829 let original_direction = processor.config.target_direction;
830 processor.set_target_direction(PI / 4.0, 0.0);
831
832 assert_ne!(processor.config.target_direction, original_direction);
833 assert_eq!(processor.config.target_direction, (PI / 4.0, 0.0));
834 }
835}