1#![warn(missing_docs)]
25#![allow(clippy::style)]
26
27mod grpc;
28mod headers;
29
30use std::net::IpAddr;
31use core::{cmp, fmt, ptr, task};
32use core::pin::Pin;
33use core::future::Future;
34
35pub use tracing;
36
37pub const REQUEST_ID: http::HeaderName = http::HeaderName::from_static("x-request-id");
39pub type MakeSpan = fn() -> tracing::Span;
41pub type ExtractClientIp = fn(&http::request::Parts) -> Option<IpAddr>;
43
44#[inline]
45fn default_client_ip(_: &http::request::Parts) -> Option<IpAddr> {
46 None
47}
48
49#[derive(Copy, Clone, PartialEq, Eq)]
50pub enum Protocol {
52 Http,
56 Grpc,
58}
59
60impl Protocol {
61 #[inline(always)]
62 pub fn from_content_type(typ: &[u8]) -> Self {
64 if typ.starts_with(b"application/grpc") {
65 Self::Grpc
66 } else {
67 Self::Http
68 }
69 }
70
71 #[inline(always)]
72 pub const fn as_str(&self) -> &'static str {
74 match self {
75 Self::Grpc => "Grpc",
76 Self::Http => "Http"
77 }
78 }
79}
80
81impl fmt::Debug for Protocol {
82 #[inline(always)]
83 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
84 fmt::Debug::fmt(self.as_str(), fmt)
85 }
86}
87
88impl fmt::Display for Protocol {
89 #[inline(always)]
90 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
91 fmt::Display::fmt(self.as_str(), fmt)
92 }
93}
94
95type RequestIdBuffer = [u8; 64];
96
97#[derive(Clone)]
98pub struct RequestId {
102 buffer: RequestIdBuffer,
103 len: u8,
104}
105
106impl RequestId {
107 fn from_bytes(bytes: &[u8]) -> Self {
108 let mut buffer: RequestIdBuffer = [0; 64];
109
110 let len = cmp::min(buffer.len(), bytes.len());
111
112 unsafe {
113 ptr::copy_nonoverlapping(bytes.as_ptr(), buffer.as_mut_ptr(), len)
114 };
115
116 Self {
117 buffer,
118 len: len as _,
119 }
120 }
121
122 fn from_uuid(uuid: uuid::Uuid) -> Self {
123 let mut buffer: RequestIdBuffer = [0; 64];
124 let uuid = uuid.as_hyphenated();
125 let len = uuid.encode_lower(&mut buffer).len();
126
127 Self {
128 buffer,
129 len: len as _,
130 }
131 }
132
133 #[inline]
134 pub const fn as_bytes(&self) -> &[u8] {
136 unsafe {
137 core::slice::from_raw_parts(self.buffer.as_ptr(), self.len as _)
138 }
139 }
140
141 #[inline(always)]
142 pub const fn as_str(&self) -> Option<&str> {
144 match core::str::from_utf8(self.as_bytes()) {
145 Ok(header) => Some(header),
146 Err(_) => None,
147 }
148 }
149}
150
151impl fmt::Debug for RequestId {
152 #[inline(always)]
153 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
154 match self.as_str() {
155 Some(id) => fmt::Debug::fmt(id, fmt),
156 None => fmt::Debug::fmt(self.as_bytes(), fmt),
157 }
158 }
159}
160
161impl fmt::Display for RequestId {
162 #[inline(always)]
163 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
164 match self.as_str() {
165 Some(id) => fmt::Display::fmt(id, fmt),
166 None => fmt::Display::fmt("<non-utf8>", fmt),
167 }
168 }
169}
170
171#[macro_export]
172macro_rules! make_request_spanner {
185 ($fn:ident($name:literal, $level:expr)) => {
186 #[track_caller]
187 pub fn $fn() -> $crate::tracing::Span {
188 use $crate::tracing::field;
189 $crate::tracing::span!(
190 $level,
191 $name,
192 http.method = field::Empty,
194 http.url = field::Empty,
195 http.request_id = field::Empty,
196 http.user_agent = field::Empty,
197 http.version = field::Empty,
198 http.headers = field::Empty,
199 protocol = field::Empty,
200 http.client.ip = field::Empty,
202 http.status_code = field::Empty,
204 error.message = field::Empty,
205 )
206 }
207 };
208}
209
210#[derive(Clone, Debug)]
211pub struct RequestInfo {
215 pub protocol: Protocol,
217 pub request_id: RequestId,
219 pub client_ip: Option<IpAddr>,
221}
222
223pub struct RequestSpan {
227 pub span: tracing::Span,
229 pub info: RequestInfo,
231}
232
233impl RequestSpan {
234 pub fn new(span: tracing::Span, extract_client_ip: ExtractClientIp, parts: &http::request::Parts) -> Self {
236 let _entered = span.enter();
237
238 let client_ip = (extract_client_ip)(parts);
239 let protocol = parts.headers
240 .get(http::header::CONTENT_TYPE)
241 .map_or(Protocol::Http, |content_type| Protocol::from_content_type(content_type.as_bytes()));
242
243 let request_id = if let Some(request_id) = parts.headers.get(REQUEST_ID) {
244 RequestId::from_bytes(request_id.as_bytes())
245 } else {
246 RequestId::from_uuid(uuid::Uuid::new_v4())
247 };
248
249 if let Some(user_agent) = parts.headers.get(http::header::USER_AGENT).and_then(|header| header.to_str().ok()) {
250 span.record("http.user_agent", user_agent);
251 }
252 span.record("http.method", parts.method.as_str());
253 span.record("http.version", tracing::field::debug(&parts.version));
254 span.record("http.url", parts.uri.path());
255 if let Some(request_id) = request_id.as_str() {
256 span.record("http.request_id", &request_id);
257 } else {
258 span.record("http.request_id", request_id.as_bytes());
259 }
260 if let Some(client_ip) = client_ip {
261 span.record("http.client.ip", tracing::field::display(client_ip));
262 }
263 span.record("protocol", protocol.as_str());
264
265 drop(_entered);
266
267 Self {
268 span,
269 info: RequestInfo {
270 protocol,
271 request_id,
272 client_ip
273 }
274 }
275 }
276}
277
278#[derive(Clone)]
279pub struct HttpRequestLayer {
281 make_span: MakeSpan,
282 inspect_headers: &'static [&'static http::HeaderName],
283 extract_client_ip: ExtractClientIp,
284}
285
286impl HttpRequestLayer {
287 #[inline]
288 pub fn new(make_span: MakeSpan) -> Self {
290 Self {
291 make_span,
292 inspect_headers: &[],
293 extract_client_ip: default_client_ip
294 }
295 }
296
297 #[inline]
298 pub fn with_inspect_headers(mut self, inspect_headers: &'static [&'static http::HeaderName]) -> Self {
302 self.inspect_headers = inspect_headers;
303 self
304 }
305
306 pub fn with_extract_client_ip(mut self, extract_client_ip: ExtractClientIp) -> Self {
310 self.extract_client_ip = extract_client_ip;
311 self
312 }
313}
314
315impl<S> tower_layer::Layer<S> for HttpRequestLayer {
316 type Service = HttpRequestService<S>;
317 #[inline(always)]
318 fn layer(&self, inner: S) -> Self::Service {
319 HttpRequestService {
320 layer: self.clone(),
321 inner,
322 }
323 }
324}
325
326pub struct HttpRequestService<S> {
328 layer: HttpRequestLayer,
329 inner: S
330}
331
332impl<ReqBody, ResBody, S: tower_service::Service<http::Request<ReqBody>, Response = http::Response<ResBody>>> tower_service::Service<http::Request<ReqBody>> for HttpRequestService<S> where S::Error: std::error::Error {
333 type Response = S::Response;
334 type Error = S::Error;
335 type Future = ResponseFut<S::Future>;
336
337 #[inline(always)]
338 fn poll_ready(&mut self, ctx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
339 self.inner.poll_ready(ctx)
340 }
341
342 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
343 let (parts, body) = req.into_parts();
344 let RequestSpan { span, info } = RequestSpan::new((self.layer.make_span)(), self.layer.extract_client_ip, &parts);
345
346 let _entered = span.enter();
347 if !self.layer.inspect_headers.is_empty() {
348 span.record("http.headers", tracing::field::debug(headers::InspectHeaders {
349 header_list: self.layer.inspect_headers,
350 headers: &parts.headers
351 }));
352 }
353 let request_id = info.request_id.clone();
354 let protocol = info.protocol;
355 let mut req = http::Request::from_parts(parts, body);
356 req.extensions_mut().insert(info);
357 let inner = self.inner.call(req);
358
359 drop(_entered);
360 ResponseFut {
361 inner,
362 span,
363 protocol,
364 request_id
365 }
366 }
367}
368
369pub struct ResponseFut<F> {
371 inner: F,
372 span: tracing::Span,
373 protocol: Protocol,
374 request_id: RequestId,
375}
376
377impl<ResBody, E: std::error::Error, F: Future<Output = Result<http::Response<ResBody>, E>>> Future for ResponseFut<F> {
378 type Output = F::Output;
379
380 fn poll(self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
381 let (fut, span, protocol, request_id) = unsafe {
382 let this = self.get_unchecked_mut();
383 (
384 Pin::new_unchecked(&mut this.inner),
385 &this.span,
386 this.protocol,
387 &this.request_id,
388 )
389 };
390 let _entered = span.enter();
391 match Future::poll(fut, ctx) {
392 task::Poll::Ready(Ok(mut resp)) => {
393 if let Ok(request_id) = http::HeaderValue::from_bytes(request_id.as_bytes()) {
394 resp.headers_mut().insert(REQUEST_ID, request_id);
395 }
396 let status = match protocol {
397 Protocol::Http => resp.status().as_u16(),
398 Protocol::Grpc => match resp.headers().get("grpc-status") {
399 Some(status) => grpc::parse_grpc_status(status.as_bytes()),
400 None => 2,
401 }
402 };
403 span.record("http.status_code", status);
404
405 task::Poll::Ready(Ok(resp))
406 }
407 task::Poll::Ready(Err(error)) => {
408 let status = match protocol {
409 Protocol::Http => 500u16,
410 Protocol::Grpc => 13,
411 };
412 span.record("http.status_code", status);
413 span.record("error.message", tracing::field::display(&error));
414 task::Poll::Ready(Err(error))
415 },
416 task::Poll::Pending => task::Poll::Pending
417 }
418 }
419}