Skip to main content

tracing_cloudwatch/
export.rs

1use std::fmt::Debug;
2use std::num::NonZeroUsize;
3use std::time::Duration;
4
5use tokio::{
6    sync::{mpsc::UnboundedReceiver, oneshot},
7    time::interval,
8};
9
10use crate::{CloudWatchClient, client::NoopClient, dispatch::LogEvent, guard::ShutdownSignal};
11
12/// Configurations to control the behavior of exporting logs to CloudWatch.
13#[derive(Debug, Clone)]
14pub struct ExportConfig {
15    /// The number of logs to retain in the buffer within the interval period.
16    batch_size: NonZeroUsize,
17    /// The interval for putting logs.
18    interval: Duration,
19    /// Where logs are sent.
20    destination: LogDestination,
21}
22
23/// Where logs are sent.
24#[derive(Debug, Clone, Default)]
25pub struct LogDestination {
26    /// The name of the log group.
27    pub log_group_name: String,
28    /// The name of the log stream.
29    pub log_stream_name: String,
30}
31
32impl Default for ExportConfig {
33    fn default() -> Self {
34        Self {
35            batch_size: NonZeroUsize::new(5).unwrap(),
36            interval: Duration::from_secs(5),
37            destination: LogDestination::default(),
38        }
39    }
40}
41
42impl ExportConfig {
43    /// Set batch size.
44    pub fn with_batch_size<T>(self, batch_size: T) -> Self
45    where
46        T: TryInto<NonZeroUsize>,
47        <T as TryInto<NonZeroUsize>>::Error: Debug,
48    {
49        Self {
50            batch_size: batch_size
51                .try_into()
52                .expect("batch size must be greater than or equal to 1"),
53            ..self
54        }
55    }
56
57    /// Set export interval.
58    pub fn with_interval(self, interval: Duration) -> Self {
59        Self { interval, ..self }
60    }
61
62    /// Set log group name.
63    pub fn with_log_group_name(self, log_group_name: impl Into<String>) -> Self {
64        Self {
65            destination: LogDestination {
66                log_group_name: log_group_name.into(),
67                log_stream_name: self.destination.log_stream_name,
68            },
69            ..self
70        }
71    }
72
73    /// Set log stream name.
74    pub fn with_log_stream_name(self, log_stream_name: impl Into<String>) -> Self {
75        Self {
76            destination: LogDestination {
77                log_stream_name: log_stream_name.into(),
78                log_group_name: self.destination.log_group_name,
79            },
80            ..self
81        }
82    }
83}
84
85pub(crate) struct BatchExporter<C> {
86    client: C,
87    queue: Vec<LogEvent>,
88    config: ExportConfig,
89}
90
91impl Default for BatchExporter<NoopClient> {
92    fn default() -> Self {
93        Self::new(NoopClient::new(), ExportConfig::default())
94    }
95}
96
97impl<C> BatchExporter<C> {
98    pub(crate) fn new(client: C, config: ExportConfig) -> Self {
99        Self {
100            client,
101            config,
102            queue: Vec::new(),
103        }
104    }
105}
106
107impl<C> BatchExporter<C>
108where
109    C: CloudWatchClient + Send + Sync + 'static,
110{
111    pub(crate) async fn run(
112        mut self,
113        mut rx: UnboundedReceiver<LogEvent>,
114        mut shutdown_rx: oneshot::Receiver<ShutdownSignal>,
115    ) {
116        let mut interval = interval(self.config.interval);
117        let mut shutdown_signal = None;
118
119        loop {
120            tokio::select! {
121                 _ = interval.tick() => {
122                    if self.queue.is_empty() {
123                        continue;
124                    }
125                }
126
127                event = rx.recv() => {
128                    let Some(event) = event else {
129                        break;
130                    };
131
132                    self.queue.push(event);
133                    if self.queue.len() < <NonZeroUsize as Into<usize>>::into(self.config.batch_size) {
134                        continue
135                    }
136                }
137
138                received_shutdown = &mut shutdown_rx => {
139                    if let Ok(signal) = received_shutdown {
140                        shutdown_signal = Some(signal);
141                    }
142                    while let Ok(event) = rx.try_recv() {
143                        self.queue.push(event);
144                    }
145                    break;
146                }
147            }
148            self.flush().await;
149        }
150        self.flush().await;
151        if let Some(shutdown_signal) = shutdown_signal {
152            shutdown_signal.ack();
153        }
154    }
155
156    async fn flush(&mut self) {
157        let logs: Vec<LogEvent> = Self::take_from_queue(&mut self.queue);
158
159        if logs.is_empty() {
160            return;
161        }
162
163        if let Err(err) = self
164            .client
165            .put_logs(self.config.destination.clone(), logs)
166            .await
167        {
168            eprintln!(
169                "[tracing-cloudwatch] Unable to put logs to cloudwatch. Error: {err:?} {:?}",
170                self.config.destination
171            );
172        }
173    }
174
175    fn take_from_queue(queue: &mut Vec<LogEvent>) -> Vec<LogEvent> {
176        if cfg!(feature = "ordered_logs") {
177            let mut logs = std::mem::take(queue);
178            logs.sort_by_key(|log| log.timestamp);
179            logs
180        } else {
181            std::mem::take(queue)
182        }
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use async_trait::async_trait;
190    use chrono::{DateTime, Utc};
191    use std::sync::{Arc, Mutex};
192    use tokio::time::{sleep, timeout};
193    use tracing_subscriber::layer::SubscriberExt;
194
195    const ONE_DAY_NS: i64 = 86_400_000_000_000;
196    const DAY_ONE: DateTime<Utc> = DateTime::from_timestamp_nanos(0 + ONE_DAY_NS);
197    const DAY_TWO: DateTime<Utc> = DateTime::from_timestamp_nanos(0 + (ONE_DAY_NS * 2));
198    const DAY_THREE: DateTime<Utc> = DateTime::from_timestamp_nanos(0 + (ONE_DAY_NS * 3));
199
200    #[cfg(not(feature = "ordered_logs"))]
201    #[test]
202    fn does_not_order_logs_by_default() {
203        let mut unordered_queue = vec![
204            LogEvent {
205                message: "1".to_string(),
206                timestamp: DAY_ONE,
207            },
208            LogEvent {
209                message: "3".to_string(),
210                timestamp: DAY_THREE,
211            },
212            LogEvent {
213                message: "2".to_string(),
214                timestamp: DAY_TWO,
215            },
216        ];
217        let still_unordered_queue =
218            BatchExporter::<NoopClient>::take_from_queue(&mut unordered_queue);
219
220        let mut still_unordered_queue_iter = still_unordered_queue.iter();
221        assert_eq!(
222            DAY_ONE,
223            still_unordered_queue_iter.next().unwrap().timestamp
224        );
225        assert_eq!(
226            DAY_THREE,
227            still_unordered_queue_iter.next().unwrap().timestamp
228        );
229        assert_eq!(
230            DAY_TWO,
231            still_unordered_queue_iter.next().unwrap().timestamp
232        );
233    }
234
235    #[cfg(feature = "ordered_logs")]
236    mod ordering {
237        use super::*;
238
239        fn assert_is_ordered(logs: Vec<LogEvent>) {
240            let mut last_timestamp = DateTime::from_timestamp_nanos(0);
241
242            for log in logs {
243                assert!(
244                    log.timestamp > last_timestamp,
245                    "Not true: {} > {}",
246                    log.timestamp,
247                    last_timestamp
248                );
249                last_timestamp = log.timestamp;
250            }
251        }
252
253        #[test]
254        fn orders_logs_when_enabled() {
255            let mut unordered_queue = vec![
256                LogEvent {
257                    message: "1".to_string(),
258                    timestamp: DAY_ONE,
259                },
260                LogEvent {
261                    message: "3".to_string(),
262                    timestamp: DAY_THREE,
263                },
264                LogEvent {
265                    message: "2".to_string(),
266                    timestamp: DAY_TWO,
267                },
268            ];
269            let ordered_queue = BatchExporter::<NoopClient>::take_from_queue(&mut unordered_queue);
270            assert_is_ordered(ordered_queue);
271        }
272    }
273
274    #[derive(Clone, Default)]
275    struct RecordingClient {
276        logs: Arc<Mutex<Vec<LogEvent>>>,
277    }
278
279    #[async_trait]
280    impl CloudWatchClient for RecordingClient {
281        async fn put_logs(
282            &self,
283            _dest: LogDestination,
284            logs: Vec<LogEvent>,
285        ) -> Result<(), crate::client::PutLogsError> {
286            self.logs.lock().unwrap().extend(logs);
287            Ok(())
288        }
289    }
290
291    impl RecordingClient {
292        fn exported_count(&self) -> usize {
293            self.logs.lock().unwrap().len()
294        }
295
296        fn exported_messages(&self) -> Vec<String> {
297            self.logs
298                .lock()
299                .unwrap()
300                .iter()
301                .map(|event| event.message.clone())
302                .collect()
303        }
304    }
305
306    async fn wait_for_exported_count(client: &RecordingClient, expected: usize) {
307        timeout(Duration::from_secs(1), async {
308            loop {
309                if client.exported_count() >= expected {
310                    break;
311                }
312                sleep(Duration::from_millis(10)).await;
313            }
314        })
315        .await
316        .expect("timed out waiting for exported log events");
317    }
318
319    #[tokio::test(flavor = "current_thread")]
320    async fn drains_all_buffered_events_on_shutdown() {
321        let client = RecordingClient::default();
322        let exporter = BatchExporter::new(
323            client.clone(),
324            ExportConfig::default()
325                .with_batch_size(10_000)
326                .with_interval(Duration::from_secs(60))
327                .with_log_group_name("group")
328                .with_log_stream_name("stream"),
329        );
330
331        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
332        let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<ShutdownSignal>();
333        let (shutdown_signal, _ack_rx) = ShutdownSignal::new();
334
335        let total = 512;
336        for idx in 0..total {
337            tx.send(LogEvent {
338                message: format!("event-{idx}"),
339                timestamp: Utc::now(),
340            })
341            .unwrap();
342        }
343        drop(tx);
344        shutdown_tx.send(shutdown_signal).unwrap();
345
346        exporter.run(rx, shutdown_rx).await;
347
348        assert_eq!(
349            client.exported_count(),
350            total,
351            "all events queued before shutdown should be exported"
352        );
353    }
354
355    #[tokio::test(flavor = "current_thread")]
356    async fn exports_events_with_registry_on_guard_shutdown() {
357        let client = RecordingClient::default();
358        let (cw_layer, guard) = crate::layer()
359            .with_code_location(false)
360            .with_target(false)
361            .with_client(
362                client.clone(),
363                ExportConfig::default()
364                    .with_batch_size(1024)
365                    .with_interval(Duration::from_secs(60))
366                    .with_log_group_name("group")
367                    .with_log_stream_name("stream"),
368            );
369
370        let subscriber = tracing_subscriber::registry().with(cw_layer);
371        tracing::subscriber::with_default(subscriber, || {
372            tracing::info!("integration-log-1");
373            tracing::warn!("integration-log-2");
374        });
375
376        guard.shutdown().await;
377
378        let messages = client.exported_messages();
379        assert_eq!(messages.len(), 2);
380        assert!(
381            messages
382                .iter()
383                .any(|message| message.contains("integration-log-1"))
384        );
385        assert!(
386            messages
387                .iter()
388                .any(|message| message.contains("integration-log-2"))
389        );
390    }
391
392    #[tokio::test(flavor = "current_thread")]
393    async fn exports_when_batch_size_is_reached() {
394        let client = RecordingClient::default();
395        let (cw_layer, guard) = crate::layer()
396            .with_code_location(false)
397            .with_target(false)
398            .with_client(
399                client.clone(),
400                ExportConfig::default()
401                    .with_batch_size(2)
402                    .with_interval(Duration::from_secs(60))
403                    .with_log_group_name("group")
404                    .with_log_stream_name("stream"),
405            );
406
407        let subscriber = tracing_subscriber::registry().with(cw_layer);
408        // Let the exporter consume the initial immediate interval tick while the queue is empty.
409        sleep(Duration::from_millis(20)).await;
410
411        tracing::subscriber::with_default(subscriber, || {
412            tracing::info!("batch-log-1");
413            tracing::info!("batch-log-2");
414        });
415
416        wait_for_exported_count(&client, 2).await;
417        guard.shutdown().await;
418    }
419
420    #[tokio::test(flavor = "current_thread")]
421    async fn exports_without_shutdown_when_batch_not_full() {
422        let client = RecordingClient::default();
423        let (cw_layer, guard) = crate::layer()
424            .with_code_location(false)
425            .with_target(false)
426            .with_client(
427                client.clone(),
428                ExportConfig::default()
429                    .with_batch_size(1024)
430                    .with_interval(Duration::from_millis(200))
431                    .with_log_group_name("group")
432                    .with_log_stream_name("stream"),
433            );
434
435        let subscriber = tracing_subscriber::registry().with(cw_layer);
436        // Let the exporter consume the initial immediate interval tick while the queue is empty.
437        sleep(Duration::from_millis(20)).await;
438
439        tracing::subscriber::with_default(subscriber, || {
440            tracing::info!("interval-log-1");
441        });
442
443        wait_for_exported_count(&client, 1).await;
444        guard.shutdown().await;
445    }
446}