Skip to main content

zsh/
tcp.rs

1//! TCP networking module - port of Modules/tcp.c
2//!
3//! Provides ztcp builtin for TCP socket operations.
4
5use std::collections::HashMap;
6use std::io::{self};
7use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
8use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
9
10/// TCP session flags
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum TcpSessionType {
13    Outbound,
14    Inbound,
15    Listen,
16}
17
18/// A TCP session
19#[derive(Debug)]
20pub struct TcpSession {
21    pub fd: RawFd,
22    pub session_type: TcpSessionType,
23    pub local_addr: Option<SocketAddr>,
24    pub peer_addr: Option<SocketAddr>,
25    pub is_zftp: bool,
26}
27
28impl TcpSession {
29    pub fn new(fd: RawFd, session_type: TcpSessionType) -> Self {
30        Self {
31            fd,
32            session_type,
33            local_addr: None,
34            peer_addr: None,
35            is_zftp: false,
36        }
37    }
38
39    pub fn type_char(&self) -> char {
40        if self.is_zftp {
41            'Z'
42        } else {
43            match self.session_type {
44                TcpSessionType::Listen => 'L',
45                TcpSessionType::Inbound => 'I',
46                TcpSessionType::Outbound => 'O',
47            }
48        }
49    }
50
51    pub fn direction_str(&self) -> &'static str {
52        match self.session_type {
53            TcpSessionType::Listen => "-<",
54            TcpSessionType::Inbound => "<-",
55            TcpSessionType::Outbound => "->",
56        }
57    }
58}
59
60/// TCP sessions manager
61#[derive(Debug, Default)]
62pub struct TcpSessions {
63    sessions: HashMap<RawFd, TcpSession>,
64}
65
66impl TcpSessions {
67    pub fn new() -> Self {
68        Self::default()
69    }
70
71    pub fn add(&mut self, session: TcpSession) {
72        self.sessions.insert(session.fd, session);
73    }
74
75    pub fn get(&self, fd: RawFd) -> Option<&TcpSession> {
76        self.sessions.get(&fd)
77    }
78
79    pub fn get_by_ref(&self, fd: &RawFd) -> Option<&TcpSession> {
80        self.sessions.get(fd)
81    }
82
83    pub fn get_mut(&mut self, fd: RawFd) -> Option<&mut TcpSession> {
84        self.sessions.get_mut(&fd)
85    }
86
87    pub fn remove(&mut self, fd: RawFd) -> Option<TcpSession> {
88        self.sessions.remove(&fd)
89    }
90
91    pub fn iter(&self) -> impl Iterator<Item = (&RawFd, &TcpSession)> {
92        self.sessions.iter()
93    }
94
95    pub fn close_all(&mut self) {
96        for (fd, _) in self.sessions.drain() {
97            let _ = close_fd(fd);
98        }
99    }
100
101    pub fn len(&self) -> usize {
102        self.sessions.len()
103    }
104
105    pub fn is_empty(&self) -> bool {
106        self.sessions.is_empty()
107    }
108}
109
110fn close_fd(fd: RawFd) -> io::Result<()> {
111    #[cfg(unix)]
112    {
113        let result = unsafe { libc::close(fd) };
114        if result < 0 {
115            Err(io::Error::last_os_error())
116        } else {
117            Ok(())
118        }
119    }
120    #[cfg(not(unix))]
121    {
122        Ok(())
123    }
124}
125
126/// Options for ztcp builtin
127#[derive(Debug, Default)]
128pub struct ZtcpOptions {
129    pub close: bool,
130    pub listen: bool,
131    pub accept: bool,
132    pub force: bool,
133    pub verbose: bool,
134    pub test: bool,
135    pub list_format: bool,
136    pub target_fd: Option<RawFd>,
137}
138
139/// Connect to a TCP host with timeout on DNS and connect (default 10s).
140/// DNS resolution runs on a background thread so it can't hang the shell.
141pub fn tcp_connect(host: &str, port: u16) -> io::Result<(RawFd, SocketAddr, SocketAddr)> {
142    tcp_connect_timeout(host, port, std::time::Duration::from_secs(10))
143}
144
145/// Connect with explicit timeout.
146pub fn tcp_connect_timeout(
147    host: &str,
148    port: u16,
149    timeout: std::time::Duration,
150) -> io::Result<(RawFd, SocketAddr, SocketAddr)> {
151    // DNS resolution on a background thread — can hang for seconds on bad DNS
152    let addr_str = format!("{}:{}", host, port);
153    let (tx, rx) = std::sync::mpsc::channel();
154    let dns_str = addr_str.clone();
155    std::thread::Builder::new()
156        .name("dns-resolve".to_string())
157        .spawn(move || {
158            let result: io::Result<Vec<SocketAddr>> =
159                dns_str.to_socket_addrs().map(|a| a.collect());
160            let _ = tx.send(result);
161        })
162        .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
163
164    let addrs = rx
165        .recv_timeout(timeout)
166        .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "DNS resolution timed out"))?
167        .map_err(|e| {
168            tracing::warn!(host, error = %e, "DNS resolution failed");
169            e
170        })?;
171
172    if addrs.is_empty() {
173        return Err(io::Error::new(
174            io::ErrorKind::NotFound,
175            "host resolution failure",
176        ));
177    }
178
179    for addr in addrs {
180        match TcpStream::connect_timeout(&addr, timeout) {
181            Ok(stream) => {
182                tracing::debug!(%addr, "tcp: connected");
183                let local = stream.local_addr()?;
184                let peer = stream.peer_addr()?;
185                let fd = stream.as_raw_fd();
186                std::mem::forget(stream);
187                return Ok((fd, local, peer));
188            }
189            Err(e) => {
190                tracing::trace!(%addr, error = %e, "tcp: connect attempt failed");
191                continue;
192            }
193        }
194    }
195
196    Err(io::Error::new(
197        io::ErrorKind::ConnectionRefused,
198        "connection failed",
199    ))
200}
201
202/// Create a listening TCP socket
203pub fn tcp_listen(port: u16) -> io::Result<(RawFd, SocketAddr)> {
204    let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
205    let listener = TcpListener::bind(addr)?;
206    let local = listener.local_addr()?;
207    let fd = listener.as_raw_fd();
208    std::mem::forget(listener);
209    Ok((fd, local))
210}
211
212/// Accept a connection on a listening socket
213pub fn tcp_accept(listen_fd: RawFd) -> io::Result<(RawFd, SocketAddr, SocketAddr)> {
214    let listener = unsafe { TcpListener::from_raw_fd(listen_fd) };
215    let result = listener.accept();
216    std::mem::forget(listener);
217
218    let (stream, peer) = result?;
219    let local = stream.local_addr()?;
220    let fd = stream.as_raw_fd();
221    std::mem::forget(stream);
222    Ok((fd, local, peer))
223}
224
225/// Check if a socket has pending connections (for -t option)
226pub fn tcp_test_accept(listen_fd: RawFd) -> io::Result<bool> {
227    #[cfg(unix)]
228    {
229        let mut pfd = libc::pollfd {
230            fd: listen_fd,
231            events: libc::POLLIN,
232            revents: 0,
233        };
234
235        let result = unsafe { libc::poll(&mut pfd, 1, 0) };
236        if result < 0 {
237            Err(io::Error::last_os_error())
238        } else {
239            Ok(result > 0)
240        }
241    }
242
243    #[cfg(not(unix))]
244    {
245        Ok(true)
246    }
247}
248
249/// Close a TCP session
250pub fn tcp_close(sessions: &mut TcpSessions, fd: RawFd, force: bool) -> Result<(), String> {
251    if let Some(session) = sessions.get(fd) {
252        if session.is_zftp && !force {
253            return Err("use -f to force closure of a zftp control connection".to_string());
254        }
255    }
256
257    if let Some(_session) = sessions.remove(fd) {
258        close_fd(fd).map_err(|e| format!("connection close failed: {}", e))?;
259        Ok(())
260    } else {
261        Err(format!("fd {} not found in tcp table", fd))
262    }
263}
264
265/// Resolve a service name to port number
266pub fn resolve_port(service: &str) -> Option<u16> {
267    if let Ok(port) = service.parse::<u16>() {
268        return Some(port);
269    }
270
271    #[cfg(unix)]
272    {
273        use std::ffi::CString;
274        let service_c = CString::new(service).ok()?;
275        let proto_c = CString::new("tcp").ok()?;
276
277        unsafe {
278            let serv = libc::getservbyname(service_c.as_ptr(), proto_c.as_ptr());
279            if serv.is_null() {
280                None
281            } else {
282                Some(u16::from_be((*serv).s_port as u16))
283            }
284        }
285    }
286
287    #[cfg(not(unix))]
288    {
289        None
290    }
291}
292
293/// Resolve hostname to IP address
294pub fn resolve_host(host: &str) -> io::Result<IpAddr> {
295    if let Ok(ip) = host.parse::<IpAddr>() {
296        return Ok(ip);
297    }
298
299    let addrs: Vec<SocketAddr> = format!("{}:0", host).to_socket_addrs()?.collect();
300    addrs
301        .first()
302        .map(|a| a.ip())
303        .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "host resolution failure"))
304}
305
306/// Reverse DNS lookup
307pub fn reverse_lookup(addr: &IpAddr) -> Option<String> {
308    let socket_addr = SocketAddr::new(*addr, 0);
309    let hostname = dns_lookup_reverse(&socket_addr);
310    hostname
311}
312
313fn dns_lookup_reverse(_addr: &SocketAddr) -> Option<String> {
314    None
315}
316
317/// Format a socket address for display
318pub fn format_addr(addr: &SocketAddr, resolve: bool) -> String {
319    if resolve {
320        if let Some(hostname) = reverse_lookup(&addr.ip()) {
321            return format!("{}:{}", hostname, addr.port());
322        }
323    }
324    format!("{}:{}", addr.ip(), addr.port())
325}
326
327/// Execute ztcp builtin
328pub fn builtin_ztcp(
329    args: &[&str],
330    options: &ZtcpOptions,
331    sessions: &mut TcpSessions,
332) -> (i32, String) {
333    let mut output = String::new();
334
335    if options.close {
336        if args.is_empty() {
337            sessions.close_all();
338            return (0, output);
339        }
340
341        let fd: RawFd = match args[0].parse() {
342            Ok(fd) => fd,
343            Err(_) => {
344                return (
345                    1,
346                    format!("ztcp: {} is an invalid argument to -c\n", args[0]),
347                );
348            }
349        };
350
351        match tcp_close(sessions, fd, options.force) {
352            Ok(()) => (0, output),
353            Err(e) => (1, format!("ztcp: {}\n", e)),
354        }
355    } else if options.listen {
356        if args.is_empty() {
357            return (1, "ztcp: -l requires an argument\n".to_string());
358        }
359
360        let port = match resolve_port(args[0]) {
361            Some(p) => p,
362            None => {
363                return (1, "ztcp: bad service name or port number\n".to_string());
364            }
365        };
366
367        match tcp_listen(port) {
368            Ok((fd, local)) => {
369                let mut session = TcpSession::new(fd, TcpSessionType::Listen);
370                session.local_addr = Some(local);
371                let result_fd = options.target_fd.unwrap_or(fd);
372                session.fd = result_fd;
373                sessions.add(session);
374
375                if options.verbose {
376                    output.push_str(&format!("{} listener is on fd {}\n", port, result_fd));
377                }
378                (0, output)
379            }
380            Err(e) => (1, format!("ztcp: could not listen: {}\n", e)),
381        }
382    } else if options.accept {
383        if args.is_empty() {
384            return (1, "ztcp: -a requires an argument\n".to_string());
385        }
386
387        let listen_fd: RawFd = match args[0].parse() {
388            Ok(fd) => fd,
389            Err(_) => {
390                return (1, "ztcp: invalid numerical argument\n".to_string());
391            }
392        };
393
394        if let Some(session) = sessions.get(listen_fd) {
395            if session.session_type != TcpSessionType::Listen {
396                return (1, "ztcp: tcp connection not a listener\n".to_string());
397            }
398        } else {
399            return (
400                1,
401                format!(
402                    "ztcp: fd {} is not registered as a tcp connection\n",
403                    args[0]
404                ),
405            );
406        }
407
408        if options.test {
409            match tcp_test_accept(listen_fd) {
410                Ok(true) => {}
411                Ok(false) => return (1, output),
412                Err(e) => return (1, format!("ztcp: poll error: {}\n", e)),
413            }
414        }
415
416        match tcp_accept(listen_fd) {
417            Ok((fd, local, peer)) => {
418                let mut session = TcpSession::new(fd, TcpSessionType::Inbound);
419                session.local_addr = Some(local);
420                session.peer_addr = Some(peer);
421                let result_fd = options.target_fd.unwrap_or(fd);
422                session.fd = result_fd;
423                sessions.add(session);
424
425                if options.verbose {
426                    output.push_str(&format!("{} is on fd {}\n", peer.port(), result_fd));
427                }
428                (0, output)
429            }
430            Err(e) => (1, format!("ztcp: could not accept connection: {}\n", e)),
431        }
432    } else if args.is_empty() {
433        for (_, session) in sessions.iter() {
434            let local_str = session
435                .local_addr
436                .map(|a| format_addr(&a, true))
437                .unwrap_or_else(|| "?:?".to_string());
438            let peer_str = session
439                .peer_addr
440                .map(|a| format_addr(&a, true))
441                .unwrap_or_else(|| "?:?".to_string());
442
443            if options.list_format {
444                output.push_str(&format!(
445                    "{} {} {} {} {} {}\n",
446                    session.fd,
447                    session.type_char(),
448                    session
449                        .local_addr
450                        .map(|a| a.ip().to_string())
451                        .unwrap_or_default(),
452                    session.local_addr.map(|a| a.port()).unwrap_or(0),
453                    session
454                        .peer_addr
455                        .map(|a| a.ip().to_string())
456                        .unwrap_or_default(),
457                    session.peer_addr.map(|a| a.port()).unwrap_or(0),
458                ));
459            } else {
460                let zftp_str = if session.is_zftp { " ZFTP" } else { "" };
461                output.push_str(&format!(
462                    "{} {} {} is on fd {}{}\n",
463                    local_str,
464                    session.direction_str(),
465                    peer_str,
466                    session.fd,
467                    zftp_str,
468                ));
469            }
470        }
471        (0, output)
472    } else {
473        let host = args[0];
474        let port = if args.len() > 1 {
475            resolve_port(args[1]).unwrap_or(23)
476        } else {
477            23
478        };
479
480        match tcp_connect(host, port) {
481            Ok((fd, local, peer)) => {
482                let mut session = TcpSession::new(fd, TcpSessionType::Outbound);
483                session.local_addr = Some(local);
484                session.peer_addr = Some(peer);
485                let result_fd = options.target_fd.unwrap_or(fd);
486                session.fd = result_fd;
487                sessions.add(session);
488
489                if options.verbose {
490                    output.push_str(&format!("{}:{} is now on fd {}\n", host, port, result_fd));
491                }
492                (0, output)
493            }
494            Err(e) => (1, format!("ztcp: connection failed: {}\n", e)),
495        }
496    }
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502    use std::net::Ipv6Addr;
503
504    #[test]
505    fn test_tcp_session_type_char() {
506        let session = TcpSession::new(3, TcpSessionType::Outbound);
507        assert_eq!(session.type_char(), 'O');
508
509        let session = TcpSession::new(3, TcpSessionType::Inbound);
510        assert_eq!(session.type_char(), 'I');
511
512        let session = TcpSession::new(3, TcpSessionType::Listen);
513        assert_eq!(session.type_char(), 'L');
514
515        let mut session = TcpSession::new(3, TcpSessionType::Outbound);
516        session.is_zftp = true;
517        assert_eq!(session.type_char(), 'Z');
518    }
519
520    #[test]
521    fn test_tcp_session_direction() {
522        let session = TcpSession::new(3, TcpSessionType::Outbound);
523        assert_eq!(session.direction_str(), "->");
524
525        let session = TcpSession::new(3, TcpSessionType::Inbound);
526        assert_eq!(session.direction_str(), "<-");
527
528        let session = TcpSession::new(3, TcpSessionType::Listen);
529        assert_eq!(session.direction_str(), "-<");
530    }
531
532    #[test]
533    fn test_tcp_sessions_manager() {
534        let mut sessions = TcpSessions::new();
535        assert!(sessions.is_empty());
536
537        let session = TcpSession::new(5, TcpSessionType::Outbound);
538        sessions.add(session);
539        assert_eq!(sessions.len(), 1);
540
541        assert!(sessions.get(5).is_some());
542        assert!(sessions.get(6).is_none());
543
544        sessions.remove(5);
545        assert!(sessions.is_empty());
546    }
547
548    #[test]
549    fn test_resolve_port() {
550        assert_eq!(resolve_port("80"), Some(80));
551        assert_eq!(resolve_port("443"), Some(443));
552        assert_eq!(resolve_port("invalid"), None);
553    }
554
555    #[test]
556    fn test_resolve_host() {
557        let ip = resolve_host("127.0.0.1").unwrap();
558        assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
559
560        let ip = resolve_host("::1").unwrap();
561        assert_eq!(ip, IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)));
562    }
563
564    #[test]
565    fn test_format_addr() {
566        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
567        let formatted = format_addr(&addr, false);
568        assert_eq!(formatted, "127.0.0.1:8080");
569    }
570
571    #[test]
572    fn test_builtin_ztcp_list_empty() {
573        let mut sessions = TcpSessions::new();
574        let options = ZtcpOptions::default();
575        let (status, output) = builtin_ztcp(&[], &options, &mut sessions);
576        assert_eq!(status, 0);
577        assert!(output.is_empty());
578    }
579
580    #[test]
581    fn test_builtin_ztcp_close_all() {
582        let mut sessions = TcpSessions::new();
583        let options = ZtcpOptions {
584            close: true,
585            ..Default::default()
586        };
587        let (status, _) = builtin_ztcp(&[], &options, &mut sessions);
588        assert_eq!(status, 0);
589    }
590
591    #[test]
592    fn test_builtin_ztcp_listen_no_arg() {
593        let mut sessions = TcpSessions::new();
594        let options = ZtcpOptions {
595            listen: true,
596            ..Default::default()
597        };
598        let (status, output) = builtin_ztcp(&[], &options, &mut sessions);
599        assert_eq!(status, 1);
600        assert!(output.contains("requires an argument"));
601    }
602
603    #[test]
604    fn test_builtin_ztcp_accept_no_arg() {
605        let mut sessions = TcpSessions::new();
606        let options = ZtcpOptions {
607            accept: true,
608            ..Default::default()
609        };
610        let (status, output) = builtin_ztcp(&[], &options, &mut sessions);
611        assert_eq!(status, 1);
612        assert!(output.contains("requires an argument"));
613    }
614}