Skip to main content

wavecraft_dev_server/ws/
mod.rs

1//! WebSocket server for browser-based UI development
2//!
3//! This module provides a WebSocket server that exposes the same IPC protocol
4//! used by the native WKWebView transport, enabling real-time communication
5//! between a browser-based UI and the Rust engine during development.
6
7use futures_util::{SinkExt, StreamExt};
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{RwLock, broadcast};
12use tokio_tungstenite::{accept_async, tungstenite::protocol::Message};
13use tracing::{debug, error, info, warn};
14use wavecraft_bridge::{IpcHandler, ParameterHost};
15use wavecraft_protocol::{
16    AudioRuntimeStatus, IpcNotification, IpcResponse, NOTIFICATION_AUDIO_STATUS_CHANGED,
17};
18
19/// Shared state for tracking connected clients
20struct ServerState {
21    /// Connected browser clients (for broadcasting meter updates)
22    browser_clients: Arc<RwLock<Vec<tokio::sync::mpsc::Sender<String>>>>,
23    /// Audio client ID (if connected)
24    audio_client: Arc<RwLock<Option<String>>>,
25}
26
27impl ServerState {
28    fn new() -> Self {
29        Self {
30            browser_clients: Arc::new(RwLock::new(Vec::new())),
31            audio_client: Arc::new(RwLock::new(None)),
32        }
33    }
34}
35
36/// A lightweight, cloneable handle to the WebSocket server's broadcast
37/// capability. Non-generic — can be passed across async task boundaries.
38///
39/// Constructed via [`WsServer::handle()`]. Used by the CLI to forward
40/// meter updates from the in-process audio callback to browser clients.
41#[derive(Clone)]
42pub struct WsHandle {
43    state: Arc<ServerState>,
44}
45
46impl WsHandle {
47    /// Broadcast a JSON string to all connected browser clients.
48    pub async fn broadcast(&self, json: &str) {
49        broadcast_to_browser_clients(&self.state, json, None, "broadcast message").await;
50    }
51
52    /// Broadcast an audioStatusChanged notification to connected clients.
53    pub async fn broadcast_audio_status_changed(
54        &self,
55        status: &AudioRuntimeStatus,
56    ) -> Result<(), serde_json::Error> {
57        let json = serde_json::to_string(&IpcNotification::new(
58            NOTIFICATION_AUDIO_STATUS_CHANGED,
59            status,
60        ))?;
61
62        self.broadcast(&json).await;
63        Ok(())
64    }
65}
66
67async fn broadcast_to_browser_clients(
68    state: &Arc<ServerState>,
69    json: &str,
70    exclude_client_index: Option<usize>,
71    warning_context: &str,
72) {
73    let clients = state.browser_clients.read().await;
74    for (index, client) in clients.iter().enumerate() {
75        if exclude_client_index.is_some_and(|excluded| index == excluded) {
76            continue;
77        }
78
79        if let Err(error) = client.try_send(json.to_owned()) {
80            warn!(
81                "Failed to {} (client {}): {}",
82                warning_context, index, error
83            );
84        }
85    }
86}
87
88/// WebSocket server for browser-based UI development
89pub struct WsServer<H: ParameterHost + 'static> {
90    /// Port the server listens on
91    port: u16,
92    /// Shared IPC handler
93    handler: Arc<IpcHandler<H>>,
94    /// Shutdown signal
95    shutdown_tx: broadcast::Sender<()>,
96    /// Shared server state
97    state: Arc<ServerState>,
98}
99
100fn build_set_parameter_notification(
101    request: &wavecraft_protocol::IpcRequest,
102    response: &str,
103) -> Option<String> {
104    if request.method != wavecraft_protocol::METHOD_SET_PARAMETER {
105        return None;
106    }
107
108    let response_msg = serde_json::from_str::<IpcResponse>(response).ok()?;
109    if response_msg.error.is_some() {
110        return None;
111    }
112
113    let params = request.params.clone()?;
114    let set_params =
115        serde_json::from_value::<wavecraft_protocol::SetParameterParams>(params).ok()?;
116
117    serde_json::to_string(&IpcNotification::new(
118        wavecraft_protocol::NOTIFICATION_PARAMETER_CHANGED,
119        serde_json::json!({
120            "id": set_params.id,
121            "value": set_params.value,
122        }),
123    ))
124    .ok()
125}
126
127impl<H: ParameterHost + 'static> WsServer<H> {
128    /// Create a new WebSocket server
129    pub fn new(port: u16, handler: Arc<IpcHandler<H>>) -> Self {
130        let (shutdown_tx, _) = broadcast::channel(1);
131        Self {
132            port,
133            handler,
134            shutdown_tx,
135            state: Arc::new(ServerState::new()),
136        }
137    }
138
139    /// Get a lightweight handle for broadcasting to connected clients.
140    ///
141    /// The returned `WsHandle` is non-generic, `Clone`, and can be moved
142    /// into async tasks (e.g., for forwarding meter updates from audio).
143    pub fn handle(&self) -> WsHandle {
144        WsHandle {
145            state: Arc::clone(&self.state),
146        }
147    }
148
149    /// Broadcast a parametersChanged notification to all connected clients.
150    ///
151    /// This is used by the hot-reload pipeline to notify the UI that
152    /// parameters have been updated and should be re-fetched.
153    pub async fn broadcast_parameters_changed(&self) -> Result<(), serde_json::Error> {
154        use wavecraft_protocol::IpcNotification;
155
156        let notification = IpcNotification::new("parametersChanged", serde_json::json!({}));
157        let json = serde_json::to_string(&notification)?;
158
159        broadcast_to_browser_clients(
160            &self.state,
161            &json,
162            None,
163            "send parametersChanged notification",
164        )
165        .await;
166
167        Ok(())
168    }
169
170    /// Start the server (spawns async tasks, returns immediately)
171    pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
172        let addr: SocketAddr = format!("127.0.0.1:{}", self.port).parse()?;
173        let listener = TcpListener::bind(&addr).await?;
174
175        info!("Server listening on ws://{}", addr);
176
177        let handler = Arc::clone(&self.handler);
178        let mut shutdown_rx = self.shutdown_tx.subscribe();
179        let state = Arc::clone(&self.state);
180
181        tokio::spawn(async move {
182            loop {
183                tokio::select! {
184                    result = listener.accept() => {
185                        match result {
186                            Ok((stream, addr)) => {
187                                info!("Client connected: {}", addr);
188                                let handler = Arc::clone(&handler);
189                                let state = Arc::clone(&state);
190                                tokio::spawn(handle_connection(handler, stream, addr, state));
191                            }
192                            Err(e) => {
193                                error!("Accept error: {}", e);
194                            }
195                        }
196                    }
197                    _ = shutdown_rx.recv() => {
198                        info!("Server shutting down");
199                        break;
200                    }
201                }
202            }
203        });
204
205        Ok(())
206    }
207
208    /// Shutdown the server gracefully.
209    ///
210    /// Note: Not currently called but kept for future graceful shutdown support.
211    #[allow(dead_code)]
212    pub fn shutdown(&self) {
213        let _ = self.shutdown_tx.send(());
214    }
215}
216
217/// Handle a single WebSocket connection
218async fn handle_connection<H: ParameterHost>(
219    handler: Arc<IpcHandler<H>>,
220    stream: TcpStream,
221    addr: SocketAddr,
222    state: Arc<ServerState>,
223) {
224    let ws_stream = match accept_async(stream).await {
225        Ok(ws) => ws,
226        Err(e) => {
227            error!("Error during handshake with {}: {}", addr, e);
228            return;
229        }
230    };
231
232    info!("WebSocket connection established: {}", addr);
233
234    let (mut write, mut read) = ws_stream.split();
235    let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(128);
236
237    // Track this client for broadcasting
238    let mut is_audio_client = false;
239    let client_index = {
240        let mut clients = state.browser_clients.write().await;
241        clients.push(tx.clone());
242        clients.len() - 1
243    };
244
245    // Spawn task to send messages from channel to WebSocket
246    let write_task = tokio::spawn(async move {
247        while let Some(msg) = rx.recv().await {
248            if let Err(e) = write.send(Message::Text(msg)).await {
249                error!("Error sending to {}: {}", addr, e);
250                break;
251            }
252        }
253    });
254
255    while let Some(msg) = read.next().await {
256        match msg {
257            Ok(Message::Text(json)) => {
258                debug!("Received from {}: {}", addr, json);
259
260                // Try to parse as IPC request for structured routing
261                let parsed_req = serde_json::from_str::<wavecraft_protocol::IpcRequest>(&json);
262
263                if let Ok(ref req) = parsed_req {
264                    // Handle registerAudio
265                    if req.method == "registerAudio" {
266                        is_audio_client = true;
267                        info!("Audio client registered: {}", addr);
268
269                        // Parse to extract client_id
270                        if let Some(params) = req.params.clone()
271                            && let Ok(audio_params) = serde_json::from_value::<
272                                wavecraft_protocol::RegisterAudioParams,
273                            >(params)
274                        {
275                            *state.audio_client.write().await =
276                                Some(audio_params.client_id.clone());
277                        }
278
279                        // Send success response using the request's id
280                        let response = wavecraft_protocol::IpcResponse::success(
281                            req.id.clone(),
282                            wavecraft_protocol::RegisterAudioResult {
283                                status: "registered".to_string(),
284                            },
285                        );
286                        let response_json = match serde_json::to_string(&response) {
287                            Ok(json) => json,
288                            Err(e) => {
289                                error!("Failed to serialize registerAudio response: {}", e);
290                                break;
291                            }
292                        };
293                        if let Err(e) = tx.try_send(response_json) {
294                            error!("Error sending response: {}", e);
295                            break;
296                        }
297                        continue;
298                    }
299
300                    // Handle meterUpdate from audio client
301                    if is_audio_client && req.method == "meterUpdate" {
302                        // Broadcast to all browser clients
303                        broadcast_to_browser_clients(
304                            &state,
305                            &json,
306                            Some(client_index),
307                            "broadcast meter update",
308                        )
309                        .await;
310                        continue;
311                    }
312                }
313
314                // Route through existing IpcHandler
315                let response = handler.handle_json(&json);
316
317                // Mirror native editor behavior in dev mode: after successful
318                // setParameter, emit parameterChanged so hooks relying on
319                // notifications stay in sync with backend-confirmed state.
320                if let Ok(req) = &parsed_req
321                    && let Some(notification_json) =
322                        build_set_parameter_notification(req, &response)
323                {
324                    broadcast_to_browser_clients(
325                        &state,
326                        &notification_json,
327                        None,
328                        "send parameterChanged notification",
329                    )
330                    .await;
331                }
332
333                // Log outgoing response
334                debug!("Sending to {}: {}", addr, response);
335
336                // Send response
337                if let Err(e) = tx.try_send(response) {
338                    error!("Error queueing response: {}", e);
339                    break;
340                }
341            }
342            Ok(Message::Close(_)) => {
343                info!("Client closed connection: {}", addr);
344                break;
345            }
346            Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {
347                // Ignore ping/pong frames (automatically handled)
348            }
349            Ok(Message::Binary(_)) => {
350                warn!("Unexpected binary message from {}", addr);
351            }
352            Ok(Message::Frame(_)) => {
353                // Raw frames shouldn't appear at this level
354            }
355            Err(e) => {
356                error!("Error receiving from {}: {}", addr, e);
357                break;
358            }
359        }
360    }
361
362    // Cleanup: remove client from broadcast list
363    state
364        .browser_clients
365        .write()
366        .await
367        .retain(|c| !c.is_closed());
368    if is_audio_client {
369        *state.audio_client.write().await = None;
370        info!("Audio client disconnected: {}", addr);
371    }
372
373    write_task.abort();
374    info!("Connection closed: {}", addr);
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use wavecraft_bridge::InMemoryParameterHost;
381    use wavecraft_protocol::{IpcRequest, IpcResponse, ParameterInfo, ParameterType, RequestId};
382
383    /// Simple test host for unit tests
384    fn test_host() -> InMemoryParameterHost {
385        InMemoryParameterHost::new(vec![ParameterInfo {
386            id: "gain".to_string(),
387            name: "Gain".to_string(),
388            param_type: ParameterType::Float,
389            value: 0.5,
390            default: 0.5,
391            min: 0.0,
392            max: 1.0,
393            unit: Some("dB".to_string()),
394            group: Some("Input".to_string()),
395            variants: None,
396        }])
397    }
398
399    #[tokio::test]
400    async fn test_server_creation() {
401        let host = test_host();
402        let handler = Arc::new(IpcHandler::new(host));
403        let server = WsServer::new(9001, handler);
404
405        // Just verify we can create a server without panicking
406        assert_eq!(server.port, 9001);
407    }
408
409    #[test]
410    fn build_set_parameter_notification_from_success_response() {
411        let request = IpcRequest::new(
412            RequestId::Number(1),
413            wavecraft_protocol::METHOD_SET_PARAMETER,
414            Some(serde_json::json!({ "id": "gain", "value": 0.8 })),
415        );
416        let response = serde_json::to_string(&IpcResponse::success(
417            RequestId::Number(1),
418            serde_json::json!({}),
419        ))
420        .expect("serialize response");
421
422        let notification = build_set_parameter_notification(&request, &response)
423            .expect("should create parameterChanged notification");
424        let json: serde_json::Value =
425            serde_json::from_str(&notification).expect("notification should parse");
426
427        assert_eq!(
428            json.get("method"),
429            Some(&serde_json::json!(
430                wavecraft_protocol::NOTIFICATION_PARAMETER_CHANGED
431            ))
432        );
433        assert_eq!(json.pointer("/params/id"), Some(&serde_json::json!("gain")));
434        let Some(value) = json
435            .pointer("/params/value")
436            .and_then(serde_json::Value::as_f64)
437        else {
438            panic!("notification should contain numeric params.value");
439        };
440        assert!(
441            (value - 0.8).abs() < 1e-5,
442            "expected approx 0.8, got {value}"
443        );
444    }
445
446    #[test]
447    fn build_set_parameter_notification_ignores_error_response() {
448        let request = IpcRequest::new(
449            RequestId::Number(1),
450            wavecraft_protocol::METHOD_SET_PARAMETER,
451            Some(serde_json::json!({ "id": "gain", "value": 10.0 })),
452        );
453
454        let response = serde_json::to_string(&IpcResponse::error(
455            RequestId::Number(1),
456            wavecraft_protocol::IpcError::invalid_params("out of range"),
457        ))
458        .expect("serialize error response");
459
460        assert!(build_set_parameter_notification(&request, &response).is_none());
461    }
462}