1use std::collections::HashMap;
6use std::io::{self};
7use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
8use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum TcpSessionType {
13 Outbound,
14 Inbound,
15 Listen,
16}
17
18#[derive(Debug)]
20pub struct TcpSession {
21 pub fd: RawFd,
22 pub session_type: TcpSessionType,
23 pub local_addr: Option<SocketAddr>,
24 pub peer_addr: Option<SocketAddr>,
25 pub is_zftp: bool,
26}
27
28impl TcpSession {
29 pub fn new(fd: RawFd, session_type: TcpSessionType) -> Self {
30 Self {
31 fd,
32 session_type,
33 local_addr: None,
34 peer_addr: None,
35 is_zftp: false,
36 }
37 }
38
39 pub fn type_char(&self) -> char {
40 if self.is_zftp {
41 'Z'
42 } else {
43 match self.session_type {
44 TcpSessionType::Listen => 'L',
45 TcpSessionType::Inbound => 'I',
46 TcpSessionType::Outbound => 'O',
47 }
48 }
49 }
50
51 pub fn direction_str(&self) -> &'static str {
52 match self.session_type {
53 TcpSessionType::Listen => "-<",
54 TcpSessionType::Inbound => "<-",
55 TcpSessionType::Outbound => "->",
56 }
57 }
58}
59
60#[derive(Debug, Default)]
62pub struct TcpSessions {
63 sessions: HashMap<RawFd, TcpSession>,
64}
65
66impl TcpSessions {
67 pub fn new() -> Self {
68 Self::default()
69 }
70
71 pub fn add(&mut self, session: TcpSession) {
72 self.sessions.insert(session.fd, session);
73 }
74
75 pub fn get(&self, fd: RawFd) -> Option<&TcpSession> {
76 self.sessions.get(&fd)
77 }
78
79 pub fn get_by_ref(&self, fd: &RawFd) -> Option<&TcpSession> {
80 self.sessions.get(fd)
81 }
82
83 pub fn get_mut(&mut self, fd: RawFd) -> Option<&mut TcpSession> {
84 self.sessions.get_mut(&fd)
85 }
86
87 pub fn remove(&mut self, fd: RawFd) -> Option<TcpSession> {
88 self.sessions.remove(&fd)
89 }
90
91 pub fn iter(&self) -> impl Iterator<Item = (&RawFd, &TcpSession)> {
92 self.sessions.iter()
93 }
94
95 pub fn close_all(&mut self) {
96 for (fd, _) in self.sessions.drain() {
97 let _ = close_fd(fd);
98 }
99 }
100
101 pub fn len(&self) -> usize {
102 self.sessions.len()
103 }
104
105 pub fn is_empty(&self) -> bool {
106 self.sessions.is_empty()
107 }
108}
109
110fn close_fd(fd: RawFd) -> io::Result<()> {
111 #[cfg(unix)]
112 {
113 let result = unsafe { libc::close(fd) };
114 if result < 0 {
115 Err(io::Error::last_os_error())
116 } else {
117 Ok(())
118 }
119 }
120 #[cfg(not(unix))]
121 {
122 Ok(())
123 }
124}
125
126#[derive(Debug, Default)]
128pub struct ZtcpOptions {
129 pub close: bool,
130 pub listen: bool,
131 pub accept: bool,
132 pub force: bool,
133 pub verbose: bool,
134 pub test: bool,
135 pub list_format: bool,
136 pub target_fd: Option<RawFd>,
137}
138
139pub fn tcp_connect(host: &str, port: u16) -> io::Result<(RawFd, SocketAddr, SocketAddr)> {
142 tcp_connect_timeout(host, port, std::time::Duration::from_secs(10))
143}
144
145pub fn tcp_connect_timeout(
147 host: &str,
148 port: u16,
149 timeout: std::time::Duration,
150) -> io::Result<(RawFd, SocketAddr, SocketAddr)> {
151 let addr_str = format!("{}:{}", host, port);
153 let (tx, rx) = std::sync::mpsc::channel();
154 let dns_str = addr_str.clone();
155 std::thread::Builder::new()
156 .name("dns-resolve".to_string())
157 .spawn(move || {
158 let result: io::Result<Vec<SocketAddr>> =
159 dns_str.to_socket_addrs().map(|a| a.collect());
160 let _ = tx.send(result);
161 })
162 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
163
164 let addrs = rx
165 .recv_timeout(timeout)
166 .map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "DNS resolution timed out"))?
167 .map_err(|e| {
168 tracing::warn!(host, error = %e, "DNS resolution failed");
169 e
170 })?;
171
172 if addrs.is_empty() {
173 return Err(io::Error::new(
174 io::ErrorKind::NotFound,
175 "host resolution failure",
176 ));
177 }
178
179 for addr in addrs {
180 match TcpStream::connect_timeout(&addr, timeout) {
181 Ok(stream) => {
182 tracing::debug!(%addr, "tcp: connected");
183 let local = stream.local_addr()?;
184 let peer = stream.peer_addr()?;
185 let fd = stream.as_raw_fd();
186 std::mem::forget(stream);
187 return Ok((fd, local, peer));
188 }
189 Err(e) => {
190 tracing::trace!(%addr, error = %e, "tcp: connect attempt failed");
191 continue;
192 }
193 }
194 }
195
196 Err(io::Error::new(
197 io::ErrorKind::ConnectionRefused,
198 "connection failed",
199 ))
200}
201
202pub fn tcp_listen(port: u16) -> io::Result<(RawFd, SocketAddr)> {
204 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
205 let listener = TcpListener::bind(addr)?;
206 let local = listener.local_addr()?;
207 let fd = listener.as_raw_fd();
208 std::mem::forget(listener);
209 Ok((fd, local))
210}
211
212pub fn tcp_accept(listen_fd: RawFd) -> io::Result<(RawFd, SocketAddr, SocketAddr)> {
214 let listener = unsafe { TcpListener::from_raw_fd(listen_fd) };
215 let result = listener.accept();
216 std::mem::forget(listener);
217
218 let (stream, peer) = result?;
219 let local = stream.local_addr()?;
220 let fd = stream.as_raw_fd();
221 std::mem::forget(stream);
222 Ok((fd, local, peer))
223}
224
225pub fn tcp_test_accept(listen_fd: RawFd) -> io::Result<bool> {
227 #[cfg(unix)]
228 {
229 let mut pfd = libc::pollfd {
230 fd: listen_fd,
231 events: libc::POLLIN,
232 revents: 0,
233 };
234
235 let result = unsafe { libc::poll(&mut pfd, 1, 0) };
236 if result < 0 {
237 Err(io::Error::last_os_error())
238 } else {
239 Ok(result > 0)
240 }
241 }
242
243 #[cfg(not(unix))]
244 {
245 Ok(true)
246 }
247}
248
249pub fn tcp_close(sessions: &mut TcpSessions, fd: RawFd, force: bool) -> Result<(), String> {
251 if let Some(session) = sessions.get(fd) {
252 if session.is_zftp && !force {
253 return Err("use -f to force closure of a zftp control connection".to_string());
254 }
255 }
256
257 if let Some(_session) = sessions.remove(fd) {
258 close_fd(fd).map_err(|e| format!("connection close failed: {}", e))?;
259 Ok(())
260 } else {
261 Err(format!("fd {} not found in tcp table", fd))
262 }
263}
264
265pub fn resolve_port(service: &str) -> Option<u16> {
267 if let Ok(port) = service.parse::<u16>() {
268 return Some(port);
269 }
270
271 #[cfg(unix)]
272 {
273 use std::ffi::CString;
274 let service_c = CString::new(service).ok()?;
275 let proto_c = CString::new("tcp").ok()?;
276
277 unsafe {
278 let serv = libc::getservbyname(service_c.as_ptr(), proto_c.as_ptr());
279 if serv.is_null() {
280 None
281 } else {
282 Some(u16::from_be((*serv).s_port as u16))
283 }
284 }
285 }
286
287 #[cfg(not(unix))]
288 {
289 None
290 }
291}
292
293pub fn resolve_host(host: &str) -> io::Result<IpAddr> {
295 if let Ok(ip) = host.parse::<IpAddr>() {
296 return Ok(ip);
297 }
298
299 let addrs: Vec<SocketAddr> = format!("{}:0", host).to_socket_addrs()?.collect();
300 addrs
301 .first()
302 .map(|a| a.ip())
303 .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "host resolution failure"))
304}
305
306pub fn reverse_lookup(addr: &IpAddr) -> Option<String> {
308 let socket_addr = SocketAddr::new(*addr, 0);
309 let hostname = dns_lookup_reverse(&socket_addr);
310 hostname
311}
312
313fn dns_lookup_reverse(_addr: &SocketAddr) -> Option<String> {
314 None
315}
316
317pub fn format_addr(addr: &SocketAddr, resolve: bool) -> String {
319 if resolve {
320 if let Some(hostname) = reverse_lookup(&addr.ip()) {
321 return format!("{}:{}", hostname, addr.port());
322 }
323 }
324 format!("{}:{}", addr.ip(), addr.port())
325}
326
327pub fn builtin_ztcp(
329 args: &[&str],
330 options: &ZtcpOptions,
331 sessions: &mut TcpSessions,
332) -> (i32, String) {
333 let mut output = String::new();
334
335 if options.close {
336 if args.is_empty() {
337 sessions.close_all();
338 return (0, output);
339 }
340
341 let fd: RawFd = match args[0].parse() {
342 Ok(fd) => fd,
343 Err(_) => {
344 return (
345 1,
346 format!("ztcp: {} is an invalid argument to -c\n", args[0]),
347 );
348 }
349 };
350
351 match tcp_close(sessions, fd, options.force) {
352 Ok(()) => (0, output),
353 Err(e) => (1, format!("ztcp: {}\n", e)),
354 }
355 } else if options.listen {
356 if args.is_empty() {
357 return (1, "ztcp: -l requires an argument\n".to_string());
358 }
359
360 let port = match resolve_port(args[0]) {
361 Some(p) => p,
362 None => {
363 return (1, "ztcp: bad service name or port number\n".to_string());
364 }
365 };
366
367 match tcp_listen(port) {
368 Ok((fd, local)) => {
369 let mut session = TcpSession::new(fd, TcpSessionType::Listen);
370 session.local_addr = Some(local);
371 let result_fd = options.target_fd.unwrap_or(fd);
372 session.fd = result_fd;
373 sessions.add(session);
374
375 if options.verbose {
376 output.push_str(&format!("{} listener is on fd {}\n", port, result_fd));
377 }
378 (0, output)
379 }
380 Err(e) => (1, format!("ztcp: could not listen: {}\n", e)),
381 }
382 } else if options.accept {
383 if args.is_empty() {
384 return (1, "ztcp: -a requires an argument\n".to_string());
385 }
386
387 let listen_fd: RawFd = match args[0].parse() {
388 Ok(fd) => fd,
389 Err(_) => {
390 return (1, "ztcp: invalid numerical argument\n".to_string());
391 }
392 };
393
394 if let Some(session) = sessions.get(listen_fd) {
395 if session.session_type != TcpSessionType::Listen {
396 return (1, "ztcp: tcp connection not a listener\n".to_string());
397 }
398 } else {
399 return (
400 1,
401 format!(
402 "ztcp: fd {} is not registered as a tcp connection\n",
403 args[0]
404 ),
405 );
406 }
407
408 if options.test {
409 match tcp_test_accept(listen_fd) {
410 Ok(true) => {}
411 Ok(false) => return (1, output),
412 Err(e) => return (1, format!("ztcp: poll error: {}\n", e)),
413 }
414 }
415
416 match tcp_accept(listen_fd) {
417 Ok((fd, local, peer)) => {
418 let mut session = TcpSession::new(fd, TcpSessionType::Inbound);
419 session.local_addr = Some(local);
420 session.peer_addr = Some(peer);
421 let result_fd = options.target_fd.unwrap_or(fd);
422 session.fd = result_fd;
423 sessions.add(session);
424
425 if options.verbose {
426 output.push_str(&format!("{} is on fd {}\n", peer.port(), result_fd));
427 }
428 (0, output)
429 }
430 Err(e) => (1, format!("ztcp: could not accept connection: {}\n", e)),
431 }
432 } else if args.is_empty() {
433 for (_, session) in sessions.iter() {
434 let local_str = session
435 .local_addr
436 .map(|a| format_addr(&a, true))
437 .unwrap_or_else(|| "?:?".to_string());
438 let peer_str = session
439 .peer_addr
440 .map(|a| format_addr(&a, true))
441 .unwrap_or_else(|| "?:?".to_string());
442
443 if options.list_format {
444 output.push_str(&format!(
445 "{} {} {} {} {} {}\n",
446 session.fd,
447 session.type_char(),
448 session
449 .local_addr
450 .map(|a| a.ip().to_string())
451 .unwrap_or_default(),
452 session.local_addr.map(|a| a.port()).unwrap_or(0),
453 session
454 .peer_addr
455 .map(|a| a.ip().to_string())
456 .unwrap_or_default(),
457 session.peer_addr.map(|a| a.port()).unwrap_or(0),
458 ));
459 } else {
460 let zftp_str = if session.is_zftp { " ZFTP" } else { "" };
461 output.push_str(&format!(
462 "{} {} {} is on fd {}{}\n",
463 local_str,
464 session.direction_str(),
465 peer_str,
466 session.fd,
467 zftp_str,
468 ));
469 }
470 }
471 (0, output)
472 } else {
473 let host = args[0];
474 let port = if args.len() > 1 {
475 resolve_port(args[1]).unwrap_or(23)
476 } else {
477 23
478 };
479
480 match tcp_connect(host, port) {
481 Ok((fd, local, peer)) => {
482 let mut session = TcpSession::new(fd, TcpSessionType::Outbound);
483 session.local_addr = Some(local);
484 session.peer_addr = Some(peer);
485 let result_fd = options.target_fd.unwrap_or(fd);
486 session.fd = result_fd;
487 sessions.add(session);
488
489 if options.verbose {
490 output.push_str(&format!("{}:{} is now on fd {}\n", host, port, result_fd));
491 }
492 (0, output)
493 }
494 Err(e) => (1, format!("ztcp: connection failed: {}\n", e)),
495 }
496 }
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502 use std::net::Ipv6Addr;
503
504 #[test]
505 fn test_tcp_session_type_char() {
506 let session = TcpSession::new(3, TcpSessionType::Outbound);
507 assert_eq!(session.type_char(), 'O');
508
509 let session = TcpSession::new(3, TcpSessionType::Inbound);
510 assert_eq!(session.type_char(), 'I');
511
512 let session = TcpSession::new(3, TcpSessionType::Listen);
513 assert_eq!(session.type_char(), 'L');
514
515 let mut session = TcpSession::new(3, TcpSessionType::Outbound);
516 session.is_zftp = true;
517 assert_eq!(session.type_char(), 'Z');
518 }
519
520 #[test]
521 fn test_tcp_session_direction() {
522 let session = TcpSession::new(3, TcpSessionType::Outbound);
523 assert_eq!(session.direction_str(), "->");
524
525 let session = TcpSession::new(3, TcpSessionType::Inbound);
526 assert_eq!(session.direction_str(), "<-");
527
528 let session = TcpSession::new(3, TcpSessionType::Listen);
529 assert_eq!(session.direction_str(), "-<");
530 }
531
532 #[test]
533 fn test_tcp_sessions_manager() {
534 let mut sessions = TcpSessions::new();
535 assert!(sessions.is_empty());
536
537 let session = TcpSession::new(5, TcpSessionType::Outbound);
538 sessions.add(session);
539 assert_eq!(sessions.len(), 1);
540
541 assert!(sessions.get(5).is_some());
542 assert!(sessions.get(6).is_none());
543
544 sessions.remove(5);
545 assert!(sessions.is_empty());
546 }
547
548 #[test]
549 fn test_resolve_port() {
550 assert_eq!(resolve_port("80"), Some(80));
551 assert_eq!(resolve_port("443"), Some(443));
552 assert_eq!(resolve_port("invalid"), None);
553 }
554
555 #[test]
556 fn test_resolve_host() {
557 let ip = resolve_host("127.0.0.1").unwrap();
558 assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
559
560 let ip = resolve_host("::1").unwrap();
561 assert_eq!(ip, IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)));
562 }
563
564 #[test]
565 fn test_format_addr() {
566 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
567 let formatted = format_addr(&addr, false);
568 assert_eq!(formatted, "127.0.0.1:8080");
569 }
570
571 #[test]
572 fn test_builtin_ztcp_list_empty() {
573 let mut sessions = TcpSessions::new();
574 let options = ZtcpOptions::default();
575 let (status, output) = builtin_ztcp(&[], &options, &mut sessions);
576 assert_eq!(status, 0);
577 assert!(output.is_empty());
578 }
579
580 #[test]
581 fn test_builtin_ztcp_close_all() {
582 let mut sessions = TcpSessions::new();
583 let options = ZtcpOptions {
584 close: true,
585 ..Default::default()
586 };
587 let (status, _) = builtin_ztcp(&[], &options, &mut sessions);
588 assert_eq!(status, 0);
589 }
590
591 #[test]
592 fn test_builtin_ztcp_listen_no_arg() {
593 let mut sessions = TcpSessions::new();
594 let options = ZtcpOptions {
595 listen: true,
596 ..Default::default()
597 };
598 let (status, output) = builtin_ztcp(&[], &options, &mut sessions);
599 assert_eq!(status, 1);
600 assert!(output.contains("requires an argument"));
601 }
602
603 #[test]
604 fn test_builtin_ztcp_accept_no_arg() {
605 let mut sessions = TcpSessions::new();
606 let options = ZtcpOptions {
607 accept: true,
608 ..Default::default()
609 };
610 let (status, output) = builtin_ztcp(&[], &options, &mut sessions);
611 assert_eq!(status, 1);
612 assert!(output.contains("requires an argument"));
613 }
614}