Skip to main content

taskers_control/
socket.rs

1use std::{
2    fs,
3    future::Future,
4    io,
5    os::unix::{fs::PermissionsExt, net::UnixListener as StdUnixListener},
6    path::Path,
7};
8
9use serde_json::{from_slice, to_vec};
10use tokio::{
11    io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
12    net::{UnixListener, UnixStream},
13};
14
15use crate::{
16    RequestFrame,
17    controller::InMemoryController,
18    protocol::{ControlCommand, ControlError, ControlResponse, ResponseFrame},
19};
20
21pub fn bind_socket(path: impl AsRef<Path>) -> io::Result<UnixListener> {
22    let path = path.as_ref();
23    if path.exists() {
24        std::fs::remove_file(path)?;
25    }
26    let listener = StdUnixListener::bind(path)?;
27    set_private_socket_permissions(path)?;
28    listener.set_nonblocking(true)?;
29    UnixListener::from_std(listener)
30}
31
32pub async fn serve<S>(
33    listener: UnixListener,
34    controller: InMemoryController,
35    shutdown: S,
36) -> io::Result<()>
37where
38    S: Future<Output = ()> + Send,
39{
40    serve_with_handler(
41        listener,
42        move |command| {
43            let controller = controller.clone();
44            async move {
45                controller
46                    .handle(command)
47                    .map_err(|error| ControlError::internal(error.to_string()))
48            }
49        },
50        shutdown,
51    )
52    .await
53}
54
55pub async fn serve_with_handler<S, H, F>(
56    listener: UnixListener,
57    handler: H,
58    shutdown: S,
59) -> io::Result<()>
60where
61    S: Future<Output = ()> + Send,
62    H: Fn(ControlCommand) -> F + Clone + Send + Sync + 'static,
63    F: Future<Output = Result<ControlResponse, ControlError>> + Send + 'static,
64{
65    tokio::pin!(shutdown);
66
67    loop {
68        tokio::select! {
69            _ = &mut shutdown => break,
70            accepted = listener.accept() => {
71                let (stream, _) = accepted?;
72                let handler = handler.clone();
73                tokio::spawn(async move {
74                    let _ = handle_connection_with_handler(stream, handler).await;
75                });
76            }
77        }
78    }
79
80    Ok(())
81}
82
83async fn handle_connection_with_handler<H, F>(stream: UnixStream, handler: H) -> io::Result<()>
84where
85    H: Fn(ControlCommand) -> F + Clone + Send + Sync + 'static,
86    F: Future<Output = Result<ControlResponse, ControlError>> + Send + 'static,
87{
88    ensure_peer_is_owner(&stream)?;
89    let (read_half, mut write_half) = stream.into_split();
90    let mut reader = BufReader::new(read_half);
91    let mut line = String::new();
92    reader.read_line(&mut line).await?;
93
94    let request: RequestFrame = from_slice(line.trim_end().as_bytes()).map_err(invalid_data)?;
95    let result = handler(request.command).await;
96    let response = ResponseFrame {
97        request_id: request.request_id,
98        response: result,
99    };
100    let payload = to_vec(&response).map_err(invalid_data)?;
101    write_half.write_all(&payload).await?;
102    write_half.write_all(b"\n").await?;
103    write_half.flush().await?;
104
105    Ok(())
106}
107
108fn invalid_data(error: impl ToString) -> io::Error {
109    io::Error::new(io::ErrorKind::InvalidData, error.to_string())
110}
111
112fn set_private_socket_permissions(path: &Path) -> io::Result<()> {
113    let permissions = fs::Permissions::from_mode(0o600);
114    fs::set_permissions(path, permissions)
115}
116
117fn ensure_peer_is_owner(stream: &UnixStream) -> io::Result<()> {
118    #[cfg(not(target_os = "linux"))]
119    {
120        let _ = stream;
121        Ok(())
122    }
123
124    #[cfg(target_os = "linux")]
125    {
126        use std::os::fd::AsRawFd;
127
128        let expected_uid = unsafe { libc::geteuid() };
129        let mut credentials = libc::ucred {
130            pid: 0,
131            uid: 0,
132            gid: 0,
133        };
134        let mut len = std::mem::size_of::<libc::ucred>() as libc::socklen_t;
135        let result = unsafe {
136            libc::getsockopt(
137                stream.as_raw_fd(),
138                libc::SOL_SOCKET,
139                libc::SO_PEERCRED,
140                (&mut credentials as *mut libc::ucred).cast(),
141                &mut len,
142            )
143        };
144        if result != 0 {
145            return Err(io::Error::last_os_error());
146        }
147        if credentials.uid != expected_uid {
148            return Err(io::Error::new(
149                io::ErrorKind::PermissionDenied,
150                format!(
151                    "rejecting control client from uid {} (expected {})",
152                    credentials.uid, expected_uid
153                ),
154            ));
155        }
156        Ok(())
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use std::os::unix::fs::PermissionsExt as _;
163    use std::{future::pending, path::PathBuf};
164
165    use tempfile::tempdir;
166    use tokio::sync::oneshot;
167
168    use taskers_domain::AppModel;
169
170    use crate::{
171        client::ControlClient,
172        controller::InMemoryController,
173        protocol::{ControlCommand, ControlQuery, ControlResponse},
174    };
175
176    use super::{bind_socket, serve};
177
178    #[tokio::test]
179    async fn client_and_server_roundtrip() {
180        let tempdir = tempdir().expect("tempdir");
181        let socket_path = PathBuf::from(tempdir.path()).join("taskers.sock");
182        let listener = bind_socket(&socket_path).expect("listener");
183        let controller = InMemoryController::new(AppModel::new("Main"));
184        let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
185
186        let server = tokio::spawn(serve(listener, controller.clone(), async move {
187            let _ = shutdown_rx.await;
188        }));
189
190        let client = ControlClient::new(&socket_path);
191        let created = client
192            .send(ControlCommand::CreateWorkspace {
193                label: "Docs".into(),
194            })
195            .await
196            .expect("create workspace request");
197        assert!(matches!(
198            created.response,
199            Ok(ControlResponse::WorkspaceCreated { .. })
200        ));
201
202        let status = client
203            .send(ControlCommand::QueryStatus {
204                query: ControlQuery::All,
205            })
206            .await
207            .expect("query request");
208        match status.response {
209            Ok(ControlResponse::Status { session }) => {
210                assert_eq!(session.model.workspaces.len(), 2);
211            }
212            other => panic!("unexpected response: {other:?}"),
213        }
214
215        shutdown_tx.send(()).expect("shutdown");
216        server.await.expect("server task").expect("serve cleanly");
217        drop(pending::<()>());
218    }
219
220    #[tokio::test]
221    async fn bound_socket_is_private() {
222        let tempdir = tempdir().expect("tempdir");
223        let socket_path = PathBuf::from(tempdir.path()).join("taskers.sock");
224        let listener = bind_socket(&socket_path).expect("listener");
225
226        let mode = std::fs::metadata(&socket_path)
227            .expect("socket metadata")
228            .permissions()
229            .mode()
230            & 0o777;
231        assert_eq!(mode, 0o600);
232
233        drop(listener);
234    }
235}