Skip to main content

tork_core/logging/
logger.rs

1//! The injectable, context-aware logger.
2
3use std::sync::Arc;
4
5use serde::Serialize;
6use serde_json::{Map, Value};
7use tracing::Level;
8
9use super::event::LogEvent;
10use crate::error::Result;
11use crate::extract::{FromRequest, RequestContext};
12
13/// Default context for a logger that was not given one.
14const DEFAULT_CONTEXT: &str = "app";
15/// Header carrying the request identifier.
16const REQUEST_ID_HEADER: &str = "x-request-id";
17
18/// A context-aware logger.
19///
20/// Injected into handlers and services; the `#[derive(Inject)]` macro gives a
21/// `logger: Logger` field the surrounding struct's name as its context. Each log
22/// line carries that context and any request-scoped fields (request id, method,
23/// path) captured when the logger was resolved.
24#[derive(Clone)]
25pub struct Logger {
26    context: Arc<str>,
27    base: Arc<LogFields>,
28}
29
30enum LogFields {
31    Empty,
32    Field {
33        parent: Arc<LogFields>,
34        key: &'static str,
35        value: Value,
36    },
37}
38
39impl Logger {
40    /// Creates a logger with the given context and no base fields.
41    pub fn new(context: impl AsRef<str>) -> Self {
42        Self {
43            context: Arc::from(context.as_ref()),
44            base: Arc::new(LogFields::Empty),
45        }
46    }
47
48    /// Creates a framework-internal logger (used for startup and request logs).
49    pub(crate) fn framework(context: &'static str) -> Self {
50        Self::new(context)
51    }
52
53    /// Returns the logger's context (the name shown in `[Context]`).
54    pub fn context(&self) -> &str {
55        &self.context
56    }
57
58    /// Returns a logger with a different context, keeping the base fields.
59    pub fn for_context(&self, context: impl AsRef<str>) -> Logger {
60        Logger {
61            context: Arc::from(context.as_ref()),
62            base: self.base.clone(),
63        }
64    }
65
66    /// Returns a logger with an extra field included on every record.
67    pub fn with_field<T: Serialize>(&self, key: &'static str, value: T) -> Logger {
68        if let Ok(value) = serde_json::to_value(value) {
69            return Logger {
70                context: self.context.clone(),
71                base: Arc::new(LogFields::Field {
72                    parent: self.base.clone(),
73                    key,
74                    value,
75                }),
76            };
77        }
78        self.clone()
79    }
80
81    /// Starts a record at the given level.
82    fn event(&self, level: Level, message: impl Into<String>) -> LogEvent {
83        let mut fields = Map::new();
84        populate_fields(&self.base, &mut fields);
85        LogEvent {
86            level,
87            context: self.context.clone(),
88            message: message.into(),
89            fields,
90            error: None,
91        }
92    }
93
94    /// Starts a `TRACE` record.
95    pub fn trace(&self, message: impl Into<String>) -> LogEvent {
96        self.event(Level::TRACE, message)
97    }
98
99    /// Starts a `DEBUG` record.
100    pub fn debug(&self, message: impl Into<String>) -> LogEvent {
101        self.event(Level::DEBUG, message)
102    }
103
104    /// Starts an `INFO` record.
105    pub fn info(&self, message: impl Into<String>) -> LogEvent {
106        self.event(Level::INFO, message)
107    }
108
109    /// Starts a `WARN` record.
110    pub fn warn(&self, message: impl Into<String>) -> LogEvent {
111        self.event(Level::WARN, message)
112    }
113
114    /// Starts an `ERROR` record.
115    pub fn error(&self, message: impl Into<String>) -> LogEvent {
116        self.event(Level::ERROR, message)
117    }
118
119    /// Builds a span for an operation, to [`enter`](super::LogSpan::enter) a scope.
120    pub fn span(&self, name: impl Into<String>) -> super::LogSpan {
121        let mut fields = Map::new();
122        populate_fields(&self.base, &mut fields);
123        super::LogSpan::new(self.context.clone(), name, fields)
124    }
125
126    /// Builds a span to [`run`](super::LogSpan::run) a future inside.
127    pub fn instrument(&self, name: impl Into<String>) -> super::LogSpan {
128        let mut fields = Map::new();
129        populate_fields(&self.base, &mut fields);
130        super::LogSpan::new(self.context.clone(), name, fields)
131    }
132}
133
134fn populate_fields(fields: &Arc<LogFields>, out: &mut Map<String, Value>) {
135    match fields.as_ref() {
136        LogFields::Empty => {}
137        LogFields::Field { parent, key, value } => {
138            populate_fields(parent, out);
139            out.insert((*key).to_owned(), value.clone());
140        }
141    }
142}
143
144impl FromRequest for Logger {
145    fn from_request(
146        ctx: &RequestContext,
147    ) -> impl std::future::Future<Output = Result<Self>> + Send {
148        let mut base: Vec<(&'static str, Value)> = Vec::new();
149        if let Some(request_id) = ctx
150            .headers()
151            .get(REQUEST_ID_HEADER)
152            .and_then(|value| value.to_str().ok())
153        {
154            base.push(("request_id", Value::String(request_id.to_owned())));
155        }
156        base.push(("method", Value::String(ctx.method().to_string())));
157        base.push(("path", Value::String(ctx.uri().path().to_owned())));
158
159        let logger = Logger {
160            context: Arc::from(DEFAULT_CONTEXT),
161            base: base
162                .into_iter()
163                .fold(Arc::new(LogFields::Empty), |parent, (key, value)| {
164                    Arc::new(LogFields::Field { parent, key, value })
165                }),
166        };
167        async move { Ok(logger) }
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use std::io::Write;
175    use std::sync::{Arc, Mutex};
176
177    use super::super::format::{JsonFormat, TorkFormat};
178    use crate::extract::FromRequest;
179    use crate::{box_body, PathParams, RequestContext, StateMap};
180    use bytes::Bytes;
181    use http_body_util::Full;
182    use serde::ser::Error as _;
183    use serde::Serializer;
184    use std::sync::Arc as StdArc;
185    use tracing_subscriber::fmt::MakeWriter;
186    use tracing_subscriber::prelude::*;
187
188    #[derive(Clone)]
189    struct BufWriter(Arc<Mutex<Vec<u8>>>);
190
191    struct BadSerialize;
192
193    impl serde::Serialize for BadSerialize {
194        fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
195        where
196            S: Serializer,
197        {
198            Err(S::Error::custom("nope"))
199        }
200    }
201
202    impl Write for BufWriter {
203        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
204            self.0.lock().unwrap().extend_from_slice(buf);
205            Ok(buf.len())
206        }
207        fn flush(&mut self) -> std::io::Result<()> {
208            Ok(())
209        }
210    }
211
212    impl<'a> MakeWriter<'a> for BufWriter {
213        type Writer = BufWriter;
214        fn make_writer(&'a self) -> Self::Writer {
215            self.clone()
216        }
217    }
218
219    #[test]
220    fn emits_context_message_and_fields() {
221        let buffer = Arc::new(Mutex::new(Vec::new()));
222        let layer = tracing_subscriber::fmt::layer()
223            .event_format(TorkFormat::Json(JsonFormat {
224                service_name: "svc".to_owned(),
225            }))
226            .with_writer(BufWriter(buffer.clone()));
227        let subscriber = tracing_subscriber::registry().with(layer);
228
229        tracing::subscriber::with_default(subscriber, || {
230            Logger::new("PaymentService")
231                .with_field("tenant", "acme")
232                .info("Charging user")
233                .field("user_id", 42)
234                .emit();
235        });
236
237        let bytes = buffer.lock().unwrap().clone();
238        let output = String::from_utf8(bytes).unwrap();
239        assert!(
240            output.contains("\"context\":\"PaymentService\""),
241            "{output}"
242        );
243        assert!(output.contains("\"message\":\"Charging user\""), "{output}");
244        assert!(output.contains("\"user_id\":42"), "{output}");
245        assert!(output.contains("\"tenant\":\"acme\""), "{output}");
246    }
247
248    #[test]
249    fn for_context_and_framework_preserve_base_fields() {
250        let logger = Logger::framework("startup").with_field("tenant", "acme");
251        let relabeled = logger.for_context("payments");
252
253        assert_eq!(logger.context(), "startup");
254        assert_eq!(relabeled.context(), "payments");
255
256        let output = {
257            let buffer = Arc::new(Mutex::new(Vec::new()));
258            let layer = tracing_subscriber::fmt::layer()
259                .event_format(TorkFormat::Json(JsonFormat {
260                    service_name: "svc".to_owned(),
261                }))
262                .with_writer(BufWriter(buffer.clone()));
263            let subscriber = tracing_subscriber::registry().with(layer);
264            tracing::subscriber::with_default(subscriber, || {
265                relabeled.info("Boot").emit();
266            });
267            let bytes = buffer.lock().unwrap().clone();
268            String::from_utf8(bytes).unwrap()
269        };
270        assert!(output.contains("\"context\":\"payments\""), "{output}");
271        assert!(output.contains("\"tenant\":\"acme\""), "{output}");
272    }
273
274    #[test]
275    fn with_field_ignores_unserializable_values() {
276        let logger = Logger::new("logger").with_field("tenant", BadSerialize);
277        let output = {
278            let buffer = Arc::new(Mutex::new(Vec::new()));
279            let layer = tracing_subscriber::fmt::layer()
280                .event_format(TorkFormat::Json(JsonFormat {
281                    service_name: "svc".to_owned(),
282                }))
283                .with_writer(BufWriter(buffer.clone()));
284            let subscriber = tracing_subscriber::registry().with(layer);
285            tracing::subscriber::with_default(subscriber, || {
286                logger.info("Hello").emit();
287            });
288            let bytes = buffer.lock().unwrap().clone();
289            String::from_utf8(bytes).unwrap()
290        };
291        assert!(!output.contains("tenant"), "{output}");
292    }
293
294    #[test]
295    fn trace_debug_warn_error_span_and_instrument_cover_helper_methods() {
296        let buffer = Arc::new(Mutex::new(Vec::new()));
297        let layer = tracing_subscriber::fmt::layer()
298            .event_format(TorkFormat::Json(JsonFormat {
299                service_name: "svc".to_owned(),
300            }))
301            .with_writer(BufWriter(buffer.clone()));
302        let subscriber = tracing_subscriber::registry().with(layer);
303
304        tracing::subscriber::with_default(subscriber, || {
305            Logger::framework("boot").trace("trace").emit();
306            Logger::new("worker").debug("debug").emit();
307            Logger::new("worker").warn("warn").emit();
308            Logger::new("worker").error("error").emit();
309            let _ = Logger::new("worker").span("span").enter();
310            let _ = Logger::new("worker").instrument("task");
311        });
312
313        let bytes = buffer.lock().unwrap().clone();
314        let output = String::from_utf8(bytes).unwrap();
315        assert!(output.contains("\"message\":\"trace\""), "{output}");
316        assert!(output.contains("\"message\":\"debug\""), "{output}");
317        assert!(output.contains("\"message\":\"warn\""), "{output}");
318        assert!(output.contains("\"message\":\"error\""), "{output}");
319    }
320
321    #[tokio::test]
322    async fn from_request_uses_request_metadata_and_default_context() {
323        let head = http::Request::builder()
324            .method("GET")
325            .uri("/logs")
326            .header("x-request-id", "req-123")
327            .body(())
328            .unwrap()
329            .into_parts()
330            .0;
331        let ctx = RequestContext::new(
332            head,
333            PathParams::new(),
334            StdArc::new(StateMap::new()),
335            box_body(Full::new(Bytes::new())),
336        );
337
338        let logger = Logger::from_request(&ctx).await.unwrap();
339        assert_eq!(logger.context(), "app");
340        let output = {
341            let buffer = Arc::new(Mutex::new(Vec::new()));
342            let layer = tracing_subscriber::fmt::layer()
343                .event_format(TorkFormat::Json(JsonFormat {
344                    service_name: "svc".to_owned(),
345                }))
346                .with_writer(BufWriter(buffer.clone()));
347            let subscriber = tracing_subscriber::registry().with(layer);
348            tracing::subscriber::with_default(subscriber, || {
349                logger.info("Hello").emit();
350            });
351            let bytes = buffer.lock().unwrap().clone();
352            String::from_utf8(bytes).unwrap()
353        };
354        assert!(output.contains("\"request_id\":\"req-123\""), "{output}");
355        assert!(output.contains("\"method\":\"GET\""), "{output}");
356        assert!(output.contains("\"path\":\"/logs\""), "{output}");
357    }
358}