1use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6use crate::{BridgeError, ParameterHost};
7use wavecraft_protocol::{AudioRuntimeStatus, MeterFrame, OscilloscopeFrame, ParameterInfo};
8
9pub trait MeterProvider: Send + Sync {
11 fn get_meter_frame(&self) -> Option<MeterFrame>;
13}
14
15pub trait OscilloscopeProvider: Send + Sync {
17 fn get_oscilloscope_frame(&self) -> Option<OscilloscopeFrame>;
19}
20
21pub struct InMemoryParameterHost {
25 parameters: RwLock<Vec<ParameterInfo>>,
26 values: RwLock<HashMap<String, f32>>,
27 meter_provider: Option<Arc<dyn MeterProvider>>,
28 oscilloscope_provider: Option<Arc<dyn OscilloscopeProvider>>,
29}
30
31impl InMemoryParameterHost {
32 pub fn new(parameters: Vec<ParameterInfo>) -> Self {
34 let values = parameters
35 .iter()
36 .map(|p| (p.id.clone(), p.default))
37 .collect();
38
39 Self {
40 parameters: RwLock::new(parameters),
41 values: RwLock::new(values),
42 meter_provider: None,
43 oscilloscope_provider: None,
44 }
45 }
46
47 pub fn with_meter_provider(
49 parameters: Vec<ParameterInfo>,
50 meter_provider: Arc<dyn MeterProvider>,
51 ) -> Self {
52 let mut host = Self::new(parameters);
53 host.meter_provider = Some(meter_provider);
54 host
55 }
56
57 pub fn with_oscilloscope_provider(
59 parameters: Vec<ParameterInfo>,
60 oscilloscope_provider: Arc<dyn OscilloscopeProvider>,
61 ) -> Self {
62 let mut host = Self::new(parameters);
63 host.oscilloscope_provider = Some(oscilloscope_provider);
64 host
65 }
66
67 pub fn with_providers(
69 parameters: Vec<ParameterInfo>,
70 meter_provider: Option<Arc<dyn MeterProvider>>,
71 oscilloscope_provider: Option<Arc<dyn OscilloscopeProvider>>,
72 ) -> Self {
73 let mut host = Self::new(parameters);
74 host.meter_provider = meter_provider;
75 host.oscilloscope_provider = oscilloscope_provider;
76 host
77 }
78
79 pub fn replace_parameters(&self, new_params: Vec<ParameterInfo>) -> Result<(), String> {
96 let mut values = match self.values.write() {
98 Ok(guard) => guard,
99 Err(poisoned) => {
100 eprintln!("⚠ Recovering from poisoned values lock");
101 poisoned.into_inner()
102 }
103 };
104
105 let mut new_values = HashMap::new();
107 for param in &new_params {
108 let value = values.get(¶m.id).copied().unwrap_or(param.default);
109 new_values.insert(param.id.clone(), value);
110 }
111
112 *values = new_values;
113 drop(values); let mut params = match self.parameters.write() {
117 Ok(guard) => guard,
118 Err(poisoned) => {
119 eprintln!("⚠ Recovering from poisoned parameters lock");
120 poisoned.into_inner()
121 }
122 };
123
124 *params = new_params;
125 Ok(())
126 }
127
128 fn current_value(&self, id: &str, default: f32) -> f32 {
129 self.values
130 .read()
131 .ok()
132 .and_then(|values| values.get(id).copied())
133 .unwrap_or(default)
134 }
135}
136
137impl ParameterHost for InMemoryParameterHost {
138 fn get_parameter(&self, id: &str) -> Option<ParameterInfo> {
139 let parameters = self.parameters.read().ok()?;
140 let param = parameters.iter().find(|p| p.id == id)?;
141
142 Some(ParameterInfo {
143 id: param.id.clone(),
144 name: param.name.clone(),
145 param_type: param.param_type,
146 value: self.current_value(¶m.id, param.default),
147 default: param.default,
148 min: param.min,
149 max: param.max,
150 unit: param.unit.clone(),
151 group: param.group.clone(),
152 variants: param.variants.clone(),
153 })
154 }
155
156 fn set_parameter(&self, id: &str, value: f32) -> Result<(), BridgeError> {
157 let parameters = self.parameters.read().ok();
158 let param_exists = parameters
159 .as_ref()
160 .map(|p| p.iter().any(|param| param.id == id))
161 .unwrap_or(false);
162
163 if !param_exists {
164 return Err(BridgeError::ParameterNotFound(id.to_string()));
165 }
166
167 let Some(param) = parameters
168 .as_ref()
169 .and_then(|p| p.iter().find(|param| param.id == id))
170 else {
171 return Err(BridgeError::ParameterNotFound(id.to_string()));
172 };
173
174 if !(param.min..=param.max).contains(&value) {
175 return Err(BridgeError::ParameterOutOfRange {
176 id: id.to_string(),
177 value,
178 });
179 }
180
181 if let Ok(mut values) = self.values.write() {
182 values.insert(id.to_string(), value);
183 }
184
185 Ok(())
186 }
187
188 fn get_all_parameters(&self) -> Vec<ParameterInfo> {
189 let parameters = match self.parameters.read() {
190 Ok(guard) => guard,
191 Err(_) => return Vec::new(), };
193
194 parameters
195 .iter()
196 .map(|param| ParameterInfo {
197 id: param.id.clone(),
198 name: param.name.clone(),
199 param_type: param.param_type,
200 value: self.current_value(¶m.id, param.default),
201 default: param.default,
202 min: param.min,
203 max: param.max,
204 unit: param.unit.clone(),
205 group: param.group.clone(),
206 variants: param.variants.clone(),
207 })
208 .collect()
209 }
210
211 fn get_meter_frame(&self) -> Option<MeterFrame> {
212 self.meter_provider
213 .as_ref()
214 .and_then(|provider| provider.get_meter_frame())
215 }
216
217 fn get_oscilloscope_frame(&self) -> Option<OscilloscopeFrame> {
218 self.oscilloscope_provider
219 .as_ref()
220 .and_then(|provider| provider.get_oscilloscope_frame())
221 }
222
223 fn request_resize(&self, _width: u32, _height: u32) -> bool {
224 false
225 }
226
227 fn get_audio_status(&self) -> Option<AudioRuntimeStatus> {
228 None
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use wavecraft_protocol::ParameterType;
236
237 struct StaticMeterProvider {
238 frame: MeterFrame,
239 }
240
241 struct StaticOscilloscopeProvider {
242 frame: OscilloscopeFrame,
243 }
244
245 impl MeterProvider for StaticMeterProvider {
246 fn get_meter_frame(&self) -> Option<MeterFrame> {
247 Some(self.frame)
248 }
249 }
250
251 impl OscilloscopeProvider for StaticOscilloscopeProvider {
252 fn get_oscilloscope_frame(&self) -> Option<OscilloscopeFrame> {
253 Some(self.frame.clone())
254 }
255 }
256
257 fn test_params() -> Vec<ParameterInfo> {
258 vec![
259 ParameterInfo {
260 id: "gain".to_string(),
261 name: "Gain".to_string(),
262 param_type: ParameterType::Float,
263 value: 0.5,
264 default: 0.5,
265 min: 0.0,
266 max: 1.0,
267 unit: Some("dB".to_string()),
268 group: Some("Input".to_string()),
269 variants: None,
270 },
271 ParameterInfo {
272 id: "mix".to_string(),
273 name: "Mix".to_string(),
274 param_type: ParameterType::Float,
275 value: 1.0,
276 default: 1.0,
277 min: 0.0,
278 max: 1.0,
279 unit: Some("%".to_string()),
280 group: None,
281 variants: None,
282 },
283 ]
284 }
285
286 #[test]
287 fn test_get_parameter() {
288 let host = InMemoryParameterHost::new(test_params());
289
290 let param = host.get_parameter("gain").expect("should find gain");
291 assert_eq!(param.id, "gain");
292 assert_eq!(param.name, "Gain");
293 assert!((param.value - 0.5).abs() < f32::EPSILON);
294 }
295
296 #[test]
297 fn test_set_parameter() {
298 let host = InMemoryParameterHost::new(test_params());
299
300 host.set_parameter("gain", 0.75).expect("should set gain");
301
302 let param = host.get_parameter("gain").expect("should find gain");
303 assert!((param.value - 0.75).abs() < f32::EPSILON);
304 }
305
306 #[test]
307 fn test_set_parameter_out_of_range() {
308 let host = InMemoryParameterHost::new(test_params());
309
310 let result = host.set_parameter("gain", 1.5);
311 assert!(result.is_err());
312
313 let result = host.set_parameter("gain", -0.1);
314 assert!(result.is_err());
315 }
316
317 #[test]
318 fn test_get_all_parameters() {
319 let host = InMemoryParameterHost::new(test_params());
320
321 let params = host.get_all_parameters();
322 assert_eq!(params.len(), 2);
323 assert!(params.iter().any(|p| p.id == "gain"));
324 assert!(params.iter().any(|p| p.id == "mix"));
325 }
326
327 #[test]
328 fn test_get_meter_frame() {
329 let frame = MeterFrame {
330 peak_l: 0.7,
331 rms_l: 0.5,
332 peak_r: 0.6,
333 rms_r: 0.4,
334 timestamp: 0,
335 };
336 let provider = Arc::new(StaticMeterProvider { frame });
337 let host = InMemoryParameterHost::with_meter_provider(test_params(), provider);
338
339 let read = host.get_meter_frame().expect("should have meter frame");
340 assert!((read.peak_l - 0.7).abs() < f32::EPSILON);
341 assert!((read.rms_r - 0.4).abs() < f32::EPSILON);
342 }
343
344 #[test]
345 fn test_get_oscilloscope_frame() {
346 let frame = OscilloscopeFrame {
347 points_l: vec![0.1; 1024],
348 points_r: vec![0.2; 1024],
349 sample_rate: 48_000.0,
350 timestamp: 99,
351 no_signal: false,
352 trigger_mode: wavecraft_protocol::OscilloscopeTriggerMode::RisingZeroCrossing,
353 };
354 let provider = Arc::new(StaticOscilloscopeProvider { frame });
355 let host = InMemoryParameterHost::with_oscilloscope_provider(test_params(), provider);
356
357 let read = host
358 .get_oscilloscope_frame()
359 .expect("should have oscilloscope frame");
360 assert_eq!(read.points_l.len(), 1024);
361 assert_eq!(read.points_r.len(), 1024);
362 assert_eq!(read.timestamp, 99);
363 }
364
365 #[test]
366 fn test_replace_parameters_preserves_values() {
367 let host = InMemoryParameterHost::new(test_params());
368
369 host.set_parameter("gain", 0.75).expect("should set gain");
371 host.set_parameter("mix", 0.5).expect("should set mix");
372
373 let new_params = vec![
375 ParameterInfo {
376 id: "gain".to_string(),
377 name: "Gain".to_string(),
378 param_type: ParameterType::Float,
379 value: 0.5,
380 default: 0.5,
381 min: 0.0,
382 max: 1.0,
383 unit: Some("dB".to_string()),
384 group: Some("Input".to_string()),
385 variants: None,
386 },
387 ParameterInfo {
388 id: "mix".to_string(),
389 name: "Mix".to_string(),
390 param_type: ParameterType::Float,
391 value: 1.0,
392 default: 1.0,
393 min: 0.0,
394 max: 1.0,
395 unit: Some("%".to_string()),
396 group: None,
397 variants: None,
398 },
399 ParameterInfo {
400 id: "freq".to_string(),
401 name: "Frequency".to_string(),
402 param_type: ParameterType::Float,
403 value: 440.0,
404 default: 440.0,
405 min: 20.0,
406 max: 5_000.0,
407 unit: Some("Hz".to_string()),
408 group: None,
409 variants: None,
410 },
411 ];
412
413 host.replace_parameters(new_params)
414 .expect("should replace parameters");
415
416 let gain = host.get_parameter("gain").expect("should find gain");
418 assert!((gain.value - 0.75).abs() < f32::EPSILON);
419
420 let mix = host.get_parameter("mix").expect("should find mix");
421 assert!((mix.value - 0.5).abs() < f32::EPSILON);
422
423 let freq = host.get_parameter("freq").expect("should find freq");
425 assert!((freq.value - 440.0).abs() < f32::EPSILON);
426 }
427
428 #[test]
429 fn test_replace_parameters_removes_old() {
430 let host = InMemoryParameterHost::new(test_params());
431
432 let new_params = vec![ParameterInfo {
434 id: "gain".to_string(),
435 name: "Gain".to_string(),
436 param_type: ParameterType::Float,
437 value: 0.5,
438 default: 0.5,
439 min: 0.0,
440 max: 1.0,
441 unit: Some("dB".to_string()),
442 group: Some("Input".to_string()),
443 variants: None,
444 }];
445
446 host.replace_parameters(new_params)
447 .expect("should replace parameters");
448
449 assert!(host.get_parameter("mix").is_none());
451
452 assert!(host.get_parameter("gain").is_some());
454 }
455
456 #[test]
457 fn test_set_parameter_uses_declared_range_not_normalized_range() {
458 let host = InMemoryParameterHost::new(vec![ParameterInfo {
459 id: "oscillator_frequency".to_string(),
460 name: "Frequency".to_string(),
461 param_type: ParameterType::Float,
462 value: 440.0,
463 default: 440.0,
464 min: 20.0,
465 max: 5_000.0,
466 unit: Some("Hz".to_string()),
467 group: Some("Oscillator".to_string()),
468 variants: None,
469 }]);
470
471 host.set_parameter("oscillator_frequency", 2_000.0)
472 .expect("frequency in declared range should be accepted");
473
474 let freq = host
475 .get_parameter("oscillator_frequency")
476 .expect("frequency should exist");
477 assert!((freq.value - 2_000.0).abs() < f32::EPSILON);
478
479 let too_low = host.set_parameter("oscillator_frequency", 10.0);
480 assert!(too_low.is_err(), "value below min should be rejected");
481
482 let too_high = host.set_parameter("oscillator_frequency", 10_000.0);
483 assert!(too_high.is_err(), "value above max should be rejected");
484 }
485}