Skip to main content

zsh/
socket.rs

1//! Unix domain socket module - port of Modules/socket.c
2//!
3//! Provides zsocket builtin for Unix domain socket operations.
4
5use std::io;
6use std::os::unix::io::RawFd;
7
8/// Options for zsocket builtin
9#[derive(Debug, Default)]
10pub struct ZsocketOptions {
11    pub listen: bool,
12    pub accept: bool,
13    pub verbose: bool,
14    pub test: bool,
15    pub target_fd: Option<RawFd>,
16}
17
18/// Unix socket session
19#[derive(Debug)]
20pub struct UnixSocket {
21    pub fd: RawFd,
22    pub path: String,
23    pub is_listener: bool,
24}
25
26impl UnixSocket {
27    pub fn new(fd: RawFd, path: &str, is_listener: bool) -> Self {
28        Self {
29            fd,
30            path: path.to_string(),
31            is_listener,
32        }
33    }
34}
35
36/// Create a listening Unix socket
37#[cfg(unix)]
38pub fn socket_listen(path: &str) -> io::Result<RawFd> {
39    let fd = unsafe { libc::socket(libc::PF_UNIX, libc::SOCK_STREAM, 0) };
40    if fd < 0 {
41        return Err(io::Error::last_os_error());
42    }
43
44    let mut addr: libc::sockaddr_un = unsafe { std::mem::zeroed() };
45    addr.sun_family = libc::AF_UNIX as libc::sa_family_t;
46
47    let path_bytes = path.as_bytes();
48    let max_len = addr.sun_path.len() - 1;
49    let copy_len = path_bytes.len().min(max_len);
50
51    for (i, &byte) in path_bytes[..copy_len].iter().enumerate() {
52        addr.sun_path[i] = byte as libc::c_char;
53    }
54
55    let result = unsafe {
56        libc::bind(
57            fd,
58            &addr as *const _ as *const libc::sockaddr,
59            std::mem::size_of::<libc::sockaddr_un>() as libc::socklen_t,
60        )
61    };
62
63    if result < 0 {
64        let err = io::Error::last_os_error();
65        unsafe { libc::close(fd) };
66        return Err(err);
67    }
68
69    let result = unsafe { libc::listen(fd, 1) };
70    if result < 0 {
71        let err = io::Error::last_os_error();
72        unsafe { libc::close(fd) };
73        return Err(err);
74    }
75
76    Ok(fd)
77}
78
79/// Accept a connection on a listening Unix socket
80#[cfg(unix)]
81pub fn socket_accept(listen_fd: RawFd) -> io::Result<(RawFd, String)> {
82    let mut addr: libc::sockaddr_un = unsafe { std::mem::zeroed() };
83    let mut len: libc::socklen_t = std::mem::size_of::<libc::sockaddr_un>() as libc::socklen_t;
84
85    let fd = loop {
86        let result = unsafe {
87            libc::accept(
88                listen_fd,
89                &mut addr as *mut _ as *mut libc::sockaddr,
90                &mut len,
91            )
92        };
93
94        if result < 0 {
95            let err = io::Error::last_os_error();
96            if err.kind() == io::ErrorKind::Interrupted {
97                continue;
98            }
99            return Err(err);
100        }
101
102        break result;
103    };
104
105    let path = addr
106        .sun_path
107        .iter()
108        .take_while(|&&c| c != 0)
109        .map(|&c| c as u8 as char)
110        .collect::<String>();
111
112    Ok((fd, path))
113}
114
115/// Test if a socket has pending connections
116#[cfg(unix)]
117pub fn socket_test(fd: RawFd) -> io::Result<bool> {
118    let mut pfd = libc::pollfd {
119        fd,
120        events: libc::POLLIN,
121        revents: 0,
122    };
123
124    let result = unsafe { libc::poll(&mut pfd, 1, 0) };
125    if result < 0 {
126        return Err(io::Error::last_os_error());
127    }
128
129    Ok(result > 0)
130}
131
132/// Connect to a Unix socket
133#[cfg(unix)]
134pub fn socket_connect(path: &str) -> io::Result<RawFd> {
135    let fd = unsafe { libc::socket(libc::PF_UNIX, libc::SOCK_STREAM, 0) };
136    if fd < 0 {
137        return Err(io::Error::last_os_error());
138    }
139
140    let mut addr: libc::sockaddr_un = unsafe { std::mem::zeroed() };
141    addr.sun_family = libc::AF_UNIX as libc::sa_family_t;
142
143    let path_bytes = path.as_bytes();
144    let max_len = addr.sun_path.len() - 1;
145    let copy_len = path_bytes.len().min(max_len);
146
147    for (i, &byte) in path_bytes[..copy_len].iter().enumerate() {
148        addr.sun_path[i] = byte as libc::c_char;
149    }
150
151    let result = unsafe {
152        libc::connect(
153            fd,
154            &addr as *const _ as *const libc::sockaddr,
155            std::mem::size_of::<libc::sockaddr_un>() as libc::socklen_t,
156        )
157    };
158
159    if result < 0 {
160        let err = io::Error::last_os_error();
161        unsafe { libc::close(fd) };
162        return Err(err);
163    }
164
165    Ok(fd)
166}
167
168/// Close a socket
169#[cfg(unix)]
170pub fn socket_close(fd: RawFd) -> io::Result<()> {
171    let result = unsafe { libc::close(fd) };
172    if result < 0 {
173        return Err(io::Error::last_os_error());
174    }
175    Ok(())
176}
177
178/// Execute zsocket builtin
179pub fn builtin_zsocket(args: &[&str], options: &ZsocketOptions) -> (i32, String, Option<RawFd>) {
180    let mut output = String::new();
181
182    if options.listen {
183        if args.is_empty() {
184            return (1, "zsocket: -l requires an argument\n".to_string(), None);
185        }
186
187        let path = args[0];
188
189        match socket_listen(path) {
190            Ok(fd) => {
191                if options.verbose {
192                    output.push_str(&format!("{} listener is on fd {}\n", path, fd));
193                }
194                (0, output, Some(fd))
195            }
196            Err(e) => (
197                1,
198                format!("zsocket: could not bind to {}: {}\n", path, e),
199                None,
200            ),
201        }
202    } else if options.accept {
203        if args.is_empty() {
204            return (1, "zsocket: -a requires an argument\n".to_string(), None);
205        }
206
207        let listen_fd: RawFd = match args[0].parse() {
208            Ok(fd) => fd,
209            Err(_) => {
210                return (1, "zsocket: invalid numerical argument\n".to_string(), None);
211            }
212        };
213
214        if options.test {
215            match socket_test(listen_fd) {
216                Ok(true) => {}
217                Ok(false) => return (1, output, None),
218                Err(e) => return (1, format!("zsocket: poll error: {}\n", e), None),
219            }
220        }
221
222        match socket_accept(listen_fd) {
223            Ok((fd, path)) => {
224                if options.verbose {
225                    output.push_str(&format!("new connection from {} is on fd {}\n", path, fd));
226                }
227                (0, output, Some(fd))
228            }
229            Err(e) => (
230                1,
231                format!("zsocket: could not accept connection: {}\n", e),
232                None,
233            ),
234        }
235    } else {
236        if args.is_empty() {
237            return (1, "zsocket: requires an argument\n".to_string(), None);
238        }
239
240        let path = args[0];
241
242        match socket_connect(path) {
243            Ok(fd) => {
244                if options.verbose {
245                    output.push_str(&format!("{} is now on fd {}\n", path, fd));
246                }
247                (0, output, Some(fd))
248            }
249            Err(e) => (1, format!("zsocket: connection failed: {}\n", e), None),
250        }
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn test_zsocket_options_default() {
260        let opts = ZsocketOptions::default();
261        assert!(!opts.listen);
262        assert!(!opts.accept);
263        assert!(!opts.verbose);
264        assert!(!opts.test);
265        assert!(opts.target_fd.is_none());
266    }
267
268    #[test]
269    fn test_unix_socket_new() {
270        let sock = UnixSocket::new(5, "/tmp/test.sock", true);
271        assert_eq!(sock.fd, 5);
272        assert_eq!(sock.path, "/tmp/test.sock");
273        assert!(sock.is_listener);
274    }
275
276    #[test]
277    fn test_builtin_zsocket_listen_no_arg() {
278        let options = ZsocketOptions {
279            listen: true,
280            ..Default::default()
281        };
282        let (status, output, _) = builtin_zsocket(&[], &options);
283        assert_eq!(status, 1);
284        assert!(output.contains("requires"));
285    }
286
287    #[test]
288    fn test_builtin_zsocket_accept_no_arg() {
289        let options = ZsocketOptions {
290            accept: true,
291            ..Default::default()
292        };
293        let (status, output, _) = builtin_zsocket(&[], &options);
294        assert_eq!(status, 1);
295        assert!(output.contains("requires"));
296    }
297
298    #[test]
299    fn test_builtin_zsocket_connect_no_arg() {
300        let options = ZsocketOptions::default();
301        let (status, output, _) = builtin_zsocket(&[], &options);
302        assert_eq!(status, 1);
303        assert!(output.contains("requires"));
304    }
305
306    #[test]
307    fn test_builtin_zsocket_accept_invalid_fd() {
308        let options = ZsocketOptions {
309            accept: true,
310            ..Default::default()
311        };
312        let (status, output, _) = builtin_zsocket(&["abc"], &options);
313        assert_eq!(status, 1);
314        assert!(output.contains("invalid"));
315    }
316}