1use crate::error::ActorError;
2use crate::message::Message;
3use core::pin::pin;
4use futures::future::{self, FutureExt as _};
5use spawned_rt::{
6 tasks::{self as rt, mpsc, oneshot, timeout, watch, CancellationToken, JoinHandle},
7 threads,
8};
9use std::{
10 fmt::Debug, future::Future, panic::AssertUnwindSafe, pin::Pin, sync::Arc, time::Duration,
11};
12
13pub use crate::response::DEFAULT_REQUEST_TIMEOUT;
14
15#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
25pub enum Backend {
26 #[default]
27 Async,
28 Blocking,
29 Thread,
30}
31
32pub trait Actor: Send + Sized + 'static {
43 fn started(&mut self, _ctx: &Context<Self>) -> impl Future<Output = ()> + Send {
44 async {}
45 }
46
47 fn stopped(&mut self, _ctx: &Context<Self>) -> impl Future<Output = ()> + Send {
48 async {}
49 }
50}
51
52pub trait Handler<M: Message>: Actor {
61 fn handle(&mut self, msg: M, ctx: &Context<Self>) -> impl Future<Output = M::Result> + Send;
62}
63
64trait Envelope<A: Actor>: Send {
69 fn handle<'a>(
70 self: Box<Self>,
71 actor: &'a mut A,
72 ctx: &'a Context<A>,
73 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
74}
75
76struct MessageEnvelope<M: Message> {
77 msg: M,
78 tx: Option<oneshot::Sender<M::Result>>,
79}
80
81impl<A, M> Envelope<A> for MessageEnvelope<M>
82where
83 A: Actor + Handler<M>,
84 M: Message,
85{
86 fn handle<'a>(
87 self: Box<Self>,
88 actor: &'a mut A,
89 ctx: &'a Context<A>,
90 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
91 Box::pin(async move {
92 let result = actor.handle(self.msg, ctx).await;
93 if let Some(tx) = self.tx {
94 let _ = tx.send(result);
95 }
96 })
97 }
98}
99
100pub struct Context<A: Actor> {
109 sender: mpsc::Sender<Box<dyn Envelope<A> + Send>>,
110 cancellation_token: CancellationToken,
111 completion_rx: watch::Receiver<bool>,
112}
113
114impl<A: Actor> Clone for Context<A> {
115 fn clone(&self) -> Self {
116 Self {
117 sender: self.sender.clone(),
118 cancellation_token: self.cancellation_token.clone(),
119 completion_rx: self.completion_rx.clone(),
120 }
121 }
122}
123
124impl<A: Actor> Debug for Context<A> {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 f.debug_struct("Context").finish_non_exhaustive()
127 }
128}
129
130impl<A: Actor> Context<A> {
131 pub fn from_ref(actor_ref: &ActorRef<A>) -> Self {
134 Self {
135 sender: actor_ref.sender.clone(),
136 cancellation_token: actor_ref.cancellation_token.clone(),
137 completion_rx: actor_ref.completion_rx.clone(),
138 }
139 }
140
141 pub fn stop(&self) {
144 self.cancellation_token.cancel();
145 }
146
147 pub fn send<M>(&self, msg: M) -> Result<(), ActorError>
149 where
150 A: Handler<M>,
151 M: Message,
152 {
153 let envelope = MessageEnvelope { msg, tx: None };
154 self.sender
155 .send(Box::new(envelope))
156 .map_err(|_| ActorError::ActorStopped)
157 }
158
159 pub fn request_raw<M>(&self, msg: M) -> Result<oneshot::Receiver<M::Result>, ActorError>
161 where
162 A: Handler<M>,
163 M: Message,
164 {
165 let (tx, rx) = oneshot::channel();
166 let envelope = MessageEnvelope { msg, tx: Some(tx) };
167 self.sender
168 .send(Box::new(envelope))
169 .map_err(|_| ActorError::ActorStopped)?;
170 Ok(rx)
171 }
172
173 pub async fn request<M>(&self, msg: M) -> Result<M::Result, ActorError>
175 where
176 A: Handler<M>,
177 M: Message,
178 {
179 self.request_with_timeout(msg, DEFAULT_REQUEST_TIMEOUT)
180 .await
181 }
182
183 pub async fn request_with_timeout<M>(
185 &self,
186 msg: M,
187 duration: Duration,
188 ) -> Result<M::Result, ActorError>
189 where
190 A: Handler<M>,
191 M: Message,
192 {
193 let rx = self.request_raw(msg)?;
194 match timeout(duration, rx).await {
195 Ok(Ok(result)) => Ok(result),
196 Ok(Err(_)) => Err(ActorError::ActorStopped),
197 Err(_) => Err(ActorError::RequestTimeout),
198 }
199 }
200
201 pub fn recipient<M>(&self) -> Recipient<M>
204 where
205 A: Handler<M>,
206 M: Message,
207 {
208 Arc::new(self.clone())
209 }
210
211 pub fn actor_ref(&self) -> ActorRef<A> {
213 ActorRef {
214 sender: self.sender.clone(),
215 cancellation_token: self.cancellation_token.clone(),
216 completion_rx: self.completion_rx.clone(),
217 }
218 }
219
220 pub(crate) fn cancellation_token(&self) -> CancellationToken {
221 self.cancellation_token.clone()
222 }
223}
224
225impl<A, M> Receiver<M> for Context<A>
227where
228 A: Actor + Handler<M>,
229 M: Message,
230{
231 fn send(&self, msg: M) -> Result<(), ActorError> {
232 Context::send(self, msg)
233 }
234
235 fn request_raw(&self, msg: M) -> Result<oneshot::Receiver<M::Result>, ActorError> {
236 Context::request_raw(self, msg)
237 }
238}
239
240pub trait Receiver<M: Message>: Send + Sync {
249 fn send(&self, msg: M) -> Result<(), ActorError>;
250 fn request_raw(&self, msg: M) -> Result<oneshot::Receiver<M::Result>, ActorError>;
251}
252
253pub type Recipient<M> = Arc<dyn Receiver<M>>;
257
258pub async fn request<M: Message>(
260 recipient: &dyn Receiver<M>,
261 msg: M,
262 timeout_duration: Duration,
263) -> Result<M::Result, ActorError> {
264 let rx = recipient.request_raw(msg)?;
265 match timeout(timeout_duration, rx).await {
266 Ok(Ok(result)) => Ok(result),
267 Ok(Err(_)) => Err(ActorError::ActorStopped),
268 Err(_) => Err(ActorError::RequestTimeout),
269 }
270}
271
272pub struct ActorRef<A: Actor> {
282 sender: mpsc::Sender<Box<dyn Envelope<A> + Send>>,
283 cancellation_token: CancellationToken,
284 completion_rx: watch::Receiver<bool>,
285}
286
287impl<A: Actor> Debug for ActorRef<A> {
288 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289 f.debug_struct("ActorRef").finish_non_exhaustive()
290 }
291}
292
293impl<A: Actor> Clone for ActorRef<A> {
294 fn clone(&self) -> Self {
295 Self {
296 sender: self.sender.clone(),
297 cancellation_token: self.cancellation_token.clone(),
298 completion_rx: self.completion_rx.clone(),
299 }
300 }
301}
302
303impl<A: Actor> ActorRef<A> {
304 pub fn send<M>(&self, msg: M) -> Result<(), ActorError>
306 where
307 A: Handler<M>,
308 M: Message,
309 {
310 let envelope = MessageEnvelope { msg, tx: None };
311 self.sender
312 .send(Box::new(envelope))
313 .map_err(|_| ActorError::ActorStopped)
314 }
315
316 pub fn request_raw<M>(&self, msg: M) -> Result<oneshot::Receiver<M::Result>, ActorError>
318 where
319 A: Handler<M>,
320 M: Message,
321 {
322 let (tx, rx) = oneshot::channel();
323 let envelope = MessageEnvelope { msg, tx: Some(tx) };
324 self.sender
325 .send(Box::new(envelope))
326 .map_err(|_| ActorError::ActorStopped)?;
327 Ok(rx)
328 }
329
330 pub async fn request<M>(&self, msg: M) -> Result<M::Result, ActorError>
332 where
333 A: Handler<M>,
334 M: Message,
335 {
336 self.request_with_timeout(msg, DEFAULT_REQUEST_TIMEOUT)
337 .await
338 }
339
340 pub async fn request_with_timeout<M>(
342 &self,
343 msg: M,
344 duration: Duration,
345 ) -> Result<M::Result, ActorError>
346 where
347 A: Handler<M>,
348 M: Message,
349 {
350 let rx = self.request_raw(msg)?;
351 match timeout(duration, rx).await {
352 Ok(Ok(result)) => Ok(result),
353 Ok(Err(_)) => Err(ActorError::ActorStopped),
354 Err(_) => Err(ActorError::RequestTimeout),
355 }
356 }
357
358 pub fn recipient<M>(&self) -> Recipient<M>
360 where
361 A: Handler<M>,
362 M: Message,
363 {
364 Arc::new(self.clone())
365 }
366
367 pub fn context(&self) -> Context<A> {
369 Context::from_ref(self)
370 }
371
372 pub async fn join(&self) {
374 let mut rx = self.completion_rx.clone();
375 while !*rx.borrow_and_update() {
376 if rx.changed().await.is_err() {
377 break;
378 }
379 }
380 }
381}
382
383impl<A, M> Receiver<M> for ActorRef<A>
385where
386 A: Actor + Handler<M>,
387 M: Message,
388{
389 fn send(&self, msg: M) -> Result<(), ActorError> {
390 ActorRef::send(self, msg)
391 }
392
393 fn request_raw(&self, msg: M) -> Result<oneshot::Receiver<M::Result>, ActorError> {
394 ActorRef::request_raw(self, msg)
395 }
396}
397
398impl<A: Actor> ActorRef<A> {
403 fn spawn(actor: A, backend: Backend) -> Self {
404 let (tx, rx) = mpsc::channel::<Box<dyn Envelope<A> + Send>>();
405 let cancellation_token = CancellationToken::new();
406 let (completion_tx, completion_rx) = watch::channel(false);
407
408 let actor_ref = ActorRef {
409 sender: tx.clone(),
410 cancellation_token: cancellation_token.clone(),
411 completion_rx,
412 };
413
414 let ctx = Context {
415 sender: tx,
416 cancellation_token: cancellation_token.clone(),
417 completion_rx: actor_ref.completion_rx.clone(),
418 };
419
420 let inner_future = async move {
421 run_actor(actor, ctx, rx, cancellation_token).await;
422 let _ = completion_tx.send(true);
423 };
424
425 match backend {
426 Backend::Async => {
427 #[cfg(debug_assertions)]
428 let inner_future = warn_on_block::WarnOnBlocking::new(inner_future);
429 let _handle = rt::spawn(inner_future);
430 }
431 Backend::Blocking => {
432 let _handle = rt::spawn_blocking(move || rt::block_on(inner_future));
433 }
434 Backend::Thread => {
435 let _handle = threads::spawn(move || threads::block_on(inner_future));
436 }
437 }
438
439 actor_ref
440 }
441}
442
443async fn run_actor<A: Actor>(
444 mut actor: A,
445 ctx: Context<A>,
446 mut rx: mpsc::Receiver<Box<dyn Envelope<A> + Send>>,
447 cancellation_token: CancellationToken,
448) {
449 let start_result = AssertUnwindSafe(actor.started(&ctx)).catch_unwind().await;
450 if let Err(panic) = start_result {
451 tracing::error!("Panic in started() callback: {panic:?}");
452 cancellation_token.cancel();
453 return;
454 }
455
456 if cancellation_token.is_cancelled() {
457 let _ = AssertUnwindSafe(actor.stopped(&ctx)).catch_unwind().await;
458 return;
459 }
460
461 loop {
462 let msg = {
463 let recv = pin!(rx.recv());
464 let cancel = pin!(cancellation_token.cancelled());
465 match future::select(recv, cancel).await {
466 future::Either::Left((msg, _)) => msg,
467 future::Either::Right(_) => None,
468 }
469 };
470 match msg {
471 Some(envelope) => {
472 let result = AssertUnwindSafe(envelope.handle(&mut actor, &ctx))
473 .catch_unwind()
474 .await;
475 if let Err(panic) = result {
476 tracing::error!("Panic in message handler: {panic:?}");
477 break;
478 }
479 if cancellation_token.is_cancelled() {
480 break;
481 }
482 }
483 None => break,
484 }
485 }
486
487 cancellation_token.cancel();
488 let stop_result = AssertUnwindSafe(actor.stopped(&ctx)).catch_unwind().await;
489 if let Err(panic) = stop_result {
490 tracing::error!("Panic in stopped() callback: {panic:?}");
491 }
492}
493
494pub trait ActorStart: Actor {
500 fn start(self) -> ActorRef<Self> {
502 self.start_with_backend(Backend::default())
503 }
504
505 fn start_with_backend(self, backend: Backend) -> ActorRef<Self> {
507 ActorRef::spawn(self, backend)
508 }
509}
510
511impl<A: Actor> ActorStart for A {}
512
513pub fn send_message_on<A, M, U>(ctx: Context<A>, future: U, msg: M) -> JoinHandle<()>
522where
523 A: Actor + Handler<M>,
524 M: Message,
525 U: Future + Send + 'static,
526 <U as Future>::Output: Send,
527{
528 let cancellation_token = ctx.cancellation_token();
529 let join_handle = rt::spawn(async move {
530 let is_cancelled = pin!(cancellation_token.cancelled());
531 let signal = pin!(future);
532 match future::select(is_cancelled, signal).await {
533 future::Either::Left(_) => tracing::debug!("Actor stopped"),
534 future::Either::Right(_) => {
535 if let Err(e) = ctx.send(msg) {
536 tracing::error!("Failed to send message: {e:?}")
537 }
538 }
539 }
540 });
541 join_handle
542}
543
544#[cfg(debug_assertions)]
549mod warn_on_block {
550 use super::*;
551 use std::time::Instant;
552 use tracing::warn;
553
554 pin_project_lite::pin_project! {
555 pub struct WarnOnBlocking<F: Future>{
556 #[pin]
557 inner: F
558 }
559 }
560
561 impl<F: Future> WarnOnBlocking<F> {
562 pub fn new(inner: F) -> Self {
563 Self { inner }
564 }
565 }
566
567 impl<F: Future> Future for WarnOnBlocking<F> {
568 type Output = F::Output;
569
570 fn poll(
571 self: std::pin::Pin<&mut Self>,
572 cx: &mut std::task::Context<'_>,
573 ) -> std::task::Poll<Self::Output> {
574 let type_id = std::any::type_name::<F>();
575 let task_id = rt::task_id();
576 let this = self.project();
577 let now = Instant::now();
578 let res = this.inner.poll(cx);
579 let elapsed = now.elapsed();
580 if elapsed > Duration::from_millis(10) {
581 warn!(task = ?task_id, future = ?type_id, elapsed = ?elapsed, "Blocking operation detected");
582 }
583 res
584 }
585 }
586}
587
588#[cfg(test)]
593mod tests {
594 use super::*;
595 use crate::message::Message;
596 use std::{
597 sync::{atomic, Arc},
598 thread,
599 time::Duration,
600 };
601
602 struct Counter {
605 count: u64,
606 }
607
608 struct GetCount;
609 impl Message for GetCount {
610 type Result = u64;
611 }
612
613 struct Increment;
614 impl Message for Increment {
615 type Result = u64;
616 }
617
618 struct StopCounter;
619 impl Message for StopCounter {
620 type Result = u64;
621 }
622
623 impl Actor for Counter {}
624
625 impl Handler<GetCount> for Counter {
626 async fn handle(&mut self, _msg: GetCount, _ctx: &Context<Self>) -> u64 {
627 self.count
628 }
629 }
630
631 impl Handler<Increment> for Counter {
632 async fn handle(&mut self, _msg: Increment, _ctx: &Context<Self>) -> u64 {
633 self.count += 1;
634 self.count
635 }
636 }
637
638 impl Handler<StopCounter> for Counter {
639 async fn handle(&mut self, _msg: StopCounter, ctx: &Context<Self>) -> u64 {
640 ctx.stop();
641 self.count
642 }
643 }
644
645 #[test]
646 pub fn backend_default_is_async() {
647 assert_eq!(Backend::default(), Backend::Async);
648 }
649
650 #[test]
651 #[allow(clippy::clone_on_copy)]
652 pub fn backend_enum_is_copy_and_clone() {
653 let backend = Backend::Async;
654 let copied = backend;
655 let cloned = backend.clone();
656 assert_eq!(backend, copied);
657 assert_eq!(backend, cloned);
658 }
659
660 #[test]
661 pub fn backend_enum_debug_format() {
662 assert_eq!(format!("{:?}", Backend::Async), "Async");
663 assert_eq!(format!("{:?}", Backend::Blocking), "Blocking");
664 assert_eq!(format!("{:?}", Backend::Thread), "Thread");
665 }
666
667 #[test]
668 pub fn backend_enum_equality() {
669 assert_eq!(Backend::Async, Backend::Async);
670 assert_eq!(Backend::Blocking, Backend::Blocking);
671 assert_eq!(Backend::Thread, Backend::Thread);
672 assert_ne!(Backend::Async, Backend::Blocking);
673 assert_ne!(Backend::Async, Backend::Thread);
674 assert_ne!(Backend::Blocking, Backend::Thread);
675 }
676
677 #[test]
678 pub fn backend_async_handles_send_and_request() {
679 let runtime = rt::Runtime::new().unwrap();
680 runtime.block_on(async move {
681 let counter = Counter { count: 0 }.start();
682
683 let result = counter.request(GetCount).await.unwrap();
684 assert_eq!(result, 0);
685
686 let result = counter.request(Increment).await.unwrap();
687 assert_eq!(result, 1);
688
689 counter.send(Increment).unwrap();
691 rt::sleep(Duration::from_millis(10)).await;
692
693 let result = counter.request(GetCount).await.unwrap();
694 assert_eq!(result, 2);
695
696 let final_count = counter.request(StopCounter).await.unwrap();
697 assert_eq!(final_count, 2);
698 });
699 }
700
701 #[test]
702 pub fn backend_blocking_handles_send_and_request() {
703 let runtime = rt::Runtime::new().unwrap();
704 runtime.block_on(async move {
705 let counter = Counter { count: 0 }.start_with_backend(Backend::Blocking);
706
707 let result = counter.request(GetCount).await.unwrap();
708 assert_eq!(result, 0);
709
710 let result = counter.request(Increment).await.unwrap();
711 assert_eq!(result, 1);
712
713 counter.send(Increment).unwrap();
714 rt::sleep(Duration::from_millis(50)).await;
715
716 let result = counter.request(GetCount).await.unwrap();
717 assert_eq!(result, 2);
718
719 let final_count = counter.request(StopCounter).await.unwrap();
720 assert_eq!(final_count, 2);
721 });
722 }
723
724 #[test]
725 pub fn backend_thread_handles_send_and_request() {
726 let runtime = rt::Runtime::new().unwrap();
727 runtime.block_on(async move {
728 let counter = Counter { count: 0 }.start_with_backend(Backend::Thread);
729
730 let result = counter.request(GetCount).await.unwrap();
731 assert_eq!(result, 0);
732
733 let result = counter.request(Increment).await.unwrap();
734 assert_eq!(result, 1);
735
736 counter.send(Increment).unwrap();
737 rt::sleep(Duration::from_millis(50)).await;
738
739 let result = counter.request(GetCount).await.unwrap();
740 assert_eq!(result, 2);
741
742 let final_count = counter.request(StopCounter).await.unwrap();
743 assert_eq!(final_count, 2);
744 });
745 }
746
747 #[test]
748 pub fn multiple_backends_concurrent() {
749 let runtime = rt::Runtime::new().unwrap();
750 runtime.block_on(async move {
751 let async_counter = Counter { count: 0 }.start();
752 let blocking_counter = Counter { count: 100 }.start_with_backend(Backend::Blocking);
753 let thread_counter = Counter { count: 200 }.start_with_backend(Backend::Thread);
754
755 async_counter.request(Increment).await.unwrap();
756 blocking_counter.request(Increment).await.unwrap();
757 thread_counter.request(Increment).await.unwrap();
758
759 let async_val = async_counter.request(GetCount).await.unwrap();
760 let blocking_val = blocking_counter.request(GetCount).await.unwrap();
761 let thread_val = thread_counter.request(GetCount).await.unwrap();
762
763 assert_eq!(async_val, 1);
764 assert_eq!(blocking_val, 101);
765 assert_eq!(thread_val, 201);
766
767 async_counter.request(StopCounter).await.unwrap();
768 blocking_counter.request(StopCounter).await.unwrap();
769 thread_counter.request(StopCounter).await.unwrap();
770 });
771 }
772
773 #[test]
774 pub fn request_timeout() {
775 let runtime = rt::Runtime::new().unwrap();
776 runtime.block_on(async move {
777 struct SlowActor;
778 struct SlowOp;
779 impl Message for SlowOp {
780 type Result = ();
781 }
782 impl Actor for SlowActor {}
783 impl Handler<SlowOp> for SlowActor {
784 async fn handle(&mut self, _msg: SlowOp, _ctx: &Context<Self>) {
785 rt::sleep(Duration::from_millis(200)).await;
786 }
787 }
788
789 let actor = SlowActor.start();
790 let result = actor
791 .request_with_timeout(SlowOp, Duration::from_millis(50))
792 .await;
793 assert!(matches!(result, Err(ActorError::RequestTimeout)));
794 });
795 }
796
797 #[test]
798 pub fn recipient_type_erasure() {
799 let runtime = rt::Runtime::new().unwrap();
800 runtime.block_on(async move {
801 let counter = Counter { count: 42 }.start();
802 let recipient: Recipient<GetCount> = counter.recipient();
803
804 let rx = recipient.request_raw(GetCount).unwrap();
805 let result = rx.await.unwrap();
806 assert_eq!(result, 42);
807
808 let result = request(&*recipient, GetCount, Duration::from_secs(5))
810 .await
811 .unwrap();
812 assert_eq!(result, 42);
813 });
814 }
815
816 struct SlowShutdownActor;
819
820 struct StopSlow;
821 impl Message for StopSlow {
822 type Result = ();
823 }
824
825 impl Actor for SlowShutdownActor {
826 async fn stopped(&mut self, _ctx: &Context<Self>) {
827 thread::sleep(Duration::from_millis(500));
828 }
829 }
830
831 impl Handler<StopSlow> for SlowShutdownActor {
832 async fn handle(&mut self, _msg: StopSlow, ctx: &Context<Self>) {
833 ctx.stop();
834 }
835 }
836
837 #[test]
838 pub fn thread_backend_join_does_not_block_runtime() {
839 let runtime = tokio::runtime::Builder::new_current_thread()
840 .enable_all()
841 .build()
842 .unwrap();
843
844 runtime.block_on(async move {
845 let slow_actor = SlowShutdownActor.start_with_backend(Backend::Thread);
846
847 let tick_count = Arc::new(atomic::AtomicU64::new(0));
848 let tick_count_clone = tick_count.clone();
849 let _ticker = rt::spawn(async move {
850 for _ in 0..20 {
851 rt::sleep(Duration::from_millis(50)).await;
852 tick_count_clone.fetch_add(1, atomic::Ordering::SeqCst);
853 }
854 });
855
856 slow_actor.send(StopSlow).unwrap();
857 rt::sleep(Duration::from_millis(10)).await;
858
859 slow_actor.join().await;
860
861 let count_after_join = tick_count.load(atomic::Ordering::SeqCst);
862 assert!(
863 count_after_join >= 8,
864 "Ticker should have completed ~10 ticks during the 500ms join(), but only got {count_after_join}. \
865 This suggests join() blocked the runtime."
866 );
867 });
868 }
869
870 #[test]
871 pub fn multiple_join_callers_all_notified() {
872 let runtime = rt::Runtime::new().unwrap();
873 runtime.block_on(async move {
874 let actor = SlowShutdownActor.start();
875 let actor_clone1 = actor.clone();
876 let actor_clone2 = actor.clone();
877
878 let join1 = rt::spawn(async move {
879 actor_clone1.join().await;
880 1u32
881 });
882 let join2 = rt::spawn(async move {
883 actor_clone2.join().await;
884 2u32
885 });
886
887 rt::sleep(Duration::from_millis(10)).await;
888
889 actor.send(StopSlow).unwrap();
890
891 let (r1, r2) = tokio::join!(join1, join2);
892 assert_eq!(r1.unwrap(), 1);
893 assert_eq!(r2.unwrap(), 2);
894
895 actor.join().await;
896 });
897 }
898
899 struct BadlyBehavedTask;
902
903 struct DoBlock;
904 impl Message for DoBlock {
905 type Result = ();
906 }
907
908 impl Actor for BadlyBehavedTask {}
909
910 impl Handler<DoBlock> for BadlyBehavedTask {
911 async fn handle(&mut self, _msg: DoBlock, ctx: &Context<Self>) {
912 rt::sleep(Duration::from_millis(20)).await;
913 thread::sleep(Duration::from_secs(2));
914 ctx.stop();
915 }
916 }
917
918 struct IncrementWell;
919 impl Message for IncrementWell {
920 type Result = ();
921 }
922
923 struct WellBehavedTask {
924 pub count: u64,
925 }
926
927 impl Actor for WellBehavedTask {}
928
929 impl Handler<GetCount> for WellBehavedTask {
930 async fn handle(&mut self, _msg: GetCount, _ctx: &Context<Self>) -> u64 {
931 self.count
932 }
933 }
934
935 impl Handler<StopCounter> for WellBehavedTask {
936 async fn handle(&mut self, _msg: StopCounter, ctx: &Context<Self>) -> u64 {
937 ctx.stop();
938 self.count
939 }
940 }
941
942 impl Handler<IncrementWell> for WellBehavedTask {
943 async fn handle(&mut self, _msg: IncrementWell, ctx: &Context<Self>) {
944 self.count += 1;
945 use crate::tasks::send_after;
946 send_after(Duration::from_millis(100), ctx.clone(), IncrementWell);
947 }
948 }
949
950 #[test]
951 pub fn badly_behaved_thread_non_blocking() {
952 let runtime = rt::Runtime::new().unwrap();
953 runtime.block_on(async move {
954 let badboy = BadlyBehavedTask.start();
955 badboy.send(DoBlock).unwrap();
956 let goodboy = WellBehavedTask { count: 0 }.start();
957 goodboy.send(IncrementWell).unwrap();
958 rt::sleep(Duration::from_secs(1)).await;
959 let count = goodboy.request(GetCount).await.unwrap();
960 assert_ne!(count, 10);
961 goodboy.request(StopCounter).await.unwrap();
962 });
963 }
964
965 #[test]
966 pub fn badly_behaved_thread() {
967 let runtime = rt::Runtime::new().unwrap();
968 runtime.block_on(async move {
969 let badboy = BadlyBehavedTask.start_with_backend(Backend::Blocking);
970 badboy.send(DoBlock).unwrap();
971 let goodboy = WellBehavedTask { count: 0 }.start();
972 goodboy.send(IncrementWell).unwrap();
973 rt::sleep(Duration::from_secs(1)).await;
974 let count = goodboy.request(GetCount).await.unwrap();
975 assert_eq!(count, 10);
976 goodboy.request(StopCounter).await.unwrap();
977 });
978 }
979
980 #[test]
981 pub fn backend_thread_isolates_blocking_work() {
982 let runtime = rt::Runtime::new().unwrap();
983 runtime.block_on(async move {
984 let badboy = BadlyBehavedTask.start_with_backend(Backend::Thread);
985 badboy.send(DoBlock).unwrap();
986 let goodboy = WellBehavedTask { count: 0 }.start();
987 goodboy.send(IncrementWell).unwrap();
988 rt::sleep(Duration::from_secs(1)).await;
989 let count = goodboy.request(GetCount).await.unwrap();
990 assert_eq!(count, 10);
991 goodboy.request(StopCounter).await.unwrap();
992 });
993 }
994
995 #[test]
998 pub fn panic_in_started_stops_actor() {
999 let runtime = rt::Runtime::new().unwrap();
1000 runtime.block_on(async move {
1001 struct PanicOnStart;
1002 struct Ping;
1003 impl Message for Ping {
1004 type Result = ();
1005 }
1006 impl Actor for PanicOnStart {
1007 async fn started(&mut self, _ctx: &Context<Self>) {
1008 panic!("boom in started");
1009 }
1010 }
1011 impl Handler<Ping> for PanicOnStart {
1012 async fn handle(&mut self, _msg: Ping, _ctx: &Context<Self>) {}
1013 }
1014
1015 let actor = PanicOnStart.start();
1016 rt::sleep(Duration::from_millis(50)).await;
1017 let result = actor.send(Ping);
1018 assert!(result.is_err());
1019 });
1020 }
1021
1022 #[test]
1023 pub fn panic_in_handler_stops_actor() {
1024 let runtime = rt::Runtime::new().unwrap();
1025 runtime.block_on(async move {
1026 struct PanicOnMsg;
1027 struct Explode;
1028 impl Message for Explode {
1029 type Result = ();
1030 }
1031 struct Check;
1032 impl Message for Check {
1033 type Result = u32;
1034 }
1035 impl Actor for PanicOnMsg {}
1036 impl Handler<Explode> for PanicOnMsg {
1037 async fn handle(&mut self, _msg: Explode, _ctx: &Context<Self>) {
1038 panic!("boom in handler");
1039 }
1040 }
1041 impl Handler<Check> for PanicOnMsg {
1042 async fn handle(&mut self, _msg: Check, _ctx: &Context<Self>) -> u32 {
1043 42
1044 }
1045 }
1046
1047 let actor = PanicOnMsg.start();
1048 actor.send(Explode).unwrap();
1049 rt::sleep(Duration::from_millis(50)).await;
1050 let result = actor.request(Check).await;
1051 assert!(result.is_err());
1052 });
1053 }
1054
1055 #[test]
1056 pub fn panic_in_stopped_still_completes() {
1057 let runtime = rt::Runtime::new().unwrap();
1058 runtime.block_on(async move {
1059 struct PanicOnStop;
1060 struct StopMe;
1061 impl Message for StopMe {
1062 type Result = ();
1063 }
1064 impl Actor for PanicOnStop {
1065 async fn stopped(&mut self, _ctx: &Context<Self>) {
1066 panic!("boom in stopped");
1067 }
1068 }
1069 impl Handler<StopMe> for PanicOnStop {
1070 async fn handle(&mut self, _msg: StopMe, ctx: &Context<Self>) {
1071 ctx.stop();
1072 }
1073 }
1074
1075 let actor = PanicOnStop.start();
1076 actor.send(StopMe).unwrap();
1077 actor.join().await;
1078 });
1079 }
1080
1081 #[test]
1082 pub fn send_message_on_delivers() {
1083 let runtime = rt::Runtime::new().unwrap();
1084 runtime.block_on(async move {
1085 let counter = Counter { count: 0 }.start();
1086 let ctx = counter.context();
1087 send_message_on(ctx, rt::sleep(Duration::from_millis(10)), Increment);
1088 rt::sleep(Duration::from_millis(100)).await;
1089 let count = counter.request(GetCount).await.unwrap();
1090 assert_eq!(count, 1);
1091 });
1092 }
1093
1094 #[test]
1095 pub fn send_message_on_cancelled() {
1096 let runtime = rt::Runtime::new().unwrap();
1097 runtime.block_on(async move {
1098 let counter = Counter { count: 0 }.start();
1099 let ctx = counter.context();
1100 send_message_on(ctx, rt::sleep(Duration::from_millis(200)), Increment);
1101 let final_count = counter.request(StopCounter).await.unwrap();
1103 assert_eq!(final_count, 0, "message should not have been delivered");
1104 counter.join().await;
1105 });
1106 }
1107}