1use futures_util::{SinkExt, StreamExt};
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{RwLock, broadcast};
12use tokio_tungstenite::{accept_async, tungstenite::protocol::Message};
13use tracing::{debug, error, info, warn};
14use wavecraft_bridge::{IpcHandler, ParameterHost};
15
16struct ServerState {
18 browser_clients: Arc<RwLock<Vec<tokio::sync::mpsc::Sender<String>>>>,
20 audio_client: Arc<RwLock<Option<String>>>,
22}
23
24impl ServerState {
25 fn new() -> Self {
26 Self {
27 browser_clients: Arc::new(RwLock::new(Vec::new())),
28 audio_client: Arc::new(RwLock::new(None)),
29 }
30 }
31}
32
33#[derive(Clone)]
39pub struct WsHandle {
40 state: Arc<ServerState>,
41}
42
43impl WsHandle {
44 pub async fn broadcast(&self, json: &str) {
46 let clients = self.state.browser_clients.read().await;
47 for client in clients.iter() {
48 if let Err(e) = client.try_send(json.to_owned()) {
49 warn!("Failed to broadcast message to client: {}", e);
50 }
51 }
52 }
53}
54
55pub struct WsServer<H: ParameterHost + 'static> {
57 port: u16,
59 handler: Arc<IpcHandler<H>>,
61 shutdown_tx: broadcast::Sender<()>,
63 verbose: bool,
65 state: Arc<ServerState>,
67}
68
69impl<H: ParameterHost + 'static> WsServer<H> {
70 pub fn new(port: u16, handler: Arc<IpcHandler<H>>, verbose: bool) -> Self {
72 let (shutdown_tx, _) = broadcast::channel(1);
73 Self {
74 port,
75 handler,
76 shutdown_tx,
77 verbose,
78 state: Arc::new(ServerState::new()),
79 }
80 }
81
82 pub fn handle(&self) -> WsHandle {
87 WsHandle {
88 state: Arc::clone(&self.state),
89 }
90 }
91
92 pub async fn broadcast_parameters_changed(&self) -> Result<(), serde_json::Error> {
97 use wavecraft_protocol::IpcNotification;
98
99 let notification = IpcNotification::new("parametersChanged", serde_json::json!({}));
100 let json = serde_json::to_string(¬ification)?;
101
102 let clients = self.state.browser_clients.read().await;
103 for client in clients.iter() {
104 if let Err(e) = client.try_send(json.clone()) {
105 warn!("Failed to send parametersChanged notification to client: {}", e);
106 }
107 }
108
109 Ok(())
110 }
111
112 pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
114 let addr: SocketAddr = format!("127.0.0.1:{}", self.port).parse()?;
115 let listener = TcpListener::bind(&addr).await?;
116
117 info!("Server listening on ws://{}", addr);
118
119 let handler = Arc::clone(&self.handler);
120 let mut shutdown_rx = self.shutdown_tx.subscribe();
121 let verbose = self.verbose;
122 let state = Arc::clone(&self.state);
123
124 tokio::spawn(async move {
125 loop {
126 tokio::select! {
127 result = listener.accept() => {
128 match result {
129 Ok((stream, addr)) => {
130 info!("Client connected: {}", addr);
131 let handler = Arc::clone(&handler);
132 let state = Arc::clone(&state);
133 tokio::spawn(handle_connection(handler, stream, addr, verbose, state));
134 }
135 Err(e) => {
136 error!("Accept error: {}", e);
137 }
138 }
139 }
140 _ = shutdown_rx.recv() => {
141 info!("Server shutting down");
142 break;
143 }
144 }
145 }
146 });
147
148 Ok(())
149 }
150
151 #[allow(dead_code)]
155 pub fn shutdown(&self) {
156 let _ = self.shutdown_tx.send(());
157 }
158}
159
160async fn handle_connection<H: ParameterHost>(
162 handler: Arc<IpcHandler<H>>,
163 stream: TcpStream,
164 addr: SocketAddr,
165 verbose: bool,
166 state: Arc<ServerState>,
167) {
168 let ws_stream = match accept_async(stream).await {
169 Ok(ws) => ws,
170 Err(e) => {
171 error!("Error during handshake with {}: {}", addr, e);
172 return;
173 }
174 };
175
176 info!("WebSocket connection established: {}", addr);
177
178 let (mut write, mut read) = ws_stream.split();
179 let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(128);
180
181 let mut is_audio_client = false;
183 let client_index = {
184 let mut clients = state.browser_clients.write().await;
185 clients.push(tx.clone());
186 clients.len() - 1
187 };
188
189 let write_task = tokio::spawn(async move {
191 while let Some(msg) = rx.recv().await {
192 if let Err(e) = write.send(Message::Text(msg)).await {
193 error!("Error sending to {}: {}", addr, e);
194 break;
195 }
196 }
197 });
198
199 while let Some(msg) = read.next().await {
200 match msg {
201 Ok(Message::Text(json)) => {
202 if verbose {
204 debug!("Received from {}: {}", addr, json);
205 }
206
207 let parsed_req = serde_json::from_str::<wavecraft_protocol::IpcRequest>(&json);
209
210 if let Ok(ref req) = parsed_req {
211 if req.method == "registerAudio" {
213 is_audio_client = true;
214 info!("Audio client registered: {}", addr);
215
216 if let Some(params) = req.params.clone()
218 && let Ok(audio_params) = serde_json::from_value::<
219 wavecraft_protocol::RegisterAudioParams,
220 >(params)
221 {
222 *state.audio_client.write().await = Some(audio_params.client_id.clone());
223 }
224
225 let response = wavecraft_protocol::IpcResponse::success(
227 req.id.clone(),
228 wavecraft_protocol::RegisterAudioResult {
229 status: "registered".to_string(),
230 },
231 );
232 let response_json = match serde_json::to_string(&response) {
233 Ok(json) => json,
234 Err(e) => {
235 error!("Failed to serialize registerAudio response: {}", e);
236 break;
237 }
238 };
239 if let Err(e) = tx.try_send(response_json) {
240 error!("Error sending response: {}", e);
241 break;
242 }
243 continue;
244 }
245
246 if is_audio_client && req.method == "meterUpdate" {
248 let clients = state.browser_clients.read().await;
250 for (idx, client) in clients.iter().enumerate() {
251 if idx != client_index {
252 if let Err(e) = client.try_send(json.clone()) {
254 warn!("Failed to broadcast meter update to client {}: {}", idx, e);
255 }
256 }
257 }
258 continue;
259 }
260 }
261
262 let response = handler.handle_json(&json);
264
265 if verbose {
267 debug!("Sending to {}: {}", addr, response);
268 }
269
270 if let Err(e) = tx.try_send(response) {
272 error!("Error queueing response: {}", e);
273 break;
274 }
275 }
276 Ok(Message::Close(_)) => {
277 info!("Client closed connection: {}", addr);
278 break;
279 }
280 Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {
281 }
283 Ok(Message::Binary(_)) => {
284 warn!("Unexpected binary message from {}", addr);
285 }
286 Ok(Message::Frame(_)) => {
287 }
289 Err(e) => {
290 error!("Error receiving from {}: {}", addr, e);
291 break;
292 }
293 }
294 }
295
296 state
298 .browser_clients
299 .write()
300 .await
301 .retain(|c| !c.is_closed());
302 if is_audio_client {
303 *state.audio_client.write().await = None;
304 info!("Audio client disconnected: {}", addr);
305 }
306
307 write_task.abort();
308 info!("Connection closed: {}", addr);
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use wavecraft_bridge::InMemoryParameterHost;
315 use wavecraft_protocol::{ParameterInfo, ParameterType};
316
317 fn test_host() -> InMemoryParameterHost {
319 InMemoryParameterHost::new(vec![ParameterInfo {
320 id: "gain".to_string(),
321 name: "Gain".to_string(),
322 param_type: ParameterType::Float,
323 value: 0.5,
324 default: 0.5,
325 unit: Some("dB".to_string()),
326 group: Some("Input".to_string()),
327 }])
328 }
329
330 #[tokio::test]
331 async fn test_server_creation() {
332 let host = test_host();
333 let handler = Arc::new(IpcHandler::new(host));
334 let server = WsServer::new(9001, handler, false);
335
336 assert_eq!(server.port, 9001);
338 }
339}