switchgear_components/axum/middleware/
logger.rs1use axum::extract::{ConnectInfo, Request};
2use axum::http::Version;
3use axum::response::Response;
4use chrono::{DateTime, Utc};
5use client_ip::{
6 cf_connecting_ip, cloudfront_viewer_address, fly_client_ip, rightmost_forwarded,
7 rightmost_x_forwarded_for, true_client_ip, x_real_ip,
8};
9use log::{log, Level};
10use std::future::Future;
11use std::net::SocketAddr;
12use std::pin::Pin;
13use std::task::{Context, Poll};
14use tower::{Layer, Service};
15
16#[derive(Clone)]
17pub struct ClfLogger {
18 service_log_target: String,
19}
20
21impl ClfLogger {
22 pub fn new(service_name: &str) -> Self {
23 Self {
24 service_log_target: format!("clf::{service_name}"),
25 }
26 }
27}
28
29impl<S> Layer<S> for ClfLogger {
30 type Service = ClfLoggerService<S>;
31
32 fn layer(&self, inner: S) -> Self::Service {
33 ClfLoggerService {
34 inner,
35 service_log_target: self.service_log_target.clone(),
36 }
37 }
38}
39
40#[derive(Clone)]
41pub struct ClfLoggerService<S> {
42 inner: S,
43 service_log_target: String,
44}
45
46impl<S> Service<Request> for ClfLoggerService<S>
47where
48 S: Service<Request, Response = Response> + Clone + Send + 'static,
49 S::Future: Send + 'static,
50 S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
51{
52 type Response = S::Response;
53 type Error = S::Error;
54 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
55
56 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
57 self.inner.poll_ready(cx)
58 }
59
60 fn call(&mut self, req: Request) -> Self::Future {
61 let not_ready_inner = self.inner.clone();
62 let mut inner = std::mem::replace(&mut self.inner, not_ready_inner);
63
64 let service_name = self.service_log_target.clone();
65
66 Box::pin(async move {
67 let method = req.method().clone();
68 let uri = req.uri().clone();
69 let version = match req.version() {
70 Version::HTTP_09 => "HTTP/0.9",
71 Version::HTTP_10 => "HTTP/1.0",
72 Version::HTTP_11 => "HTTP/1.1",
73 Version::HTTP_2 => "HTTP/2.0",
74 Version::HTTP_3 => "HTTP/3.0",
75 _ => "HTTP/1.1",
76 };
77
78 let host = cf_connecting_ip(req.headers())
79 .ok()
80 .or_else(|| cloudfront_viewer_address(req.headers()).ok())
81 .or_else(|| fly_client_ip(req.headers()).ok())
82 .or_else(|| x_real_ip(req.headers()).ok())
83 .or_else(|| true_client_ip(req.headers()).ok())
84 .or_else(|| rightmost_forwarded(req.headers()).ok())
85 .or_else(|| rightmost_x_forwarded_for(req.headers()).ok())
86 .or_else(|| {
87 req.extensions()
88 .get::<ConnectInfo<SocketAddr>>()
89 .map(|ci| ci.ip())
90 });
91
92 let host = host.map_or_else(|| "-".to_string(), |a| a.to_string());
93
94 let response = inner.call(req).await?;
95
96 let status = response.status();
97 let status_code = status.as_u16();
98
99 let now: DateTime<Utc> = Utc::now();
101 let timestamp = format!("[{}]", now.format("%d/%b/%Y:%H:%M:%S %z"));
102
103 let level = if status.is_server_error() {
104 Level::Error
105 } else if status.is_client_error() {
106 Level::Warn
107 } else {
108 Level::Info
109 };
110
111 log!(target:&service_name, level, "{host} - - {timestamp} {method} {uri} {version} {status_code} -");
113
114 Ok(response)
115 })
116 }
117}