tracing_causality/
channel.rs

1use crate::Update;
2use std::{
3    collections::BTreeSet,
4    fmt::Debug,
5    sync::{
6        atomic::{AtomicBool, Ordering},
7        Arc,
8    },
9};
10use tracing_core::span::Id;
11
12/// Constructs a channel of [`Trace`] updates bounded at `capacity`.
13pub(crate) fn bounded(id: Id, capacity: usize) -> (Sender, Updates) {
14    let (sender, receiver) = flume::bounded(capacity);
15    let overflow_flag = OverflowFlag::default();
16    let sender = Sender {
17        id,
18        sender,
19        overflow_flag: overflow_flag.clone(),
20    };
21    let updates = Updates {
22        receiver,
23        overflow_flag,
24    };
25    (sender, updates)
26}
27
28#[derive(Default, Clone, Debug)]
29struct OverflowFlag {
30    flag: Arc<AtomicBool>,
31}
32
33impl OverflowFlag {
34    fn set(&self) {
35        self.flag.store(true, Ordering::Release);
36    }
37
38    fn check(&self) -> bool {
39        self.flag.load(Ordering::Acquire)
40    }
41}
42
43#[cfg(test)]
44mod test_overflow_flag {
45    use super::OverflowFlag;
46
47    #[test]
48    fn set_and_check() {
49        let flag = OverflowFlag::default();
50        // the starting value of the flag is `false`
51        assert_eq!(flag.check(), false);
52        // calling `set` makes the flag value `true`
53        flag.set();
54        assert_eq!(flag.check(), true);
55        // calling `set` again does NOT toggle the flag to `false`
56        flag.set();
57        assert_eq!(flag.check(), true);
58    }
59}
60
61/// A stream of [`Update`]s that affect a [`Trace`][crate::Trace].
62pub struct Updates {
63    receiver: flume::Receiver<Update>,
64    overflow_flag: OverflowFlag,
65}
66
67impl Updates {
68    pub fn is_empty(&self) -> bool {
69        self.receiver.is_empty()
70    }
71
72    /// Returns `true` if the span that these `Updates` are being provided for
73    /// has been closed.
74    ///
75    pub fn is_disconnected(&self) -> bool {
76        self.receiver.is_disconnected()
77    }
78
79    pub fn next(&self) -> Option<Update> {
80        self.receiver.try_recv().ok()
81    }
82
83    pub fn into_iter(self) -> impl Iterator<Item = Update> {
84        self.receiver
85            .into_iter()
86            .take_while(move |_| !self.overflow_flag.check())
87    }
88
89    pub fn into_stream(self) -> impl futures_core::stream::Stream {
90        use futures::stream::StreamExt;
91        self.receiver
92            .into_stream()
93            .take_while(move |_| std::future::ready(!self.overflow_flag.check()))
94    }
95
96    /// Produces an iterator over all [`Update`]s currently sitting in the
97    /// the channel.
98    pub fn drain(&self) -> impl ExactSizeIterator<Item = Update> + '_ {
99        self.receiver.drain()
100    }
101}
102
103impl Default for Updates {
104    fn default() -> Self {
105        let (_, receiver) = flume::bounded(0);
106        let overflow_flag = OverflowFlag::default();
107
108        Self {
109            receiver,
110            overflow_flag,
111        }
112    }
113}
114
115#[derive(Clone, Debug)]
116pub(crate) struct Sender {
117    id: Id,
118    sender: flume::Sender<Update>,
119    overflow_flag: OverflowFlag,
120}
121
122impl Sender {
123    fn try_send(&self, update: Update) -> Result<(), ()> {
124        use flume::TrySendError::{Disconnected, Full};
125
126        self.sender
127            .try_send(update)
128            .map_err(|err| match err {
129                Full(_) => {
130                    // notify receivers that an update has been dropped
131                    self.overflow_flag.set();
132                }
133                Disconnected(_) => {
134                    // don't take any action here if disconnected
135                }
136            })
137            .map(|_| {})
138    }
139
140    pub(crate) fn broadcast(listeners: &mut BTreeSet<Self>, update: Update) {
141        listeners.retain(|listener| listener.try_send(update.clone()).is_ok());
142    }
143}
144
145impl std::hash::Hash for Sender {
146    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
147        self.id.hash(state);
148    }
149}
150
151impl Eq for Sender {}
152
153impl Ord for Sender {
154    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
155        self.id.into_u64().cmp(&other.id.into_u64())
156    }
157}
158
159impl PartialOrd for Sender {
160    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
161        self.id.into_u64().partial_cmp(&other.id.into_u64())
162    }
163}
164
165impl PartialEq for Sender {
166    fn eq(&self, other: &Self) -> bool {
167        self.id.eq(&other.id)
168    }
169}
170
171#[cfg(test)]
172mod test_sender {
173    use super::*;
174    use tracing_core::span::Id;
175    use tracing_subscriber::{prelude::*, registry::Registry};
176
177    use crate as tracing_causality;
178
179    /// `Updates::is_disconnected` should produce `true` as soon as the last
180    /// sender is dropped, even if there remain updates.
181    #[test]
182    fn should_disconnect_if_sender_dropped() {
183        let _guard = Registry::default().set_default();
184
185        let (sender, updates) = bounded(Id::from_u64(1), 1);
186
187        assert!(!updates.is_disconnected());
188
189        let cause = tracing::trace_span!("cause");
190        let consequence = cause.in_scope(|| tracing::trace_span!("consequence"));
191
192        let cause_id_and_metadata = tracing_causality::Span {
193            id: cause.id().unwrap(),
194            metadata: cause.metadata().unwrap(),
195        };
196
197        let consequence_id_and_metadata = tracing_causality::Span {
198            id: consequence.id().unwrap(),
199            metadata: consequence.metadata().unwrap(),
200        };
201
202        let update = Update::OpenDirect {
203            cause: cause_id_and_metadata,
204            consequence: consequence_id_and_metadata,
205        };
206
207        sender
208            .try_send(update.clone())
209            .expect("sending should succeed");
210
211        assert!(!updates.is_disconnected());
212
213        drop(sender);
214
215        assert!(updates.is_disconnected());
216
217        assert_eq!(updates.next(), Some(update.clone()));
218    }
219
220    #[test]
221    fn try_send_success() {
222        let _guard = Registry::default().set_default();
223
224        let (sender, updates) = bounded(Id::from_u64(1), 1);
225
226        let cause = tracing::trace_span!("cause");
227        let consequence = cause.in_scope(|| tracing::trace_span!("consequence"));
228
229        let cause_id_and_metadata = tracing_causality::Span {
230            id: cause.id().unwrap(),
231            metadata: cause.metadata().unwrap(),
232        };
233
234        let consequence_id_and_metadata = tracing_causality::Span {
235            id: consequence.id().unwrap(),
236            metadata: consequence.metadata().unwrap(),
237        };
238
239        let update = Update::OpenDirect {
240            cause: cause_id_and_metadata,
241            consequence: consequence_id_and_metadata,
242        };
243        let send_result = sender.try_send(update.clone());
244        assert!(send_result.is_ok());
245        assert_eq!(updates.next(), Some(update.clone()));
246        assert!(updates.is_empty());
247    }
248
249    #[test]
250    fn try_send_err_disconnected() {
251        let _guard = Registry::default().set_default();
252
253        // drop `Updates` immediately
254        let (sender, _) = bounded(Id::from_u64(1), 1);
255
256        let cause = tracing::trace_span!("cause");
257        let consequence = cause.in_scope(|| tracing::trace_span!("consequence"));
258
259        let cause_id_and_metadata = tracing_causality::Span {
260            id: cause.id().unwrap(),
261            metadata: cause.metadata().unwrap(),
262        };
263
264        let consequence_id_and_metadata = tracing_causality::Span {
265            id: consequence.id().unwrap(),
266            metadata: consequence.metadata().unwrap(),
267        };
268
269        let update = Update::OpenDirect {
270            cause: cause_id_and_metadata,
271            consequence: consequence_id_and_metadata,
272        };
273        let send_result = sender.try_send(update);
274        assert!(send_result.is_err());
275    }
276
277    #[test]
278    fn try_send_err_full() {
279        let _guard = Registry::default().set_default();
280
281        // set capacity to 0 to overflow on first send
282        let (sender, updates) = bounded(Id::from_u64(1), 0);
283        let cause = tracing::trace_span!("cause");
284        let consequence = cause.in_scope(|| tracing::trace_span!("consequence"));
285
286        let cause_id_and_metadata = tracing_causality::Span {
287            id: cause.id().unwrap(),
288            metadata: cause.metadata().unwrap(),
289        };
290
291        let consequence_id_and_metadata = tracing_causality::Span {
292            id: consequence.id().unwrap(),
293            metadata: consequence.metadata().unwrap(),
294        };
295
296        let update = Update::OpenDirect {
297            cause: cause_id_and_metadata,
298            consequence: consequence_id_and_metadata,
299        };
300        assert_eq!(updates.overflow_flag.check(), false);
301        let send_result = sender.try_send(update);
302        assert!(send_result.is_err());
303        assert_eq!(updates.overflow_flag.check(), true);
304        assert_eq!(updates.next(), None,);
305    }
306}