#![deny(missing_docs)]
pub mod error;
mod resolv_addr;
use std::convert::{TryFrom, TryInto};
use std::fmt;
use std::ffi::{OsStr, OsString};
use crate::error::*;
use crate::resolv_addr::ResolvAddr;
#[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
use std::convert::Infallible as Never;
#[cfg(all(target_os = "linux", feature = "enable_systemd"))]
pub(crate) mod systemd_sockets {
use std::fmt;
use std::sync::Mutex;
use libsystemd::activation::FileDescriptor;
use libsystemd::errors::Error as LibSystemdError;
use libsystemd::errors::Result as LibSystemdResult;
#[derive(Debug)]
pub(crate) struct Error(&'static Mutex<LibSystemdError>);
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&*self.0.lock().expect("mutex poisoned"), f)
}
}
impl std::error::Error for Error {}
pub(crate) fn take(name: &str) -> Result<Option<FileDescriptor>, Error> {
match &*SYSTEMD_SOCKETS {
Ok(sockets) => Ok(sockets.take(name)),
Err(error) => Err(Error(error))
}
}
struct SystemdSockets(std::sync::Mutex<std::collections::HashMap<String, FileDescriptor>>);
impl SystemdSockets {
fn new() -> LibSystemdResult<Self> {
let map = libsystemd::activation::receive_descriptors_with_names( true)?.into_iter().map(|(fd, name)| (name, fd)).collect();
Ok(SystemdSockets(Mutex::new(map)))
}
fn take(&self, name: &str) -> Option<FileDescriptor> {
self.0.lock().expect("poisoned mutex").remove(name)
}
}
lazy_static::lazy_static! {
static ref SYSTEMD_SOCKETS: Result<SystemdSockets, Mutex<LibSystemdError>> = SystemdSockets::new().map_err(Mutex::new);
}
}
#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(serde_crate::Deserialize), serde(crate = "serde_crate", try_from = "serde_str_helpers::DeserBorrowStr"))]
pub struct SocketAddr(SocketAddrInner);
impl SocketAddr {
pub fn from_systemd_name<T: Into<String>>(name: T) -> Result<Self, ParseError> {
Self::inner_from_systemd_name(name.into(), false)
}
#[cfg(all(target_os = "linux", feature = "enable_systemd"))]
fn inner_from_systemd_name(name: String, prefixed: bool) -> Result<Self, ParseError> {
let real_systemd_name = if prefixed {
&name[SYSTEMD_PREFIX.len()..]
} else {
&name
};
let name_len = real_systemd_name.len();
match real_systemd_name.chars().enumerate().find(|(_, c)| !c.is_ascii() || *c < ' ' || *c == ':') {
None if name_len <= 255 && prefixed => Ok(SocketAddr(SocketAddrInner::Systemd(name))),
None if name_len <= 255 && !prefixed => Ok(SocketAddr(SocketAddrInner::SystemdNoPrefix(name))),
None => Err(ParseErrorInner::LongSocketName { string: name, len: name_len }.into()),
Some((pos, c)) => Err(ParseErrorInner::InvalidCharacter { string: name, c, pos, }.into()),
}
}
#[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
fn inner_from_systemd_name(name: String, _prefixed: bool) -> Result<Self, ParseError> {
Err(ParseError(ParseErrorInner::SystemdUnsupported(name)))
}
pub fn bind(self) -> Result<std::net::TcpListener, BindError> {
match self.0 {
SocketAddrInner::Ordinary(addr) => match std::net::TcpListener::bind(addr) {
Ok(socket) => Ok(socket),
Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()),
},
SocketAddrInner::WithHostname(addr) => match std::net::TcpListener::bind(addr.as_str()) {
Ok(socket) => Ok(socket),
Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()),
},
SocketAddrInner::Systemd(socket_name) => Self::get_systemd(socket_name, true).map(|(socket, _)| socket),
SocketAddrInner::SystemdNoPrefix(socket_name) => Self::get_systemd(socket_name, false).map(|(socket, _)| socket),
}
}
#[cfg(feature = "tokio_0_2")]
pub async fn bind_tokio_0_2(self) -> Result<tokio_0_2::net::TcpListener, TokioBindError> {
match self.0 {
SocketAddrInner::Ordinary(addr) => match tokio_0_2::net::TcpListener::bind(addr).await {
Ok(socket) => Ok(socket),
Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())),
},
SocketAddrInner::WithHostname(addr) => match tokio_0_2::net::TcpListener::bind(addr.as_str()).await {
Ok(socket) => Ok(socket),
Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
},
SocketAddrInner::Systemd(socket_name) => {
let (socket, addr) = Self::get_systemd(socket_name, true)?;
socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
},
SocketAddrInner::SystemdNoPrefix(socket_name) => {
let (socket, addr) = Self::get_systemd(socket_name, false)?;
socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
},
}
}
#[cfg(feature = "tokio_0_3")]
pub async fn bind_tokio_0_3(self) -> Result<tokio_0_3::net::TcpListener, TokioBindError> {
match self.0 {
SocketAddrInner::Ordinary(addr) => match tokio_0_3::net::TcpListener::bind(addr).await {
Ok(socket) => Ok(socket),
Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindFailed { addr, error, }.into())),
},
SocketAddrInner::WithHostname(addr) => match tokio_0_3::net::TcpListener::bind(addr.as_str()).await {
Ok(socket) => Ok(socket),
Err(error) => Err(TokioBindError::Bind(BindErrorInner::BindOrResolvFailed { addr, error, }.into())),
},
SocketAddrInner::Systemd(socket_name) => {
let (socket, addr) = Self::get_systemd(socket_name, true)?;
socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
},
SocketAddrInner::SystemdNoPrefix(socket_name) => {
let (socket, addr) = Self::get_systemd(socket_name, false)?;
socket.try_into().map_err(|error| TokioConversionError { addr, error, }.into())
},
}
}
#[cfg(feature = "async-std")]
pub async fn bind_async_std(self) -> Result<async_std::net::TcpListener, BindError> {
match self.0 {
SocketAddrInner::Ordinary(addr) => match async_std::net::TcpListener::bind(addr).await {
Ok(socket) => Ok(socket),
Err(error) => Err(BindErrorInner::BindFailed { addr, error, }.into()),
},
SocketAddrInner::WithHostname(addr) => match async_std::net::TcpListener::bind(addr.as_str()).await {
Ok(socket) => Ok(socket),
Err(error) => Err(BindErrorInner::BindOrResolvFailed { addr, error, }.into()),
},
SocketAddrInner::Systemd(socket_name) => {
let (socket, _) = Self::get_systemd(socket_name, true)?;
Ok(socket.into())
},
SocketAddrInner::SystemdNoPrefix(socket_name) => {
let (socket, _) = Self::get_systemd(socket_name, false)?;
Ok(socket.into())
},
}
}
fn try_from_generic<'a, T>(string: T) -> Result<Self, ParseError> where T: 'a + std::ops::Deref<Target=str> + Into<String> {
if string.starts_with(SYSTEMD_PREFIX) {
Self::inner_from_systemd_name(string.into(), true)
} else {
match string.parse() {
Ok(addr) => Ok(SocketAddr(SocketAddrInner::Ordinary(addr))),
Err(_) => Ok(SocketAddr(SocketAddrInner::WithHostname(ResolvAddr::try_from_generic(string).map_err(ParseErrorInner::ResolvAddr)?))),
}
}
}
#[cfg(all(target_os = "linux", feature = "enable_systemd"))]
fn get_systemd(socket_name: String, prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
use libsystemd::activation::IsType;
use std::os::unix::io::{FromRawFd, IntoRawFd};
let real_systemd_name = if prefixed {
&socket_name[SYSTEMD_PREFIX.len()..]
} else {
&socket_name
};
let socket = systemd_sockets::take(real_systemd_name).map_err(BindErrorInner::ReceiveDescriptors)?;
unsafe {
match socket {
Some(socket) if socket.is_inet() => Ok((std::net::TcpListener::from_raw_fd(socket.into_raw_fd()), SocketAddrInner::Systemd(socket_name))),
Some(_) => Err(BindErrorInner::NotInetSocket(socket_name).into()),
None => Err(BindErrorInner::MissingDescriptor(socket_name).into())
}
}
}
#[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
fn get_systemd(socket_name: Never, _prefixed: bool) -> Result<(std::net::TcpListener, SocketAddrInner), BindError> {
match socket_name {}
}
}
impl fmt::Display for SocketAddr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
impl fmt::Display for SocketAddrInner {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
SocketAddrInner::Ordinary(addr) => fmt::Display::fmt(addr, f),
SocketAddrInner::Systemd(addr) => fmt::Display::fmt(addr, f),
SocketAddrInner::SystemdNoPrefix(addr) => write!(f, "{}{}", SYSTEMD_PREFIX, addr),
SocketAddrInner::WithHostname(addr) => fmt::Display::fmt(addr, f),
}
}
}
#[derive(Debug, PartialEq)]
enum SocketAddrInner {
Ordinary(std::net::SocketAddr),
WithHostname(resolv_addr::ResolvAddr),
#[cfg(all(target_os = "linux", feature = "enable_systemd"))]
Systemd(String),
#[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
#[allow(dead_code)]
Systemd(Never),
#[cfg(all(target_os = "linux", feature = "enable_systemd"))]
#[allow(dead_code)]
SystemdNoPrefix(String),
#[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
#[allow(dead_code)]
SystemdNoPrefix(Never),
}
const SYSTEMD_PREFIX: &str = "systemd://";
impl<I: Into<std::net::IpAddr>> From<(I, u16)> for SocketAddr {
fn from(value: (I, u16)) -> Self {
SocketAddr(SocketAddrInner::Ordinary(value.into()))
}
}
impl From<std::net::SocketAddr> for SocketAddr {
fn from(value: std::net::SocketAddr) -> Self {
SocketAddr(SocketAddrInner::Ordinary(value))
}
}
impl From<std::net::SocketAddrV4> for SocketAddr {
fn from(value: std::net::SocketAddrV4) -> Self {
SocketAddr(SocketAddrInner::Ordinary(value.into()))
}
}
impl From<std::net::SocketAddrV6> for SocketAddr {
fn from(value: std::net::SocketAddrV6) -> Self {
SocketAddr(SocketAddrInner::Ordinary(value.into()))
}
}
impl std::str::FromStr for SocketAddr {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
SocketAddr::try_from_generic(s)
}
}
impl<'a> TryFrom<&'a str> for SocketAddr {
type Error = ParseError;
fn try_from(s: &'a str) -> Result<Self, Self::Error> {
SocketAddr::try_from_generic(s)
}
}
impl TryFrom<String> for SocketAddr {
type Error = ParseError;
fn try_from(s: String) -> Result<Self, Self::Error> {
SocketAddr::try_from_generic(s)
}
}
impl<'a> TryFrom<&'a OsStr> for SocketAddr {
type Error = ParseOsStrError;
fn try_from(s: &'a OsStr) -> Result<Self, Self::Error> {
s.to_str().ok_or(ParseOsStrError::InvalidUtf8)?.try_into().map_err(Into::into)
}
}
impl TryFrom<OsString> for SocketAddr {
type Error = ParseOsStrError;
fn try_from(s: OsString) -> Result<Self, Self::Error> {
s.into_string().map_err(|_| ParseOsStrError::InvalidUtf8)?.try_into().map_err(Into::into)
}
}
#[cfg(feature = "serde")]
impl<'a> TryFrom<serde_str_helpers::DeserBorrowStr<'a>> for SocketAddr {
type Error = ParseError;
fn try_from(s: serde_str_helpers::DeserBorrowStr<'a>) -> Result<Self, Self::Error> {
SocketAddr::try_from_generic(s)
}
}
#[cfg(feature = "parse_arg")]
impl parse_arg::ParseArg for SocketAddr {
type Error = ParseOsStrError;
fn describe_type<W: fmt::Write>(mut writer: W) -> fmt::Result {
std::net::SocketAddr::describe_type(&mut writer)?;
write!(writer, " or a systemd socket name prefixed with systemd://")
}
fn parse_arg(arg: &OsStr) -> Result<Self, Self::Error> {
arg.try_into()
}
fn parse_owned_arg(arg: OsString) -> Result<Self, Self::Error> {
arg.try_into()
}
}
#[cfg(test)]
mod tests {
use super::{SocketAddr, SocketAddrInner};
#[test]
fn parse_ordinary() {
assert_eq!("127.0.0.1:42".parse::<SocketAddr>().unwrap().0, SocketAddrInner::Ordinary(([127, 0, 0, 1], 42).into()));
}
#[test]
#[cfg(all(target_os = "linux", feature = "enable_systemd"))]
fn parse_systemd() {
assert_eq!("systemd://foo".parse::<SocketAddr>().unwrap().0, SocketAddrInner::Systemd("systemd://foo".to_owned()));
}
#[test]
#[cfg(not(all(target_os = "linux", feature = "enable_systemd")))]
#[should_panic]
fn parse_systemd() {
"systemd://foo".parse::<SocketAddr>().unwrap();
}
#[test]
#[should_panic]
fn parse_systemd_fail_control() {
"systemd://foo\n".parse::<SocketAddr>().unwrap();
}
#[test]
#[should_panic]
fn parse_systemd_fail_colon() {
"systemd://foo:".parse::<SocketAddr>().unwrap();
}
#[test]
#[should_panic]
fn parse_systemd_fail_non_ascii() {
"systemd://fooá".parse::<SocketAddr>().unwrap();
}
#[test]
#[should_panic]
fn parse_systemd_fail_too_long() {
"systemd://xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx".parse::<SocketAddr>().unwrap();
}
#[test]
#[cfg_attr(not(all(target_os = "linux", feature = "enable_systemd")), should_panic)]
fn no_prefix_parse_systemd() {
SocketAddr::from_systemd_name("foo").unwrap();
}
#[test]
#[should_panic]
fn no_prefix_parse_systemd_fail_non_ascii() {
SocketAddr::from_systemd_name("fooá").unwrap();
}
#[test]
#[should_panic]
fn no_prefix_parse_systemd_fail_too_long() {
SocketAddr::from_systemd_name("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx").unwrap();
}
}