Skip to main content

wavecraft_dev_server/ws/
mod.rs

1//! WebSocket server for browser-based UI development
2//!
3//! This module provides a WebSocket server that exposes the same IPC protocol
4//! used by the native WKWebView transport, enabling real-time communication
5//! between a browser-based UI and the Rust engine during development.
6
7use futures_util::{SinkExt, StreamExt};
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{RwLock, broadcast};
12use tokio_tungstenite::{accept_async, tungstenite::protocol::Message};
13use tracing::{debug, error, info, warn};
14use wavecraft_bridge::{IpcHandler, ParameterHost};
15
16/// Shared state for tracking connected clients
17struct ServerState {
18    /// Connected browser clients (for broadcasting meter updates)
19    browser_clients: Arc<RwLock<Vec<tokio::sync::mpsc::Sender<String>>>>,
20    /// Audio client ID (if connected)
21    audio_client: Arc<RwLock<Option<String>>>,
22}
23
24impl ServerState {
25    fn new() -> Self {
26        Self {
27            browser_clients: Arc::new(RwLock::new(Vec::new())),
28            audio_client: Arc::new(RwLock::new(None)),
29        }
30    }
31}
32
33/// A lightweight, cloneable handle to the WebSocket server's broadcast
34/// capability. Non-generic — can be passed across async task boundaries.
35///
36/// Constructed via [`WsServer::handle()`]. Used by the CLI to forward
37/// meter updates from the in-process audio callback to browser clients.
38#[derive(Clone)]
39pub struct WsHandle {
40    state: Arc<ServerState>,
41}
42
43impl WsHandle {
44    /// Broadcast a JSON string to all connected browser clients.
45    pub async fn broadcast(&self, json: &str) {
46        let clients = self.state.browser_clients.read().await;
47        for client in clients.iter() {
48            if let Err(e) = client.try_send(json.to_owned()) {
49                warn!("Failed to broadcast message to client: {}", e);
50            }
51        }
52    }
53}
54
55/// WebSocket server for browser-based UI development
56pub struct WsServer<H: ParameterHost + 'static> {
57    /// Port the server listens on
58    port: u16,
59    /// Shared IPC handler
60    handler: Arc<IpcHandler<H>>,
61    /// Shutdown signal
62    shutdown_tx: broadcast::Sender<()>,
63    /// Enable verbose logging (all JSON-RPC messages)
64    verbose: bool,
65    /// Shared server state
66    state: Arc<ServerState>,
67}
68
69impl<H: ParameterHost + 'static> WsServer<H> {
70    /// Create a new WebSocket server
71    pub fn new(port: u16, handler: Arc<IpcHandler<H>>, verbose: bool) -> Self {
72        let (shutdown_tx, _) = broadcast::channel(1);
73        Self {
74            port,
75            handler,
76            shutdown_tx,
77            verbose,
78            state: Arc::new(ServerState::new()),
79        }
80    }
81
82    /// Get a lightweight handle for broadcasting to connected clients.
83    ///
84    /// The returned `WsHandle` is non-generic, `Clone`, and can be moved
85    /// into async tasks (e.g., for forwarding meter updates from audio).
86    pub fn handle(&self) -> WsHandle {
87        WsHandle {
88            state: Arc::clone(&self.state),
89        }
90    }
91
92    /// Broadcast a parametersChanged notification to all connected clients.
93    ///
94    /// This is used by the hot-reload pipeline to notify the UI that
95    /// parameters have been updated and should be re-fetched.
96    pub async fn broadcast_parameters_changed(&self) -> Result<(), serde_json::Error> {
97        use wavecraft_protocol::IpcNotification;
98
99        let notification = IpcNotification::new("parametersChanged", serde_json::json!({}));
100        let json = serde_json::to_string(&notification)?;
101
102        let clients = self.state.browser_clients.read().await;
103        for client in clients.iter() {
104            if let Err(e) = client.try_send(json.clone()) {
105                warn!("Failed to send parametersChanged notification to client: {}", e);
106            }
107        }
108
109        Ok(())
110    }
111
112    /// Start the server (spawns async tasks, returns immediately)
113    pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
114        let addr: SocketAddr = format!("127.0.0.1:{}", self.port).parse()?;
115        let listener = TcpListener::bind(&addr).await?;
116
117        info!("Server listening on ws://{}", addr);
118
119        let handler = Arc::clone(&self.handler);
120        let mut shutdown_rx = self.shutdown_tx.subscribe();
121        let verbose = self.verbose;
122        let state = Arc::clone(&self.state);
123
124        tokio::spawn(async move {
125            loop {
126                tokio::select! {
127                    result = listener.accept() => {
128                        match result {
129                            Ok((stream, addr)) => {
130                                info!("Client connected: {}", addr);
131                                let handler = Arc::clone(&handler);
132                                let state = Arc::clone(&state);
133                                tokio::spawn(handle_connection(handler, stream, addr, verbose, state));
134                            }
135                            Err(e) => {
136                                error!("Accept error: {}", e);
137                            }
138                        }
139                    }
140                    _ = shutdown_rx.recv() => {
141                        info!("Server shutting down");
142                        break;
143                    }
144                }
145            }
146        });
147
148        Ok(())
149    }
150
151    /// Shutdown the server gracefully.
152    ///
153    /// Note: Not currently called but kept for future graceful shutdown support.
154    #[allow(dead_code)]
155    pub fn shutdown(&self) {
156        let _ = self.shutdown_tx.send(());
157    }
158}
159
160/// Handle a single WebSocket connection
161async fn handle_connection<H: ParameterHost>(
162    handler: Arc<IpcHandler<H>>,
163    stream: TcpStream,
164    addr: SocketAddr,
165    verbose: bool,
166    state: Arc<ServerState>,
167) {
168    let ws_stream = match accept_async(stream).await {
169        Ok(ws) => ws,
170        Err(e) => {
171            error!("Error during handshake with {}: {}", addr, e);
172            return;
173        }
174    };
175
176    info!("WebSocket connection established: {}", addr);
177
178    let (mut write, mut read) = ws_stream.split();
179    let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(128);
180
181    // Track this client for broadcasting
182    let mut is_audio_client = false;
183    let client_index = {
184        let mut clients = state.browser_clients.write().await;
185        clients.push(tx.clone());
186        clients.len() - 1
187    };
188
189    // Spawn task to send messages from channel to WebSocket
190    let write_task = tokio::spawn(async move {
191        while let Some(msg) = rx.recv().await {
192            if let Err(e) = write.send(Message::Text(msg)).await {
193                error!("Error sending to {}: {}", addr, e);
194                break;
195            }
196        }
197    });
198
199    while let Some(msg) = read.next().await {
200        match msg {
201            Ok(Message::Text(json)) => {
202                // Log incoming message (verbose only)
203                if verbose {
204                    debug!("Received from {}: {}", addr, json);
205                }
206
207                // Try to parse as IPC request for structured routing
208                let parsed_req = serde_json::from_str::<wavecraft_protocol::IpcRequest>(&json);
209                
210                if let Ok(ref req) = parsed_req {
211                    // Handle registerAudio
212                    if req.method == "registerAudio" {
213                        is_audio_client = true;
214                        info!("Audio client registered: {}", addr);
215
216                        // Parse to extract client_id
217                        if let Some(params) = req.params.clone()
218                            && let Ok(audio_params) = serde_json::from_value::<
219                                wavecraft_protocol::RegisterAudioParams,
220                            >(params)
221                        {
222                            *state.audio_client.write().await = Some(audio_params.client_id.clone());
223                        }
224
225                        // Send success response using the request's id
226                        let response = wavecraft_protocol::IpcResponse::success(
227                            req.id.clone(),
228                            wavecraft_protocol::RegisterAudioResult {
229                                status: "registered".to_string(),
230                            },
231                        );
232                        let response_json = match serde_json::to_string(&response) {
233                            Ok(json) => json,
234                            Err(e) => {
235                                error!("Failed to serialize registerAudio response: {}", e);
236                                break;
237                            }
238                        };
239                        if let Err(e) = tx.try_send(response_json) {
240                            error!("Error sending response: {}", e);
241                            break;
242                        }
243                        continue;
244                    }
245                    
246                    // Handle meterUpdate from audio client
247                    if is_audio_client && req.method == "meterUpdate" {
248                        // Broadcast to all browser clients
249                        let clients = state.browser_clients.read().await;
250                        for (idx, client) in clients.iter().enumerate() {
251                            if idx != client_index {
252                                // Don't send back to audio client
253                                if let Err(e) = client.try_send(json.clone()) {
254                                    warn!("Failed to broadcast meter update to client {}: {}", idx, e);
255                                }
256                            }
257                        }
258                        continue;
259                    }
260                }
261
262                // Route through existing IpcHandler
263                let response = handler.handle_json(&json);
264
265                // Log outgoing response (verbose only)
266                if verbose {
267                    debug!("Sending to {}: {}", addr, response);
268                }
269
270                // Send response
271                if let Err(e) = tx.try_send(response) {
272                    error!("Error queueing response: {}", e);
273                    break;
274                }
275            }
276            Ok(Message::Close(_)) => {
277                info!("Client closed connection: {}", addr);
278                break;
279            }
280            Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {
281                // Ignore ping/pong frames (automatically handled)
282            }
283            Ok(Message::Binary(_)) => {
284                warn!("Unexpected binary message from {}", addr);
285            }
286            Ok(Message::Frame(_)) => {
287                // Raw frames shouldn't appear at this level
288            }
289            Err(e) => {
290                error!("Error receiving from {}: {}", addr, e);
291                break;
292            }
293        }
294    }
295
296    // Cleanup: remove client from broadcast list
297    state
298        .browser_clients
299        .write()
300        .await
301        .retain(|c| !c.is_closed());
302    if is_audio_client {
303        *state.audio_client.write().await = None;
304        info!("Audio client disconnected: {}", addr);
305    }
306
307    write_task.abort();
308    info!("Connection closed: {}", addr);
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use wavecraft_bridge::InMemoryParameterHost;
315    use wavecraft_protocol::{ParameterInfo, ParameterType};
316
317    /// Simple test host for unit tests
318    fn test_host() -> InMemoryParameterHost {
319        InMemoryParameterHost::new(vec![ParameterInfo {
320            id: "gain".to_string(),
321            name: "Gain".to_string(),
322            param_type: ParameterType::Float,
323            value: 0.5,
324            default: 0.5,
325            unit: Some("dB".to_string()),
326            group: Some("Input".to_string()),
327        }])
328    }
329
330    #[tokio::test]
331    async fn test_server_creation() {
332        let host = test_host();
333        let handler = Arc::new(IpcHandler::new(host));
334        let server = WsServer::new(9001, handler, false);
335
336        // Just verify we can create a server without panicking
337        assert_eq!(server.port, 9001);
338    }
339}