1use futures_util::{SinkExt, StreamExt};
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{RwLock, broadcast};
12use tokio_tungstenite::{accept_async, tungstenite::protocol::Message};
13use tracing::{debug, error, info, warn};
14use wavecraft_bridge::{IpcHandler, ParameterHost};
15use wavecraft_protocol::{
16 AudioRuntimeStatus, IpcNotification, IpcRequest, IpcResponse, METHOD_REGISTER_AUDIO,
17 NOTIFICATION_AUDIO_STATUS_CHANGED, NOTIFICATION_METER_UPDATE, NOTIFICATION_PARAMETER_CHANGED,
18 SetParameterParams,
19};
20
21const NOTIFICATION_PARAMETERS_CHANGED: &str = "parametersChanged";
22
23type BrowserClientTx = tokio::sync::mpsc::Sender<String>;
24
25struct ServerState {
27 browser_clients: Arc<RwLock<Vec<BrowserClientTx>>>,
29 audio_client: Arc<RwLock<Option<String>>>,
31}
32
33impl ServerState {
34 fn new() -> Self {
35 Self {
36 browser_clients: Arc::new(RwLock::new(Vec::new())),
37 audio_client: Arc::new(RwLock::new(None)),
38 }
39 }
40}
41
42#[derive(Clone)]
48pub struct WsHandle {
49 state: Arc<ServerState>,
50}
51
52impl WsHandle {
53 pub async fn broadcast(&self, json: &str) {
55 broadcast_to_browser_clients(&self.state, json, None, "broadcast message").await;
56 }
57
58 pub async fn broadcast_audio_status_changed(
60 &self,
61 status: &AudioRuntimeStatus,
62 ) -> Result<(), serde_json::Error> {
63 let json = serde_json::to_string(&IpcNotification::new(
64 NOTIFICATION_AUDIO_STATUS_CHANGED,
65 status,
66 ))?;
67
68 self.broadcast(&json).await;
69 Ok(())
70 }
71}
72
73async fn broadcast_to_browser_clients(
74 state: &Arc<ServerState>,
75 json: &str,
76 exclude_client_index: Option<usize>,
77 warning_context: &str,
78) {
79 let clients = state.browser_clients.read().await;
80 for (index, client) in clients.iter().enumerate() {
81 if exclude_client_index.is_some_and(|excluded| index == excluded) {
82 continue;
83 }
84
85 if let Err(error) = client.try_send(json.to_owned()) {
86 warn!(
87 "Failed to {} (client {}): {}",
88 warning_context, index, error
89 );
90 }
91 }
92}
93
94pub struct WsServer<H: ParameterHost + 'static> {
96 port: u16,
98 handler: Arc<IpcHandler<H>>,
100 shutdown_tx: broadcast::Sender<()>,
102 state: Arc<ServerState>,
104}
105
106fn build_set_parameter_notification(request: &IpcRequest, response: &str) -> Option<String> {
107 if request.method != wavecraft_protocol::METHOD_SET_PARAMETER {
108 return None;
109 }
110
111 let response_msg = serde_json::from_str::<IpcResponse>(response).ok()?;
112 if response_msg.error.is_some() {
113 return None;
114 }
115
116 let params = request.params.clone()?;
117 let set_params = serde_json::from_value::<SetParameterParams>(params).ok()?;
118
119 serde_json::to_string(&IpcNotification::new(
120 NOTIFICATION_PARAMETER_CHANGED,
121 serde_json::json!({
122 "id": set_params.id,
123 "value": set_params.value,
124 }),
125 ))
126 .ok()
127}
128
129impl<H: ParameterHost + 'static> WsServer<H> {
130 pub fn new(port: u16, handler: Arc<IpcHandler<H>>) -> Self {
132 let (shutdown_tx, _) = broadcast::channel(1);
133 Self {
134 port,
135 handler,
136 shutdown_tx,
137 state: Arc::new(ServerState::new()),
138 }
139 }
140
141 pub fn handle(&self) -> WsHandle {
146 WsHandle {
147 state: Arc::clone(&self.state),
148 }
149 }
150
151 pub async fn broadcast_parameters_changed(&self) -> Result<(), serde_json::Error> {
156 let notification =
157 IpcNotification::new(NOTIFICATION_PARAMETERS_CHANGED, serde_json::json!({}));
158 let json = serde_json::to_string(¬ification)?;
159
160 broadcast_to_browser_clients(
161 &self.state,
162 &json,
163 None,
164 "send parametersChanged notification",
165 )
166 .await;
167
168 Ok(())
169 }
170
171 pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
173 let addr: SocketAddr = format!("127.0.0.1:{}", self.port).parse()?;
174 let listener = TcpListener::bind(&addr).await?;
175
176 info!("Server listening on ws://{}", addr);
177
178 let handler = Arc::clone(&self.handler);
179 let mut shutdown_rx = self.shutdown_tx.subscribe();
180 let state = Arc::clone(&self.state);
181
182 tokio::spawn(async move {
183 loop {
184 tokio::select! {
185 result = listener.accept() => {
186 match result {
187 Ok((stream, addr)) => {
188 info!("Client connected: {}", addr);
189 let handler = Arc::clone(&handler);
190 let state = Arc::clone(&state);
191 tokio::spawn(handle_connection(handler, stream, addr, state));
192 }
193 Err(e) => {
194 error!("Accept error: {}", e);
195 }
196 }
197 }
198 _ = shutdown_rx.recv() => {
199 info!("Server shutting down");
200 break;
201 }
202 }
203 }
204 });
205
206 Ok(())
207 }
208
209 #[allow(dead_code)]
213 pub fn shutdown(&self) {
214 let _ = self.shutdown_tx.send(());
215 }
216}
217
218async fn handle_connection<H: ParameterHost>(
220 handler: Arc<IpcHandler<H>>,
221 stream: TcpStream,
222 addr: SocketAddr,
223 state: Arc<ServerState>,
224) {
225 let ws_stream = match accept_async(stream).await {
226 Ok(ws) => ws,
227 Err(e) => {
228 error!("Error during handshake with {}: {}", addr, e);
229 return;
230 }
231 };
232
233 info!("WebSocket connection established: {}", addr);
234
235 let (mut write, mut read) = ws_stream.split();
236 let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(128);
237
238 let mut is_audio_client = false;
240 let client_index = {
241 let mut clients = state.browser_clients.write().await;
242 clients.push(tx.clone());
243 clients.len() - 1
244 };
245
246 let write_task = tokio::spawn(async move {
248 while let Some(msg) = rx.recv().await {
249 if let Err(e) = write.send(Message::Text(msg)).await {
250 error!("Error sending to {}: {}", addr, e);
251 break;
252 }
253 }
254 });
255
256 while let Some(msg) = read.next().await {
257 match msg {
258 Ok(Message::Text(json)) => {
259 debug!("Received from {}: {}", addr, json);
260
261 let parsed_req = serde_json::from_str::<IpcRequest>(&json);
263
264 if let Ok(ref req) = parsed_req {
265 if req.method == METHOD_REGISTER_AUDIO {
267 is_audio_client = true;
268 info!("Audio client registered: {}", addr);
269
270 if let Some(params) = req.params.clone()
272 && let Ok(audio_params) = serde_json::from_value::<
273 wavecraft_protocol::RegisterAudioParams,
274 >(params)
275 {
276 *state.audio_client.write().await =
277 Some(audio_params.client_id.clone());
278 }
279
280 let response = wavecraft_protocol::IpcResponse::success(
282 req.id.clone(),
283 wavecraft_protocol::RegisterAudioResult {
284 status: "registered".to_string(),
285 },
286 );
287 let response_json = match serde_json::to_string(&response) {
288 Ok(json) => json,
289 Err(e) => {
290 error!("Failed to serialize registerAudio response: {}", e);
291 break;
292 }
293 };
294 if let Err(e) = tx.try_send(response_json) {
295 error!("Error sending response: {}", e);
296 break;
297 }
298 continue;
299 }
300
301 if is_audio_client && req.method == NOTIFICATION_METER_UPDATE {
303 broadcast_to_browser_clients(
305 &state,
306 &json,
307 Some(client_index),
308 "broadcast meter update",
309 )
310 .await;
311 continue;
312 }
313 }
314
315 let response = handler.handle_json(&json);
317
318 if let Ok(req) = &parsed_req
322 && let Some(notification_json) =
323 build_set_parameter_notification(req, &response)
324 {
325 broadcast_to_browser_clients(
326 &state,
327 ¬ification_json,
328 None,
329 "send parameterChanged notification",
330 )
331 .await;
332 }
333
334 debug!("Sending to {}: {}", addr, response);
336
337 if let Err(e) = tx.try_send(response) {
339 error!("Error queueing response: {}", e);
340 break;
341 }
342 }
343 Ok(Message::Close(_)) => {
344 info!("Client closed connection: {}", addr);
345 break;
346 }
347 Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {
348 }
350 Ok(Message::Binary(_)) => {
351 warn!("Unexpected binary message from {}", addr);
352 }
353 Ok(Message::Frame(_)) => {
354 }
356 Err(e) => {
357 error!("Error receiving from {}: {}", addr, e);
358 break;
359 }
360 }
361 }
362
363 state
365 .browser_clients
366 .write()
367 .await
368 .retain(|c| !c.is_closed());
369 if is_audio_client {
370 *state.audio_client.write().await = None;
371 info!("Audio client disconnected: {}", addr);
372 }
373
374 write_task.abort();
375 info!("Connection closed: {}", addr);
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use wavecraft_bridge::InMemoryParameterHost;
382 use wavecraft_protocol::{IpcRequest, IpcResponse, ParameterInfo, ParameterType, RequestId};
383
384 fn test_host() -> InMemoryParameterHost {
386 InMemoryParameterHost::new(vec![ParameterInfo {
387 id: "gain".to_string(),
388 name: "Gain".to_string(),
389 param_type: ParameterType::Float,
390 value: 0.5,
391 default: 0.5,
392 min: 0.0,
393 max: 1.0,
394 unit: Some("dB".to_string()),
395 group: Some("Input".to_string()),
396 variants: None,
397 }])
398 }
399
400 #[tokio::test]
401 async fn test_server_creation() {
402 let host = test_host();
403 let handler = Arc::new(IpcHandler::new(host));
404 let server = WsServer::new(9001, handler);
405
406 assert_eq!(server.port, 9001);
408 }
409
410 #[test]
411 fn build_set_parameter_notification_from_success_response() {
412 let request = IpcRequest::new(
413 RequestId::Number(1),
414 wavecraft_protocol::METHOD_SET_PARAMETER,
415 Some(serde_json::json!({ "id": "gain", "value": 0.8 })),
416 );
417 let response = serde_json::to_string(&IpcResponse::success(
418 RequestId::Number(1),
419 serde_json::json!({}),
420 ))
421 .expect("serialize response");
422
423 let notification = build_set_parameter_notification(&request, &response)
424 .expect("should create parameterChanged notification");
425 let json: serde_json::Value =
426 serde_json::from_str(¬ification).expect("notification should parse");
427
428 assert_eq!(
429 json.get("method"),
430 Some(&serde_json::json!(
431 wavecraft_protocol::NOTIFICATION_PARAMETER_CHANGED
432 ))
433 );
434 assert_eq!(json.pointer("/params/id"), Some(&serde_json::json!("gain")));
435 let Some(value) = json
436 .pointer("/params/value")
437 .and_then(serde_json::Value::as_f64)
438 else {
439 panic!("notification should contain numeric params.value");
440 };
441 assert!(
442 (value - 0.8).abs() < 1e-5,
443 "expected approx 0.8, got {value}"
444 );
445 }
446
447 #[test]
448 fn build_set_parameter_notification_ignores_error_response() {
449 let request = IpcRequest::new(
450 RequestId::Number(1),
451 wavecraft_protocol::METHOD_SET_PARAMETER,
452 Some(serde_json::json!({ "id": "gain", "value": 10.0 })),
453 );
454
455 let response = serde_json::to_string(&IpcResponse::error(
456 RequestId::Number(1),
457 wavecraft_protocol::IpcError::invalid_params("out of range"),
458 ))
459 .expect("serialize error response");
460
461 assert!(build_set_parameter_notification(&request, &response).is_none());
462 }
463}