Skip to main content

wavecraft_dev_server/ws/
mod.rs

1//! WebSocket server for browser-based UI development
2//!
3//! This module provides a WebSocket server that exposes the same IPC protocol
4//! used by the native WKWebView transport, enabling real-time communication
5//! between a browser-based UI and the Rust engine during development.
6
7use futures_util::{SinkExt, StreamExt};
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{RwLock, broadcast};
12use tokio_tungstenite::{accept_async, tungstenite::protocol::Message};
13use tracing::{debug, error, info, warn};
14use wavecraft_bridge::{IpcHandler, ParameterHost};
15use wavecraft_protocol::{
16    AudioRuntimeStatus, IpcNotification, IpcRequest, IpcResponse, METHOD_REGISTER_AUDIO,
17    NOTIFICATION_AUDIO_STATUS_CHANGED, NOTIFICATION_METER_UPDATE, NOTIFICATION_PARAMETER_CHANGED,
18    SetParameterParams,
19};
20
21const NOTIFICATION_PARAMETERS_CHANGED: &str = "parametersChanged";
22
23type BrowserClientTx = tokio::sync::mpsc::Sender<String>;
24
25/// Shared state for tracking connected clients
26struct ServerState {
27    /// Connected browser clients (for broadcasting meter updates)
28    browser_clients: Arc<RwLock<Vec<BrowserClientTx>>>,
29    /// Audio client ID (if connected)
30    audio_client: Arc<RwLock<Option<String>>>,
31}
32
33impl ServerState {
34    fn new() -> Self {
35        Self {
36            browser_clients: Arc::new(RwLock::new(Vec::new())),
37            audio_client: Arc::new(RwLock::new(None)),
38        }
39    }
40}
41
42/// A lightweight, cloneable handle to the WebSocket server's broadcast
43/// capability. Non-generic — can be passed across async task boundaries.
44///
45/// Constructed via [`WsServer::handle()`]. Used by the CLI to forward
46/// meter updates from the in-process audio callback to browser clients.
47#[derive(Clone)]
48pub struct WsHandle {
49    state: Arc<ServerState>,
50}
51
52impl WsHandle {
53    /// Broadcast a JSON string to all connected browser clients.
54    pub async fn broadcast(&self, json: &str) {
55        broadcast_to_browser_clients(&self.state, json, None, "broadcast message").await;
56    }
57
58    /// Broadcast an audioStatusChanged notification to connected clients.
59    pub async fn broadcast_audio_status_changed(
60        &self,
61        status: &AudioRuntimeStatus,
62    ) -> Result<(), serde_json::Error> {
63        let json = serde_json::to_string(&IpcNotification::new(
64            NOTIFICATION_AUDIO_STATUS_CHANGED,
65            status,
66        ))?;
67
68        self.broadcast(&json).await;
69        Ok(())
70    }
71}
72
73async fn broadcast_to_browser_clients(
74    state: &Arc<ServerState>,
75    json: &str,
76    exclude_client_index: Option<usize>,
77    warning_context: &str,
78) {
79    let clients = state.browser_clients.read().await;
80    for (index, client) in clients.iter().enumerate() {
81        if exclude_client_index.is_some_and(|excluded| index == excluded) {
82            continue;
83        }
84
85        if let Err(error) = client.try_send(json.to_owned()) {
86            warn!(
87                "Failed to {} (client {}): {}",
88                warning_context, index, error
89            );
90        }
91    }
92}
93
94/// WebSocket server for browser-based UI development
95pub struct WsServer<H: ParameterHost + 'static> {
96    /// Port the server listens on
97    port: u16,
98    /// Shared IPC handler
99    handler: Arc<IpcHandler<H>>,
100    /// Shutdown signal
101    shutdown_tx: broadcast::Sender<()>,
102    /// Shared server state
103    state: Arc<ServerState>,
104}
105
106fn build_set_parameter_notification(request: &IpcRequest, response: &str) -> Option<String> {
107    if request.method != wavecraft_protocol::METHOD_SET_PARAMETER {
108        return None;
109    }
110
111    let response_msg = serde_json::from_str::<IpcResponse>(response).ok()?;
112    if response_msg.error.is_some() {
113        return None;
114    }
115
116    let params = request.params.clone()?;
117    let set_params = serde_json::from_value::<SetParameterParams>(params).ok()?;
118
119    serde_json::to_string(&IpcNotification::new(
120        NOTIFICATION_PARAMETER_CHANGED,
121        serde_json::json!({
122            "id": set_params.id,
123            "value": set_params.value,
124        }),
125    ))
126    .ok()
127}
128
129impl<H: ParameterHost + 'static> WsServer<H> {
130    /// Create a new WebSocket server
131    pub fn new(port: u16, handler: Arc<IpcHandler<H>>) -> Self {
132        let (shutdown_tx, _) = broadcast::channel(1);
133        Self {
134            port,
135            handler,
136            shutdown_tx,
137            state: Arc::new(ServerState::new()),
138        }
139    }
140
141    /// Get a lightweight handle for broadcasting to connected clients.
142    ///
143    /// The returned `WsHandle` is non-generic, `Clone`, and can be moved
144    /// into async tasks (e.g., for forwarding meter updates from audio).
145    pub fn handle(&self) -> WsHandle {
146        WsHandle {
147            state: Arc::clone(&self.state),
148        }
149    }
150
151    /// Broadcast a parametersChanged notification to all connected clients.
152    ///
153    /// This is used by the hot-reload pipeline to notify the UI that
154    /// parameters have been updated and should be re-fetched.
155    pub async fn broadcast_parameters_changed(&self) -> Result<(), serde_json::Error> {
156        let notification =
157            IpcNotification::new(NOTIFICATION_PARAMETERS_CHANGED, serde_json::json!({}));
158        let json = serde_json::to_string(&notification)?;
159
160        broadcast_to_browser_clients(
161            &self.state,
162            &json,
163            None,
164            "send parametersChanged notification",
165        )
166        .await;
167
168        Ok(())
169    }
170
171    /// Start the server (spawns async tasks, returns immediately)
172    pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
173        let addr: SocketAddr = format!("127.0.0.1:{}", self.port).parse()?;
174        let listener = TcpListener::bind(&addr).await?;
175
176        info!("Server listening on ws://{}", addr);
177
178        let handler = Arc::clone(&self.handler);
179        let mut shutdown_rx = self.shutdown_tx.subscribe();
180        let state = Arc::clone(&self.state);
181
182        tokio::spawn(async move {
183            loop {
184                tokio::select! {
185                    result = listener.accept() => {
186                        match result {
187                            Ok((stream, addr)) => {
188                                info!("Client connected: {}", addr);
189                                let handler = Arc::clone(&handler);
190                                let state = Arc::clone(&state);
191                                tokio::spawn(handle_connection(handler, stream, addr, state));
192                            }
193                            Err(e) => {
194                                error!("Accept error: {}", e);
195                            }
196                        }
197                    }
198                    _ = shutdown_rx.recv() => {
199                        info!("Server shutting down");
200                        break;
201                    }
202                }
203            }
204        });
205
206        Ok(())
207    }
208
209    /// Shutdown the server gracefully.
210    ///
211    /// Note: Not currently called but kept for future graceful shutdown support.
212    #[allow(dead_code)]
213    pub fn shutdown(&self) {
214        let _ = self.shutdown_tx.send(());
215    }
216}
217
218/// Handle a single WebSocket connection
219async fn handle_connection<H: ParameterHost>(
220    handler: Arc<IpcHandler<H>>,
221    stream: TcpStream,
222    addr: SocketAddr,
223    state: Arc<ServerState>,
224) {
225    let ws_stream = match accept_async(stream).await {
226        Ok(ws) => ws,
227        Err(e) => {
228            error!("Error during handshake with {}: {}", addr, e);
229            return;
230        }
231    };
232
233    info!("WebSocket connection established: {}", addr);
234
235    let (mut write, mut read) = ws_stream.split();
236    let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(128);
237
238    // Track this client for broadcasting
239    let mut is_audio_client = false;
240    let client_index = {
241        let mut clients = state.browser_clients.write().await;
242        clients.push(tx.clone());
243        clients.len() - 1
244    };
245
246    // Spawn task to send messages from channel to WebSocket
247    let write_task = tokio::spawn(async move {
248        while let Some(msg) = rx.recv().await {
249            if let Err(e) = write.send(Message::Text(msg)).await {
250                error!("Error sending to {}: {}", addr, e);
251                break;
252            }
253        }
254    });
255
256    while let Some(msg) = read.next().await {
257        match msg {
258            Ok(Message::Text(json)) => {
259                debug!("Received from {}: {}", addr, json);
260
261                // Try to parse as IPC request for structured routing
262                let parsed_req = serde_json::from_str::<IpcRequest>(&json);
263
264                if let Ok(ref req) = parsed_req {
265                    // Handle registerAudio
266                    if req.method == METHOD_REGISTER_AUDIO {
267                        is_audio_client = true;
268                        info!("Audio client registered: {}", addr);
269
270                        // Parse to extract client_id
271                        if let Some(params) = req.params.clone()
272                            && let Ok(audio_params) = serde_json::from_value::<
273                                wavecraft_protocol::RegisterAudioParams,
274                            >(params)
275                        {
276                            *state.audio_client.write().await =
277                                Some(audio_params.client_id.clone());
278                        }
279
280                        // Send success response using the request's id
281                        let response = wavecraft_protocol::IpcResponse::success(
282                            req.id.clone(),
283                            wavecraft_protocol::RegisterAudioResult {
284                                status: "registered".to_string(),
285                            },
286                        );
287                        let response_json = match serde_json::to_string(&response) {
288                            Ok(json) => json,
289                            Err(e) => {
290                                error!("Failed to serialize registerAudio response: {}", e);
291                                break;
292                            }
293                        };
294                        if let Err(e) = tx.try_send(response_json) {
295                            error!("Error sending response: {}", e);
296                            break;
297                        }
298                        continue;
299                    }
300
301                    // Handle meterUpdate from audio client
302                    if is_audio_client && req.method == NOTIFICATION_METER_UPDATE {
303                        // Broadcast to all browser clients
304                        broadcast_to_browser_clients(
305                            &state,
306                            &json,
307                            Some(client_index),
308                            "broadcast meter update",
309                        )
310                        .await;
311                        continue;
312                    }
313                }
314
315                // Route through existing IpcHandler
316                let response = handler.handle_json(&json);
317
318                // Mirror native editor behavior in dev mode: after successful
319                // setParameter, emit parameterChanged so hooks relying on
320                // notifications stay in sync with backend-confirmed state.
321                if let Ok(req) = &parsed_req
322                    && let Some(notification_json) =
323                        build_set_parameter_notification(req, &response)
324                {
325                    broadcast_to_browser_clients(
326                        &state,
327                        &notification_json,
328                        None,
329                        "send parameterChanged notification",
330                    )
331                    .await;
332                }
333
334                // Log outgoing response
335                debug!("Sending to {}: {}", addr, response);
336
337                // Send response
338                if let Err(e) = tx.try_send(response) {
339                    error!("Error queueing response: {}", e);
340                    break;
341                }
342            }
343            Ok(Message::Close(_)) => {
344                info!("Client closed connection: {}", addr);
345                break;
346            }
347            Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {
348                // Ignore ping/pong frames (automatically handled)
349            }
350            Ok(Message::Binary(_)) => {
351                warn!("Unexpected binary message from {}", addr);
352            }
353            Ok(Message::Frame(_)) => {
354                // Raw frames shouldn't appear at this level
355            }
356            Err(e) => {
357                error!("Error receiving from {}: {}", addr, e);
358                break;
359            }
360        }
361    }
362
363    // Cleanup: remove client from broadcast list
364    state
365        .browser_clients
366        .write()
367        .await
368        .retain(|c| !c.is_closed());
369    if is_audio_client {
370        *state.audio_client.write().await = None;
371        info!("Audio client disconnected: {}", addr);
372    }
373
374    write_task.abort();
375    info!("Connection closed: {}", addr);
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use wavecraft_bridge::InMemoryParameterHost;
382    use wavecraft_protocol::{IpcRequest, IpcResponse, ParameterInfo, ParameterType, RequestId};
383
384    /// Simple test host for unit tests
385    fn test_host() -> InMemoryParameterHost {
386        InMemoryParameterHost::new(vec![ParameterInfo {
387            id: "gain".to_string(),
388            name: "Gain".to_string(),
389            param_type: ParameterType::Float,
390            value: 0.5,
391            default: 0.5,
392            min: 0.0,
393            max: 1.0,
394            unit: Some("dB".to_string()),
395            group: Some("Input".to_string()),
396            variants: None,
397        }])
398    }
399
400    #[tokio::test]
401    async fn test_server_creation() {
402        let host = test_host();
403        let handler = Arc::new(IpcHandler::new(host));
404        let server = WsServer::new(9001, handler);
405
406        // Just verify we can create a server without panicking
407        assert_eq!(server.port, 9001);
408    }
409
410    #[test]
411    fn build_set_parameter_notification_from_success_response() {
412        let request = IpcRequest::new(
413            RequestId::Number(1),
414            wavecraft_protocol::METHOD_SET_PARAMETER,
415            Some(serde_json::json!({ "id": "gain", "value": 0.8 })),
416        );
417        let response = serde_json::to_string(&IpcResponse::success(
418            RequestId::Number(1),
419            serde_json::json!({}),
420        ))
421        .expect("serialize response");
422
423        let notification = build_set_parameter_notification(&request, &response)
424            .expect("should create parameterChanged notification");
425        let json: serde_json::Value =
426            serde_json::from_str(&notification).expect("notification should parse");
427
428        assert_eq!(
429            json.get("method"),
430            Some(&serde_json::json!(
431                wavecraft_protocol::NOTIFICATION_PARAMETER_CHANGED
432            ))
433        );
434        assert_eq!(json.pointer("/params/id"), Some(&serde_json::json!("gain")));
435        let Some(value) = json
436            .pointer("/params/value")
437            .and_then(serde_json::Value::as_f64)
438        else {
439            panic!("notification should contain numeric params.value");
440        };
441        assert!(
442            (value - 0.8).abs() < 1e-5,
443            "expected approx 0.8, got {value}"
444        );
445    }
446
447    #[test]
448    fn build_set_parameter_notification_ignores_error_response() {
449        let request = IpcRequest::new(
450            RequestId::Number(1),
451            wavecraft_protocol::METHOD_SET_PARAMETER,
452            Some(serde_json::json!({ "id": "gain", "value": 10.0 })),
453        );
454
455        let response = serde_json::to_string(&IpcResponse::error(
456            RequestId::Number(1),
457            wavecraft_protocol::IpcError::invalid_params("out of range"),
458        ))
459        .expect("serialize error response");
460
461        assert!(build_set_parameter_notification(&request, &response).is_none());
462    }
463}