1use std::collections::HashMap;
4use std::sync::{Arc, RwLock};
5
6use crate::{BridgeError, ParameterHost};
7use wavecraft_protocol::{MeterFrame, ParameterInfo};
8
9pub trait MeterProvider: Send + Sync {
11 fn get_meter_frame(&self) -> Option<MeterFrame>;
13}
14
15pub struct InMemoryParameterHost {
19 parameters: RwLock<Vec<ParameterInfo>>,
20 values: RwLock<HashMap<String, f32>>,
21 meter_provider: Option<Arc<dyn MeterProvider>>,
22}
23
24impl InMemoryParameterHost {
25 pub fn new(parameters: Vec<ParameterInfo>) -> Self {
27 let values = parameters
28 .iter()
29 .map(|p| (p.id.clone(), p.default))
30 .collect();
31
32 Self {
33 parameters: RwLock::new(parameters),
34 values: RwLock::new(values),
35 meter_provider: None,
36 }
37 }
38
39 pub fn with_meter_provider(
41 parameters: Vec<ParameterInfo>,
42 meter_provider: Arc<dyn MeterProvider>,
43 ) -> Self {
44 let mut host = Self::new(parameters);
45 host.meter_provider = Some(meter_provider);
46 host
47 }
48
49 pub fn replace_parameters(&self, new_params: Vec<ParameterInfo>) -> Result<(), String> {
66 let mut values = match self.values.write() {
68 Ok(guard) => guard,
69 Err(poisoned) => {
70 eprintln!("⚠ Recovering from poisoned values lock");
71 poisoned.into_inner()
72 }
73 };
74
75 let mut new_values = HashMap::new();
77 for param in &new_params {
78 let value = values.get(¶m.id).copied().unwrap_or(param.default);
79 new_values.insert(param.id.clone(), value);
80 }
81
82 *values = new_values;
83 drop(values); let mut params = match self.parameters.write() {
87 Ok(guard) => guard,
88 Err(poisoned) => {
89 eprintln!("⚠ Recovering from poisoned parameters lock");
90 poisoned.into_inner()
91 }
92 };
93
94 *params = new_params;
95 Ok(())
96 }
97
98 fn current_value(&self, id: &str, default: f32) -> f32 {
99 self.values
100 .read()
101 .ok()
102 .and_then(|values| values.get(id).copied())
103 .unwrap_or(default)
104 }
105}
106
107impl ParameterHost for InMemoryParameterHost {
108 fn get_parameter(&self, id: &str) -> Option<ParameterInfo> {
109 let parameters = self.parameters.read().ok()?;
110 let param = parameters.iter().find(|p| p.id == id)?;
111
112 Some(ParameterInfo {
113 id: param.id.clone(),
114 name: param.name.clone(),
115 param_type: param.param_type,
116 value: self.current_value(¶m.id, param.default),
117 default: param.default,
118 unit: param.unit.clone(),
119 group: param.group.clone(),
120 })
121 }
122
123 fn set_parameter(&self, id: &str, value: f32) -> Result<(), BridgeError> {
124 let parameters = self.parameters.read().ok();
125 let param_exists = parameters
126 .as_ref()
127 .map(|p| p.iter().any(|param| param.id == id))
128 .unwrap_or(false);
129
130 if !param_exists {
131 return Err(BridgeError::ParameterNotFound(id.to_string()));
132 }
133
134 if !(0.0..=1.0).contains(&value) {
135 return Err(BridgeError::ParameterOutOfRange {
136 id: id.to_string(),
137 value,
138 });
139 }
140
141 if let Ok(mut values) = self.values.write() {
142 values.insert(id.to_string(), value);
143 }
144
145 Ok(())
146 }
147
148 fn get_all_parameters(&self) -> Vec<ParameterInfo> {
149 let parameters = match self.parameters.read() {
150 Ok(guard) => guard,
151 Err(_) => return Vec::new(), };
153
154 parameters
155 .iter()
156 .map(|param| ParameterInfo {
157 id: param.id.clone(),
158 name: param.name.clone(),
159 param_type: param.param_type,
160 value: self.current_value(¶m.id, param.default),
161 default: param.default,
162 unit: param.unit.clone(),
163 group: param.group.clone(),
164 })
165 .collect()
166 }
167
168 fn get_meter_frame(&self) -> Option<MeterFrame> {
169 self.meter_provider
170 .as_ref()
171 .and_then(|provider| provider.get_meter_frame())
172 }
173
174 fn request_resize(&self, _width: u32, _height: u32) -> bool {
175 false
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use wavecraft_protocol::ParameterType;
183
184 struct StaticMeterProvider {
185 frame: MeterFrame,
186 }
187
188 impl MeterProvider for StaticMeterProvider {
189 fn get_meter_frame(&self) -> Option<MeterFrame> {
190 Some(self.frame)
191 }
192 }
193
194 fn test_params() -> Vec<ParameterInfo> {
195 vec![
196 ParameterInfo {
197 id: "gain".to_string(),
198 name: "Gain".to_string(),
199 param_type: ParameterType::Float,
200 value: 0.5,
201 default: 0.5,
202 unit: Some("dB".to_string()),
203 group: Some("Input".to_string()),
204 },
205 ParameterInfo {
206 id: "mix".to_string(),
207 name: "Mix".to_string(),
208 param_type: ParameterType::Float,
209 value: 1.0,
210 default: 1.0,
211 unit: Some("%".to_string()),
212 group: None,
213 },
214 ]
215 }
216
217 #[test]
218 fn test_get_parameter() {
219 let host = InMemoryParameterHost::new(test_params());
220
221 let param = host.get_parameter("gain").expect("should find gain");
222 assert_eq!(param.id, "gain");
223 assert_eq!(param.name, "Gain");
224 assert!((param.value - 0.5).abs() < f32::EPSILON);
225 }
226
227 #[test]
228 fn test_set_parameter() {
229 let host = InMemoryParameterHost::new(test_params());
230
231 host.set_parameter("gain", 0.75).expect("should set gain");
232
233 let param = host.get_parameter("gain").expect("should find gain");
234 assert!((param.value - 0.75).abs() < f32::EPSILON);
235 }
236
237 #[test]
238 fn test_set_parameter_out_of_range() {
239 let host = InMemoryParameterHost::new(test_params());
240
241 let result = host.set_parameter("gain", 1.5);
242 assert!(result.is_err());
243
244 let result = host.set_parameter("gain", -0.1);
245 assert!(result.is_err());
246 }
247
248 #[test]
249 fn test_get_all_parameters() {
250 let host = InMemoryParameterHost::new(test_params());
251
252 let params = host.get_all_parameters();
253 assert_eq!(params.len(), 2);
254 assert!(params.iter().any(|p| p.id == "gain"));
255 assert!(params.iter().any(|p| p.id == "mix"));
256 }
257
258 #[test]
259 fn test_get_meter_frame() {
260 let frame = MeterFrame {
261 peak_l: 0.7,
262 rms_l: 0.5,
263 peak_r: 0.6,
264 rms_r: 0.4,
265 timestamp: 0,
266 };
267 let provider = Arc::new(StaticMeterProvider { frame });
268 let host = InMemoryParameterHost::with_meter_provider(test_params(), provider);
269
270 let read = host.get_meter_frame().expect("should have meter frame");
271 assert!((read.peak_l - 0.7).abs() < f32::EPSILON);
272 assert!((read.rms_r - 0.4).abs() < f32::EPSILON);
273 }
274
275 #[test]
276 fn test_replace_parameters_preserves_values() {
277 let host = InMemoryParameterHost::new(test_params());
278
279 host.set_parameter("gain", 0.75).expect("should set gain");
281 host.set_parameter("mix", 0.5).expect("should set mix");
282
283 let new_params = vec![
285 ParameterInfo {
286 id: "gain".to_string(),
287 name: "Gain".to_string(),
288 param_type: ParameterType::Float,
289 value: 0.5,
290 default: 0.5,
291 unit: Some("dB".to_string()),
292 group: Some("Input".to_string()),
293 },
294 ParameterInfo {
295 id: "mix".to_string(),
296 name: "Mix".to_string(),
297 param_type: ParameterType::Float,
298 value: 1.0,
299 default: 1.0,
300 unit: Some("%".to_string()),
301 group: None,
302 },
303 ParameterInfo {
304 id: "freq".to_string(),
305 name: "Frequency".to_string(),
306 param_type: ParameterType::Float,
307 value: 440.0,
308 default: 440.0,
309 unit: Some("Hz".to_string()),
310 group: None,
311 },
312 ];
313
314 host.replace_parameters(new_params)
315 .expect("should replace parameters");
316
317 let gain = host.get_parameter("gain").expect("should find gain");
319 assert!((gain.value - 0.75).abs() < f32::EPSILON);
320
321 let mix = host.get_parameter("mix").expect("should find mix");
322 assert!((mix.value - 0.5).abs() < f32::EPSILON);
323
324 let freq = host.get_parameter("freq").expect("should find freq");
326 assert!((freq.value - 440.0).abs() < f32::EPSILON);
327 }
328
329 #[test]
330 fn test_replace_parameters_removes_old() {
331 let host = InMemoryParameterHost::new(test_params());
332
333 let new_params = vec![ParameterInfo {
335 id: "gain".to_string(),
336 name: "Gain".to_string(),
337 param_type: ParameterType::Float,
338 value: 0.5,
339 default: 0.5,
340 unit: Some("dB".to_string()),
341 group: Some("Input".to_string()),
342 }];
343
344 host.replace_parameters(new_params)
345 .expect("should replace parameters");
346
347 assert!(host.get_parameter("mix").is_none());
349
350 assert!(host.get_parameter("gain").is_some());
352 }
353}