Skip to main content

tork_core/testing/
recorder.rs

1//! A log recorder for asserting on logs in tests.
2
3use std::sync::{Arc, Mutex};
4
5use serde_json::{Map, Value};
6use tracing::field::{Field, Visit};
7use tracing::Event;
8use tracing_subscriber::layer::{Context, Layer};
9
10/// A single captured log record.
11#[derive(Clone, Debug)]
12pub struct LogRecord {
13    /// The level (`INFO`, `ERROR`, ...).
14    pub level: String,
15    /// The logger context (for example a service name).
16    pub context: String,
17    /// The log message.
18    pub message: String,
19    /// The structured fields.
20    pub fields: Map<String, Value>,
21}
22
23/// Captures log records for assertions in tests.
24///
25/// Attach it with [`TestClientBuilder::logger`](super::TestClientBuilder::logger);
26/// the client routes its request logs to this recorder for the duration of the
27/// test. Works with the default current-thread test runtime.
28#[derive(Clone, Default)]
29pub struct LogRecorder {
30    records: Arc<Mutex<Vec<LogRecord>>>,
31}
32
33impl LogRecorder {
34    /// Creates an empty recorder.
35    pub fn new() -> Self {
36        Self::default()
37    }
38
39    /// Returns a snapshot of the captured records.
40    pub fn records(&self) -> Vec<LogRecord> {
41        self.records
42            .lock()
43            .expect("recorder mutex poisoned")
44            .clone()
45    }
46
47    /// Removes all captured records.
48    pub fn clear(&self) {
49        self.records
50            .lock()
51            .expect("recorder mutex poisoned")
52            .clear();
53    }
54
55    /// Returns `true` if any record has the given context.
56    pub fn contains_context(&self, context: &str) -> bool {
57        self.records()
58            .iter()
59            .any(|record| record.context == context)
60    }
61
62    /// Returns `true` if any record's message contains `text`.
63    pub fn contains_message(&self, text: &str) -> bool {
64        self.records()
65            .iter()
66            .any(|record| record.message.contains(text))
67    }
68}
69
70impl<S: tracing::Subscriber> Layer<S> for LogRecorder {
71    fn on_event(&self, event: &Event<'_>, _ctx: Context<'_, S>) {
72        let mut visitor = RecordVisitor::default();
73        event.record(&mut visitor);
74        let record = LogRecord {
75            level: event.metadata().level().to_string(),
76            context: visitor
77                .context
78                .unwrap_or_else(|| event.metadata().target().to_owned()),
79            message: visitor.message.unwrap_or_default(),
80            fields: visitor.fields,
81        };
82        self.records
83            .lock()
84            .expect("recorder mutex poisoned")
85            .push(record);
86    }
87}
88
89/// Extracts the Tork fields from an event into a [`LogRecord`].
90#[derive(Default)]
91struct RecordVisitor {
92    message: Option<String>,
93    context: Option<String>,
94    fields: Map<String, Value>,
95}
96
97impl RecordVisitor {
98    fn set(&mut self, name: &str, value: String) {
99        match name {
100            "message" => self.message = Some(value),
101            "tork.context" => self.context = Some(value),
102            "tork.fields" => {
103                if let Ok(Value::Object(map)) = serde_json::from_str::<Value>(&value) {
104                    self.fields = map;
105                }
106            }
107            _ => {}
108        }
109    }
110}
111
112impl Visit for RecordVisitor {
113    fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
114        self.set(field.name(), format!("{value:?}"));
115    }
116
117    fn record_str(&mut self, field: &Field, value: &str) {
118        self.set(field.name(), value.to_owned());
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use tracing_subscriber::layer::SubscriberExt;
126
127    #[test]
128    fn recorder_helpers_find_context_and_message() {
129        let recorder = LogRecorder::new();
130        let subscriber = tracing_subscriber::registry().with(recorder.clone());
131
132        tracing::subscriber::with_default(subscriber, || {
133            tracing::info!(
134                tork.context = "Orders",
135                tork.fields = "{\"id\":1}",
136                "created order"
137            );
138        });
139
140        assert!(recorder.contains_context("Orders"));
141        assert!(recorder.contains_message("created order"));
142        assert_eq!(recorder.records()[0].fields["id"], Value::from(1));
143    }
144
145    #[test]
146    fn visitor_ignores_invalid_json_fields_payload() {
147        let mut visitor = RecordVisitor::default();
148        visitor.set("tork.fields", "not-json".to_owned());
149        visitor.set("message", "hello".to_owned());
150
151        assert_eq!(visitor.message.as_deref(), Some("hello"));
152        assert!(visitor.fields.is_empty());
153    }
154}
155
156/// Asserts a recorder captured a log with the given context and message substring.
157///
158/// ```ignore
159/// assert_logs!(recorder, context = "OrderService", message = "Listing orders");
160/// ```
161#[macro_export]
162macro_rules! assert_logs {
163    ($recorder:expr, context = $context:expr, message = $message:expr $(,)?) => {{
164        let records = $recorder.records();
165        assert!(
166            records
167                .iter()
168                .any(|record| record.context == $context && record.message.contains($message)),
169            "no log with context {:?} and message containing {:?}; captured: {:?}",
170            $context,
171            $message,
172            records,
173        );
174    }};
175}