1use crate::dbus::{
2 ObjectInfo,
3 interfaces::SpatialRefProxy,
4 list_query::{ListQueryMapper, ObjectListQuery},
5 object_registry::{ObjectRegistry, Objects},
6};
7use std::{
8 collections::{HashMap, HashSet},
9 marker::PhantomData,
10 ops::{Deref, DerefMut},
11 sync::Arc,
12 time::{Duration, Instant},
13};
14use tokio::{
15 sync::{
16 RwLock, RwLockReadGuard,
17 broadcast::error::RecvError,
18 mpsc::{self, error::TryRecvError},
19 watch,
20 },
21 task::{AbortHandle, JoinSet},
22 time::timeout,
23};
24use variadics_please::all_tuples;
25use zbus::{
26 Connection, Proxy, fdo,
27 names::{BusName, InterfaceName, OwnedBusName},
28 proxy::{Defaults, ProxyImpl},
29 zvariant::{ObjectPath, OwnedObjectPath},
30};
31
32pub struct ObjectQuery<Q: Queryable<Ctx>, Ctx: QueryContext> {
33 update_task_handle: AbortHandle,
34 event_reader: mpsc::Receiver<QueryEvent<Q, Ctx>>,
35}
36
37pub trait Queryable<Ctx: QueryContext>: Sized + 'static + Send + Sync {
38 fn try_new(
39 connection: &Connection,
40 ctx: &Arc<Ctx>,
41 object: &ObjectInfo,
42 contains_interface: &(impl Fn(&InterfaceName) -> bool + Send + Sync),
43 ) -> impl std::future::Future<Output = Option<Self>> + Send;
44}
45pub trait QueryContext: Sized + 'static + Send + Sync {}
46impl QueryContext for () {}
47
48pub enum QueryEvent<Q: Queryable<Ctx> + Send + Sync, Ctx: QueryContext> {
49 NewMatch(ObjectInfo, Q),
50 MatchModified(ObjectInfo, Q),
51 MatchLost(ObjectInfo),
52 PhantomVariant(PhantomData<Ctx>),
53}
54impl<Q: Queryable<Ctx> + Send + Sync + std::fmt::Debug, Ctx: QueryContext + std::fmt::Debug>
55 std::fmt::Debug for QueryEvent<Q, Ctx>
56{
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 match self {
59 Self::NewMatch(arg0, arg1) => {
60 f.debug_tuple("NewMatch").field(arg0).field(arg1).finish()
61 }
62 Self::MatchModified(arg0, arg1) => f
63 .debug_tuple("MatchModified")
64 .field(arg0)
65 .field(arg1)
66 .finish(),
67 Self::MatchLost(arg0) => f.debug_tuple("MatchLost").field(arg0).finish(),
68 Self::PhantomVariant(arg0) => f.debug_tuple("PhantomVariant").field(arg0).finish(),
69 }
70 }
71}
72
73#[macro_export]
74macro_rules! impl_queryable_for_proxy {
75 ($($T:ident),*) => {
76 $(impl<Ctx: $crate::dbus::query::QueryContext> $crate::dbus::query::Queryable<Ctx> for $T<'static> {
77 async fn try_new(
78 connection: &::zbus::Connection,
79 _ctx: &std::sync::Arc<Ctx>,
80 object: &$crate::dbus::ObjectInfo,
81 contains_interface: &(impl Fn(&zbus::names::InterfaceName) -> bool + Send + Sync),
82 ) -> Option<Self> {
83 use ::zbus::proxy::Defaults;
84 let interface = $T::INTERFACE.as_ref()?;
85 if !contains_interface(&interface) {
86 return None;
87 }
88 object.to_typed_proxy::<Self>(connection).await.ok()
89 }
90 })*
91 };
92}
93
94impl<Q, Ctx> ObjectQuery<Q, Ctx>
95where
96 Ctx: QueryContext,
97 Q: Queryable<Ctx>,
98{
99 pub fn new(object_registry: Arc<ObjectRegistry>, context: impl Into<Arc<Ctx>>) -> Self {
100 let (tx, rx) = mpsc::channel(32);
101 let update_task_handle =
102 tokio::spawn(Self::update_task(context.into(), object_registry, tx)).abort_handle();
103 Self {
104 update_task_handle,
105 event_reader: rx,
106 }
107 }
108 pub async fn recv_event(&mut self) -> Option<QueryEvent<Q, Ctx>> {
109 self.event_reader.recv().await
110 }
111 pub fn try_recv_event(&mut self) -> Result<QueryEvent<Q, Ctx>, TryRecvError> {
112 self.event_reader.try_recv()
113 }
114 pub fn to_list_query<T: Send + Sync + 'static>(
115 self,
116 ) -> (ObjectListQuery<T>, ListQueryMapper<Q, T, Ctx>) {
117 ObjectListQuery::from_query(self)
118 }
119}
120
121impl<Q: Queryable<Ctx>, Ctx: QueryContext> Queryable<Ctx> for Option<Q> {
122 async fn try_new(
123 connection: &Connection,
124 ctx: &Arc<Ctx>,
125 object: &ObjectInfo,
126 contains_interface: &(impl Fn(&InterfaceName) -> bool + Send + Sync),
127 ) -> Option<Self> {
128 Some(Q::try_new(connection, ctx, object, contains_interface).await)
129 }
130}
131
132macro_rules! impl_queryable {
133 ($($T:ident),*) => {
134 impl<Ctx: QueryContext, $($T: Queryable<Ctx>),*> Queryable<Ctx> for ($($T,)*) {
135 #[allow(unused_variables)]
136 async fn try_new(
137 connection: &Connection,
138 ctx: &Arc<Ctx>,
139 object: &ObjectInfo,
140 contains_interface: &(impl Fn(&zbus::names::InterfaceName) -> bool + Send + Sync),
141 ) -> Option<Self> {
142 Some(($($T::try_new(connection, ctx, object, contains_interface).await?,)*))
143 }
144 }
145 };
146}
147
148all_tuples!(impl_queryable, 0, 15, T);
149
150impl<Q: Queryable<Ctx>, Ctx: QueryContext> ObjectQuery<Q, Ctx> {
151 async fn new_match(
152 tx: &mpsc::Sender<QueryEvent<Q, Ctx>>,
153 matching_objects: &mut HashSet<ObjectInfo>,
154 object: ObjectInfo,
155 data: Q,
156 ) {
157 matching_objects.insert(object.clone());
158 _ = tx.send(QueryEvent::NewMatch(object, data)).await;
159 }
160 async fn match_lost(
161 tx: &mpsc::Sender<QueryEvent<Q, Ctx>>,
162 object: ObjectInfo,
163 matching_objects: &mut HashSet<ObjectInfo>,
164 ) {
165 matching_objects.remove(&object);
166 _ = tx.send(QueryEvent::MatchLost(object)).await;
167 }
168 async fn update_task(
169 ctx: Arc<Ctx>,
170 object_registry: Arc<ObjectRegistry>,
171 tx: mpsc::Sender<QueryEvent<Q, Ctx>>,
172 ) -> zbus::Result<()> {
173 let mut recv = object_registry.get_object_events_receiver();
174 let mut matching_objects: HashSet<ObjectInfo> = HashSet::new();
175 let connection = object_registry.get_connection();
176 let watch = object_registry.get_watch();
177 let v = watch.borrow().object_to_interfaces.clone();
178
179 for (object, interfaces) in v {
180 let data = Q::try_new(connection, &ctx, &object, &|i| {
181 interfaces.iter().any(|f| i == f)
182 })
183 .await;
184 let already_matching = matching_objects.contains(&object);
185 match (data, already_matching) {
186 (None, true) => Self::match_lost(&tx, object, &mut matching_objects).await,
187 (None, false) => {}
188 (Some(data), true) => {
189 _ = tx.send(QueryEvent::MatchModified(object, data)).await;
190 }
191 (Some(data), false) => {
192 Self::new_match(&tx, &mut matching_objects, object, data).await
193 }
194 }
195 }
196
197 loop {
198 let object_event = match recv.recv().await {
199 Ok(objs) => objs,
200 Err(RecvError::Closed) => break,
201 Err(RecvError::Lagged(_)) => continue,
202 };
203 let Some(v) = watch
204 .borrow()
205 .object_to_interfaces
206 .get(&object_event.object)
207 .cloned()
208 else {
209 Self::match_lost(&tx, object_event.object, &mut matching_objects).await;
210 continue;
211 };
212 let data = Q::try_new(connection, &ctx, &object_event.object, &|i| {
213 v.iter().any(|j| i == j)
214 })
215 .await;
216
217 let already_matching = matching_objects.contains(&object_event.object);
218 match (data, already_matching) {
219 (None, true) => {
220 Self::match_lost(&tx, object_event.object, &mut matching_objects).await
221 }
222 (None, false) => {}
223 (Some(data), true) => {
224 _ = tx
225 .send(QueryEvent::MatchModified(object_event.object, data))
226 .await;
227 }
228 (Some(data), false) => {
229 Self::new_match(&tx, &mut matching_objects, object_event.object, data).await
230 }
231 }
232 }
233
234 Ok(())
235 }
236}
237
238#[tokio::test]
239async fn query_test() {
240 use crate::dbus::{
241 object_registry::ObjectRegistry,
242 query::{ObjectQuery, QueryContext},
243 };
244 use std::{
245 sync::{
246 Arc,
247 atomic::{AtomicBool, Ordering},
248 },
249 thread,
250 time::Duration,
251 };
252 use tokio::{sync::Notify, time::sleep};
253 use zbus::{Connection, fdo::ObjectManager, interface};
254
255 struct TestInterface;
256 #[interface(name = "org.stardustxr.Query.TestInterface", proxy())]
257 impl TestInterface {
258 fn hello(&self) {
259 println!("hello world");
260 }
261 }
262 impl_queryable_for_proxy!(TestInterfaceProxy);
263
264 tokio::task::spawn(async {
265 tokio::time::sleep(Duration::from_secs(10)).await;
266 panic!("Took too long to run");
267 });
268
269 let service_conn = zbus::conn::Builder::session()
270 .unwrap()
271 .serve_at("/", zbus::fdo::ObjectManager)
272 .unwrap()
273 .serve_at("/org/stardustxr/TestObject", TestInterface)
274 .unwrap()
275 .build()
276 .await
277 .unwrap();
278 println!("name: {:?}", service_conn.unique_name());
279
280 let scan_conn = Connection::session().await.unwrap();
281 let object_registry = ObjectRegistry::new(&scan_conn).await;
282
283 println!(
284 "Objects updated: {:#?}",
285 object_registry.get_watch().borrow().clone()
286 );
287 let mut query = ObjectQuery::<TestInterfaceProxy, ()>::new(object_registry, ());
295
296 while let Some(e) = query.recv_event().await {
297 println!("New event: {e:#?}");
298 if let QueryEvent::NewMatch(object_info, p) = e {
299 println!("New match to query, {object_info:#?}");
300 if object_info.object_path.as_str() == "/org/stardustxr/TestObject" {
301 p.hello().await.unwrap();
302 break;
303 }
304 }
305 }
306
307 drop(service_conn);
308 println!("Dropping the other connection");
309 while let Some(e) = query.recv_event().await {
310 println!("New event: {e:#?}");
311 if let QueryEvent::MatchLost(object_info) = e {
312 println!("Dropped match to query, {object_info:#?}");
313 if object_info.object_path.as_str() == "/org/stardustxr/TestObject" {
314 break;
315 }
316 }
317 }
318}