trace_context/lib.rs
1//! Extract and inject [W3C TraceContext](https://w3c.github.io/trace-context/) headers.
2//!
3//! ## Examples
4//! ```
5//! let mut headers = http::HeaderMap::new();
6//! headers.insert(
7//! "traceparent",
8//! "00-0af7651916cd43dd8448eb211c80319c-00f067aa0ba902b7-01".parse().unwrap()
9//! );
10//!
11//! let context = trace_context::TraceContext::extract(&headers).unwrap();
12//!
13//! assert_eq!(context.trace_id(), u128::from_str_radix("0af7651916cd43dd8448eb211c80319c", 16).unwrap());
14//! assert_eq!(context.parent_id(), u64::from_str_radix("00f067aa0ba902b7", 16).ok());
15//! assert_eq!(context.sampled(), true);
16//! ```
17
18#![deny(unsafe_code)]
19
20use rand::Rng;
21use std::fmt;
22
23/// A TraceContext object
24#[derive(Debug)]
25pub struct TraceContext {
26 id: u64,
27 version: u8,
28 trace_id: u128,
29 parent_id: Option<u64>,
30 flags: u8,
31}
32
33impl TraceContext {
34 /// Create and return TraceContext object based on `traceparent` HTTP header.
35 ///
36 /// ## Examples
37 /// ```
38 /// let mut headers = http::HeaderMap::new();
39 /// headers.insert("traceparent", "00-0af7651916cd43dd8448eb211c80319c-00f067aa0ba902b7-01".parse().unwrap());
40 ///
41 /// let context = trace_context::TraceContext::extract(&headers).unwrap();
42 ///
43 /// assert_eq!(context.trace_id(), u128::from_str_radix("0af7651916cd43dd8448eb211c80319c", 16).unwrap());
44 /// assert_eq!(context.parent_id(), u64::from_str_radix("00f067aa0ba902b7",
45 /// 16).ok());
46 /// assert_eq!(context.sampled(), true);
47 /// ```
48 pub fn extract(headers: &http::HeaderMap) -> Result<Self, std::num::ParseIntError> {
49 let mut rng = rand::thread_rng();
50
51 let traceparent = match headers.get("traceparent") {
52 Some(header) => header.to_str().unwrap(),
53 None => return Ok(Self::new_root()),
54 };
55
56 let parts: Vec<&str> = traceparent.split('-').collect();
57
58 Ok(Self {
59 id: rng.gen(),
60 version: u8::from_str_radix(parts[0], 16)?,
61 trace_id: u128::from_str_radix(parts[1], 16)?,
62 parent_id: Some(u64::from_str_radix(parts[2], 16)?),
63 flags: u8::from_str_radix(parts[3], 16)?
64 })
65 }
66
67 pub fn new_root() -> Self {
68 let mut rng = rand::thread_rng();
69
70 Self {
71 id: rng.gen(),
72 version: 0,
73 trace_id: rng.gen(),
74 parent_id: None,
75 flags: 1
76 }
77 }
78
79 /// Add the traceparent header to the http headers
80 ///
81 /// ## Examples
82 /// ```
83 /// let mut input_headers = http::HeaderMap::new();
84 /// input_headers.insert("traceparent", "00-00000000000000000000000000000001-0000000000000002-01".parse().unwrap());
85 ///
86 /// let parent = trace_context::TraceContext::extract(&input_headers).unwrap();
87 ///
88 /// let mut output_headers = http::HeaderMap::new();
89 /// parent.inject(&mut output_headers);
90 ///
91 /// let child = trace_context::TraceContext::extract(&output_headers).unwrap();
92 ///
93 /// assert_eq!(child.version(), parent.version());
94 /// assert_eq!(child.trace_id(), parent.trace_id());
95 /// assert_eq!(child.parent_id(), Some(parent.id()));
96 /// assert_eq!(child.flags(), parent.flags());
97 /// ```
98 pub fn inject(&self, headers: &mut http::HeaderMap) {
99 headers.insert("traceparent", format!("{}", self).parse().unwrap());
100 }
101
102 pub fn child(&self) -> Self {
103 let mut rng = rand::thread_rng();
104
105 Self {
106 id: rng.gen(),
107 version: self.version,
108 trace_id: self.trace_id,
109 parent_id: Some(self.id),
110 flags: self.flags,
111 }
112 }
113
114 pub fn id(&self) -> u64 {
115 self.id
116 }
117
118 pub fn version(&self) -> u8 {
119 self.version
120 }
121
122 pub fn trace_id(&self) -> u128 {
123 self.trace_id
124 }
125
126 pub fn parent_id(&self) -> Option<u64> {
127 self.parent_id
128 }
129
130 pub fn flags(&self) -> u8 {
131 self.flags
132 }
133
134 /// Returns true if the trace is sampled
135 ///
136 /// ## Examples
137 /// ```
138 /// let mut headers = http::HeaderMap::new();
139 /// headers.insert("traceparent", "00-00000000000000000000000000000001-0000000000000002-01".parse().unwrap());
140 /// let context = trace_context::TraceContext::extract(&headers).unwrap();
141 /// assert_eq!(context.sampled(), true);
142 /// ```
143 pub fn sampled(&self) -> bool {
144 (self.flags & 0b00000001) == 1
145 }
146
147 /// Change sampled flag
148 ///
149 /// ## Examples
150 /// ```
151 /// let mut context = trace_context::TraceContext::new_root();
152 /// assert_eq!(context.sampled(), true);
153 /// context.set_sampled(false);
154 /// assert_eq!(context.sampled(), false);
155 /// ```
156 pub fn set_sampled(&mut self, sampled: bool) {
157 let x = sampled as u8;
158 self.flags ^= (x ^ self.flags) & (1 << 0);
159 }
160}
161
162impl fmt::Display for TraceContext {
163 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
164 write!(
165 f,
166 "{:02x}-{:032}-{:016x}-{:02x}",
167 self.version, self.trace_id, self.id, self.flags
168 )
169 }
170}
171
172#[cfg(test)]
173mod test {
174 mod extract {
175 #[test]
176 fn default() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
177 let mut headers = http::HeaderMap::new();
178 headers.insert("traceparent", "00-01-deadbeef-00".parse()?);
179 let context = crate::TraceContext::extract(&headers)?;
180 assert_eq!(context.version(), 0);
181 assert_eq!(context.trace_id(), 1);
182 assert_eq!(context.parent_id().unwrap(), 3735928559);
183 assert_eq!(context.flags(), 0);
184 assert_eq!(context.sampled(), false);
185 Ok(())
186 }
187
188 #[test]
189 fn no_header() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
190 let headers = http::HeaderMap::new();
191 let context = crate::TraceContext::extract(&headers)?;
192 assert_eq!(context.version(), 0);
193 assert_eq!(context.parent_id(), None);
194 assert_eq!(context.flags(), 1);
195 assert_eq!(context.sampled(), true);
196 Ok(())
197 }
198
199 #[test]
200 fn not_sampled() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
201 let mut headers = http::HeaderMap::new();
202 headers.insert("traceparent", "00-01-02-00".parse().unwrap());
203 let context = crate::TraceContext::extract(&headers)?;
204 assert_eq!(context.sampled(), false);
205 Ok(())
206 }
207
208 #[test]
209 fn sampled() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
210 let mut headers = http::HeaderMap::new();
211 headers.insert("traceparent", "00-01-02-01".parse().unwrap());
212 let context = crate::TraceContext::extract(&headers)?;
213 assert_eq!(context.sampled(), true);
214 Ok(())
215 }
216 }
217}