1use std::sync::Arc;
4
5use serde::Serialize;
6use serde_json::{Map, Value};
7use tracing::Level;
8
9use super::event::LogEvent;
10use crate::error::Result;
11use crate::extract::{FromRequest, RequestContext};
12
13const DEFAULT_CONTEXT: &str = "app";
15const REQUEST_ID_HEADER: &str = "x-request-id";
17
18#[derive(Clone)]
25pub struct Logger {
26 context: Arc<str>,
27 base: Arc<LogFields>,
28}
29
30enum LogFields {
31 Empty,
32 Field {
33 parent: Arc<LogFields>,
34 key: &'static str,
35 value: Value,
36 },
37}
38
39impl Logger {
40 pub fn new(context: impl AsRef<str>) -> Self {
42 Self {
43 context: Arc::from(context.as_ref()),
44 base: Arc::new(LogFields::Empty),
45 }
46 }
47
48 pub(crate) fn framework(context: &'static str) -> Self {
50 Self::new(context)
51 }
52
53 pub fn context(&self) -> &str {
55 &self.context
56 }
57
58 pub fn for_context(&self, context: impl AsRef<str>) -> Logger {
60 Logger {
61 context: Arc::from(context.as_ref()),
62 base: self.base.clone(),
63 }
64 }
65
66 pub fn with_field<T: Serialize>(&self, key: &'static str, value: T) -> Logger {
68 if let Ok(value) = serde_json::to_value(value) {
69 return Logger {
70 context: self.context.clone(),
71 base: Arc::new(LogFields::Field {
72 parent: self.base.clone(),
73 key,
74 value,
75 }),
76 };
77 }
78 self.clone()
79 }
80
81 fn event(&self, level: Level, message: impl Into<String>) -> LogEvent {
83 let mut fields = Map::new();
84 populate_fields(&self.base, &mut fields);
85 LogEvent {
86 level,
87 context: self.context.clone(),
88 message: message.into(),
89 fields,
90 error: None,
91 }
92 }
93
94 pub fn trace(&self, message: impl Into<String>) -> LogEvent {
96 self.event(Level::TRACE, message)
97 }
98
99 pub fn debug(&self, message: impl Into<String>) -> LogEvent {
101 self.event(Level::DEBUG, message)
102 }
103
104 pub fn info(&self, message: impl Into<String>) -> LogEvent {
106 self.event(Level::INFO, message)
107 }
108
109 pub fn warn(&self, message: impl Into<String>) -> LogEvent {
111 self.event(Level::WARN, message)
112 }
113
114 pub fn error(&self, message: impl Into<String>) -> LogEvent {
116 self.event(Level::ERROR, message)
117 }
118
119 pub fn span(&self, name: impl Into<String>) -> super::LogSpan {
121 let mut fields = Map::new();
122 populate_fields(&self.base, &mut fields);
123 super::LogSpan::new(self.context.clone(), name, fields)
124 }
125
126 pub fn instrument(&self, name: impl Into<String>) -> super::LogSpan {
128 let mut fields = Map::new();
129 populate_fields(&self.base, &mut fields);
130 super::LogSpan::new(self.context.clone(), name, fields)
131 }
132}
133
134fn populate_fields(fields: &Arc<LogFields>, out: &mut Map<String, Value>) {
135 match fields.as_ref() {
136 LogFields::Empty => {}
137 LogFields::Field { parent, key, value } => {
138 populate_fields(parent, out);
139 out.insert((*key).to_owned(), value.clone());
140 }
141 }
142}
143
144impl FromRequest for Logger {
145 fn from_request(
146 ctx: &RequestContext,
147 ) -> impl std::future::Future<Output = Result<Self>> + Send {
148 let mut base: Vec<(&'static str, Value)> = Vec::new();
149 if let Some(request_id) = ctx
150 .headers()
151 .get(REQUEST_ID_HEADER)
152 .and_then(|value| value.to_str().ok())
153 {
154 base.push(("request_id", Value::String(request_id.to_owned())));
155 }
156 base.push(("method", Value::String(ctx.method().to_string())));
157 base.push(("path", Value::String(ctx.uri().path().to_owned())));
158
159 let logger = Logger {
160 context: Arc::from(DEFAULT_CONTEXT),
161 base: base
162 .into_iter()
163 .fold(Arc::new(LogFields::Empty), |parent, (key, value)| {
164 Arc::new(LogFields::Field { parent, key, value })
165 }),
166 };
167 async move { Ok(logger) }
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use std::io::Write;
175 use std::sync::{Arc, Mutex};
176
177 use super::super::format::{JsonFormat, TorkFormat};
178 use crate::extract::FromRequest;
179 use crate::{box_body, PathParams, RequestContext, StateMap};
180 use bytes::Bytes;
181 use http_body_util::Full;
182 use serde::ser::Error as _;
183 use serde::Serializer;
184 use std::sync::Arc as StdArc;
185 use tracing_subscriber::fmt::MakeWriter;
186 use tracing_subscriber::prelude::*;
187
188 #[derive(Clone)]
189 struct BufWriter(Arc<Mutex<Vec<u8>>>);
190
191 struct BadSerialize;
192
193 impl serde::Serialize for BadSerialize {
194 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
195 where
196 S: Serializer,
197 {
198 Err(S::Error::custom("nope"))
199 }
200 }
201
202 impl Write for BufWriter {
203 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
204 self.0.lock().unwrap().extend_from_slice(buf);
205 Ok(buf.len())
206 }
207 fn flush(&mut self) -> std::io::Result<()> {
208 Ok(())
209 }
210 }
211
212 impl<'a> MakeWriter<'a> for BufWriter {
213 type Writer = BufWriter;
214 fn make_writer(&'a self) -> Self::Writer {
215 self.clone()
216 }
217 }
218
219 #[test]
220 fn emits_context_message_and_fields() {
221 let buffer = Arc::new(Mutex::new(Vec::new()));
222 let layer = tracing_subscriber::fmt::layer()
223 .event_format(TorkFormat::Json(JsonFormat {
224 service_name: "svc".to_owned(),
225 }))
226 .with_writer(BufWriter(buffer.clone()));
227 let subscriber = tracing_subscriber::registry().with(layer);
228
229 tracing::subscriber::with_default(subscriber, || {
230 Logger::new("PaymentService")
231 .with_field("tenant", "acme")
232 .info("Charging user")
233 .field("user_id", 42)
234 .emit();
235 });
236
237 let bytes = buffer.lock().unwrap().clone();
238 let output = String::from_utf8(bytes).unwrap();
239 assert!(
240 output.contains("\"context\":\"PaymentService\""),
241 "{output}"
242 );
243 assert!(output.contains("\"message\":\"Charging user\""), "{output}");
244 assert!(output.contains("\"user_id\":42"), "{output}");
245 assert!(output.contains("\"tenant\":\"acme\""), "{output}");
246 }
247
248 #[test]
249 fn for_context_and_framework_preserve_base_fields() {
250 let logger = Logger::framework("startup").with_field("tenant", "acme");
251 let relabeled = logger.for_context("payments");
252
253 assert_eq!(logger.context(), "startup");
254 assert_eq!(relabeled.context(), "payments");
255
256 let output = {
257 let buffer = Arc::new(Mutex::new(Vec::new()));
258 let layer = tracing_subscriber::fmt::layer()
259 .event_format(TorkFormat::Json(JsonFormat {
260 service_name: "svc".to_owned(),
261 }))
262 .with_writer(BufWriter(buffer.clone()));
263 let subscriber = tracing_subscriber::registry().with(layer);
264 tracing::subscriber::with_default(subscriber, || {
265 relabeled.info("Boot").emit();
266 });
267 let bytes = buffer.lock().unwrap().clone();
268 String::from_utf8(bytes).unwrap()
269 };
270 assert!(output.contains("\"context\":\"payments\""), "{output}");
271 assert!(output.contains("\"tenant\":\"acme\""), "{output}");
272 }
273
274 #[test]
275 fn with_field_ignores_unserializable_values() {
276 let logger = Logger::new("logger").with_field("tenant", BadSerialize);
277 let output = {
278 let buffer = Arc::new(Mutex::new(Vec::new()));
279 let layer = tracing_subscriber::fmt::layer()
280 .event_format(TorkFormat::Json(JsonFormat {
281 service_name: "svc".to_owned(),
282 }))
283 .with_writer(BufWriter(buffer.clone()));
284 let subscriber = tracing_subscriber::registry().with(layer);
285 tracing::subscriber::with_default(subscriber, || {
286 logger.info("Hello").emit();
287 });
288 let bytes = buffer.lock().unwrap().clone();
289 String::from_utf8(bytes).unwrap()
290 };
291 assert!(!output.contains("tenant"), "{output}");
292 }
293
294 #[test]
295 fn trace_debug_warn_error_span_and_instrument_cover_helper_methods() {
296 let buffer = Arc::new(Mutex::new(Vec::new()));
297 let layer = tracing_subscriber::fmt::layer()
298 .event_format(TorkFormat::Json(JsonFormat {
299 service_name: "svc".to_owned(),
300 }))
301 .with_writer(BufWriter(buffer.clone()));
302 let subscriber = tracing_subscriber::registry().with(layer);
303
304 tracing::subscriber::with_default(subscriber, || {
305 Logger::framework("boot").trace("trace").emit();
306 Logger::new("worker").debug("debug").emit();
307 Logger::new("worker").warn("warn").emit();
308 Logger::new("worker").error("error").emit();
309 let _ = Logger::new("worker").span("span").enter();
310 let _ = Logger::new("worker").instrument("task");
311 });
312
313 let bytes = buffer.lock().unwrap().clone();
314 let output = String::from_utf8(bytes).unwrap();
315 assert!(output.contains("\"message\":\"trace\""), "{output}");
316 assert!(output.contains("\"message\":\"debug\""), "{output}");
317 assert!(output.contains("\"message\":\"warn\""), "{output}");
318 assert!(output.contains("\"message\":\"error\""), "{output}");
319 }
320
321 #[tokio::test]
322 async fn from_request_uses_request_metadata_and_default_context() {
323 let head = http::Request::builder()
324 .method("GET")
325 .uri("/logs")
326 .header("x-request-id", "req-123")
327 .body(())
328 .unwrap()
329 .into_parts()
330 .0;
331 let ctx = RequestContext::new(
332 head,
333 PathParams::new(),
334 StdArc::new(StateMap::new()),
335 box_body(Full::new(Bytes::new())),
336 );
337
338 let logger = Logger::from_request(&ctx).await.unwrap();
339 assert_eq!(logger.context(), "app");
340 let output = {
341 let buffer = Arc::new(Mutex::new(Vec::new()));
342 let layer = tracing_subscriber::fmt::layer()
343 .event_format(TorkFormat::Json(JsonFormat {
344 service_name: "svc".to_owned(),
345 }))
346 .with_writer(BufWriter(buffer.clone()));
347 let subscriber = tracing_subscriber::registry().with(layer);
348 tracing::subscriber::with_default(subscriber, || {
349 logger.info("Hello").emit();
350 });
351 let bytes = buffer.lock().unwrap().clone();
352 String::from_utf8(bytes).unwrap()
353 };
354 assert!(output.contains("\"request_id\":\"req-123\""), "{output}");
355 assert!(output.contains("\"method\":\"GET\""), "{output}");
356 assert!(output.contains("\"path\":\"/logs\""), "{output}");
357 }
358}