use std::convert::TryInto;
use std::env;
use std::io::BufRead;
use std::os::unix::io::AsRawFd;
use std::os::unix::net::UnixStream;
use std::str::FromStr;
use nix::poll::PollFlags;
use nix::unistd::Uid;
use crate::address::Address;
use crate::guid::Guid;
use crate::raw::{Connection, Socket};
use crate::utils::wait_on;
use crate::{Error, Result};
#[derive(Debug)]
enum ClientHandshakeStep {
Init,
SendingOauth,
WaitOauth,
SendingNegociateFd,
WaitNegociateFd,
SendingBegin,
Done,
}
#[derive(Debug)]
pub struct ClientHandshake<S> {
socket: S,
buffer: Vec<u8>,
step: ClientHandshakeStep,
server_guid: Option<Guid>,
cap_unix_fd: bool,
}
#[derive(Debug)]
pub struct Authenticated<S> {
pub(crate) conn: Connection<S>,
pub(crate) server_guid: Guid,
pub(crate) cap_unix_fd: bool,
}
impl<S: Socket> ClientHandshake<S> {
pub fn new(socket: S) -> ClientHandshake<S> {
ClientHandshake {
socket,
buffer: Vec::new(),
step: ClientHandshakeStep::Init,
server_guid: None,
cap_unix_fd: false,
}
}
fn flush_buffer(&mut self) -> Result<()> {
while !self.buffer.is_empty() {
let written = self.socket.sendmsg(&self.buffer, &[])?;
self.buffer.drain(..written);
}
Ok(())
}
fn read_command(&mut self) -> Result<()> {
while !self.buffer.ends_with(b"\r\n") {
let mut buf = [0; 40];
let (read, _) = self.socket.recvmsg(&mut buf)?;
self.buffer.extend(&buf[..read]);
}
Ok(())
}
pub fn advance_handshake(&mut self) -> Result<()> {
loop {
match self.step {
ClientHandshakeStep::Init => {
let uid_str = Uid::current()
.to_string()
.chars()
.map(|c| format!("{:x}", c as u32))
.collect::<String>();
self.buffer = format!("\0AUTH EXTERNAL {}\r\n", uid_str).into();
self.step = ClientHandshakeStep::SendingOauth;
}
ClientHandshakeStep::SendingOauth => {
self.flush_buffer()?;
self.step = ClientHandshakeStep::WaitOauth;
}
ClientHandshakeStep::WaitOauth => {
self.read_command()?;
let mut reply = String::new();
(&self.buffer[..]).read_line(&mut reply)?;
let mut words = reply.split_whitespace();
let guid = match (words.next(), words.next(), words.next()) {
(Some("OK"), Some(guid), None) => guid.try_into()?,
_ => {
return Err(Error::Handshake(
"Unexpected server AUTH reply".to_string(),
))
}
};
self.server_guid = Some(guid);
self.buffer = Vec::from(&b"NEGOTIATE_UNIX_FD\r\n"[..]);
self.step = ClientHandshakeStep::SendingNegociateFd;
}
ClientHandshakeStep::SendingNegociateFd => {
self.flush_buffer()?;
self.step = ClientHandshakeStep::WaitNegociateFd;
}
ClientHandshakeStep::WaitNegociateFd => {
self.read_command()?;
if self.buffer.starts_with(b"AGREE_UNIX_FD") {
self.cap_unix_fd = true;
} else if self.buffer.starts_with(b"ERROR") {
self.cap_unix_fd = false;
} else {
return Err(Error::Handshake(
"Unexpected server UNIX_FD reply".to_string(),
));
}
self.buffer = Vec::from(&b"BEGIN\r\n"[..]);
self.step = ClientHandshakeStep::SendingBegin;
}
ClientHandshakeStep::SendingBegin => {
self.flush_buffer()?;
self.step = ClientHandshakeStep::Done;
}
ClientHandshakeStep::Done => return Ok(()),
}
}
}
pub fn try_finish(self) -> std::result::Result<Authenticated<S>, Self> {
if let ClientHandshakeStep::Done = self.step {
Ok(Authenticated {
conn: Connection::wrap(self.socket),
server_guid: self.server_guid.unwrap(),
cap_unix_fd: self.cap_unix_fd,
})
} else {
Err(self)
}
}
pub fn socket(&self) -> &S {
&self.socket
}
}
impl ClientHandshake<UnixStream> {
pub fn new_session() -> Result<Self> {
session_socket().map(Self::new)
}
pub fn new_session_nonblock() -> Result<Self> {
let socket = session_socket()?;
socket.set_nonblocking(true)?;
Ok(Self::new(socket))
}
pub fn new_system() -> Result<Self> {
system_socket().map(Self::new)
}
pub fn new_system_nonblock() -> Result<Self> {
let socket = system_socket()?;
socket.set_nonblocking(true)?;
Ok(Self::new(socket))
}
pub fn new_for_address(address: &str) -> Result<Self> {
Address::from_str(address)?.connect().map(Self::new)
}
pub fn new_for_address_nonblock(address: &str) -> Result<Self> {
let socket = crate::address::Address::from_str(address)?.connect()?;
socket.set_nonblocking(true)?;
Ok(Self::new(socket))
}
pub fn blocking_finish(mut self) -> Result<Authenticated<UnixStream>> {
loop {
match self.advance_handshake() {
Ok(()) => return Ok(self.try_finish().unwrap_or_else(|_| unreachable!())),
Err(Error::Io(e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
let flags = match self.step {
ClientHandshakeStep::SendingOauth
| ClientHandshakeStep::SendingNegociateFd
| ClientHandshakeStep::SendingBegin => PollFlags::POLLOUT,
ClientHandshakeStep::WaitOauth | ClientHandshakeStep::WaitNegociateFd => {
PollFlags::POLLIN
}
ClientHandshakeStep::Init | ClientHandshakeStep::Done => unreachable!(),
};
wait_on(self.socket.as_raw_fd(), flags)?;
}
Err(e) => return Err(e),
}
}
}
}
#[derive(Debug)]
enum ServerHandshakeStep {
WaitingForNull,
WaitingForAuth,
SendingAuthOK,
SendingAuthError,
WaitingForBegin,
SendingBeginMessage,
Done,
}
#[derive(Debug)]
pub struct ServerHandshake<S> {
socket: S,
buffer: Vec<u8>,
step: ServerHandshakeStep,
server_guid: Guid,
cap_unix_fd: bool,
client_uid: u32,
}
impl<S: Socket> ServerHandshake<S> {
pub fn new(socket: S, guid: Guid, client_uid: u32) -> ServerHandshake<S> {
ServerHandshake {
socket,
buffer: Vec::new(),
step: ServerHandshakeStep::WaitingForNull,
server_guid: guid,
cap_unix_fd: false,
client_uid,
}
}
fn flush_buffer(&mut self) -> Result<()> {
while !self.buffer.is_empty() {
let written = self.socket.sendmsg(&self.buffer, &[])?;
self.buffer.drain(..written);
}
Ok(())
}
fn read_command(&mut self) -> Result<()> {
while !self.buffer.ends_with(b"\r\n") {
let mut buf = [0; 40];
let (read, _) = self.socket.recvmsg(&mut buf)?;
self.buffer.extend(&buf[..read]);
}
Ok(())
}
pub fn advance_handshake(&mut self) -> Result<()> {
loop {
match self.step {
ServerHandshakeStep::WaitingForNull => {
let mut buffer = [0; 1];
let (read, _) = self.socket.recvmsg(&mut buffer)?;
debug_assert!(read == 1);
if buffer[0] != 0 {
return Err(Error::Handshake(
"First client byte is not NUL!".to_string(),
));
}
self.step = ServerHandshakeStep::WaitingForAuth;
}
ServerHandshakeStep::WaitingForAuth => {
self.read_command()?;
let mut reply = String::new();
(&self.buffer[..]).read_line(&mut reply)?;
let mut words = reply.split_whitespace();
match (words.next(), words.next(), words.next(), words.next()) {
(Some("AUTH"), Some("EXTERNAL"), Some(uid), None) => {
let uid = id_from_str(uid)
.map_err(|e| Error::Handshake(format!("Invalid UID: {}", e)))?;
if uid == self.client_uid {
self.buffer = format!("OK {}\r\n", self.server_guid).into();
self.step = ServerHandshakeStep::SendingAuthOK;
} else {
self.buffer = Vec::from(&b"REJECTED EXTERNAL\r\n"[..]);
self.step = ServerHandshakeStep::SendingAuthError;
}
}
(Some("AUTH"), _, _, _) | (Some("ERROR"), _, _, _) => {
self.buffer = Vec::from(&b"REJECTED EXTERNAL\r\n"[..]);
self.step = ServerHandshakeStep::SendingAuthError;
}
(Some("BEGIN"), None, None, None) => {
return Err(Error::Handshake(
"Received BEGIN while not authenticated".to_string(),
));
}
_ => {
self.buffer = Vec::from(&b"ERROR Unsupported command\r\n"[..]);
self.step = ServerHandshakeStep::SendingAuthError;
}
}
}
ServerHandshakeStep::SendingAuthError => {
self.flush_buffer()?;
self.step = ServerHandshakeStep::WaitingForAuth;
}
ServerHandshakeStep::SendingAuthOK => {
self.flush_buffer()?;
self.step = ServerHandshakeStep::WaitingForBegin;
}
ServerHandshakeStep::WaitingForBegin => {
self.read_command()?;
let mut reply = String::new();
(&self.buffer[..]).read_line(&mut reply)?;
let mut words = reply.split_whitespace();
match (words.next(), words.next()) {
(Some("BEGIN"), None) => {
self.step = ServerHandshakeStep::Done;
}
(Some("CANCEL"), None) => {
self.buffer = Vec::from(&b"REJECTED EXTERNAL\r\n"[..]);
self.step = ServerHandshakeStep::SendingAuthError;
}
(Some("ERROR"), _) => {
self.buffer = Vec::from(&b"REJECTED EXTERNAL\r\n"[..]);
self.step = ServerHandshakeStep::SendingAuthError;
}
(Some("NEGOTIATE_UNIX_FD"), None) => {
self.cap_unix_fd = true;
self.buffer = Vec::from(&b"AGREE_UNIX_FD\r\n"[..]);
self.step = ServerHandshakeStep::SendingBeginMessage;
}
_ => {
self.buffer = Vec::from(&b"ERROR Unsupported command\r\n"[..]);
self.step = ServerHandshakeStep::SendingBeginMessage;
}
}
}
ServerHandshakeStep::SendingBeginMessage => {
self.flush_buffer()?;
self.step = ServerHandshakeStep::WaitingForBegin;
}
ServerHandshakeStep::Done => return Ok(()),
}
}
}
pub fn try_finish(self) -> std::result::Result<Authenticated<S>, Self> {
if let ServerHandshakeStep::Done = self.step {
Ok(Authenticated {
conn: Connection::wrap(self.socket),
server_guid: self.server_guid,
cap_unix_fd: self.cap_unix_fd,
})
} else {
Err(self)
}
}
pub fn socket(&self) -> &S {
&self.socket
}
}
impl ServerHandshake<UnixStream> {
pub fn blocking_finish(mut self) -> Result<Authenticated<UnixStream>> {
loop {
match self.advance_handshake() {
Ok(()) => return Ok(self.try_finish().unwrap_or_else(|_| unreachable!())),
Err(Error::Io(e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
let flags = match self.step {
ServerHandshakeStep::SendingAuthError
| ServerHandshakeStep::SendingAuthOK
| ServerHandshakeStep::SendingBeginMessage => PollFlags::POLLOUT,
ServerHandshakeStep::WaitingForNull
| ServerHandshakeStep::WaitingForBegin
| ServerHandshakeStep::WaitingForAuth => PollFlags::POLLIN,
ServerHandshakeStep::Done => unreachable!(),
};
wait_on(self.socket.as_raw_fd(), flags)?;
}
Err(e) => return Err(e),
}
}
}
}
fn session_socket() -> Result<UnixStream> {
match env::var("DBUS_SESSION_BUS_ADDRESS") {
Ok(val) => Address::from_str(&val)?.connect(),
_ => {
let uid = Uid::current();
let path = format!("/run/user/{}/bus", uid);
Ok(UnixStream::connect(path)?)
}
}
}
fn system_socket() -> Result<UnixStream> {
match env::var("DBUS_SYSTEM_BUS_ADDRESS") {
Ok(val) => Address::from_str(&val)?.connect(),
_ => Ok(UnixStream::connect("/var/run/dbus/system_bus_socket")?),
}
}
fn id_from_str(s: &str) -> std::result::Result<u32, Box<dyn std::error::Error>> {
let mut id = String::new();
for s in s.as_bytes().chunks(2) {
let c = char::from(u8::from_str_radix(std::str::from_utf8(s)?, 16)?);
id.push(c);
}
Ok(id.parse::<u32>()?)
}
#[cfg(test)]
mod tests {
use std::os::unix::net::UnixStream;
use super::*;
use crate::Guid;
#[test]
fn async_handshake() {
let (p0, p1) = UnixStream::pair().unwrap();
p0.set_nonblocking(true).unwrap();
p1.set_nonblocking(true).unwrap();
let mut client = ClientHandshake::new(p0);
let mut server = ServerHandshake::new(p1, Guid::generate(), Uid::current().into());
let mut client_done = false;
let mut server_done = false;
while !(client_done && server_done) {
match client.advance_handshake() {
Ok(()) => client_done = true,
Err(Error::Io(e)) => assert!(e.kind() == std::io::ErrorKind::WouldBlock),
Err(e) => panic!("Unexpected error: {:?}", e),
}
match server.advance_handshake() {
Ok(()) => server_done = true,
Err(Error::Io(e)) => assert!(e.kind() == std::io::ErrorKind::WouldBlock),
Err(e) => panic!("Unexpected error: {:?}", e),
}
}
let client = client.try_finish().unwrap();
let server = server.try_finish().unwrap();
assert_eq!(client.server_guid, server.server_guid);
assert_eq!(client.cap_unix_fd, server.cap_unix_fd);
}
}