witty_actors/
registry.rs

1// Copyright (C) 2023 Quickwit, Inc.
2//
3// Quickwit is offered under the AGPL v3.0 and as commercial software.
4// For commercial licensing, contact us at hello@quickwit.io.
5//
6// AGPL:
7// This program is free software: you can redistribute it and/or modify
8// it under the terms of the GNU Affero General Public License as
9// published by the Free Software Foundation, either version 3 of the
10// License, or (at your option) any later version.
11//
12// This program is distributed in the hope that it will be useful,
13// but WITHOUT ANY WARRANTY; without even the implied warranty of
14// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15// GNU Affero General Public License for more details.
16//
17// You should have received a copy of the GNU Affero General Public License
18// along with this program. If not, see <http://www.gnu.org/licenses/>.
19
20use std::any::{Any, TypeId};
21use std::collections::HashMap;
22use std::pin::Pin;
23use std::sync::{Arc, RwLock};
24use std::time::Duration;
25
26use async_trait::async_trait;
27use futures::future::{self, Shared};
28use futures::{Future, FutureExt};
29use serde::Serialize;
30use serde_json::Value as JsonValue;
31use tokio::task::JoinHandle;
32
33use crate::command::Observe;
34use crate::mailbox::WeakMailbox;
35use crate::{Actor, ActorExitStatus, Command, Mailbox};
36
37struct TypedJsonObservable<A: Actor> {
38    actor_instance_id: String,
39    weak_mailbox: WeakMailbox<A>,
40    join_handle: ActorJoinHandle,
41}
42
43#[async_trait]
44trait JsonObservable: Sync + Send {
45    fn is_disconnected(&self) -> bool;
46    fn any(&self) -> &dyn Any;
47    fn actor_instance_id(&self) -> &str;
48    async fn observe(&self) -> Option<JsonValue>;
49    async fn quit(&self) -> ActorExitStatus;
50    async fn join(&self) -> ActorExitStatus;
51}
52
53#[async_trait]
54impl<A: Actor> JsonObservable for TypedJsonObservable<A> {
55    fn is_disconnected(&self) -> bool {
56        self.weak_mailbox
57            .upgrade()
58            .map(|mailbox| mailbox.is_disconnected())
59            .unwrap_or(true)
60    }
61    fn any(&self) -> &dyn Any {
62        &self.weak_mailbox
63    }
64    fn actor_instance_id(&self) -> &str {
65        self.actor_instance_id.as_str()
66    }
67    async fn observe(&self) -> Option<JsonValue> {
68        let mailbox = self.weak_mailbox.upgrade()?;
69        let oneshot_rx = mailbox.send_message_with_high_priority(Observe).ok()?;
70        let state: <A as Actor>::ObservableState = oneshot_rx.await.ok()?;
71        serde_json::to_value(&state).ok()
72    }
73
74    async fn quit(&self) -> ActorExitStatus {
75        if let Some(mailbox) = self.weak_mailbox.upgrade() {
76            let _ = mailbox.send_message_with_high_priority(Command::Quit);
77        }
78        self.join().await
79    }
80
81    async fn join(&self) -> ActorExitStatus {
82        self.join_handle.join().await
83    }
84}
85
86#[derive(Default, Clone)]
87pub(crate) struct ActorRegistry {
88    actors: Arc<RwLock<HashMap<TypeId, ActorRegistryForSpecificType>>>,
89}
90
91struct ActorRegistryForSpecificType {
92    type_name: &'static str,
93    observables: Vec<Arc<dyn JsonObservable>>,
94}
95
96impl ActorRegistryForSpecificType {
97    fn for_type<A>() -> ActorRegistryForSpecificType {
98        ActorRegistryForSpecificType {
99            type_name: std::any::type_name::<A>(),
100            observables: Vec::new(),
101        }
102    }
103
104    fn gc(&mut self) {
105        let mut i = 0;
106        while i < self.observables.len() {
107            if self.observables[i].is_disconnected() {
108                self.observables.swap_remove(i);
109            } else {
110                i += 1;
111            }
112        }
113    }
114}
115
116#[derive(Serialize, Debug)]
117pub struct ActorObservation {
118    pub type_name: &'static str,
119    pub instance_id: String,
120    pub obs: Option<JsonValue>,
121}
122
123impl ActorRegistry {
124    pub fn register<A: Actor>(&self, mailbox: &Mailbox<A>, join_handle: ActorJoinHandle) {
125        let typed_id = TypeId::of::<A>();
126        let actor_instance_id = mailbox.actor_instance_id().to_string();
127        let weak_mailbox = mailbox.downgrade();
128        self.actors
129            .write()
130            .unwrap()
131            .entry(typed_id)
132            .or_insert_with(|| ActorRegistryForSpecificType::for_type::<A>())
133            .observables
134            .push(Arc::new(TypedJsonObservable {
135                weak_mailbox,
136                actor_instance_id,
137                join_handle,
138            }));
139    }
140
141    pub async fn observe(&self, timeout: Duration) -> Vec<ActorObservation> {
142        self.gc();
143        let mut obs_futures = Vec::new();
144        for registry_for_type in self.actors.read().unwrap().values() {
145            for obs in &registry_for_type.observables {
146                if obs.is_disconnected() {
147                    continue;
148                }
149                let obs_clone = obs.clone();
150                let type_name = registry_for_type.type_name;
151                let instance_id = obs.actor_instance_id().to_string();
152                obs_futures.push(async move {
153                    let obs = tokio::time::timeout(timeout, obs_clone.observe())
154                        .await
155                        .unwrap_or(None);
156                    ActorObservation {
157                        type_name,
158                        instance_id,
159                        obs,
160                    }
161                });
162            }
163        }
164        future::join_all(obs_futures.into_iter()).await
165    }
166
167    pub fn get<A: Actor>(&self) -> Vec<Mailbox<A>> {
168        let mut lock = self.actors.write().unwrap();
169        get_iter::<A>(&mut lock).collect()
170    }
171
172    pub fn get_one<A: Actor>(&self) -> Option<Mailbox<A>> {
173        let mut lock = self.actors.write().unwrap();
174        let opt = get_iter::<A>(&mut lock).next();
175        opt
176    }
177
178    fn gc(&self) {
179        for registry_for_type in self.actors.write().unwrap().values_mut() {
180            registry_for_type.gc();
181        }
182    }
183
184    pub async fn quit(&self) -> HashMap<String, ActorExitStatus> {
185        let mut obs_futures = Vec::new();
186        let mut actor_ids = Vec::new();
187        for registry_for_type in self.actors.read().unwrap().values() {
188            for obs in &registry_for_type.observables {
189                let obs_clone = obs.clone();
190                obs_futures.push(async move { obs_clone.quit().await });
191                actor_ids.push(obs.actor_instance_id().to_string());
192            }
193        }
194        let res = future::join_all(obs_futures).await;
195        actor_ids.into_iter().zip(res).collect()
196    }
197
198    pub fn is_empty(&self) -> bool {
199        self.actors
200            .read()
201            .unwrap()
202            .values()
203            .all(|registry_for_type| {
204                registry_for_type
205                    .observables
206                    .iter()
207                    .all(|obs| obs.is_disconnected())
208            })
209    }
210}
211
212fn get_iter<A: Actor>(
213    actors: &mut HashMap<TypeId, ActorRegistryForSpecificType>,
214) -> impl Iterator<Item = Mailbox<A>> + '_ {
215    let typed_id = TypeId::of::<A>();
216    actors
217        .get(&typed_id)
218        .into_iter()
219        .flat_map(|registry_for_type| {
220            registry_for_type
221                .observables
222                .iter()
223                .flat_map(|box_any| box_any.any().downcast_ref::<WeakMailbox<A>>())
224                .flat_map(|weak_mailbox| weak_mailbox.upgrade())
225        })
226        .filter(|mailbox| !mailbox.is_disconnected())
227}
228
229/// This structure contains an optional exit handle. The handle is present
230/// until the join() method is called.
231#[derive(Clone)]
232pub(crate) struct ActorJoinHandle {
233    holder: Shared<Pin<Box<dyn Future<Output = ActorExitStatus> + Send>>>,
234}
235
236impl ActorJoinHandle {
237    pub(crate) fn new(join_handle: JoinHandle<ActorExitStatus>) -> Self {
238        ActorJoinHandle {
239            holder: Self::inner_join(join_handle).boxed().shared(),
240        }
241    }
242
243    async fn inner_join(join_handle: JoinHandle<ActorExitStatus>) -> ActorExitStatus {
244        join_handle.await.unwrap_or_else(|join_err| {
245            if join_err.is_panic() {
246                ActorExitStatus::Panicked
247            } else {
248                ActorExitStatus::Killed
249            }
250        })
251    }
252
253    /// Joins the actor and returns its exit status on the first invocation.
254    /// Returns None afterwards.
255    pub(crate) async fn join(&self) -> ActorExitStatus {
256        self.holder.clone().await
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use std::time::Duration;
263
264    use crate::tests::PingReceiverActor;
265    use crate::Universe;
266
267    #[tokio::test]
268    async fn test_registry() {
269        let test_actor = PingReceiverActor::default();
270        let universe = Universe::with_accelerated_time();
271        let (_mailbox, _handle) = universe.spawn_builder().spawn(test_actor);
272        let _actor_mailbox = universe.get_one::<PingReceiverActor>().unwrap();
273        universe.assert_quit().await;
274    }
275
276    #[tokio::test]
277    async fn test_registry_killed_actor() {
278        let test_actor = PingReceiverActor::default();
279        let universe = Universe::with_accelerated_time();
280        let (_mailbox, handle) = universe.spawn_builder().spawn(test_actor);
281        handle.kill().await;
282        assert!(universe.get_one::<PingReceiverActor>().is_none());
283    }
284
285    #[tokio::test]
286    async fn test_registry_last_mailbox_dropped_actor() {
287        let test_actor = PingReceiverActor::default();
288        let universe = Universe::with_accelerated_time();
289        let (mailbox, handle) = universe.spawn_builder().spawn(test_actor);
290        drop(mailbox);
291        handle.join().await;
292        assert!(universe.get_one::<PingReceiverActor>().is_none());
293    }
294
295    #[tokio::test]
296    async fn test_get_actor_states() {
297        let test_actor = PingReceiverActor::default();
298        let universe = Universe::with_accelerated_time();
299        let (_mailbox, _handle) = universe.spawn_builder().spawn(test_actor);
300        let obs = universe.observe(Duration::from_millis(1000)).await;
301        assert_eq!(obs.len(), 1);
302        universe.assert_quit().await;
303    }
304}