Skip to main content

voirs_spatial/
plugins.rs

1//! Plugin System for VoiRS Spatial Audio
2//!
3//! This module provides an extensible plugin architecture for the spatial audio system,
4//! allowing third-party developers to create custom spatial processing effects,
5//! HRTF implementations, room simulation algorithms, and audio processing pipelines.
6
7use crate::types::{AudioChannel, BinauraAudio, Position3D};
8use crate::{Error, Result};
9use async_trait::async_trait;
10use serde::{Deserialize, Serialize};
11use std::any::Any;
12use std::collections::HashMap;
13use std::fmt::Debug;
14use std::sync::{Arc, RwLock};
15
16/// Plugin trait that all spatial audio plugins must implement
17#[async_trait]
18pub trait SpatialPlugin: Send + Sync + Debug {
19    /// Get the plugin name
20    fn name(&self) -> &str;
21
22    /// Get the plugin version
23    fn version(&self) -> &str;
24
25    /// Get the plugin description
26    fn description(&self) -> &str;
27
28    /// Get the plugin author
29    fn author(&self) -> &str;
30
31    /// Get the plugin capabilities
32    fn capabilities(&self) -> PluginCapabilities;
33
34    /// Initialize the plugin
35    async fn initialize(&mut self, config: PluginConfig) -> Result<()>;
36
37    /// Process audio with spatial effects
38    async fn process_audio(
39        &self,
40        audio: &[f32],
41        listener_position: Position3D,
42        source_position: Position3D,
43        context: &ProcessingContext,
44    ) -> Result<Vec<f32>>;
45
46    /// Process binaural audio
47    async fn process_binaural(
48        &self,
49        audio: &BinauraAudio,
50        context: &ProcessingContext,
51    ) -> Result<BinauraAudio> {
52        // Default implementation just returns the input unchanged
53        Ok(audio.clone())
54    }
55
56    /// Update plugin parameters
57    async fn update_parameters(&mut self, parameters: PluginParameters) -> Result<()>;
58
59    /// Get current plugin state
60    fn get_state(&self) -> PluginState;
61
62    /// Cleanup plugin resources
63    async fn cleanup(&mut self) -> Result<()>;
64
65    /// Cast to Any for downcasting
66    fn as_any(&self) -> &dyn Any;
67
68    /// Cast to mutable Any for downcasting
69    fn as_any_mut(&mut self) -> &mut dyn Any;
70}
71
72/// Plugin capabilities bitmask
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
74pub struct PluginCapabilities {
75    /// Can process mono audio
76    pub supports_mono: bool,
77    /// Can process stereo audio
78    pub supports_stereo: bool,
79    /// Can process multi-channel audio
80    pub supports_multichannel: bool,
81    /// Can process binaural audio
82    pub supports_binaural: bool,
83    /// Can process real-time streams
84    pub supports_realtime: bool,
85    /// Can process batch audio
86    pub supports_batch: bool,
87    /// Has configurable parameters
88    pub has_parameters: bool,
89    /// Supports state serialization
90    pub supports_serialization: bool,
91    /// Requires GPU acceleration
92    pub requires_gpu: bool,
93    /// Supports 3D positioning
94    pub supports_3d_positioning: bool,
95    /// Supports HRTF processing
96    pub supports_hrtf: bool,
97    /// Supports room simulation
98    pub supports_room_simulation: bool,
99}
100
101impl Default for PluginCapabilities {
102    fn default() -> Self {
103        Self {
104            supports_mono: true,
105            supports_stereo: true,
106            supports_multichannel: false,
107            supports_binaural: false,
108            supports_realtime: true,
109            supports_batch: true,
110            has_parameters: false,
111            supports_serialization: false,
112            requires_gpu: false,
113            supports_3d_positioning: false,
114            supports_hrtf: false,
115            supports_room_simulation: false,
116        }
117    }
118}
119
120/// Plugin configuration
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct PluginConfig {
123    /// Plugin-specific configuration parameters
124    pub parameters: HashMap<String, PluginParameter>,
125    /// Sample rate for audio processing
126    pub sample_rate: f32,
127    /// Buffer size for audio processing
128    pub buffer_size: usize,
129    /// Number of audio channels
130    pub channels: usize,
131    /// Enable GPU acceleration if available
132    pub use_gpu: bool,
133    /// Real-time processing mode
134    pub realtime_mode: bool,
135    /// Quality level (0.0 = lowest, 1.0 = highest)
136    pub quality_level: f32,
137}
138
139impl Default for PluginConfig {
140    fn default() -> Self {
141        Self {
142            parameters: HashMap::new(),
143            sample_rate: 44100.0,
144            buffer_size: 1024,
145            channels: 2,
146            use_gpu: false,
147            realtime_mode: true,
148            quality_level: 0.8,
149        }
150    }
151}
152
153/// Plugin parameter types
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub enum PluginParameter {
156    /// Boolean parameter
157    Bool(bool),
158    /// Integer parameter
159    Int(i32),
160    /// Float parameter
161    Float(f32),
162    /// String parameter
163    String(String),
164    /// Array of floats
165    FloatArray(Vec<f32>),
166    /// Nested parameters
167    Object(HashMap<String, PluginParameter>),
168}
169
170/// Plugin parameters collection
171#[derive(Debug, Clone, Default, Serialize, Deserialize)]
172pub struct PluginParameters {
173    /// Parameter map
174    pub parameters: HashMap<String, PluginParameter>,
175}
176
177impl PluginParameters {
178    /// Create new empty parameters
179    pub fn new() -> Self {
180        Self::default()
181    }
182
183    /// Set a boolean parameter
184    pub fn set_bool(&mut self, key: &str, value: bool) {
185        self.parameters
186            .insert(key.to_string(), PluginParameter::Bool(value));
187    }
188
189    /// Set an integer parameter
190    pub fn set_int(&mut self, key: &str, value: i32) {
191        self.parameters
192            .insert(key.to_string(), PluginParameter::Int(value));
193    }
194
195    /// Set a float parameter
196    pub fn set_float(&mut self, key: &str, value: f32) {
197        self.parameters
198            .insert(key.to_string(), PluginParameter::Float(value));
199    }
200
201    /// Set a string parameter
202    pub fn set_string(&mut self, key: &str, value: String) {
203        self.parameters
204            .insert(key.to_string(), PluginParameter::String(value));
205    }
206
207    /// Get a boolean parameter
208    pub fn get_bool(&self, key: &str) -> Option<bool> {
209        match self.parameters.get(key)? {
210            PluginParameter::Bool(value) => Some(*value),
211            _ => None,
212        }
213    }
214
215    /// Get an integer parameter
216    pub fn get_int(&self, key: &str) -> Option<i32> {
217        match self.parameters.get(key)? {
218            PluginParameter::Int(value) => Some(*value),
219            _ => None,
220        }
221    }
222
223    /// Get a float parameter
224    pub fn get_float(&self, key: &str) -> Option<f32> {
225        match self.parameters.get(key)? {
226            PluginParameter::Float(value) => Some(*value),
227            _ => None,
228        }
229    }
230
231    /// Get a string parameter
232    pub fn get_string(&self, key: &str) -> Option<&str> {
233        match self.parameters.get(key)? {
234            PluginParameter::String(value) => Some(value),
235            _ => None,
236        }
237    }
238}
239
240/// Processing context for plugins
241#[derive(Debug, Clone)]
242pub struct ProcessingContext {
243    /// Current sample rate
244    pub sample_rate: f32,
245    /// Current buffer size
246    pub buffer_size: usize,
247    /// Number of channels
248    pub channels: usize,
249    /// Processing timestamp
250    pub timestamp: std::time::Instant,
251    /// Quality level (0.0 = lowest, 1.0 = highest)
252    pub quality_level: f32,
253    /// Real-time processing mode
254    pub realtime_mode: bool,
255    /// Additional context data
256    pub context_data: HashMap<String, PluginParameter>,
257}
258
259impl Default for ProcessingContext {
260    fn default() -> Self {
261        Self {
262            sample_rate: 44100.0,
263            buffer_size: 1024,
264            channels: 2,
265            timestamp: std::time::Instant::now(),
266            quality_level: 0.8,
267            realtime_mode: true,
268            context_data: HashMap::new(),
269        }
270    }
271}
272
273/// Plugin state
274#[derive(Debug, Clone, Serialize, Deserialize)]
275pub enum PluginState {
276    /// Plugin is uninitialized
277    Uninitialized,
278    /// Plugin is initialized and ready
279    Ready,
280    /// Plugin is currently processing
281    Processing,
282    /// Plugin is paused
283    Paused,
284    /// Plugin has an error
285    Error(String),
286    /// Plugin is being cleaned up
287    Cleanup,
288}
289
290/// Plugin manager for loading and managing spatial audio plugins
291#[derive(Debug)]
292pub struct PluginManager {
293    /// Loaded plugins
294    plugins: Arc<RwLock<HashMap<String, Box<dyn SpatialPlugin>>>>,
295    /// Plugin configurations
296    configs: Arc<RwLock<HashMap<String, PluginConfig>>>,
297    /// Processing chains
298    chains: Arc<RwLock<HashMap<String, ProcessingChain>>>,
299}
300
301impl PluginManager {
302    /// Create a new plugin manager
303    pub fn new() -> Self {
304        Self {
305            plugins: Arc::new(RwLock::new(HashMap::new())),
306            configs: Arc::new(RwLock::new(HashMap::new())),
307            chains: Arc::new(RwLock::new(HashMap::new())),
308        }
309    }
310
311    /// Register a plugin
312    pub async fn register_plugin(
313        &self,
314        plugin: Box<dyn SpatialPlugin>,
315        config: PluginConfig,
316    ) -> Result<()> {
317        let name = plugin.name().to_string();
318
319        // Store plugin configuration
320        {
321            let mut configs = self
322                .configs
323                .write()
324                .map_err(|_| Error::LegacyAudio("Plugin config lock poisoned".to_string()))?;
325            configs.insert(name.clone(), config);
326        }
327
328        // Store plugin
329        {
330            let mut plugins = self
331                .plugins
332                .write()
333                .map_err(|_| Error::LegacyAudio("Plugin lock poisoned".to_string()))?;
334            plugins.insert(name, plugin);
335        }
336
337        Ok(())
338    }
339
340    /// Unregister a plugin
341    pub async fn unregister_plugin(&self, name: &str) -> Result<()> {
342        // Cleanup plugin
343        let plugin_to_cleanup = {
344            let mut plugins = self
345                .plugins
346                .write()
347                .map_err(|_| Error::LegacyAudio("Plugin lock poisoned".to_string()))?;
348            plugins.remove(name)
349        };
350
351        if let Some(mut plugin) = plugin_to_cleanup {
352            plugin.cleanup().await?;
353        }
354
355        // Remove configuration
356        {
357            let mut configs = self
358                .configs
359                .write()
360                .map_err(|_| Error::LegacyAudio("Plugin config lock poisoned".to_string()))?;
361            configs.remove(name);
362        }
363
364        Ok(())
365    }
366
367    /// Get plugin names
368    pub fn get_plugin_names(&self) -> Vec<String> {
369        match self.plugins.read() {
370            Ok(plugins) => plugins.keys().cloned().collect(),
371            Err(_) => {
372                tracing::warn!("Plugin lock poisoned, returning empty list");
373                Vec::new()
374            }
375        }
376    }
377
378    /// Check if plugin exists
379    pub fn has_plugin(&self, name: &str) -> bool {
380        match self.plugins.read() {
381            Ok(plugins) => plugins.contains_key(name),
382            Err(_) => {
383                tracing::warn!("Plugin lock poisoned, assuming plugin doesn't exist");
384                false
385            }
386        }
387    }
388
389    /// Process audio through a specific plugin
390    #[allow(clippy::await_holding_lock)]
391    pub async fn process_with_plugin(
392        &self,
393        plugin_name: &str,
394        audio: &[f32],
395        listener_position: Position3D,
396        source_position: Position3D,
397        context: &ProcessingContext,
398    ) -> Result<Vec<f32>> {
399        let plugins = self
400            .plugins
401            .read()
402            .map_err(|_| Error::LegacyAudio("Plugin lock poisoned".to_string()))?;
403        let plugin = plugins
404            .get(plugin_name)
405            .ok_or_else(|| Error::LegacyAudio(format!("Plugin {plugin_name} not found")))?;
406
407        plugin
408            .process_audio(audio, listener_position, source_position, context)
409            .await
410    }
411
412    /// Process audio through a processing chain
413    #[allow(clippy::await_holding_lock)]
414    pub async fn process_with_chain(
415        &self,
416        chain_name: &str,
417        audio: &[f32],
418        listener_position: Position3D,
419        source_position: Position3D,
420        context: &ProcessingContext,
421    ) -> Result<Vec<f32>> {
422        let chains = self
423            .chains
424            .read()
425            .map_err(|_| Error::LegacyAudio("Chain lock poisoned".to_string()))?;
426        let chain = chains.get(chain_name).ok_or_else(|| {
427            Error::LegacyAudio(format!("Processing chain {chain_name} not found"))
428        })?;
429
430        self.process_chain(chain, audio, listener_position, source_position, context)
431            .await
432    }
433
434    /// Create a processing chain
435    pub async fn create_chain(&self, name: &str, plugin_names: Vec<String>) -> Result<()> {
436        let chain = ProcessingChain {
437            name: name.to_string(),
438            plugins: plugin_names,
439            enabled: true,
440        };
441
442        let mut chains = self
443            .chains
444            .write()
445            .map_err(|_| Error::LegacyAudio("Chain lock poisoned".to_string()))?;
446        chains.insert(name.to_string(), chain);
447
448        Ok(())
449    }
450
451    /// Remove a processing chain
452    pub async fn remove_chain(&self, name: &str) -> Result<()> {
453        let mut chains = self
454            .chains
455            .write()
456            .map_err(|_| Error::LegacyAudio("Chain lock poisoned".to_string()))?;
457        chains.remove(name);
458        Ok(())
459    }
460
461    /// Process through a processing chain
462    async fn process_chain(
463        &self,
464        chain: &ProcessingChain,
465        mut audio: &[f32],
466        listener_position: Position3D,
467        source_position: Position3D,
468        context: &ProcessingContext,
469    ) -> Result<Vec<f32>> {
470        if !chain.enabled {
471            return Ok(audio.to_vec());
472        }
473
474        let mut result = audio.to_vec();
475
476        for plugin_name in &chain.plugins {
477            result = self
478                .process_with_plugin(
479                    plugin_name,
480                    &result,
481                    listener_position,
482                    source_position,
483                    context,
484                )
485                .await?;
486        }
487
488        Ok(result)
489    }
490
491    /// Get plugin capabilities
492    pub fn get_plugin_capabilities(&self, name: &str) -> Option<PluginCapabilities> {
493        match self.plugins.read() {
494            Ok(plugins) => plugins.get(name).map(|plugin| plugin.capabilities()),
495            Err(_) => {
496                tracing::warn!("Plugin lock poisoned, returning None for capabilities");
497                None
498            }
499        }
500    }
501
502    /// Update plugin parameters
503    #[allow(clippy::await_holding_lock)]
504    pub async fn update_plugin_parameters(
505        &self,
506        plugin_name: &str,
507        parameters: PluginParameters,
508    ) -> Result<()> {
509        let mut plugins = self
510            .plugins
511            .write()
512            .map_err(|_| Error::LegacyAudio("Plugin lock poisoned".to_string()))?;
513        let plugin = plugins
514            .get_mut(plugin_name)
515            .ok_or_else(|| Error::LegacyAudio(format!("Plugin {plugin_name} not found")))?;
516
517        plugin.update_parameters(parameters).await
518    }
519}
520
521impl Default for PluginManager {
522    fn default() -> Self {
523        Self::new()
524    }
525}
526
527/// Processing chain definition
528#[derive(Debug, Clone, Serialize, Deserialize)]
529pub struct ProcessingChain {
530    /// Chain name
531    pub name: String,
532    /// Ordered list of plugin names
533    pub plugins: Vec<String>,
534    /// Whether the chain is enabled
535    pub enabled: bool,
536}
537
538/// Example reverb plugin implementation
539#[derive(Debug)]
540pub struct ReverbPlugin {
541    name: String,
542    version: String,
543    room_size: f32,
544    damping: f32,
545    wet_level: f32,
546    dry_level: f32,
547    state: PluginState,
548}
549
550impl Default for ReverbPlugin {
551    fn default() -> Self {
552        Self::new()
553    }
554}
555
556impl ReverbPlugin {
557    /// Create a new reverb plugin
558    pub fn new() -> Self {
559        Self {
560            name: "Spatial Reverb".to_string(),
561            version: "1.0.0".to_string(),
562            room_size: 0.5,
563            damping: 0.5,
564            wet_level: 0.3,
565            dry_level: 0.7,
566            state: PluginState::Uninitialized,
567        }
568    }
569}
570
571#[async_trait]
572impl SpatialPlugin for ReverbPlugin {
573    fn name(&self) -> &str {
574        &self.name
575    }
576    fn version(&self) -> &str {
577        &self.version
578    }
579    fn description(&self) -> &str {
580        "Spatial reverb effect for room simulation"
581    }
582    fn author(&self) -> &str {
583        "VoiRS Team"
584    }
585
586    fn capabilities(&self) -> PluginCapabilities {
587        PluginCapabilities {
588            supports_mono: true,
589            supports_stereo: true,
590            supports_multichannel: true,
591            supports_binaural: true,
592            supports_realtime: true,
593            supports_batch: true,
594            has_parameters: true,
595            supports_serialization: true,
596            requires_gpu: false,
597            supports_3d_positioning: true,
598            supports_hrtf: false,
599            supports_room_simulation: true,
600        }
601    }
602
603    async fn initialize(&mut self, config: PluginConfig) -> Result<()> {
604        // Initialize reverb parameters from config
605        if let Some(PluginParameter::Float(size)) = config.parameters.get("room_size") {
606            self.room_size = *size;
607        }
608        if let Some(PluginParameter::Float(damping)) = config.parameters.get("damping") {
609            self.damping = *damping;
610        }
611        if let Some(PluginParameter::Float(wet)) = config.parameters.get("wet_level") {
612            self.wet_level = *wet;
613        }
614        if let Some(PluginParameter::Float(dry)) = config.parameters.get("dry_level") {
615            self.dry_level = *dry;
616        }
617
618        self.state = PluginState::Ready;
619        Ok(())
620    }
621
622    async fn process_audio(
623        &self,
624        audio: &[f32],
625        listener_position: Position3D,
626        source_position: Position3D,
627        context: &ProcessingContext,
628    ) -> Result<Vec<f32>> {
629        if matches!(self.state, PluginState::Error(_)) {
630            return Err(Error::LegacyAudio("Plugin is in error state".to_string()));
631        }
632
633        // Calculate distance for reverb scaling
634        let distance = listener_position.distance_to(&source_position);
635        let reverb_scale = (distance / 10.0).min(1.0); // Scale reverb with distance
636
637        // Simple reverb simulation (placeholder for real implementation)
638        let mut output = Vec::with_capacity(audio.len());
639        for (i, &sample) in audio.iter().enumerate() {
640            // Simple delay-based reverb
641            let delayed_sample = if i >= context.buffer_size / 4 {
642                audio[i - context.buffer_size / 4] * self.room_size * reverb_scale
643            } else {
644                0.0
645            };
646
647            let wet = delayed_sample * self.wet_level * reverb_scale;
648            let dry = sample * self.dry_level;
649            output.push(dry + wet);
650        }
651
652        Ok(output)
653    }
654
655    async fn update_parameters(&mut self, parameters: PluginParameters) -> Result<()> {
656        if let Some(size) = parameters.get_float("room_size") {
657            self.room_size = size.clamp(0.0, 1.0);
658        }
659        if let Some(damping) = parameters.get_float("damping") {
660            self.damping = damping.clamp(0.0, 1.0);
661        }
662        if let Some(wet) = parameters.get_float("wet_level") {
663            self.wet_level = wet.clamp(0.0, 1.0);
664        }
665        if let Some(dry) = parameters.get_float("dry_level") {
666            self.dry_level = dry.clamp(0.0, 1.0);
667        }
668        Ok(())
669    }
670
671    fn get_state(&self) -> PluginState {
672        self.state.clone()
673    }
674
675    async fn cleanup(&mut self) -> Result<()> {
676        self.state = PluginState::Cleanup;
677        Ok(())
678    }
679
680    fn as_any(&self) -> &dyn Any {
681        self
682    }
683    fn as_any_mut(&mut self) -> &mut dyn Any {
684        self
685    }
686}
687
688#[cfg(test)]
689mod tests {
690    use super::*;
691    use tokio::test;
692
693    #[test]
694    async fn test_plugin_manager_creation() {
695        let manager = PluginManager::new();
696        assert_eq!(manager.get_plugin_names().len(), 0);
697    }
698
699    #[test]
700    async fn test_plugin_registration() {
701        let manager = PluginManager::new();
702        let plugin = Box::new(ReverbPlugin::new());
703        let config = PluginConfig::default();
704
705        manager.register_plugin(plugin, config).await.unwrap();
706        assert_eq!(manager.get_plugin_names().len(), 1);
707        assert!(manager.has_plugin("Spatial Reverb"));
708    }
709
710    #[test]
711    async fn test_plugin_capabilities() {
712        let plugin = ReverbPlugin::new();
713        let caps = plugin.capabilities();
714
715        assert!(caps.supports_mono);
716        assert!(caps.supports_stereo);
717        assert!(caps.supports_room_simulation);
718        assert!(caps.has_parameters);
719    }
720
721    #[test]
722    async fn test_plugin_parameters() {
723        let mut params = PluginParameters::new();
724        params.set_float("room_size", 0.8);
725        params.set_bool("enabled", true);
726        params.set_string("preset", "Hall".to_string());
727
728        assert_eq!(params.get_float("room_size"), Some(0.8));
729        assert_eq!(params.get_bool("enabled"), Some(true));
730        assert_eq!(params.get_string("preset"), Some("Hall"));
731    }
732
733    #[test]
734    async fn test_plugin_audio_processing() {
735        let mut plugin = ReverbPlugin::new();
736        let config = PluginConfig::default();
737        plugin.initialize(config).await.unwrap();
738
739        let audio = vec![0.5; 1024];
740        let listener_pos = Position3D::new(0.0, 0.0, 0.0);
741        let source_pos = Position3D::new(1.0, 0.0, 0.0);
742        let context = ProcessingContext::default();
743
744        let result = plugin
745            .process_audio(&audio, listener_pos, source_pos, &context)
746            .await
747            .unwrap();
748        assert_eq!(result.len(), audio.len());
749    }
750
751    #[test]
752    async fn test_processing_chain() {
753        let manager = PluginManager::new();
754
755        // Register a plugin
756        let plugin = Box::new(ReverbPlugin::new());
757        let config = PluginConfig::default();
758        manager.register_plugin(plugin, config).await.unwrap();
759
760        // Create a processing chain
761        manager
762            .create_chain("test_chain", vec!["Spatial Reverb".to_string()])
763            .await
764            .unwrap();
765
766        // Process audio through the chain
767        let audio = vec![0.5; 1024];
768        let listener_pos = Position3D::new(0.0, 0.0, 0.0);
769        let source_pos = Position3D::new(1.0, 0.0, 0.0);
770        let context = ProcessingContext::default();
771
772        let result = manager
773            .process_with_chain("test_chain", &audio, listener_pos, source_pos, &context)
774            .await
775            .unwrap();
776        assert_eq!(result.len(), audio.len());
777    }
778
779    #[test]
780    async fn test_plugin_cleanup() {
781        let manager = PluginManager::new();
782        let plugin = Box::new(ReverbPlugin::new());
783        let config = PluginConfig::default();
784
785        manager.register_plugin(plugin, config).await.unwrap();
786        assert!(manager.has_plugin("Spatial Reverb"));
787
788        manager.unregister_plugin("Spatial Reverb").await.unwrap();
789        assert!(!manager.has_plugin("Spatial Reverb"));
790    }
791}