tracing_fluent_assertions/
layer.rs

1use std::{any::TypeId, marker::PhantomData, sync::Arc};
2
3use tracing::{span::Attributes, Id, Subscriber};
4use tracing_subscriber::{layer::Context, registry::LookupSpan, Layer};
5
6use crate::{state::State, AssertionRegistry};
7
8/// A [`tracing_subscriber::Layer`] that tracks the lifecycle changes of certain spans based on span
9/// matchers which define which spans to track.
10pub struct AssertionsLayer<S> {
11    state: Arc<State>,
12    _subscriber: PhantomData<fn(S)>,
13}
14
15impl<S> AssertionsLayer<S>
16where
17    S: Subscriber,
18{
19    /// Create a new [`AssertionsLayer`] tied to the given [`AssertionRegistry`].
20    pub fn new(controller: &AssertionRegistry) -> Self {
21        Self {
22            state: Arc::clone(controller.state()),
23            _subscriber: PhantomData,
24        }
25    }
26}
27
28impl<S> Layer<S> for AssertionsLayer<S>
29where
30    S: Subscriber + for<'a> LookupSpan<'a>,
31{
32    fn on_new_span(&self, _attributes: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) {
33        let span = ctx.span(id).expect("span must already exist!");
34        if let Some(entry) = self.state.get_entry(span) {
35            entry.track_created();
36        }
37    }
38
39    fn on_enter(&self, id: &Id, ctx: Context<'_, S>) {
40        let span = ctx.span(id).expect("span must already exist!");
41        if let Some(entry) = self.state.get_entry(span) {
42            entry.track_entered();
43        }
44    }
45
46    fn on_exit(&self, id: &Id, ctx: Context<'_, S>) {
47        let span = ctx.span(id).expect("span must already exist!");
48        if let Some(entry) = self.state.get_entry(span) {
49            entry.track_exited();
50        }
51    }
52
53    fn on_close(&self, id: Id, ctx: Context<'_, S>) {
54        let span = ctx.span(&id).expect("span must already exist!");
55        if let Some(entry) = self.state.get_entry(span) {
56            entry.track_closed();
57        }
58    }
59
60    unsafe fn downcast_raw(&self, id: TypeId) -> Option<*const ()> {
61        match id {
62            id if id == TypeId::of::<Self>() => Some(self as *const _ as *const ()),
63            _ => None,
64        }
65    }
66}