use std::{
convert::{Into, TryFrom, TryInto},
fmt,
io::{Cursor, Write},
};
#[cfg(unix)]
use std::{
os::unix::io::{AsRawFd, RawFd},
sync::{Arc, RwLock},
};
use enumflags2::BitFlags;
use static_assertions::assert_impl_all;
use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, UniqueName};
#[cfg(unix)]
use crate::OwnedFd;
use crate::{
utils::padding_for_8_bytes,
zvariant::{DynamicType, EncodingContext, ObjectPath, Signature, Type},
EndianSig, Error, MessageField, MessageFieldCode, MessageFields, MessageFlags, MessageHeader,
MessagePrimaryHeader, MessageType, QuickMessageFields, Result, MAX_MESSAGE_SIZE,
MIN_MESSAGE_SIZE, NATIVE_ENDIAN_SIG,
};
#[cfg(unix)]
const LOCK_PANIC_MSG: &str = "lock poisoned";
#[cfg(unix)]
type BuildGenericResult = Vec<RawFd>;
#[cfg(not(unix))]
type BuildGenericResult = ();
macro_rules! dbus_context {
($n_bytes_before: expr) => {
EncodingContext::<byteorder::NativeEndian>::new_dbus($n_bytes_before)
};
}
#[derive(Debug)]
pub struct MessageBuilder<'a> {
header: MessageHeader<'a>,
}
impl<'a> MessageBuilder<'a> {
fn new(msg_type: MessageType) -> Self {
let primary = MessagePrimaryHeader::new(msg_type, 0);
let fields = MessageFields::new();
let header = MessageHeader::new(primary, fields);
Self { header }
}
pub fn method_call<'p: 'a, 'm: 'a, P, M>(path: P, method_name: M) -> Result<Self>
where
P: TryInto<ObjectPath<'p>>,
M: TryInto<MemberName<'m>>,
P::Error: Into<Error>,
M::Error: Into<Error>,
{
Self::new(MessageType::MethodCall)
.path(path)?
.member(method_name)
}
pub fn signal<'p: 'a, 'i: 'a, 'm: 'a, P, I, M>(path: P, interface: I, name: M) -> Result<Self>
where
P: TryInto<ObjectPath<'p>>,
I: TryInto<InterfaceName<'i>>,
M: TryInto<MemberName<'m>>,
P::Error: Into<Error>,
I::Error: Into<Error>,
M::Error: Into<Error>,
{
Self::new(MessageType::Signal)
.path(path)?
.interface(interface)?
.member(name)
}
pub fn method_return(reply_to: &MessageHeader<'_>) -> Result<Self> {
Self::new(MessageType::MethodReturn).reply_to(reply_to)
}
pub fn error<'e: 'a, E>(reply_to: &MessageHeader<'_>, name: E) -> Result<Self>
where
E: TryInto<ErrorName<'e>>,
E::Error: Into<Error>,
{
Self::new(MessageType::Error)
.error_name(name)?
.reply_to(reply_to)
}
pub fn with_flags(mut self, flag: MessageFlags) -> Result<Self> {
if self.header.message_type()? != MessageType::MethodCall
&& BitFlags::from_flag(flag).contains(MessageFlags::NoReplyExpected)
{
return Err(Error::InvalidField);
}
let flags = self.header.primary().flags() | flag;
self.header.primary_mut().set_flags(flags);
Ok(self)
}
pub fn sender<'s: 'a, S>(mut self, sender: S) -> Result<Self>
where
S: TryInto<UniqueName<'s>>,
S::Error: Into<Error>,
{
self.header
.fields_mut()
.replace(MessageField::Sender(sender.try_into().map_err(Into::into)?));
Ok(self)
}
pub fn path<'p: 'a, P>(mut self, path: P) -> Result<Self>
where
P: TryInto<ObjectPath<'p>>,
P::Error: Into<Error>,
{
self.header
.fields_mut()
.replace(MessageField::Path(path.try_into().map_err(Into::into)?));
Ok(self)
}
pub fn interface<'i: 'a, I>(mut self, interface: I) -> Result<Self>
where
I: TryInto<InterfaceName<'i>>,
I::Error: Into<Error>,
{
self.header.fields_mut().replace(MessageField::Interface(
interface.try_into().map_err(Into::into)?,
));
Ok(self)
}
pub fn member<'m: 'a, M>(mut self, member: M) -> Result<Self>
where
M: TryInto<MemberName<'m>>,
M::Error: Into<Error>,
{
self.header
.fields_mut()
.replace(MessageField::Member(member.try_into().map_err(Into::into)?));
Ok(self)
}
fn error_name<'e: 'a, E>(mut self, error: E) -> Result<Self>
where
E: TryInto<ErrorName<'e>>,
E::Error: Into<Error>,
{
self.header.fields_mut().replace(MessageField::ErrorName(
error.try_into().map_err(Into::into)?,
));
Ok(self)
}
pub fn destination<'d: 'a, D>(mut self, destination: D) -> Result<Self>
where
D: TryInto<BusName<'d>>,
D::Error: Into<Error>,
{
self.header.fields_mut().replace(MessageField::Destination(
destination.try_into().map_err(Into::into)?,
));
Ok(self)
}
fn reply_to(mut self, reply_to: &MessageHeader<'_>) -> Result<Self> {
let serial = reply_to.primary().serial_num().ok_or(Error::MissingField)?;
self.header
.fields_mut()
.replace(MessageField::ReplySerial(*serial));
if let Some(sender) = reply_to.sender()? {
self.destination(sender.to_owned())
} else {
Ok(self)
}
}
pub fn build<B>(self, body: &B) -> Result<Message>
where
B: serde::ser::Serialize + DynamicType,
{
let ctxt = dbus_context!(0);
#[cfg(unix)]
let (body_len, fds_len) = zvariant::serialized_size_fds(ctxt, body)?;
#[cfg(not(unix))]
let body_len = zvariant::serialized_size(ctxt, body)?;
let signature = body.dynamic_signature();
self.build_generic(
signature,
body_len,
move |cursor| {
#[cfg(unix)]
{
let (_, fds) = zvariant::to_writer_fds(cursor, ctxt, body)?;
Ok::<Vec<RawFd>, Error>(fds)
}
#[cfg(not(unix))]
{
zvariant::to_writer(cursor, ctxt, body)?;
Ok::<(), Error>(())
}
},
#[cfg(unix)]
fds_len,
)
}
pub unsafe fn build_raw_body<'b, S>(
self,
body_bytes: &[u8],
signature: S,
#[cfg(unix)] fds: Vec<RawFd>,
) -> Result<Message>
where
S: TryInto<Signature<'b>>,
S::Error: Into<Error>,
{
let signature: Signature<'b> = signature.try_into().map_err(Into::into)?;
#[cfg(unix)]
let fds_len = fds.len();
self.build_generic(
signature,
body_bytes.len(),
move |cursor: &mut Cursor<&mut Vec<u8>>| {
cursor.write_all(body_bytes)?;
#[cfg(unix)]
return Ok::<Vec<RawFd>, Error>(fds);
#[cfg(not(unix))]
return Ok::<(), Error>(());
},
#[cfg(unix)]
fds_len,
)
}
fn build_generic<WriteFunc>(
self,
mut signature: Signature<'_>,
body_len: usize,
write_body: WriteFunc,
#[cfg(unix)] fds_len: usize,
) -> Result<Message>
where
WriteFunc: FnOnce(&mut Cursor<&mut Vec<u8>>) -> Result<BuildGenericResult>,
{
let ctxt = dbus_context!(0);
let mut header = self.header;
if !signature.is_empty() {
if signature.starts_with(zvariant::STRUCT_SIG_START_STR) {
signature = signature.slice(1..signature.len() - 1);
}
header.fields_mut().add(MessageField::Signature(signature));
}
let body_len_u32 = body_len.try_into().map_err(|_| Error::ExcessData)?;
header.primary_mut().set_body_len(body_len_u32);
#[cfg(unix)]
{
let fds_len_u32 = fds_len.try_into().map_err(|_| Error::ExcessData)?;
if fds_len != 0 {
header.fields_mut().add(MessageField::UnixFDs(fds_len_u32));
}
}
let hdr_len = zvariant::serialized_size(ctxt, &header)?;
let body_padding = padding_for_8_bytes(hdr_len);
let body_offset = hdr_len + body_padding;
let total_len = body_offset + body_len;
if total_len > MAX_MESSAGE_SIZE {
return Err(Error::ExcessData);
}
let mut bytes: Vec<u8> = Vec::with_capacity(total_len);
let mut cursor = Cursor::new(&mut bytes);
zvariant::to_writer(&mut cursor, ctxt, &header)?;
for _ in 0..body_padding {
cursor.write_all(&[0u8])?;
}
#[cfg(unix)]
let fds = write_body(&mut cursor)?;
#[cfg(not(unix))]
write_body(&mut cursor)?;
let primary_header = header.into_primary();
let header: MessageHeader<'_> = zvariant::from_slice(&bytes, ctxt)?;
let quick_fields = QuickMessageFields::new(&bytes, &header)?;
Ok(Message {
primary_header,
quick_fields,
bytes,
body_offset,
#[cfg(unix)]
fds: Arc::new(RwLock::new(Fds::Raw(fds))),
recv_seq: MessageSequence::default(),
})
}
}
#[cfg(unix)]
#[derive(Debug, Eq, PartialEq)]
enum Fds {
Owned(Vec<OwnedFd>),
Raw(Vec<RawFd>),
}
#[derive(Debug, Default, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct MessageSequence {
recv_seq: u64,
}
impl MessageSequence {
pub(crate) const LAST: Self = Self { recv_seq: u64::MAX };
}
#[derive(Clone)]
pub struct Message {
primary_header: MessagePrimaryHeader,
quick_fields: QuickMessageFields,
bytes: Vec<u8>,
body_offset: usize,
#[cfg(unix)]
fds: Arc<RwLock<Fds>>,
recv_seq: MessageSequence,
}
assert_impl_all!(Message: Send, Sync, Unpin);
impl Message {
pub fn method<'s, 'd, 'p, 'i, 'm, S, D, P, I, M, B>(
sender: Option<S>,
destination: Option<D>,
path: P,
iface: Option<I>,
method_name: M,
body: &B,
) -> Result<Self>
where
S: TryInto<UniqueName<'s>>,
D: TryInto<BusName<'d>>,
P: TryInto<ObjectPath<'p>>,
I: TryInto<InterfaceName<'i>>,
M: TryInto<MemberName<'m>>,
S::Error: Into<Error>,
D::Error: Into<Error>,
P::Error: Into<Error>,
I::Error: Into<Error>,
M::Error: Into<Error>,
B: serde::ser::Serialize + DynamicType,
{
let mut b = MessageBuilder::method_call(path, method_name)?;
if let Some(sender) = sender {
b = b.sender(sender)?;
}
if let Some(destination) = destination {
b = b.destination(destination)?;
}
if let Some(iface) = iface {
b = b.interface(iface)?;
}
b.build(body)
}
pub fn signal<'s, 'd, 'p, 'i, 'm, S, D, P, I, M, B>(
sender: Option<S>,
destination: Option<D>,
path: P,
iface: I,
signal_name: M,
body: &B,
) -> Result<Self>
where
S: TryInto<UniqueName<'s>>,
D: TryInto<BusName<'d>>,
P: TryInto<ObjectPath<'p>>,
I: TryInto<InterfaceName<'i>>,
M: TryInto<MemberName<'m>>,
S::Error: Into<Error>,
D::Error: Into<Error>,
P::Error: Into<Error>,
I::Error: Into<Error>,
M::Error: Into<Error>,
B: serde::ser::Serialize + DynamicType,
{
let mut b = MessageBuilder::signal(path, iface, signal_name)?;
if let Some(sender) = sender {
b = b.sender(sender)?;
}
if let Some(destination) = destination {
b = b.destination(destination)?;
}
b.build(body)
}
pub fn method_reply<'s, S, B>(sender: Option<S>, call: &Self, body: &B) -> Result<Self>
where
S: TryInto<UniqueName<'s>>,
S::Error: Into<Error>,
B: serde::ser::Serialize + DynamicType,
{
let mut b = MessageBuilder::method_return(&call.header()?)?;
if let Some(sender) = sender {
b = b.sender(sender)?;
}
b.build(body)
}
pub fn method_error<'s, 'e, S, E, B>(
sender: Option<S>,
call: &Self,
name: E,
body: &B,
) -> Result<Self>
where
S: TryInto<UniqueName<'s>>,
S::Error: Into<Error>,
E: TryInto<ErrorName<'e>>,
E::Error: Into<Error>,
B: serde::ser::Serialize + DynamicType,
{
let mut b = MessageBuilder::error(&call.header()?, name)?;
if let Some(sender) = sender {
b = b.sender(sender)?;
}
b.build(body)
}
pub(crate) fn from_raw_parts(
bytes: Vec<u8>,
#[cfg(unix)] fds: Vec<OwnedFd>,
recv_seq: u64,
) -> Result<Self> {
if EndianSig::try_from(bytes[0])? != NATIVE_ENDIAN_SIG {
return Err(Error::IncorrectEndian);
}
let (primary_header, fields_len) = MessagePrimaryHeader::read(&bytes)?;
let header = zvariant::from_slice(&bytes, dbus_context!(0))?;
#[cfg(unix)]
let fds = Arc::new(RwLock::new(Fds::Owned(fds)));
let header_len = MIN_MESSAGE_SIZE + fields_len as usize;
let body_offset = header_len + padding_for_8_bytes(header_len);
let quick_fields = QuickMessageFields::new(&bytes, &header)?;
Ok(Self {
primary_header,
quick_fields,
bytes,
body_offset,
#[cfg(unix)]
fds,
recv_seq: MessageSequence { recv_seq },
})
}
#[cfg(unix)]
pub fn take_fds(&self) -> Vec<OwnedFd> {
let mut fds_lock = self.fds.write().expect(LOCK_PANIC_MSG);
if let Fds::Owned(ref mut fds) = *fds_lock {
let fds = std::mem::take(&mut *fds);
*fds_lock = Fds::Raw(fds.iter().map(|fd| fd.as_raw_fd()).collect());
fds
} else {
vec![]
}
}
pub fn body_signature(&self) -> Result<Signature<'_>> {
match self
.header()?
.into_fields()
.into_field(MessageFieldCode::Signature)
.ok_or(Error::NoBodySignature)?
{
MessageField::Signature(signature) => Ok(signature),
_ => Err(Error::InvalidField),
}
}
pub fn primary_header(&self) -> &MessagePrimaryHeader {
&self.primary_header
}
pub(crate) fn modify_primary_header<F>(&mut self, mut modifier: F) -> Result<()>
where
F: FnMut(&mut MessagePrimaryHeader) -> Result<()>,
{
modifier(&mut self.primary_header)?;
let mut cursor = Cursor::new(&mut self.bytes);
zvariant::to_writer(&mut cursor, dbus_context!(0), &self.primary_header)
.map(|_| ())
.map_err(Error::from)
}
pub fn header(&self) -> Result<MessageHeader<'_>> {
zvariant::from_slice(&self.bytes, dbus_context!(0)).map_err(Error::from)
}
pub fn fields(&self) -> Result<MessageFields<'_>> {
let ctxt = dbus_context!(crate::PRIMARY_HEADER_SIZE);
zvariant::from_slice(&self.bytes[crate::PRIMARY_HEADER_SIZE..], ctxt).map_err(Error::from)
}
pub fn message_type(&self) -> MessageType {
self.primary_header.msg_type()
}
pub fn path(&self) -> Option<ObjectPath<'_>> {
self.quick_fields.path(self)
}
pub fn interface(&self) -> Option<InterfaceName<'_>> {
self.quick_fields.interface(self)
}
pub fn member(&self) -> Option<MemberName<'_>> {
self.quick_fields.member(self)
}
pub fn reply_serial(&self) -> Option<u32> {
self.quick_fields.reply_serial()
}
pub fn body_unchecked<'d, 'm: 'd, B>(&'m self) -> Result<B>
where
B: serde::de::Deserialize<'d> + Type,
{
{
#[cfg(unix)]
{
zvariant::from_slice_fds(
&self.bytes[self.body_offset..],
Some(&self.fds()),
dbus_context!(0),
)
}
#[cfg(not(unix))]
{
zvariant::from_slice(&self.bytes[self.body_offset..], dbus_context!(0))
}
}
.map_err(Error::from)
}
pub fn body<'d, 'm: 'd, B>(&'m self) -> Result<B>
where
B: zvariant::DynamicDeserialize<'d>,
{
let body_sig = match self.body_signature() {
Ok(sig) => sig,
Err(Error::NoBodySignature) => Signature::from_static_str_unchecked(""),
Err(e) => return Err(e),
};
{
#[cfg(unix)]
{
zvariant::from_slice_fds_for_dynamic_signature(
&self.bytes[self.body_offset..],
Some(&self.fds()),
dbus_context!(0),
&body_sig,
)
}
#[cfg(not(unix))]
{
zvariant::from_slice_for_dynamic_signature(
&self.bytes[self.body_offset..],
dbus_context!(0),
&body_sig,
)
}
}
.map_err(Error::from)
}
#[cfg(unix)]
pub(crate) fn fds(&self) -> Vec<RawFd> {
match &*self.fds.read().expect(LOCK_PANIC_MSG) {
Fds::Raw(fds) => fds.clone(),
Fds::Owned(fds) => fds.iter().map(|f| f.as_raw_fd()).collect(),
}
}
pub fn as_bytes(&self) -> &[u8] {
&self.bytes
}
pub fn body_as_bytes(&self) -> Result<&[u8]> {
Ok(&self.bytes[self.body_offset..])
}
pub fn recv_position(&self) -> MessageSequence {
self.recv_seq
}
}
impl fmt::Debug for Message {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut msg = f.debug_struct("Msg");
let _ = self.header().map(|h| {
if let Ok(t) = h.message_type() {
msg.field("type", &t);
}
if let Ok(Some(sender)) = h.sender() {
msg.field("sender", &sender);
}
if let Ok(Some(serial)) = h.reply_serial() {
msg.field("reply-serial", &serial);
}
if let Ok(Some(path)) = h.path() {
msg.field("path", &path);
}
if let Ok(Some(iface)) = h.interface() {
msg.field("iface", &iface);
}
if let Ok(Some(member)) = h.member() {
msg.field("member", &member);
}
});
if let Ok(s) = self.body_signature() {
msg.field("body", &s);
}
#[cfg(unix)]
{
let fds = self.fds();
if !fds.is_empty() {
msg.field("fds", &fds);
}
}
msg.finish()
}
}
impl fmt::Display for Message {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let header = self.header();
let (ty, error_name, sender, member) = if let Ok(h) = header.as_ref() {
(
h.message_type().ok(),
h.error_name().ok().flatten(),
h.sender().ok().flatten(),
h.member().ok().flatten(),
)
} else {
(None, None, None, None)
};
match ty {
Some(MessageType::MethodCall) => {
write!(f, "Method call")?;
if let Some(m) = member {
write!(f, " {m}")?;
}
}
Some(MessageType::MethodReturn) => {
write!(f, "Method return")?;
}
Some(MessageType::Error) => {
write!(f, "Error")?;
if let Some(e) = error_name {
write!(f, " {e}")?;
}
let msg = self.body_unchecked::<&str>();
if let Ok(msg) = msg {
write!(f, ": {msg}")?;
}
}
Some(MessageType::Signal) => {
write!(f, "Signal")?;
if let Some(m) = member {
write!(f, " {m}")?;
}
}
_ => {
write!(f, "Unknown message")?;
}
}
if let Some(s) = sender {
write!(f, " from {s}")?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
#[cfg(unix)]
use std::os::unix::io::AsRawFd;
use test_log::test;
#[cfg(unix)]
use zvariant::Fd;
#[cfg(unix)]
use super::Fds;
use super::{Message, MessageBuilder};
use crate::Error;
#[test]
fn test() {
#[cfg(unix)]
let stdout = std::io::stdout();
let m = Message::method(
Some(":1.72"),
None::<()>,
"/",
None::<()>,
"do",
&(
#[cfg(unix)]
Fd::from(&stdout),
"foo",
),
)
.unwrap();
assert_eq!(
m.body_signature().unwrap().to_string(),
if cfg!(unix) { "hs" } else { "s" }
);
#[cfg(unix)]
assert_eq!(*m.fds.read().unwrap(), Fds::Raw(vec![stdout.as_raw_fd()]));
let body: Result<u32, Error> = m.body();
assert!(matches!(
body.unwrap_err(),
Error::Variant(zvariant::Error::SignatureMismatch { .. })
));
assert_eq!(m.to_string(), "Method call do from :1.72");
let r = Message::method_reply(None::<()>, &m, &("all fine!")).unwrap();
assert_eq!(r.to_string(), "Method return");
let e = Message::method_error(
None::<()>,
&m,
"org.freedesktop.zbus.Error",
&("kaboom!", 32),
)
.unwrap();
assert_eq!(e.to_string(), "Error org.freedesktop.zbus.Error: kaboom!");
}
#[test]
fn test_raw() -> Result<(), Error> {
let raw_body: &[u8] = &[16, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0];
let message_builder = MessageBuilder::signal("/", "test.test", "test")?;
let message = unsafe {
message_builder.build_raw_body(
raw_body,
"ai",
#[cfg(unix)]
vec![],
)?
};
let output: Vec<i32> = message.body()?;
assert_eq!(output, vec![1, 2, 3, 4]);
Ok(())
}
}