Skip to main content

wavecraft_bridge/
in_memory_host.rs

1//! In-memory ParameterHost implementation for dev tools and tests.
2
3use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6use crate::{BridgeError, ParameterHost};
7use wavecraft_protocol::{AudioRuntimeStatus, MeterFrame, ParameterInfo};
8
9/// Provides metering data for an in-memory host.
10pub trait MeterProvider: Send + Sync {
11    /// Return the latest meter frame, if available.
12    fn get_meter_frame(&self) -> Option<MeterFrame>;
13}
14
15/// In-memory host for storing parameter values and optional meter data.
16///
17/// This is intended for development tools (like the CLI dev server) and tests.
18pub struct InMemoryParameterHost {
19    parameters: RwLock<Vec<ParameterInfo>>,
20    values: RwLock<HashMap<String, f32>>,
21    meter_provider: Option<Arc<dyn MeterProvider>>,
22}
23
24impl InMemoryParameterHost {
25    /// Create a new in-memory host with the given parameter metadata.
26    pub fn new(parameters: Vec<ParameterInfo>) -> Self {
27        let values = parameters
28            .iter()
29            .map(|p| (p.id.clone(), p.default))
30            .collect();
31
32        Self {
33            parameters: RwLock::new(parameters),
34            values: RwLock::new(values),
35            meter_provider: None,
36        }
37    }
38
39    /// Create a new in-memory host with a meter provider.
40    pub fn with_meter_provider(
41        parameters: Vec<ParameterInfo>,
42        meter_provider: Arc<dyn MeterProvider>,
43    ) -> Self {
44        let mut host = Self::new(parameters);
45        host.meter_provider = Some(meter_provider);
46        host
47    }
48
49    /// Replace all parameters with new metadata from a fresh build.
50    ///
51    /// This method is used during hot-reload to update parameter definitions
52    /// while preserving existing parameter values where possible. Parameters
53    /// with matching IDs retain their current values; new parameters get
54    /// their default values; removed parameters are dropped.
55    ///
56    /// # Thread Safety
57    ///
58    /// This method acquires write locks on both the parameters and values maps.
59    /// If a lock is poisoned (from a previous panic), it recovers gracefully
60    /// by clearing the poisoned lock and continuing.
61    ///
62    /// # Errors
63    ///
64    /// Returns an error if both lock recovery attempts fail.
65    pub fn replace_parameters(&self, new_params: Vec<ParameterInfo>) -> Result<(), String> {
66        // Acquire values lock with poison recovery
67        let mut values = match self.values.write() {
68            Ok(guard) => guard,
69            Err(poisoned) => {
70                eprintln!("⚠ Recovering from poisoned values lock");
71                poisoned.into_inner()
72            }
73        };
74
75        // Build new values map, preserving existing values where IDs match
76        let mut new_values = HashMap::new();
77        for param in &new_params {
78            let value = values.get(&param.id).copied().unwrap_or(param.default);
79            new_values.insert(param.id.clone(), value);
80        }
81
82        *values = new_values;
83        drop(values); // Release values lock before acquiring parameters lock
84
85        // Acquire parameters lock with poison recovery
86        let mut params = match self.parameters.write() {
87            Ok(guard) => guard,
88            Err(poisoned) => {
89                eprintln!("⚠ Recovering from poisoned parameters lock");
90                poisoned.into_inner()
91            }
92        };
93
94        *params = new_params;
95        Ok(())
96    }
97
98    fn current_value(&self, id: &str, default: f32) -> f32 {
99        self.values
100            .read()
101            .ok()
102            .and_then(|values| values.get(id).copied())
103            .unwrap_or(default)
104    }
105}
106
107impl ParameterHost for InMemoryParameterHost {
108    fn get_parameter(&self, id: &str) -> Option<ParameterInfo> {
109        let parameters = self.parameters.read().ok()?;
110        let param = parameters.iter().find(|p| p.id == id)?;
111
112        Some(ParameterInfo {
113            id: param.id.clone(),
114            name: param.name.clone(),
115            param_type: param.param_type,
116            value: self.current_value(&param.id, param.default),
117            default: param.default,
118            min: param.min,
119            max: param.max,
120            unit: param.unit.clone(),
121            group: param.group.clone(),
122        })
123    }
124
125    fn set_parameter(&self, id: &str, value: f32) -> Result<(), BridgeError> {
126        let parameters = self.parameters.read().ok();
127        let param_exists = parameters
128            .as_ref()
129            .map(|p| p.iter().any(|param| param.id == id))
130            .unwrap_or(false);
131
132        if !param_exists {
133            return Err(BridgeError::ParameterNotFound(id.to_string()));
134        }
135
136        let Some(param) = parameters
137            .as_ref()
138            .and_then(|p| p.iter().find(|param| param.id == id))
139        else {
140            return Err(BridgeError::ParameterNotFound(id.to_string()));
141        };
142
143        if !(param.min..=param.max).contains(&value) {
144            return Err(BridgeError::ParameterOutOfRange {
145                id: id.to_string(),
146                value,
147            });
148        }
149
150        if let Ok(mut values) = self.values.write() {
151            values.insert(id.to_string(), value);
152        }
153
154        Ok(())
155    }
156
157    fn get_all_parameters(&self) -> Vec<ParameterInfo> {
158        let parameters = match self.parameters.read() {
159            Ok(guard) => guard,
160            Err(_) => return Vec::new(), // Return empty on poisoned lock
161        };
162
163        parameters
164            .iter()
165            .map(|param| ParameterInfo {
166                id: param.id.clone(),
167                name: param.name.clone(),
168                param_type: param.param_type,
169                value: self.current_value(&param.id, param.default),
170                default: param.default,
171                min: param.min,
172                max: param.max,
173                unit: param.unit.clone(),
174                group: param.group.clone(),
175            })
176            .collect()
177    }
178
179    fn get_meter_frame(&self) -> Option<MeterFrame> {
180        self.meter_provider
181            .as_ref()
182            .and_then(|provider| provider.get_meter_frame())
183    }
184
185    fn request_resize(&self, _width: u32, _height: u32) -> bool {
186        false
187    }
188
189    fn get_audio_status(&self) -> Option<AudioRuntimeStatus> {
190        None
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use wavecraft_protocol::ParameterType;
198
199    struct StaticMeterProvider {
200        frame: MeterFrame,
201    }
202
203    impl MeterProvider for StaticMeterProvider {
204        fn get_meter_frame(&self) -> Option<MeterFrame> {
205            Some(self.frame)
206        }
207    }
208
209    fn test_params() -> Vec<ParameterInfo> {
210        vec![
211            ParameterInfo {
212                id: "gain".to_string(),
213                name: "Gain".to_string(),
214                param_type: ParameterType::Float,
215                value: 0.5,
216                default: 0.5,
217                min: 0.0,
218                max: 1.0,
219                unit: Some("dB".to_string()),
220                group: Some("Input".to_string()),
221            },
222            ParameterInfo {
223                id: "mix".to_string(),
224                name: "Mix".to_string(),
225                param_type: ParameterType::Float,
226                value: 1.0,
227                default: 1.0,
228                min: 0.0,
229                max: 1.0,
230                unit: Some("%".to_string()),
231                group: None,
232            },
233        ]
234    }
235
236    #[test]
237    fn test_get_parameter() {
238        let host = InMemoryParameterHost::new(test_params());
239
240        let param = host.get_parameter("gain").expect("should find gain");
241        assert_eq!(param.id, "gain");
242        assert_eq!(param.name, "Gain");
243        assert!((param.value - 0.5).abs() < f32::EPSILON);
244    }
245
246    #[test]
247    fn test_set_parameter() {
248        let host = InMemoryParameterHost::new(test_params());
249
250        host.set_parameter("gain", 0.75).expect("should set gain");
251
252        let param = host.get_parameter("gain").expect("should find gain");
253        assert!((param.value - 0.75).abs() < f32::EPSILON);
254    }
255
256    #[test]
257    fn test_set_parameter_out_of_range() {
258        let host = InMemoryParameterHost::new(test_params());
259
260        let result = host.set_parameter("gain", 1.5);
261        assert!(result.is_err());
262
263        let result = host.set_parameter("gain", -0.1);
264        assert!(result.is_err());
265    }
266
267    #[test]
268    fn test_get_all_parameters() {
269        let host = InMemoryParameterHost::new(test_params());
270
271        let params = host.get_all_parameters();
272        assert_eq!(params.len(), 2);
273        assert!(params.iter().any(|p| p.id == "gain"));
274        assert!(params.iter().any(|p| p.id == "mix"));
275    }
276
277    #[test]
278    fn test_get_meter_frame() {
279        let frame = MeterFrame {
280            peak_l: 0.7,
281            rms_l: 0.5,
282            peak_r: 0.6,
283            rms_r: 0.4,
284            timestamp: 0,
285        };
286        let provider = Arc::new(StaticMeterProvider { frame });
287        let host = InMemoryParameterHost::with_meter_provider(test_params(), provider);
288
289        let read = host.get_meter_frame().expect("should have meter frame");
290        assert!((read.peak_l - 0.7).abs() < f32::EPSILON);
291        assert!((read.rms_r - 0.4).abs() < f32::EPSILON);
292    }
293
294    #[test]
295    fn test_replace_parameters_preserves_values() {
296        let host = InMemoryParameterHost::new(test_params());
297
298        // Set custom values
299        host.set_parameter("gain", 0.75).expect("should set gain");
300        host.set_parameter("mix", 0.5).expect("should set mix");
301
302        // Add a new parameter
303        let new_params = vec![
304            ParameterInfo {
305                id: "gain".to_string(),
306                name: "Gain".to_string(),
307                param_type: ParameterType::Float,
308                value: 0.5,
309                default: 0.5,
310                min: 0.0,
311                max: 1.0,
312                unit: Some("dB".to_string()),
313                group: Some("Input".to_string()),
314            },
315            ParameterInfo {
316                id: "mix".to_string(),
317                name: "Mix".to_string(),
318                param_type: ParameterType::Float,
319                value: 1.0,
320                default: 1.0,
321                min: 0.0,
322                max: 1.0,
323                unit: Some("%".to_string()),
324                group: None,
325            },
326            ParameterInfo {
327                id: "freq".to_string(),
328                name: "Frequency".to_string(),
329                param_type: ParameterType::Float,
330                value: 440.0,
331                default: 440.0,
332                min: 20.0,
333                max: 5_000.0,
334                unit: Some("Hz".to_string()),
335                group: None,
336            },
337        ];
338
339        host.replace_parameters(new_params)
340            .expect("should replace parameters");
341
342        // Existing parameters should preserve their values
343        let gain = host.get_parameter("gain").expect("should find gain");
344        assert!((gain.value - 0.75).abs() < f32::EPSILON);
345
346        let mix = host.get_parameter("mix").expect("should find mix");
347        assert!((mix.value - 0.5).abs() < f32::EPSILON);
348
349        // New parameter should have default value
350        let freq = host.get_parameter("freq").expect("should find freq");
351        assert!((freq.value - 440.0).abs() < f32::EPSILON);
352    }
353
354    #[test]
355    fn test_replace_parameters_removes_old() {
356        let host = InMemoryParameterHost::new(test_params());
357
358        // Replace with fewer parameters
359        let new_params = vec![ParameterInfo {
360            id: "gain".to_string(),
361            name: "Gain".to_string(),
362            param_type: ParameterType::Float,
363            value: 0.5,
364            default: 0.5,
365            min: 0.0,
366            max: 1.0,
367            unit: Some("dB".to_string()),
368            group: Some("Input".to_string()),
369        }];
370
371        host.replace_parameters(new_params)
372            .expect("should replace parameters");
373
374        // Old parameter should be gone
375        assert!(host.get_parameter("mix").is_none());
376
377        // Kept parameter should still be accessible
378        assert!(host.get_parameter("gain").is_some());
379    }
380
381    #[test]
382    fn test_set_parameter_uses_declared_range_not_normalized_range() {
383        let host = InMemoryParameterHost::new(vec![ParameterInfo {
384            id: "oscillator_frequency".to_string(),
385            name: "Frequency".to_string(),
386            param_type: ParameterType::Float,
387            value: 440.0,
388            default: 440.0,
389            min: 20.0,
390            max: 5_000.0,
391            unit: Some("Hz".to_string()),
392            group: Some("Oscillator".to_string()),
393        }]);
394
395        host.set_parameter("oscillator_frequency", 2_000.0)
396            .expect("frequency in declared range should be accepted");
397
398        let freq = host
399            .get_parameter("oscillator_frequency")
400            .expect("frequency should exist");
401        assert!((freq.value - 2_000.0).abs() < f32::EPSILON);
402
403        let too_low = host.set_parameter("oscillator_frequency", 10.0);
404        assert!(too_low.is_err(), "value below min should be rejected");
405
406        let too_high = host.set_parameter("oscillator_frequency", 10_000.0);
407        assert!(too_high.is_err(), "value above max should be rejected");
408    }
409}