zlayer_overlayd/
transport.rs1use std::path::Path;
14
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use zlayer_types::overlayd::OverlaydFrame;
17
18use crate::error::{OverlaydError, Result, MAX_FRAME_BYTES};
19
20pub struct FramedConn<S> {
22 stream: S,
23}
24
25impl<S: AsyncReadExt + AsyncWriteExt + Unpin> FramedConn<S> {
26 pub fn new(stream: S) -> Self {
28 Self { stream }
29 }
30
31 pub fn into_inner(self) -> S {
33 self.stream
34 }
35
36 pub async fn send(&mut self, frame: &OverlaydFrame) -> Result<()> {
43 let body = serde_json::to_vec(frame)?;
44 if body.len() > MAX_FRAME_BYTES {
45 return Err(OverlaydError::FrameTooLarge(body.len()));
46 }
47 let len =
48 u32::try_from(body.len()).map_err(|_| OverlaydError::FrameTooLarge(body.len()))?;
49 self.stream.write_all(&len.to_le_bytes()).await?;
50 self.stream.write_all(&body).await?;
51 self.stream.flush().await?;
52 Ok(())
53 }
54
55 pub async fn recv(&mut self) -> Result<OverlaydFrame> {
63 let mut len_buf = [0u8; 4];
64 match self.stream.read_exact(&mut len_buf).await {
65 Ok(_) => {}
66 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
67 return Err(OverlaydError::Closed);
68 }
69 Err(e) => return Err(e.into()),
70 }
71 let len = u32::from_le_bytes(len_buf) as usize;
72 if len > MAX_FRAME_BYTES {
73 return Err(OverlaydError::FrameTooLarge(len));
74 }
75 let mut body = vec![0u8; len];
76 self.stream.read_exact(&mut body).await?;
77 Ok(serde_json::from_slice(&body)?)
78 }
79}
80
81#[cfg(unix)]
96pub async fn serve<F, Fut>(endpoint: &Path, handler: F) -> Result<()>
97where
98 F: Fn(FramedConn<tokio::net::UnixStream>) -> Fut + Send + Sync + 'static,
99 Fut: std::future::Future<Output = ()> + Send + 'static,
100{
101 use std::sync::Arc;
102
103 if tokio::fs::try_exists(endpoint).await.unwrap_or(false) {
104 let _ = tokio::fs::remove_file(endpoint).await;
106 }
107 if let Some(parent) = endpoint.parent() {
108 tokio::fs::create_dir_all(parent).await?;
109 }
110 let listener = tokio::net::UnixListener::bind(endpoint)?;
111
112 #[allow(unsafe_code)]
119 {
120 use std::os::unix::ffi::OsStrExt as _;
121 use std::os::unix::fs::PermissionsExt as _;
122 if let Err(e) = std::fs::set_permissions(endpoint, std::fs::Permissions::from_mode(0o660)) {
123 tracing::debug!(error = %e, socket = %endpoint.display(), "failed to set overlayd socket perms 0o660");
124 }
125 if let (Ok(path_c), Ok(gname)) = (
129 std::ffi::CString::new(endpoint.as_os_str().as_bytes()),
130 std::ffi::CString::new("zlayer"),
131 ) {
132 unsafe {
136 let grp = libc::getgrnam(gname.as_ptr());
137 if grp.is_null() {
138 tracing::debug!(socket = %endpoint.display(), "group 'zlayer' not present; skipping overlayd socket chown");
139 } else {
140 let gid = (*grp).gr_gid;
141 if libc::chown(path_c.as_ptr(), u32::MAX, gid) == 0 {
142 tracing::info!(socket = %endpoint.display(), "overlayd socket chowned to <owner>:zlayer 0o660");
143 } else {
144 tracing::debug!(error = %std::io::Error::last_os_error(), socket = %endpoint.display(), "failed to chown overlayd socket to zlayer group");
145 }
146 }
147 }
148 }
149 }
150
151 tracing::info!(endpoint = %endpoint.display(), "overlayd IPC listening (unix socket)");
152 let handler = Arc::new(handler);
153 loop {
154 let (stream, _addr) = listener.accept().await?;
155 let handler = Arc::clone(&handler);
156 tokio::spawn(async move {
157 handler(FramedConn::new(stream)).await;
158 });
159 }
160}
161
162#[cfg(windows)]
169pub async fn serve<F, Fut>(endpoint: &Path, handler: F) -> Result<()>
170where
171 F: Fn(FramedConn<tokio::net::windows::named_pipe::NamedPipeServer>) -> Fut
172 + Send
173 + Sync
174 + 'static,
175 Fut: std::future::Future<Output = ()> + Send + 'static,
176{
177 use std::sync::Arc;
178 use tokio::net::windows::named_pipe::ServerOptions;
179
180 let pipe_name = endpoint
181 .to_str()
182 .ok_or_else(|| OverlaydError::Other("named-pipe path is not valid UTF-8".to_string()))?
183 .to_string();
184 tracing::info!(endpoint = %pipe_name, "overlayd IPC listening (named pipe)");
185 let handler = Arc::new(handler);
186 loop {
187 let server = ServerOptions::new()
188 .first_pipe_instance(false)
189 .create(&pipe_name)?;
190 server.connect().await?;
191 let handler = Arc::clone(&handler);
192 tokio::spawn(async move {
193 handler(FramedConn::new(server)).await;
194 });
195 }
196}
197
198#[cfg(unix)]
205pub type ClientConn = FramedConn<tokio::net::UnixStream>;
206#[cfg(windows)]
209pub type ClientConn = FramedConn<tokio::net::windows::named_pipe::NamedPipeClient>;
210
211#[cfg(unix)]
216pub async fn connect(endpoint: &Path) -> Result<ClientConn> {
217 let stream = tokio::net::UnixStream::connect(endpoint).await?;
218 Ok(FramedConn::new(stream))
219}
220
221#[cfg(windows)]
226#[allow(clippy::unused_async)]
229pub async fn connect(endpoint: &Path) -> Result<ClientConn> {
230 use tokio::net::windows::named_pipe::ClientOptions;
231 let pipe_name = endpoint
232 .to_str()
233 .ok_or_else(|| OverlaydError::Other("named-pipe path is not valid UTF-8".to_string()))?;
234 let client = ClientOptions::new().open(pipe_name)?;
235 Ok(FramedConn::new(client))
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use zlayer_types::overlayd::{OverlaydRequest, OverlaydResponse};
242
243 #[tokio::test]
244 async fn frames_round_trip_over_duplex() {
245 let (a, b) = tokio::io::duplex(64 * 1024);
247 let mut client = FramedConn::new(a);
248 let mut server = FramedConn::new(b);
249
250 let req = OverlaydFrame::Request {
251 id: 7,
252 request: OverlaydRequest::Status,
253 };
254 client.send(&req).await.unwrap();
255 let got = server.recv().await.unwrap();
256 assert_eq!(got, req);
257
258 let resp = OverlaydFrame::Response {
259 id: 7,
260 response: OverlaydResponse::Ok,
261 };
262 server.send(&resp).await.unwrap();
263 assert_eq!(client.recv().await.unwrap(), resp);
264 }
265
266 #[tokio::test]
267 async fn clean_eof_maps_to_closed() {
268 let (a, b) = tokio::io::duplex(1024);
269 drop(a); let mut server = FramedConn::new(b);
271 assert!(matches!(server.recv().await, Err(OverlaydError::Closed)));
272 }
273}