Skip to main content

winston_tracing/
lib.rs

1use logform::LogInfo;
2use std::{collections::HashMap, sync::Arc};
3use tracing::{Event, Level, Subscriber};
4use tracing_subscriber::{layer::Context, registry::LookupSpan, Layer};
5use winston::Logger;
6
7struct SpanFields(HashMap<String, serde_json::Value>);
8
9/// A [`tracing_subscriber::Layer`] that routes tracing events into a Winston [`Logger`].
10///
11/// Span fields are collected and merged into every event that fires within the span,
12/// with child span fields overriding parent fields, and event fields overriding both.
13pub struct WinstonLayer {
14    logger: Arc<Logger>,
15}
16
17impl WinstonLayer {
18    pub fn new(logger: impl Into<Arc<Logger>>) -> Self {
19        Self {
20            logger: logger.into(),
21        }
22    }
23}
24
25/// Extension trait that lets a [`Logger`] (or `Arc<Logger>`) produce a
26/// [`tracing_subscriber`] layer directly.
27///
28/// # Example
29///
30/// ```rust,no_run
31/// use tracing_subscriber::prelude::*;
32/// use winston::Logger;
33/// use winston_tracing::prelude::*;
34///
35/// tracing_subscriber::registry()
36///     .with(
37///         Logger::builder()
38///             .transport(winston::transports::stdout())
39///             .build()
40///             .layer(),
41///     )
42///     .init();
43///
44/// tracing::info!(user_id = 42, "user logged in");
45/// ```
46///
47/// When you need a handle to the logger after handing it to the subscriber
48/// (e.g. to flush on shutdown), wrap in `Arc` first:
49///
50/// ```rust,no_run
51/// use std::sync::Arc;
52/// use tracing_subscriber::prelude::*;
53/// use winston::Logger;
54/// use winston_tracing::prelude::*;
55///
56/// let logger = Arc::new(
57///     Logger::builder()
58///         .transport(winston::transports::stdout())
59///         .build(),
60/// );
61///
62/// tracing_subscriber::registry()
63///     .with(Arc::clone(&logger).layer())
64///     .init();
65///
66/// tracing::info!("hello");
67/// logger.flush().unwrap();
68/// ```
69pub trait LoggerTracingExt {
70    fn layer(self) -> WinstonLayer;
71}
72
73impl LoggerTracingExt for Logger {
74    fn layer(self) -> WinstonLayer {
75        WinstonLayer::new(self)
76    }
77}
78
79impl LoggerTracingExt for Arc<Logger> {
80    fn layer(self) -> WinstonLayer {
81        WinstonLayer::new(self)
82    }
83}
84
85impl<S> Layer<S> for WinstonLayer
86where
87    S: Subscriber + for<'a> LookupSpan<'a>,
88{
89    fn on_new_span(
90        &self,
91        attrs: &tracing::span::Attributes<'_>,
92        id: &tracing::span::Id,
93        ctx: Context<'_, S>,
94    ) {
95        let span = ctx.span(id).expect("span not found, this is a bug");
96        let mut fields = HashMap::new();
97        // Seed with the span name so child events know which span they fired in.
98        fields.insert(
99            "span".to_string(),
100            serde_json::Value::String(span.name().to_string()),
101        );
102        attrs.record(&mut FieldVisitor(&mut fields));
103        span.extensions_mut().insert(SpanFields(fields));
104    }
105
106    fn on_record(
107        &self,
108        id: &tracing::span::Id,
109        values: &tracing::span::Record<'_>,
110        ctx: Context<'_, S>,
111    ) {
112        let span = ctx.span(id).expect("span not found, this is a bug");
113        let mut extensions = span.extensions_mut();
114        if let Some(sf) = extensions.get_mut::<SpanFields>() {
115            values.record(&mut FieldVisitor(&mut sf.0));
116        }
117    }
118
119    fn enabled(&self, metadata: &tracing::Metadata<'_>, _ctx: Context<'_, S>) -> bool {
120        self.logger
121            .is_level_enabled_fast(map_level(metadata.level()))
122    }
123
124    fn max_level_hint(&self) -> Option<tracing_subscriber::filter::LevelFilter> {
125        use tracing::Level;
126        use tracing_subscriber::filter::LevelFilter;
127        for (level, filter) in [
128            (Level::TRACE, LevelFilter::TRACE),
129            (Level::DEBUG, LevelFilter::DEBUG),
130            (Level::INFO, LevelFilter::INFO),
131            (Level::WARN, LevelFilter::WARN),
132            (Level::ERROR, LevelFilter::ERROR),
133        ] {
134            if self.logger.is_level_enabled_fast(map_level(&level)) {
135                return Some(filter);
136            }
137        }
138        Some(LevelFilter::OFF)
139    }
140
141    fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
142        let level = map_level(event.metadata().level()).to_string();
143
144        let mut fields: HashMap<String, serde_json::Value> = HashMap::new();
145
146        // Walk ancestor spans outermost → innermost so that more specific
147        // (closer) spans override broader context.
148        if let Some(scope) = ctx.event_scope(event) {
149            let spans: Vec<_> = scope.collect();
150            for span in spans.iter().rev() {
151                if let Some(sf) = span.extensions().get::<SpanFields>() {
152                    for (k, v) in &sf.0 {
153                        fields.insert(k.clone(), v.clone());
154                    }
155                }
156            }
157        }
158
159        // Event fields are most specific and override span context.
160        event.record(&mut FieldVisitor(&mut fields));
161
162        // "message" is tracing's conventional field name for the primary log line.
163        let message = fields
164            .remove("message")
165            .map(|v| match v {
166                serde_json::Value::String(s) => s,
167                other => other.to_string(),
168            })
169            .unwrap_or_default();
170
171        fields.insert(
172            "target".to_string(),
173            serde_json::Value::String(event.metadata().target().to_string()),
174        );
175
176        if let Some(file) = event.metadata().file() {
177            fields.insert(
178                "file".to_string(),
179                serde_json::Value::String(file.to_string()),
180            );
181        }
182        if let Some(line) = event.metadata().line() {
183            fields.insert("line".to_string(), serde_json::Value::Number(line.into()));
184        }
185
186        self.logger.log(LogInfo {
187            level,
188            message,
189            meta: fields,
190        });
191    }
192}
193
194fn map_level(level: &Level) -> &'static str {
195    match *level {
196        Level::ERROR => "error",
197        Level::WARN => "warn",
198        Level::INFO => "info",
199        Level::DEBUG => "debug",
200        Level::TRACE => "trace",
201    }
202}
203
204struct FieldVisitor<'a>(&'a mut HashMap<String, serde_json::Value>);
205
206impl tracing::field::Visit for FieldVisitor<'_> {
207    fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
208        let number = serde_json::Number::from_f64(value).unwrap_or_else(|| 0.into());
209        self.0
210            .insert(field.name().to_string(), serde_json::Value::Number(number));
211    }
212
213    fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
214        self.0.insert(
215            field.name().to_string(),
216            serde_json::Value::Number(value.into()),
217        );
218    }
219
220    fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
221        self.0.insert(
222            field.name().to_string(),
223            serde_json::Value::Number(value.into()),
224        );
225    }
226
227    fn record_i128(&mut self, field: &tracing::field::Field, value: i128) {
228        // serde_json::Number doesn't support i128; store as string to avoid silent truncation.
229        self.0.insert(
230            field.name().to_string(),
231            serde_json::Value::String(value.to_string()),
232        );
233    }
234
235    fn record_u128(&mut self, field: &tracing::field::Field, value: u128) {
236        self.0.insert(
237            field.name().to_string(),
238            serde_json::Value::String(value.to_string()),
239        );
240    }
241
242    fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
243        self.0
244            .insert(field.name().to_string(), serde_json::Value::Bool(value));
245    }
246
247    fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
248        self.0.insert(
249            field.name().to_string(),
250            serde_json::Value::String(value.to_string()),
251        );
252    }
253
254    fn record_error(
255        &mut self,
256        field: &tracing::field::Field,
257        value: &(dyn std::error::Error + 'static),
258    ) {
259        self.0.insert(
260            field.name().to_string(),
261            serde_json::Value::String(value.to_string()),
262        );
263    }
264
265    fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
266        self.0.insert(
267            field.name().to_string(),
268            serde_json::Value::String(format!("{value:?}")),
269        );
270    }
271}
272
273pub mod prelude {
274    pub use super::LoggerTracingExt;
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use std::sync::{Arc, Mutex};
281    use tracing_subscriber::prelude::*;
282    use winston_transport::Transport;
283
284    #[derive(Clone)]
285    struct CaptureTransport(Arc<Mutex<Vec<LogInfo>>>);
286
287    impl Transport<LogInfo> for CaptureTransport {
288        fn log(&self, info: LogInfo) {
289            self.0.lock().unwrap().push(info);
290        }
291    }
292
293    // level("trace") captures everything; passthrough() leaves LogInfo fields untouched
294    // so assertions see raw level/message/meta rather than the default formatted message string.
295    fn make_logger_and_capture() -> (Arc<Logger>, Arc<Mutex<Vec<LogInfo>>>) {
296        let captured = Arc::new(Mutex::new(Vec::new()));
297        let logger = Arc::new(
298            Logger::builder()
299                .level("trace")
300                .format(logform::passthrough())
301                .transport(CaptureTransport(captured.clone()))
302                .build(),
303        );
304        (logger, captured)
305    }
306
307    #[test]
308    fn event_fields_become_meta() {
309        let (logger, captured) = make_logger_and_capture();
310        let _guard = tracing_subscriber::registry()
311            .with(Arc::clone(&logger).layer())
312            .set_default();
313
314        tracing::info!(user_id = 42u64, "login");
315        logger.flush().unwrap();
316
317        let logs = captured.lock().unwrap();
318        assert_eq!(logs.len(), 1);
319        let entry = &logs[0];
320        assert_eq!(entry.level, "info");
321        assert_eq!(entry.message, "login");
322        assert_eq!(entry.meta["user_id"], serde_json::json!(42u64));
323    }
324
325    #[test]
326    fn span_fields_propagate_into_events() {
327        let (logger, captured) = make_logger_and_capture();
328        let _guard = tracing_subscriber::registry()
329            .with(Arc::clone(&logger).layer())
330            .set_default();
331
332        let span = tracing::info_span!("request", request_id = "abc-123");
333        let _enter = span.enter();
334        tracing::warn!("something went wrong");
335        logger.flush().unwrap();
336
337        let logs = captured.lock().unwrap();
338        assert_eq!(logs.len(), 1);
339        let entry = &logs[0];
340        assert_eq!(entry.level, "warn");
341        assert_eq!(entry.message, "something went wrong");
342        assert_eq!(entry.meta["request_id"], serde_json::json!("abc-123"));
343        assert_eq!(entry.meta["span"], serde_json::json!("request"));
344    }
345
346    #[test]
347    fn event_fields_override_span_fields() {
348        let (logger, captured) = make_logger_and_capture();
349        let _guard = tracing_subscriber::registry()
350            .with(Arc::clone(&logger).layer())
351            .set_default();
352
353        let span = tracing::info_span!("work", key = "from-span");
354        let _enter = span.enter();
355        tracing::info!(key = "from-event", "override");
356        logger.flush().unwrap();
357
358        let logs = captured.lock().unwrap();
359        assert_eq!(logs[0].meta["key"], serde_json::json!("from-event"));
360    }
361
362    #[test]
363    fn level_mapping() {
364        let (logger, captured) = make_logger_and_capture();
365        let _guard = tracing_subscriber::registry()
366            .with(Arc::clone(&logger).layer())
367            .set_default();
368
369        tracing::error!("e");
370        tracing::warn!("w");
371        tracing::info!("i");
372        tracing::debug!("d");
373        tracing::trace!("t");
374        logger.flush().unwrap();
375
376        let logs = captured.lock().unwrap();
377        let levels: Vec<&str> = logs.iter().map(|l| l.level.as_str()).collect();
378        assert_eq!(levels, ["error", "warn", "info", "debug", "trace"]);
379    }
380}