1#![warn(missing_docs)]
33#![allow(clippy::style)]
34
35mod grpc;
36mod headers;
37#[cfg(feature = "opentelemetry")]
38pub mod opentelemetry;
39#[cfg(feature = "datadog")]
40pub mod datadog;
41
42use std::net::IpAddr;
43use core::{cmp, fmt, ptr, task};
44use core::pin::Pin;
45use core::future::Future;
46
47pub use tracing;
48
49pub const REQUEST_ID: http::HeaderName = http::HeaderName::from_static("x-request-id");
51pub type MakeSpan = fn() -> tracing::Span;
53pub type ExtractClientIp = fn(&http::request::Parts) -> Option<IpAddr>;
55
56#[inline]
57fn default_client_ip(_: &http::request::Parts) -> Option<IpAddr> {
58 None
59}
60
61#[derive(Copy, Clone, PartialEq, Eq)]
62pub enum Protocol {
64 Http,
68 Grpc,
70}
71
72impl Protocol {
73 #[inline(always)]
74 pub fn from_content_type(typ: &[u8]) -> Self {
76 if typ.starts_with(b"application/grpc") {
77 Self::Grpc
78 } else {
79 Self::Http
80 }
81 }
82
83 #[inline(always)]
84 pub const fn as_str(&self) -> &'static str {
86 match self {
87 Self::Grpc => "grpc",
88 Self::Http => "http"
89 }
90 }
91}
92
93impl fmt::Debug for Protocol {
94 #[inline(always)]
95 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
96 fmt::Debug::fmt(self.as_str(), fmt)
97 }
98}
99
100impl fmt::Display for Protocol {
101 #[inline(always)]
102 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
103 fmt::Display::fmt(self.as_str(), fmt)
104 }
105}
106
107type RequestIdBuffer = [u8; 64];
108
109#[derive(Clone)]
110pub struct RequestId {
114 buffer: RequestIdBuffer,
115 len: u8,
116}
117
118impl RequestId {
119 fn from_bytes(bytes: &[u8]) -> Self {
120 let mut buffer: RequestIdBuffer = [0; 64];
121
122 let len = cmp::min(buffer.len(), bytes.len());
123
124 unsafe {
125 ptr::copy_nonoverlapping(bytes.as_ptr(), buffer.as_mut_ptr(), len)
126 };
127
128 Self {
129 buffer,
130 len: len as _,
131 }
132 }
133
134 fn from_uuid(uuid: uuid::Uuid) -> Self {
135 let mut buffer: RequestIdBuffer = [0; 64];
136 let uuid = uuid.as_hyphenated();
137 let len = uuid.encode_lower(&mut buffer).len();
138
139 Self {
140 buffer,
141 len: len as _,
142 }
143 }
144
145 #[inline]
146 pub const fn as_bytes(&self) -> &[u8] {
148 unsafe {
149 core::slice::from_raw_parts(self.buffer.as_ptr(), self.len as _)
150 }
151 }
152
153 #[inline(always)]
154 pub const fn as_str(&self) -> Option<&str> {
156 match core::str::from_utf8(self.as_bytes()) {
157 Ok(header) => Some(header),
158 Err(_) => None,
159 }
160 }
161}
162
163impl fmt::Debug for RequestId {
164 #[inline(always)]
165 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
166 match self.as_str() {
167 Some(id) => fmt::Debug::fmt(id, fmt),
168 None => fmt::Debug::fmt(self.as_bytes(), fmt),
169 }
170 }
171}
172
173impl fmt::Display for RequestId {
174 #[inline(always)]
175 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
176 match self.as_str() {
177 Some(id) => fmt::Display::fmt(id, fmt),
178 None => fmt::Display::fmt("<non-utf8>", fmt),
179 }
180 }
181}
182
183#[macro_export]
184macro_rules! make_request_spanner {
219 ($fn:ident($name:literal, $level:expr)) => {
220 $crate::make_request_spanner!($fn($name, $level,));
221 };
222 ($fn:ident($name:literal, $level:expr, $($fields:tt)*)) => {
223 #[track_caller]
224 pub fn $fn() -> $crate::tracing::Span {
225 use $crate::tracing::field;
226
227 $crate::tracing::span!(
228 $level,
229 $name,
230 span.kind = "server",
232 http.request.method = field::Empty,
234 url.path = field::Empty,
235 url.query = field::Empty,
236 url.scheme = field::Empty,
237 http.request_id = field::Empty,
238 user_agent.original = field::Empty,
239 http.headers = field::Empty,
240 network.protocol.name = field::Empty,
241 network.protocol.version = field::Empty,
242 client.address = field::Empty,
244 http.response.status_code = field::Empty,
246 error.type = field::Empty,
247 error.message = field::Empty,
248 $(
249 $fields
250 )*
251 )
252 }
253 };
254}
255
256#[derive(Clone, Debug)]
257pub struct RequestInfo {
261 pub protocol: Protocol,
263 pub request_id: RequestId,
265 pub client_ip: Option<IpAddr>,
267}
268
269pub struct RequestSpan {
273 pub span: tracing::Span,
275 pub info: RequestInfo,
277}
278
279impl RequestSpan {
280 pub fn new(span: tracing::Span, extract_client_ip: ExtractClientIp, parts: &http::request::Parts) -> Self {
282 let _entered = span.enter();
283
284 let client_ip = (extract_client_ip)(parts);
285 let protocol = parts.headers
286 .get(http::header::CONTENT_TYPE)
287 .map_or(Protocol::Http, |content_type| Protocol::from_content_type(content_type.as_bytes()));
288
289 let request_id = if let Some(request_id) = parts.headers.get(REQUEST_ID) {
290 RequestId::from_bytes(request_id.as_bytes())
291 } else {
292 RequestId::from_uuid(uuid::Uuid::new_v4())
293 };
294
295 if let Some(user_agent) = parts.headers.get(http::header::USER_AGENT).and_then(|header| header.to_str().ok()) {
296 span.record("user_agent.original", user_agent);
297 }
298 span.record("http.request.method", parts.method.as_str());
299 span.record("url.path", parts.uri.path());
300 if let Some(query) = parts.uri.query() {
301 span.record("url.query", query);
302 }
303 if let Some(scheme) = parts.uri.scheme() {
304 span.record("url.scheme", scheme.as_str());
305 }
306 if let Some(request_id) = request_id.as_str() {
307 span.record("http.request_id", &request_id);
308 } else {
309 span.record("http.request_id", request_id.as_bytes());
310 }
311 if let Some(client_ip) = client_ip {
312 span.record("client.address", tracing::field::display(client_ip));
313 }
314 span.record("network.protocol.name", protocol.as_str());
315 if let Protocol::Http = protocol {
316 match parts.version {
317 http::Version::HTTP_09 => span.record("network.protocol.version", 0.9),
318 http::Version::HTTP_10 => span.record("network.protocol.version", 1.0),
319 http::Version::HTTP_11 => span.record("network.protocol.version", 1.1),
320 http::Version::HTTP_2 => span.record("network.protocol.version", 2),
321 http::Version::HTTP_3 => span.record("network.protocol.version", 3),
322 _ => span.record("network.protocol.version", 0),
324 };
325 }
326
327 drop(_entered);
328
329 Self {
330 span,
331 info: RequestInfo {
332 protocol,
333 request_id,
334 client_ip
335 }
336 }
337 }
338}
339
340#[derive(Clone)]
341pub struct HttpRequestLayer {
343 make_span: MakeSpan,
344 inspect_headers: &'static [&'static http::HeaderName],
345 extract_client_ip: ExtractClientIp,
346}
347
348impl HttpRequestLayer {
349 #[inline]
350 pub fn new(make_span: MakeSpan) -> Self {
352 Self {
353 make_span,
354 inspect_headers: &[],
355 extract_client_ip: default_client_ip
356 }
357 }
358
359 #[inline]
360 pub fn with_inspect_headers(mut self, inspect_headers: &'static [&'static http::HeaderName]) -> Self {
364 self.inspect_headers = inspect_headers;
365 self
366 }
367
368 pub fn with_extract_client_ip(mut self, extract_client_ip: ExtractClientIp) -> Self {
372 self.extract_client_ip = extract_client_ip;
373 self
374 }
375}
376
377impl<S> tower_layer::Layer<S> for HttpRequestLayer {
378 type Service = HttpRequestService<S>;
379 #[inline(always)]
380 fn layer(&self, inner: S) -> Self::Service {
381 HttpRequestService {
382 layer: self.clone(),
383 inner,
384 }
385 }
386}
387
388pub struct HttpRequestService<S> {
390 layer: HttpRequestLayer,
391 inner: S
392}
393
394impl<ReqBody, ResBody, S: tower_service::Service<http::Request<ReqBody>, Response = http::Response<ResBody>>> tower_service::Service<http::Request<ReqBody>> for HttpRequestService<S> where S::Error: std::error::Error {
395 type Response = S::Response;
396 type Error = S::Error;
397 type Future = ResponseFut<S::Future>;
398
399 #[inline(always)]
400 fn poll_ready(&mut self, ctx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
401 self.inner.poll_ready(ctx)
402 }
403
404 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
405 let (parts, body) = req.into_parts();
406 let RequestSpan { span, info } = RequestSpan::new((self.layer.make_span)(), self.layer.extract_client_ip, &parts);
407
408 let mut req = http::Request::from_parts(parts, body);
409 #[cfg(feature = "opentelemetry")]
410 opentelemetry::on_request(&span, &req);
411 #[cfg(feature = "datadog")]
412 datadog::on_request(&span, &req);
413
414 let _entered = span.enter();
415 if !self.layer.inspect_headers.is_empty() {
416 span.record("http.headers", tracing::field::debug(headers::InspectHeaders {
417 header_list: self.layer.inspect_headers,
418 headers: req.headers()
419 }));
420 }
421 let request_id = info.request_id.clone();
422 let protocol = info.protocol;
423 req.extensions_mut().insert(info);
424
425 let inner = self.inner.call(req);
426
427 drop(_entered);
428 ResponseFut {
429 inner,
430 span,
431 protocol,
432 request_id
433 }
434 }
435}
436
437pub struct ResponseFut<F> {
439 inner: F,
440 span: tracing::Span,
441 protocol: Protocol,
442 request_id: RequestId,
443}
444
445impl<ResBody, E: std::error::Error, F: Future<Output = Result<http::Response<ResBody>, E>>> Future for ResponseFut<F> {
446 type Output = F::Output;
447
448 fn poll(self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
449 let (fut, span, protocol, request_id) = unsafe {
450 let this = self.get_unchecked_mut();
451 (
452 Pin::new_unchecked(&mut this.inner),
453 &this.span,
454 this.protocol,
455 &this.request_id,
456 )
457 };
458 let _entered = span.enter();
459 match Future::poll(fut, ctx) {
460 task::Poll::Ready(Ok(mut resp)) => {
461 if let Ok(request_id) = http::HeaderValue::from_bytes(request_id.as_bytes()) {
462 resp.headers_mut().insert(REQUEST_ID, request_id);
463 }
464 let status = match protocol {
465 Protocol::Http => resp.status().as_u16(),
466 Protocol::Grpc => match resp.headers().get("grpc-status") {
467 Some(status) => grpc::parse_grpc_status(status.as_bytes()),
468 None => 2,
469 }
470 };
471 span.record("http.response.status_code", status);
472
473 #[cfg(feature = "opentelemetry")]
474 opentelemetry::on_response_ok(&span, &mut resp);
475 #[cfg(feature = "datadog")]
476 datadog::on_response_ok(&span, &mut resp);
477
478 task::Poll::Ready(Ok(resp))
479 }
480 task::Poll::Ready(Err(error)) => {
481 let status = match protocol {
482 Protocol::Http => 500u16,
483 Protocol::Grpc => 13,
484 };
485 span.record("http.response.status_code", status);
486 span.record("error.type", core::any::type_name::<E>());
487 span.record("error.message", tracing::field::display(&error));
488
489 #[cfg(feature = "opentelemetry")]
490 opentelemetry::on_response_error(&span, &error);
491 #[cfg(feature = "datadog")]
492 datadog::on_response_error(&span, &error);
493
494 task::Poll::Ready(Err(error))
495 },
496 task::Poll::Pending => task::Poll::Pending
497 }
498 }
499}