1use logform::LogInfo;
2use std::{collections::HashMap, sync::Arc};
3use tracing::{Event, Level, Subscriber};
4use tracing_subscriber::{layer::Context, registry::LookupSpan, Layer};
5use winston::Logger;
6
7struct SpanFields(HashMap<String, serde_json::Value>);
8
9pub struct WinstonLayer {
14 logger: Arc<Logger>,
15}
16
17impl WinstonLayer {
18 pub fn new(logger: impl Into<Arc<Logger>>) -> Self {
19 Self {
20 logger: logger.into(),
21 }
22 }
23}
24
25pub trait LoggerTracingExt {
70 fn layer(self) -> WinstonLayer;
71}
72
73impl LoggerTracingExt for Logger {
74 fn layer(self) -> WinstonLayer {
75 WinstonLayer::new(self)
76 }
77}
78
79impl LoggerTracingExt for Arc<Logger> {
80 fn layer(self) -> WinstonLayer {
81 WinstonLayer::new(self)
82 }
83}
84
85impl<S> Layer<S> for WinstonLayer
86where
87 S: Subscriber + for<'a> LookupSpan<'a>,
88{
89 fn on_new_span(
90 &self,
91 attrs: &tracing::span::Attributes<'_>,
92 id: &tracing::span::Id,
93 ctx: Context<'_, S>,
94 ) {
95 let span = ctx.span(id).expect("span not found, this is a bug");
96 let mut fields = HashMap::new();
97 fields.insert(
99 "span".to_string(),
100 serde_json::Value::String(span.name().to_string()),
101 );
102 attrs.record(&mut FieldVisitor(&mut fields));
103 span.extensions_mut().insert(SpanFields(fields));
104 }
105
106 fn on_record(
107 &self,
108 id: &tracing::span::Id,
109 values: &tracing::span::Record<'_>,
110 ctx: Context<'_, S>,
111 ) {
112 let span = ctx.span(id).expect("span not found, this is a bug");
113 let mut extensions = span.extensions_mut();
114 if let Some(sf) = extensions.get_mut::<SpanFields>() {
115 values.record(&mut FieldVisitor(&mut sf.0));
116 }
117 }
118
119 fn enabled(&self, metadata: &tracing::Metadata<'_>, _ctx: Context<'_, S>) -> bool {
120 self.logger
121 .is_level_enabled_fast(map_level(metadata.level()))
122 }
123
124 fn max_level_hint(&self) -> Option<tracing_subscriber::filter::LevelFilter> {
125 use tracing::Level;
126 use tracing_subscriber::filter::LevelFilter;
127 for (level, filter) in [
128 (Level::TRACE, LevelFilter::TRACE),
129 (Level::DEBUG, LevelFilter::DEBUG),
130 (Level::INFO, LevelFilter::INFO),
131 (Level::WARN, LevelFilter::WARN),
132 (Level::ERROR, LevelFilter::ERROR),
133 ] {
134 if self.logger.is_level_enabled_fast(map_level(&level)) {
135 return Some(filter);
136 }
137 }
138 Some(LevelFilter::OFF)
139 }
140
141 fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
142 let level = map_level(event.metadata().level()).to_string();
143
144 let mut fields: HashMap<String, serde_json::Value> = HashMap::new();
145
146 if let Some(scope) = ctx.event_scope(event) {
149 let spans: Vec<_> = scope.collect();
150 for span in spans.iter().rev() {
151 if let Some(sf) = span.extensions().get::<SpanFields>() {
152 for (k, v) in &sf.0 {
153 fields.insert(k.clone(), v.clone());
154 }
155 }
156 }
157 }
158
159 event.record(&mut FieldVisitor(&mut fields));
161
162 let message = fields
164 .remove("message")
165 .map(|v| match v {
166 serde_json::Value::String(s) => s,
167 other => other.to_string(),
168 })
169 .unwrap_or_default();
170
171 fields.insert(
172 "target".to_string(),
173 serde_json::Value::String(event.metadata().target().to_string()),
174 );
175
176 if let Some(file) = event.metadata().file() {
177 fields.insert(
178 "file".to_string(),
179 serde_json::Value::String(file.to_string()),
180 );
181 }
182 if let Some(line) = event.metadata().line() {
183 fields.insert("line".to_string(), serde_json::Value::Number(line.into()));
184 }
185
186 self.logger.log(LogInfo {
187 level,
188 message,
189 meta: fields,
190 });
191 }
192}
193
194fn map_level(level: &Level) -> &'static str {
195 match *level {
196 Level::ERROR => "error",
197 Level::WARN => "warn",
198 Level::INFO => "info",
199 Level::DEBUG => "debug",
200 Level::TRACE => "trace",
201 }
202}
203
204struct FieldVisitor<'a>(&'a mut HashMap<String, serde_json::Value>);
205
206impl tracing::field::Visit for FieldVisitor<'_> {
207 fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
208 let number = serde_json::Number::from_f64(value).unwrap_or_else(|| 0.into());
209 self.0
210 .insert(field.name().to_string(), serde_json::Value::Number(number));
211 }
212
213 fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
214 self.0.insert(
215 field.name().to_string(),
216 serde_json::Value::Number(value.into()),
217 );
218 }
219
220 fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
221 self.0.insert(
222 field.name().to_string(),
223 serde_json::Value::Number(value.into()),
224 );
225 }
226
227 fn record_i128(&mut self, field: &tracing::field::Field, value: i128) {
228 self.0.insert(
230 field.name().to_string(),
231 serde_json::Value::String(value.to_string()),
232 );
233 }
234
235 fn record_u128(&mut self, field: &tracing::field::Field, value: u128) {
236 self.0.insert(
237 field.name().to_string(),
238 serde_json::Value::String(value.to_string()),
239 );
240 }
241
242 fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
243 self.0
244 .insert(field.name().to_string(), serde_json::Value::Bool(value));
245 }
246
247 fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
248 self.0.insert(
249 field.name().to_string(),
250 serde_json::Value::String(value.to_string()),
251 );
252 }
253
254 fn record_error(
255 &mut self,
256 field: &tracing::field::Field,
257 value: &(dyn std::error::Error + 'static),
258 ) {
259 self.0.insert(
260 field.name().to_string(),
261 serde_json::Value::String(value.to_string()),
262 );
263 }
264
265 fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
266 self.0.insert(
267 field.name().to_string(),
268 serde_json::Value::String(format!("{value:?}")),
269 );
270 }
271}
272
273pub mod prelude {
274 pub use super::LoggerTracingExt;
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use std::sync::{Arc, Mutex};
281 use tracing_subscriber::prelude::*;
282 use winston_transport::Transport;
283
284 #[derive(Clone)]
285 struct CaptureTransport(Arc<Mutex<Vec<LogInfo>>>);
286
287 impl Transport<LogInfo> for CaptureTransport {
288 fn log(&self, info: LogInfo) {
289 self.0.lock().unwrap().push(info);
290 }
291 }
292
293 fn make_logger_and_capture() -> (Arc<Logger>, Arc<Mutex<Vec<LogInfo>>>) {
296 let captured = Arc::new(Mutex::new(Vec::new()));
297 let logger = Arc::new(
298 Logger::builder()
299 .level("trace")
300 .format(logform::passthrough())
301 .transport(CaptureTransport(captured.clone()))
302 .build(),
303 );
304 (logger, captured)
305 }
306
307 #[test]
308 fn event_fields_become_meta() {
309 let (logger, captured) = make_logger_and_capture();
310 let _guard = tracing_subscriber::registry()
311 .with(Arc::clone(&logger).layer())
312 .set_default();
313
314 tracing::info!(user_id = 42u64, "login");
315 logger.flush().unwrap();
316
317 let logs = captured.lock().unwrap();
318 assert_eq!(logs.len(), 1);
319 let entry = &logs[0];
320 assert_eq!(entry.level, "info");
321 assert_eq!(entry.message, "login");
322 assert_eq!(entry.meta["user_id"], serde_json::json!(42u64));
323 }
324
325 #[test]
326 fn span_fields_propagate_into_events() {
327 let (logger, captured) = make_logger_and_capture();
328 let _guard = tracing_subscriber::registry()
329 .with(Arc::clone(&logger).layer())
330 .set_default();
331
332 let span = tracing::info_span!("request", request_id = "abc-123");
333 let _enter = span.enter();
334 tracing::warn!("something went wrong");
335 logger.flush().unwrap();
336
337 let logs = captured.lock().unwrap();
338 assert_eq!(logs.len(), 1);
339 let entry = &logs[0];
340 assert_eq!(entry.level, "warn");
341 assert_eq!(entry.message, "something went wrong");
342 assert_eq!(entry.meta["request_id"], serde_json::json!("abc-123"));
343 assert_eq!(entry.meta["span"], serde_json::json!("request"));
344 }
345
346 #[test]
347 fn event_fields_override_span_fields() {
348 let (logger, captured) = make_logger_and_capture();
349 let _guard = tracing_subscriber::registry()
350 .with(Arc::clone(&logger).layer())
351 .set_default();
352
353 let span = tracing::info_span!("work", key = "from-span");
354 let _enter = span.enter();
355 tracing::info!(key = "from-event", "override");
356 logger.flush().unwrap();
357
358 let logs = captured.lock().unwrap();
359 assert_eq!(logs[0].meta["key"], serde_json::json!("from-event"));
360 }
361
362 #[test]
363 fn level_mapping() {
364 let (logger, captured) = make_logger_and_capture();
365 let _guard = tracing_subscriber::registry()
366 .with(Arc::clone(&logger).layer())
367 .set_default();
368
369 tracing::error!("e");
370 tracing::warn!("w");
371 tracing::info!("i");
372 tracing::debug!("d");
373 tracing::trace!("t");
374 logger.flush().unwrap();
375
376 let logs = captured.lock().unwrap();
377 let levels: Vec<&str> = logs.iter().map(|l| l.level.as_str()).collect();
378 assert_eq!(levels, ["error", "warn", "info", "debug", "trace"]);
379 }
380}