1use async_trait::async_trait;
21use serde::Serialize;
22use tracing::{info, warn};
23
24use crate::mailbox::Inbox;
25use crate::{
26 Actor, ActorContext, ActorExitStatus, ActorHandle, ActorState, Handler, Health, Supervisable,
27};
28
29#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, Serialize)]
30pub struct SupervisorState {
31 pub num_panics: usize,
32 pub num_errors: usize,
33 pub num_kills: usize,
34}
35
36pub struct Supervisor<A: Actor> {
37 actor_name: String,
38 actor_factory: Box<dyn Fn() -> A + Sync + Send>,
39 inbox: Inbox<A>,
40 handle_opt: Option<ActorHandle<A>>,
41 state: SupervisorState,
42}
43
44#[derive(Debug, Copy, Clone)]
45struct SuperviseLoop;
46
47#[async_trait]
48impl<A: Actor> Actor for Supervisor<A> {
49 type ObservableState = SupervisorState;
50
51 fn observable_state(&self) -> Self::ObservableState {
52 self.state
53 }
54
55 fn name(&self) -> String {
56 format!("Supervisor({})", self.actor_name)
57 }
58
59 fn queue_capacity(&self) -> crate::QueueCapacity {
60 crate::QueueCapacity::Unbounded
61 }
62
63 async fn initialize(&mut self, ctx: &ActorContext<Self>) -> Result<(), ActorExitStatus> {
64 ctx.schedule_self_msg(crate::HEARTBEAT, SuperviseLoop).await;
65 Ok(())
66 }
67
68 async fn finalize(
69 &mut self,
70 exit_status: &ActorExitStatus,
71 _ctx: &ActorContext<Self>,
72 ) -> anyhow::Result<()> {
73 match exit_status {
74 ActorExitStatus::Quit => {
75 if let Some(handle) = self.handle_opt.take() {
76 handle.quit().await;
77 }
78 }
79 ActorExitStatus::Killed => {
80 if let Some(handle) = self.handle_opt.take() {
81 handle.kill().await;
82 }
83 }
84 ActorExitStatus::Failure(_)
85 | ActorExitStatus::Success
86 | ActorExitStatus::DownstreamClosed => {}
87 ActorExitStatus::Panicked => {}
88 }
89
90 Ok(())
91 }
92}
93
94impl<A: Actor> Supervisor<A> {
95 pub(crate) fn new(
96 actor_name: String,
97 actor_factory: Box<dyn Fn() -> A + Sync + Send>,
98 inbox: Inbox<A>,
99 handle: ActorHandle<A>,
100 ) -> Self {
101 let state = Default::default();
102 Supervisor {
103 actor_name,
104 actor_factory,
105 inbox,
106 handle_opt: Some(handle),
107 state,
108 }
109 }
110
111 async fn supervise(
112 &mut self,
113 ctx: &ActorContext<Supervisor<A>>,
114 ) -> Result<(), ActorExitStatus> {
115 match self
116 .handle_opt
117 .as_ref()
118 .expect("The actor handle should always be set.")
119 .harvest_health()
120 {
121 Health::Healthy => {
122 return Ok(());
123 }
124 Health::FailureOrUnhealthy => {}
125 Health::Success => {
126 return Err(ActorExitStatus::Success);
127 }
128 }
129 warn!("unhealthy-actor");
130 let actor_handle = self.handle_opt.take().unwrap();
132 let actor_mailbox = actor_handle.mailbox().clone();
133 let (actor_exit_status, _last_state) = if actor_handle.state() == ActorState::Processing {
134 warn!("killing");
137 actor_handle.kill().await
138 } else {
139 actor_handle.join().await
140 };
141 match actor_exit_status {
142 ActorExitStatus::Success => {
143 return Err(ActorExitStatus::Success);
144 }
145 ActorExitStatus::Quit => {
146 return Err(ActorExitStatus::Quit);
147 }
148 ActorExitStatus::DownstreamClosed => {
149 return Err(ActorExitStatus::DownstreamClosed);
150 }
151 ActorExitStatus::Killed => {
152 self.state.num_kills += 1;
153 }
154 ActorExitStatus::Failure(_err) => {
155 self.state.num_errors += 1;
156 }
157 ActorExitStatus::Panicked => {
158 self.state.num_panics += 1;
159 }
160 }
161 info!("respawning-actor");
162 let (_, actor_handle) = ctx
163 .spawn_actor()
164 .set_mailboxes(actor_mailbox, self.inbox.clone())
165 .set_kill_switch(ctx.kill_switch().child())
166 .spawn((*self.actor_factory)());
167 self.handle_opt = Some(actor_handle);
168 Ok(())
169 }
170}
171
172#[async_trait]
173impl<A: Actor> Handler<SuperviseLoop> for Supervisor<A> {
174 type Reply = ();
175
176 async fn handle(
177 &mut self,
178 _msg: SuperviseLoop,
179 ctx: &ActorContext<Self>,
180 ) -> Result<Self::Reply, ActorExitStatus> {
181 self.supervise(ctx).await?;
182 ctx.schedule_self_msg(crate::HEARTBEAT, SuperviseLoop).await;
183 Ok(())
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use std::time::Duration;
190
191 use async_trait::async_trait;
192 use tracing::info;
193
194 use crate::supervisor::SupervisorState;
195 use crate::{Actor, ActorContext, ActorExitStatus, AskError, Handler, Universe};
196
197 #[derive(Copy, Clone, Debug)]
198 enum FailingActorMessage {
199 Panic,
200 ReturnError,
201 Increment,
202 Freeze(Duration),
203 }
204
205 #[derive(Default, Clone)]
206 struct FailingActor {
207 counter: usize,
208 }
209
210 #[async_trait]
211 impl Actor for FailingActor {
212 type ObservableState = usize;
213
214 fn name(&self) -> String {
215 "FailingActor".to_string()
216 }
217
218 fn observable_state(&self) -> Self::ObservableState {
219 self.counter
220 }
221
222 async fn finalize(
223 &mut self,
224 _exit_status: &ActorExitStatus,
225 _ctx: &ActorContext<Self>,
226 ) -> anyhow::Result<()> {
227 info!("finalize-failing-actor");
228 Ok(())
229 }
230 }
231
232 #[async_trait]
233 impl Handler<FailingActorMessage> for FailingActor {
234 type Reply = usize;
235
236 async fn handle(
237 &mut self,
238 msg: FailingActorMessage,
239 ctx: &ActorContext<Self>,
240 ) -> Result<Self::Reply, ActorExitStatus> {
241 match msg {
242 FailingActorMessage::Panic => {
243 panic!("Failing actor panicked");
244 }
245 FailingActorMessage::ReturnError => {
246 return Err(ActorExitStatus::from(anyhow::anyhow!(
247 "Failing actor error"
248 )));
249 }
250 FailingActorMessage::Increment => {
251 self.counter += 1;
252 }
253 FailingActorMessage::Freeze(wait_duration) => {
254 ctx.sleep(wait_duration).await;
255 }
256 }
257 Ok(self.counter)
258 }
259 }
260
261 #[tokio::test]
262 async fn test_supervisor_restart_on_panic() {
263 let universe = Universe::with_accelerated_time();
265 let actor = FailingActor::default();
266 let (mailbox, supervisor_handle) = universe.spawn_builder().supervise(actor);
267 assert_eq!(
268 mailbox.ask(FailingActorMessage::Increment).await.unwrap(),
269 1
270 );
271 assert_eq!(
272 mailbox.ask(FailingActorMessage::Increment).await.unwrap(),
273 2
274 );
275 assert!(mailbox.ask(FailingActorMessage::Panic).await.is_err());
276 assert_eq!(
277 mailbox.ask(FailingActorMessage::Increment).await.unwrap(),
278 1
279 );
280 assert_eq!(
281 *supervisor_handle.observe().await,
282 SupervisorState {
283 num_panics: 1,
284 num_errors: 0,
285 num_kills: 0
286 }
287 );
288 assert!(!matches!(
289 supervisor_handle.quit().await.0,
290 ActorExitStatus::Panicked
291 ));
292 }
293
294 #[tokio::test]
295 async fn test_supervisor_restart_on_error() {
296 let universe = Universe::with_accelerated_time();
297 let actor = FailingActor::default();
298 let (mailbox, supervisor_handle) = universe.spawn_builder().supervise(actor);
299 assert_eq!(
300 mailbox.ask(FailingActorMessage::Increment).await.unwrap(),
301 1
302 );
303 assert_eq!(
304 mailbox.ask(FailingActorMessage::Increment).await.unwrap(),
305 2
306 );
307 assert!(mailbox.ask(FailingActorMessage::ReturnError).await.is_err());
308 assert_eq!(
309 mailbox.ask(FailingActorMessage::Increment).await.unwrap(),
310 1
311 );
312 assert_eq!(
313 *supervisor_handle.observe().await,
314 SupervisorState {
315 num_panics: 0,
316 num_errors: 1,
317 num_kills: 0
318 }
319 );
320 assert!(!matches!(
321 supervisor_handle.quit().await.0,
322 ActorExitStatus::Panicked
323 ));
324 }
325
326 #[tokio::test]
327 async fn test_supervisor_kills_and_restart_frozen_actor() {
328 let universe = Universe::with_accelerated_time();
329 let actor = FailingActor::default();
330 let (mailbox, supervisor_handle) = universe.spawn_builder().supervise(actor);
331 assert_eq!(
332 mailbox.ask(FailingActorMessage::Increment).await.unwrap(),
333 1
334 );
335 assert_eq!(
336 mailbox.ask(FailingActorMessage::Increment).await.unwrap(),
337 2
338 );
339 assert_eq!(
340 *supervisor_handle.observe().await,
341 SupervisorState {
342 num_panics: 0,
343 num_errors: 0,
344 num_kills: 0
345 }
346 );
347 mailbox
348 .send_message(FailingActorMessage::Freeze(
349 crate::HEARTBEAT.mul_f32(3.0f32),
350 ))
351 .await
352 .unwrap();
353 assert_eq!(
354 mailbox.ask(FailingActorMessage::Increment).await.unwrap(),
355 1
356 );
357 assert_eq!(
358 *supervisor_handle.observe().await,
359 SupervisorState {
360 num_panics: 0,
361 num_errors: 0,
362 num_kills: 1
363 }
364 );
365 assert!(!matches!(
366 supervisor_handle.quit().await.0,
367 ActorExitStatus::Panicked
368 ));
369 }
370
371 #[tokio::test]
372 async fn test_supervisor_forwards_quit_commands() {
373 let universe = Universe::with_accelerated_time();
374 let actor = FailingActor::default();
375 let (mailbox, supervisor_handle) = universe.spawn_builder().supervise(actor);
376 assert_eq!(
377 mailbox.ask(FailingActorMessage::Increment).await.unwrap(),
378 1
379 );
380 let (exit_status, _state) = supervisor_handle.quit().await;
381 assert!(matches!(
382 mailbox
383 .ask(FailingActorMessage::Increment)
384 .await
385 .unwrap_err(),
386 AskError::MessageNotDelivered
387 ));
388 assert!(matches!(exit_status, ActorExitStatus::Quit));
389 }
390
391 #[tokio::test]
392 async fn test_supervisor_forwards_kill_command() {
393 let universe = Universe::with_accelerated_time();
395 let actor = FailingActor::default();
396 let (mailbox, supervisor_handle) = universe.spawn_builder().supervise(actor);
397 assert_eq!(
398 mailbox.ask(FailingActorMessage::Increment).await.unwrap(),
399 1
400 );
401 let (exit_status, _state) = supervisor_handle.kill().await;
402 assert!(mailbox.ask(FailingActorMessage::Increment).await.is_err());
403 assert!(matches!(
404 mailbox
405 .ask(FailingActorMessage::Increment)
406 .await
407 .unwrap_err(),
408 AskError::MessageNotDelivered
409 ));
410 assert!(matches!(exit_status, ActorExitStatus::Killed));
411 }
412
413 #[tokio::test]
414 async fn test_supervisor_exits_successfully_when_supervised_actor_mailbox_is_dropped() {
415 let universe = Universe::with_accelerated_time();
417 let actor = FailingActor::default();
418 let (_, supervisor_handle) = universe.spawn_builder().supervise(actor);
419 let (exit_status, _state) = supervisor_handle.join().await;
420 assert!(matches!(exit_status, ActorExitStatus::Success));
421 }
422}