Skip to main content

wavecraft_dev_server/ws/
mod.rs

1//! WebSocket server for browser-based UI development
2//!
3//! This module provides a WebSocket server that exposes the same IPC protocol
4//! used by the native WKWebView transport, enabling real-time communication
5//! between a browser-based UI and the Rust engine during development.
6
7use futures_util::{SinkExt, StreamExt};
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{RwLock, broadcast};
12use tokio_tungstenite::{accept_async, tungstenite::protocol::Message};
13use tracing::{debug, error, info, warn};
14use wavecraft_bridge::{IpcHandler, ParameterHost};
15use wavecraft_protocol::{
16    AudioRuntimeStatus, IpcNotification, IpcResponse, NOTIFICATION_AUDIO_STATUS_CHANGED,
17};
18
19/// Shared state for tracking connected clients
20struct ServerState {
21    /// Connected browser clients (for broadcasting meter updates)
22    browser_clients: Arc<RwLock<Vec<tokio::sync::mpsc::Sender<String>>>>,
23    /// Audio client ID (if connected)
24    audio_client: Arc<RwLock<Option<String>>>,
25}
26
27impl ServerState {
28    fn new() -> Self {
29        Self {
30            browser_clients: Arc::new(RwLock::new(Vec::new())),
31            audio_client: Arc::new(RwLock::new(None)),
32        }
33    }
34}
35
36/// A lightweight, cloneable handle to the WebSocket server's broadcast
37/// capability. Non-generic — can be passed across async task boundaries.
38///
39/// Constructed via [`WsServer::handle()`]. Used by the CLI to forward
40/// meter updates from the in-process audio callback to browser clients.
41#[derive(Clone)]
42pub struct WsHandle {
43    state: Arc<ServerState>,
44}
45
46impl WsHandle {
47    /// Broadcast a JSON string to all connected browser clients.
48    pub async fn broadcast(&self, json: &str) {
49        broadcast_to_browser_clients(&self.state, json, None, "broadcast message").await;
50    }
51
52    /// Broadcast an audioStatusChanged notification to connected clients.
53    pub async fn broadcast_audio_status_changed(
54        &self,
55        status: &AudioRuntimeStatus,
56    ) -> Result<(), serde_json::Error> {
57        let json = serde_json::to_string(&IpcNotification::new(
58            NOTIFICATION_AUDIO_STATUS_CHANGED,
59            status,
60        ))?;
61
62        self.broadcast(&json).await;
63        Ok(())
64    }
65}
66
67async fn broadcast_to_browser_clients(
68    state: &Arc<ServerState>,
69    json: &str,
70    exclude_client_index: Option<usize>,
71    warning_context: &str,
72) {
73    let clients = state.browser_clients.read().await;
74    for (index, client) in clients.iter().enumerate() {
75        if exclude_client_index.is_some_and(|excluded| index == excluded) {
76            continue;
77        }
78
79        if let Err(error) = client.try_send(json.to_owned()) {
80            warn!(
81                "Failed to {} (client {}): {}",
82                warning_context, index, error
83            );
84        }
85    }
86}
87
88/// WebSocket server for browser-based UI development
89pub struct WsServer<H: ParameterHost + 'static> {
90    /// Port the server listens on
91    port: u16,
92    /// Shared IPC handler
93    handler: Arc<IpcHandler<H>>,
94    /// Shutdown signal
95    shutdown_tx: broadcast::Sender<()>,
96    /// Enable verbose logging (all JSON-RPC messages)
97    verbose: bool,
98    /// Shared server state
99    state: Arc<ServerState>,
100}
101
102fn build_set_parameter_notification(
103    request: &wavecraft_protocol::IpcRequest,
104    response: &str,
105) -> Option<String> {
106    if request.method != wavecraft_protocol::METHOD_SET_PARAMETER {
107        return None;
108    }
109
110    let response_msg = serde_json::from_str::<IpcResponse>(response).ok()?;
111    if response_msg.error.is_some() {
112        return None;
113    }
114
115    let params = request.params.clone()?;
116    let set_params =
117        serde_json::from_value::<wavecraft_protocol::SetParameterParams>(params).ok()?;
118
119    serde_json::to_string(&IpcNotification::new(
120        wavecraft_protocol::NOTIFICATION_PARAMETER_CHANGED,
121        serde_json::json!({
122            "id": set_params.id,
123            "value": set_params.value,
124        }),
125    ))
126    .ok()
127}
128
129impl<H: ParameterHost + 'static> WsServer<H> {
130    /// Create a new WebSocket server
131    pub fn new(port: u16, handler: Arc<IpcHandler<H>>, verbose: bool) -> Self {
132        let (shutdown_tx, _) = broadcast::channel(1);
133        Self {
134            port,
135            handler,
136            shutdown_tx,
137            verbose,
138            state: Arc::new(ServerState::new()),
139        }
140    }
141
142    /// Get a lightweight handle for broadcasting to connected clients.
143    ///
144    /// The returned `WsHandle` is non-generic, `Clone`, and can be moved
145    /// into async tasks (e.g., for forwarding meter updates from audio).
146    pub fn handle(&self) -> WsHandle {
147        WsHandle {
148            state: Arc::clone(&self.state),
149        }
150    }
151
152    /// Broadcast a parametersChanged notification to all connected clients.
153    ///
154    /// This is used by the hot-reload pipeline to notify the UI that
155    /// parameters have been updated and should be re-fetched.
156    pub async fn broadcast_parameters_changed(&self) -> Result<(), serde_json::Error> {
157        use wavecraft_protocol::IpcNotification;
158
159        let notification = IpcNotification::new("parametersChanged", serde_json::json!({}));
160        let json = serde_json::to_string(&notification)?;
161
162        broadcast_to_browser_clients(
163            &self.state,
164            &json,
165            None,
166            "send parametersChanged notification",
167        )
168        .await;
169
170        Ok(())
171    }
172
173    /// Start the server (spawns async tasks, returns immediately)
174    pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
175        let addr: SocketAddr = format!("127.0.0.1:{}", self.port).parse()?;
176        let listener = TcpListener::bind(&addr).await?;
177
178        info!("Server listening on ws://{}", addr);
179
180        let handler = Arc::clone(&self.handler);
181        let mut shutdown_rx = self.shutdown_tx.subscribe();
182        let verbose = self.verbose;
183        let state = Arc::clone(&self.state);
184
185        tokio::spawn(async move {
186            loop {
187                tokio::select! {
188                    result = listener.accept() => {
189                        match result {
190                            Ok((stream, addr)) => {
191                                info!("Client connected: {}", addr);
192                                let handler = Arc::clone(&handler);
193                                let state = Arc::clone(&state);
194                                tokio::spawn(handle_connection(handler, stream, addr, verbose, state));
195                            }
196                            Err(e) => {
197                                error!("Accept error: {}", e);
198                            }
199                        }
200                    }
201                    _ = shutdown_rx.recv() => {
202                        info!("Server shutting down");
203                        break;
204                    }
205                }
206            }
207        });
208
209        Ok(())
210    }
211
212    /// Shutdown the server gracefully.
213    ///
214    /// Note: Not currently called but kept for future graceful shutdown support.
215    #[allow(dead_code)]
216    pub fn shutdown(&self) {
217        let _ = self.shutdown_tx.send(());
218    }
219}
220
221/// Handle a single WebSocket connection
222async fn handle_connection<H: ParameterHost>(
223    handler: Arc<IpcHandler<H>>,
224    stream: TcpStream,
225    addr: SocketAddr,
226    verbose: bool,
227    state: Arc<ServerState>,
228) {
229    let ws_stream = match accept_async(stream).await {
230        Ok(ws) => ws,
231        Err(e) => {
232            error!("Error during handshake with {}: {}", addr, e);
233            return;
234        }
235    };
236
237    info!("WebSocket connection established: {}", addr);
238
239    let (mut write, mut read) = ws_stream.split();
240    let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(128);
241
242    // Track this client for broadcasting
243    let mut is_audio_client = false;
244    let client_index = {
245        let mut clients = state.browser_clients.write().await;
246        clients.push(tx.clone());
247        clients.len() - 1
248    };
249
250    // Spawn task to send messages from channel to WebSocket
251    let write_task = tokio::spawn(async move {
252        while let Some(msg) = rx.recv().await {
253            if let Err(e) = write.send(Message::Text(msg)).await {
254                error!("Error sending to {}: {}", addr, e);
255                break;
256            }
257        }
258    });
259
260    while let Some(msg) = read.next().await {
261        match msg {
262            Ok(Message::Text(json)) => {
263                // Log incoming message (verbose only)
264                if verbose {
265                    debug!("Received from {}: {}", addr, json);
266                }
267
268                // Try to parse as IPC request for structured routing
269                let parsed_req = serde_json::from_str::<wavecraft_protocol::IpcRequest>(&json);
270
271                if let Ok(ref req) = parsed_req {
272                    // Handle registerAudio
273                    if req.method == "registerAudio" {
274                        is_audio_client = true;
275                        info!("Audio client registered: {}", addr);
276
277                        // Parse to extract client_id
278                        if let Some(params) = req.params.clone()
279                            && let Ok(audio_params) = serde_json::from_value::<
280                                wavecraft_protocol::RegisterAudioParams,
281                            >(params)
282                        {
283                            *state.audio_client.write().await =
284                                Some(audio_params.client_id.clone());
285                        }
286
287                        // Send success response using the request's id
288                        let response = wavecraft_protocol::IpcResponse::success(
289                            req.id.clone(),
290                            wavecraft_protocol::RegisterAudioResult {
291                                status: "registered".to_string(),
292                            },
293                        );
294                        let response_json = match serde_json::to_string(&response) {
295                            Ok(json) => json,
296                            Err(e) => {
297                                error!("Failed to serialize registerAudio response: {}", e);
298                                break;
299                            }
300                        };
301                        if let Err(e) = tx.try_send(response_json) {
302                            error!("Error sending response: {}", e);
303                            break;
304                        }
305                        continue;
306                    }
307
308                    // Handle meterUpdate from audio client
309                    if is_audio_client && req.method == "meterUpdate" {
310                        // Broadcast to all browser clients
311                        broadcast_to_browser_clients(
312                            &state,
313                            &json,
314                            Some(client_index),
315                            "broadcast meter update",
316                        )
317                        .await;
318                        continue;
319                    }
320                }
321
322                // Route through existing IpcHandler
323                let response = handler.handle_json(&json);
324
325                // Mirror native editor behavior in dev mode: after successful
326                // setParameter, emit parameterChanged so hooks relying on
327                // notifications stay in sync with backend-confirmed state.
328                if let Ok(req) = &parsed_req
329                    && let Some(notification_json) =
330                        build_set_parameter_notification(req, &response)
331                {
332                    broadcast_to_browser_clients(
333                        &state,
334                        &notification_json,
335                        None,
336                        "send parameterChanged notification",
337                    )
338                    .await;
339                }
340
341                // Log outgoing response (verbose only)
342                if verbose {
343                    debug!("Sending to {}: {}", addr, response);
344                }
345
346                // Send response
347                if let Err(e) = tx.try_send(response) {
348                    error!("Error queueing response: {}", e);
349                    break;
350                }
351            }
352            Ok(Message::Close(_)) => {
353                info!("Client closed connection: {}", addr);
354                break;
355            }
356            Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {
357                // Ignore ping/pong frames (automatically handled)
358            }
359            Ok(Message::Binary(_)) => {
360                warn!("Unexpected binary message from {}", addr);
361            }
362            Ok(Message::Frame(_)) => {
363                // Raw frames shouldn't appear at this level
364            }
365            Err(e) => {
366                error!("Error receiving from {}: {}", addr, e);
367                break;
368            }
369        }
370    }
371
372    // Cleanup: remove client from broadcast list
373    state
374        .browser_clients
375        .write()
376        .await
377        .retain(|c| !c.is_closed());
378    if is_audio_client {
379        *state.audio_client.write().await = None;
380        info!("Audio client disconnected: {}", addr);
381    }
382
383    write_task.abort();
384    info!("Connection closed: {}", addr);
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use wavecraft_bridge::InMemoryParameterHost;
391    use wavecraft_protocol::{IpcRequest, IpcResponse, ParameterInfo, ParameterType, RequestId};
392
393    /// Simple test host for unit tests
394    fn test_host() -> InMemoryParameterHost {
395        InMemoryParameterHost::new(vec![ParameterInfo {
396            id: "gain".to_string(),
397            name: "Gain".to_string(),
398            param_type: ParameterType::Float,
399            value: 0.5,
400            default: 0.5,
401            min: 0.0,
402            max: 1.0,
403            unit: Some("dB".to_string()),
404            group: Some("Input".to_string()),
405        }])
406    }
407
408    #[tokio::test]
409    async fn test_server_creation() {
410        let host = test_host();
411        let handler = Arc::new(IpcHandler::new(host));
412        let server = WsServer::new(9001, handler, false);
413
414        // Just verify we can create a server without panicking
415        assert_eq!(server.port, 9001);
416    }
417
418    #[test]
419    fn build_set_parameter_notification_from_success_response() {
420        let request = IpcRequest::new(
421            RequestId::Number(1),
422            wavecraft_protocol::METHOD_SET_PARAMETER,
423            Some(serde_json::json!({ "id": "gain", "value": 0.8 })),
424        );
425        let response = serde_json::to_string(&IpcResponse::success(
426            RequestId::Number(1),
427            serde_json::json!({}),
428        ))
429        .expect("serialize response");
430
431        let notification = build_set_parameter_notification(&request, &response)
432            .expect("should create parameterChanged notification");
433        let json: serde_json::Value =
434            serde_json::from_str(&notification).expect("notification should parse");
435
436        assert_eq!(
437            json.get("method"),
438            Some(&serde_json::json!(
439                wavecraft_protocol::NOTIFICATION_PARAMETER_CHANGED
440            ))
441        );
442        assert_eq!(json.pointer("/params/id"), Some(&serde_json::json!("gain")));
443        let Some(value) = json
444            .pointer("/params/value")
445            .and_then(serde_json::Value::as_f64)
446        else {
447            panic!("notification should contain numeric params.value");
448        };
449        assert!(
450            (value - 0.8).abs() < 1e-5,
451            "expected approx 0.8, got {value}"
452        );
453    }
454
455    #[test]
456    fn build_set_parameter_notification_ignores_error_response() {
457        let request = IpcRequest::new(
458            RequestId::Number(1),
459            wavecraft_protocol::METHOD_SET_PARAMETER,
460            Some(serde_json::json!({ "id": "gain", "value": 10.0 })),
461        );
462
463        let response = serde_json::to_string(&IpcResponse::error(
464            RequestId::Number(1),
465            wavecraft_protocol::IpcError::invalid_params("out of range"),
466        ))
467        .expect("serialize error response");
468
469        assert!(build_set_parameter_notification(&request, &response).is_none());
470    }
471}