1use spawned_rt::threads::{
2 self as rt, mpsc, oneshot, oneshot::RecvTimeoutError, CancellationToken,
3};
4use std::{
5 fmt::Debug,
6 panic::{catch_unwind, AssertUnwindSafe},
7 sync::{Arc, Condvar, Mutex},
8 time::Duration,
9};
10
11use crate::error::ActorError;
12use crate::message::Message;
13
14pub use crate::response::DEFAULT_REQUEST_TIMEOUT;
15
16pub trait Actor: Send + Sized + 'static {
27 fn started(&mut self, _ctx: &Context<Self>) {}
28 fn stopped(&mut self, _ctx: &Context<Self>) {}
29}
30
31pub trait Handler<M: Message>: Actor {
39 fn handle(&mut self, msg: M, ctx: &Context<Self>) -> M::Result;
40}
41
42trait Envelope<A: Actor>: Send {
47 fn handle(self: Box<Self>, actor: &mut A, ctx: &Context<A>);
48}
49
50struct MessageEnvelope<M: Message> {
51 msg: M,
52 tx: Option<oneshot::Sender<M::Result>>,
53}
54
55impl<A, M> Envelope<A> for MessageEnvelope<M>
56where
57 A: Actor + Handler<M>,
58 M: Message,
59{
60 fn handle(self: Box<Self>, actor: &mut A, ctx: &Context<A>) {
61 let result = actor.handle(self.msg, ctx);
62 if let Some(tx) = self.tx {
63 let _ = tx.send(result);
64 }
65 }
66}
67
68pub struct Context<A: Actor> {
77 sender: mpsc::Sender<Box<dyn Envelope<A> + Send>>,
78 cancellation_token: CancellationToken,
79 completion: Arc<(Mutex<bool>, Condvar)>,
80}
81
82impl<A: Actor> Clone for Context<A> {
83 fn clone(&self) -> Self {
84 Self {
85 sender: self.sender.clone(),
86 cancellation_token: self.cancellation_token.clone(),
87 completion: self.completion.clone(),
88 }
89 }
90}
91
92impl<A: Actor> Debug for Context<A> {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("Context").finish_non_exhaustive()
95 }
96}
97
98impl<A: Actor> Context<A> {
99 pub fn from_ref(actor_ref: &ActorRef<A>) -> Self {
102 Self {
103 sender: actor_ref.sender.clone(),
104 cancellation_token: actor_ref.cancellation_token.clone(),
105 completion: actor_ref.completion.clone(),
106 }
107 }
108
109 pub fn stop(&self) {
112 self.cancellation_token.cancel();
113 }
114
115 pub fn send<M>(&self, msg: M) -> Result<(), ActorError>
117 where
118 A: Handler<M>,
119 M: Message,
120 {
121 let envelope = MessageEnvelope { msg, tx: None };
122 self.sender
123 .send(Box::new(envelope))
124 .map_err(|_| ActorError::ActorStopped)
125 }
126
127 pub fn request_raw<M>(&self, msg: M) -> Result<oneshot::Receiver<M::Result>, ActorError>
129 where
130 A: Handler<M>,
131 M: Message,
132 {
133 let (tx, rx) = oneshot::channel();
134 let envelope = MessageEnvelope { msg, tx: Some(tx) };
135 self.sender
136 .send(Box::new(envelope))
137 .map_err(|_| ActorError::ActorStopped)?;
138 Ok(rx)
139 }
140
141 pub fn request<M>(&self, msg: M) -> Result<M::Result, ActorError>
143 where
144 A: Handler<M>,
145 M: Message,
146 {
147 self.request_with_timeout(msg, DEFAULT_REQUEST_TIMEOUT)
148 }
149
150 pub fn request_with_timeout<M>(
152 &self,
153 msg: M,
154 duration: Duration,
155 ) -> Result<M::Result, ActorError>
156 where
157 A: Handler<M>,
158 M: Message,
159 {
160 let rx = self.request_raw(msg)?;
161 match rx.recv_timeout(duration) {
162 Ok(result) => Ok(result),
163 Err(RecvTimeoutError::Timeout) => Err(ActorError::RequestTimeout),
164 Err(RecvTimeoutError::Disconnected) => Err(ActorError::ActorStopped),
165 }
166 }
167
168 pub fn recipient<M>(&self) -> Recipient<M>
171 where
172 A: Handler<M>,
173 M: Message,
174 {
175 Arc::new(self.clone())
176 }
177
178 pub fn actor_ref(&self) -> ActorRef<A> {
180 ActorRef {
181 sender: self.sender.clone(),
182 cancellation_token: self.cancellation_token.clone(),
183 completion: self.completion.clone(),
184 }
185 }
186
187 pub(crate) fn cancellation_token(&self) -> CancellationToken {
188 self.cancellation_token.clone()
189 }
190}
191
192impl<A, M> Receiver<M> for Context<A>
194where
195 A: Actor + Handler<M>,
196 M: Message,
197{
198 fn send(&self, msg: M) -> Result<(), ActorError> {
199 Context::send(self, msg)
200 }
201
202 fn request_raw(&self, msg: M) -> Result<oneshot::Receiver<M::Result>, ActorError> {
203 Context::request_raw(self, msg)
204 }
205}
206
207pub trait Receiver<M: Message>: Send + Sync {
216 fn send(&self, msg: M) -> Result<(), ActorError>;
217 fn request_raw(&self, msg: M) -> Result<oneshot::Receiver<M::Result>, ActorError>;
218}
219
220pub type Recipient<M> = Arc<dyn Receiver<M>>;
222
223pub fn request<M: Message>(
225 recipient: &dyn Receiver<M>,
226 msg: M,
227 timeout: Duration,
228) -> Result<M::Result, ActorError> {
229 let rx = recipient.request_raw(msg)?;
230 match rx.recv_timeout(timeout) {
231 Ok(result) => Ok(result),
232 Err(RecvTimeoutError::Timeout) => Err(ActorError::RequestTimeout),
233 Err(RecvTimeoutError::Disconnected) => Err(ActorError::ActorStopped),
234 }
235}
236
237struct CompletionGuard(Arc<(Mutex<bool>, Condvar)>);
242
243impl Drop for CompletionGuard {
244 fn drop(&mut self) {
245 let (lock, cvar) = &*self.0;
246 let mut completed = lock.lock().unwrap_or_else(|p| p.into_inner());
247 *completed = true;
248 cvar.notify_all();
249 }
250}
251
252pub struct ActorRef<A: Actor> {
258 sender: mpsc::Sender<Box<dyn Envelope<A> + Send>>,
259 cancellation_token: CancellationToken,
260 completion: Arc<(Mutex<bool>, Condvar)>,
261}
262
263impl<A: Actor> Debug for ActorRef<A> {
264 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265 f.debug_struct("ActorRef").finish_non_exhaustive()
266 }
267}
268
269impl<A: Actor> Clone for ActorRef<A> {
270 fn clone(&self) -> Self {
271 Self {
272 sender: self.sender.clone(),
273 cancellation_token: self.cancellation_token.clone(),
274 completion: self.completion.clone(),
275 }
276 }
277}
278
279impl<A: Actor> ActorRef<A> {
280 pub fn send<M>(&self, msg: M) -> Result<(), ActorError>
282 where
283 A: Handler<M>,
284 M: Message,
285 {
286 let envelope = MessageEnvelope { msg, tx: None };
287 self.sender
288 .send(Box::new(envelope))
289 .map_err(|_| ActorError::ActorStopped)
290 }
291
292 pub fn request_raw<M>(&self, msg: M) -> Result<oneshot::Receiver<M::Result>, ActorError>
294 where
295 A: Handler<M>,
296 M: Message,
297 {
298 let (tx, rx) = oneshot::channel();
299 let envelope = MessageEnvelope { msg, tx: Some(tx) };
300 self.sender
301 .send(Box::new(envelope))
302 .map_err(|_| ActorError::ActorStopped)?;
303 Ok(rx)
304 }
305
306 pub fn request<M>(&self, msg: M) -> Result<M::Result, ActorError>
308 where
309 A: Handler<M>,
310 M: Message,
311 {
312 self.request_with_timeout(msg, DEFAULT_REQUEST_TIMEOUT)
313 }
314
315 pub fn request_with_timeout<M>(
317 &self,
318 msg: M,
319 duration: Duration,
320 ) -> Result<M::Result, ActorError>
321 where
322 A: Handler<M>,
323 M: Message,
324 {
325 let rx = self.request_raw(msg)?;
326 match rx.recv_timeout(duration) {
327 Ok(result) => Ok(result),
328 Err(RecvTimeoutError::Timeout) => Err(ActorError::RequestTimeout),
329 Err(RecvTimeoutError::Disconnected) => Err(ActorError::ActorStopped),
330 }
331 }
332
333 pub fn recipient<M>(&self) -> Recipient<M>
335 where
336 A: Handler<M>,
337 M: Message,
338 {
339 Arc::new(self.clone())
340 }
341
342 pub fn context(&self) -> Context<A> {
344 Context::from_ref(self)
345 }
346
347 pub fn join(&self) {
349 let (lock, cvar) = &*self.completion;
350 let mut completed = lock.lock().unwrap_or_else(|p| p.into_inner());
351 while !*completed {
352 completed = cvar.wait(completed).unwrap_or_else(|p| p.into_inner());
353 }
354 }
355}
356
357impl<A, M> Receiver<M> for ActorRef<A>
359where
360 A: Actor + Handler<M>,
361 M: Message,
362{
363 fn send(&self, msg: M) -> Result<(), ActorError> {
364 ActorRef::send(self, msg)
365 }
366
367 fn request_raw(&self, msg: M) -> Result<oneshot::Receiver<M::Result>, ActorError> {
368 ActorRef::request_raw(self, msg)
369 }
370}
371
372impl<A: Actor> ActorRef<A> {
377 fn spawn(actor: A) -> Self {
378 let (tx, rx) = mpsc::channel::<Box<dyn Envelope<A> + Send>>();
379 let cancellation_token = CancellationToken::new();
380 let completion = Arc::new((Mutex::new(false), Condvar::new()));
381
382 let actor_ref = ActorRef {
383 sender: tx.clone(),
384 cancellation_token: cancellation_token.clone(),
385 completion: completion.clone(),
386 };
387
388 let ctx = Context {
389 sender: tx,
390 cancellation_token: cancellation_token.clone(),
391 completion: actor_ref.completion.clone(),
392 };
393
394 let _thread_handle = rt::spawn(move || {
395 let _guard = CompletionGuard(completion);
396 run_actor(actor, ctx, rx, cancellation_token);
397 });
398
399 actor_ref
400 }
401}
402
403fn run_actor<A: Actor>(
404 mut actor: A,
405 ctx: Context<A>,
406 rx: mpsc::Receiver<Box<dyn Envelope<A> + Send>>,
407 cancellation_token: CancellationToken,
408) {
409 let start_result = catch_unwind(AssertUnwindSafe(|| {
410 actor.started(&ctx);
411 }));
412 if let Err(panic) = start_result {
413 tracing::error!("Panic in started() callback: {panic:?}");
414 cancellation_token.cancel();
415 return;
416 }
417
418 if cancellation_token.is_cancelled() {
419 let _ = catch_unwind(AssertUnwindSafe(|| actor.stopped(&ctx)));
420 return;
421 }
422
423 loop {
424 let msg = match rx.recv_timeout(Duration::from_millis(100)) {
425 Ok(msg) => Some(msg),
426 Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
427 if cancellation_token.is_cancelled() {
428 break;
429 }
430 continue;
431 }
432 Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => None,
433 };
434 match msg {
435 Some(envelope) => {
436 let result = catch_unwind(AssertUnwindSafe(|| {
437 envelope.handle(&mut actor, &ctx);
438 }));
439 if let Err(panic) = result {
440 tracing::error!("Panic in message handler: {panic:?}");
441 break;
442 }
443 if cancellation_token.is_cancelled() {
444 break;
445 }
446 }
447 None => break,
448 }
449 }
450
451 cancellation_token.cancel();
452 let stop_result = catch_unwind(AssertUnwindSafe(|| {
453 actor.stopped(&ctx);
454 }));
455 if let Err(panic) = stop_result {
456 tracing::error!("Panic in stopped() callback: {panic:?}");
457 }
458}
459
460pub trait ActorStart: Actor {
466 fn start(self) -> ActorRef<Self> {
468 ActorRef::spawn(self)
469 }
470}
471
472impl<A: Actor> ActorStart for A {}
473
474pub fn send_message_on<A, M, F>(ctx: Context<A>, f: F, msg: M) -> rt::JoinHandle<()>
483where
484 A: Actor + Handler<M>,
485 M: Message,
486 F: FnOnce() + Send + 'static,
487{
488 let cancellation_token = ctx.cancellation_token();
489 rt::spawn(move || {
490 f();
491 if !cancellation_token.is_cancelled() {
492 if let Err(e) = ctx.send(msg) {
493 tracing::error!("Failed to send message: {e:?}")
494 }
495 }
496 })
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502 use crate::message::Message;
503 use std::thread;
504
505 struct Counter {
506 count: u64,
507 }
508
509 struct GetCount;
510 impl Message for GetCount {
511 type Result = u64;
512 }
513
514 struct Increment;
515 impl Message for Increment {
516 type Result = u64;
517 }
518
519 struct StopCounter;
520 impl Message for StopCounter {
521 type Result = u64;
522 }
523
524 impl Actor for Counter {}
525
526 impl Handler<GetCount> for Counter {
527 fn handle(&mut self, _msg: GetCount, _ctx: &Context<Self>) -> u64 {
528 self.count
529 }
530 }
531
532 impl Handler<Increment> for Counter {
533 fn handle(&mut self, _msg: Increment, _ctx: &Context<Self>) -> u64 {
534 self.count += 1;
535 self.count
536 }
537 }
538
539 impl Handler<StopCounter> for Counter {
540 fn handle(&mut self, _msg: StopCounter, ctx: &Context<Self>) -> u64 {
541 ctx.stop();
542 self.count
543 }
544 }
545
546 #[test]
547 fn basic_send_and_request() {
548 let actor = Counter { count: 0 }.start();
549 assert_eq!(actor.request(GetCount).unwrap(), 0);
550 assert_eq!(actor.request(Increment).unwrap(), 1);
551 actor.send(Increment).unwrap();
552 rt::sleep(Duration::from_millis(50));
553 assert_eq!(actor.request(GetCount).unwrap(), 2);
554 actor.request(StopCounter).unwrap();
555 }
556
557 #[test]
558 fn join_waits_for_completion() {
559 struct SlowStop;
560 struct StopSlow;
561 impl Message for StopSlow {
562 type Result = ();
563 }
564 impl Actor for SlowStop {
565 fn stopped(&mut self, _ctx: &Context<Self>) {
566 rt::sleep(Duration::from_millis(300));
567 }
568 }
569 impl Handler<StopSlow> for SlowStop {
570 fn handle(&mut self, _msg: StopSlow, ctx: &Context<Self>) {
571 ctx.stop();
572 }
573 }
574
575 let actor = SlowStop.start();
576 actor.send(StopSlow).unwrap();
577 actor.join();
578 }
580
581 #[test]
582 fn join_multiple_callers() {
583 struct SlowStop2;
584 struct StopSlow2;
585 impl Message for StopSlow2 {
586 type Result = ();
587 }
588 impl Actor for SlowStop2 {
589 fn stopped(&mut self, _ctx: &Context<Self>) {
590 rt::sleep(Duration::from_millis(200));
591 }
592 }
593 impl Handler<StopSlow2> for SlowStop2 {
594 fn handle(&mut self, _msg: StopSlow2, ctx: &Context<Self>) {
595 ctx.stop();
596 }
597 }
598
599 let actor = SlowStop2.start();
600 let a1 = actor.clone();
601 let a2 = actor.clone();
602 let t1 = thread::spawn(move || {
603 a1.join();
604 1u32
605 });
606 let t2 = thread::spawn(move || {
607 a2.join();
608 2u32
609 });
610 actor.send(StopSlow2).unwrap();
611 assert_eq!(t1.join().unwrap(), 1);
612 assert_eq!(t2.join().unwrap(), 2);
613 }
614
615 #[test]
616 fn panic_in_started_stops_actor() {
617 struct PanicOnStart;
618 struct PingThread;
619 impl Message for PingThread {
620 type Result = ();
621 }
622 impl Actor for PanicOnStart {
623 fn started(&mut self, _ctx: &Context<Self>) {
624 panic!("boom in started");
625 }
626 }
627 impl Handler<PingThread> for PanicOnStart {
628 fn handle(&mut self, _msg: PingThread, _ctx: &Context<Self>) {}
629 }
630
631 let actor = PanicOnStart.start();
632 rt::sleep(Duration::from_millis(50));
633 let result = actor.send(PingThread);
634 assert!(result.is_err());
635 }
636
637 #[test]
638 fn panic_in_handler_stops_actor() {
639 struct PanicOnMsg;
640 struct ExplodeThread;
641 impl Message for ExplodeThread {
642 type Result = ();
643 }
644 struct CheckThread;
645 impl Message for CheckThread {
646 type Result = u32;
647 }
648 impl Actor for PanicOnMsg {}
649 impl Handler<ExplodeThread> for PanicOnMsg {
650 fn handle(&mut self, _msg: ExplodeThread, _ctx: &Context<Self>) {
651 panic!("boom in handler");
652 }
653 }
654 impl Handler<CheckThread> for PanicOnMsg {
655 fn handle(&mut self, _msg: CheckThread, _ctx: &Context<Self>) -> u32 {
656 42
657 }
658 }
659
660 let actor = PanicOnMsg.start();
661 actor.send(ExplodeThread).unwrap();
662 rt::sleep(Duration::from_millis(200));
663 let result = actor.request(CheckThread);
664 assert!(result.is_err());
665 }
666
667 #[test]
668 fn panic_in_stopped_still_completes() {
669 struct PanicOnStop;
670 struct StopMeThread;
671 impl Message for StopMeThread {
672 type Result = ();
673 }
674 impl Actor for PanicOnStop {
675 fn stopped(&mut self, _ctx: &Context<Self>) {
676 panic!("boom in stopped");
677 }
678 }
679 impl Handler<StopMeThread> for PanicOnStop {
680 fn handle(&mut self, _msg: StopMeThread, ctx: &Context<Self>) {
681 ctx.stop();
682 }
683 }
684
685 let actor = PanicOnStop.start();
686 actor.send(StopMeThread).unwrap();
687 actor.join();
688 }
689
690 #[test]
691 fn recipient_type_erasure() {
692 let actor = Counter { count: 42 }.start();
693 let recipient: Recipient<GetCount> = actor.recipient();
694 let result = request(&*recipient, GetCount, Duration::from_secs(5)).unwrap();
695 assert_eq!(result, 42);
696 }
697
698 #[test]
699 fn send_message_on_delivers() {
700 let actor = Counter { count: 0 }.start();
701 let ctx = actor.context();
702 send_message_on(ctx, || rt::sleep(Duration::from_millis(10)), Increment);
703 rt::sleep(Duration::from_millis(200));
704 let count = actor.request(GetCount).unwrap();
705 assert_eq!(count, 1);
706 }
707}