1use std::borrow::Borrow;
2use std::collections::HashMap;
3use std::fmt;
4use std::future::Future;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use futures::future;
9use parking_lot::{Mutex, MutexGuard};
10use tokio::task::{JoinError, JoinHandle};
11
12use crate::{Channel, Message, System, TaskId};
13
14#[derive(Clone, Default)]
15pub struct TaskHandles {
16    handles: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
17}
18
19impl TaskHandles {
20    pub fn add(&self, id: TaskId, handle: JoinHandle<()>) {
21        self.handles.lock().insert(id, handle);
22    }
23
24    pub fn remove(&self, id: &TaskId) -> Option<JoinHandle<()>> {
25        self.handles.lock().remove(id)
26    }
27
28    pub fn is_empty(&self) -> bool {
29        self.handles.lock().is_empty()
30    }
31
32    pub fn len(&self) -> usize {
33        self.handles.lock().len()
34    }
35
36    pub async fn join(&self, id: &TaskId) -> Result<(), JoinError> {
37        if let Some(handle) = self.remove(id) {
38            handle.await
39        } else {
40            Ok(())
41        }
42    }
43
44    pub async fn join_all(&self) {
45        let handles: Vec<_> = self.handles.lock().drain().map(|(_, handle)| handle).collect();
46        future::join_all(handles).await;
47    }
48}
49
50pub type DefaultActorId = String;
51pub type DefaultContext = Context<DefaultActorId>;
52
53#[derive(Clone)]
54pub struct Context<ActorId = DefaultActorId> {
55    system: Arc<Mutex<System<ActorId>>>,
56    handles: TaskHandles,
57}
58
59impl<ActorId> Default for Context<ActorId> {
60    fn default() -> Self {
61        Self {
62            system: Default::default(),
63            handles: Default::default(),
64        }
65    }
66}
67
68impl<ActorId> Context<ActorId> {
69    pub fn new() -> Self {
70        Self::default()
71    }
72
73    pub fn from_system(system: System<ActorId>) -> Self {
74        Self {
75            system: Arc::new(Mutex::new(system)),
76            handles: Default::default(),
77        }
78    }
79
80    pub fn spawn<T>(&self, future: T) -> TaskId
81    where
82        T: Future<Output = ()> + Send + 'static,
83    {
84        let task_id = self.system().next_task_id();
85        let handles = self.handles.clone();
86        let handle = tokio::spawn(async move {
87            future.await;
88            handles.remove(&task_id);
89        });
90        self.handles.add(task_id, handle);
91        task_id
92    }
93
94    pub fn extract_channel<M: Message>(&self) -> Option<M::Channel> {
95        self.system().extract_channel::<M>()
96    }
97
98    pub fn get_sender<M: Message>(&self) -> Option<<M::Channel as Channel>::Sender> {
99        self.system().get_channel::<M>().map(|channel| channel.sender())
100    }
101
102    pub fn sender_of_custom_channel<M: Message>(
103        &self,
104        constructor: impl FnOnce() -> M::Channel,
105    ) -> <M::Channel as Channel>::Sender {
106        self.system().sender_of_custom_channel::<M>(constructor)
107    }
108
109    pub fn receiver_of_custom_channel<M: Message>(
110        &self,
111        constructor: impl FnOnce() -> M::Channel,
112    ) -> <M::Channel as Channel>::Receiver {
113        self.system().receiver_of_custom_channel::<M>(constructor)
114    }
115
116    pub fn sender<M: Message>(&self) -> <M::Channel as Channel>::Sender {
117        self.system().sender::<M>()
118    }
119
120    pub fn receiver<M: Message>(&self) -> <M::Channel as Channel>::Receiver {
121        self.system().receiver::<M>()
122    }
123
124    pub fn is_channel_closed<M: Message>(&self) -> Option<bool> {
125        self.system().get_channel::<M>().map(|channel| channel.is_closed())
126    }
127
128    pub fn system(&self) -> MutexGuard<'_, System<ActorId>> {
129        self.system.lock()
130    }
131
132    pub async fn shutdown(&self) {
133        self.system().shutdown();
134        self.join_all().await
135    }
136
137    pub async fn join_all(&self) {
138        self.handles.join_all().await
139    }
140
141    pub async fn join(&self, id: &TaskId) -> Result<(), JoinError> {
142        self.handles.join(id).await
143    }
144
145    pub fn handles(&self) -> &TaskHandles {
146        &self.handles
147    }
148}
149
150impl<ActorId: Eq + Hash> Context<ActorId> {
151    pub fn extract_actor_channel<M: Message>(&self, actor_id: &ActorId) -> Option<M::Channel> {
152        self.system().extract_actor_channel::<M>(actor_id)
153    }
154
155    pub fn get_actor_sender<M: Message>(&self, actor_id: &ActorId) -> Option<<M::Channel as Channel>::Sender> {
156        self.system()
157            .get_actor_channel::<M>(actor_id)
158            .map(|channel| channel.sender())
159    }
160
161    pub fn actor_sender_of_custom_channel<M: Message>(
162        &self,
163        actor_id: ActorId,
164        constructor: impl FnOnce() -> M::Channel,
165    ) -> <M::Channel as Channel>::Sender {
166        self.system().actor_sender_of_custom_channel::<M>(actor_id, constructor)
167    }
168
169    pub fn actor_receiver_of_custom_channel<M: Message>(
170        &self,
171        actor_id: ActorId,
172        constructor: impl FnOnce() -> M::Channel,
173    ) -> <M::Channel as Channel>::Receiver {
174        self.system()
175            .actor_receiver_of_custom_channel::<M>(actor_id, constructor)
176    }
177}
178
179impl<ActorId: Eq + Hash + fmt::Display> Context<ActorId> {
180    pub fn actor_sender<M: Message>(&self, actor_id: impl Into<ActorId>) -> <M::Channel as Channel>::Sender {
181        self.system().actor_sender::<M>(actor_id.into())
182    }
183
184    pub fn actor_receiver<M: Message>(&self, actor_id: impl Into<ActorId>) -> <M::Channel as Channel>::Receiver {
185        self.system().actor_receiver::<M>(actor_id.into())
186    }
187
188    pub fn is_actor_channel_closed<M: Message>(&self, actor_id: impl Borrow<ActorId>) -> Option<bool> {
189        self.system()
190            .get_actor_channel::<M>(actor_id.borrow())
191            .map(|channel| channel.is_closed())
192    }
193}
194
195impl<ActorId> From<System<ActorId>> for Context<ActorId> {
196    fn from(system: System<ActorId>) -> Self {
197        Self::from_system(system)
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use tokio::sync::mpsc::error::SendError;
204
205    use crate::{Context, Message, MpscChannel};
206
207    struct Value(&'static str);
208
209    impl Message for Value {
210        type Channel = MpscChannel<Self>;
211    }
212
213    #[tokio::test]
214    async fn actor_channels() {
215        let ctx = Context::<i32>::new();
216
217        let sender = ctx.sender::<Value>();
218        let mut receiver = ctx.receiver::<Value>();
219
220        sender.send(Value("common")).await.ok().unwrap();
221        assert_eq!(receiver.recv().await.unwrap().0, "common");
222
223        let actor_sender = ctx.actor_sender::<Value>(1);
224        let mut actor_receiver = ctx.actor_receiver::<Value>(1);
225
226        sender.send(Value("common")).await.ok().unwrap();
227        actor_sender.send(Value("actor")).await.ok().unwrap();
228
229        assert_eq!(receiver.recv().await.unwrap().0, "common");
230        assert_eq!(actor_receiver.recv().await.unwrap().0, "actor");
231
232        let (extracted_actor_sender, _) = ctx.extract_actor_channel::<Value>(&1).unwrap().into_inner();
233
234        extracted_actor_sender
235            .send(Value("extracted actor"))
236            .await
237            .ok()
238            .unwrap();
239
240        assert!(receiver.try_recv().is_err());
241        assert_eq!(actor_receiver.recv().await.unwrap().0, "extracted actor");
242
243        drop(actor_receiver);
244
245        assert!(matches!(
246            actor_sender.send(Value("actor closed")).await,
247            Err(SendError(Value("actor closed")))
248        ));
249        assert!(matches!(
250            extracted_actor_sender.send(Value("actor closed")).await,
251            Err(SendError(Value("actor closed")))
252        ));
253
254        sender.send(Value("common")).await.ok().unwrap();
255        assert_eq!(receiver.recv().await.unwrap().0, "common");
256    }
257
258    #[tokio::test]
259    async fn close_actor_channel_by_drop() {
260        let ctx = Context::<i32>::new();
261        let actor_sender = ctx.actor_sender::<Value>(1);
262        let mut actor_receiver = ctx.actor_receiver::<Value>(1);
263
264        actor_sender.send(Value("test")).await.ok().unwrap();
265        assert_eq!(actor_receiver.recv().await.unwrap().0, "test");
266
267        assert!(!ctx.is_actor_channel_closed::<Value>(1_i32).unwrap_or(true));
268        drop(actor_receiver);
269        assert!(ctx.is_actor_channel_closed::<Value>(1_i32).unwrap_or(true));
270    }
271}