1use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6use crate::{BridgeError, ParameterHost};
7use wavecraft_protocol::{AudioRuntimeStatus, MeterFrame, ParameterInfo};
8
9pub trait MeterProvider: Send + Sync {
11 fn get_meter_frame(&self) -> Option<MeterFrame>;
13}
14
15pub struct InMemoryParameterHost {
19 parameters: RwLock<Vec<ParameterInfo>>,
20 values: RwLock<HashMap<String, f32>>,
21 meter_provider: Option<Arc<dyn MeterProvider>>,
22}
23
24impl InMemoryParameterHost {
25 pub fn new(parameters: Vec<ParameterInfo>) -> Self {
27 let values = parameters
28 .iter()
29 .map(|p| (p.id.clone(), p.default))
30 .collect();
31
32 Self {
33 parameters: RwLock::new(parameters),
34 values: RwLock::new(values),
35 meter_provider: None,
36 }
37 }
38
39 pub fn with_meter_provider(
41 parameters: Vec<ParameterInfo>,
42 meter_provider: Arc<dyn MeterProvider>,
43 ) -> Self {
44 let mut host = Self::new(parameters);
45 host.meter_provider = Some(meter_provider);
46 host
47 }
48
49 pub fn replace_parameters(&self, new_params: Vec<ParameterInfo>) -> Result<(), String> {
66 let mut values = match self.values.write() {
68 Ok(guard) => guard,
69 Err(poisoned) => {
70 eprintln!("⚠ Recovering from poisoned values lock");
71 poisoned.into_inner()
72 }
73 };
74
75 let mut new_values = HashMap::new();
77 for param in &new_params {
78 let value = values.get(¶m.id).copied().unwrap_or(param.default);
79 new_values.insert(param.id.clone(), value);
80 }
81
82 *values = new_values;
83 drop(values); let mut params = match self.parameters.write() {
87 Ok(guard) => guard,
88 Err(poisoned) => {
89 eprintln!("⚠ Recovering from poisoned parameters lock");
90 poisoned.into_inner()
91 }
92 };
93
94 *params = new_params;
95 Ok(())
96 }
97
98 fn current_value(&self, id: &str, default: f32) -> f32 {
99 self.values
100 .read()
101 .ok()
102 .and_then(|values| values.get(id).copied())
103 .unwrap_or(default)
104 }
105}
106
107impl ParameterHost for InMemoryParameterHost {
108 fn get_parameter(&self, id: &str) -> Option<ParameterInfo> {
109 let parameters = self.parameters.read().ok()?;
110 let param = parameters.iter().find(|p| p.id == id)?;
111
112 Some(ParameterInfo {
113 id: param.id.clone(),
114 name: param.name.clone(),
115 param_type: param.param_type,
116 value: self.current_value(¶m.id, param.default),
117 default: param.default,
118 min: param.min,
119 max: param.max,
120 unit: param.unit.clone(),
121 group: param.group.clone(),
122 })
123 }
124
125 fn set_parameter(&self, id: &str, value: f32) -> Result<(), BridgeError> {
126 let parameters = self.parameters.read().ok();
127 let param_exists = parameters
128 .as_ref()
129 .map(|p| p.iter().any(|param| param.id == id))
130 .unwrap_or(false);
131
132 if !param_exists {
133 return Err(BridgeError::ParameterNotFound(id.to_string()));
134 }
135
136 let Some(param) = parameters
137 .as_ref()
138 .and_then(|p| p.iter().find(|param| param.id == id))
139 else {
140 return Err(BridgeError::ParameterNotFound(id.to_string()));
141 };
142
143 if !(param.min..=param.max).contains(&value) {
144 return Err(BridgeError::ParameterOutOfRange {
145 id: id.to_string(),
146 value,
147 });
148 }
149
150 if let Ok(mut values) = self.values.write() {
151 values.insert(id.to_string(), value);
152 }
153
154 Ok(())
155 }
156
157 fn get_all_parameters(&self) -> Vec<ParameterInfo> {
158 let parameters = match self.parameters.read() {
159 Ok(guard) => guard,
160 Err(_) => return Vec::new(), };
162
163 parameters
164 .iter()
165 .map(|param| ParameterInfo {
166 id: param.id.clone(),
167 name: param.name.clone(),
168 param_type: param.param_type,
169 value: self.current_value(¶m.id, param.default),
170 default: param.default,
171 min: param.min,
172 max: param.max,
173 unit: param.unit.clone(),
174 group: param.group.clone(),
175 })
176 .collect()
177 }
178
179 fn get_meter_frame(&self) -> Option<MeterFrame> {
180 self.meter_provider
181 .as_ref()
182 .and_then(|provider| provider.get_meter_frame())
183 }
184
185 fn request_resize(&self, _width: u32, _height: u32) -> bool {
186 false
187 }
188
189 fn get_audio_status(&self) -> Option<AudioRuntimeStatus> {
190 None
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use wavecraft_protocol::ParameterType;
198
199 struct StaticMeterProvider {
200 frame: MeterFrame,
201 }
202
203 impl MeterProvider for StaticMeterProvider {
204 fn get_meter_frame(&self) -> Option<MeterFrame> {
205 Some(self.frame)
206 }
207 }
208
209 fn test_params() -> Vec<ParameterInfo> {
210 vec![
211 ParameterInfo {
212 id: "gain".to_string(),
213 name: "Gain".to_string(),
214 param_type: ParameterType::Float,
215 value: 0.5,
216 default: 0.5,
217 min: 0.0,
218 max: 1.0,
219 unit: Some("dB".to_string()),
220 group: Some("Input".to_string()),
221 },
222 ParameterInfo {
223 id: "mix".to_string(),
224 name: "Mix".to_string(),
225 param_type: ParameterType::Float,
226 value: 1.0,
227 default: 1.0,
228 min: 0.0,
229 max: 1.0,
230 unit: Some("%".to_string()),
231 group: None,
232 },
233 ]
234 }
235
236 #[test]
237 fn test_get_parameter() {
238 let host = InMemoryParameterHost::new(test_params());
239
240 let param = host.get_parameter("gain").expect("should find gain");
241 assert_eq!(param.id, "gain");
242 assert_eq!(param.name, "Gain");
243 assert!((param.value - 0.5).abs() < f32::EPSILON);
244 }
245
246 #[test]
247 fn test_set_parameter() {
248 let host = InMemoryParameterHost::new(test_params());
249
250 host.set_parameter("gain", 0.75).expect("should set gain");
251
252 let param = host.get_parameter("gain").expect("should find gain");
253 assert!((param.value - 0.75).abs() < f32::EPSILON);
254 }
255
256 #[test]
257 fn test_set_parameter_out_of_range() {
258 let host = InMemoryParameterHost::new(test_params());
259
260 let result = host.set_parameter("gain", 1.5);
261 assert!(result.is_err());
262
263 let result = host.set_parameter("gain", -0.1);
264 assert!(result.is_err());
265 }
266
267 #[test]
268 fn test_get_all_parameters() {
269 let host = InMemoryParameterHost::new(test_params());
270
271 let params = host.get_all_parameters();
272 assert_eq!(params.len(), 2);
273 assert!(params.iter().any(|p| p.id == "gain"));
274 assert!(params.iter().any(|p| p.id == "mix"));
275 }
276
277 #[test]
278 fn test_get_meter_frame() {
279 let frame = MeterFrame {
280 peak_l: 0.7,
281 rms_l: 0.5,
282 peak_r: 0.6,
283 rms_r: 0.4,
284 timestamp: 0,
285 };
286 let provider = Arc::new(StaticMeterProvider { frame });
287 let host = InMemoryParameterHost::with_meter_provider(test_params(), provider);
288
289 let read = host.get_meter_frame().expect("should have meter frame");
290 assert!((read.peak_l - 0.7).abs() < f32::EPSILON);
291 assert!((read.rms_r - 0.4).abs() < f32::EPSILON);
292 }
293
294 #[test]
295 fn test_replace_parameters_preserves_values() {
296 let host = InMemoryParameterHost::new(test_params());
297
298 host.set_parameter("gain", 0.75).expect("should set gain");
300 host.set_parameter("mix", 0.5).expect("should set mix");
301
302 let new_params = vec![
304 ParameterInfo {
305 id: "gain".to_string(),
306 name: "Gain".to_string(),
307 param_type: ParameterType::Float,
308 value: 0.5,
309 default: 0.5,
310 min: 0.0,
311 max: 1.0,
312 unit: Some("dB".to_string()),
313 group: Some("Input".to_string()),
314 },
315 ParameterInfo {
316 id: "mix".to_string(),
317 name: "Mix".to_string(),
318 param_type: ParameterType::Float,
319 value: 1.0,
320 default: 1.0,
321 min: 0.0,
322 max: 1.0,
323 unit: Some("%".to_string()),
324 group: None,
325 },
326 ParameterInfo {
327 id: "freq".to_string(),
328 name: "Frequency".to_string(),
329 param_type: ParameterType::Float,
330 value: 440.0,
331 default: 440.0,
332 min: 20.0,
333 max: 5_000.0,
334 unit: Some("Hz".to_string()),
335 group: None,
336 },
337 ];
338
339 host.replace_parameters(new_params)
340 .expect("should replace parameters");
341
342 let gain = host.get_parameter("gain").expect("should find gain");
344 assert!((gain.value - 0.75).abs() < f32::EPSILON);
345
346 let mix = host.get_parameter("mix").expect("should find mix");
347 assert!((mix.value - 0.5).abs() < f32::EPSILON);
348
349 let freq = host.get_parameter("freq").expect("should find freq");
351 assert!((freq.value - 440.0).abs() < f32::EPSILON);
352 }
353
354 #[test]
355 fn test_replace_parameters_removes_old() {
356 let host = InMemoryParameterHost::new(test_params());
357
358 let new_params = vec![ParameterInfo {
360 id: "gain".to_string(),
361 name: "Gain".to_string(),
362 param_type: ParameterType::Float,
363 value: 0.5,
364 default: 0.5,
365 min: 0.0,
366 max: 1.0,
367 unit: Some("dB".to_string()),
368 group: Some("Input".to_string()),
369 }];
370
371 host.replace_parameters(new_params)
372 .expect("should replace parameters");
373
374 assert!(host.get_parameter("mix").is_none());
376
377 assert!(host.get_parameter("gain").is_some());
379 }
380
381 #[test]
382 fn test_set_parameter_uses_declared_range_not_normalized_range() {
383 let host = InMemoryParameterHost::new(vec![ParameterInfo {
384 id: "oscillator_frequency".to_string(),
385 name: "Frequency".to_string(),
386 param_type: ParameterType::Float,
387 value: 440.0,
388 default: 440.0,
389 min: 20.0,
390 max: 5_000.0,
391 unit: Some("Hz".to_string()),
392 group: Some("Oscillator".to_string()),
393 }]);
394
395 host.set_parameter("oscillator_frequency", 2_000.0)
396 .expect("frequency in declared range should be accepted");
397
398 let freq = host
399 .get_parameter("oscillator_frequency")
400 .expect("frequency should exist");
401 assert!((freq.value - 2_000.0).abs() < f32::EPSILON);
402
403 let too_low = host.set_parameter("oscillator_frequency", 10.0);
404 assert!(too_low.is_err(), "value below min should be rejected");
405
406 let too_high = host.set_parameter("oscillator_frequency", 10_000.0);
407 assert!(too_high.is_err(), "value above max should be rejected");
408 }
409}