stardust_xr_schemas/dbus/
query.rs

1use crate::dbus::{
2	ObjectInfo,
3	interfaces::SpatialRefProxy,
4	list_query::{ListQueryMapper, ObjectListQuery},
5	object_registry::{ObjectRegistry, Objects},
6};
7use std::{
8	collections::{HashMap, HashSet},
9	marker::PhantomData,
10	ops::{Deref, DerefMut},
11	sync::Arc,
12	time::{Duration, Instant},
13};
14use tokio::{
15	sync::{
16		RwLock, RwLockReadGuard,
17		broadcast::error::RecvError,
18		mpsc::{self, error::TryRecvError},
19		watch,
20	},
21	task::{AbortHandle, JoinSet},
22	time::timeout,
23};
24use variadics_please::all_tuples;
25use zbus::{
26	Connection, Proxy, fdo,
27	names::{BusName, InterfaceName, OwnedBusName},
28	proxy::{Defaults, ProxyImpl},
29	zvariant::{ObjectPath, OwnedObjectPath},
30};
31
32pub struct ObjectQuery<Q: Queryable<Ctx>, Ctx: QueryContext> {
33	update_task_handle: AbortHandle,
34	event_reader: mpsc::Receiver<QueryEvent<Q, Ctx>>,
35}
36
37pub trait Queryable<Ctx: QueryContext>: Sized + 'static + Send + Sync {
38	fn try_new(
39		connection: &Connection,
40		ctx: &Arc<Ctx>,
41		object: &ObjectInfo,
42		contains_interface: &(impl Fn(&InterfaceName) -> bool + Send + Sync),
43	) -> impl std::future::Future<Output = Option<Self>> + Send;
44}
45pub trait QueryContext: Sized + 'static + Send + Sync {}
46impl QueryContext for () {}
47
48pub enum QueryEvent<Q: Queryable<Ctx> + Send + Sync, Ctx: QueryContext> {
49	NewMatch(ObjectInfo, Q),
50	MatchModified(ObjectInfo, Q),
51	MatchLost(ObjectInfo),
52	PhantomVariant(PhantomData<Ctx>),
53}
54impl<Q: Queryable<Ctx> + Send + Sync + std::fmt::Debug, Ctx: QueryContext + std::fmt::Debug>
55	std::fmt::Debug for QueryEvent<Q, Ctx>
56{
57	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58		match self {
59			Self::NewMatch(arg0, arg1) => {
60				f.debug_tuple("NewMatch").field(arg0).field(arg1).finish()
61			}
62			Self::MatchModified(arg0, arg1) => f
63				.debug_tuple("MatchModified")
64				.field(arg0)
65				.field(arg1)
66				.finish(),
67			Self::MatchLost(arg0) => f.debug_tuple("MatchLost").field(arg0).finish(),
68			Self::PhantomVariant(arg0) => f.debug_tuple("PhantomVariant").field(arg0).finish(),
69		}
70	}
71}
72
73#[macro_export]
74macro_rules! impl_queryable_for_proxy {
75	($($T:ident),*) => {
76		$(impl<Ctx: $crate::dbus::query::QueryContext> $crate::dbus::query::Queryable<Ctx> for $T<'static> {
77			async fn try_new(
78				connection: &::zbus::Connection,
79				_ctx: &std::sync::Arc<Ctx>,
80				object: &$crate::dbus::ObjectInfo,
81				contains_interface: &(impl Fn(&zbus::names::InterfaceName) -> bool + Send + Sync),
82			) -> Option<Self> {
83				use ::zbus::proxy::Defaults;
84				let interface = $T::INTERFACE.as_ref()?;
85				if !contains_interface(&interface) {
86					return None;
87				}
88				object.to_typed_proxy::<Self>(connection).await.ok()
89			}
90		})*
91	};
92}
93
94impl<Q, Ctx> ObjectQuery<Q, Ctx>
95where
96	Ctx: QueryContext,
97	Q: Queryable<Ctx>,
98{
99	pub fn new(object_registry: Arc<ObjectRegistry>, context: impl Into<Arc<Ctx>>) -> Self {
100		let (tx, rx) = mpsc::channel(32);
101		let update_task_handle =
102			tokio::spawn(Self::update_task(context.into(), object_registry, tx)).abort_handle();
103		Self {
104			update_task_handle,
105			event_reader: rx,
106		}
107	}
108	pub async fn recv_event(&mut self) -> Option<QueryEvent<Q, Ctx>> {
109		self.event_reader.recv().await
110	}
111	pub fn try_recv_event(&mut self) -> Result<QueryEvent<Q, Ctx>, TryRecvError> {
112		self.event_reader.try_recv()
113	}
114	pub fn to_list_query<T: Send + Sync + 'static>(
115		self,
116	) -> (ObjectListQuery<T>, ListQueryMapper<Q, T, Ctx>) {
117		ObjectListQuery::from_query(self)
118	}
119}
120
121impl<Q: Queryable<Ctx>, Ctx: QueryContext> Queryable<Ctx> for Option<Q> {
122	async fn try_new(
123		connection: &Connection,
124		ctx: &Arc<Ctx>,
125		object: &ObjectInfo,
126		contains_interface: &(impl Fn(&InterfaceName) -> bool + Send + Sync),
127	) -> Option<Self> {
128		Some(Q::try_new(connection, ctx, object, contains_interface).await)
129	}
130}
131
132macro_rules! impl_queryable {
133    ($($T:ident),*) => {
134        impl<Ctx: QueryContext, $($T: Queryable<Ctx>),*> Queryable<Ctx> for ($($T,)*) {
135			#[allow(unused_variables)]
136			async fn try_new(
137				connection: &Connection,
138				ctx: &Arc<Ctx>,
139				object: &ObjectInfo,
140				contains_interface: &(impl Fn(&zbus::names::InterfaceName) -> bool + Send + Sync),
141			) -> Option<Self> {
142				Some(($($T::try_new(connection, ctx, object, contains_interface).await?,)*))
143			}
144        }
145    };
146}
147
148all_tuples!(impl_queryable, 0, 15, T);
149
150impl<Q: Queryable<Ctx>, Ctx: QueryContext> ObjectQuery<Q, Ctx> {
151	async fn new_match(
152		tx: &mpsc::Sender<QueryEvent<Q, Ctx>>,
153		matching_objects: &mut HashSet<ObjectInfo>,
154		object: ObjectInfo,
155		data: Q,
156	) {
157		matching_objects.insert(object.clone());
158		_ = tx.send(QueryEvent::NewMatch(object, data)).await;
159	}
160	async fn match_lost(
161		tx: &mpsc::Sender<QueryEvent<Q, Ctx>>,
162		object: ObjectInfo,
163		matching_objects: &mut HashSet<ObjectInfo>,
164	) {
165		matching_objects.remove(&object);
166		_ = tx.send(QueryEvent::MatchLost(object)).await;
167	}
168	async fn update_task(
169		ctx: Arc<Ctx>,
170		object_registry: Arc<ObjectRegistry>,
171		tx: mpsc::Sender<QueryEvent<Q, Ctx>>,
172	) -> zbus::Result<()> {
173		let mut recv = object_registry.get_object_events_receiver();
174		let mut matching_objects: HashSet<ObjectInfo> = HashSet::new();
175		let connection = object_registry.get_connection();
176		let watch = object_registry.get_watch();
177		let v = watch.borrow().object_to_interfaces.clone();
178
179		for (object, interfaces) in v {
180			let data = Q::try_new(connection, &ctx, &object, &|i| {
181				interfaces.iter().any(|f| i == f)
182			})
183			.await;
184			let already_matching = matching_objects.contains(&object);
185			match (data, already_matching) {
186				(None, true) => Self::match_lost(&tx, object, &mut matching_objects).await,
187				(None, false) => {}
188				(Some(data), true) => {
189					_ = tx.send(QueryEvent::MatchModified(object, data)).await;
190				}
191				(Some(data), false) => {
192					Self::new_match(&tx, &mut matching_objects, object, data).await
193				}
194			}
195		}
196
197		loop {
198			let object_event = match recv.recv().await {
199				Ok(objs) => objs,
200				Err(RecvError::Closed) => break,
201				Err(RecvError::Lagged(_)) => continue,
202			};
203			let Some(v) = watch
204				.borrow()
205				.object_to_interfaces
206				.get(&object_event.object)
207				.cloned()
208			else {
209				Self::match_lost(&tx, object_event.object, &mut matching_objects).await;
210				continue;
211			};
212			let data = Q::try_new(connection, &ctx, &object_event.object, &|i| {
213				v.iter().any(|j| i == j)
214			})
215			.await;
216
217			let already_matching = matching_objects.contains(&object_event.object);
218			match (data, already_matching) {
219				(None, true) => {
220					Self::match_lost(&tx, object_event.object, &mut matching_objects).await
221				}
222				(None, false) => {}
223				(Some(data), true) => {
224					_ = tx
225						.send(QueryEvent::MatchModified(object_event.object, data))
226						.await;
227				}
228				(Some(data), false) => {
229					Self::new_match(&tx, &mut matching_objects, object_event.object, data).await
230				}
231			}
232		}
233
234		Ok(())
235	}
236}
237
238#[tokio::test]
239async fn query_test() {
240	use crate::dbus::{
241		object_registry::ObjectRegistry,
242		query::{ObjectQuery, QueryContext},
243	};
244	use std::{
245		sync::{
246			Arc,
247			atomic::{AtomicBool, Ordering},
248		},
249		thread,
250		time::Duration,
251	};
252	use tokio::{sync::Notify, time::sleep};
253	use zbus::{Connection, fdo::ObjectManager, interface};
254
255	struct TestInterface;
256	#[interface(name = "org.stardustxr.Query.TestInterface", proxy())]
257	impl TestInterface {
258		fn hello(&self) {
259			println!("hello world");
260		}
261	}
262	impl_queryable_for_proxy!(TestInterfaceProxy);
263
264	tokio::task::spawn(async {
265		tokio::time::sleep(Duration::from_secs(10)).await;
266		panic!("Took too long to run");
267	});
268
269	let service_conn = zbus::conn::Builder::session()
270		.unwrap()
271		.serve_at("/", zbus::fdo::ObjectManager)
272		.unwrap()
273		.serve_at("/org/stardustxr/TestObject", TestInterface)
274		.unwrap()
275		.build()
276		.await
277		.unwrap();
278	println!("name: {:?}", service_conn.unique_name());
279
280	let scan_conn = Connection::session().await.unwrap();
281	let object_registry = ObjectRegistry::new(&scan_conn).await;
282
283	println!(
284		"Objects updated: {:#?}",
285		object_registry.get_watch().borrow().clone()
286	);
287	// let mut watch = object_registry.get_watch();
288	// tokio::task::spawn(async move {
289	// 	while !matches!(watch.changed().await, Ok(())) {
290	// 		println!("Object registry changed: {:#?}", watch.borrow());
291	// 	}
292	// });
293
294	let mut query = ObjectQuery::<TestInterfaceProxy, ()>::new(object_registry, ());
295
296	while let Some(e) = query.recv_event().await {
297		println!("New event: {e:#?}");
298		if let QueryEvent::NewMatch(object_info, p) = e {
299			println!("New match to query, {object_info:#?}");
300			if object_info.object_path.as_str() == "/org/stardustxr/TestObject" {
301				p.hello().await.unwrap();
302				break;
303			}
304		}
305	}
306
307	drop(service_conn);
308	println!("Dropping the other connection");
309	while let Some(e) = query.recv_event().await {
310		println!("New event: {e:#?}");
311		if let QueryEvent::MatchLost(object_info) = e {
312			println!("Dropped match to query, {object_info:#?}");
313			if object_info.object_path.as_str() == "/org/stardustxr/TestObject" {
314				break;
315			}
316		}
317	}
318}