1use futures_util::{SinkExt, StreamExt};
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{RwLock, broadcast};
12use tokio_tungstenite::{accept_async, tungstenite::protocol::Message};
13use tracing::{debug, error, info, warn};
14use wavecraft_bridge::{IpcHandler, ParameterHost};
15use wavecraft_protocol::{
16 AudioRuntimeStatus, IpcNotification, IpcResponse, NOTIFICATION_AUDIO_STATUS_CHANGED,
17};
18
19struct ServerState {
21 browser_clients: Arc<RwLock<Vec<tokio::sync::mpsc::Sender<String>>>>,
23 audio_client: Arc<RwLock<Option<String>>>,
25}
26
27impl ServerState {
28 fn new() -> Self {
29 Self {
30 browser_clients: Arc::new(RwLock::new(Vec::new())),
31 audio_client: Arc::new(RwLock::new(None)),
32 }
33 }
34}
35
36#[derive(Clone)]
42pub struct WsHandle {
43 state: Arc<ServerState>,
44}
45
46impl WsHandle {
47 pub async fn broadcast(&self, json: &str) {
49 broadcast_to_browser_clients(&self.state, json, None, "broadcast message").await;
50 }
51
52 pub async fn broadcast_audio_status_changed(
54 &self,
55 status: &AudioRuntimeStatus,
56 ) -> Result<(), serde_json::Error> {
57 let json = serde_json::to_string(&IpcNotification::new(
58 NOTIFICATION_AUDIO_STATUS_CHANGED,
59 status,
60 ))?;
61
62 self.broadcast(&json).await;
63 Ok(())
64 }
65}
66
67async fn broadcast_to_browser_clients(
68 state: &Arc<ServerState>,
69 json: &str,
70 exclude_client_index: Option<usize>,
71 warning_context: &str,
72) {
73 let clients = state.browser_clients.read().await;
74 for (index, client) in clients.iter().enumerate() {
75 if exclude_client_index.is_some_and(|excluded| index == excluded) {
76 continue;
77 }
78
79 if let Err(error) = client.try_send(json.to_owned()) {
80 warn!(
81 "Failed to {} (client {}): {}",
82 warning_context, index, error
83 );
84 }
85 }
86}
87
88pub struct WsServer<H: ParameterHost + 'static> {
90 port: u16,
92 handler: Arc<IpcHandler<H>>,
94 shutdown_tx: broadcast::Sender<()>,
96 verbose: bool,
98 state: Arc<ServerState>,
100}
101
102fn build_set_parameter_notification(
103 request: &wavecraft_protocol::IpcRequest,
104 response: &str,
105) -> Option<String> {
106 if request.method != wavecraft_protocol::METHOD_SET_PARAMETER {
107 return None;
108 }
109
110 let response_msg = serde_json::from_str::<IpcResponse>(response).ok()?;
111 if response_msg.error.is_some() {
112 return None;
113 }
114
115 let params = request.params.clone()?;
116 let set_params =
117 serde_json::from_value::<wavecraft_protocol::SetParameterParams>(params).ok()?;
118
119 serde_json::to_string(&IpcNotification::new(
120 wavecraft_protocol::NOTIFICATION_PARAMETER_CHANGED,
121 serde_json::json!({
122 "id": set_params.id,
123 "value": set_params.value,
124 }),
125 ))
126 .ok()
127}
128
129impl<H: ParameterHost + 'static> WsServer<H> {
130 pub fn new(port: u16, handler: Arc<IpcHandler<H>>, verbose: bool) -> Self {
132 let (shutdown_tx, _) = broadcast::channel(1);
133 Self {
134 port,
135 handler,
136 shutdown_tx,
137 verbose,
138 state: Arc::new(ServerState::new()),
139 }
140 }
141
142 pub fn handle(&self) -> WsHandle {
147 WsHandle {
148 state: Arc::clone(&self.state),
149 }
150 }
151
152 pub async fn broadcast_parameters_changed(&self) -> Result<(), serde_json::Error> {
157 use wavecraft_protocol::IpcNotification;
158
159 let notification = IpcNotification::new("parametersChanged", serde_json::json!({}));
160 let json = serde_json::to_string(¬ification)?;
161
162 broadcast_to_browser_clients(
163 &self.state,
164 &json,
165 None,
166 "send parametersChanged notification",
167 )
168 .await;
169
170 Ok(())
171 }
172
173 pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
175 let addr: SocketAddr = format!("127.0.0.1:{}", self.port).parse()?;
176 let listener = TcpListener::bind(&addr).await?;
177
178 info!("Server listening on ws://{}", addr);
179
180 let handler = Arc::clone(&self.handler);
181 let mut shutdown_rx = self.shutdown_tx.subscribe();
182 let verbose = self.verbose;
183 let state = Arc::clone(&self.state);
184
185 tokio::spawn(async move {
186 loop {
187 tokio::select! {
188 result = listener.accept() => {
189 match result {
190 Ok((stream, addr)) => {
191 info!("Client connected: {}", addr);
192 let handler = Arc::clone(&handler);
193 let state = Arc::clone(&state);
194 tokio::spawn(handle_connection(handler, stream, addr, verbose, state));
195 }
196 Err(e) => {
197 error!("Accept error: {}", e);
198 }
199 }
200 }
201 _ = shutdown_rx.recv() => {
202 info!("Server shutting down");
203 break;
204 }
205 }
206 }
207 });
208
209 Ok(())
210 }
211
212 #[allow(dead_code)]
216 pub fn shutdown(&self) {
217 let _ = self.shutdown_tx.send(());
218 }
219}
220
221async fn handle_connection<H: ParameterHost>(
223 handler: Arc<IpcHandler<H>>,
224 stream: TcpStream,
225 addr: SocketAddr,
226 verbose: bool,
227 state: Arc<ServerState>,
228) {
229 let ws_stream = match accept_async(stream).await {
230 Ok(ws) => ws,
231 Err(e) => {
232 error!("Error during handshake with {}: {}", addr, e);
233 return;
234 }
235 };
236
237 info!("WebSocket connection established: {}", addr);
238
239 let (mut write, mut read) = ws_stream.split();
240 let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(128);
241
242 let mut is_audio_client = false;
244 let client_index = {
245 let mut clients = state.browser_clients.write().await;
246 clients.push(tx.clone());
247 clients.len() - 1
248 };
249
250 let write_task = tokio::spawn(async move {
252 while let Some(msg) = rx.recv().await {
253 if let Err(e) = write.send(Message::Text(msg)).await {
254 error!("Error sending to {}: {}", addr, e);
255 break;
256 }
257 }
258 });
259
260 while let Some(msg) = read.next().await {
261 match msg {
262 Ok(Message::Text(json)) => {
263 if verbose {
265 debug!("Received from {}: {}", addr, json);
266 }
267
268 let parsed_req = serde_json::from_str::<wavecraft_protocol::IpcRequest>(&json);
270
271 if let Ok(ref req) = parsed_req {
272 if req.method == "registerAudio" {
274 is_audio_client = true;
275 info!("Audio client registered: {}", addr);
276
277 if let Some(params) = req.params.clone()
279 && let Ok(audio_params) = serde_json::from_value::<
280 wavecraft_protocol::RegisterAudioParams,
281 >(params)
282 {
283 *state.audio_client.write().await =
284 Some(audio_params.client_id.clone());
285 }
286
287 let response = wavecraft_protocol::IpcResponse::success(
289 req.id.clone(),
290 wavecraft_protocol::RegisterAudioResult {
291 status: "registered".to_string(),
292 },
293 );
294 let response_json = match serde_json::to_string(&response) {
295 Ok(json) => json,
296 Err(e) => {
297 error!("Failed to serialize registerAudio response: {}", e);
298 break;
299 }
300 };
301 if let Err(e) = tx.try_send(response_json) {
302 error!("Error sending response: {}", e);
303 break;
304 }
305 continue;
306 }
307
308 if is_audio_client && req.method == "meterUpdate" {
310 broadcast_to_browser_clients(
312 &state,
313 &json,
314 Some(client_index),
315 "broadcast meter update",
316 )
317 .await;
318 continue;
319 }
320 }
321
322 let response = handler.handle_json(&json);
324
325 if let Ok(req) = &parsed_req
329 && let Some(notification_json) =
330 build_set_parameter_notification(req, &response)
331 {
332 broadcast_to_browser_clients(
333 &state,
334 ¬ification_json,
335 None,
336 "send parameterChanged notification",
337 )
338 .await;
339 }
340
341 if verbose {
343 debug!("Sending to {}: {}", addr, response);
344 }
345
346 if let Err(e) = tx.try_send(response) {
348 error!("Error queueing response: {}", e);
349 break;
350 }
351 }
352 Ok(Message::Close(_)) => {
353 info!("Client closed connection: {}", addr);
354 break;
355 }
356 Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {
357 }
359 Ok(Message::Binary(_)) => {
360 warn!("Unexpected binary message from {}", addr);
361 }
362 Ok(Message::Frame(_)) => {
363 }
365 Err(e) => {
366 error!("Error receiving from {}: {}", addr, e);
367 break;
368 }
369 }
370 }
371
372 state
374 .browser_clients
375 .write()
376 .await
377 .retain(|c| !c.is_closed());
378 if is_audio_client {
379 *state.audio_client.write().await = None;
380 info!("Audio client disconnected: {}", addr);
381 }
382
383 write_task.abort();
384 info!("Connection closed: {}", addr);
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use wavecraft_bridge::InMemoryParameterHost;
391 use wavecraft_protocol::{IpcRequest, IpcResponse, ParameterInfo, ParameterType, RequestId};
392
393 fn test_host() -> InMemoryParameterHost {
395 InMemoryParameterHost::new(vec![ParameterInfo {
396 id: "gain".to_string(),
397 name: "Gain".to_string(),
398 param_type: ParameterType::Float,
399 value: 0.5,
400 default: 0.5,
401 min: 0.0,
402 max: 1.0,
403 unit: Some("dB".to_string()),
404 group: Some("Input".to_string()),
405 }])
406 }
407
408 #[tokio::test]
409 async fn test_server_creation() {
410 let host = test_host();
411 let handler = Arc::new(IpcHandler::new(host));
412 let server = WsServer::new(9001, handler, false);
413
414 assert_eq!(server.port, 9001);
416 }
417
418 #[test]
419 fn build_set_parameter_notification_from_success_response() {
420 let request = IpcRequest::new(
421 RequestId::Number(1),
422 wavecraft_protocol::METHOD_SET_PARAMETER,
423 Some(serde_json::json!({ "id": "gain", "value": 0.8 })),
424 );
425 let response = serde_json::to_string(&IpcResponse::success(
426 RequestId::Number(1),
427 serde_json::json!({}),
428 ))
429 .expect("serialize response");
430
431 let notification = build_set_parameter_notification(&request, &response)
432 .expect("should create parameterChanged notification");
433 let json: serde_json::Value =
434 serde_json::from_str(¬ification).expect("notification should parse");
435
436 assert_eq!(
437 json.get("method"),
438 Some(&serde_json::json!(
439 wavecraft_protocol::NOTIFICATION_PARAMETER_CHANGED
440 ))
441 );
442 assert_eq!(json.pointer("/params/id"), Some(&serde_json::json!("gain")));
443 let Some(value) = json
444 .pointer("/params/value")
445 .and_then(serde_json::Value::as_f64)
446 else {
447 panic!("notification should contain numeric params.value");
448 };
449 assert!(
450 (value - 0.8).abs() < 1e-5,
451 "expected approx 0.8, got {value}"
452 );
453 }
454
455 #[test]
456 fn build_set_parameter_notification_ignores_error_response() {
457 let request = IpcRequest::new(
458 RequestId::Number(1),
459 wavecraft_protocol::METHOD_SET_PARAMETER,
460 Some(serde_json::json!({ "id": "gain", "value": 10.0 })),
461 );
462
463 let response = serde_json::to_string(&IpcResponse::error(
464 RequestId::Number(1),
465 wavecraft_protocol::IpcError::invalid_params("out of range"),
466 ))
467 .expect("serialize error response");
468
469 assert!(build_set_parameter_notification(&request, &response).is_none());
470 }
471}