Skip to main content

spawned_concurrency/threads/
actor.rs

1use spawned_rt::threads::{
2    self as rt, mpsc, oneshot, oneshot::RecvTimeoutError, CancellationToken,
3};
4use std::{
5    fmt::Debug,
6    panic::{catch_unwind, AssertUnwindSafe},
7    sync::{Arc, Condvar, Mutex},
8    time::Duration,
9};
10
11use crate::error::ActorError;
12use crate::message::Message;
13
14pub use crate::response::DEFAULT_REQUEST_TIMEOUT;
15
16// ---------------------------------------------------------------------------
17// Actor trait
18// ---------------------------------------------------------------------------
19
20/// Trait for defining an actor's lifecycle hooks.
21///
22/// Implement this trait (typically via `#[actor]`) to define `started()` and
23/// `stopped()` callbacks. Message handling is defined separately via [`Handler<M>`].
24///
25/// Actors must be `Send + Sized + 'static` so they can be moved to a spawned thread.
26pub trait Actor: Send + Sized + 'static {
27    fn started(&mut self, _ctx: &Context<Self>) {}
28    fn stopped(&mut self, _ctx: &Context<Self>) {}
29}
30
31// ---------------------------------------------------------------------------
32// Handler trait (per-message, sync version)
33// ---------------------------------------------------------------------------
34
35/// Per-message handler trait. Implement once for each message type the actor handles.
36///
37/// Unlike the `tasks` version, handlers are synchronous — no `async`/`.await`.
38pub trait Handler<M: Message>: Actor {
39    fn handle(&mut self, msg: M, ctx: &Context<Self>) -> M::Result;
40}
41
42// ---------------------------------------------------------------------------
43// Envelope (type-erasure)
44// ---------------------------------------------------------------------------
45
46trait Envelope<A: Actor>: Send {
47    fn handle(self: Box<Self>, actor: &mut A, ctx: &Context<A>);
48}
49
50struct MessageEnvelope<M: Message> {
51    msg: M,
52    tx: Option<oneshot::Sender<M::Result>>,
53}
54
55impl<A, M> Envelope<A> for MessageEnvelope<M>
56where
57    A: Actor + Handler<M>,
58    M: Message,
59{
60    fn handle(self: Box<Self>, actor: &mut A, ctx: &Context<A>) {
61        let result = actor.handle(self.msg, ctx);
62        if let Some(tx) = self.tx {
63            let _ = tx.send(result);
64        }
65    }
66}
67
68// ---------------------------------------------------------------------------
69// Context
70// ---------------------------------------------------------------------------
71
72/// Handle passed to every handler and lifecycle hook, providing access to the
73/// actor's mailbox and lifecycle controls.
74///
75/// Clone is cheap — it clones the inner channel sender and cancellation token.
76pub struct Context<A: Actor> {
77    sender: mpsc::Sender<Box<dyn Envelope<A> + Send>>,
78    cancellation_token: CancellationToken,
79    completion: Arc<(Mutex<bool>, Condvar)>,
80}
81
82impl<A: Actor> Clone for Context<A> {
83    fn clone(&self) -> Self {
84        Self {
85            sender: self.sender.clone(),
86            cancellation_token: self.cancellation_token.clone(),
87            completion: self.completion.clone(),
88        }
89    }
90}
91
92impl<A: Actor> Debug for Context<A> {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        f.debug_struct("Context").finish_non_exhaustive()
95    }
96}
97
98impl<A: Actor> Context<A> {
99    /// Create a `Context` from an `ActorRef`. Useful for setting up timers
100    /// or stream listeners from outside the actor.
101    pub fn from_ref(actor_ref: &ActorRef<A>) -> Self {
102        Self {
103            sender: actor_ref.sender.clone(),
104            cancellation_token: actor_ref.cancellation_token.clone(),
105            completion: actor_ref.completion.clone(),
106        }
107    }
108
109    /// Signal the actor to stop. The current handler will finish, then
110    /// `stopped()` is called and the actor exits.
111    pub fn stop(&self) {
112        self.cancellation_token.cancel();
113    }
114
115    /// Send a fire-and-forget message to this actor.
116    pub fn send<M>(&self, msg: M) -> Result<(), ActorError>
117    where
118        A: Handler<M>,
119        M: Message,
120    {
121        let envelope = MessageEnvelope { msg, tx: None };
122        self.sender
123            .send(Box::new(envelope))
124            .map_err(|_| ActorError::ActorStopped)
125    }
126
127    /// Send a request and get a raw oneshot receiver for the reply.
128    pub fn request_raw<M>(&self, msg: M) -> Result<oneshot::Receiver<M::Result>, ActorError>
129    where
130        A: Handler<M>,
131        M: Message,
132    {
133        let (tx, rx) = oneshot::channel();
134        let envelope = MessageEnvelope { msg, tx: Some(tx) };
135        self.sender
136            .send(Box::new(envelope))
137            .map_err(|_| ActorError::ActorStopped)?;
138        Ok(rx)
139    }
140
141    /// Send a request and block until the reply arrives (default 5s timeout).
142    pub fn request<M>(&self, msg: M) -> Result<M::Result, ActorError>
143    where
144        A: Handler<M>,
145        M: Message,
146    {
147        self.request_with_timeout(msg, DEFAULT_REQUEST_TIMEOUT)
148    }
149
150    /// Send a request and block until the reply arrives, with a custom timeout.
151    pub fn request_with_timeout<M>(
152        &self,
153        msg: M,
154        duration: Duration,
155    ) -> Result<M::Result, ActorError>
156    where
157        A: Handler<M>,
158        M: Message,
159    {
160        let rx = self.request_raw(msg)?;
161        match rx.recv_timeout(duration) {
162            Ok(result) => Ok(result),
163            Err(RecvTimeoutError::Timeout) => Err(ActorError::RequestTimeout),
164            Err(RecvTimeoutError::Disconnected) => Err(ActorError::ActorStopped),
165        }
166    }
167
168    /// Get a type-erased `Recipient<M>` for sending a single message type
169    /// to this actor.
170    pub fn recipient<M>(&self) -> Recipient<M>
171    where
172        A: Handler<M>,
173        M: Message,
174    {
175        Arc::new(self.clone())
176    }
177
178    /// Get an `ActorRef<A>` from this context.
179    pub fn actor_ref(&self) -> ActorRef<A> {
180        ActorRef {
181            sender: self.sender.clone(),
182            cancellation_token: self.cancellation_token.clone(),
183            completion: self.completion.clone(),
184        }
185    }
186
187    pub(crate) fn cancellation_token(&self) -> CancellationToken {
188        self.cancellation_token.clone()
189    }
190}
191
192// Bridge: Context<A> implements Receiver<M> for any M that A handles
193impl<A, M> Receiver<M> for Context<A>
194where
195    A: Actor + Handler<M>,
196    M: Message,
197{
198    fn send(&self, msg: M) -> Result<(), ActorError> {
199        Context::send(self, msg)
200    }
201
202    fn request_raw(&self, msg: M) -> Result<oneshot::Receiver<M::Result>, ActorError> {
203        Context::request_raw(self, msg)
204    }
205}
206
207// ---------------------------------------------------------------------------
208// Receiver trait (object-safe) + Recipient alias
209// ---------------------------------------------------------------------------
210
211/// Object-safe trait for sending a single message type to an actor.
212///
213/// Implemented automatically by `ActorRef<A>` and `Context<A>` for any
214/// message type that `A` handles.
215pub trait Receiver<M: Message>: Send + Sync {
216    fn send(&self, msg: M) -> Result<(), ActorError>;
217    fn request_raw(&self, msg: M) -> Result<oneshot::Receiver<M::Result>, ActorError>;
218}
219
220/// Type-erased reference for sending a single message type.
221pub type Recipient<M> = Arc<dyn Receiver<M>>;
222
223/// Send a request through a type-erased `Receiver` with a custom timeout.
224pub fn request<M: Message>(
225    recipient: &dyn Receiver<M>,
226    msg: M,
227    timeout: Duration,
228) -> Result<M::Result, ActorError> {
229    let rx = recipient.request_raw(msg)?;
230    match rx.recv_timeout(timeout) {
231        Ok(result) => Ok(result),
232        Err(RecvTimeoutError::Timeout) => Err(ActorError::RequestTimeout),
233        Err(RecvTimeoutError::Disconnected) => Err(ActorError::ActorStopped),
234    }
235}
236
237// ---------------------------------------------------------------------------
238// ActorRef
239// ---------------------------------------------------------------------------
240
241struct CompletionGuard(Arc<(Mutex<bool>, Condvar)>);
242
243impl Drop for CompletionGuard {
244    fn drop(&mut self) {
245        let (lock, cvar) = &*self.0;
246        let mut completed = lock.lock().unwrap_or_else(|p| p.into_inner());
247        *completed = true;
248        cvar.notify_all();
249    }
250}
251
252/// External handle to a running actor. Cloneable, `Send + Sync`.
253///
254/// Use this to send messages, make requests, or wait for the actor to stop.
255/// To stop the actor, send an explicit shutdown message through your protocol,
256/// or call [`Context::stop`] from within a handler.
257pub struct ActorRef<A: Actor> {
258    sender: mpsc::Sender<Box<dyn Envelope<A> + Send>>,
259    cancellation_token: CancellationToken,
260    completion: Arc<(Mutex<bool>, Condvar)>,
261}
262
263impl<A: Actor> Debug for ActorRef<A> {
264    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
265        f.debug_struct("ActorRef").finish_non_exhaustive()
266    }
267}
268
269impl<A: Actor> Clone for ActorRef<A> {
270    fn clone(&self) -> Self {
271        Self {
272            sender: self.sender.clone(),
273            cancellation_token: self.cancellation_token.clone(),
274            completion: self.completion.clone(),
275        }
276    }
277}
278
279impl<A: Actor> ActorRef<A> {
280    /// Send a fire-and-forget message to the actor.
281    pub fn send<M>(&self, msg: M) -> Result<(), ActorError>
282    where
283        A: Handler<M>,
284        M: Message,
285    {
286        let envelope = MessageEnvelope { msg, tx: None };
287        self.sender
288            .send(Box::new(envelope))
289            .map_err(|_| ActorError::ActorStopped)
290    }
291
292    /// Send a request and get a raw oneshot receiver for the reply.
293    pub fn request_raw<M>(&self, msg: M) -> Result<oneshot::Receiver<M::Result>, ActorError>
294    where
295        A: Handler<M>,
296        M: Message,
297    {
298        let (tx, rx) = oneshot::channel();
299        let envelope = MessageEnvelope { msg, tx: Some(tx) };
300        self.sender
301            .send(Box::new(envelope))
302            .map_err(|_| ActorError::ActorStopped)?;
303        Ok(rx)
304    }
305
306    /// Send a request and block until the reply arrives (default 5s timeout).
307    pub fn request<M>(&self, msg: M) -> Result<M::Result, ActorError>
308    where
309        A: Handler<M>,
310        M: Message,
311    {
312        self.request_with_timeout(msg, DEFAULT_REQUEST_TIMEOUT)
313    }
314
315    /// Send a request and block until the reply arrives, with a custom timeout.
316    pub fn request_with_timeout<M>(
317        &self,
318        msg: M,
319        duration: Duration,
320    ) -> Result<M::Result, ActorError>
321    where
322        A: Handler<M>,
323        M: Message,
324    {
325        let rx = self.request_raw(msg)?;
326        match rx.recv_timeout(duration) {
327            Ok(result) => Ok(result),
328            Err(RecvTimeoutError::Timeout) => Err(ActorError::RequestTimeout),
329            Err(RecvTimeoutError::Disconnected) => Err(ActorError::ActorStopped),
330        }
331    }
332
333    /// Get a type-erased `Recipient<M>` for this actor.
334    pub fn recipient<M>(&self) -> Recipient<M>
335    where
336        A: Handler<M>,
337        M: Message,
338    {
339        Arc::new(self.clone())
340    }
341
342    /// Get a `Context<A>` from this ref, for timer setup or stream listeners.
343    pub fn context(&self) -> Context<A> {
344        Context::from_ref(self)
345    }
346
347    /// Block until the actor has fully stopped (including `stopped()` callback).
348    pub fn join(&self) {
349        let (lock, cvar) = &*self.completion;
350        let mut completed = lock.lock().unwrap_or_else(|p| p.into_inner());
351        while !*completed {
352            completed = cvar.wait(completed).unwrap_or_else(|p| p.into_inner());
353        }
354    }
355}
356
357// Bridge: ActorRef<A> implements Receiver<M> for any M that A handles
358impl<A, M> Receiver<M> for ActorRef<A>
359where
360    A: Actor + Handler<M>,
361    M: Message,
362{
363    fn send(&self, msg: M) -> Result<(), ActorError> {
364        ActorRef::send(self, msg)
365    }
366
367    fn request_raw(&self, msg: M) -> Result<oneshot::Receiver<M::Result>, ActorError> {
368        ActorRef::request_raw(self, msg)
369    }
370}
371
372// ---------------------------------------------------------------------------
373// Actor startup + main loop
374// ---------------------------------------------------------------------------
375
376impl<A: Actor> ActorRef<A> {
377    fn spawn(actor: A) -> Self {
378        let (tx, rx) = mpsc::channel::<Box<dyn Envelope<A> + Send>>();
379        let cancellation_token = CancellationToken::new();
380        let completion = Arc::new((Mutex::new(false), Condvar::new()));
381
382        let actor_ref = ActorRef {
383            sender: tx.clone(),
384            cancellation_token: cancellation_token.clone(),
385            completion: completion.clone(),
386        };
387
388        let ctx = Context {
389            sender: tx,
390            cancellation_token: cancellation_token.clone(),
391            completion: actor_ref.completion.clone(),
392        };
393
394        let _thread_handle = rt::spawn(move || {
395            let _guard = CompletionGuard(completion);
396            run_actor(actor, ctx, rx, cancellation_token);
397        });
398
399        actor_ref
400    }
401}
402
403fn run_actor<A: Actor>(
404    mut actor: A,
405    ctx: Context<A>,
406    rx: mpsc::Receiver<Box<dyn Envelope<A> + Send>>,
407    cancellation_token: CancellationToken,
408) {
409    let start_result = catch_unwind(AssertUnwindSafe(|| {
410        actor.started(&ctx);
411    }));
412    if let Err(panic) = start_result {
413        tracing::error!("Panic in started() callback: {panic:?}");
414        cancellation_token.cancel();
415        return;
416    }
417
418    if cancellation_token.is_cancelled() {
419        let _ = catch_unwind(AssertUnwindSafe(|| actor.stopped(&ctx)));
420        return;
421    }
422
423    loop {
424        let msg = match rx.recv_timeout(Duration::from_millis(100)) {
425            Ok(msg) => Some(msg),
426            Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
427                if cancellation_token.is_cancelled() {
428                    break;
429                }
430                continue;
431            }
432            Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => None,
433        };
434        match msg {
435            Some(envelope) => {
436                let result = catch_unwind(AssertUnwindSafe(|| {
437                    envelope.handle(&mut actor, &ctx);
438                }));
439                if let Err(panic) = result {
440                    tracing::error!("Panic in message handler: {panic:?}");
441                    break;
442                }
443                if cancellation_token.is_cancelled() {
444                    break;
445                }
446            }
447            None => break,
448        }
449    }
450
451    cancellation_token.cancel();
452    let stop_result = catch_unwind(AssertUnwindSafe(|| {
453        actor.stopped(&ctx);
454    }));
455    if let Err(panic) = stop_result {
456        tracing::error!("Panic in stopped() callback: {panic:?}");
457    }
458}
459
460// ---------------------------------------------------------------------------
461// Actor::start
462// ---------------------------------------------------------------------------
463
464/// Extension trait for starting an actor. Automatically implemented for all [`Actor`] types.
465pub trait ActorStart: Actor {
466    /// Start the actor on a dedicated OS thread.
467    fn start(self) -> ActorRef<Self> {
468        ActorRef::spawn(self)
469    }
470}
471
472impl<A: Actor> ActorStart for A {}
473
474// ---------------------------------------------------------------------------
475// send_message_on (utility)
476// ---------------------------------------------------------------------------
477
478/// Send a message to an actor when a blocking closure completes.
479///
480/// Spawns a thread that runs `f()`, then sends `msg` to the actor.
481/// If the actor stops before `f()` returns, the message is not sent.
482pub fn send_message_on<A, M, F>(ctx: Context<A>, f: F, msg: M) -> rt::JoinHandle<()>
483where
484    A: Actor + Handler<M>,
485    M: Message,
486    F: FnOnce() + Send + 'static,
487{
488    let cancellation_token = ctx.cancellation_token();
489    rt::spawn(move || {
490        f();
491        if !cancellation_token.is_cancelled() {
492            if let Err(e) = ctx.send(msg) {
493                tracing::error!("Failed to send message: {e:?}")
494            }
495        }
496    })
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502    use crate::message::Message;
503    use std::thread;
504
505    struct Counter {
506        count: u64,
507    }
508
509    struct GetCount;
510    impl Message for GetCount {
511        type Result = u64;
512    }
513
514    struct Increment;
515    impl Message for Increment {
516        type Result = u64;
517    }
518
519    struct StopCounter;
520    impl Message for StopCounter {
521        type Result = u64;
522    }
523
524    impl Actor for Counter {}
525
526    impl Handler<GetCount> for Counter {
527        fn handle(&mut self, _msg: GetCount, _ctx: &Context<Self>) -> u64 {
528            self.count
529        }
530    }
531
532    impl Handler<Increment> for Counter {
533        fn handle(&mut self, _msg: Increment, _ctx: &Context<Self>) -> u64 {
534            self.count += 1;
535            self.count
536        }
537    }
538
539    impl Handler<StopCounter> for Counter {
540        fn handle(&mut self, _msg: StopCounter, ctx: &Context<Self>) -> u64 {
541            ctx.stop();
542            self.count
543        }
544    }
545
546    #[test]
547    fn basic_send_and_request() {
548        let actor = Counter { count: 0 }.start();
549        assert_eq!(actor.request(GetCount).unwrap(), 0);
550        assert_eq!(actor.request(Increment).unwrap(), 1);
551        actor.send(Increment).unwrap();
552        rt::sleep(Duration::from_millis(50));
553        assert_eq!(actor.request(GetCount).unwrap(), 2);
554        actor.request(StopCounter).unwrap();
555    }
556
557    #[test]
558    fn join_waits_for_completion() {
559        struct SlowStop;
560        struct StopSlow;
561        impl Message for StopSlow {
562            type Result = ();
563        }
564        impl Actor for SlowStop {
565            fn stopped(&mut self, _ctx: &Context<Self>) {
566                rt::sleep(Duration::from_millis(300));
567            }
568        }
569        impl Handler<StopSlow> for SlowStop {
570            fn handle(&mut self, _msg: StopSlow, ctx: &Context<Self>) {
571                ctx.stop();
572            }
573        }
574
575        let actor = SlowStop.start();
576        actor.send(StopSlow).unwrap();
577        actor.join();
578        // If join() returned, stopped() has completed
579    }
580
581    #[test]
582    fn join_multiple_callers() {
583        struct SlowStop2;
584        struct StopSlow2;
585        impl Message for StopSlow2 {
586            type Result = ();
587        }
588        impl Actor for SlowStop2 {
589            fn stopped(&mut self, _ctx: &Context<Self>) {
590                rt::sleep(Duration::from_millis(200));
591            }
592        }
593        impl Handler<StopSlow2> for SlowStop2 {
594            fn handle(&mut self, _msg: StopSlow2, ctx: &Context<Self>) {
595                ctx.stop();
596            }
597        }
598
599        let actor = SlowStop2.start();
600        let a1 = actor.clone();
601        let a2 = actor.clone();
602        let t1 = thread::spawn(move || {
603            a1.join();
604            1u32
605        });
606        let t2 = thread::spawn(move || {
607            a2.join();
608            2u32
609        });
610        actor.send(StopSlow2).unwrap();
611        assert_eq!(t1.join().unwrap(), 1);
612        assert_eq!(t2.join().unwrap(), 2);
613    }
614
615    #[test]
616    fn panic_in_started_stops_actor() {
617        struct PanicOnStart;
618        struct PingThread;
619        impl Message for PingThread {
620            type Result = ();
621        }
622        impl Actor for PanicOnStart {
623            fn started(&mut self, _ctx: &Context<Self>) {
624                panic!("boom in started");
625            }
626        }
627        impl Handler<PingThread> for PanicOnStart {
628            fn handle(&mut self, _msg: PingThread, _ctx: &Context<Self>) {}
629        }
630
631        let actor = PanicOnStart.start();
632        rt::sleep(Duration::from_millis(50));
633        let result = actor.send(PingThread);
634        assert!(result.is_err());
635    }
636
637    #[test]
638    fn panic_in_handler_stops_actor() {
639        struct PanicOnMsg;
640        struct ExplodeThread;
641        impl Message for ExplodeThread {
642            type Result = ();
643        }
644        struct CheckThread;
645        impl Message for CheckThread {
646            type Result = u32;
647        }
648        impl Actor for PanicOnMsg {}
649        impl Handler<ExplodeThread> for PanicOnMsg {
650            fn handle(&mut self, _msg: ExplodeThread, _ctx: &Context<Self>) {
651                panic!("boom in handler");
652            }
653        }
654        impl Handler<CheckThread> for PanicOnMsg {
655            fn handle(&mut self, _msg: CheckThread, _ctx: &Context<Self>) -> u32 {
656                42
657            }
658        }
659
660        let actor = PanicOnMsg.start();
661        actor.send(ExplodeThread).unwrap();
662        rt::sleep(Duration::from_millis(200));
663        let result = actor.request(CheckThread);
664        assert!(result.is_err());
665    }
666
667    #[test]
668    fn panic_in_stopped_still_completes() {
669        struct PanicOnStop;
670        struct StopMeThread;
671        impl Message for StopMeThread {
672            type Result = ();
673        }
674        impl Actor for PanicOnStop {
675            fn stopped(&mut self, _ctx: &Context<Self>) {
676                panic!("boom in stopped");
677            }
678        }
679        impl Handler<StopMeThread> for PanicOnStop {
680            fn handle(&mut self, _msg: StopMeThread, ctx: &Context<Self>) {
681                ctx.stop();
682            }
683        }
684
685        let actor = PanicOnStop.start();
686        actor.send(StopMeThread).unwrap();
687        actor.join();
688    }
689
690    #[test]
691    fn recipient_type_erasure() {
692        let actor = Counter { count: 42 }.start();
693        let recipient: Recipient<GetCount> = actor.recipient();
694        let result = request(&*recipient, GetCount, Duration::from_secs(5)).unwrap();
695        assert_eq!(result, 42);
696    }
697
698    #[test]
699    fn send_message_on_delivers() {
700        let actor = Counter { count: 0 }.start();
701        let ctx = actor.context();
702        send_message_on(ctx, || rt::sleep(Duration::from_millis(10)), Increment);
703        rt::sleep(Duration::from_millis(200));
704        let count = actor.request(GetCount).unwrap();
705        assert_eq!(count, 1);
706    }
707}