trace_context/
lib.rs

1//!  Extract and inject [W3C TraceContext](https://w3c.github.io/trace-context/) headers.
2//!
3//! ## Examples
4//! ```
5//! let mut headers = http::HeaderMap::new();
6//! headers.insert(
7//!     "traceparent",
8//!     "00-0af7651916cd43dd8448eb211c80319c-00f067aa0ba902b7-01".parse().unwrap()
9//! );
10//! 
11//! let context = trace_context::TraceContext::extract(&headers).unwrap();
12//! 
13//! assert_eq!(context.trace_id(), u128::from_str_radix("0af7651916cd43dd8448eb211c80319c", 16).unwrap());
14//! assert_eq!(context.parent_id(), u64::from_str_radix("00f067aa0ba902b7", 16).ok());
15//! assert_eq!(context.sampled(), true);
16//! ```
17
18#![deny(unsafe_code)]
19
20use rand::Rng;
21use std::fmt;
22
23/// A TraceContext object
24#[derive(Debug)]
25pub struct TraceContext {
26    id: u64,
27    version: u8,
28    trace_id: u128,
29    parent_id: Option<u64>,
30    flags: u8,
31}
32
33impl TraceContext {
34    /// Create and return TraceContext object based on `traceparent` HTTP header.
35    ///
36    /// ## Examples
37    /// ```
38    /// let mut headers = http::HeaderMap::new();
39    /// headers.insert("traceparent", "00-0af7651916cd43dd8448eb211c80319c-00f067aa0ba902b7-01".parse().unwrap());
40    /// 
41    /// let context = trace_context::TraceContext::extract(&headers).unwrap();
42    /// 
43    /// assert_eq!(context.trace_id(), u128::from_str_radix("0af7651916cd43dd8448eb211c80319c", 16).unwrap());
44    /// assert_eq!(context.parent_id(), u64::from_str_radix("00f067aa0ba902b7",
45    /// 16).ok());
46    /// assert_eq!(context.sampled(), true);
47    /// ```
48    pub fn extract(headers: &http::HeaderMap) -> Result<Self, std::num::ParseIntError> {
49        let mut rng = rand::thread_rng();
50
51        let traceparent = match headers.get("traceparent") {
52            Some(header) => header.to_str().unwrap(),
53            None => return Ok(Self::new_root()),
54        };
55
56        let parts: Vec<&str> = traceparent.split('-').collect();
57
58        Ok(Self {
59            id: rng.gen(),
60            version: u8::from_str_radix(parts[0], 16)?,
61            trace_id: u128::from_str_radix(parts[1], 16)?,
62            parent_id: Some(u64::from_str_radix(parts[2], 16)?),
63            flags: u8::from_str_radix(parts[3], 16)?
64        })
65    }
66
67    pub fn new_root() -> Self {
68        let mut rng = rand::thread_rng();
69
70        Self {
71            id: rng.gen(),
72            version: 0,
73            trace_id: rng.gen(),
74            parent_id: None,
75            flags: 1
76        }
77    }
78
79    /// Add the traceparent header to the http headers
80    ///
81    /// ## Examples
82    /// ```
83    /// let mut input_headers = http::HeaderMap::new();
84    /// input_headers.insert("traceparent", "00-00000000000000000000000000000001-0000000000000002-01".parse().unwrap());
85    ///
86    /// let parent = trace_context::TraceContext::extract(&input_headers).unwrap();
87    ///
88    /// let mut output_headers = http::HeaderMap::new();
89    /// parent.inject(&mut output_headers);
90    ///
91    /// let child = trace_context::TraceContext::extract(&output_headers).unwrap();
92    ///
93    /// assert_eq!(child.version(), parent.version());
94    /// assert_eq!(child.trace_id(), parent.trace_id());
95    /// assert_eq!(child.parent_id(), Some(parent.id()));
96    /// assert_eq!(child.flags(), parent.flags());
97    /// ```
98    pub fn inject(&self, headers: &mut http::HeaderMap) {
99        headers.insert("traceparent", format!("{}", self).parse().unwrap());
100    }
101
102    pub fn child(&self) -> Self {
103        let mut rng = rand::thread_rng();
104
105        Self {
106            id: rng.gen(),
107            version: self.version,
108            trace_id: self.trace_id,
109            parent_id: Some(self.id),
110            flags: self.flags,
111        }
112    }
113
114    pub fn id(&self) -> u64 {
115        self.id
116    }
117
118    pub fn version(&self) -> u8 {
119        self.version
120    }
121
122    pub fn trace_id(&self) -> u128 {
123        self.trace_id
124    }
125
126    pub fn parent_id(&self) -> Option<u64> {
127        self.parent_id
128    }
129
130    pub fn flags(&self) -> u8 {
131        self.flags
132    }
133
134    /// Returns true if the trace is sampled
135    ///
136    /// ## Examples
137    /// ```
138    /// let mut headers = http::HeaderMap::new();
139    /// headers.insert("traceparent", "00-00000000000000000000000000000001-0000000000000002-01".parse().unwrap());
140    /// let context = trace_context::TraceContext::extract(&headers).unwrap();
141    /// assert_eq!(context.sampled(), true);
142    /// ```
143    pub fn sampled(&self) -> bool {
144        (self.flags & 0b00000001) == 1
145    }
146
147    /// Change sampled flag
148    ///
149    /// ## Examples
150    /// ```
151    /// let mut context = trace_context::TraceContext::new_root();
152    /// assert_eq!(context.sampled(), true);
153    /// context.set_sampled(false);
154    /// assert_eq!(context.sampled(), false);
155    /// ```
156    pub fn set_sampled(&mut self, sampled: bool) {
157        let x = sampled as u8;
158        self.flags ^= (x ^ self.flags) & (1 << 0);
159    }
160}
161
162impl fmt::Display for TraceContext {
163    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
164        write!(
165            f,
166            "{:02x}-{:032}-{:016x}-{:02x}",
167            self.version, self.trace_id, self.id, self.flags
168        )
169    }
170}
171
172#[cfg(test)]
173mod test {
174    mod extract {
175        #[test]
176        fn default() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
177            let mut headers = http::HeaderMap::new();
178            headers.insert("traceparent", "00-01-deadbeef-00".parse()?);
179            let context = crate::TraceContext::extract(&headers)?;
180            assert_eq!(context.version(), 0);
181            assert_eq!(context.trace_id(), 1);
182            assert_eq!(context.parent_id().unwrap(), 3735928559);
183            assert_eq!(context.flags(), 0);
184            assert_eq!(context.sampled(), false);
185            Ok(())
186        }
187
188        #[test]
189        fn no_header() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
190            let headers = http::HeaderMap::new();
191            let context = crate::TraceContext::extract(&headers)?;
192            assert_eq!(context.version(), 0);
193            assert_eq!(context.parent_id(), None);
194            assert_eq!(context.flags(), 1);
195            assert_eq!(context.sampled(), true);
196            Ok(())
197        }
198
199        #[test]
200        fn not_sampled() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
201            let mut headers = http::HeaderMap::new();
202            headers.insert("traceparent", "00-01-02-00".parse().unwrap());
203            let context = crate::TraceContext::extract(&headers)?;
204            assert_eq!(context.sampled(), false);
205            Ok(())
206        }
207
208        #[test]
209        fn sampled() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
210            let mut headers = http::HeaderMap::new();
211            headers.insert("traceparent", "00-01-02-01".parse().unwrap());
212            let context = crate::TraceContext::extract(&headers)?;
213            assert_eq!(context.sampled(), true);
214            Ok(())
215        }
216    }
217}