witchcraft_log/
mdc.rs

1// Copyright 2021 Palantir Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! A Mapped Diagnostic Context (MDC) for Witchcraft loggers.
15//!
16//! An MDC is a thread local map containing extra parameters. Witchcraft logging implementations should include the
17//! contents of the MDC in service logs.
18use conjure_object::Any;
19use pin_project::{pin_project, pinned_drop};
20use serde::Serialize;
21use std::cell::RefCell;
22use std::collections::{hash_map, HashMap};
23use std::future::Future;
24use std::mem;
25use std::pin::Pin;
26use std::sync::{Arc, OnceLock};
27use std::task::{Context, Poll};
28
29static EMPTY: OnceLock<Map> = OnceLock::new();
30
31thread_local! {
32    static MDC: RefCell<Snapshot> = RefCell::new(Snapshot::new());
33}
34
35/// Inserts a new safe parameter into the MDC.
36///
37/// # Panics
38///
39/// Panics if the value cannot be serialized into an [`Any`].
40pub fn insert_safe<T>(key: &'static str, value: T) -> Option<Any>
41where
42    T: Serialize,
43{
44    MDC.with(|v| v.borrow_mut().safe_mut().insert(key, value))
45}
46
47/// Inserts a new unsafe parameter into the MDC.
48///
49/// # Panics
50///
51/// Panics if the value cannot be serialized into an [`Any`].
52pub fn insert_unsafe<T>(key: &'static str, value: T) -> Option<Any>
53where
54    T: Serialize,
55{
56    MDC.with(|v| v.borrow_mut().unsafe_mut().insert(key, value))
57}
58
59/// Removes the specified safe parameter from the MDC.
60pub fn remove_safe(key: &str) -> Option<Any> {
61    MDC.with(|v| v.borrow_mut().safe_mut().remove(key))
62}
63
64/// Removes the specified unsafe parameter from the MDC.
65pub fn remove_unsafe(key: &str) -> Option<Any> {
66    MDC.with(|v| v.borrow_mut().unsafe_mut().remove(key))
67}
68
69/// Takes a snapshot of the MDC.
70///
71/// The snapshot and MDC are not connected - updates to the snapshot will not affect the MDC and vice versa.
72pub fn snapshot() -> Snapshot {
73    MDC.with(|v| v.borrow().clone())
74}
75
76/// Clears the contents of the MDC.
77pub fn clear() {
78    MDC.with(|v| {
79        let mut mdc = v.borrow_mut();
80        mdc.safe_mut().clear();
81        mdc.unsafe_mut().clear();
82    });
83}
84
85/// Overwrites the MDC with a snapshot, returning the previous state.
86pub fn set(snapshot: Snapshot) -> Snapshot {
87    MDC.with(|v| mem::replace(&mut *v.borrow_mut(), snapshot))
88}
89
90/// Swaps the MDC with a snapshot in-place.
91pub fn swap(snapshot: &mut Snapshot) {
92    MDC.with(|v| mem::swap(&mut *v.borrow_mut(), snapshot));
93}
94
95/// Wraps a future with a layer that maintains the MDC across polls.
96///
97/// The future will begin executing with the MDC state at the time this function is called, and
98/// updates to the MDC within calls to `poll` will be propagated forward.
99pub fn bind<F>(future: F) -> Bind<F> {
100    Bind {
101        future: Some(future),
102        snapshot: snapshot(),
103    }
104}
105
106/// Creates a guard object which will reset the MDC to the state it was previously in on drop.
107pub fn scope() -> Scope {
108    Scope { old: snapshot() }
109}
110
111/// A map of MDC entries.
112#[derive(Clone, Debug, PartialEq, Eq)]
113pub struct Map {
114    map: Arc<HashMap<&'static str, Any>>,
115}
116
117impl Default for Map {
118    #[inline]
119    fn default() -> Self {
120        EMPTY
121            .get_or_init(|| Map {
122                map: Arc::new(HashMap::new()),
123            })
124            .clone()
125    }
126}
127
128impl Map {
129    /// Returns a new, empty map.
130    #[inline]
131    pub fn new() -> Self {
132        Map::default()
133    }
134
135    /// Removes all entries from the map.
136    #[inline]
137    pub fn clear(&mut self) {
138        // try to preserve capacity if we're the unique owner
139        match Arc::get_mut(&mut self.map) {
140            Some(map) => map.clear(),
141            None => *self = Map::new(),
142        }
143    }
144
145    /// Returns the number of entries in the map.
146    #[inline]
147    pub fn len(&self) -> usize {
148        self.map.len()
149    }
150
151    /// Determines if the map is empty.
152    #[inline]
153    pub fn is_empty(&self) -> bool {
154        self.map.is_empty()
155    }
156
157    /// Looks up a value in the map.
158    #[inline]
159    pub fn get(&self, key: &str) -> Option<&Any> {
160        self.map.get(key)
161    }
162
163    /// Determines if the map contains the specified key.
164    #[inline]
165    pub fn contains_key(&self, key: &str) -> bool {
166        self.map.contains_key(key)
167    }
168
169    /// Inserts a new entry into the map, returning the old value corresponding to the key.
170    ///
171    /// # Panics
172    ///
173    /// Panics if the value cannot be serialized into an [`Any`].
174    #[inline]
175    pub fn insert<V>(&mut self, key: &'static str, value: V) -> Option<Any>
176    where
177        V: Serialize,
178    {
179        let value = Any::new(value).expect("value failed to serialize");
180        Arc::make_mut(&mut self.map).insert(key, value)
181    }
182
183    /// Removes an entry from the map, returning its value.
184    #[inline]
185    pub fn remove(&mut self, key: &str) -> Option<Any> {
186        Arc::make_mut(&mut self.map).remove(key)
187    }
188
189    /// Returns an iterator over the entries in the map.
190    #[inline]
191    pub fn iter(&self) -> Iter<'_> {
192        Iter {
193            it: self.map.iter(),
194        }
195    }
196}
197
198impl<'a> IntoIterator for &'a Map {
199    type Item = (&'static str, &'a Any);
200
201    type IntoIter = Iter<'a>;
202
203    #[inline]
204    fn into_iter(self) -> Self::IntoIter {
205        self.iter()
206    }
207}
208
209/// An iterator over the entries in a [`Map`].
210pub struct Iter<'a> {
211    it: hash_map::Iter<'a, &'static str, Any>,
212}
213
214impl<'a> Iterator for Iter<'a> {
215    type Item = (&'static str, &'a Any);
216
217    #[inline]
218    fn next(&mut self) -> Option<Self::Item> {
219        self.it.next().map(|(k, v)| (*k, v))
220    }
221
222    #[inline]
223    fn size_hint(&self) -> (usize, Option<usize>) {
224        self.it.size_hint()
225    }
226}
227
228impl ExactSizeIterator for Iter<'_> {
229    #[inline]
230    fn len(&self) -> usize {
231        self.it.len()
232    }
233}
234
235/// A portable snapshot of the MDC.
236#[derive(Clone, Default, Debug, PartialEq, Eq)]
237pub struct Snapshot {
238    safe: Map,
239    unsafe_: Map,
240}
241
242impl Snapshot {
243    /// Returns a new, empty snapshot.
244    #[inline]
245    pub fn new() -> Self {
246        Snapshot::default()
247    }
248
249    /// Returns a shared reference to the safe entries in the snapshot.
250    #[inline]
251    pub fn safe(&self) -> &Map {
252        &self.safe
253    }
254
255    /// Returns a mutable reference to the safe entries in the snapshot.
256    #[inline]
257    pub fn safe_mut(&mut self) -> &mut Map {
258        &mut self.safe
259    }
260
261    /// Returns a shared reference to the unsafe entries in the snapshot.
262    #[inline]
263    pub fn unsafe_(&self) -> &Map {
264        &self.unsafe_
265    }
266
267    /// Returns a shared reference to the unsafe entries in the snapshot.
268    #[inline]
269    pub fn unsafe_mut(&mut self) -> &mut Map {
270        &mut self.unsafe_
271    }
272}
273
274/// A guard object which resets the MDC to an earlier state when it drops.
275pub struct Scope {
276    old: Snapshot,
277}
278
279impl Drop for Scope {
280    fn drop(&mut self) {
281        swap(&mut self.old);
282    }
283}
284
285/// A future which manages the MDC across polls to a delegate.
286#[pin_project(PinnedDrop)]
287pub struct Bind<F> {
288    #[pin]
289    future: Option<F>,
290    snapshot: Snapshot,
291}
292
293#[pinned_drop]
294impl<F> PinnedDrop for Bind<F> {
295    fn drop(self: Pin<&mut Self>) {
296        let mut this = self.project();
297        let _guard = scope_with(this.snapshot);
298        this.future.set(None);
299    }
300}
301
302impl<F> Future for Bind<F>
303where
304    F: Future,
305{
306    type Output = F::Output;
307
308    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
309        let this = self.project();
310        let _guard = scope_with(this.snapshot);
311        this.future.as_pin_mut().unwrap().poll(cx)
312    }
313}
314
315/// Sets the MDC state to the snapshot, resetting it to the state it was previously on drop.
316fn scope_with(snapshot: &mut Snapshot) -> ScopeWith<'_> {
317    swap(snapshot);
318    ScopeWith { snapshot }
319}
320
321struct ScopeWith<'a> {
322    snapshot: &'a mut Snapshot,
323}
324
325impl Drop for ScopeWith<'_> {
326    fn drop(&mut self) {
327        swap(self.snapshot);
328    }
329}
330
331#[cfg(test)]
332mod test {
333    use conjure_object::Any;
334
335    use crate::mdc;
336
337    #[test]
338    fn scope() {
339        mdc::clear();
340
341        mdc::insert_safe("foo", "bar");
342        let guard = mdc::scope();
343        mdc::insert_safe("foo", "baz");
344        assert_eq!(
345            mdc::snapshot().safe().get("foo").unwrap(),
346            &Any::new("baz").unwrap(),
347        );
348
349        drop(guard);
350        assert_eq!(
351            mdc::snapshot().safe().get("foo").unwrap(),
352            &Any::new("bar").unwrap(),
353        );
354    }
355
356    #[test]
357    fn bind() {
358        mdc::clear();
359
360        mdc::insert_safe("foo", "bar");
361        futures_executor::block_on(mdc::bind(async {
362            mdc::insert_safe("foo", "baz");
363            assert_eq!(
364                mdc::snapshot().safe().get("foo").unwrap(),
365                &Any::new("baz").unwrap(),
366            );
367        }));
368
369        assert_eq!(
370            mdc::snapshot().safe().get("foo").unwrap(),
371            &Any::new("bar").unwrap(),
372        );
373    }
374}