tokio_unix_tcp/
listener.rs

1/*
2 * Copyright (c) 2023, networkException <git@nwex.de>
3 *
4 * SPDX-License-Identifier: BSD-2-Clause OR MIT
5 */
6
7use std::io;
8
9#[cfg(unix)]
10use std::{
11    os::unix::prelude::PermissionsExt,
12    fs::{self, Permissions}
13};
14
15use tokio::net::TcpListener;
16
17use crate::{SocketAddr, Stream, NamedSocketAddr};
18
19#[cfg(unix)]
20use tokio::net::UnixListener;
21
22#[derive(Debug)]
23pub enum Listener {
24    Tcp(TcpListener),
25    #[cfg(unix)]
26    Unix(UnixListener),
27}
28
29impl From<TcpListener> for Listener {
30    fn from(listener: TcpListener) -> Listener {
31        Listener::Tcp(listener)
32    }
33}
34
35#[cfg(unix)]
36impl From<UnixListener> for Listener {
37    fn from(listener: UnixListener) -> Listener {
38        Listener::Unix(listener)
39    }
40}
41
42impl Listener {
43    // On non unix systems, remove and mode are not used.
44    #[cfg_attr(not(unix), allow(unused_variables))]
45    pub async fn bind_and_prepare_unix(named_socket_addr: &NamedSocketAddr, remove: bool, mode: Option<u32>) -> io::Result<Listener> {
46        match named_socket_addr {
47            NamedSocketAddr::Inet(inet_socket_addr) => {
48                TcpListener::bind(inet_socket_addr).await.map(Listener::Tcp)
49            }
50            #[cfg(unix)]
51            NamedSocketAddr::Unix(path) => {
52                if remove && path.exists() {
53                    fs::remove_file(path)?
54                }
55
56                let bound = UnixListener::bind(path)?;
57
58                fs::set_permissions(
59                    path,
60                    Permissions::from_mode(mode.unwrap_or(0o222)),
61                )?;
62
63                Ok(Listener::Unix(bound))
64            }
65        }
66    }
67
68    pub async fn bind(named_socket_addr: &NamedSocketAddr) -> io::Result<Listener> {
69        match named_socket_addr {
70            NamedSocketAddr::Inet(inet_socket_addr) => {
71                TcpListener::bind(inet_socket_addr).await.map(Listener::Tcp)
72            }
73            #[cfg(unix)]
74            NamedSocketAddr::Unix(path) => UnixListener::bind(path).map(Listener::Unix),
75        }
76    }
77
78    pub async fn accept(&self) -> io::Result<(Stream, SocketAddr)> {
79        match self {
80            Listener::Tcp(listener) => listener
81                .accept()
82                .await
83                .map(|(tcp_stream, inet_socket_addr)| (Stream::Tcp(tcp_stream), SocketAddr::Inet(inet_socket_addr))),
84            #[cfg(unix)]
85            Listener::Unix(listener) => listener
86                .accept()
87                .await
88                .map(|(unix_stream, unix_socket_addr)| (Stream::Unix(unix_stream), SocketAddr::Unix(unix_socket_addr.into()))),
89        }
90    }
91}