switchgear_components/axum/middleware/
logger.rs

1use axum::extract::{ConnectInfo, Request};
2use axum::http::Version;
3use axum::response::Response;
4use chrono::{DateTime, Utc};
5use client_ip::{
6    cf_connecting_ip, cloudfront_viewer_address, fly_client_ip, rightmost_forwarded,
7    rightmost_x_forwarded_for, true_client_ip, x_real_ip,
8};
9use log::{log, Level};
10use std::future::Future;
11use std::net::SocketAddr;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14use tower::{Layer, Service};
15
16#[derive(Clone)]
17pub struct ClfLogger {
18    service_log_target: String,
19}
20
21impl ClfLogger {
22    pub fn new(service_name: &str) -> Self {
23        Self {
24            service_log_target: format!("clf::{service_name}"),
25        }
26    }
27}
28
29impl<S> Layer<S> for ClfLogger {
30    type Service = ClfLoggerService<S>;
31
32    fn layer(&self, inner: S) -> Self::Service {
33        ClfLoggerService {
34            inner,
35            service_log_target: self.service_log_target.clone(),
36        }
37    }
38}
39
40#[derive(Clone)]
41pub struct ClfLoggerService<S> {
42    inner: S,
43    service_log_target: String,
44}
45
46impl<S> Service<Request> for ClfLoggerService<S>
47where
48    S: Service<Request, Response = Response> + Clone + Send + 'static,
49    S::Future: Send + 'static,
50    S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
51{
52    type Response = S::Response;
53    type Error = S::Error;
54    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
55
56    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
57        self.inner.poll_ready(cx)
58    }
59
60    fn call(&mut self, req: Request) -> Self::Future {
61        let not_ready_inner = self.inner.clone();
62        let mut inner = std::mem::replace(&mut self.inner, not_ready_inner);
63
64        let service_name = self.service_log_target.clone();
65
66        Box::pin(async move {
67            let method = req.method().clone();
68            let uri = req.uri().clone();
69            let version = match req.version() {
70                Version::HTTP_09 => "HTTP/0.9",
71                Version::HTTP_10 => "HTTP/1.0",
72                Version::HTTP_11 => "HTTP/1.1",
73                Version::HTTP_2 => "HTTP/2.0",
74                Version::HTTP_3 => "HTTP/3.0",
75                _ => "HTTP/1.1",
76            };
77
78            let host = cf_connecting_ip(req.headers())
79                .ok()
80                .or_else(|| cloudfront_viewer_address(req.headers()).ok())
81                .or_else(|| fly_client_ip(req.headers()).ok())
82                .or_else(|| x_real_ip(req.headers()).ok())
83                .or_else(|| true_client_ip(req.headers()).ok())
84                .or_else(|| rightmost_forwarded(req.headers()).ok())
85                .or_else(|| rightmost_x_forwarded_for(req.headers()).ok())
86                .or_else(|| {
87                    req.extensions()
88                        .get::<ConnectInfo<SocketAddr>>()
89                        .map(|ci| ci.ip())
90                });
91
92            let host = host.map_or_else(|| "-".to_string(), |a| a.to_string());
93
94            let response = inner.call(req).await?;
95
96            let status = response.status();
97            let status_code = status.as_u16();
98
99            // strftime format: %d/%b/%Y:%H:%M:%S %z
100            let now: DateTime<Utc> = Utc::now();
101            let timestamp = format!("[{}]", now.format("%d/%b/%Y:%H:%M:%S %z"));
102
103            let level = if status.is_server_error() {
104                Level::Error
105            } else if status.is_client_error() {
106                Level::Warn
107            } else {
108                Level::Info
109            };
110
111            // host ident authuser timestamp request-line status bytes
112            log!(target:&service_name, level, "{host} - - {timestamp} {method} {uri} {version} {status_code} -");
113
114            Ok(response)
115        })
116    }
117}