wavecraft_dev_server/
ws_server.rs1use futures_util::{SinkExt, StreamExt};
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{RwLock, broadcast};
12use tokio_tungstenite::{accept_async, tungstenite::protocol::Message};
13use tracing::{debug, error, info, warn};
14use wavecraft_bridge::{IpcHandler, ParameterHost};
15
16struct ServerState {
18 browser_clients: Arc<RwLock<Vec<tokio::sync::mpsc::UnboundedSender<String>>>>,
20 audio_client: Arc<RwLock<Option<String>>>,
22}
23
24impl ServerState {
25 fn new() -> Self {
26 Self {
27 browser_clients: Arc::new(RwLock::new(Vec::new())),
28 audio_client: Arc::new(RwLock::new(None)),
29 }
30 }
31}
32
33#[derive(Clone)]
39#[allow(dead_code)] pub struct WsHandle {
41 state: Arc<ServerState>,
42}
43
44#[allow(dead_code)] impl WsHandle {
46 pub async fn broadcast(&self, json: &str) {
48 let clients = self.state.browser_clients.read().await;
49 for client in clients.iter() {
50 let _ = client.send(json.to_owned());
51 }
52 }
53}
54
55pub struct WsServer<H: ParameterHost + 'static> {
57 port: u16,
59 handler: Arc<IpcHandler<H>>,
61 shutdown_tx: broadcast::Sender<()>,
63 verbose: bool,
65 state: Arc<ServerState>,
67}
68
69impl<H: ParameterHost + 'static> WsServer<H> {
70 pub fn new(port: u16, handler: Arc<IpcHandler<H>>, verbose: bool) -> Self {
72 let (shutdown_tx, _) = broadcast::channel(1);
73 Self {
74 port,
75 handler,
76 shutdown_tx,
77 verbose,
78 state: Arc::new(ServerState::new()),
79 }
80 }
81
82 #[allow(dead_code)] pub fn handle(&self) -> WsHandle {
88 WsHandle {
89 state: Arc::clone(&self.state),
90 }
91 }
92
93 pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
95 let addr: SocketAddr = format!("127.0.0.1:{}", self.port).parse()?;
96 let listener = TcpListener::bind(&addr).await?;
97
98 info!("Server listening on ws://{}", addr);
99
100 let handler = Arc::clone(&self.handler);
101 let mut shutdown_rx = self.shutdown_tx.subscribe();
102 let verbose = self.verbose;
103 let state = Arc::clone(&self.state);
104
105 tokio::spawn(async move {
106 loop {
107 tokio::select! {
108 result = listener.accept() => {
109 match result {
110 Ok((stream, addr)) => {
111 info!("Client connected: {}", addr);
112 let handler = Arc::clone(&handler);
113 let state = Arc::clone(&state);
114 tokio::spawn(handle_connection(handler, stream, addr, verbose, state));
115 }
116 Err(e) => {
117 error!("Accept error: {}", e);
118 }
119 }
120 }
121 _ = shutdown_rx.recv() => {
122 info!("Server shutting down");
123 break;
124 }
125 }
126 }
127 });
128
129 Ok(())
130 }
131
132 #[allow(dead_code)]
136 pub fn shutdown(&self) {
137 let _ = self.shutdown_tx.send(());
138 }
139}
140
141async fn handle_connection<H: ParameterHost>(
143 handler: Arc<IpcHandler<H>>,
144 stream: TcpStream,
145 addr: SocketAddr,
146 verbose: bool,
147 state: Arc<ServerState>,
148) {
149 let ws_stream = match accept_async(stream).await {
150 Ok(ws) => ws,
151 Err(e) => {
152 error!("Error during handshake with {}: {}", addr, e);
153 return;
154 }
155 };
156
157 info!("WebSocket connection established: {}", addr);
158
159 let (mut write, mut read) = ws_stream.split();
160 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<String>();
161
162 let mut is_audio_client = false;
164 state.browser_clients.write().await.push(tx.clone());
165 let client_index = state.browser_clients.read().await.len() - 1;
166
167 let write_task = tokio::spawn(async move {
169 while let Some(msg) = rx.recv().await {
170 if let Err(e) = write.send(Message::Text(msg)).await {
171 error!("Error sending to {}: {}", addr, e);
172 break;
173 }
174 }
175 });
176
177 while let Some(msg) = read.next().await {
178 match msg {
179 Ok(Message::Text(json)) => {
180 if verbose {
182 debug!("Received from {}: {}", addr, json);
183 }
184
185 if json.contains("\"method\":\"registerAudio\"") {
187 is_audio_client = true;
188 info!("Audio client registered: {}", addr);
189
190 if let Ok(req) = serde_json::from_str::<wavecraft_protocol::IpcRequest>(&json)
192 && let Some(params) = req.params
193 && let Ok(audio_params) = serde_json::from_value::<
194 wavecraft_protocol::RegisterAudioParams,
195 >(params)
196 {
197 *state.audio_client.write().await = Some(audio_params.client_id.clone());
198 }
199
200 let response = wavecraft_protocol::IpcResponse::success(
202 wavecraft_protocol::RequestId::Number(1),
203 wavecraft_protocol::RegisterAudioResult {
204 status: "registered".to_string(),
205 },
206 );
207 let response_json = serde_json::to_string(&response).unwrap();
208 if let Err(e) = tx.send(response_json) {
209 error!("Error sending response: {}", e);
210 break;
211 }
212 continue;
213 }
214
215 if is_audio_client && json.contains("\"method\":\"meterUpdate\"") {
217 let clients = state.browser_clients.read().await;
219 for (idx, client) in clients.iter().enumerate() {
220 if idx != client_index {
221 let _ = client.send(json.clone());
223 }
224 }
225 continue;
226 }
227
228 let response = handler.handle_json(&json);
230
231 if verbose {
233 debug!("Sending to {}: {}", addr, response);
234 }
235
236 if let Err(e) = tx.send(response) {
238 error!("Error queueing response: {}", e);
239 break;
240 }
241 }
242 Ok(Message::Close(_)) => {
243 info!("Client closed connection: {}", addr);
244 break;
245 }
246 Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {
247 }
249 Ok(Message::Binary(_)) => {
250 warn!("Unexpected binary message from {}", addr);
251 }
252 Ok(Message::Frame(_)) => {
253 }
255 Err(e) => {
256 error!("Error receiving from {}: {}", addr, e);
257 break;
258 }
259 }
260 }
261
262 state
264 .browser_clients
265 .write()
266 .await
267 .retain(|c| !c.is_closed());
268 if is_audio_client {
269 *state.audio_client.write().await = None;
270 info!("Audio client disconnected: {}", addr);
271 }
272
273 write_task.abort();
274 info!("Connection closed: {}", addr);
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use crate::app::AppState;
281
282 #[tokio::test]
283 async fn test_server_creation() {
284 let state = AppState::new();
285 let handler = Arc::new(IpcHandler::new(state));
286 let server = WsServer::new(9001, handler, false);
287
288 assert_eq!(server.port, 9001);
290 }
291}