use std::convert::TryInto;
use std::io::IoSlice;
use std::sync::{Condvar, Mutex, MutexGuard, TryLockError};
use crate::connection::{
compute_length_field, Connection, DiscardMode, ReplyOrError, RequestConnection, RequestKind,
SequenceNumber,
};
use crate::cookie::{Cookie, CookieWithFds, VoidCookie};
pub use crate::errors::{ConnectError, ConnectionError, ParseError, ReplyError, ReplyOrIdError};
use crate::extension_manager::ExtensionManager;
use crate::protocol::bigreq::{ConnectionExt as _, EnableReply};
use crate::protocol::xproto::{Setup, SetupRequest, GET_INPUT_FOCUS_REQUEST};
use crate::utils::RawFdContainer;
use crate::x11_utils::{ExtensionInformation, Serialize, TryParse, TryParseFd};
mod id_allocator;
mod inner;
mod packet_reader;
mod parse_display;
mod stream;
mod write_buffer;
mod xauth;
use inner::PollReply;
use packet_reader::PacketReader;
pub use stream::{DefaultStream, PollMode, Stream};
use write_buffer::WriteBuffer;
type Buffer = <RustConnection as RequestConnection>::Buf;
pub type RawEventAndSeqNumber = crate::connection::RawEventAndSeqNumber<Buffer>;
pub type BufWithFds = crate::connection::BufWithFds<Buffer>;
#[derive(Debug)]
enum MaxRequestBytes {
Unknown,
Requested(Option<SequenceNumber>),
Known(usize),
}
type MutexGuardInner<'a> = MutexGuard<'a, inner::ConnectionInner>;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(crate) enum ReplyFDKind {
NoReply,
ReplyWithoutFDs,
ReplyWithFDs,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(crate) enum BlockingMode {
Blocking,
NonBlocking,
}
#[derive(Debug)]
pub struct RustConnection<S: Stream = DefaultStream> {
inner: Mutex<inner::ConnectionInner>,
stream: S,
packet_reader: Mutex<PacketReader>,
reader_condition: Condvar,
id_allocator: Mutex<id_allocator::IDAllocator>,
setup: Setup,
extension_manager: Mutex<ExtensionManager>,
maximum_request_bytes: Mutex<MaxRequestBytes>,
}
impl RustConnection<DefaultStream> {
pub fn connect(dpy_name: Option<&str>) -> Result<(Self, usize), ConnectError> {
let parsed_display =
parse_display::parse_display(dpy_name).ok_or(ConnectError::DisplayParsingError)?;
let protocol = parsed_display.protocol.as_deref();
let stream =
DefaultStream::connect(&*parsed_display.host, protocol, parsed_display.display)?;
let screen = parsed_display.screen.into();
let (family, address) = stream.peer_addr()?;
let (auth_name, auth_data) = xauth::get_auth(family, &address, parsed_display.display)
.unwrap_or(None)
.unwrap_or_else(|| (Vec::new(), Vec::new()));
Ok((
Self::connect_to_stream_with_auth_info(stream, screen, auth_name, auth_data)?,
screen,
))
}
}
impl<S: Stream> RustConnection<S> {
pub fn connect_to_stream(stream: S, screen: usize) -> Result<Self, ConnectError> {
Self::connect_to_stream_with_auth_info(stream, screen, Vec::new(), Vec::new())
}
pub fn connect_to_stream_with_auth_info(
stream: S,
screen: usize,
auth_name: Vec<u8>,
auth_data: Vec<u8>,
) -> Result<Self, ConnectError> {
write_setup(&stream, auth_name, auth_data)?;
let setup = read_setup(&stream)?;
if screen >= setup.roots.len() {
return Err(ConnectError::InvalidScreen);
}
Self::for_connected_stream(stream, setup)
}
pub fn for_connected_stream(stream: S, setup: Setup) -> Result<Self, ConnectError> {
Self::for_inner(stream, inner::ConnectionInner::new(), setup)
}
fn for_inner(
stream: S,
inner: inner::ConnectionInner,
setup: Setup,
) -> Result<Self, ConnectError> {
let allocator =
id_allocator::IDAllocator::new(setup.resource_id_base, setup.resource_id_mask)?;
Ok(RustConnection {
inner: Mutex::new(inner),
stream,
packet_reader: Mutex::new(PacketReader::new()),
reader_condition: Condvar::new(),
id_allocator: Mutex::new(allocator),
setup,
extension_manager: Default::default(),
maximum_request_bytes: Mutex::new(MaxRequestBytes::Unknown),
})
}
fn send_request(
&self,
bufs: &[IoSlice<'_>],
fds: Vec<RawFdContainer>,
kind: ReplyFDKind,
) -> Result<SequenceNumber, ConnectionError> {
let mut storage = Default::default();
let bufs = compute_length_field(self, bufs, &mut storage)?;
let mut inner = self.inner.lock().unwrap();
loop {
match inner.send_request(kind) {
Some(seqno) => {
let _inner = self.write_all_vectored(inner, bufs, fds)?;
return Ok(seqno);
}
None => {
inner = self.send_sync(inner)?;
}
}
}
}
fn send_sync<'a>(
&'a self,
mut inner: MutexGuardInner<'a>,
) -> Result<MutexGuardInner<'a>, std::io::Error> {
let length = 1u16.to_ne_bytes();
let request = [
GET_INPUT_FOCUS_REQUEST,
0,
length[0],
length[1],
];
let seqno = inner
.send_request(ReplyFDKind::ReplyWithoutFDs)
.expect("Sending a HasResponse request should not be blocked by syncs");
inner.discard_reply(seqno, DiscardMode::DiscardReplyAndError);
let inner = self.write_all_vectored(inner, &[IoSlice::new(&request)], Vec::new())?;
Ok(inner)
}
fn write_all_vectored<'a>(
&'a self,
mut inner: MutexGuardInner<'a>,
mut bufs: &[IoSlice<'_>],
mut fds: Vec<RawFdContainer>,
) -> std::io::Result<MutexGuardInner<'a>> {
let mut partial_buf: &[u8] = &[];
while !partial_buf.is_empty() || !bufs.is_empty() || !fds.is_empty() {
self.stream.poll(PollMode::ReadAndWritable)?;
let write_result = if !partial_buf.is_empty() {
inner
.write_buffer
.write(&self.stream, partial_buf, &mut fds)
} else {
inner
.write_buffer
.write_vectored(&self.stream, bufs, &mut fds)
};
match write_result {
Ok(0) => {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"failed to write anything",
));
}
Ok(mut count) => {
if count >= partial_buf.len() {
count -= partial_buf.len();
partial_buf = &[];
} else {
partial_buf = &partial_buf[count..];
count = 0;
}
while count > 0 {
if count >= bufs[0].len() {
count -= bufs[0].len();
} else {
partial_buf = &bufs[0][count..];
count = 0;
}
bufs = &bufs[1..];
while bufs.first().map(|s| s.len()) == Some(0) {
bufs = &bufs[1..];
}
}
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
inner = self.read_packet_and_enqueue(inner, BlockingMode::NonBlocking)?;
}
Err(e) => return Err(e),
}
}
Ok(inner)
}
fn flush_impl<'a>(
&'a self,
mut inner: MutexGuardInner<'a>,
) -> std::io::Result<MutexGuardInner<'a>> {
while inner.write_buffer.needs_flush() {
self.stream.poll(PollMode::ReadAndWritable)?;
match inner.write_buffer.flush(&self.stream) {
Ok(()) => break,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
inner = self.read_packet_and_enqueue(inner, BlockingMode::NonBlocking)?;
}
Err(e) => return Err(e),
}
}
Ok(inner)
}
fn read_packet_and_enqueue<'a>(
&'a self,
mut inner: MutexGuardInner<'a>,
mode: BlockingMode,
) -> Result<MutexGuardInner<'a>, std::io::Error> {
match self.packet_reader.try_lock() {
Err(TryLockError::WouldBlock) => {
match mode {
BlockingMode::NonBlocking => return Ok(inner),
BlockingMode::Blocking => {}
}
Ok(self.reader_condition.wait(inner).unwrap())
}
Err(TryLockError::Poisoned(e)) => panic!("{}", e),
Ok(mut packet_reader) => {
let notify_on_drop = NotifyOnDrop(&self.reader_condition);
if mode == BlockingMode::Blocking {
drop(inner);
self.stream.poll(PollMode::Readable)?;
inner = self.inner.lock().unwrap();
}
let mut fds = Vec::new();
let mut packets = Vec::new();
packet_reader.try_read_packets(&self.stream, &mut packets, &mut fds)?;
drop(packet_reader);
inner.enqueue_fds(fds);
packets
.into_iter()
.for_each(|packet| inner.enqueue_packet(packet));
drop(notify_on_drop);
Ok(inner)
}
}
}
fn prefetch_maximum_request_bytes_impl(&self, max_bytes: &mut MutexGuard<'_, MaxRequestBytes>) {
if let MaxRequestBytes::Unknown = **max_bytes {
let request = self
.bigreq_enable()
.map(|cookie| cookie.into_sequence_number())
.ok();
**max_bytes = MaxRequestBytes::Requested(request);
}
}
pub fn stream(&self) -> &S {
&self.stream
}
}
impl<S: Stream> RequestConnection for RustConnection<S> {
type Buf = Vec<u8>;
fn send_request_with_reply<Reply>(
&self,
bufs: &[IoSlice<'_>],
fds: Vec<RawFdContainer>,
) -> Result<Cookie<'_, Self, Reply>, ConnectionError>
where
Reply: TryParse,
{
Ok(Cookie::new(
self,
self.send_request(bufs, fds, ReplyFDKind::ReplyWithoutFDs)?,
))
}
fn send_request_with_reply_with_fds<Reply>(
&self,
bufs: &[IoSlice<'_>],
fds: Vec<RawFdContainer>,
) -> Result<CookieWithFds<'_, Self, Reply>, ConnectionError>
where
Reply: TryParseFd,
{
Ok(CookieWithFds::new(
self,
self.send_request(bufs, fds, ReplyFDKind::ReplyWithFDs)?,
))
}
fn send_request_without_reply(
&self,
bufs: &[IoSlice<'_>],
fds: Vec<RawFdContainer>,
) -> Result<VoidCookie<'_, Self>, ConnectionError> {
Ok(VoidCookie::new(
self,
self.send_request(bufs, fds, ReplyFDKind::NoReply)?,
))
}
fn discard_reply(&self, sequence: SequenceNumber, _kind: RequestKind, mode: DiscardMode) {
self.inner.lock().unwrap().discard_reply(sequence, mode);
}
fn prefetch_extension_information(
&self,
extension_name: &'static str,
) -> Result<(), ConnectionError> {
self.extension_manager
.lock()
.unwrap()
.prefetch_extension_information(self, extension_name)
}
fn extension_information(
&self,
extension_name: &'static str,
) -> Result<Option<ExtensionInformation>, ConnectionError> {
self.extension_manager
.lock()
.unwrap()
.extension_information(self, extension_name)
}
fn wait_for_reply_or_raw_error(
&self,
sequence: SequenceNumber,
) -> Result<ReplyOrError<Vec<u8>>, ConnectionError> {
match self.wait_for_reply_with_fds_raw(sequence)? {
ReplyOrError::Reply((reply, _fds)) => Ok(ReplyOrError::Reply(reply)),
ReplyOrError::Error(e) => Ok(ReplyOrError::Error(e)),
}
}
fn wait_for_reply(&self, sequence: SequenceNumber) -> Result<Option<Vec<u8>>, ConnectionError> {
let mut inner = self.inner.lock().unwrap();
inner = self.flush_impl(inner)?;
loop {
match inner.poll_for_reply(sequence) {
PollReply::TryAgain => {}
PollReply::NoReply => return Ok(None),
PollReply::Reply(buffer) => return Ok(Some(buffer)),
}
inner = self.read_packet_and_enqueue(inner, BlockingMode::Blocking)?;
}
}
fn check_for_raw_error(
&self,
sequence: SequenceNumber,
) -> Result<Option<Buffer>, ConnectionError> {
let mut inner = self.inner.lock().unwrap();
if inner.prepare_check_for_reply_or_error(sequence) {
inner = self.send_sync(inner)?;
assert!(!inner.prepare_check_for_reply_or_error(sequence));
}
inner = self.flush_impl(inner)?;
loop {
match inner.poll_check_for_reply_or_error(sequence) {
PollReply::TryAgain => {}
PollReply::NoReply => return Ok(None),
PollReply::Reply(buffer) => return Ok(Some(buffer)),
}
inner = self.read_packet_and_enqueue(inner, BlockingMode::Blocking)?;
}
}
fn wait_for_reply_with_fds_raw(
&self,
sequence: SequenceNumber,
) -> Result<ReplyOrError<BufWithFds, Buffer>, ConnectionError> {
let mut inner = self.inner.lock().unwrap();
inner = self.flush_impl(inner)?;
loop {
if let Some(reply) = inner.poll_for_reply_or_error(sequence) {
if reply.0[0] == 0 {
return Ok(ReplyOrError::Error(reply.0));
} else {
return Ok(ReplyOrError::Reply(reply));
}
}
inner = self.read_packet_and_enqueue(inner, BlockingMode::Blocking)?;
}
}
fn maximum_request_bytes(&self) -> usize {
let mut max_bytes = self.maximum_request_bytes.lock().unwrap();
self.prefetch_maximum_request_bytes_impl(&mut max_bytes);
use MaxRequestBytes::*;
match *max_bytes {
Unknown => unreachable!("We just prefetched this"),
Requested(seqno) => {
let length = seqno
.and_then(|seqno| {
Cookie::<_, EnableReply>::new(self, seqno)
.reply()
.map(|reply| reply.maximum_request_length)
.ok()
})
.unwrap_or_else(|| self.setup.maximum_request_length.into())
.try_into()
.unwrap_or(usize::max_value());
let length = length * 4;
*max_bytes = Known(length);
length
}
Known(length) => length,
}
}
fn prefetch_maximum_request_bytes(&self) {
let mut max_bytes = self.maximum_request_bytes.lock().unwrap();
self.prefetch_maximum_request_bytes_impl(&mut max_bytes);
}
fn parse_error(&self, error: &[u8]) -> Result<crate::x11_utils::X11Error, ParseError> {
let ext_mgr = self.extension_manager.lock().unwrap();
crate::x11_utils::X11Error::try_parse(error, &*ext_mgr)
}
fn parse_event(&self, event: &[u8]) -> Result<crate::protocol::Event, ParseError> {
let ext_mgr = self.extension_manager.lock().unwrap();
crate::protocol::Event::parse(event, &*ext_mgr)
}
}
impl<S: Stream> Connection for RustConnection<S> {
fn wait_for_raw_event_with_sequence(&self) -> Result<RawEventAndSeqNumber, ConnectionError> {
let mut inner = self.inner.lock().unwrap();
loop {
if let Some(event) = inner.poll_for_event_with_sequence() {
return Ok(event);
}
inner = self.read_packet_and_enqueue(inner, BlockingMode::Blocking)?;
}
}
fn poll_for_raw_event_with_sequence(
&self,
) -> Result<Option<RawEventAndSeqNumber>, ConnectionError> {
let mut inner = self.inner.lock().unwrap();
if let Some(event) = inner.poll_for_event_with_sequence() {
Ok(Some(event))
} else {
inner = self.read_packet_and_enqueue(inner, BlockingMode::NonBlocking)?;
Ok(inner.poll_for_event_with_sequence())
}
}
fn flush(&self) -> Result<(), ConnectionError> {
let inner = self.inner.lock().unwrap();
let _inner = self.flush_impl(inner)?;
Ok(())
}
fn setup(&self) -> &Setup {
&self.setup
}
fn generate_id(&self) -> Result<u32, ReplyOrIdError> {
self.id_allocator.lock().unwrap().generate_id(self)
}
}
#[cfg(target_endian = "little")]
fn byte_order() -> u8 {
0x6c
}
#[cfg(target_endian = "big")]
fn byte_order() -> u8 {
0x42
}
fn write_setup(
write: &impl Stream,
auth_name: Vec<u8>,
auth_data: Vec<u8>,
) -> Result<(), std::io::Error> {
let request = SetupRequest {
byte_order: byte_order(),
protocol_major_version: 11,
protocol_minor_version: 0,
authorization_protocol_name: auth_name,
authorization_protocol_data: auth_data,
};
let data = request.serialize();
let mut nwritten = 0;
while nwritten != data.len() {
write.poll(PollMode::Writable)?;
match write.write(&data[nwritten..], &mut Vec::new()) {
Ok(0) => {
return Err(std::io::Error::new(
std::io::ErrorKind::WriteZero,
"failed to write whole buffer",
))
}
Ok(n) => nwritten += n,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
Err(e) => return Err(e),
}
}
Ok(())
}
fn read_setup(stream: &impl Stream) -> Result<Setup, ConnectError> {
let mut fds = Vec::new();
let mut setup = vec![0; 8];
stream.read_exact(&mut setup, &mut fds)?;
let extra_length = usize::from(u16::from_ne_bytes([setup[6], setup[7]])) * 4;
setup.reserve_exact(extra_length);
setup.resize(8 + extra_length, 0);
stream.read_exact(&mut setup[8..], &mut fds)?;
if !fds.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"unexpectedly received FDs in connection setup",
)
.into());
}
match setup[0] {
0 => Err(ConnectError::SetupFailed(
TryParse::try_parse(&setup[..])?.0,
)),
1 => Ok(Setup::try_parse(&setup[..])?.0),
2 => Err(ConnectError::SetupAuthenticate(
TryParse::try_parse(&setup[..])?.0,
)),
_ => Err(ParseError::InvalidValue.into()),
}
}
#[derive(Debug)]
struct NotifyOnDrop<'a>(&'a Condvar);
impl Drop for NotifyOnDrop<'_> {
fn drop(&mut self) {
self.0.notify_all();
}
}
#[cfg(test)]
mod test {
use std::cell::RefCell;
use std::io::{Read, Result, Write};
use super::{read_setup, PollMode, Stream};
use crate::errors::ConnectError;
use crate::protocol::xproto::{ImageOrder, Setup, SetupAuthenticate, SetupFailed};
use crate::utils::RawFdContainer;
use crate::x11_utils::Serialize;
struct SliceStream<'a, 'b> {
read_slice: RefCell<&'a [u8]>,
write_slice: RefCell<&'b mut [u8]>,
}
impl<'a, 'b> Stream for SliceStream<'a, 'b> {
fn poll(&self, _mode: PollMode) -> Result<()> {
Ok(())
}
fn read(&self, buf: &mut [u8], _fd_storage: &mut Vec<RawFdContainer>) -> Result<usize> {
self.read_slice.borrow_mut().read(buf)
}
fn write(&self, buf: &[u8], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
assert!(fds.is_empty());
self.write_slice.borrow_mut().write(buf)
}
}
#[test]
fn read_setup_success() {
let mut setup = Setup {
status: 1,
protocol_major_version: 11,
protocol_minor_version: 0,
length: 0,
release_number: 0,
resource_id_base: 0,
resource_id_mask: 0,
motion_buffer_size: 0,
maximum_request_length: 0,
image_byte_order: ImageOrder::LSB_FIRST,
bitmap_format_bit_order: ImageOrder::LSB_FIRST,
bitmap_format_scanline_unit: 0,
bitmap_format_scanline_pad: 0,
min_keycode: 0,
max_keycode: 0,
vendor: vec![],
pixmap_formats: vec![],
roots: vec![],
};
setup.length = ((setup.serialize().len() - 8) / 4) as _;
let setup_bytes = setup.serialize();
let stream = SliceStream {
read_slice: RefCell::new(&setup_bytes),
write_slice: RefCell::new(&mut []),
};
let read = read_setup(&stream);
assert_eq!(setup, read.unwrap());
}
#[test]
fn read_setup_failed() {
let mut setup = SetupFailed {
status: 0,
protocol_major_version: 11,
protocol_minor_version: 0,
length: 0,
reason: b"whatever".to_vec(),
};
setup.length = ((setup.serialize().len() - 8) / 4) as _;
let setup_bytes = setup.serialize();
let stream = SliceStream {
read_slice: RefCell::new(&setup_bytes),
write_slice: RefCell::new(&mut []),
};
match read_setup(&stream) {
Err(ConnectError::SetupFailed(read)) => assert_eq!(setup, read),
value => panic!("Unexpected value {:?}", value),
}
}
#[test]
fn read_setup_authenticate() {
let setup = SetupAuthenticate {
status: 2,
reason: b"12345678".to_vec(),
};
let setup_bytes = setup.serialize();
let stream = SliceStream {
read_slice: RefCell::new(&setup_bytes),
write_slice: RefCell::new(&mut []),
};
match read_setup(&stream) {
Err(ConnectError::SetupAuthenticate(read)) => assert_eq!(setup, read),
value => panic!("Unexpected value {:?}", value),
}
}
}