taskers_control/
socket.rs1use std::{
2 fs,
3 future::Future,
4 io,
5 os::unix::{fs::PermissionsExt, net::UnixListener as StdUnixListener},
6 path::Path,
7};
8
9use serde_json::{from_slice, to_vec};
10use tokio::{
11 io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
12 net::{UnixListener, UnixStream},
13};
14
15use crate::{
16 RequestFrame,
17 controller::InMemoryController,
18 protocol::{ControlCommand, ControlError, ControlResponse, ResponseFrame},
19};
20
21pub fn bind_socket(path: impl AsRef<Path>) -> io::Result<UnixListener> {
22 let path = path.as_ref();
23 if path.exists() {
24 std::fs::remove_file(path)?;
25 }
26 let listener = StdUnixListener::bind(path)?;
27 set_private_socket_permissions(path)?;
28 listener.set_nonblocking(true)?;
29 UnixListener::from_std(listener)
30}
31
32pub async fn serve<S>(
33 listener: UnixListener,
34 controller: InMemoryController,
35 shutdown: S,
36) -> io::Result<()>
37where
38 S: Future<Output = ()> + Send,
39{
40 serve_with_handler(
41 listener,
42 move |command| {
43 let controller = controller.clone();
44 async move {
45 controller
46 .handle(command)
47 .map_err(|error| ControlError::internal(error.to_string()))
48 }
49 },
50 shutdown,
51 )
52 .await
53}
54
55pub async fn serve_with_handler<S, H, F>(
56 listener: UnixListener,
57 handler: H,
58 shutdown: S,
59) -> io::Result<()>
60where
61 S: Future<Output = ()> + Send,
62 H: Fn(ControlCommand) -> F + Clone + Send + Sync + 'static,
63 F: Future<Output = Result<ControlResponse, ControlError>> + Send + 'static,
64{
65 tokio::pin!(shutdown);
66
67 loop {
68 tokio::select! {
69 _ = &mut shutdown => break,
70 accepted = listener.accept() => {
71 let (stream, _) = accepted?;
72 let handler = handler.clone();
73 tokio::spawn(async move {
74 let _ = handle_connection_with_handler(stream, handler).await;
75 });
76 }
77 }
78 }
79
80 Ok(())
81}
82
83async fn handle_connection_with_handler<H, F>(stream: UnixStream, handler: H) -> io::Result<()>
84where
85 H: Fn(ControlCommand) -> F + Clone + Send + Sync + 'static,
86 F: Future<Output = Result<ControlResponse, ControlError>> + Send + 'static,
87{
88 ensure_peer_is_owner(&stream)?;
89 let (read_half, mut write_half) = stream.into_split();
90 let mut reader = BufReader::new(read_half);
91 let mut line = String::new();
92 reader.read_line(&mut line).await?;
93
94 let request: RequestFrame = from_slice(line.trim_end().as_bytes()).map_err(invalid_data)?;
95 let result = handler(request.command).await;
96 let response = ResponseFrame {
97 request_id: request.request_id,
98 response: result,
99 };
100 let payload = to_vec(&response).map_err(invalid_data)?;
101 write_half.write_all(&payload).await?;
102 write_half.write_all(b"\n").await?;
103 write_half.flush().await?;
104
105 Ok(())
106}
107
108fn invalid_data(error: impl ToString) -> io::Error {
109 io::Error::new(io::ErrorKind::InvalidData, error.to_string())
110}
111
112fn set_private_socket_permissions(path: &Path) -> io::Result<()> {
113 let permissions = fs::Permissions::from_mode(0o600);
114 fs::set_permissions(path, permissions)
115}
116
117fn ensure_peer_is_owner(stream: &UnixStream) -> io::Result<()> {
118 #[cfg(not(target_os = "linux"))]
119 {
120 let _ = stream;
121 Ok(())
122 }
123
124 #[cfg(target_os = "linux")]
125 {
126 use std::os::fd::AsRawFd;
127
128 let expected_uid = unsafe { libc::geteuid() };
129 let mut credentials = libc::ucred {
130 pid: 0,
131 uid: 0,
132 gid: 0,
133 };
134 let mut len = std::mem::size_of::<libc::ucred>() as libc::socklen_t;
135 let result = unsafe {
136 libc::getsockopt(
137 stream.as_raw_fd(),
138 libc::SOL_SOCKET,
139 libc::SO_PEERCRED,
140 (&mut credentials as *mut libc::ucred).cast(),
141 &mut len,
142 )
143 };
144 if result != 0 {
145 return Err(io::Error::last_os_error());
146 }
147 if credentials.uid != expected_uid {
148 return Err(io::Error::new(
149 io::ErrorKind::PermissionDenied,
150 format!(
151 "rejecting control client from uid {} (expected {})",
152 credentials.uid, expected_uid
153 ),
154 ));
155 }
156 Ok(())
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use std::os::unix::fs::PermissionsExt as _;
163 use std::{future::pending, path::PathBuf};
164
165 use tempfile::tempdir;
166 use tokio::sync::oneshot;
167
168 use taskers_domain::AppModel;
169
170 use crate::{
171 client::ControlClient,
172 controller::InMemoryController,
173 protocol::{ControlCommand, ControlQuery, ControlResponse},
174 };
175
176 use super::{bind_socket, serve};
177
178 #[tokio::test]
179 async fn client_and_server_roundtrip() {
180 let tempdir = tempdir().expect("tempdir");
181 let socket_path = PathBuf::from(tempdir.path()).join("taskers.sock");
182 let listener = bind_socket(&socket_path).expect("listener");
183 let controller = InMemoryController::new(AppModel::new("Main"));
184 let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
185
186 let server = tokio::spawn(serve(listener, controller.clone(), async move {
187 let _ = shutdown_rx.await;
188 }));
189
190 let client = ControlClient::new(&socket_path);
191 let created = client
192 .send(ControlCommand::CreateWorkspace {
193 label: "Docs".into(),
194 })
195 .await
196 .expect("create workspace request");
197 assert!(matches!(
198 created.response,
199 Ok(ControlResponse::WorkspaceCreated { .. })
200 ));
201
202 let status = client
203 .send(ControlCommand::QueryStatus {
204 query: ControlQuery::All,
205 })
206 .await
207 .expect("query request");
208 match status.response {
209 Ok(ControlResponse::Status { session }) => {
210 assert_eq!(session.model.workspaces.len(), 2);
211 }
212 other => panic!("unexpected response: {other:?}"),
213 }
214
215 shutdown_tx.send(()).expect("shutdown");
216 server.await.expect("server task").expect("serve cleanly");
217 drop(pending::<()>());
218 }
219
220 #[tokio::test]
221 async fn bound_socket_is_private() {
222 let tempdir = tempdir().expect("tempdir");
223 let socket_path = PathBuf::from(tempdir.path()).join("taskers.sock");
224 let listener = bind_socket(&socket_path).expect("listener");
225
226 let mode = std::fs::metadata(&socket_path)
227 .expect("socket metadata")
228 .permissions()
229 .mode()
230 & 0o777;
231 assert_eq!(mode, 0o600);
232
233 drop(listener);
234 }
235}