Skip to main content

walrus_socket/
server.rs

1//! Unix domain socket server — accept loop and per-connection message handler.
2
3use tokio::{
4    net::UnixListener,
5    sync::{mpsc, oneshot},
6};
7use wcore::protocol::{
8    codec,
9    message::{ClientMessage, ServerMessage},
10};
11
12/// Accept connections on the given `UnixListener` until shutdown is signalled.
13///
14/// Each connection is handled in a separate task. For each incoming
15/// `ClientMessage`, calls `on_message(msg, reply_tx)` where `reply_tx` is
16/// the per-connection sender for streaming `ServerMessage`s back.
17pub async fn accept_loop<F>(
18    listener: UnixListener,
19    on_message: F,
20    mut shutdown: oneshot::Receiver<()>,
21) where
22    F: Fn(ClientMessage, mpsc::UnboundedSender<ServerMessage>) + Clone + Send + 'static,
23{
24    loop {
25        tokio::select! {
26            result = listener.accept() => {
27                match result {
28                    Ok((stream, _addr)) => {
29                        let cb = on_message.clone();
30                        tokio::spawn(async move {
31                            let (mut reader, mut writer) = stream.into_split();
32                            let (tx, mut rx) = mpsc::unbounded_channel::<ServerMessage>();
33                            let send_task = tokio::spawn(async move {
34                                while let Some(msg) = rx.recv().await {
35                                    if let Err(e) = codec::write_message(&mut writer, &msg).await {
36                                        tracing::error!("failed to write message: {e}");
37                                        break;
38                                    }
39                                }
40                            });
41
42                            loop {
43                                let client_msg: ClientMessage = match codec::read_message(&mut reader).await {
44                                    Ok(msg) => msg,
45                                    Err(codec::FrameError::ConnectionClosed) => break,
46                                    Err(e) => { tracing::debug!("read error: {e}"); break; }
47                                };
48                                cb(client_msg, tx.clone());
49                            }
50
51                            drop(tx);
52                            let _ = send_task.await;
53                        });
54                    }
55                    Err(e) => tracing::error!("failed to accept connection: {e}"),
56                }
57            }
58            _ = &mut shutdown => {
59                tracing::info!("accept loop shutting down");
60                break;
61            }
62        }
63    }
64}