1use anyhow::Result;
37use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
38use axum::extract::State;
39use axum::response::Response;
40use axum::routing::get;
41use axum::Router;
42use std::net::SocketAddr;
43use std::sync::{Arc, Mutex};
44use tokio::sync::broadcast;
45
46use crate::Event;
47
48#[derive(Debug, Clone)]
52pub struct Patch {
53 pub table: subsecond_types::JumpTable,
56 pub dylib_bytes: Arc<Vec<u8>>,
60}
61
62#[derive(Debug, Clone, serde::Serialize)]
73#[serde(tag = "kind", rename_all = "snake_case")]
74enum PatchHeader<'a> {
75 Patch {
76 #[serde(serialize_with = "wire_jump_table::serialize")]
77 table: &'a subsecond_types::JumpTable,
78 },
79}
80
81pub mod wire_jump_table {
86 use serde::ser::SerializeStruct;
87 use serde::Serializer;
88 use subsecond_types::JumpTable;
89
90 pub fn serialize<S: Serializer>(t: &JumpTable, s: S) -> Result<S::Ok, S::Error> {
91 let pairs: Vec<(u64, u64)> = t.map.iter().map(|(k, v)| (*k, *v)).collect();
92 let mut st = s.serialize_struct("JumpTable", 5)?;
93 st.serialize_field("lib", &t.lib)?;
94 st.serialize_field("map", &pairs)?;
95 st.serialize_field("aslr_reference", &t.aslr_reference)?;
96 st.serialize_field("new_base_address", &t.new_base_address)?;
97 st.serialize_field("ifunc_count", &t.ifunc_count)?;
98 st.end()
99 }
100}
101
102#[derive(Clone)]
106pub struct PatchSender {
107 tx: broadcast::Sender<Patch>,
108 aslr_reference: Arc<Mutex<Option<u64>>>,
115}
116
117impl PatchSender {
118 pub fn send(&self, patch: Patch) -> usize {
122 self.tx.send(patch).unwrap_or(0)
123 }
124
125 pub fn client_count(&self) -> usize {
127 self.tx.receiver_count()
128 }
129
130 pub fn latest_aslr_reference(&self) -> Option<u64> {
136 self.aslr_reference.lock().ok().and_then(|g| *g)
137 }
138}
139
140#[derive(Clone)]
141struct AppState {
142 tx: broadcast::Sender<Patch>,
143 on_event: Option<Arc<dyn Fn(Event) + Send + Sync>>,
144 aslr_reference: Arc<Mutex<Option<u64>>>,
145}
146
147pub async fn serve(
156 addr: SocketAddr,
157 on_event: Option<Arc<dyn Fn(Event) + Send + Sync>>,
158) -> Result<(PatchSender, SocketAddr, tokio::task::JoinHandle<()>)> {
159 let (tx, _rx) = broadcast::channel::<Patch>(16);
160 let aslr_reference: Arc<Mutex<Option<u64>>> = Arc::new(Mutex::new(None));
161 let state = AppState {
162 tx: tx.clone(),
163 on_event,
164 aslr_reference: Arc::clone(&aslr_reference),
165 };
166
167 let app = Router::new()
168 .route("/whisker-dev", get(ws_handler))
169 .with_state(state);
170
171 let listener = tokio::net::TcpListener::bind(addr).await?;
172 let bound = listener.local_addr()?;
173
174 let handle = tokio::spawn(async move {
175 if let Err(e) = axum::serve(listener, app).await {
176 whisker_build::ui::error(format!("axum serve error: {e}"));
177 }
178 });
179
180 Ok((PatchSender { tx, aslr_reference }, bound, handle))
181}
182
183async fn ws_handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
184 ws.on_upgrade(move |socket| handle_socket(socket, state))
185}
186
187async fn handle_socket(socket: WebSocket, state: AppState) {
188 use futures_util::{SinkExt, StreamExt};
189
190 let (mut tx_ws, mut rx_ws) = socket.split();
191 let mut bcast_rx = state.tx.subscribe();
192 whisker_build::ui::set_status(format!("{} client(s) connected", state.tx.receiver_count(),));
193 if let Some(cb) = &state.on_event {
196 cb(Event::ClientConnected);
197 }
198
199 loop {
200 tokio::select! {
201 recv = bcast_rx.recv() => {
203 let patch = match recv {
204 Ok(p) => p,
205 Err(broadcast::error::RecvError::Lagged(_)) => continue,
206 Err(broadcast::error::RecvError::Closed) => break,
207 };
208 let frame = match encode_patch_frame(&patch) {
209 Ok(b) => b,
210 Err(e) => {
211 whisker_build::ui::warn(format!("encode patch frame: {e}"));
212 continue;
213 }
214 };
215 if tx_ws.send(Message::Binary(frame.into())).await.is_err() {
216 break;
217 }
218 }
219 msg = rx_ws.next() => {
224 match msg {
225 Some(Ok(Message::Close(_))) | None => break,
226 Some(Err(_)) => break,
227 Some(Ok(Message::Text(t))) => {
228 if let Some(aslr) = parse_client_aslr_reference(&t) {
229 whisker_build::ui::debug(format!(
230 "client hello · aslr_reference={aslr:#x}"
231 ));
232 if let Ok(mut g) = state.aslr_reference.lock() {
233 *g = Some(aslr);
234 }
235 } else if let Some(log) = parse_client_log(&t) {
236 if let Some(cb) = &state.on_event {
237 cb(Event::DeviceLog {
238 stream: log.stream,
239 line: log.line,
240 ts_micros: log.ts_micros,
241 });
242 }
243 }
244 }
245 _ => {}
246 }
247 }
248 }
249 }
250
251 if let Some(cb) = &state.on_event {
252 cb(Event::ClientDisconnected);
253 }
254}
255
256fn encode_patch_frame(patch: &Patch) -> Result<Vec<u8>> {
259 let header = PatchHeader::Patch {
260 table: &patch.table,
261 };
262 let json = serde_json::to_vec(&header)?;
263 let json_len = json.len() as u64;
264 let dylib = patch.dylib_bytes.as_slice();
265 let mut frame = Vec::with_capacity(8 + json.len() + dylib.len());
266 frame.extend_from_slice(&json_len.to_be_bytes());
267 frame.extend_from_slice(&json);
268 frame.extend_from_slice(dylib);
269 Ok(frame)
270}
271
272fn parse_client_aslr_reference(text: &str) -> Option<u64> {
277 #[derive(serde::Deserialize)]
278 struct Hello {
279 kind: String,
280 aslr_reference: u64,
281 }
282 let h: Hello = serde_json::from_str(text).ok()?;
283 if h.kind == "hello" {
284 Some(h.aslr_reference)
285 } else {
286 None
287 }
288}
289
290struct ClientLog {
293 stream: String,
294 line: String,
295 ts_micros: u128,
296}
297
298fn parse_client_log(text: &str) -> Option<ClientLog> {
309 #[derive(serde::Deserialize)]
310 struct Log {
311 kind: String,
312 stream: String,
313 line: String,
314 #[serde(default)]
315 ts_micros: Option<String>,
316 }
317 let h: Log = serde_json::from_str(text).ok()?;
318 if h.kind != "log" {
319 return None;
320 }
321 let ts_micros = h
322 .ts_micros
323 .as_deref()
324 .and_then(|s| s.parse::<u128>().ok())
325 .unwrap_or(0);
326 Some(ClientLog {
327 stream: h.stream,
328 line: h.line,
329 ts_micros,
330 })
331}
332
333#[cfg(test)]
338mod tests {
339 use super::*;
340 use futures_util::{SinkExt, StreamExt};
341 use std::sync::atomic::{AtomicUsize, Ordering};
342
343 fn make_dummy_jump_table() -> subsecond_types::JumpTable {
344 let json = r#"{
348 "lib": "/tmp/dummy.dylib",
349 "map": {},
350 "aslr_reference": 4294967296,
351 "new_base_address": 8589934592,
352 "ifunc_count": 0
353 }"#;
354 serde_json::from_str(json).expect("dummy JumpTable")
355 }
356
357 async fn spawn_test_server(
360 on_event: Option<Arc<dyn Fn(Event) + Send + Sync>>,
361 ) -> (PatchSender, SocketAddr) {
362 let any: SocketAddr = "127.0.0.1:0".parse().unwrap();
363 let (sender, addr, _handle) = serve(any, on_event).await.expect("serve");
364 (sender, addr)
365 }
366
367 async fn connect(
368 addr: SocketAddr,
369 ) -> tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>
370 {
371 let url = format!("ws://{addr}/whisker-dev");
372 let (ws, _) = tokio_tungstenite::connect_async(&url)
373 .await
374 .expect("connect");
375 ws
376 }
377
378 fn decode_patch_frame(bytes: &[u8]) -> (serde_json::Value, Vec<u8>) {
381 assert!(bytes.len() >= 8, "frame too short");
382 let json_len = u64::from_be_bytes(bytes[..8].try_into().unwrap()) as usize;
383 assert!(bytes.len() >= 8 + json_len, "frame truncated");
384 let header: serde_json::Value =
385 serde_json::from_slice(&bytes[8..8 + json_len]).expect("parse header");
386 let dylib = bytes[8 + json_len..].to_vec();
387 (header, dylib)
388 }
389
390 #[tokio::test]
391 async fn client_can_connect_and_receive_a_broadcast_patch() {
392 let (sender, addr) = spawn_test_server(None).await;
393 let mut client = connect(addr).await;
394
395 for _ in 0..100 {
398 if sender.client_count() > 0 {
399 break;
400 }
401 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
402 }
403 assert_eq!(sender.client_count(), 1);
404
405 let table = make_dummy_jump_table();
406 let n = sender.send(Patch {
407 table: table.clone(),
408 dylib_bytes: Arc::new(b"FAKE_DYLIB_BYTES".to_vec()),
409 });
410 assert_eq!(n, 1);
411
412 let msg = tokio::time::timeout(std::time::Duration::from_secs(2), client.next())
413 .await
414 .expect("recv timed out")
415 .expect("stream ended")
416 .expect("ws error");
417 let bytes = match msg {
418 tokio_tungstenite::tungstenite::Message::Binary(b) => b,
419 other => panic!("expected binary, got {other:?}"),
420 };
421 let (header, dylib) = decode_patch_frame(&bytes);
422 assert_eq!(header["kind"], "patch");
423 assert_eq!(header["table"]["lib"], "/tmp/dummy.dylib");
424 assert_eq!(header["table"]["aslr_reference"], 4294967296_u64);
425 assert_eq!(dylib, b"FAKE_DYLIB_BYTES");
426 }
427
428 #[tokio::test]
429 async fn send_with_no_clients_returns_zero_and_does_not_error() {
430 let (sender, _addr) = spawn_test_server(None).await;
431 assert_eq!(sender.client_count(), 0);
432 let n = sender.send(Patch {
433 table: make_dummy_jump_table(),
434 dylib_bytes: Arc::new(Vec::new()),
435 });
436 assert_eq!(n, 0);
437 }
438
439 #[tokio::test]
440 async fn multiple_clients_each_receive_the_same_patch() {
441 let (sender, addr) = spawn_test_server(None).await;
442 let mut a = connect(addr).await;
443 let mut b = connect(addr).await;
444
445 for _ in 0..100 {
446 if sender.client_count() == 2 {
447 break;
448 }
449 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
450 }
451 assert_eq!(sender.client_count(), 2);
452
453 let n = sender.send(Patch {
454 table: make_dummy_jump_table(),
455 dylib_bytes: Arc::new(b"SHARED".to_vec()),
456 });
457 assert_eq!(n, 2);
458
459 for client in [&mut a, &mut b] {
460 let msg = tokio::time::timeout(std::time::Duration::from_secs(2), client.next())
461 .await
462 .expect("timeout")
463 .expect("stream end")
464 .expect("ws err");
465 assert!(matches!(
466 msg,
467 tokio_tungstenite::tungstenite::Message::Binary(_)
468 ));
469 }
470 }
471
472 #[tokio::test]
473 async fn on_event_callback_fires_for_connect_and_disconnect() {
474 let connect_count = Arc::new(AtomicUsize::new(0));
475 let disconnect_count = Arc::new(AtomicUsize::new(0));
476
477 let cc = connect_count.clone();
478 let dc = disconnect_count.clone();
479 let on_event: Arc<dyn Fn(Event) + Send + Sync> = Arc::new(move |e| match e {
480 Event::ClientConnected => {
481 cc.fetch_add(1, Ordering::SeqCst);
482 }
483 Event::ClientDisconnected => {
484 dc.fetch_add(1, Ordering::SeqCst);
485 }
486 _ => {}
487 });
488
489 let (sender, addr) = spawn_test_server(Some(on_event)).await;
490
491 let mut client = connect(addr).await;
492 for _ in 0..100 {
494 if connect_count.load(Ordering::SeqCst) == 1 {
495 break;
496 }
497 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
498 }
499 assert_eq!(connect_count.load(Ordering::SeqCst), 1);
500
501 client
503 .send(tokio_tungstenite::tungstenite::Message::Close(None))
504 .await
505 .expect("send close");
506 drop(client);
507
508 for _ in 0..200 {
510 if disconnect_count.load(Ordering::SeqCst) == 1 {
511 break;
512 }
513 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
514 }
515 assert_eq!(disconnect_count.load(Ordering::SeqCst), 1);
516
517 assert_eq!(sender.client_count(), 0);
519 }
520
521 #[test]
522 fn parse_client_log_decodes_a_well_formed_frame() {
523 let log = parse_client_log(
524 r#"{"kind":"log","stream":"stdout","line":"hello world","ts_micros":"12345"}"#,
525 )
526 .expect("valid log envelope");
527 assert_eq!(log.stream, "stdout");
528 assert_eq!(log.line, "hello world");
529 assert_eq!(log.ts_micros, 12345);
530 }
531
532 #[test]
533 fn parse_client_log_falls_back_to_zero_ts_when_missing() {
534 let log =
535 parse_client_log(r#"{"kind":"log","stream":"stderr","line":"oops"}"#).expect("valid");
536 assert_eq!(log.stream, "stderr");
537 assert_eq!(log.line, "oops");
538 assert_eq!(log.ts_micros, 0);
539 }
540
541 #[test]
542 fn parse_client_log_rejects_other_kinds() {
543 assert!(parse_client_log(r#"{"kind":"hello","aslr_reference":42}"#,).is_none());
544 }
545
546 #[tokio::test]
547 async fn on_event_callback_fires_with_device_log_lines() {
548 use std::sync::Mutex;
549 let captured: Arc<Mutex<Vec<(String, String, u128)>>> = Arc::new(Mutex::new(Vec::new()));
550 let captured_clone = Arc::clone(&captured);
551 let on_event: Arc<dyn Fn(Event) + Send + Sync> = Arc::new(move |e| {
552 if let Event::DeviceLog {
553 stream,
554 line,
555 ts_micros,
556 } = e
557 {
558 captured_clone
559 .lock()
560 .unwrap()
561 .push((stream, line, ts_micros));
562 }
563 });
564
565 let (sender, addr) = spawn_test_server(Some(on_event)).await;
566 let mut client = connect(addr).await;
567 for _ in 0..100 {
568 if sender.client_count() > 0 {
569 break;
570 }
571 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
572 }
573 assert_eq!(sender.client_count(), 1);
574
575 client
576 .send(tokio_tungstenite::tungstenite::Message::Text(
577 r#"{"kind":"log","stream":"stdout","line":"hi from device","ts_micros":"42"}"#
578 .into(),
579 ))
580 .await
581 .expect("send log frame");
582
583 for _ in 0..100 {
585 if !captured.lock().unwrap().is_empty() {
586 break;
587 }
588 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
589 }
590 let g = captured.lock().unwrap();
591 assert_eq!(g.len(), 1);
592 assert_eq!(g[0].0, "stdout");
593 assert_eq!(g[0].1, "hi from device");
594 assert_eq!(g[0].2, 42);
595 }
596}