tide_tracing_middleware/
lib.rs

1use std::collections::HashSet;
2use std::convert::TryFrom;
3use std::fmt::{self, Display, Error as fmtError, Formatter, Result as fmtResult};
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures::AsyncRead;
9use pin_project::{pin_project, pinned_drop};
10use regex::{Regex, RegexSet};
11use tide::http::headers::HeaderName;
12use tide::{Body, Middleware, Next, Request, Response};
13use time::OffsetDateTime;
14use tracing::{error, info, Span};
15use tracing_futures::Instrument;
16
17/// `TracingMiddleware` for logging request and response info to the terminal.
18///
19/// ## Usage
20///
21/// Create `TracingMiddleware` middleware with the specified `format`.
22/// Default `TracingMiddleware` could be created with `default` method, it uses the
23/// default format:
24///
25/// ```plain
26/// %a "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T
27/// ```
28///
29/// ```rust
30/// use tide::{Request, Response, StatusCode};
31/// use tide_tracing_middleware::TracingMiddleware;
32/// use tracing::Level;
33/// use tracing_subscriber::FmtSubscriber;
34///
35/// #[async_std::main]
36/// async fn main() -> tide::Result<()> {
37///     FmtSubscriber::builder().with_max_level(Level::DEBUG).init();
38///
39///     let mut app = tide::new();
40///     app.with(TracingMiddleware::default());
41///     app.at("/index").get(index);
42///     app.listen("127.0.0.1:8080").await?;
43///     Ok(())
44/// }
45///
46/// async fn index(_req: Request<()>) -> tide::Result {
47///     let res = Response::builder(StatusCode::Ok)
48///         .body("hello world!")
49///         .build();
50///     Ok(res)
51/// }
52/// ```
53///
54/// ## Format
55///
56/// - `%%`: The percent sign
57/// - `%a`: Remote IP-address (IP-address of proxy if using reverse proxy)
58/// - `%t`: Time when the request was started to process (in rfc3339 format)
59/// - `%r`: First line of request
60/// - `%s`: Response status code
61/// - `%b`: Size of response body in bytes, not including HTTP headers
62/// - `%T`: Time taken to serve the request, in seconds with floating fraction in .06f format
63/// - `%D`: Time taken to serve the request, in milliseconds
64/// - `%U`: Request URL
65/// - `%M`: Request method
66/// - `%V`: Request HTTP version
67/// - `%Q`: Request URL's query string
68/// - `%{r}a`: Real IP remote address **\***
69/// - `%{FOO}i`: request.headers['FOO']
70/// - `%{FOO}o`: response.headers['FOO']
71/// - `%{FOO}e`: os.environ['FOO']
72/// - `%{FOO}xi`: [custom request replacement](TracingMiddleware::custom_request_replace) labelled "FOO"
73/// - `%{FOO}xo`: [custom response replacement](TracingMiddleware::custom_response_replace) labelled "FOO"
74///
75pub struct TracingMiddleware<State: Clone + Send + Sync + 'static> {
76    inner: Arc<Inner<State>>,
77}
78
79struct Inner<State: Clone + Send + Sync + 'static> {
80    format: Format<State>,
81    exclude: HashSet<String>,
82    exclude_regex: RegexSet,
83    gen_tracing_span: Option<fn(&Request<State>) -> Span>,
84}
85
86impl<State> TracingMiddleware<State>
87where
88    State: Clone + Send + Sync + 'static,
89{
90    /// Create `TracingMiddleware` middleware with the specified `format`.
91    pub fn new(s: &str) -> Self {
92        Self {
93            inner: Arc::new(Inner {
94                format: Format::new(s),
95                exclude: HashSet::new(),
96                exclude_regex: RegexSet::empty(),
97                gen_tracing_span: None,
98            }),
99        }
100    }
101
102    /// Ignore and do not log access info for specified path.
103    pub fn exclude<T: Into<String>>(mut self, path: T) -> Self {
104        Arc::get_mut(&mut self.inner)
105            .unwrap()
106            .exclude
107            .insert(path.into());
108        self
109    }
110
111    /// Ignore and do not log access info for paths that match regex
112    pub fn exclude_regex<T: Into<String>>(mut self, path: T) -> Self {
113        let inner = Arc::get_mut(&mut self.inner).unwrap();
114        let mut patterns = inner.exclude_regex.patterns().to_vec();
115        patterns.push(path.into());
116        let regex_set = RegexSet::new(patterns).unwrap();
117        inner.exclude_regex = regex_set;
118        self
119    }
120
121    /// Register a function that receives a Request and returns a String for use in the
122    /// log line. The label passed as the first argument should match a replacement substring in
123    /// the logger format like `%{label}xi`.
124    ///
125    /// It is convention to print "-" to indicate no output instead of an empty string.
126    pub fn custom_request_replace(
127        mut self,
128        label: &str,
129        f: impl Fn(&Request<State>) -> String + Send + Sync + 'static,
130    ) -> Self {
131        let inner = Arc::get_mut(&mut self.inner).unwrap();
132
133        let ft = inner.format.0.iter_mut().find(
134            |ft| matches!(ft, FormatText::CustomRequest(unit_label, _) if label == unit_label),
135        );
136
137        if let Some(FormatText::CustomRequest(_, request_fn)) = ft {
138            // replace into None or previously registered fn using same label
139            request_fn.replace(CustomRequestFn {
140                inner_fn: Arc::new(f),
141            });
142        } else {
143            // non-printed request replacement function diagnostic
144            error!(
145                "Attempted to register custom request logging function for nonexistent label: {}",
146                label
147            );
148        }
149
150        self
151    }
152
153    /// Register a function that receives a Response and returns a String for use in the
154    /// log line. The label passed as the first argument should match a replacement substring in
155    /// the logger format like `%{label}xo`.
156    ///
157    /// It is convention to print "-" to indicate no output instead of an empty string.
158    pub fn custom_response_replace(
159        mut self,
160        label: &str,
161        f: impl Fn(&Response) -> String + Send + Sync + 'static,
162    ) -> Self {
163        let inner = Arc::get_mut(&mut self.inner).unwrap();
164
165        let ft = inner.format.0.iter_mut().find(
166            |ft| matches!(ft, FormatText::CustomResponse(unit_label, _) if label == unit_label),
167        );
168
169        if let Some(FormatText::CustomResponse(_, response_fn)) = ft {
170            // replace into None or previously registered fn using same label
171            response_fn.replace(CustomResponseFn {
172                inner_fn: Arc::new(f),
173            });
174        } else {
175            // non-printed response replacement function diagnostic
176            error!(
177                "Attempted to register custom response logging function for nonexistent label: {}",
178                label
179            );
180        }
181
182        self
183    }
184
185    pub fn gen_tracing_span(mut self, f: fn(&Request<State>) -> Span) -> Self {
186        let inner = Arc::get_mut(&mut self.inner).unwrap();
187        inner.gen_tracing_span.replace(f);
188        self
189    }
190}
191
192impl<State: Clone + Send + Sync + 'static> Default for TracingMiddleware<State> {
193    /// Create `TracingMiddleware` middleware with format:
194    ///
195    /// ```ignore
196    /// %a "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T
197    /// ```
198    fn default() -> Self {
199        Self {
200            inner: Arc::new(Inner {
201                format: Format::default(),
202                exclude: HashSet::new(),
203                exclude_regex: RegexSet::empty(),
204                gen_tracing_span: None,
205            }),
206        }
207    }
208}
209
210#[tide::utils::async_trait]
211impl<State> Middleware<State> for TracingMiddleware<State>
212where
213    State: Clone + Send + Sync + 'static,
214{
215    async fn handle(&self, request: Request<State>, next: Next<'_, State>) -> tide::Result {
216        let path = request.url().path();
217        if self.inner.exclude.contains(path) || self.inner.exclude_regex.is_match(path) {
218            return Ok(next.run(request).await);
219        }
220
221        let now = OffsetDateTime::now_utc();
222        let mut format = self.inner.format.clone();
223        for unit in &mut format.0 {
224            unit.render_request(now, &request);
225        }
226
227        let span = if let Some(f) = self.inner.gen_tracing_span.as_ref() {
228            f(&request)
229        } else {
230            Span::none()
231        };
232        let cloned_span = span.clone();
233
234        let mut resp = next.run(request).instrument(span).await;
235
236        for unit in &mut format.0 {
237            unit.render_response(&resp);
238        }
239
240        let body = resp.take_body();
241        let body_len = body.len();
242        let body_mime = body.mime().clone();
243        let mut new_body = Body::from_reader(
244            futures::io::BufReader::new(StreamLog {
245                body,
246                format,
247                size: 0,
248                time: now,
249                span: cloned_span,
250            }),
251            body_len,
252        );
253        new_body.set_mime(body_mime);
254
255        resp.set_body(new_body);
256        Ok(resp)
257    }
258}
259
260#[doc(hidden)]
261#[derive(Debug, Clone)]
262struct Format<State: Clone + Send + Sync + 'static>(Vec<FormatText<State>>);
263
264impl<State: Clone + Send + Sync + 'static> Format<State> {
265    /// Create a `Format` from a format string.
266    ///
267    /// Returns `None` if the format string syntax is incorrect.
268    fn new(s: &str) -> Format<State> {
269        let fmt = Regex::new(r"%(\{([A-Za-z0-9\-_]+)\}([aioe]|xi|xo)|[atPrUsbTDMVQ]?)").unwrap();
270
271        let mut idx = 0;
272        let mut results = Vec::new();
273        for cap in fmt.captures_iter(s) {
274            let m = cap.get(0).unwrap();
275            let pos = m.start();
276            if idx != pos {
277                results.push(FormatText::Str(s[idx..pos].to_owned()));
278            }
279            idx = m.end();
280
281            if let Some(key) = cap.get(2) {
282                results.push(match cap.get(3).unwrap().as_str() {
283                    "a" => {
284                        if key.as_str() == "r" {
285                            FormatText::RealIPRemoteAddr
286                        } else {
287                            unreachable!()
288                        }
289                    }
290                    "i" => FormatText::RequestHeader(HeaderName::try_from(key.as_str()).unwrap()),
291                    "o" => FormatText::ResponseHeader(HeaderName::try_from(key.as_str()).unwrap()),
292                    "e" => FormatText::EnvironHeader(key.as_str().to_owned()),
293                    "xi" => FormatText::CustomRequest(key.as_str().to_owned(), None),
294                    "xo" => FormatText::CustomResponse(key.as_str().to_owned(), None),
295                    _ => unreachable!(),
296                })
297            } else {
298                let m = cap.get(1).unwrap();
299                results.push(match m.as_str() {
300                    "%" => FormatText::Percent,
301                    "a" => FormatText::RemoteAddr,
302                    "t" => FormatText::RequestTime,
303                    "r" => FormatText::RequestLine,
304                    "s" => FormatText::ResponseStatus,
305                    "b" => FormatText::ResponseSize,
306                    "M" => FormatText::Method,
307                    "V" => FormatText::Version,
308                    "Q" => FormatText::Query,
309                    "U" => FormatText::UrlPath,
310                    "T" => FormatText::Time,
311                    "D" => FormatText::TimeMillis,
312                    _ => FormatText::Str(m.as_str().to_owned()),
313                });
314            }
315        }
316        if idx != s.len() {
317            results.push(FormatText::Str(s[idx..].to_owned()));
318        }
319
320        Format(results)
321    }
322}
323
324impl<State: Clone + Send + Sync + 'static> Default for Format<State> {
325    /// Return the default formatting style for the `TracingMiddleware`:
326    fn default() -> Self {
327        Format::new(r#"%a "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T"#)
328    }
329}
330
331/// A string of text to be logged. This is either one of the data
332/// fields supported by the `TracingMiddleware`, or a custom `String`.
333#[doc(hidden)]
334#[non_exhaustive]
335#[derive(Debug, Clone)]
336enum FormatText<State: Clone + Send + Sync + 'static> {
337    Str(String),
338    Percent,
339    RequestLine,
340    RequestTime,
341    ResponseStatus,
342    ResponseSize,
343    Time,
344    TimeMillis,
345    RemoteAddr,
346    RealIPRemoteAddr,
347    Method,
348    Version,
349    UrlPath,
350    Query,
351    RequestHeader(HeaderName),
352    ResponseHeader(HeaderName),
353    EnvironHeader(String),
354    CustomRequest(String, Option<CustomRequestFn<State>>),
355    CustomResponse(String, Option<CustomResponseFn>),
356}
357
358#[doc(hidden)]
359#[derive(Clone)]
360pub struct CustomRequestFn<State: Clone + Send + Sync + 'static> {
361    inner_fn: Arc<dyn Fn(&Request<State>) -> String + Sync + Send>,
362}
363
364impl<State> CustomRequestFn<State>
365where
366    State: Clone + Send + Sync + 'static,
367{
368    fn call(&self, req: &Request<State>) -> String {
369        (self.inner_fn)(req)
370    }
371}
372
373impl<State> fmt::Debug for CustomRequestFn<State>
374where
375    State: Clone + Send + Sync + 'static,
376{
377    fn fmt(&self, f: &mut Formatter<'_>) -> fmtResult {
378        f.write_str("custom_request_fn")
379    }
380}
381
382#[doc(hidden)]
383#[derive(Clone)]
384pub struct CustomResponseFn {
385    inner_fn: Arc<dyn Fn(&Response) -> String + Sync + Send>,
386}
387
388impl CustomResponseFn {
389    fn call(&self, resp: &Response) -> String {
390        (self.inner_fn)(resp)
391    }
392}
393
394impl fmt::Debug for CustomResponseFn {
395    fn fmt(&self, f: &mut Formatter<'_>) -> fmtResult {
396        f.write_str("custom_response_fn")
397    }
398}
399
400impl<State> FormatText<State>
401where
402    State: Clone + Send + Sync + 'static,
403{
404    fn render_request(&mut self, now: OffsetDateTime, req: &Request<State>) {
405        match &*self {
406            FormatText::RequestLine => {
407                *self = if let Some(query_str) = req.url().query() {
408                    FormatText::Str(format!(
409                        "{} {}?{} {}",
410                        req.method(),
411                        req.url().path(),
412                        query_str,
413                        req.version().as_ref().map_or("?", |v| v.as_ref())
414                    ))
415                } else {
416                    FormatText::Str(format!(
417                        "{} {} {}",
418                        req.method(),
419                        req.url().path(),
420                        req.version().as_ref().map_or("?", |v| v.as_ref())
421                    ))
422                };
423            }
424            FormatText::Method => *self = FormatText::Str(req.method().to_string()),
425            FormatText::Version => {
426                *self = FormatText::Str(
427                    req.version()
428                        .as_ref()
429                        .map_or("?".to_owned(), |v| v.to_string()),
430                )
431            }
432            FormatText::Query => {
433                *self = FormatText::Str(req.url().query().map_or("-".to_owned(), |v| v.to_string()))
434            }
435            FormatText::UrlPath => *self = FormatText::Str(req.url().path().to_string()),
436            FormatText::RequestTime => *self = FormatText::Str(now.format("%Y-%m-%dT%H:%M:%S")),
437            FormatText::RequestHeader(ref name) => {
438                let s = if let Some(val) = req.header(name) {
439                    if let Some(v) = val.get(0) {
440                        v.as_str()
441                    } else {
442                        "_"
443                    }
444                } else {
445                    "-"
446                };
447                *self = FormatText::Str(s.to_string());
448            }
449            FormatText::RemoteAddr => {
450                *self = if let Some(addr) = req.remote() {
451                    FormatText::Str(addr.to_string())
452                } else {
453                    FormatText::Str("-".to_string())
454                };
455            }
456            FormatText::RealIPRemoteAddr => {
457                *self = if let Some(remote) = req.peer_addr() {
458                    FormatText::Str(remote.to_string())
459                } else {
460                    FormatText::Str("-".to_string())
461                };
462            }
463            FormatText::CustomRequest(_, request_fn) => {
464                *self = match request_fn {
465                    Some(f) => FormatText::Str(f.call(req)),
466                    None => FormatText::Str("-".to_owned()),
467                };
468            }
469            _ => (),
470        }
471    }
472
473    fn render_response(&mut self, resp: &Response) {
474        match &*self {
475            FormatText::ResponseStatus => {
476                *self = FormatText::Str(format!("{}", resp.status() as u16))
477            }
478            FormatText::ResponseHeader(name) => {
479                let s = if let Some(val) = resp.header(name) {
480                    if let Some(v) = val.get(0) {
481                        v.as_str()
482                    } else {
483                        "-"
484                    }
485                } else {
486                    "-"
487                };
488                *self = FormatText::Str(s.to_string())
489            }
490            FormatText::CustomResponse(_, response_fn) => {
491                *self = match response_fn {
492                    Some(f) => FormatText::Str(f.call(resp)),
493                    None => FormatText::Str("-".to_owned()),
494                };
495            }
496            _ => (),
497        }
498    }
499
500    fn render(
501        &self,
502        fmt: &mut Formatter<'_>,
503        size: usize,
504        entry_time: OffsetDateTime,
505    ) -> Result<(), fmtError> {
506        match *self {
507            FormatText::Str(ref string) => fmt.write_str(string),
508            FormatText::Percent => "%".fmt(fmt),
509            FormatText::ResponseSize => size.fmt(fmt),
510            FormatText::Time => {
511                let rt = OffsetDateTime::now_utc() - entry_time;
512                let rt = rt.as_seconds_f64();
513                fmt.write_fmt(format_args!("{:.6}", rt))
514            }
515            FormatText::TimeMillis => {
516                let rt = OffsetDateTime::now_utc() - entry_time;
517                let rt = (rt.whole_nanoseconds() as f64) / 1_000_000.0;
518                fmt.write_fmt(format_args!("{:.6}", rt))
519            }
520            FormatText::EnvironHeader(ref name) => {
521                if let Ok(val) = std::env::var(name) {
522                    fmt.write_fmt(format_args!("{}", val))
523                } else {
524                    "-".fmt(fmt)
525                }
526            }
527            _ => Ok(()),
528        }
529    }
530}
531
532#[pin_project(PinnedDrop)]
533struct StreamLog<State: Clone + Send + Sync + 'static> {
534    #[pin]
535    body: Body,
536    format: Format<State>,
537    size: usize,
538    time: OffsetDateTime,
539    span: Span,
540}
541
542#[pinned_drop]
543impl<State: Clone + Send + Sync + 'static> PinnedDrop for StreamLog<State> {
544    fn drop(self: Pin<&mut Self>) {
545        let render = |fmt: &mut Formatter<'_>| {
546            for unit in &self.format.0 {
547                unit.render(fmt, self.size, self.time)?;
548            }
549            Ok(())
550        };
551        info!(parent: &self.span, "{}", FormatDisplay(&render));
552    }
553}
554
555impl<State> AsyncRead for StreamLog<State>
556where
557    State: Clone + Send + Sync + 'static,
558{
559    fn poll_read(
560        self: Pin<&mut Self>,
561        cx: &mut Context<'_>,
562        buf: &mut [u8],
563    ) -> Poll<std::io::Result<usize>> {
564        let this = self.project();
565        let res = this.body.poll_read(cx, buf);
566        if let Poll::Ready(size) = &res {
567            *this.size += if let Ok(n) = size { *n } else { 0 };
568        }
569        res
570    }
571}
572
573/// Converter to get a String from something that writes to a Formatter.
574struct FormatDisplay<'a>(&'a dyn Fn(&mut Formatter<'_>) -> Result<(), fmtError>);
575
576impl<'a> Display for FormatDisplay<'a> {
577    fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), fmtError> {
578        (self.0)(fmt)
579    }
580}