Skip to main content

taskers_control/
socket.rs

1use std::{future::Future, io, os::unix::net::UnixListener as StdUnixListener, path::Path};
2
3use serde_json::{from_slice, to_vec};
4use tokio::{
5    io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
6    net::{UnixListener, UnixStream},
7};
8
9use crate::{RequestFrame, controller::InMemoryController, protocol::ResponseFrame};
10
11pub fn bind_socket(path: impl AsRef<Path>) -> io::Result<UnixListener> {
12    let path = path.as_ref();
13    if path.exists() {
14        std::fs::remove_file(path)?;
15    }
16    let listener = StdUnixListener::bind(path)?;
17    listener.set_nonblocking(true)?;
18    UnixListener::from_std(listener)
19}
20
21pub async fn serve<S>(
22    listener: UnixListener,
23    controller: InMemoryController,
24    shutdown: S,
25) -> io::Result<()>
26where
27    S: Future<Output = ()> + Send,
28{
29    tokio::pin!(shutdown);
30
31    loop {
32        tokio::select! {
33            _ = &mut shutdown => break,
34            accepted = listener.accept() => {
35                let (stream, _) = accepted?;
36                let controller = controller.clone();
37                tokio::spawn(async move {
38                    let _ = handle_connection(stream, controller).await;
39                });
40            }
41        }
42    }
43
44    Ok(())
45}
46
47async fn handle_connection(stream: UnixStream, controller: InMemoryController) -> io::Result<()> {
48    let (read_half, mut write_half) = stream.into_split();
49    let mut reader = BufReader::new(read_half);
50    let mut line = String::new();
51    reader.read_line(&mut line).await?;
52
53    let request: RequestFrame = from_slice(line.trim_end().as_bytes()).map_err(invalid_data)?;
54    let response = ResponseFrame {
55        request_id: request.request_id,
56        response: controller
57            .handle(request.command)
58            .map_err(|error| error.to_string()),
59    };
60    let payload = to_vec(&response).map_err(invalid_data)?;
61    write_half.write_all(&payload).await?;
62    write_half.write_all(b"\n").await?;
63    write_half.flush().await?;
64
65    Ok(())
66}
67
68fn invalid_data(error: impl ToString) -> io::Error {
69    io::Error::new(io::ErrorKind::InvalidData, error.to_string())
70}
71
72#[cfg(test)]
73mod tests {
74    use std::{future::pending, path::PathBuf};
75
76    use tempfile::tempdir;
77    use tokio::sync::oneshot;
78
79    use taskers_domain::AppModel;
80
81    use crate::{
82        client::ControlClient,
83        controller::InMemoryController,
84        protocol::{ControlCommand, ControlQuery, ControlResponse},
85    };
86
87    use super::{bind_socket, serve};
88
89    #[tokio::test]
90    async fn client_and_server_roundtrip() {
91        let tempdir = tempdir().expect("tempdir");
92        let socket_path = PathBuf::from(tempdir.path()).join("taskers.sock");
93        let listener = bind_socket(&socket_path).expect("listener");
94        let controller = InMemoryController::new(AppModel::new("Main"));
95        let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
96
97        let server = tokio::spawn(serve(listener, controller.clone(), async move {
98            let _ = shutdown_rx.await;
99        }));
100
101        let client = ControlClient::new(&socket_path);
102        let created = client
103            .send(ControlCommand::CreateWorkspace {
104                label: "Docs".into(),
105            })
106            .await
107            .expect("create workspace request");
108        assert!(matches!(
109            created.response,
110            Ok(ControlResponse::WorkspaceCreated { .. })
111        ));
112
113        let status = client
114            .send(ControlCommand::QueryStatus {
115                query: ControlQuery::All,
116            })
117            .await
118            .expect("query request");
119        match status.response {
120            Ok(ControlResponse::Status { session }) => {
121                assert_eq!(session.model.workspaces.len(), 2);
122            }
123            other => panic!("unexpected response: {other:?}"),
124        }
125
126        shutdown_tx.send(()).expect("shutdown");
127        server.await.expect("server task").expect("serve cleanly");
128        drop(pending::<()>());
129    }
130}