use std::any::{Any, TypeId};
use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::fmt::Write;
use std::marker::PhantomData;
use std::rc::Rc;
use scoped_tls::scoped_thread_local;
use zvariant::{ObjectPath, OwnedValue, Value};
use crate::{dbus_interface, fdo, Connection, Error, Message, MessageHeader, MessageType, Result};
scoped_thread_local!(static LOCAL_NODE: Node);
scoped_thread_local!(static LOCAL_CONNECTION: Connection);
pub trait Interface: Any {
fn name() -> &'static str
where
Self: Sized;
fn get(&self, property_name: &str) -> Option<fdo::Result<OwnedValue>>;
fn get_all(&self) -> HashMap<String, OwnedValue>;
fn set(&mut self, property_name: &str, value: &Value) -> Option<fdo::Result<()>>;
fn call(&self, connection: &Connection, msg: &Message, name: &str) -> Option<Result<u32>>;
fn call_mut(
&mut self,
connection: &Connection,
msg: &Message,
name: &str,
) -> Option<Result<u32>>;
fn introspect_to_writer(&self, writer: &mut dyn Write, level: usize);
}
impl dyn Interface {
fn downcast_ref<T: Any>(&self) -> Option<&T> {
if <dyn Interface as Any>::type_id(self) == TypeId::of::<T>() {
Some(unsafe { &*(self as *const dyn Interface as *const T) })
} else {
None
}
}
}
struct Introspectable;
#[dbus_interface(name = "org.freedesktop.DBus.Introspectable")]
impl Introspectable {
fn introspect(&self) -> String {
LOCAL_NODE.with(|node| node.introspect())
}
}
struct Peer;
#[dbus_interface(name = "org.freedesktop.DBus.Peer")]
impl Peer {
fn ping(&self) {}
fn get_machine_id(&self) -> fdo::Result<String> {
let mut id = match std::fs::read_to_string("/var/lib/dbus/machine-id") {
Ok(id) => id,
Err(e) => {
if let Ok(id) = std::fs::read_to_string("/etc/machine-id") {
id
} else {
return Err(fdo::Error::IOError(format!(
"Failed to read from /var/lib/dbus/machine-id or /etc/machine-id: {}",
e
)));
}
}
};
let len = id.trim_end().len();
id.truncate(len);
Ok(id)
}
}
struct Properties;
#[dbus_interface(name = "org.freedesktop.DBus.Properties")]
impl Properties {
fn get(&self, interface_name: &str, property_name: &str) -> fdo::Result<OwnedValue> {
LOCAL_NODE.with(|node| {
let iface = node.get_interface(interface_name).ok_or_else(|| {
fdo::Error::UnknownInterface(format!("Unknown interface '{}'", interface_name))
})?;
let res = iface.borrow().get(property_name);
res.ok_or_else(|| {
fdo::Error::UnknownProperty(format!("Unknown property '{}'", property_name))
})?
})
}
fn set(
&mut self,
interface_name: &str,
property_name: &str,
value: OwnedValue,
) -> fdo::Result<()> {
LOCAL_NODE.with(|node| {
let iface = node.get_interface(interface_name).ok_or_else(|| {
fdo::Error::UnknownInterface(format!("Unknown interface '{}'", interface_name))
})?;
let res = iface.borrow_mut().set(property_name, &value);
res.ok_or_else(|| {
fdo::Error::UnknownProperty(format!("Unknown property '{}'", property_name))
})?
})
}
fn get_all(&self, interface_name: &str) -> fdo::Result<HashMap<String, OwnedValue>> {
LOCAL_NODE.with(|node| {
let iface = node.get_interface(interface_name).ok_or_else(|| {
fdo::Error::UnknownInterface(format!("Unknown interface '{}'", interface_name))
})?;
let res = iface.borrow().get_all();
Ok(res)
})
}
#[dbus_interface(signal)]
fn properties_changed(
&self,
interface_name: &str,
changed_properties: &HashMap<&str, &Value>,
invalidated_properties: &[&str],
) -> Result<()>;
}
#[derive(Default, derivative::Derivative)]
#[derivative(Debug)]
struct Node {
path: String,
children: HashMap<String, Node>,
#[derivative(Debug = "ignore")]
interfaces: HashMap<&'static str, Rc<RefCell<dyn Interface>>>,
}
impl Node {
fn new(path: &str) -> Self {
let mut node = Self {
path: path.to_string(),
..Default::default()
};
node.at(Peer::name(), Peer);
node.at(Introspectable::name(), Introspectable);
node.at(Properties::name(), Properties);
node
}
fn get_interface(&self, iface: &str) -> Option<Rc<RefCell<dyn Interface>>> {
self.interfaces.get(iface).cloned()
}
fn at<I>(&mut self, name: &'static str, iface: I) -> bool
where
I: Interface,
{
match self.interfaces.entry(name) {
Entry::Vacant(e) => e.insert(Rc::new(RefCell::new(iface))),
Entry::Occupied(_) => return false,
};
true
}
fn with_iface_func<F, I>(&self, func: F) -> Result<()>
where
F: Fn(&I) -> Result<()>,
I: Interface,
{
let iface = self
.interfaces
.get(I::name())
.ok_or(Error::InterfaceNotFound)?
.borrow();
let iface = iface.downcast_ref::<I>().ok_or(Error::InterfaceNotFound)?;
func(iface)
}
fn introspect_to_writer<W: Write>(&self, writer: &mut W, level: usize) {
if level == 0 {
writeln!(
writer,
r#"
<!DOCTYPE node PUBLIC "-//freedesktop//DTD D-BUS Object Introspection 1.0//EN"
"http://www.freedesktop.org/standards/dbus/1.0/introspect.dtd">
<node>"#
)
.unwrap();
}
for iface in self.interfaces.values() {
iface.borrow().introspect_to_writer(writer, level + 2);
}
for (path, node) in &self.children {
let level = level + 2;
writeln!(
writer,
"{:indent$}<node name=\"{}\">",
"",
path,
indent = level
)
.unwrap();
node.introspect_to_writer(writer, level);
writeln!(writer, "{:indent$}</node>", "", indent = level).unwrap();
}
if level == 0 {
writeln!(writer, "</node>").unwrap();
}
}
fn introspect(&self) -> String {
let mut xml = String::with_capacity(1024);
self.introspect_to_writer(&mut xml, 0);
xml
}
fn emit_signal<B>(
&self,
dest: Option<&str>,
iface: &str,
signal_name: &str,
body: &B,
) -> Result<()>
where
B: serde::ser::Serialize + zvariant::Type,
{
if !LOCAL_CONNECTION.is_set() {
panic!("emit_signal: Connection TLS not set");
}
LOCAL_CONNECTION.with(|conn| conn.emit_signal(dest, &self.path, iface, signal_name, body))
}
}
#[derive(Debug)]
pub struct ObjectServer<'a> {
conn: Connection,
root: Node,
phantom: PhantomData<&'a ()>,
}
impl<'a> ObjectServer<'a> {
pub fn new(connection: &Connection) -> Self {
Self {
conn: connection.clone(),
root: Node::new("/"),
phantom: PhantomData,
}
}
fn get_node(&self, path: &ObjectPath) -> Option<&Node> {
let mut node = &self.root;
let mut node_path = String::new();
for i in path.split('/').skip(1) {
if i.is_empty() {
continue;
}
write!(&mut node_path, "/{}", i).unwrap();
match node.children.get(i) {
Some(n) => node = n,
None => return None,
}
}
Some(node)
}
fn get_node_mut(&mut self, path: &ObjectPath, create: bool) -> Option<&mut Node> {
let mut node = &mut self.root;
let mut node_path = String::new();
for i in path.split('/').skip(1) {
if i.is_empty() {
continue;
}
write!(&mut node_path, "/{}", i).unwrap();
match node.children.entry(i.into()) {
Entry::Vacant(e) => {
if create {
node = e.insert(Node::new(&node_path));
} else {
return None;
}
}
Entry::Occupied(e) => node = e.into_mut(),
}
}
Some(node)
}
pub fn at<I>(&mut self, path: &ObjectPath, iface: I) -> Result<bool>
where
I: Interface,
{
Ok(self.get_node_mut(path, true).unwrap().at(I::name(), iface))
}
pub fn with<F, I>(&self, path: &ObjectPath, func: F) -> Result<()>
where
F: Fn(&I) -> Result<()>,
I: Interface,
{
let node = self.get_node(path).ok_or(Error::InterfaceNotFound)?;
LOCAL_CONNECTION.set(&self.conn, || {
LOCAL_NODE.set(node, || node.with_iface_func(func))
})
}
pub fn local_node_emit_signal<B>(
destination: Option<&str>,
iface: &str,
signal_name: &str,
body: &B,
) -> Result<()>
where
B: serde::ser::Serialize + zvariant::Type,
{
if !LOCAL_NODE.is_set() {
panic!("emit_signal: Node TLS not set");
}
LOCAL_NODE.with(|n| n.emit_signal(destination, iface, signal_name, body))
}
fn dispatch_method_call_try(
&mut self,
msg_header: &MessageHeader,
msg: &Message,
) -> fdo::Result<Result<u32>> {
let conn = self.conn.clone();
let path = msg_header
.path()
.ok()
.flatten()
.ok_or_else(|| fdo::Error::Failed("Missing object path".into()))?;
let iface = msg_header
.interface()
.ok()
.flatten()
.ok_or_else(|| fdo::Error::Failed("Missing interface".into()))?;
let member = msg_header
.member()
.ok()
.flatten()
.ok_or_else(|| fdo::Error::Failed("Missing member".into()))?;
let node = self
.get_node_mut(&path, false)
.ok_or_else(|| fdo::Error::UnknownObject(format!("Unknown object '{}'", path)))?;
let iface = node.get_interface(iface).ok_or_else(|| {
fdo::Error::UnknownInterface(format!("Unknown interface '{}'", iface))
})?;
LOCAL_CONNECTION.set(&conn, || {
LOCAL_NODE.set(node, || {
let res = iface.borrow().call(&conn, &msg, member);
res.or_else(|| iface.borrow_mut().call_mut(&conn, &msg, member))
.ok_or_else(|| {
fdo::Error::UnknownMethod(format!("Unknown method '{}'", member))
})
})
})
}
fn dispatch_method_call(&mut self, msg_header: &MessageHeader, msg: &Message) -> Result<u32> {
match self.dispatch_method_call_try(msg_header, msg) {
Err(e) => e.reply(&self.conn, msg),
Ok(r) => r,
}
}
pub fn dispatch_message(&mut self, msg: &Message) -> Result<bool> {
let msg_header = msg.header()?;
match msg_header.message_type()? {
MessageType::MethodCall => {
self.dispatch_method_call(&msg_header, &msg)?;
Ok(true)
}
_ => Ok(false),
}
}
pub fn try_handle_next(&mut self) -> Result<Option<Message>> {
let msg = self.conn.receive_message()?;
if !self.dispatch_message(&msg)? {
return Ok(Some(msg));
}
Ok(None)
}
}
#[cfg(test)]
mod tests {
use std::cell::RefCell;
use std::convert::TryInto;
use std::error::Error;
use std::rc::Rc;
use std::thread;
use ntest::timeout;
use serde::{Deserialize, Serialize};
use zvariant::derive::Type;
use crate::fdo;
use crate::{dbus_interface, dbus_proxy, Connection, MessageHeader, MessageType, ObjectServer};
#[derive(Deserialize, Serialize, Type)]
pub struct ArgStructTest {
foo: i32,
bar: String,
}
#[dbus_proxy]
trait MyIface {
fn ping(&self) -> zbus::Result<u32>;
fn quit(&self, val: bool) -> zbus::Result<()>;
fn test_header(&self) -> zbus::Result<()>;
fn test_error(&self) -> zbus::Result<()>;
fn test_single_struct_arg(&self, arg: ArgStructTest) -> zbus::Result<()>;
#[dbus_proxy(property)]
fn count(&self) -> fdo::Result<u32>;
#[dbus_proxy(property)]
fn set_count(&self, count: u32) -> fdo::Result<()>;
}
#[derive(Debug)]
struct MyIfaceImpl {
quit: Rc<RefCell<bool>>,
count: u32,
}
impl MyIfaceImpl {
fn new(quit: Rc<RefCell<bool>>) -> Self {
Self { quit, count: 0 }
}
}
#[dbus_interface(interface = "org.freedesktop.MyIface")]
impl MyIfaceImpl {
fn ping(&mut self) -> u32 {
self.count += 1;
if self.count % 3 == 0 {
self.alert_count(self.count).expect("Failed to emit signal");
}
self.count
}
fn quit(&mut self, val: bool) {
*self.quit.borrow_mut() = val;
}
fn test_header(&self, #[zbus(header)] header: MessageHeader<'_>) {
assert_eq!(header.message_type().unwrap(), MessageType::MethodCall);
assert_eq!(header.member().unwrap(), Some("TestHeader"));
}
fn test_error(&self) -> zbus::fdo::Result<()> {
Err(zbus::fdo::Error::Failed("error raised".to_string()))
}
fn test_single_struct_arg(&self, arg: ArgStructTest) {
assert_eq!(arg.foo, 1);
assert_eq!(arg.bar, "TestString");
}
#[dbus_interface(property)]
fn set_count(&mut self, val: u32) -> zbus::fdo::Result<()> {
if val == 42 {
return Err(zbus::fdo::Error::InvalidArgs("Tsss tsss!".to_string()));
}
self.count = val;
Ok(())
}
#[dbus_interface(property)]
fn count(&self) -> u32 {
self.count
}
#[dbus_interface(signal)]
fn alert_count(&self, val: u32) -> zbus::Result<()>;
}
fn my_iface_test() -> std::result::Result<u32, Box<dyn Error>> {
let conn = Connection::new_session()?;
let proxy = MyIfaceProxy::new_for(
&conn,
"org.freedesktop.MyService",
"/org/freedesktop/MyService",
)?;
proxy.ping()?;
assert_eq!(proxy.count()?, 1);
proxy.test_header()?;
proxy.test_single_struct_arg(ArgStructTest {
foo: 1,
bar: "TestString".into(),
})?;
proxy.introspect()?;
let val = proxy.ping()?;
proxy.quit(true)?;
Ok(val)
}
#[test]
#[timeout(2000)]
fn basic_iface() {
let conn = Connection::new_session().unwrap();
let mut object_server = ObjectServer::new(&conn);
let quit = Rc::new(RefCell::new(false));
fdo::DBusProxy::new(&conn)
.unwrap()
.request_name(
"org.freedesktop.MyService",
fdo::RequestNameFlags::ReplaceExisting.into(),
)
.unwrap();
let iface = MyIfaceImpl::new(quit.clone());
object_server
.at(&"/org/freedesktop/MyService".try_into().unwrap(), iface)
.unwrap();
let child = thread::spawn(|| my_iface_test().expect("child failed"));
loop {
let m = conn.receive_message().unwrap();
if let Err(e) = object_server.dispatch_message(&m) {
eprintln!("{}", e);
}
object_server
.with(
&"/org/freedesktop/MyService".try_into().unwrap(),
|iface: &MyIfaceImpl| iface.alert_count(51),
)
.unwrap();
if *quit.borrow() {
break;
}
}
let val = child.join().expect("failed to join");
assert_eq!(val, 2);
}
}