Skip to main content

rustrails_support/
callbacks.rs

1use std::any::Any;
2use std::fmt;
3
4/// Result of executing a callback or wrapped action.
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum CallbackResult {
7    /// Continue running the callback chain.
8    Continue,
9    /// Stop running the callback chain.
10    Halt,
11}
12
13/// When in the lifecycle a callback runs.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum CallbackKind {
16    /// Runs before the wrapped action.
17    Before,
18    /// Runs after the wrapped action.
19    After,
20    /// Wraps the wrapped action and any inner around callbacks.
21    Around,
22}
23
24/// A predicate used to decide whether a callback should fire.
25pub type ConditionPredicate = Box<dyn Fn(&dyn Any) -> bool + Send + Sync + 'static>;
26
27/// The continuation passed to around callbacks.
28pub type AroundContinuation<'a, T> = Box<dyn FnOnce(&mut T) -> CallbackResult + 'a>;
29
30/// The function signature for around callbacks.
31pub type AroundFilter<T> = Box<
32    dyn for<'a> Fn(&mut T, AroundContinuation<'a, T>) -> CallbackResult + Send + Sync + 'static,
33>;
34
35/// Conditions that control when a callback is eligible to run.
36#[derive(Default)]
37pub struct CallbackConditions {
38    /// Optional allow-list of action names. When present, only listed actions match.
39    pub only: Option<Vec<String>>,
40    /// Optional deny-list of action names. When present, listed actions do not match.
41    pub except: Option<Vec<String>>,
42    /// Predicate that must return `true` for the callback to run.
43    pub if_cond: Option<ConditionPredicate>,
44    /// Predicate that must return `false` for the callback to run.
45    pub unless_cond: Option<ConditionPredicate>,
46}
47
48impl fmt::Debug for CallbackConditions {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        f.debug_struct("CallbackConditions")
51            .field("only", &self.only)
52            .field("except", &self.except)
53            .field("if_cond", &self.if_cond.as_ref().map(|_| "<predicate>"))
54            .field(
55                "unless_cond",
56                &self.unless_cond.as_ref().map(|_| "<predicate>"),
57            )
58            .finish()
59    }
60}
61
62impl CallbackConditions {
63    /// Creates an empty set of callback conditions.
64    #[must_use]
65    pub fn new() -> Self {
66        Self::default()
67    }
68
69    /// Restricts the callback to the provided action names.
70    #[must_use]
71    pub fn only<I, S>(mut self, actions: I) -> Self
72    where
73        I: IntoIterator<Item = S>,
74        S: Into<String>,
75    {
76        self.only = Some(actions.into_iter().map(Into::into).collect());
77        self
78    }
79
80    /// Prevents the callback from running for the provided action names.
81    #[must_use]
82    pub fn except<I, S>(mut self, actions: I) -> Self
83    where
84        I: IntoIterator<Item = S>,
85        S: Into<String>,
86    {
87        self.except = Some(actions.into_iter().map(Into::into).collect());
88        self
89    }
90
91    /// Adds a predicate that must return `true` for the callback to run.
92    #[must_use]
93    pub fn if_cond<F>(mut self, predicate: F) -> Self
94    where
95        F: Fn(&dyn Any) -> bool + Send + Sync + 'static,
96    {
97        self.if_cond = Some(Box::new(predicate));
98        self
99    }
100
101    /// Adds a predicate that must return `false` for the callback to run.
102    #[must_use]
103    pub fn unless_cond<F>(mut self, predicate: F) -> Self
104    where
105        F: Fn(&dyn Any) -> bool + Send + Sync + 'static,
106    {
107        self.unless_cond = Some(Box::new(predicate));
108        self
109    }
110
111    fn matches(&self, target: &dyn Any, action_name: &str) -> bool {
112        if let Some(only) = &self.only
113            && !only.iter().any(|candidate| candidate == action_name)
114        {
115            return false;
116        }
117
118        if let Some(except) = &self.except
119            && except.iter().any(|candidate| candidate == action_name)
120        {
121            return false;
122        }
123
124        if let Some(predicate) = &self.if_cond
125            && !predicate(target)
126        {
127            return false;
128        }
129
130        if let Some(predicate) = &self.unless_cond
131            && predicate(target)
132        {
133            return false;
134        }
135
136        true
137    }
138}
139
140/// The executable body of a callback.
141pub enum CallbackFilter<T: Send + Sync + 'static> {
142    /// A standard before or after callback.
143    Standard(Box<dyn Fn(&mut T) -> CallbackResult + Send + Sync + 'static>),
144    /// An around callback that decides whether and when to call the continuation.
145    Around(AroundFilter<T>),
146}
147
148impl<T: Send + Sync + 'static> fmt::Debug for CallbackFilter<T> {
149    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150        match self {
151            Self::Standard(_) => f.write_str("Standard(<callback>)"),
152            Self::Around(_) => f.write_str("Around(<callback>)"),
153        }
154    }
155}
156
157/// A single callback entry in a callback chain.
158pub struct Callback<T: Send + Sync + 'static> {
159    /// Stable callback name used for replacement and skipping.
160    pub name: String,
161    /// The lifecycle phase where this callback runs.
162    pub kind: CallbackKind,
163    /// The executable body for the callback.
164    pub filter: CallbackFilter<T>,
165    /// Conditions that control whether this callback runs.
166    pub conditions: CallbackConditions,
167}
168
169impl<T: Send + Sync + 'static> fmt::Debug for Callback<T> {
170    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171        f.debug_struct("Callback")
172            .field("name", &self.name)
173            .field("kind", &self.kind)
174            .field("filter", &self.filter)
175            .field("conditions", &self.conditions)
176            .finish()
177    }
178}
179
180impl<T: Send + Sync + 'static> Callback<T> {
181    /// Creates a before callback.
182    #[must_use]
183    pub fn before<N, F>(name: N, filter: F) -> Self
184    where
185        N: Into<String>,
186        F: Fn(&mut T) -> CallbackResult + Send + Sync + 'static,
187    {
188        Self {
189            name: name.into(),
190            kind: CallbackKind::Before,
191            filter: CallbackFilter::Standard(Box::new(filter)),
192            conditions: CallbackConditions::default(),
193        }
194    }
195
196    /// Creates an after callback.
197    #[must_use]
198    pub fn after<N, F>(name: N, filter: F) -> Self
199    where
200        N: Into<String>,
201        F: Fn(&mut T) -> CallbackResult + Send + Sync + 'static,
202    {
203        Self {
204            name: name.into(),
205            kind: CallbackKind::After,
206            filter: CallbackFilter::Standard(Box::new(filter)),
207            conditions: CallbackConditions::default(),
208        }
209    }
210
211    /// Creates an around callback.
212    #[must_use]
213    pub fn around<N, F>(name: N, filter: F) -> Self
214    where
215        N: Into<String>,
216        F: for<'a> Fn(&mut T, AroundContinuation<'a, T>) -> CallbackResult + Send + Sync + 'static,
217    {
218        Self {
219            name: name.into(),
220            kind: CallbackKind::Around,
221            filter: CallbackFilter::Around(Box::new(filter)),
222            conditions: CallbackConditions::default(),
223        }
224    }
225
226    /// Replaces the callback conditions.
227    #[must_use]
228    pub fn with_conditions(mut self, conditions: CallbackConditions) -> Self {
229        self.conditions = conditions;
230        self
231    }
232}
233
234/// A named chain of callbacks for a lifecycle event.
235pub struct CallbackChain<T: Send + Sync + 'static> {
236    name: String,
237    callbacks: Vec<Callback<T>>,
238    run_after_on_halt: bool,
239}
240
241impl<T: Send + Sync + 'static> fmt::Debug for CallbackChain<T> {
242    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243        f.debug_struct("CallbackChain")
244            .field("name", &self.name)
245            .field("callbacks", &self.callbacks)
246            .field("run_after_on_halt", &self.run_after_on_halt)
247            .finish()
248    }
249}
250
251impl<T: Send + Sync + 'static> CallbackChain<T> {
252    /// Creates an empty callback chain for the given lifecycle name.
253    #[must_use]
254    pub fn new(name: impl Into<String>) -> Self {
255        Self {
256            name: name.into(),
257            callbacks: Vec::new(),
258            run_after_on_halt: true,
259        }
260    }
261
262    /// Returns the chain name.
263    #[must_use]
264    pub fn name(&self) -> &str {
265        &self.name
266    }
267
268    /// Returns the number of registered callbacks.
269    #[must_use]
270    pub fn len(&self) -> usize {
271        self.callbacks.len()
272    }
273
274    /// Returns `true` when the chain has no callbacks.
275    #[must_use]
276    pub fn is_empty(&self) -> bool {
277        self.callbacks.is_empty()
278    }
279
280    /// Controls whether after callbacks still run when the chain halts.
281    pub fn set_run_after_on_halt(&mut self, run_after_on_halt: bool) {
282        self.run_after_on_halt = run_after_on_halt;
283    }
284
285    /// Adds a callback to the chain, replacing any existing callback with the same name.
286    pub fn add(&mut self, callback: Callback<T>) {
287        self.callbacks
288            .retain(|existing| existing.name != callback.name);
289        self.callbacks.push(callback);
290    }
291
292    /// Removes every callback with the provided name.
293    pub fn skip(&mut self, name: &str) {
294        self.callbacks.retain(|callback| callback.name != name);
295    }
296
297    /// Clears the callback chain.
298    pub fn reset(&mut self) {
299        self.callbacks.clear();
300    }
301
302    /// Runs the callback chain without a custom action.
303    pub fn run(&self, target: &mut T, action_name: &str) -> CallbackResult {
304        self.run_with(target, action_name, |_| CallbackResult::Continue)
305    }
306
307    /// Runs the callback chain around the provided action.
308    pub fn run_with<F>(&self, target: &mut T, action_name: &str, action: F) -> CallbackResult
309    where
310        F: FnOnce(&mut T) -> CallbackResult,
311    {
312        if self.callbacks.is_empty() {
313            return action(target);
314        }
315
316        let target_any = &*target as &dyn Any;
317        let applicable = self
318            .callbacks
319            .iter()
320            .filter(|callback| callback.conditions.matches(target_any, action_name))
321            .collect::<Vec<_>>();
322
323        let mut result = CallbackResult::Continue;
324
325        for callback in applicable
326            .iter()
327            .copied()
328            .filter(|callback| callback.kind == CallbackKind::Before)
329        {
330            result = match &callback.filter {
331                CallbackFilter::Standard(filter) => filter(target),
332                CallbackFilter::Around(_) => CallbackResult::Continue,
333            };
334
335            if result == CallbackResult::Halt {
336                break;
337            }
338        }
339
340        if result != CallbackResult::Halt {
341            let around = applicable
342                .iter()
343                .copied()
344                .filter(|callback| callback.kind == CallbackKind::Around)
345                .collect::<Vec<_>>();
346            result = invoke_around(&around, target, Box::new(action));
347        }
348
349        if self.run_after_on_halt || result != CallbackResult::Halt {
350            for callback in applicable
351                .iter()
352                .rev()
353                .copied()
354                .filter(|callback| callback.kind == CallbackKind::After)
355            {
356                if let CallbackFilter::Standard(filter) = &callback.filter {
357                    let _ = filter(target);
358                }
359            }
360        }
361
362        result
363    }
364}
365
366fn invoke_around<'a, T: Send + Sync + 'static>(
367    around_callbacks: &'a [&'a Callback<T>],
368    target: &mut T,
369    action: AroundContinuation<'a, T>,
370) -> CallbackResult {
371    match around_callbacks.split_first() {
372        Some((callback, rest)) => match &callback.filter {
373            CallbackFilter::Around(filter) => filter(
374                target,
375                Box::new(move |target| invoke_around(rest, target, action)),
376            ),
377            CallbackFilter::Standard(_) => invoke_around(rest, target, action),
378        },
379        None => action(target),
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use std::any::Any;
387
388    #[derive(Default)]
389    struct Recorder {
390        log: Vec<&'static str>,
391        allow: bool,
392        block: bool,
393        action_ran: bool,
394    }
395
396    impl Recorder {
397        fn push(&mut self, entry: &'static str) {
398            self.log.push(entry);
399        }
400    }
401
402    fn allow_predicate(target: &dyn Any) -> bool {
403        target
404            .downcast_ref::<Recorder>()
405            .map(|recorder| recorder.allow)
406            .unwrap_or(false)
407    }
408
409    fn block_predicate(target: &dyn Any) -> bool {
410        target
411            .downcast_ref::<Recorder>()
412            .map(|recorder| recorder.block)
413            .unwrap_or(false)
414    }
415
416    #[test]
417    fn test_empty_chain_returns_continue() {
418        let chain = CallbackChain::<Recorder>::new("save");
419        let mut recorder = Recorder::default();
420
421        let result = chain.run(&mut recorder, "save");
422
423        assert_eq!(result, CallbackResult::Continue);
424        assert!(recorder.log.is_empty());
425    }
426
427    #[test]
428    fn test_before_callback_runs() {
429        let mut chain = CallbackChain::new("save");
430        chain.add(Callback::before("audit", |record: &mut Recorder| {
431            record.push("before");
432            CallbackResult::Continue
433        }));
434        let mut recorder = Recorder::default();
435
436        let result = chain.run(&mut recorder, "save");
437
438        assert_eq!(result, CallbackResult::Continue);
439        assert_eq!(recorder.log, vec!["before"]);
440    }
441
442    #[test]
443    fn test_before_callback_halts_chain() {
444        let mut chain = CallbackChain::new("save");
445        chain.add(Callback::before("halt", |record: &mut Recorder| {
446            record.push("before");
447            CallbackResult::Halt
448        }));
449        let mut recorder = Recorder::default();
450
451        let result = chain.run_with(&mut recorder, "save", |record| {
452            record.push("action");
453            record.action_ran = true;
454            CallbackResult::Continue
455        });
456
457        assert_eq!(result, CallbackResult::Halt);
458        assert_eq!(recorder.log, vec!["before"]);
459        assert!(!recorder.action_ran);
460    }
461
462    #[test]
463    fn test_halting_before_prevents_subsequent_before_callbacks() {
464        let mut chain = CallbackChain::new("save");
465        chain.add(Callback::before("first", |record: &mut Recorder| {
466            record.push("first");
467            CallbackResult::Halt
468        }));
469        chain.add(Callback::before("second", |record: &mut Recorder| {
470            record.push("second");
471            CallbackResult::Continue
472        }));
473        let mut recorder = Recorder::default();
474
475        let result = chain.run(&mut recorder, "save");
476
477        assert_eq!(result, CallbackResult::Halt);
478        assert_eq!(recorder.log, vec!["first"]);
479    }
480
481    #[test]
482    fn test_after_callbacks_run_in_reverse_order() {
483        let mut chain = CallbackChain::new("save");
484        chain.add(Callback::after("first", |record: &mut Recorder| {
485            record.push("after-first");
486            CallbackResult::Continue
487        }));
488        chain.add(Callback::after("second", |record: &mut Recorder| {
489            record.push("after-second");
490            CallbackResult::Continue
491        }));
492        let mut recorder = Recorder::default();
493
494        let result = chain.run_with(&mut recorder, "save", |record| {
495            record.push("action");
496            record.action_ran = true;
497            CallbackResult::Continue
498        });
499
500        assert_eq!(result, CallbackResult::Continue);
501        assert_eq!(recorder.log, vec!["action", "after-second", "after-first"]);
502    }
503
504    #[test]
505    fn test_after_callbacks_ignore_their_own_halt_result() {
506        let mut chain = CallbackChain::new("save");
507        chain.add(Callback::after("first", |record: &mut Recorder| {
508            record.push("after-first");
509            CallbackResult::Halt
510        }));
511        chain.add(Callback::after("second", |record: &mut Recorder| {
512            record.push("after-second");
513            CallbackResult::Continue
514        }));
515        let mut recorder = Recorder::default();
516
517        let result = chain.run(&mut recorder, "save");
518
519        assert_eq!(result, CallbackResult::Continue);
520        assert_eq!(recorder.log, vec!["after-second", "after-first"]);
521    }
522
523    #[test]
524    fn test_around_callback_wraps_action() {
525        let mut chain = CallbackChain::new("save");
526        chain.add(Callback::around(
527            "wrapper",
528            |record: &mut Recorder, next| {
529                record.push("around-before");
530                let result = next(record);
531                record.push("around-after");
532                result
533            },
534        ));
535        let mut recorder = Recorder::default();
536
537        let result = chain.run_with(&mut recorder, "save", |record| {
538            record.push("action");
539            record.action_ran = true;
540            CallbackResult::Continue
541        });
542
543        assert_eq!(result, CallbackResult::Continue);
544        assert_eq!(
545            recorder.log,
546            vec!["around-before", "action", "around-after"]
547        );
548        assert!(recorder.action_ran);
549    }
550
551    #[test]
552    fn test_multiple_around_callbacks_nest_in_order() {
553        let mut chain = CallbackChain::new("save");
554        chain.add(Callback::around("outer", |record: &mut Recorder, next| {
555            record.push("outer-before");
556            let result = next(record);
557            record.push("outer-after");
558            result
559        }));
560        chain.add(Callback::around("inner", |record: &mut Recorder, next| {
561            record.push("inner-before");
562            let result = next(record);
563            record.push("inner-after");
564            result
565        }));
566        let mut recorder = Recorder::default();
567
568        chain.run_with(&mut recorder, "save", |record| {
569            record.push("action");
570            CallbackResult::Continue
571        });
572
573        assert_eq!(
574            recorder.log,
575            vec![
576                "outer-before",
577                "inner-before",
578                "action",
579                "inner-after",
580                "outer-after"
581            ]
582        );
583    }
584
585    #[test]
586    fn test_around_callback_can_prevent_inner_execution() {
587        let mut chain = CallbackChain::new("save");
588        chain.add(Callback::around(
589            "stopper",
590            |record: &mut Recorder, _next| {
591                record.push("around-before");
592                CallbackResult::Halt
593            },
594        ));
595        let mut recorder = Recorder::default();
596
597        let result = chain.run_with(&mut recorder, "save", |record| {
598            record.push("action");
599            record.action_ran = true;
600            CallbackResult::Continue
601        });
602
603        assert_eq!(result, CallbackResult::Halt);
604        assert_eq!(recorder.log, vec!["around-before"]);
605        assert!(!recorder.action_ran);
606    }
607
608    #[test]
609    fn test_multiple_callbacks_execute_in_order() {
610        let mut chain = CallbackChain::new("save");
611        chain.add(Callback::before("one", |record: &mut Recorder| {
612            record.push("before-one");
613            CallbackResult::Continue
614        }));
615        chain.add(Callback::before("two", |record: &mut Recorder| {
616            record.push("before-two");
617            CallbackResult::Continue
618        }));
619        chain.add(Callback::after("three", |record: &mut Recorder| {
620            record.push("after-one");
621            CallbackResult::Continue
622        }));
623        chain.add(Callback::after("four", |record: &mut Recorder| {
624            record.push("after-two");
625            CallbackResult::Continue
626        }));
627        let mut recorder = Recorder::default();
628
629        chain.run_with(&mut recorder, "save", |record| {
630            record.push("action");
631            CallbackResult::Continue
632        });
633
634        assert_eq!(
635            recorder.log,
636            vec![
637                "before-one",
638                "before-two",
639                "action",
640                "after-two",
641                "after-one"
642            ]
643        );
644    }
645
646    #[test]
647    fn test_mixed_kinds_execute_in_expected_order() {
648        let mut chain = CallbackChain::new("save");
649        chain.add(Callback::before("before", |record: &mut Recorder| {
650            record.push("before");
651            CallbackResult::Continue
652        }));
653        chain.add(Callback::around("around", |record: &mut Recorder, next| {
654            record.push("around-before");
655            let result = next(record);
656            record.push("around-after");
657            result
658        }));
659        chain.add(Callback::after("after", |record: &mut Recorder| {
660            record.push("after");
661            CallbackResult::Continue
662        }));
663        let mut recorder = Recorder::default();
664
665        let result = chain.run_with(&mut recorder, "save", |record| {
666            record.push("action");
667            CallbackResult::Continue
668        });
669
670        assert_eq!(result, CallbackResult::Continue);
671        assert_eq!(
672            recorder.log,
673            vec!["before", "around-before", "action", "around-after", "after"]
674        );
675    }
676
677    #[test]
678    fn test_only_condition_matches_listed_action() {
679        let mut chain = CallbackChain::new("save");
680        chain.add(
681            Callback::before("audit", |record: &mut Recorder| {
682                record.push("before");
683                CallbackResult::Continue
684            })
685            .with_conditions(CallbackConditions::default().only(["save"])),
686        );
687        let mut recorder = Recorder::default();
688
689        chain.run(&mut recorder, "save");
690
691        assert_eq!(recorder.log, vec!["before"]);
692    }
693
694    #[test]
695    fn test_only_condition_skips_unlisted_action() {
696        let mut chain = CallbackChain::new("save");
697        chain.add(
698            Callback::before("audit", |record: &mut Recorder| {
699                record.push("before");
700                CallbackResult::Continue
701            })
702            .with_conditions(CallbackConditions::default().only(["create"])),
703        );
704        let mut recorder = Recorder::default();
705
706        chain.run(&mut recorder, "save");
707
708        assert!(recorder.log.is_empty());
709    }
710
711    #[test]
712    fn test_only_empty_list_skips_callback() {
713        let mut chain = CallbackChain::new("save");
714        chain.add(
715            Callback::before("audit", |record: &mut Recorder| {
716                record.push("before");
717                CallbackResult::Continue
718            })
719            .with_conditions(CallbackConditions::default().only(Vec::<&str>::new())),
720        );
721        let mut recorder = Recorder::default();
722
723        chain.run(&mut recorder, "save");
724
725        assert!(recorder.log.is_empty());
726    }
727
728    #[test]
729    fn test_except_condition_skips_listed_action() {
730        let mut chain = CallbackChain::new("save");
731        chain.add(
732            Callback::before("audit", |record: &mut Recorder| {
733                record.push("before");
734                CallbackResult::Continue
735            })
736            .with_conditions(CallbackConditions::default().except(["save"])),
737        );
738        let mut recorder = Recorder::default();
739
740        chain.run(&mut recorder, "save");
741
742        assert!(recorder.log.is_empty());
743    }
744
745    #[test]
746    fn test_except_condition_allows_unlisted_action() {
747        let mut chain = CallbackChain::new("save");
748        chain.add(
749            Callback::before("audit", |record: &mut Recorder| {
750                record.push("before");
751                CallbackResult::Continue
752            })
753            .with_conditions(CallbackConditions::default().except(["destroy"])),
754        );
755        let mut recorder = Recorder::default();
756
757        chain.run(&mut recorder, "save");
758
759        assert_eq!(recorder.log, vec!["before"]);
760    }
761
762    #[test]
763    fn test_except_empty_list_allows_callback() {
764        let mut chain = CallbackChain::new("save");
765        chain.add(
766            Callback::before("audit", |record: &mut Recorder| {
767                record.push("before");
768                CallbackResult::Continue
769            })
770            .with_conditions(CallbackConditions::default().except(Vec::<&str>::new())),
771        );
772        let mut recorder = Recorder::default();
773
774        chain.run(&mut recorder, "save");
775
776        assert_eq!(recorder.log, vec!["before"]);
777    }
778
779    #[test]
780    fn test_if_condition_true_runs_callback() {
781        let mut chain = CallbackChain::new("save");
782        chain.add(
783            Callback::before("audit", |record: &mut Recorder| {
784                record.push("before");
785                CallbackResult::Continue
786            })
787            .with_conditions(CallbackConditions::default().if_cond(allow_predicate)),
788        );
789        let mut recorder = Recorder {
790            allow: true,
791            ..Recorder::default()
792        };
793
794        chain.run(&mut recorder, "save");
795
796        assert_eq!(recorder.log, vec!["before"]);
797    }
798
799    #[test]
800    fn test_if_condition_false_skips_callback() {
801        let mut chain = CallbackChain::new("save");
802        chain.add(
803            Callback::before("audit", |record: &mut Recorder| {
804                record.push("before");
805                CallbackResult::Continue
806            })
807            .with_conditions(CallbackConditions::default().if_cond(allow_predicate)),
808        );
809        let mut recorder = Recorder::default();
810
811        chain.run(&mut recorder, "save");
812
813        assert!(recorder.log.is_empty());
814    }
815
816    #[test]
817    fn test_unless_condition_true_skips_callback() {
818        let mut chain = CallbackChain::new("save");
819        chain.add(
820            Callback::before("audit", |record: &mut Recorder| {
821                record.push("before");
822                CallbackResult::Continue
823            })
824            .with_conditions(CallbackConditions::default().unless_cond(block_predicate)),
825        );
826        let mut recorder = Recorder {
827            block: true,
828            ..Recorder::default()
829        };
830
831        chain.run(&mut recorder, "save");
832
833        assert!(recorder.log.is_empty());
834    }
835
836    #[test]
837    fn test_unless_condition_false_runs_callback() {
838        let mut chain = CallbackChain::new("save");
839        chain.add(
840            Callback::before("audit", |record: &mut Recorder| {
841                record.push("before");
842                CallbackResult::Continue
843            })
844            .with_conditions(CallbackConditions::default().unless_cond(block_predicate)),
845        );
846        let mut recorder = Recorder::default();
847
848        chain.run(&mut recorder, "save");
849
850        assert_eq!(recorder.log, vec!["before"]);
851    }
852
853    #[test]
854    fn test_combined_conditions_require_all_checks_to_pass() {
855        let mut chain = CallbackChain::new("save");
856        chain.add(
857            Callback::before("audit", |record: &mut Recorder| {
858                record.push("before");
859                CallbackResult::Continue
860            })
861            .with_conditions(
862                CallbackConditions::default()
863                    .only(["save"])
864                    .if_cond(allow_predicate)
865                    .unless_cond(block_predicate),
866            ),
867        );
868        let mut allowed = Recorder {
869            allow: true,
870            block: false,
871            ..Recorder::default()
872        };
873        let mut blocked = Recorder {
874            allow: true,
875            block: true,
876            ..Recorder::default()
877        };
878
879        chain.run(&mut allowed, "save");
880        chain.run(&mut blocked, "save");
881
882        assert_eq!(allowed.log, vec!["before"]);
883        assert!(blocked.log.is_empty());
884    }
885
886    #[test]
887    fn test_skip_callback_removes_named_callback() {
888        let mut chain = CallbackChain::new("save");
889        chain.add(Callback::before("audit", |record: &mut Recorder| {
890            record.push("before");
891            CallbackResult::Continue
892        }));
893        chain.skip("audit");
894        let mut recorder = Recorder::default();
895
896        chain.run(&mut recorder, "save");
897
898        assert!(recorder.log.is_empty());
899    }
900
901    #[test]
902    fn test_skip_missing_callback_is_a_noop() {
903        let mut chain = CallbackChain::new("save");
904        chain.add(Callback::before("audit", |record: &mut Recorder| {
905            record.push("before");
906            CallbackResult::Continue
907        }));
908        chain.skip("missing");
909        let mut recorder = Recorder::default();
910
911        chain.run(&mut recorder, "save");
912
913        assert_eq!(recorder.log, vec!["before"]);
914    }
915
916    #[test]
917    fn test_reset_callbacks_clears_chain() {
918        let mut chain = CallbackChain::new("save");
919        chain.add(Callback::before("audit", |record: &mut Recorder| {
920            record.push("before");
921            CallbackResult::Continue
922        }));
923        chain.reset();
924        let mut recorder = Recorder::default();
925
926        let result = chain.run(&mut recorder, "save");
927
928        assert_eq!(result, CallbackResult::Continue);
929        assert!(recorder.log.is_empty());
930    }
931
932    #[test]
933    fn test_callbacks_with_same_name_replace_existing_callback() {
934        let mut chain = CallbackChain::new("save");
935        chain.add(Callback::before("audit", |record: &mut Recorder| {
936            record.push("first");
937            CallbackResult::Continue
938        }));
939        chain.add(Callback::before("audit", |record: &mut Recorder| {
940            record.push("second");
941            CallbackResult::Continue
942        }));
943        let mut recorder = Recorder::default();
944
945        chain.run(&mut recorder, "save");
946
947        assert_eq!(recorder.log, vec!["second"]);
948    }
949
950    #[test]
951    fn test_after_callbacks_run_when_before_halts_by_default() {
952        let mut chain = CallbackChain::new("save");
953        chain.add(Callback::before("halt", |record: &mut Recorder| {
954            record.push("before");
955            CallbackResult::Halt
956        }));
957        chain.add(Callback::after("cleanup", |record: &mut Recorder| {
958            record.push("after");
959            CallbackResult::Continue
960        }));
961        let mut recorder = Recorder::default();
962
963        let result = chain.run_with(&mut recorder, "save", |record| {
964            record.push("action");
965            CallbackResult::Continue
966        });
967
968        assert_eq!(result, CallbackResult::Halt);
969        assert_eq!(recorder.log, vec!["before", "after"]);
970    }
971
972    #[test]
973    fn test_configured_chain_can_skip_after_callbacks_when_halted() {
974        let mut chain = CallbackChain::new("save");
975        chain.set_run_after_on_halt(false);
976        chain.add(Callback::before("halt", |record: &mut Recorder| {
977            record.push("before");
978            CallbackResult::Halt
979        }));
980        chain.add(Callback::after("cleanup", |record: &mut Recorder| {
981            record.push("after");
982            CallbackResult::Continue
983        }));
984        let mut recorder = Recorder::default();
985
986        let result = chain.run_with(&mut recorder, "save", |record| {
987            record.push("action");
988            CallbackResult::Continue
989        });
990
991        assert_eq!(result, CallbackResult::Halt);
992        assert_eq!(recorder.log, vec!["before"]);
993    }
994
995    #[test]
996    fn test_run_with_propagates_action_result() {
997        let chain = CallbackChain::<Recorder>::new("save");
998        let mut recorder = Recorder::default();
999
1000        let result = chain.run_with(&mut recorder, "save", |record| {
1001            record.push("action");
1002            CallbackResult::Halt
1003        });
1004
1005        assert_eq!(result, CallbackResult::Halt);
1006        assert_eq!(recorder.log, vec!["action"]);
1007    }
1008
1009    #[test]
1010    fn test_around_callback_can_return_continue_without_running_inner_action() {
1011        let mut chain = CallbackChain::new("save");
1012        chain.add(Callback::around("skip", |record: &mut Recorder, _next| {
1013            record.push("around-only");
1014            CallbackResult::Continue
1015        }));
1016        chain.add(Callback::after("cleanup", |record: &mut Recorder| {
1017            record.push("after");
1018            CallbackResult::Continue
1019        }));
1020        let mut recorder = Recorder::default();
1021
1022        let result = chain.run_with(&mut recorder, "save", |record| {
1023            record.push("action");
1024            record.action_ran = true;
1025            CallbackResult::Continue
1026        });
1027
1028        assert_eq!(result, CallbackResult::Continue);
1029        assert_eq!(recorder.log, vec!["around-only", "after"]);
1030        assert!(!recorder.action_ran);
1031    }
1032    use std::sync::{
1033        Arc,
1034        atomic::{AtomicUsize, Ordering},
1035    };
1036
1037    #[test]
1038    fn test_chain_metadata_reports_name_and_empty_state() {
1039        let chain = CallbackChain::<Recorder>::new("save");
1040
1041        assert_eq!(chain.name(), "save");
1042        assert_eq!(chain.len(), 0);
1043        assert!(chain.is_empty());
1044    }
1045
1046    #[test]
1047    fn test_adding_callback_updates_length_and_empty_state() {
1048        let mut chain = CallbackChain::new("save");
1049        chain.add(Callback::before("audit", |record: &mut Recorder| {
1050            record.push("before");
1051            CallbackResult::Continue
1052        }));
1053
1054        assert_eq!(chain.len(), 1);
1055        assert!(!chain.is_empty());
1056    }
1057
1058    #[test]
1059    fn test_skip_removes_named_around_callback() {
1060        let mut chain = CallbackChain::new("save");
1061        chain.add(Callback::around(
1062            "wrapper",
1063            |record: &mut Recorder, next| {
1064                record.push("around-before");
1065                let result = next(record);
1066                record.push("around-after");
1067                result
1068            },
1069        ));
1070        chain.skip("wrapper");
1071        let mut recorder = Recorder::default();
1072
1073        let result = chain.run_with(&mut recorder, "save", |record| {
1074            record.push("action");
1075            CallbackResult::Continue
1076        });
1077
1078        assert_eq!(result, CallbackResult::Continue);
1079        assert_eq!(recorder.log, vec!["action"]);
1080    }
1081
1082    #[test]
1083    fn test_skip_preserves_order_of_remaining_callbacks() {
1084        let mut chain = CallbackChain::new("save");
1085        chain.add(Callback::before("first", |record: &mut Recorder| {
1086            record.push("first");
1087            CallbackResult::Continue
1088        }));
1089        chain.add(Callback::before("skip", |record: &mut Recorder| {
1090            record.push("skip");
1091            CallbackResult::Continue
1092        }));
1093        chain.add(Callback::before("third", |record: &mut Recorder| {
1094            record.push("third");
1095            CallbackResult::Continue
1096        }));
1097        chain.skip("skip");
1098        let mut recorder = Recorder::default();
1099
1100        chain.run(&mut recorder, "save");
1101
1102        assert_eq!(recorder.log, vec!["first", "third"]);
1103    }
1104
1105    #[test]
1106    fn test_reset_clears_all_callback_kinds_and_length() {
1107        let mut chain = CallbackChain::new("save");
1108        chain.add(Callback::before("before", |record: &mut Recorder| {
1109            record.push("before");
1110            CallbackResult::Continue
1111        }));
1112        chain.add(Callback::around("around", |record: &mut Recorder, next| {
1113            record.push("around-before");
1114            let result = next(record);
1115            record.push("around-after");
1116            result
1117        }));
1118        chain.add(Callback::after("after", |record: &mut Recorder| {
1119            record.push("after");
1120            CallbackResult::Continue
1121        }));
1122
1123        assert_eq!(chain.len(), 3);
1124        chain.reset();
1125
1126        let mut recorder = Recorder::default();
1127        let result = chain.run_with(&mut recorder, "save", |record| {
1128            record.push("action");
1129            CallbackResult::Continue
1130        });
1131
1132        assert_eq!(result, CallbackResult::Continue);
1133        assert_eq!(chain.len(), 0);
1134        assert!(chain.is_empty());
1135        assert_eq!(recorder.log, vec!["action"]);
1136    }
1137
1138    #[test]
1139    fn test_only_condition_matches_any_listed_action() {
1140        let mut chain = CallbackChain::new("save");
1141        chain.add(
1142            Callback::before("audit", |record: &mut Recorder| {
1143                record.push("before");
1144                CallbackResult::Continue
1145            })
1146            .with_conditions(CallbackConditions::new().only(["create", "save"])),
1147        );
1148
1149        let mut save_recorder = Recorder::default();
1150        let mut create_recorder = Recorder::default();
1151
1152        chain.run(&mut save_recorder, "save");
1153        chain.run(&mut create_recorder, "create");
1154
1155        assert_eq!(save_recorder.log, vec!["before"]);
1156        assert_eq!(create_recorder.log, vec!["before"]);
1157    }
1158
1159    #[test]
1160    fn test_except_condition_skips_each_listed_action() {
1161        let mut chain = CallbackChain::new("save");
1162        chain.add(
1163            Callback::before("audit", |record: &mut Recorder| {
1164                record.push("before");
1165                CallbackResult::Continue
1166            })
1167            .with_conditions(CallbackConditions::new().except(["create", "destroy"])),
1168        );
1169
1170        let mut create_recorder = Recorder::default();
1171        let mut destroy_recorder = Recorder::default();
1172        let mut save_recorder = Recorder::default();
1173
1174        chain.run(&mut create_recorder, "create");
1175        chain.run(&mut destroy_recorder, "destroy");
1176        chain.run(&mut save_recorder, "save");
1177
1178        assert!(create_recorder.log.is_empty());
1179        assert!(destroy_recorder.log.is_empty());
1180        assert_eq!(save_recorder.log, vec!["before"]);
1181    }
1182
1183    #[test]
1184    fn test_if_and_unless_conditions_must_both_allow_execution() {
1185        let mut chain = CallbackChain::new("save");
1186        chain.add(
1187            Callback::before("audit", |record: &mut Recorder| {
1188                record.push("before");
1189                CallbackResult::Continue
1190            })
1191            .with_conditions(
1192                CallbackConditions::new()
1193                    .if_cond(allow_predicate)
1194                    .unless_cond(block_predicate),
1195            ),
1196        );
1197
1198        let mut allowed = Recorder {
1199            allow: true,
1200            block: false,
1201            ..Recorder::default()
1202        };
1203        let mut blocked = Recorder {
1204            allow: true,
1205            block: true,
1206            ..Recorder::default()
1207        };
1208        let mut disallowed = Recorder {
1209            allow: false,
1210            block: false,
1211            ..Recorder::default()
1212        };
1213
1214        chain.run(&mut allowed, "save");
1215        chain.run(&mut blocked, "save");
1216        chain.run(&mut disallowed, "save");
1217
1218        assert_eq!(allowed.log, vec!["before"]);
1219        assert!(blocked.log.is_empty());
1220        assert!(disallowed.log.is_empty());
1221    }
1222
1223    #[test]
1224    fn test_before_halt_skips_around_callbacks_and_action() {
1225        let mut chain = CallbackChain::new("save");
1226        chain.add(Callback::before("halt", |record: &mut Recorder| {
1227            record.push("before");
1228            CallbackResult::Halt
1229        }));
1230        chain.add(Callback::around(
1231            "wrapper",
1232            |record: &mut Recorder, next| {
1233                record.push("around-before");
1234                let result = next(record);
1235                record.push("around-after");
1236                result
1237            },
1238        ));
1239        let mut recorder = Recorder::default();
1240
1241        let result = chain.run_with(&mut recorder, "save", |record| {
1242            record.push("action");
1243            record.action_ran = true;
1244            CallbackResult::Continue
1245        });
1246
1247        assert_eq!(result, CallbackResult::Halt);
1248        assert_eq!(recorder.log, vec!["before"]);
1249        assert!(!recorder.action_ran);
1250    }
1251
1252    #[test]
1253    fn test_around_halt_runs_after_callbacks_by_default() {
1254        let mut chain = CallbackChain::new("save");
1255        chain.add(Callback::around("halt", |record: &mut Recorder, _next| {
1256            record.push("around-before");
1257            CallbackResult::Halt
1258        }));
1259        chain.add(Callback::after("cleanup", |record: &mut Recorder| {
1260            record.push("after");
1261            CallbackResult::Continue
1262        }));
1263        let mut recorder = Recorder::default();
1264
1265        let result = chain.run_with(&mut recorder, "save", |record| {
1266            record.push("action");
1267            CallbackResult::Continue
1268        });
1269
1270        assert_eq!(result, CallbackResult::Halt);
1271        assert_eq!(recorder.log, vec!["around-before", "after"]);
1272    }
1273
1274    #[test]
1275    fn test_around_halt_can_skip_after_callbacks_when_configured() {
1276        let mut chain = CallbackChain::new("save");
1277        chain.set_run_after_on_halt(false);
1278        chain.add(Callback::around("halt", |record: &mut Recorder, _next| {
1279            record.push("around-before");
1280            CallbackResult::Halt
1281        }));
1282        chain.add(Callback::after("cleanup", |record: &mut Recorder| {
1283            record.push("after");
1284            CallbackResult::Continue
1285        }));
1286        let mut recorder = Recorder::default();
1287
1288        let result = chain.run_with(&mut recorder, "save", |record| {
1289            record.push("action");
1290            CallbackResult::Continue
1291        });
1292
1293        assert_eq!(result, CallbackResult::Halt);
1294        assert_eq!(recorder.log, vec!["around-before"]);
1295    }
1296
1297    #[test]
1298    fn test_replacing_callback_with_different_kind_moves_execution_phase() {
1299        let mut chain = CallbackChain::new("save");
1300        chain.add(Callback::before("audit", |record: &mut Recorder| {
1301            record.push("before");
1302            CallbackResult::Continue
1303        }));
1304        chain.add(Callback::after("audit", |record: &mut Recorder| {
1305            record.push("after");
1306            CallbackResult::Continue
1307        }));
1308        let mut recorder = Recorder::default();
1309
1310        let result = chain.run_with(&mut recorder, "save", |record| {
1311            record.push("action");
1312            CallbackResult::Continue
1313        });
1314
1315        assert_eq!(result, CallbackResult::Continue);
1316        assert_eq!(chain.len(), 1);
1317        assert_eq!(recorder.log, vec!["action", "after"]);
1318    }
1319
1320    #[test]
1321    fn test_callback_closure_can_capture_environment() {
1322        let counter = Arc::new(AtomicUsize::new(0));
1323        let captured = Arc::clone(&counter);
1324        let mut chain = CallbackChain::new("save");
1325        chain.add(Callback::before("audit", move |record: &mut Recorder| {
1326            captured.fetch_add(1, Ordering::SeqCst);
1327            record.push("before");
1328            CallbackResult::Continue
1329        }));
1330        let mut recorder = Recorder::default();
1331
1332        chain.run(&mut recorder, "save");
1333
1334        assert_eq!(counter.load(Ordering::SeqCst), 1);
1335        assert_eq!(recorder.log, vec!["before"]);
1336    }
1337
1338    #[test]
1339    fn test_nested_around_callbacks_propagate_halt_outward() {
1340        let mut chain = CallbackChain::new("save");
1341        chain.add(Callback::around("outer", |record: &mut Recorder, next| {
1342            record.push("outer-before");
1343            let result = next(record);
1344            record.push("outer-after");
1345            result
1346        }));
1347        chain.add(Callback::around("inner", |record: &mut Recorder, _next| {
1348            record.push("inner-before");
1349            CallbackResult::Halt
1350        }));
1351        let mut recorder = Recorder::default();
1352
1353        let result = chain.run_with(&mut recorder, "save", |record| {
1354            record.push("action");
1355            CallbackResult::Continue
1356        });
1357
1358        assert_eq!(result, CallbackResult::Halt);
1359        assert_eq!(
1360            recorder.log,
1361            vec!["outer-before", "inner-before", "outer-after"]
1362        );
1363    }
1364
1365    #[test]
1366    fn test_filtered_callbacks_still_run_action() {
1367        let mut chain = CallbackChain::new("save");
1368        chain.add(
1369            Callback::before("audit", |record: &mut Recorder| {
1370                record.push("before");
1371                CallbackResult::Continue
1372            })
1373            .with_conditions(CallbackConditions::new().only(["create"])),
1374        );
1375        let mut recorder = Recorder::default();
1376
1377        let result = chain.run_with(&mut recorder, "save", |record| {
1378            record.push("action");
1379            record.action_ran = true;
1380            CallbackResult::Continue
1381        });
1382
1383        assert_eq!(result, CallbackResult::Continue);
1384        assert_eq!(recorder.log, vec!["action"]);
1385        assert!(recorder.action_ran);
1386    }
1387
1388    #[test]
1389    fn test_reset_then_readd_callback_uses_new_chain_contents() {
1390        let mut chain = CallbackChain::new("save");
1391        chain.add(Callback::before("before", |record: &mut Recorder| {
1392            record.push("before");
1393            CallbackResult::Continue
1394        }));
1395        chain.reset();
1396        chain.add(Callback::after("after", |record: &mut Recorder| {
1397            record.push("after");
1398            CallbackResult::Continue
1399        }));
1400        let mut recorder = Recorder::default();
1401
1402        let result = chain.run_with(&mut recorder, "save", |record| {
1403            record.push("action");
1404            CallbackResult::Continue
1405        });
1406
1407        assert_eq!(result, CallbackResult::Continue);
1408        assert_eq!(chain.len(), 1);
1409        assert_eq!(recorder.log, vec!["action", "after"]);
1410    }
1411}