Skip to main content

wavecraft_bridge/
in_memory_host.rs

1//! In-memory ParameterHost implementation for dev tools and tests.
2
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6use crate::{BridgeError, ParameterHost};
7use wavecraft_protocol::{AudioRuntimeStatus, MeterFrame, OscilloscopeFrame, ParameterInfo};
8
9/// Provides metering data for an in-memory host.
10pub trait MeterProvider: Send + Sync {
11    /// Return the latest meter frame, if available.
12    fn get_meter_frame(&self) -> Option<MeterFrame>;
13}
14
15/// Provides oscilloscope frame data for an in-memory host.
16pub trait OscilloscopeProvider: Send + Sync {
17    /// Return the latest oscilloscope frame, if available.
18    fn get_oscilloscope_frame(&self) -> Option<OscilloscopeFrame>;
19}
20
21/// In-memory host for storing parameter values and optional meter data.
22///
23/// This is intended for development tools (like the CLI dev server) and tests.
24pub struct InMemoryParameterHost {
25    parameters: RwLock<Vec<ParameterInfo>>,
26    values: RwLock<HashMap<String, f32>>,
27    meter_provider: Option<Arc<dyn MeterProvider>>,
28    oscilloscope_provider: Option<Arc<dyn OscilloscopeProvider>>,
29}
30
31impl InMemoryParameterHost {
32    /// Create a new in-memory host with the given parameter metadata.
33    pub fn new(parameters: Vec<ParameterInfo>) -> Self {
34        let values = parameters
35            .iter()
36            .map(|p| (p.id.clone(), p.default))
37            .collect();
38
39        Self {
40            parameters: RwLock::new(parameters),
41            values: RwLock::new(values),
42            meter_provider: None,
43            oscilloscope_provider: None,
44        }
45    }
46
47    /// Create a new in-memory host with a meter provider.
48    pub fn with_meter_provider(
49        parameters: Vec<ParameterInfo>,
50        meter_provider: Arc<dyn MeterProvider>,
51    ) -> Self {
52        let mut host = Self::new(parameters);
53        host.meter_provider = Some(meter_provider);
54        host
55    }
56
57    /// Create a new in-memory host with an oscilloscope provider.
58    pub fn with_oscilloscope_provider(
59        parameters: Vec<ParameterInfo>,
60        oscilloscope_provider: Arc<dyn OscilloscopeProvider>,
61    ) -> Self {
62        let mut host = Self::new(parameters);
63        host.oscilloscope_provider = Some(oscilloscope_provider);
64        host
65    }
66
67    /// Create a new in-memory host with both meter and oscilloscope providers.
68    pub fn with_providers(
69        parameters: Vec<ParameterInfo>,
70        meter_provider: Option<Arc<dyn MeterProvider>>,
71        oscilloscope_provider: Option<Arc<dyn OscilloscopeProvider>>,
72    ) -> Self {
73        let mut host = Self::new(parameters);
74        host.meter_provider = meter_provider;
75        host.oscilloscope_provider = oscilloscope_provider;
76        host
77    }
78
79    /// Replace all parameters with new metadata from a fresh build.
80    ///
81    /// This method is used during hot-reload to update parameter definitions
82    /// while preserving existing parameter values where possible. Parameters
83    /// with matching IDs retain their current values; new parameters get
84    /// their default values; removed parameters are dropped.
85    ///
86    /// # Thread Safety
87    ///
88    /// This method acquires write locks on both the parameters and values maps.
89    /// If a lock is poisoned (from a previous panic), it recovers gracefully
90    /// by clearing the poisoned lock and continuing.
91    ///
92    /// # Errors
93    ///
94    /// Returns an error if both lock recovery attempts fail.
95    pub fn replace_parameters(&self, new_params: Vec<ParameterInfo>) -> Result<(), String> {
96        // Acquire values lock with poison recovery
97        let mut values = match self.values.write() {
98            Ok(guard) => guard,
99            Err(poisoned) => {
100                eprintln!("⚠ Recovering from poisoned values lock");
101                poisoned.into_inner()
102            }
103        };
104
105        // Build new values map, preserving existing values where IDs match
106        let mut new_values = HashMap::new();
107        for param in &new_params {
108            let value = values.get(&param.id).copied().unwrap_or(param.default);
109            new_values.insert(param.id.clone(), value);
110        }
111
112        *values = new_values;
113        drop(values); // Release values lock before acquiring parameters lock
114
115        // Acquire parameters lock with poison recovery
116        let mut params = match self.parameters.write() {
117            Ok(guard) => guard,
118            Err(poisoned) => {
119                eprintln!("⚠ Recovering from poisoned parameters lock");
120                poisoned.into_inner()
121            }
122        };
123
124        *params = new_params;
125        Ok(())
126    }
127
128    fn current_value(&self, id: &str, default: f32) -> f32 {
129        self.values
130            .read()
131            .ok()
132            .and_then(|values| values.get(id).copied())
133            .unwrap_or(default)
134    }
135}
136
137impl ParameterHost for InMemoryParameterHost {
138    fn get_parameter(&self, id: &str) -> Option<ParameterInfo> {
139        let parameters = self.parameters.read().ok()?;
140        let param = parameters.iter().find(|p| p.id == id)?;
141
142        Some(ParameterInfo {
143            id: param.id.clone(),
144            name: param.name.clone(),
145            param_type: param.param_type,
146            value: self.current_value(&param.id, param.default),
147            default: param.default,
148            min: param.min,
149            max: param.max,
150            unit: param.unit.clone(),
151            group: param.group.clone(),
152            variants: param.variants.clone(),
153        })
154    }
155
156    fn set_parameter(&self, id: &str, value: f32) -> Result<(), BridgeError> {
157        let parameters = self.parameters.read().ok();
158        let param_exists = parameters
159            .as_ref()
160            .map(|p| p.iter().any(|param| param.id == id))
161            .unwrap_or(false);
162
163        if !param_exists {
164            return Err(BridgeError::ParameterNotFound(id.to_string()));
165        }
166
167        let Some(param) = parameters
168            .as_ref()
169            .and_then(|p| p.iter().find(|param| param.id == id))
170        else {
171            return Err(BridgeError::ParameterNotFound(id.to_string()));
172        };
173
174        if !(param.min..=param.max).contains(&value) {
175            return Err(BridgeError::ParameterOutOfRange {
176                id: id.to_string(),
177                value,
178            });
179        }
180
181        if let Ok(mut values) = self.values.write() {
182            values.insert(id.to_string(), value);
183        }
184
185        Ok(())
186    }
187
188    fn get_all_parameters(&self) -> Vec<ParameterInfo> {
189        let parameters = match self.parameters.read() {
190            Ok(guard) => guard,
191            Err(_) => return Vec::new(), // Return empty on poisoned lock
192        };
193
194        parameters
195            .iter()
196            .map(|param| ParameterInfo {
197                id: param.id.clone(),
198                name: param.name.clone(),
199                param_type: param.param_type,
200                value: self.current_value(&param.id, param.default),
201                default: param.default,
202                min: param.min,
203                max: param.max,
204                unit: param.unit.clone(),
205                group: param.group.clone(),
206                variants: param.variants.clone(),
207            })
208            .collect()
209    }
210
211    fn get_meter_frame(&self) -> Option<MeterFrame> {
212        self.meter_provider
213            .as_ref()
214            .and_then(|provider| provider.get_meter_frame())
215    }
216
217    fn get_oscilloscope_frame(&self) -> Option<OscilloscopeFrame> {
218        self.oscilloscope_provider
219            .as_ref()
220            .and_then(|provider| provider.get_oscilloscope_frame())
221    }
222
223    fn request_resize(&self, _width: u32, _height: u32) -> bool {
224        false
225    }
226
227    fn get_audio_status(&self) -> Option<AudioRuntimeStatus> {
228        None
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use wavecraft_protocol::ParameterType;
236
237    struct StaticMeterProvider {
238        frame: MeterFrame,
239    }
240
241    struct StaticOscilloscopeProvider {
242        frame: OscilloscopeFrame,
243    }
244
245    impl MeterProvider for StaticMeterProvider {
246        fn get_meter_frame(&self) -> Option<MeterFrame> {
247            Some(self.frame)
248        }
249    }
250
251    impl OscilloscopeProvider for StaticOscilloscopeProvider {
252        fn get_oscilloscope_frame(&self) -> Option<OscilloscopeFrame> {
253            Some(self.frame.clone())
254        }
255    }
256
257    fn test_params() -> Vec<ParameterInfo> {
258        vec![
259            ParameterInfo {
260                id: "gain".to_string(),
261                name: "Gain".to_string(),
262                param_type: ParameterType::Float,
263                value: 0.5,
264                default: 0.5,
265                min: 0.0,
266                max: 1.0,
267                unit: Some("dB".to_string()),
268                group: Some("Input".to_string()),
269                variants: None,
270            },
271            ParameterInfo {
272                id: "mix".to_string(),
273                name: "Mix".to_string(),
274                param_type: ParameterType::Float,
275                value: 1.0,
276                default: 1.0,
277                min: 0.0,
278                max: 1.0,
279                unit: Some("%".to_string()),
280                group: None,
281                variants: None,
282            },
283        ]
284    }
285
286    #[test]
287    fn test_get_parameter() {
288        let host = InMemoryParameterHost::new(test_params());
289
290        let param = host.get_parameter("gain").expect("should find gain");
291        assert_eq!(param.id, "gain");
292        assert_eq!(param.name, "Gain");
293        assert!((param.value - 0.5).abs() < f32::EPSILON);
294    }
295
296    #[test]
297    fn test_set_parameter() {
298        let host = InMemoryParameterHost::new(test_params());
299
300        host.set_parameter("gain", 0.75).expect("should set gain");
301
302        let param = host.get_parameter("gain").expect("should find gain");
303        assert!((param.value - 0.75).abs() < f32::EPSILON);
304    }
305
306    #[test]
307    fn test_set_parameter_out_of_range() {
308        let host = InMemoryParameterHost::new(test_params());
309
310        let result = host.set_parameter("gain", 1.5);
311        assert!(result.is_err());
312
313        let result = host.set_parameter("gain", -0.1);
314        assert!(result.is_err());
315    }
316
317    #[test]
318    fn test_get_all_parameters() {
319        let host = InMemoryParameterHost::new(test_params());
320
321        let params = host.get_all_parameters();
322        assert_eq!(params.len(), 2);
323        assert!(params.iter().any(|p| p.id == "gain"));
324        assert!(params.iter().any(|p| p.id == "mix"));
325    }
326
327    #[test]
328    fn test_get_meter_frame() {
329        let frame = MeterFrame {
330            peak_l: 0.7,
331            rms_l: 0.5,
332            peak_r: 0.6,
333            rms_r: 0.4,
334            timestamp: 0,
335        };
336        let provider = Arc::new(StaticMeterProvider { frame });
337        let host = InMemoryParameterHost::with_meter_provider(test_params(), provider);
338
339        let read = host.get_meter_frame().expect("should have meter frame");
340        assert!((read.peak_l - 0.7).abs() < f32::EPSILON);
341        assert!((read.rms_r - 0.4).abs() < f32::EPSILON);
342    }
343
344    #[test]
345    fn test_get_oscilloscope_frame() {
346        let frame = OscilloscopeFrame {
347            points_l: vec![0.1; 1024],
348            points_r: vec![0.2; 1024],
349            sample_rate: 48_000.0,
350            timestamp: 99,
351            no_signal: false,
352            trigger_mode: wavecraft_protocol::OscilloscopeTriggerMode::RisingZeroCrossing,
353        };
354        let provider = Arc::new(StaticOscilloscopeProvider { frame });
355        let host = InMemoryParameterHost::with_oscilloscope_provider(test_params(), provider);
356
357        let read = host
358            .get_oscilloscope_frame()
359            .expect("should have oscilloscope frame");
360        assert_eq!(read.points_l.len(), 1024);
361        assert_eq!(read.points_r.len(), 1024);
362        assert_eq!(read.timestamp, 99);
363    }
364
365    #[test]
366    fn test_replace_parameters_preserves_values() {
367        let host = InMemoryParameterHost::new(test_params());
368
369        // Set custom values
370        host.set_parameter("gain", 0.75).expect("should set gain");
371        host.set_parameter("mix", 0.5).expect("should set mix");
372
373        // Add a new parameter
374        let new_params = vec![
375            ParameterInfo {
376                id: "gain".to_string(),
377                name: "Gain".to_string(),
378                param_type: ParameterType::Float,
379                value: 0.5,
380                default: 0.5,
381                min: 0.0,
382                max: 1.0,
383                unit: Some("dB".to_string()),
384                group: Some("Input".to_string()),
385                variants: None,
386            },
387            ParameterInfo {
388                id: "mix".to_string(),
389                name: "Mix".to_string(),
390                param_type: ParameterType::Float,
391                value: 1.0,
392                default: 1.0,
393                min: 0.0,
394                max: 1.0,
395                unit: Some("%".to_string()),
396                group: None,
397                variants: None,
398            },
399            ParameterInfo {
400                id: "freq".to_string(),
401                name: "Frequency".to_string(),
402                param_type: ParameterType::Float,
403                value: 440.0,
404                default: 440.0,
405                min: 20.0,
406                max: 5_000.0,
407                unit: Some("Hz".to_string()),
408                group: None,
409                variants: None,
410            },
411        ];
412
413        host.replace_parameters(new_params)
414            .expect("should replace parameters");
415
416        // Existing parameters should preserve their values
417        let gain = host.get_parameter("gain").expect("should find gain");
418        assert!((gain.value - 0.75).abs() < f32::EPSILON);
419
420        let mix = host.get_parameter("mix").expect("should find mix");
421        assert!((mix.value - 0.5).abs() < f32::EPSILON);
422
423        // New parameter should have default value
424        let freq = host.get_parameter("freq").expect("should find freq");
425        assert!((freq.value - 440.0).abs() < f32::EPSILON);
426    }
427
428    #[test]
429    fn test_replace_parameters_removes_old() {
430        let host = InMemoryParameterHost::new(test_params());
431
432        // Replace with fewer parameters
433        let new_params = vec![ParameterInfo {
434            id: "gain".to_string(),
435            name: "Gain".to_string(),
436            param_type: ParameterType::Float,
437            value: 0.5,
438            default: 0.5,
439            min: 0.0,
440            max: 1.0,
441            unit: Some("dB".to_string()),
442            group: Some("Input".to_string()),
443            variants: None,
444        }];
445
446        host.replace_parameters(new_params)
447            .expect("should replace parameters");
448
449        // Old parameter should be gone
450        assert!(host.get_parameter("mix").is_none());
451
452        // Kept parameter should still be accessible
453        assert!(host.get_parameter("gain").is_some());
454    }
455
456    #[test]
457    fn test_set_parameter_uses_declared_range_not_normalized_range() {
458        let host = InMemoryParameterHost::new(vec![ParameterInfo {
459            id: "oscillator_frequency".to_string(),
460            name: "Frequency".to_string(),
461            param_type: ParameterType::Float,
462            value: 440.0,
463            default: 440.0,
464            min: 20.0,
465            max: 5_000.0,
466            unit: Some("Hz".to_string()),
467            group: Some("Oscillator".to_string()),
468            variants: None,
469        }]);
470
471        host.set_parameter("oscillator_frequency", 2_000.0)
472            .expect("frequency in declared range should be accepted");
473
474        let freq = host
475            .get_parameter("oscillator_frequency")
476            .expect("frequency should exist");
477        assert!((freq.value - 2_000.0).abs() < f32::EPSILON);
478
479        let too_low = host.set_parameter("oscillator_frequency", 10.0);
480        assert!(too_low.is_err(), "value below min should be rejected");
481
482        let too_high = host.set_parameter("oscillator_frequency", 10_000.0);
483        assert!(too_high.is_err(), "value above max should be rejected");
484    }
485}