1use std::collections::HashSet;
2use std::convert::TryFrom;
3use std::fmt::{self, Display, Error as fmtError, Formatter, Result as fmtResult};
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures::AsyncRead;
9use pin_project::{pin_project, pinned_drop};
10use regex::{Regex, RegexSet};
11use tide::http::headers::HeaderName;
12use tide::{Body, Middleware, Next, Request, Response};
13use time::OffsetDateTime;
14use tracing::{error, info, Span};
15use tracing_futures::Instrument;
16
17pub struct TracingMiddleware<State: Clone + Send + Sync + 'static> {
76 inner: Arc<Inner<State>>,
77}
78
79struct Inner<State: Clone + Send + Sync + 'static> {
80 format: Format<State>,
81 exclude: HashSet<String>,
82 exclude_regex: RegexSet,
83 gen_tracing_span: Option<fn(&Request<State>) -> Span>,
84}
85
86impl<State> TracingMiddleware<State>
87where
88 State: Clone + Send + Sync + 'static,
89{
90 pub fn new(s: &str) -> Self {
92 Self {
93 inner: Arc::new(Inner {
94 format: Format::new(s),
95 exclude: HashSet::new(),
96 exclude_regex: RegexSet::empty(),
97 gen_tracing_span: None,
98 }),
99 }
100 }
101
102 pub fn exclude<T: Into<String>>(mut self, path: T) -> Self {
104 Arc::get_mut(&mut self.inner)
105 .unwrap()
106 .exclude
107 .insert(path.into());
108 self
109 }
110
111 pub fn exclude_regex<T: Into<String>>(mut self, path: T) -> Self {
113 let inner = Arc::get_mut(&mut self.inner).unwrap();
114 let mut patterns = inner.exclude_regex.patterns().to_vec();
115 patterns.push(path.into());
116 let regex_set = RegexSet::new(patterns).unwrap();
117 inner.exclude_regex = regex_set;
118 self
119 }
120
121 pub fn custom_request_replace(
127 mut self,
128 label: &str,
129 f: impl Fn(&Request<State>) -> String + Send + Sync + 'static,
130 ) -> Self {
131 let inner = Arc::get_mut(&mut self.inner).unwrap();
132
133 let ft = inner.format.0.iter_mut().find(
134 |ft| matches!(ft, FormatText::CustomRequest(unit_label, _) if label == unit_label),
135 );
136
137 if let Some(FormatText::CustomRequest(_, request_fn)) = ft {
138 request_fn.replace(CustomRequestFn {
140 inner_fn: Arc::new(f),
141 });
142 } else {
143 error!(
145 "Attempted to register custom request logging function for nonexistent label: {}",
146 label
147 );
148 }
149
150 self
151 }
152
153 pub fn custom_response_replace(
159 mut self,
160 label: &str,
161 f: impl Fn(&Response) -> String + Send + Sync + 'static,
162 ) -> Self {
163 let inner = Arc::get_mut(&mut self.inner).unwrap();
164
165 let ft = inner.format.0.iter_mut().find(
166 |ft| matches!(ft, FormatText::CustomResponse(unit_label, _) if label == unit_label),
167 );
168
169 if let Some(FormatText::CustomResponse(_, response_fn)) = ft {
170 response_fn.replace(CustomResponseFn {
172 inner_fn: Arc::new(f),
173 });
174 } else {
175 error!(
177 "Attempted to register custom response logging function for nonexistent label: {}",
178 label
179 );
180 }
181
182 self
183 }
184
185 pub fn gen_tracing_span(mut self, f: fn(&Request<State>) -> Span) -> Self {
186 let inner = Arc::get_mut(&mut self.inner).unwrap();
187 inner.gen_tracing_span.replace(f);
188 self
189 }
190}
191
192impl<State: Clone + Send + Sync + 'static> Default for TracingMiddleware<State> {
193 fn default() -> Self {
199 Self {
200 inner: Arc::new(Inner {
201 format: Format::default(),
202 exclude: HashSet::new(),
203 exclude_regex: RegexSet::empty(),
204 gen_tracing_span: None,
205 }),
206 }
207 }
208}
209
210#[tide::utils::async_trait]
211impl<State> Middleware<State> for TracingMiddleware<State>
212where
213 State: Clone + Send + Sync + 'static,
214{
215 async fn handle(&self, request: Request<State>, next: Next<'_, State>) -> tide::Result {
216 let path = request.url().path();
217 if self.inner.exclude.contains(path) || self.inner.exclude_regex.is_match(path) {
218 return Ok(next.run(request).await);
219 }
220
221 let now = OffsetDateTime::now_utc();
222 let mut format = self.inner.format.clone();
223 for unit in &mut format.0 {
224 unit.render_request(now, &request);
225 }
226
227 let span = if let Some(f) = self.inner.gen_tracing_span.as_ref() {
228 f(&request)
229 } else {
230 Span::none()
231 };
232 let cloned_span = span.clone();
233
234 let mut resp = next.run(request).instrument(span).await;
235
236 for unit in &mut format.0 {
237 unit.render_response(&resp);
238 }
239
240 let body = resp.take_body();
241 let body_len = body.len();
242 let body_mime = body.mime().clone();
243 let mut new_body = Body::from_reader(
244 futures::io::BufReader::new(StreamLog {
245 body,
246 format,
247 size: 0,
248 time: now,
249 span: cloned_span,
250 }),
251 body_len,
252 );
253 new_body.set_mime(body_mime);
254
255 resp.set_body(new_body);
256 Ok(resp)
257 }
258}
259
260#[doc(hidden)]
261#[derive(Debug, Clone)]
262struct Format<State: Clone + Send + Sync + 'static>(Vec<FormatText<State>>);
263
264impl<State: Clone + Send + Sync + 'static> Format<State> {
265 fn new(s: &str) -> Format<State> {
269 let fmt = Regex::new(r"%(\{([A-Za-z0-9\-_]+)\}([aioe]|xi|xo)|[atPrUsbTDMVQ]?)").unwrap();
270
271 let mut idx = 0;
272 let mut results = Vec::new();
273 for cap in fmt.captures_iter(s) {
274 let m = cap.get(0).unwrap();
275 let pos = m.start();
276 if idx != pos {
277 results.push(FormatText::Str(s[idx..pos].to_owned()));
278 }
279 idx = m.end();
280
281 if let Some(key) = cap.get(2) {
282 results.push(match cap.get(3).unwrap().as_str() {
283 "a" => {
284 if key.as_str() == "r" {
285 FormatText::RealIPRemoteAddr
286 } else {
287 unreachable!()
288 }
289 }
290 "i" => FormatText::RequestHeader(HeaderName::try_from(key.as_str()).unwrap()),
291 "o" => FormatText::ResponseHeader(HeaderName::try_from(key.as_str()).unwrap()),
292 "e" => FormatText::EnvironHeader(key.as_str().to_owned()),
293 "xi" => FormatText::CustomRequest(key.as_str().to_owned(), None),
294 "xo" => FormatText::CustomResponse(key.as_str().to_owned(), None),
295 _ => unreachable!(),
296 })
297 } else {
298 let m = cap.get(1).unwrap();
299 results.push(match m.as_str() {
300 "%" => FormatText::Percent,
301 "a" => FormatText::RemoteAddr,
302 "t" => FormatText::RequestTime,
303 "r" => FormatText::RequestLine,
304 "s" => FormatText::ResponseStatus,
305 "b" => FormatText::ResponseSize,
306 "M" => FormatText::Method,
307 "V" => FormatText::Version,
308 "Q" => FormatText::Query,
309 "U" => FormatText::UrlPath,
310 "T" => FormatText::Time,
311 "D" => FormatText::TimeMillis,
312 _ => FormatText::Str(m.as_str().to_owned()),
313 });
314 }
315 }
316 if idx != s.len() {
317 results.push(FormatText::Str(s[idx..].to_owned()));
318 }
319
320 Format(results)
321 }
322}
323
324impl<State: Clone + Send + Sync + 'static> Default for Format<State> {
325 fn default() -> Self {
327 Format::new(r#"%a "%r" %s %b "%{Referer}i" "%{User-Agent}i" %T"#)
328 }
329}
330
331#[doc(hidden)]
334#[non_exhaustive]
335#[derive(Debug, Clone)]
336enum FormatText<State: Clone + Send + Sync + 'static> {
337 Str(String),
338 Percent,
339 RequestLine,
340 RequestTime,
341 ResponseStatus,
342 ResponseSize,
343 Time,
344 TimeMillis,
345 RemoteAddr,
346 RealIPRemoteAddr,
347 Method,
348 Version,
349 UrlPath,
350 Query,
351 RequestHeader(HeaderName),
352 ResponseHeader(HeaderName),
353 EnvironHeader(String),
354 CustomRequest(String, Option<CustomRequestFn<State>>),
355 CustomResponse(String, Option<CustomResponseFn>),
356}
357
358#[doc(hidden)]
359#[derive(Clone)]
360pub struct CustomRequestFn<State: Clone + Send + Sync + 'static> {
361 inner_fn: Arc<dyn Fn(&Request<State>) -> String + Sync + Send>,
362}
363
364impl<State> CustomRequestFn<State>
365where
366 State: Clone + Send + Sync + 'static,
367{
368 fn call(&self, req: &Request<State>) -> String {
369 (self.inner_fn)(req)
370 }
371}
372
373impl<State> fmt::Debug for CustomRequestFn<State>
374where
375 State: Clone + Send + Sync + 'static,
376{
377 fn fmt(&self, f: &mut Formatter<'_>) -> fmtResult {
378 f.write_str("custom_request_fn")
379 }
380}
381
382#[doc(hidden)]
383#[derive(Clone)]
384pub struct CustomResponseFn {
385 inner_fn: Arc<dyn Fn(&Response) -> String + Sync + Send>,
386}
387
388impl CustomResponseFn {
389 fn call(&self, resp: &Response) -> String {
390 (self.inner_fn)(resp)
391 }
392}
393
394impl fmt::Debug for CustomResponseFn {
395 fn fmt(&self, f: &mut Formatter<'_>) -> fmtResult {
396 f.write_str("custom_response_fn")
397 }
398}
399
400impl<State> FormatText<State>
401where
402 State: Clone + Send + Sync + 'static,
403{
404 fn render_request(&mut self, now: OffsetDateTime, req: &Request<State>) {
405 match &*self {
406 FormatText::RequestLine => {
407 *self = if let Some(query_str) = req.url().query() {
408 FormatText::Str(format!(
409 "{} {}?{} {}",
410 req.method(),
411 req.url().path(),
412 query_str,
413 req.version().as_ref().map_or("?", |v| v.as_ref())
414 ))
415 } else {
416 FormatText::Str(format!(
417 "{} {} {}",
418 req.method(),
419 req.url().path(),
420 req.version().as_ref().map_or("?", |v| v.as_ref())
421 ))
422 };
423 }
424 FormatText::Method => *self = FormatText::Str(req.method().to_string()),
425 FormatText::Version => {
426 *self = FormatText::Str(
427 req.version()
428 .as_ref()
429 .map_or("?".to_owned(), |v| v.to_string()),
430 )
431 }
432 FormatText::Query => {
433 *self = FormatText::Str(req.url().query().map_or("-".to_owned(), |v| v.to_string()))
434 }
435 FormatText::UrlPath => *self = FormatText::Str(req.url().path().to_string()),
436 FormatText::RequestTime => *self = FormatText::Str(now.format("%Y-%m-%dT%H:%M:%S")),
437 FormatText::RequestHeader(ref name) => {
438 let s = if let Some(val) = req.header(name) {
439 if let Some(v) = val.get(0) {
440 v.as_str()
441 } else {
442 "_"
443 }
444 } else {
445 "-"
446 };
447 *self = FormatText::Str(s.to_string());
448 }
449 FormatText::RemoteAddr => {
450 *self = if let Some(addr) = req.remote() {
451 FormatText::Str(addr.to_string())
452 } else {
453 FormatText::Str("-".to_string())
454 };
455 }
456 FormatText::RealIPRemoteAddr => {
457 *self = if let Some(remote) = req.peer_addr() {
458 FormatText::Str(remote.to_string())
459 } else {
460 FormatText::Str("-".to_string())
461 };
462 }
463 FormatText::CustomRequest(_, request_fn) => {
464 *self = match request_fn {
465 Some(f) => FormatText::Str(f.call(req)),
466 None => FormatText::Str("-".to_owned()),
467 };
468 }
469 _ => (),
470 }
471 }
472
473 fn render_response(&mut self, resp: &Response) {
474 match &*self {
475 FormatText::ResponseStatus => {
476 *self = FormatText::Str(format!("{}", resp.status() as u16))
477 }
478 FormatText::ResponseHeader(name) => {
479 let s = if let Some(val) = resp.header(name) {
480 if let Some(v) = val.get(0) {
481 v.as_str()
482 } else {
483 "-"
484 }
485 } else {
486 "-"
487 };
488 *self = FormatText::Str(s.to_string())
489 }
490 FormatText::CustomResponse(_, response_fn) => {
491 *self = match response_fn {
492 Some(f) => FormatText::Str(f.call(resp)),
493 None => FormatText::Str("-".to_owned()),
494 };
495 }
496 _ => (),
497 }
498 }
499
500 fn render(
501 &self,
502 fmt: &mut Formatter<'_>,
503 size: usize,
504 entry_time: OffsetDateTime,
505 ) -> Result<(), fmtError> {
506 match *self {
507 FormatText::Str(ref string) => fmt.write_str(string),
508 FormatText::Percent => "%".fmt(fmt),
509 FormatText::ResponseSize => size.fmt(fmt),
510 FormatText::Time => {
511 let rt = OffsetDateTime::now_utc() - entry_time;
512 let rt = rt.as_seconds_f64();
513 fmt.write_fmt(format_args!("{:.6}", rt))
514 }
515 FormatText::TimeMillis => {
516 let rt = OffsetDateTime::now_utc() - entry_time;
517 let rt = (rt.whole_nanoseconds() as f64) / 1_000_000.0;
518 fmt.write_fmt(format_args!("{:.6}", rt))
519 }
520 FormatText::EnvironHeader(ref name) => {
521 if let Ok(val) = std::env::var(name) {
522 fmt.write_fmt(format_args!("{}", val))
523 } else {
524 "-".fmt(fmt)
525 }
526 }
527 _ => Ok(()),
528 }
529 }
530}
531
532#[pin_project(PinnedDrop)]
533struct StreamLog<State: Clone + Send + Sync + 'static> {
534 #[pin]
535 body: Body,
536 format: Format<State>,
537 size: usize,
538 time: OffsetDateTime,
539 span: Span,
540}
541
542#[pinned_drop]
543impl<State: Clone + Send + Sync + 'static> PinnedDrop for StreamLog<State> {
544 fn drop(self: Pin<&mut Self>) {
545 let render = |fmt: &mut Formatter<'_>| {
546 for unit in &self.format.0 {
547 unit.render(fmt, self.size, self.time)?;
548 }
549 Ok(())
550 };
551 info!(parent: &self.span, "{}", FormatDisplay(&render));
552 }
553}
554
555impl<State> AsyncRead for StreamLog<State>
556where
557 State: Clone + Send + Sync + 'static,
558{
559 fn poll_read(
560 self: Pin<&mut Self>,
561 cx: &mut Context<'_>,
562 buf: &mut [u8],
563 ) -> Poll<std::io::Result<usize>> {
564 let this = self.project();
565 let res = this.body.poll_read(cx, buf);
566 if let Poll::Ready(size) = &res {
567 *this.size += if let Ok(n) = size { *n } else { 0 };
568 }
569 res
570 }
571}
572
573struct FormatDisplay<'a>(&'a dyn Fn(&mut Formatter<'_>) -> Result<(), fmtError>);
575
576impl<'a> Display for FormatDisplay<'a> {
577 fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), fmtError> {
578 (self.0)(fmt)
579 }
580}