Skip to main content

wavecraft_dev_server/
ws_server.rs

1//! WebSocket server for browser-based UI development
2//!
3//! This module provides a WebSocket server that exposes the same IPC protocol
4//! used by the native WKWebView transport, enabling real-time communication
5//! between a browser-based UI and the Rust engine during development.
6
7use futures_util::{SinkExt, StreamExt};
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{RwLock, broadcast};
12use tokio_tungstenite::{accept_async, tungstenite::protocol::Message};
13use tracing::{debug, error, info, warn};
14use wavecraft_bridge::{IpcHandler, ParameterHost};
15
16/// Shared state for tracking connected clients
17struct ServerState {
18    /// Connected browser clients (for broadcasting meter updates)
19    browser_clients: Arc<RwLock<Vec<tokio::sync::mpsc::UnboundedSender<String>>>>,
20    /// Audio client ID (if connected)
21    audio_client: Arc<RwLock<Option<String>>>,
22}
23
24impl ServerState {
25    fn new() -> Self {
26        Self {
27            browser_clients: Arc::new(RwLock::new(Vec::new())),
28            audio_client: Arc::new(RwLock::new(None)),
29        }
30    }
31}
32
33/// A lightweight, cloneable handle to the WebSocket server's broadcast
34/// capability. Non-generic — can be passed across async task boundaries.
35///
36/// Constructed via [`WsServer::handle()`]. Used by the CLI to forward
37/// meter updates from the in-process audio callback to browser clients.
38#[derive(Clone)]
39#[allow(dead_code)] // Used by CLI crate (outside engine workspace)
40pub struct WsHandle {
41    state: Arc<ServerState>,
42}
43
44#[allow(dead_code)] // Used by CLI crate (outside engine workspace)
45impl WsHandle {
46    /// Broadcast a JSON string to all connected browser clients.
47    pub async fn broadcast(&self, json: &str) {
48        let clients = self.state.browser_clients.read().await;
49        for client in clients.iter() {
50            let _ = client.send(json.to_owned());
51        }
52    }
53}
54
55/// WebSocket server for browser-based UI development
56pub struct WsServer<H: ParameterHost + 'static> {
57    /// Port the server listens on
58    port: u16,
59    /// Shared IPC handler
60    handler: Arc<IpcHandler<H>>,
61    /// Shutdown signal
62    shutdown_tx: broadcast::Sender<()>,
63    /// Enable verbose logging (all JSON-RPC messages)
64    verbose: bool,
65    /// Shared server state
66    state: Arc<ServerState>,
67}
68
69impl<H: ParameterHost + 'static> WsServer<H> {
70    /// Create a new WebSocket server
71    pub fn new(port: u16, handler: Arc<IpcHandler<H>>, verbose: bool) -> Self {
72        let (shutdown_tx, _) = broadcast::channel(1);
73        Self {
74            port,
75            handler,
76            shutdown_tx,
77            verbose,
78            state: Arc::new(ServerState::new()),
79        }
80    }
81
82    /// Get a lightweight handle for broadcasting to connected clients.
83    ///
84    /// The returned `WsHandle` is non-generic, `Clone`, and can be moved
85    /// into async tasks (e.g., for forwarding meter updates from audio).
86    #[allow(dead_code)] // Used by CLI crate (outside engine workspace)
87    pub fn handle(&self) -> WsHandle {
88        WsHandle {
89            state: Arc::clone(&self.state),
90        }
91    }
92
93    /// Start the server (spawns async tasks, returns immediately)
94    pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
95        let addr: SocketAddr = format!("127.0.0.1:{}", self.port).parse()?;
96        let listener = TcpListener::bind(&addr).await?;
97
98        info!("Server listening on ws://{}", addr);
99
100        let handler = Arc::clone(&self.handler);
101        let mut shutdown_rx = self.shutdown_tx.subscribe();
102        let verbose = self.verbose;
103        let state = Arc::clone(&self.state);
104
105        tokio::spawn(async move {
106            loop {
107                tokio::select! {
108                    result = listener.accept() => {
109                        match result {
110                            Ok((stream, addr)) => {
111                                info!("Client connected: {}", addr);
112                                let handler = Arc::clone(&handler);
113                                let state = Arc::clone(&state);
114                                tokio::spawn(handle_connection(handler, stream, addr, verbose, state));
115                            }
116                            Err(e) => {
117                                error!("Accept error: {}", e);
118                            }
119                        }
120                    }
121                    _ = shutdown_rx.recv() => {
122                        info!("Server shutting down");
123                        break;
124                    }
125                }
126            }
127        });
128
129        Ok(())
130    }
131
132    /// Shutdown the server gracefully.
133    ///
134    /// Note: Not currently called but kept for future graceful shutdown support.
135    #[allow(dead_code)]
136    pub fn shutdown(&self) {
137        let _ = self.shutdown_tx.send(());
138    }
139}
140
141/// Handle a single WebSocket connection
142async fn handle_connection<H: ParameterHost>(
143    handler: Arc<IpcHandler<H>>,
144    stream: TcpStream,
145    addr: SocketAddr,
146    verbose: bool,
147    state: Arc<ServerState>,
148) {
149    let ws_stream = match accept_async(stream).await {
150        Ok(ws) => ws,
151        Err(e) => {
152            error!("Error during handshake with {}: {}", addr, e);
153            return;
154        }
155    };
156
157    info!("WebSocket connection established: {}", addr);
158
159    let (mut write, mut read) = ws_stream.split();
160    let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
161
162    // Track this client for broadcasting
163    let mut is_audio_client = false;
164    state.browser_clients.write().await.push(tx.clone());
165    let client_index = state.browser_clients.read().await.len() - 1;
166
167    // Spawn task to send messages from channel to WebSocket
168    let write_task = tokio::spawn(async move {
169        while let Some(msg) = rx.recv().await {
170            if let Err(e) = write.send(Message::Text(msg)).await {
171                error!("Error sending to {}: {}", addr, e);
172                break;
173            }
174        }
175    });
176
177    while let Some(msg) = read.next().await {
178        match msg {
179            Ok(Message::Text(json)) => {
180                // Log incoming message (verbose only)
181                if verbose {
182                    debug!("Received from {}: {}", addr, json);
183                }
184
185                // Check if this is an audio client registration
186                if json.contains("\"method\":\"registerAudio\"") {
187                    is_audio_client = true;
188                    info!("Audio client registered: {}", addr);
189
190                    // Parse to extract client_id
191                    if let Ok(req) = serde_json::from_str::<wavecraft_protocol::IpcRequest>(&json)
192                        && let Some(params) = req.params
193                        && let Ok(audio_params) = serde_json::from_value::<
194                            wavecraft_protocol::RegisterAudioParams,
195                        >(params)
196                    {
197                        *state.audio_client.write().await = Some(audio_params.client_id.clone());
198                    }
199
200                    // Send success response
201                    let response = wavecraft_protocol::IpcResponse::success(
202                        wavecraft_protocol::RequestId::Number(1),
203                        wavecraft_protocol::RegisterAudioResult {
204                            status: "registered".to_string(),
205                        },
206                    );
207                    let response_json = serde_json::to_string(&response).unwrap();
208                    if let Err(e) = tx.send(response_json) {
209                        error!("Error sending response: {}", e);
210                        break;
211                    }
212                    continue;
213                }
214
215                // Check if this is a meter update notification from audio client
216                if is_audio_client && json.contains("\"method\":\"meterUpdate\"") {
217                    // Broadcast to all browser clients
218                    let clients = state.browser_clients.read().await;
219                    for (idx, client) in clients.iter().enumerate() {
220                        if idx != client_index {
221                            // Don't send back to audio client
222                            let _ = client.send(json.clone());
223                        }
224                    }
225                    continue;
226                }
227
228                // Route through existing IpcHandler
229                let response = handler.handle_json(&json);
230
231                // Log outgoing response (verbose only)
232                if verbose {
233                    debug!("Sending to {}: {}", addr, response);
234                }
235
236                // Send response
237                if let Err(e) = tx.send(response) {
238                    error!("Error queueing response: {}", e);
239                    break;
240                }
241            }
242            Ok(Message::Close(_)) => {
243                info!("Client closed connection: {}", addr);
244                break;
245            }
246            Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {
247                // Ignore ping/pong frames (automatically handled)
248            }
249            Ok(Message::Binary(_)) => {
250                warn!("Unexpected binary message from {}", addr);
251            }
252            Ok(Message::Frame(_)) => {
253                // Raw frames shouldn't appear at this level
254            }
255            Err(e) => {
256                error!("Error receiving from {}: {}", addr, e);
257                break;
258            }
259        }
260    }
261
262    // Cleanup: remove client from broadcast list
263    state
264        .browser_clients
265        .write()
266        .await
267        .retain(|c| !c.is_closed());
268    if is_audio_client {
269        *state.audio_client.write().await = None;
270        info!("Audio client disconnected: {}", addr);
271    }
272
273    write_task.abort();
274    info!("Connection closed: {}", addr);
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use crate::app::AppState;
281
282    #[tokio::test]
283    async fn test_server_creation() {
284        let state = AppState::new();
285        let handler = Arc::new(IpcHandler::new(state));
286        let server = WsServer::new(9001, handler, false);
287
288        // Just verify we can create a server without panicking
289        assert_eq!(server.port, 9001);
290    }
291}