1use std::io;
6use std::os::unix::io::RawFd;
7
8#[derive(Debug, Default)]
10pub struct ZsocketOptions {
11 pub listen: bool,
12 pub accept: bool,
13 pub verbose: bool,
14 pub test: bool,
15 pub target_fd: Option<RawFd>,
16}
17
18#[derive(Debug)]
20pub struct UnixSocket {
21 pub fd: RawFd,
22 pub path: String,
23 pub is_listener: bool,
24}
25
26impl UnixSocket {
27 pub fn new(fd: RawFd, path: &str, is_listener: bool) -> Self {
28 Self {
29 fd,
30 path: path.to_string(),
31 is_listener,
32 }
33 }
34}
35
36#[cfg(unix)]
38pub fn socket_listen(path: &str) -> io::Result<RawFd> {
39 let fd = unsafe { libc::socket(libc::PF_UNIX, libc::SOCK_STREAM, 0) };
40 if fd < 0 {
41 return Err(io::Error::last_os_error());
42 }
43
44 let mut addr: libc::sockaddr_un = unsafe { std::mem::zeroed() };
45 addr.sun_family = libc::AF_UNIX as libc::sa_family_t;
46
47 let path_bytes = path.as_bytes();
48 let max_len = addr.sun_path.len() - 1;
49 let copy_len = path_bytes.len().min(max_len);
50
51 for (i, &byte) in path_bytes[..copy_len].iter().enumerate() {
52 addr.sun_path[i] = byte as libc::c_char;
53 }
54
55 let result = unsafe {
56 libc::bind(
57 fd,
58 &addr as *const _ as *const libc::sockaddr,
59 std::mem::size_of::<libc::sockaddr_un>() as libc::socklen_t,
60 )
61 };
62
63 if result < 0 {
64 let err = io::Error::last_os_error();
65 unsafe { libc::close(fd) };
66 return Err(err);
67 }
68
69 let result = unsafe { libc::listen(fd, 1) };
70 if result < 0 {
71 let err = io::Error::last_os_error();
72 unsafe { libc::close(fd) };
73 return Err(err);
74 }
75
76 Ok(fd)
77}
78
79#[cfg(unix)]
81pub fn socket_accept(listen_fd: RawFd) -> io::Result<(RawFd, String)> {
82 let mut addr: libc::sockaddr_un = unsafe { std::mem::zeroed() };
83 let mut len: libc::socklen_t = std::mem::size_of::<libc::sockaddr_un>() as libc::socklen_t;
84
85 let fd = loop {
86 let result = unsafe {
87 libc::accept(
88 listen_fd,
89 &mut addr as *mut _ as *mut libc::sockaddr,
90 &mut len,
91 )
92 };
93
94 if result < 0 {
95 let err = io::Error::last_os_error();
96 if err.kind() == io::ErrorKind::Interrupted {
97 continue;
98 }
99 return Err(err);
100 }
101
102 break result;
103 };
104
105 let path = addr
106 .sun_path
107 .iter()
108 .take_while(|&&c| c != 0)
109 .map(|&c| c as u8 as char)
110 .collect::<String>();
111
112 Ok((fd, path))
113}
114
115#[cfg(unix)]
117pub fn socket_test(fd: RawFd) -> io::Result<bool> {
118 let mut pfd = libc::pollfd {
119 fd,
120 events: libc::POLLIN,
121 revents: 0,
122 };
123
124 let result = unsafe { libc::poll(&mut pfd, 1, 0) };
125 if result < 0 {
126 return Err(io::Error::last_os_error());
127 }
128
129 Ok(result > 0)
130}
131
132#[cfg(unix)]
134pub fn socket_connect(path: &str) -> io::Result<RawFd> {
135 let fd = unsafe { libc::socket(libc::PF_UNIX, libc::SOCK_STREAM, 0) };
136 if fd < 0 {
137 return Err(io::Error::last_os_error());
138 }
139
140 let mut addr: libc::sockaddr_un = unsafe { std::mem::zeroed() };
141 addr.sun_family = libc::AF_UNIX as libc::sa_family_t;
142
143 let path_bytes = path.as_bytes();
144 let max_len = addr.sun_path.len() - 1;
145 let copy_len = path_bytes.len().min(max_len);
146
147 for (i, &byte) in path_bytes[..copy_len].iter().enumerate() {
148 addr.sun_path[i] = byte as libc::c_char;
149 }
150
151 let result = unsafe {
152 libc::connect(
153 fd,
154 &addr as *const _ as *const libc::sockaddr,
155 std::mem::size_of::<libc::sockaddr_un>() as libc::socklen_t,
156 )
157 };
158
159 if result < 0 {
160 let err = io::Error::last_os_error();
161 unsafe { libc::close(fd) };
162 return Err(err);
163 }
164
165 Ok(fd)
166}
167
168#[cfg(unix)]
170pub fn socket_close(fd: RawFd) -> io::Result<()> {
171 let result = unsafe { libc::close(fd) };
172 if result < 0 {
173 return Err(io::Error::last_os_error());
174 }
175 Ok(())
176}
177
178pub fn builtin_zsocket(args: &[&str], options: &ZsocketOptions) -> (i32, String, Option<RawFd>) {
180 let mut output = String::new();
181
182 if options.listen {
183 if args.is_empty() {
184 return (1, "zsocket: -l requires an argument\n".to_string(), None);
185 }
186
187 let path = args[0];
188
189 match socket_listen(path) {
190 Ok(fd) => {
191 if options.verbose {
192 output.push_str(&format!("{} listener is on fd {}\n", path, fd));
193 }
194 (0, output, Some(fd))
195 }
196 Err(e) => (
197 1,
198 format!("zsocket: could not bind to {}: {}\n", path, e),
199 None,
200 ),
201 }
202 } else if options.accept {
203 if args.is_empty() {
204 return (1, "zsocket: -a requires an argument\n".to_string(), None);
205 }
206
207 let listen_fd: RawFd = match args[0].parse() {
208 Ok(fd) => fd,
209 Err(_) => {
210 return (1, "zsocket: invalid numerical argument\n".to_string(), None);
211 }
212 };
213
214 if options.test {
215 match socket_test(listen_fd) {
216 Ok(true) => {}
217 Ok(false) => return (1, output, None),
218 Err(e) => return (1, format!("zsocket: poll error: {}\n", e), None),
219 }
220 }
221
222 match socket_accept(listen_fd) {
223 Ok((fd, path)) => {
224 if options.verbose {
225 output.push_str(&format!("new connection from {} is on fd {}\n", path, fd));
226 }
227 (0, output, Some(fd))
228 }
229 Err(e) => (
230 1,
231 format!("zsocket: could not accept connection: {}\n", e),
232 None,
233 ),
234 }
235 } else {
236 if args.is_empty() {
237 return (1, "zsocket: requires an argument\n".to_string(), None);
238 }
239
240 let path = args[0];
241
242 match socket_connect(path) {
243 Ok(fd) => {
244 if options.verbose {
245 output.push_str(&format!("{} is now on fd {}\n", path, fd));
246 }
247 (0, output, Some(fd))
248 }
249 Err(e) => (1, format!("zsocket: connection failed: {}\n", e), None),
250 }
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn test_zsocket_options_default() {
260 let opts = ZsocketOptions::default();
261 assert!(!opts.listen);
262 assert!(!opts.accept);
263 assert!(!opts.verbose);
264 assert!(!opts.test);
265 assert!(opts.target_fd.is_none());
266 }
267
268 #[test]
269 fn test_unix_socket_new() {
270 let sock = UnixSocket::new(5, "/tmp/test.sock", true);
271 assert_eq!(sock.fd, 5);
272 assert_eq!(sock.path, "/tmp/test.sock");
273 assert!(sock.is_listener);
274 }
275
276 #[test]
277 fn test_builtin_zsocket_listen_no_arg() {
278 let options = ZsocketOptions {
279 listen: true,
280 ..Default::default()
281 };
282 let (status, output, _) = builtin_zsocket(&[], &options);
283 assert_eq!(status, 1);
284 assert!(output.contains("requires"));
285 }
286
287 #[test]
288 fn test_builtin_zsocket_accept_no_arg() {
289 let options = ZsocketOptions {
290 accept: true,
291 ..Default::default()
292 };
293 let (status, output, _) = builtin_zsocket(&[], &options);
294 assert_eq!(status, 1);
295 assert!(output.contains("requires"));
296 }
297
298 #[test]
299 fn test_builtin_zsocket_connect_no_arg() {
300 let options = ZsocketOptions::default();
301 let (status, output, _) = builtin_zsocket(&[], &options);
302 assert_eq!(status, 1);
303 assert!(output.contains("requires"));
304 }
305
306 #[test]
307 fn test_builtin_zsocket_accept_invalid_fd() {
308 let options = ZsocketOptions {
309 accept: true,
310 ..Default::default()
311 };
312 let (status, output, _) = builtin_zsocket(&["abc"], &options);
313 assert_eq!(status, 1);
314 assert!(output.contains("invalid"));
315 }
316}