1use anyhow::Result;
37use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
38use axum::extract::State;
39use axum::response::Response;
40use axum::routing::get;
41use axum::Router;
42use std::net::SocketAddr;
43use std::sync::{Arc, Mutex};
44use tokio::sync::broadcast;
45
46use crate::Event;
47
48#[derive(Debug, Clone)]
52pub struct Patch {
53 pub table: subsecond_types::JumpTable,
56 pub dylib_bytes: Arc<Vec<u8>>,
60}
61
62#[derive(Debug, Clone, serde::Serialize)]
73#[serde(tag = "kind", rename_all = "snake_case")]
74enum PatchHeader<'a> {
75 Patch {
76 #[serde(serialize_with = "wire_jump_table::serialize")]
77 table: &'a subsecond_types::JumpTable,
78 },
79}
80
81pub mod wire_jump_table {
86 use serde::ser::SerializeStruct;
87 use serde::Serializer;
88 use subsecond_types::JumpTable;
89
90 pub fn serialize<S: Serializer>(t: &JumpTable, s: S) -> Result<S::Ok, S::Error> {
91 let pairs: Vec<(u64, u64)> = t.map.iter().map(|(k, v)| (*k, *v)).collect();
92 let mut st = s.serialize_struct("JumpTable", 5)?;
93 st.serialize_field("lib", &t.lib)?;
94 st.serialize_field("map", &pairs)?;
95 st.serialize_field("aslr_reference", &t.aslr_reference)?;
96 st.serialize_field("new_base_address", &t.new_base_address)?;
97 st.serialize_field("ifunc_count", &t.ifunc_count)?;
98 st.end()
99 }
100}
101
102#[derive(Clone)]
106pub struct PatchSender {
107 tx: broadcast::Sender<Patch>,
108 aslr_reference: Arc<Mutex<Option<u64>>>,
115}
116
117impl PatchSender {
118 pub fn send(&self, patch: Patch) -> usize {
122 self.tx.send(patch).unwrap_or(0)
123 }
124
125 pub fn client_count(&self) -> usize {
127 self.tx.receiver_count()
128 }
129
130 pub fn latest_aslr_reference(&self) -> Option<u64> {
136 self.aslr_reference.lock().ok().and_then(|g| *g)
137 }
138}
139
140#[derive(Clone)]
141struct AppState {
142 tx: broadcast::Sender<Patch>,
143 on_event: Option<Arc<dyn Fn(Event) + Send + Sync>>,
144 aslr_reference: Arc<Mutex<Option<u64>>>,
145 expected_token: Option<Arc<str>>,
153}
154
155pub async fn serve(
164 addr: SocketAddr,
165 on_event: Option<Arc<dyn Fn(Event) + Send + Sync>>,
166 expected_token: Option<String>,
167) -> Result<(PatchSender, SocketAddr, tokio::task::JoinHandle<()>)> {
168 let (tx, _rx) = broadcast::channel::<Patch>(16);
169 let aslr_reference: Arc<Mutex<Option<u64>>> = Arc::new(Mutex::new(None));
170 let state = AppState {
171 tx: tx.clone(),
172 on_event,
173 aslr_reference: Arc::clone(&aslr_reference),
174 expected_token: expected_token.map(Arc::from),
175 };
176
177 let app = Router::new()
178 .route("/whisker-dev", get(ws_handler))
179 .with_state(state);
180
181 let listener = tokio::net::TcpListener::bind(addr).await?;
182 let bound = listener.local_addr()?;
183
184 let handle = tokio::spawn(async move {
185 if let Err(e) = axum::serve(listener, app).await {
186 whisker_build::ui::error(format!("axum serve error: {e}"));
187 }
188 });
189
190 Ok((PatchSender { tx, aslr_reference }, bound, handle))
191}
192
193async fn ws_handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
194 ws.on_upgrade(move |socket| handle_socket(socket, state))
195}
196
197async fn handle_socket(socket: WebSocket, state: AppState) {
198 use futures_util::{SinkExt, StreamExt};
199
200 let (mut tx_ws, mut rx_ws) = socket.split();
201 let mut bcast_rx = state.tx.subscribe();
202 whisker_build::ui::set_status(format!("{} client(s) connected", state.tx.receiver_count(),));
203 if let Some(cb) = &state.on_event {
206 cb(Event::ClientConnected);
207 }
208
209 let mut authed = state.expected_token.is_none();
214
215 loop {
216 tokio::select! {
217 recv = bcast_rx.recv() => {
219 let patch = match recv {
220 Ok(p) => p,
221 Err(broadcast::error::RecvError::Lagged(_)) => continue,
222 Err(broadcast::error::RecvError::Closed) => break,
223 };
224 if !authed {
228 continue;
229 }
230 let frame = match encode_patch_frame(&patch) {
231 Ok(b) => b,
232 Err(e) => {
233 whisker_build::ui::warn(format!("encode patch frame: {e}"));
234 continue;
235 }
236 };
237 if tx_ws.send(Message::Binary(frame.into())).await.is_err() {
238 break;
239 }
240 }
241 msg = rx_ws.next() => {
246 match msg {
247 Some(Ok(Message::Close(_))) | None => break,
248 Some(Err(_)) => break,
249 Some(Ok(Message::Text(t))) => {
250 if let Some(hello) = parse_client_hello(&t) {
251 if let Some(expected) = &state.expected_token {
256 if hello.token.as_deref() != Some(expected.as_ref()) {
257 whisker_build::ui::warn(
258 "rejecting hot-reload client: missing/invalid dev token",
259 );
260 break;
261 }
262 authed = true;
263 }
264 let aslr = hello.aslr_reference;
265 whisker_build::ui::debug(format!(
266 "client hello · aslr_reference={aslr:#x}"
267 ));
268 if let Ok(mut g) = state.aslr_reference.lock() {
269 *g = Some(aslr);
270 }
271 } else if let Some(log) = parse_client_log(&t) {
272 if let Some(cb) = &state.on_event {
273 cb(Event::DeviceLog {
274 stream: log.stream,
275 line: log.line,
276 ts_micros: log.ts_micros,
277 });
278 }
279 }
280 }
281 _ => {}
282 }
283 }
284 }
285 }
286
287 if let Ok(mut g) = state.aslr_reference.lock() {
295 *g = None;
296 }
297
298 if let Some(cb) = &state.on_event {
299 cb(Event::ClientDisconnected);
300 }
301}
302
303fn encode_patch_frame(patch: &Patch) -> Result<Vec<u8>> {
306 let header = PatchHeader::Patch {
307 table: &patch.table,
308 };
309 let json = serde_json::to_vec(&header)?;
310 let json_len = json.len() as u64;
311 let dylib = patch.dylib_bytes.as_slice();
312 let mut frame = Vec::with_capacity(8 + json.len() + dylib.len());
313 frame.extend_from_slice(&json_len.to_be_bytes());
314 frame.extend_from_slice(&json);
315 frame.extend_from_slice(dylib);
316 Ok(frame)
317}
318
319struct ClientHello {
321 aslr_reference: u64,
322 token: Option<String>,
324}
325
326fn parse_client_hello(text: &str) -> Option<ClientHello> {
330 #[derive(serde::Deserialize)]
331 struct Hello {
332 kind: String,
333 aslr_reference: u64,
334 #[serde(default)]
335 token: Option<String>,
336 }
337 let h: Hello = serde_json::from_str(text).ok()?;
338 if h.kind == "hello" {
339 Some(ClientHello {
340 aslr_reference: h.aslr_reference,
341 token: h.token,
342 })
343 } else {
344 None
345 }
346}
347
348struct ClientLog {
351 stream: String,
352 line: String,
353 ts_micros: u128,
354}
355
356fn parse_client_log(text: &str) -> Option<ClientLog> {
367 #[derive(serde::Deserialize)]
368 struct Log {
369 kind: String,
370 stream: String,
371 line: String,
372 #[serde(default)]
373 ts_micros: Option<String>,
374 }
375 let h: Log = serde_json::from_str(text).ok()?;
376 if h.kind != "log" {
377 return None;
378 }
379 let ts_micros = h
380 .ts_micros
381 .as_deref()
382 .and_then(|s| s.parse::<u128>().ok())
383 .unwrap_or(0);
384 Some(ClientLog {
385 stream: h.stream,
386 line: h.line,
387 ts_micros,
388 })
389}
390
391#[cfg(test)]
396mod tests {
397 use super::*;
398 use futures_util::{SinkExt, StreamExt};
399 use std::sync::atomic::{AtomicUsize, Ordering};
400
401 fn make_dummy_jump_table() -> subsecond_types::JumpTable {
402 let json = r#"{
406 "lib": "/tmp/dummy.dylib",
407 "map": {},
408 "aslr_reference": 4294967296,
409 "new_base_address": 8589934592,
410 "ifunc_count": 0
411 }"#;
412 serde_json::from_str(json).expect("dummy JumpTable")
413 }
414
415 async fn spawn_test_server(
418 on_event: Option<Arc<dyn Fn(Event) + Send + Sync>>,
419 ) -> (PatchSender, SocketAddr) {
420 let any: SocketAddr = "127.0.0.1:0".parse().unwrap();
421 let (sender, addr, _handle) = serve(any, on_event, None).await.expect("serve");
422 (sender, addr)
423 }
424
425 async fn connect(
426 addr: SocketAddr,
427 ) -> tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>
428 {
429 let url = format!("ws://{addr}/whisker-dev");
430 let (ws, _) = tokio_tungstenite::connect_async(&url)
431 .await
432 .expect("connect");
433 ws
434 }
435
436 fn decode_patch_frame(bytes: &[u8]) -> (serde_json::Value, Vec<u8>) {
439 assert!(bytes.len() >= 8, "frame too short");
440 let json_len = u64::from_be_bytes(bytes[..8].try_into().unwrap()) as usize;
441 assert!(bytes.len() >= 8 + json_len, "frame truncated");
442 let header: serde_json::Value =
443 serde_json::from_slice(&bytes[8..8 + json_len]).expect("parse header");
444 let dylib = bytes[8 + json_len..].to_vec();
445 (header, dylib)
446 }
447
448 #[tokio::test]
449 async fn client_can_connect_and_receive_a_broadcast_patch() {
450 let (sender, addr) = spawn_test_server(None).await;
451 let mut client = connect(addr).await;
452
453 for _ in 0..100 {
456 if sender.client_count() > 0 {
457 break;
458 }
459 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
460 }
461 assert_eq!(sender.client_count(), 1);
462
463 let table = make_dummy_jump_table();
464 let n = sender.send(Patch {
465 table: table.clone(),
466 dylib_bytes: Arc::new(b"FAKE_DYLIB_BYTES".to_vec()),
467 });
468 assert_eq!(n, 1);
469
470 let msg = tokio::time::timeout(std::time::Duration::from_secs(2), client.next())
471 .await
472 .expect("recv timed out")
473 .expect("stream ended")
474 .expect("ws error");
475 let bytes = match msg {
476 tokio_tungstenite::tungstenite::Message::Binary(b) => b,
477 other => panic!("expected binary, got {other:?}"),
478 };
479 let (header, dylib) = decode_patch_frame(&bytes);
480 assert_eq!(header["kind"], "patch");
481 assert_eq!(header["table"]["lib"], "/tmp/dummy.dylib");
482 assert_eq!(header["table"]["aslr_reference"], 4294967296_u64);
483 assert_eq!(dylib, b"FAKE_DYLIB_BYTES");
484 }
485
486 async fn spawn_test_server_with_token(token: Option<String>) -> (PatchSender, SocketAddr) {
487 let any: SocketAddr = "127.0.0.1:0".parse().unwrap();
488 let (sender, addr, _handle) = serve(any, None, token).await.expect("serve");
489 (sender, addr)
490 }
491
492 #[tokio::test]
493 async fn client_with_valid_token_is_armed_and_receives_patches() {
494 use futures_util::SinkExt;
495 use tokio_tungstenite::tungstenite::Message as TMsg;
496
497 let (sender, addr) = spawn_test_server_with_token(Some("s3kret".into())).await;
498 let mut client = connect(addr).await;
499
500 client
502 .send(TMsg::Text(
503 r#"{"kind":"hello","aslr_reference":4294967296,"token":"s3kret"}"#.into(),
504 ))
505 .await
506 .expect("send hello");
507
508 for _ in 0..200 {
511 if sender.latest_aslr_reference().is_some() {
512 break;
513 }
514 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
515 }
516 assert_eq!(sender.latest_aslr_reference(), Some(0x1_0000_0000));
517
518 let n = sender.send(Patch {
519 table: make_dummy_jump_table(),
520 dylib_bytes: Arc::new(b"OK".to_vec()),
521 });
522 assert_eq!(n, 1);
523
524 let msg = tokio::time::timeout(std::time::Duration::from_secs(2), client.next())
525 .await
526 .expect("recv timed out")
527 .expect("stream ended")
528 .expect("ws error");
529 assert!(
530 matches!(msg, TMsg::Binary(_)),
531 "authed client should receive the patch frame"
532 );
533 }
534
535 #[tokio::test]
536 async fn client_with_invalid_token_is_disconnected_and_gets_no_patch() {
537 use futures_util::SinkExt;
538 use tokio_tungstenite::tungstenite::Message as TMsg;
539
540 let (sender, addr) = spawn_test_server_with_token(Some("s3kret".into())).await;
541 let mut client = connect(addr).await;
542
543 client
546 .send(TMsg::Text(
547 r#"{"kind":"hello","aslr_reference":1,"token":"WRONG"}"#.into(),
548 ))
549 .await
550 .expect("send hello");
551
552 let ended = tokio::time::timeout(std::time::Duration::from_secs(2), async {
555 loop {
556 match client.next().await {
557 Some(Ok(TMsg::Binary(_))) => return false, None | Some(Ok(TMsg::Close(_))) | Some(Err(_)) => return true,
559 _ => continue,
560 }
561 }
562 })
563 .await
564 .expect("disconnect timed out");
565 assert!(
566 ended,
567 "unauthenticated client must be disconnected, not fed patches"
568 );
569
570 for _ in 0..200 {
572 if sender.client_count() == 0 {
573 break;
574 }
575 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
576 }
577 assert_eq!(sender.client_count(), 0);
578 }
579
580 #[tokio::test]
581 async fn send_with_no_clients_returns_zero_and_does_not_error() {
582 let (sender, _addr) = spawn_test_server(None).await;
583 assert_eq!(sender.client_count(), 0);
584 let n = sender.send(Patch {
585 table: make_dummy_jump_table(),
586 dylib_bytes: Arc::new(Vec::new()),
587 });
588 assert_eq!(n, 0);
589 }
590
591 #[tokio::test]
592 async fn multiple_clients_each_receive_the_same_patch() {
593 let (sender, addr) = spawn_test_server(None).await;
594 let mut a = connect(addr).await;
595 let mut b = connect(addr).await;
596
597 for _ in 0..100 {
598 if sender.client_count() == 2 {
599 break;
600 }
601 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
602 }
603 assert_eq!(sender.client_count(), 2);
604
605 let n = sender.send(Patch {
606 table: make_dummy_jump_table(),
607 dylib_bytes: Arc::new(b"SHARED".to_vec()),
608 });
609 assert_eq!(n, 2);
610
611 for client in [&mut a, &mut b] {
612 let msg = tokio::time::timeout(std::time::Duration::from_secs(2), client.next())
613 .await
614 .expect("timeout")
615 .expect("stream end")
616 .expect("ws err");
617 assert!(matches!(
618 msg,
619 tokio_tungstenite::tungstenite::Message::Binary(_)
620 ));
621 }
622 }
623
624 #[tokio::test]
625 async fn on_event_callback_fires_for_connect_and_disconnect() {
626 let connect_count = Arc::new(AtomicUsize::new(0));
627 let disconnect_count = Arc::new(AtomicUsize::new(0));
628
629 let cc = connect_count.clone();
630 let dc = disconnect_count.clone();
631 let on_event: Arc<dyn Fn(Event) + Send + Sync> = Arc::new(move |e| match e {
632 Event::ClientConnected => {
633 cc.fetch_add(1, Ordering::SeqCst);
634 }
635 Event::ClientDisconnected => {
636 dc.fetch_add(1, Ordering::SeqCst);
637 }
638 _ => {}
639 });
640
641 let (sender, addr) = spawn_test_server(Some(on_event)).await;
642
643 let mut client = connect(addr).await;
644 for _ in 0..100 {
646 if connect_count.load(Ordering::SeqCst) == 1 {
647 break;
648 }
649 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
650 }
651 assert_eq!(connect_count.load(Ordering::SeqCst), 1);
652
653 client
655 .send(tokio_tungstenite::tungstenite::Message::Close(None))
656 .await
657 .expect("send close");
658 drop(client);
659
660 for _ in 0..200 {
662 if disconnect_count.load(Ordering::SeqCst) == 1 {
663 break;
664 }
665 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
666 }
667 assert_eq!(disconnect_count.load(Ordering::SeqCst), 1);
668
669 assert_eq!(sender.client_count(), 0);
671 }
672
673 #[test]
674 fn parse_client_log_decodes_a_well_formed_frame() {
675 let log = parse_client_log(
676 r#"{"kind":"log","stream":"stdout","line":"hello world","ts_micros":"12345"}"#,
677 )
678 .expect("valid log envelope");
679 assert_eq!(log.stream, "stdout");
680 assert_eq!(log.line, "hello world");
681 assert_eq!(log.ts_micros, 12345);
682 }
683
684 #[test]
685 fn parse_client_log_falls_back_to_zero_ts_when_missing() {
686 let log =
687 parse_client_log(r#"{"kind":"log","stream":"stderr","line":"oops"}"#).expect("valid");
688 assert_eq!(log.stream, "stderr");
689 assert_eq!(log.line, "oops");
690 assert_eq!(log.ts_micros, 0);
691 }
692
693 #[test]
694 fn parse_client_log_rejects_other_kinds() {
695 assert!(parse_client_log(r#"{"kind":"hello","aslr_reference":42}"#,).is_none());
696 }
697
698 #[tokio::test]
699 async fn on_event_callback_fires_with_device_log_lines() {
700 use std::sync::Mutex;
701 let captured: Arc<Mutex<Vec<(String, String, u128)>>> = Arc::new(Mutex::new(Vec::new()));
702 let captured_clone = Arc::clone(&captured);
703 let on_event: Arc<dyn Fn(Event) + Send + Sync> = Arc::new(move |e| {
704 if let Event::DeviceLog {
705 stream,
706 line,
707 ts_micros,
708 } = e
709 {
710 captured_clone
711 .lock()
712 .unwrap()
713 .push((stream, line, ts_micros));
714 }
715 });
716
717 let (sender, addr) = spawn_test_server(Some(on_event)).await;
718 let mut client = connect(addr).await;
719 for _ in 0..100 {
720 if sender.client_count() > 0 {
721 break;
722 }
723 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
724 }
725 assert_eq!(sender.client_count(), 1);
726
727 client
728 .send(tokio_tungstenite::tungstenite::Message::Text(
729 r#"{"kind":"log","stream":"stdout","line":"hi from device","ts_micros":"42"}"#
730 .into(),
731 ))
732 .await
733 .expect("send log frame");
734
735 for _ in 0..100 {
737 if !captured.lock().unwrap().is_empty() {
738 break;
739 }
740 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
741 }
742 let g = captured.lock().unwrap();
743 assert_eq!(g.len(), 1);
744 assert_eq!(g[0].0, "stdout");
745 assert_eq!(g[0].1, "hi from device");
746 assert_eq!(g[0].2, 42);
747 }
748}