1use futures_util::{SinkExt, StreamExt};
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{RwLock, broadcast};
12use tokio_tungstenite::{accept_async, tungstenite::protocol::Message};
13use tracing::{debug, error, info, warn};
14use wavecraft_bridge::{IpcHandler, ParameterHost};
15use wavecraft_protocol::{
16 AudioRuntimeStatus, IpcNotification, IpcResponse, NOTIFICATION_AUDIO_STATUS_CHANGED,
17};
18
19struct ServerState {
21 browser_clients: Arc<RwLock<Vec<tokio::sync::mpsc::Sender<String>>>>,
23 audio_client: Arc<RwLock<Option<String>>>,
25}
26
27impl ServerState {
28 fn new() -> Self {
29 Self {
30 browser_clients: Arc::new(RwLock::new(Vec::new())),
31 audio_client: Arc::new(RwLock::new(None)),
32 }
33 }
34}
35
36#[derive(Clone)]
42pub struct WsHandle {
43 state: Arc<ServerState>,
44}
45
46impl WsHandle {
47 pub async fn broadcast(&self, json: &str) {
49 broadcast_to_browser_clients(&self.state, json, None, "broadcast message").await;
50 }
51
52 pub async fn broadcast_audio_status_changed(
54 &self,
55 status: &AudioRuntimeStatus,
56 ) -> Result<(), serde_json::Error> {
57 let json = serde_json::to_string(&IpcNotification::new(
58 NOTIFICATION_AUDIO_STATUS_CHANGED,
59 status,
60 ))?;
61
62 self.broadcast(&json).await;
63 Ok(())
64 }
65}
66
67async fn broadcast_to_browser_clients(
68 state: &Arc<ServerState>,
69 json: &str,
70 exclude_client_index: Option<usize>,
71 warning_context: &str,
72) {
73 let clients = state.browser_clients.read().await;
74 for (index, client) in clients.iter().enumerate() {
75 if exclude_client_index.is_some_and(|excluded| index == excluded) {
76 continue;
77 }
78
79 if let Err(error) = client.try_send(json.to_owned()) {
80 warn!(
81 "Failed to {} (client {}): {}",
82 warning_context, index, error
83 );
84 }
85 }
86}
87
88pub struct WsServer<H: ParameterHost + 'static> {
90 port: u16,
92 handler: Arc<IpcHandler<H>>,
94 shutdown_tx: broadcast::Sender<()>,
96 state: Arc<ServerState>,
98}
99
100fn build_set_parameter_notification(
101 request: &wavecraft_protocol::IpcRequest,
102 response: &str,
103) -> Option<String> {
104 if request.method != wavecraft_protocol::METHOD_SET_PARAMETER {
105 return None;
106 }
107
108 let response_msg = serde_json::from_str::<IpcResponse>(response).ok()?;
109 if response_msg.error.is_some() {
110 return None;
111 }
112
113 let params = request.params.clone()?;
114 let set_params =
115 serde_json::from_value::<wavecraft_protocol::SetParameterParams>(params).ok()?;
116
117 serde_json::to_string(&IpcNotification::new(
118 wavecraft_protocol::NOTIFICATION_PARAMETER_CHANGED,
119 serde_json::json!({
120 "id": set_params.id,
121 "value": set_params.value,
122 }),
123 ))
124 .ok()
125}
126
127impl<H: ParameterHost + 'static> WsServer<H> {
128 pub fn new(port: u16, handler: Arc<IpcHandler<H>>) -> Self {
130 let (shutdown_tx, _) = broadcast::channel(1);
131 Self {
132 port,
133 handler,
134 shutdown_tx,
135 state: Arc::new(ServerState::new()),
136 }
137 }
138
139 pub fn handle(&self) -> WsHandle {
144 WsHandle {
145 state: Arc::clone(&self.state),
146 }
147 }
148
149 pub async fn broadcast_parameters_changed(&self) -> Result<(), serde_json::Error> {
154 use wavecraft_protocol::IpcNotification;
155
156 let notification = IpcNotification::new("parametersChanged", serde_json::json!({}));
157 let json = serde_json::to_string(¬ification)?;
158
159 broadcast_to_browser_clients(
160 &self.state,
161 &json,
162 None,
163 "send parametersChanged notification",
164 )
165 .await;
166
167 Ok(())
168 }
169
170 pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
172 let addr: SocketAddr = format!("127.0.0.1:{}", self.port).parse()?;
173 let listener = TcpListener::bind(&addr).await?;
174
175 info!("Server listening on ws://{}", addr);
176
177 let handler = Arc::clone(&self.handler);
178 let mut shutdown_rx = self.shutdown_tx.subscribe();
179 let state = Arc::clone(&self.state);
180
181 tokio::spawn(async move {
182 loop {
183 tokio::select! {
184 result = listener.accept() => {
185 match result {
186 Ok((stream, addr)) => {
187 info!("Client connected: {}", addr);
188 let handler = Arc::clone(&handler);
189 let state = Arc::clone(&state);
190 tokio::spawn(handle_connection(handler, stream, addr, state));
191 }
192 Err(e) => {
193 error!("Accept error: {}", e);
194 }
195 }
196 }
197 _ = shutdown_rx.recv() => {
198 info!("Server shutting down");
199 break;
200 }
201 }
202 }
203 });
204
205 Ok(())
206 }
207
208 #[allow(dead_code)]
212 pub fn shutdown(&self) {
213 let _ = self.shutdown_tx.send(());
214 }
215}
216
217async fn handle_connection<H: ParameterHost>(
219 handler: Arc<IpcHandler<H>>,
220 stream: TcpStream,
221 addr: SocketAddr,
222 state: Arc<ServerState>,
223) {
224 let ws_stream = match accept_async(stream).await {
225 Ok(ws) => ws,
226 Err(e) => {
227 error!("Error during handshake with {}: {}", addr, e);
228 return;
229 }
230 };
231
232 info!("WebSocket connection established: {}", addr);
233
234 let (mut write, mut read) = ws_stream.split();
235 let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(128);
236
237 let mut is_audio_client = false;
239 let client_index = {
240 let mut clients = state.browser_clients.write().await;
241 clients.push(tx.clone());
242 clients.len() - 1
243 };
244
245 let write_task = tokio::spawn(async move {
247 while let Some(msg) = rx.recv().await {
248 if let Err(e) = write.send(Message::Text(msg)).await {
249 error!("Error sending to {}: {}", addr, e);
250 break;
251 }
252 }
253 });
254
255 while let Some(msg) = read.next().await {
256 match msg {
257 Ok(Message::Text(json)) => {
258 debug!("Received from {}: {}", addr, json);
259
260 let parsed_req = serde_json::from_str::<wavecraft_protocol::IpcRequest>(&json);
262
263 if let Ok(ref req) = parsed_req {
264 if req.method == "registerAudio" {
266 is_audio_client = true;
267 info!("Audio client registered: {}", addr);
268
269 if let Some(params) = req.params.clone()
271 && let Ok(audio_params) = serde_json::from_value::<
272 wavecraft_protocol::RegisterAudioParams,
273 >(params)
274 {
275 *state.audio_client.write().await =
276 Some(audio_params.client_id.clone());
277 }
278
279 let response = wavecraft_protocol::IpcResponse::success(
281 req.id.clone(),
282 wavecraft_protocol::RegisterAudioResult {
283 status: "registered".to_string(),
284 },
285 );
286 let response_json = match serde_json::to_string(&response) {
287 Ok(json) => json,
288 Err(e) => {
289 error!("Failed to serialize registerAudio response: {}", e);
290 break;
291 }
292 };
293 if let Err(e) = tx.try_send(response_json) {
294 error!("Error sending response: {}", e);
295 break;
296 }
297 continue;
298 }
299
300 if is_audio_client && req.method == "meterUpdate" {
302 broadcast_to_browser_clients(
304 &state,
305 &json,
306 Some(client_index),
307 "broadcast meter update",
308 )
309 .await;
310 continue;
311 }
312 }
313
314 let response = handler.handle_json(&json);
316
317 if let Ok(req) = &parsed_req
321 && let Some(notification_json) =
322 build_set_parameter_notification(req, &response)
323 {
324 broadcast_to_browser_clients(
325 &state,
326 ¬ification_json,
327 None,
328 "send parameterChanged notification",
329 )
330 .await;
331 }
332
333 debug!("Sending to {}: {}", addr, response);
335
336 if let Err(e) = tx.try_send(response) {
338 error!("Error queueing response: {}", e);
339 break;
340 }
341 }
342 Ok(Message::Close(_)) => {
343 info!("Client closed connection: {}", addr);
344 break;
345 }
346 Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {
347 }
349 Ok(Message::Binary(_)) => {
350 warn!("Unexpected binary message from {}", addr);
351 }
352 Ok(Message::Frame(_)) => {
353 }
355 Err(e) => {
356 error!("Error receiving from {}: {}", addr, e);
357 break;
358 }
359 }
360 }
361
362 state
364 .browser_clients
365 .write()
366 .await
367 .retain(|c| !c.is_closed());
368 if is_audio_client {
369 *state.audio_client.write().await = None;
370 info!("Audio client disconnected: {}", addr);
371 }
372
373 write_task.abort();
374 info!("Connection closed: {}", addr);
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use wavecraft_bridge::InMemoryParameterHost;
381 use wavecraft_protocol::{IpcRequest, IpcResponse, ParameterInfo, ParameterType, RequestId};
382
383 fn test_host() -> InMemoryParameterHost {
385 InMemoryParameterHost::new(vec![ParameterInfo {
386 id: "gain".to_string(),
387 name: "Gain".to_string(),
388 param_type: ParameterType::Float,
389 value: 0.5,
390 default: 0.5,
391 min: 0.0,
392 max: 1.0,
393 unit: Some("dB".to_string()),
394 group: Some("Input".to_string()),
395 variants: None,
396 }])
397 }
398
399 #[tokio::test]
400 async fn test_server_creation() {
401 let host = test_host();
402 let handler = Arc::new(IpcHandler::new(host));
403 let server = WsServer::new(9001, handler);
404
405 assert_eq!(server.port, 9001);
407 }
408
409 #[test]
410 fn build_set_parameter_notification_from_success_response() {
411 let request = IpcRequest::new(
412 RequestId::Number(1),
413 wavecraft_protocol::METHOD_SET_PARAMETER,
414 Some(serde_json::json!({ "id": "gain", "value": 0.8 })),
415 );
416 let response = serde_json::to_string(&IpcResponse::success(
417 RequestId::Number(1),
418 serde_json::json!({}),
419 ))
420 .expect("serialize response");
421
422 let notification = build_set_parameter_notification(&request, &response)
423 .expect("should create parameterChanged notification");
424 let json: serde_json::Value =
425 serde_json::from_str(¬ification).expect("notification should parse");
426
427 assert_eq!(
428 json.get("method"),
429 Some(&serde_json::json!(
430 wavecraft_protocol::NOTIFICATION_PARAMETER_CHANGED
431 ))
432 );
433 assert_eq!(json.pointer("/params/id"), Some(&serde_json::json!("gain")));
434 let Some(value) = json
435 .pointer("/params/value")
436 .and_then(serde_json::Value::as_f64)
437 else {
438 panic!("notification should contain numeric params.value");
439 };
440 assert!(
441 (value - 0.8).abs() < 1e-5,
442 "expected approx 0.8, got {value}"
443 );
444 }
445
446 #[test]
447 fn build_set_parameter_notification_ignores_error_response() {
448 let request = IpcRequest::new(
449 RequestId::Number(1),
450 wavecraft_protocol::METHOD_SET_PARAMETER,
451 Some(serde_json::json!({ "id": "gain", "value": 10.0 })),
452 );
453
454 let response = serde_json::to_string(&IpcResponse::error(
455 RequestId::Number(1),
456 wavecraft_protocol::IpcError::invalid_params("out of range"),
457 ))
458 .expect("serialize error response");
459
460 assert!(build_set_parameter_notification(&request, &response).is_none());
461 }
462}