1use std::fmt::Debug;
2use std::num::NonZeroUsize;
3use std::time::Duration;
4
5use tokio::{
6 sync::{mpsc::UnboundedReceiver, oneshot},
7 time::interval,
8};
9
10use crate::{CloudWatchClient, client::NoopClient, dispatch::LogEvent, guard::ShutdownSignal};
11
12#[derive(Debug, Clone)]
14pub struct ExportConfig {
15 batch_size: NonZeroUsize,
17 interval: Duration,
19 destination: LogDestination,
21}
22
23#[derive(Debug, Clone, Default)]
25pub struct LogDestination {
26 pub log_group_name: String,
28 pub log_stream_name: String,
30}
31
32impl Default for ExportConfig {
33 fn default() -> Self {
34 Self {
35 batch_size: NonZeroUsize::new(5).unwrap(),
36 interval: Duration::from_secs(5),
37 destination: LogDestination::default(),
38 }
39 }
40}
41
42impl ExportConfig {
43 pub fn with_batch_size<T>(self, batch_size: T) -> Self
45 where
46 T: TryInto<NonZeroUsize>,
47 <T as TryInto<NonZeroUsize>>::Error: Debug,
48 {
49 Self {
50 batch_size: batch_size
51 .try_into()
52 .expect("batch size must be greater than or equal to 1"),
53 ..self
54 }
55 }
56
57 pub fn with_interval(self, interval: Duration) -> Self {
59 Self { interval, ..self }
60 }
61
62 pub fn with_log_group_name(self, log_group_name: impl Into<String>) -> Self {
64 Self {
65 destination: LogDestination {
66 log_group_name: log_group_name.into(),
67 log_stream_name: self.destination.log_stream_name,
68 },
69 ..self
70 }
71 }
72
73 pub fn with_log_stream_name(self, log_stream_name: impl Into<String>) -> Self {
75 Self {
76 destination: LogDestination {
77 log_stream_name: log_stream_name.into(),
78 log_group_name: self.destination.log_group_name,
79 },
80 ..self
81 }
82 }
83}
84
85pub(crate) struct BatchExporter<C> {
86 client: C,
87 queue: Vec<LogEvent>,
88 config: ExportConfig,
89}
90
91impl Default for BatchExporter<NoopClient> {
92 fn default() -> Self {
93 Self::new(NoopClient::new(), ExportConfig::default())
94 }
95}
96
97impl<C> BatchExporter<C> {
98 pub(crate) fn new(client: C, config: ExportConfig) -> Self {
99 Self {
100 client,
101 config,
102 queue: Vec::new(),
103 }
104 }
105}
106
107impl<C> BatchExporter<C>
108where
109 C: CloudWatchClient + Send + Sync + 'static,
110{
111 pub(crate) async fn run(
112 mut self,
113 mut rx: UnboundedReceiver<LogEvent>,
114 mut shutdown_rx: oneshot::Receiver<ShutdownSignal>,
115 ) {
116 let mut interval = interval(self.config.interval);
117 let mut shutdown_signal = None;
118
119 loop {
120 tokio::select! {
121 _ = interval.tick() => {
122 if self.queue.is_empty() {
123 continue;
124 }
125 }
126
127 event = rx.recv() => {
128 let Some(event) = event else {
129 break;
130 };
131
132 self.queue.push(event);
133 if self.queue.len() < <NonZeroUsize as Into<usize>>::into(self.config.batch_size) {
134 continue
135 }
136 }
137
138 received_shutdown = &mut shutdown_rx => {
139 if let Ok(signal) = received_shutdown {
140 shutdown_signal = Some(signal);
141 }
142 while let Ok(event) = rx.try_recv() {
143 self.queue.push(event);
144 }
145 break;
146 }
147 }
148 self.flush().await;
149 }
150 self.flush().await;
151 if let Some(shutdown_signal) = shutdown_signal {
152 shutdown_signal.ack();
153 }
154 }
155
156 async fn flush(&mut self) {
157 let logs: Vec<LogEvent> = Self::take_from_queue(&mut self.queue);
158
159 if logs.is_empty() {
160 return;
161 }
162
163 if let Err(err) = self
164 .client
165 .put_logs(self.config.destination.clone(), logs)
166 .await
167 {
168 eprintln!(
169 "[tracing-cloudwatch] Unable to put logs to cloudwatch. Error: {err:?} {:?}",
170 self.config.destination
171 );
172 }
173 }
174
175 fn take_from_queue(queue: &mut Vec<LogEvent>) -> Vec<LogEvent> {
176 if cfg!(feature = "ordered_logs") {
177 let mut logs = std::mem::take(queue);
178 logs.sort_by_key(|log| log.timestamp);
179 logs
180 } else {
181 std::mem::take(queue)
182 }
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use async_trait::async_trait;
190 use chrono::{DateTime, Utc};
191 use std::sync::{Arc, Mutex};
192 use tokio::time::{sleep, timeout};
193 use tracing_subscriber::layer::SubscriberExt;
194
195 const ONE_DAY_NS: i64 = 86_400_000_000_000;
196 const DAY_ONE: DateTime<Utc> = DateTime::from_timestamp_nanos(0 + ONE_DAY_NS);
197 const DAY_TWO: DateTime<Utc> = DateTime::from_timestamp_nanos(0 + (ONE_DAY_NS * 2));
198 const DAY_THREE: DateTime<Utc> = DateTime::from_timestamp_nanos(0 + (ONE_DAY_NS * 3));
199
200 #[cfg(not(feature = "ordered_logs"))]
201 #[test]
202 fn does_not_order_logs_by_default() {
203 let mut unordered_queue = vec![
204 LogEvent {
205 message: "1".to_string(),
206 timestamp: DAY_ONE,
207 },
208 LogEvent {
209 message: "3".to_string(),
210 timestamp: DAY_THREE,
211 },
212 LogEvent {
213 message: "2".to_string(),
214 timestamp: DAY_TWO,
215 },
216 ];
217 let still_unordered_queue =
218 BatchExporter::<NoopClient>::take_from_queue(&mut unordered_queue);
219
220 let mut still_unordered_queue_iter = still_unordered_queue.iter();
221 assert_eq!(
222 DAY_ONE,
223 still_unordered_queue_iter.next().unwrap().timestamp
224 );
225 assert_eq!(
226 DAY_THREE,
227 still_unordered_queue_iter.next().unwrap().timestamp
228 );
229 assert_eq!(
230 DAY_TWO,
231 still_unordered_queue_iter.next().unwrap().timestamp
232 );
233 }
234
235 #[cfg(feature = "ordered_logs")]
236 mod ordering {
237 use super::*;
238
239 fn assert_is_ordered(logs: Vec<LogEvent>) {
240 let mut last_timestamp = DateTime::from_timestamp_nanos(0);
241
242 for log in logs {
243 assert!(
244 log.timestamp > last_timestamp,
245 "Not true: {} > {}",
246 log.timestamp,
247 last_timestamp
248 );
249 last_timestamp = log.timestamp;
250 }
251 }
252
253 #[test]
254 fn orders_logs_when_enabled() {
255 let mut unordered_queue = vec![
256 LogEvent {
257 message: "1".to_string(),
258 timestamp: DAY_ONE,
259 },
260 LogEvent {
261 message: "3".to_string(),
262 timestamp: DAY_THREE,
263 },
264 LogEvent {
265 message: "2".to_string(),
266 timestamp: DAY_TWO,
267 },
268 ];
269 let ordered_queue = BatchExporter::<NoopClient>::take_from_queue(&mut unordered_queue);
270 assert_is_ordered(ordered_queue);
271 }
272 }
273
274 #[derive(Clone, Default)]
275 struct RecordingClient {
276 logs: Arc<Mutex<Vec<LogEvent>>>,
277 }
278
279 #[async_trait]
280 impl CloudWatchClient for RecordingClient {
281 async fn put_logs(
282 &self,
283 _dest: LogDestination,
284 logs: Vec<LogEvent>,
285 ) -> Result<(), crate::client::PutLogsError> {
286 self.logs.lock().unwrap().extend(logs);
287 Ok(())
288 }
289 }
290
291 impl RecordingClient {
292 fn exported_count(&self) -> usize {
293 self.logs.lock().unwrap().len()
294 }
295
296 fn exported_messages(&self) -> Vec<String> {
297 self.logs
298 .lock()
299 .unwrap()
300 .iter()
301 .map(|event| event.message.clone())
302 .collect()
303 }
304 }
305
306 async fn wait_for_exported_count(client: &RecordingClient, expected: usize) {
307 timeout(Duration::from_secs(1), async {
308 loop {
309 if client.exported_count() >= expected {
310 break;
311 }
312 sleep(Duration::from_millis(10)).await;
313 }
314 })
315 .await
316 .expect("timed out waiting for exported log events");
317 }
318
319 #[tokio::test(flavor = "current_thread")]
320 async fn drains_all_buffered_events_on_shutdown() {
321 let client = RecordingClient::default();
322 let exporter = BatchExporter::new(
323 client.clone(),
324 ExportConfig::default()
325 .with_batch_size(10_000)
326 .with_interval(Duration::from_secs(60))
327 .with_log_group_name("group")
328 .with_log_stream_name("stream"),
329 );
330
331 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
332 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<ShutdownSignal>();
333 let (shutdown_signal, _ack_rx) = ShutdownSignal::new();
334
335 let total = 512;
336 for idx in 0..total {
337 tx.send(LogEvent {
338 message: format!("event-{idx}"),
339 timestamp: Utc::now(),
340 })
341 .unwrap();
342 }
343 drop(tx);
344 shutdown_tx.send(shutdown_signal).unwrap();
345
346 exporter.run(rx, shutdown_rx).await;
347
348 assert_eq!(
349 client.exported_count(),
350 total,
351 "all events queued before shutdown should be exported"
352 );
353 }
354
355 #[tokio::test(flavor = "current_thread")]
356 async fn exports_events_with_registry_on_guard_shutdown() {
357 let client = RecordingClient::default();
358 let (cw_layer, guard) = crate::layer()
359 .with_code_location(false)
360 .with_target(false)
361 .with_client(
362 client.clone(),
363 ExportConfig::default()
364 .with_batch_size(1024)
365 .with_interval(Duration::from_secs(60))
366 .with_log_group_name("group")
367 .with_log_stream_name("stream"),
368 );
369
370 let subscriber = tracing_subscriber::registry().with(cw_layer);
371 tracing::subscriber::with_default(subscriber, || {
372 tracing::info!("integration-log-1");
373 tracing::warn!("integration-log-2");
374 });
375
376 guard.shutdown().await;
377
378 let messages = client.exported_messages();
379 assert_eq!(messages.len(), 2);
380 assert!(
381 messages
382 .iter()
383 .any(|message| message.contains("integration-log-1"))
384 );
385 assert!(
386 messages
387 .iter()
388 .any(|message| message.contains("integration-log-2"))
389 );
390 }
391
392 #[tokio::test(flavor = "current_thread")]
393 async fn exports_when_batch_size_is_reached() {
394 let client = RecordingClient::default();
395 let (cw_layer, guard) = crate::layer()
396 .with_code_location(false)
397 .with_target(false)
398 .with_client(
399 client.clone(),
400 ExportConfig::default()
401 .with_batch_size(2)
402 .with_interval(Duration::from_secs(60))
403 .with_log_group_name("group")
404 .with_log_stream_name("stream"),
405 );
406
407 let subscriber = tracing_subscriber::registry().with(cw_layer);
408 sleep(Duration::from_millis(20)).await;
410
411 tracing::subscriber::with_default(subscriber, || {
412 tracing::info!("batch-log-1");
413 tracing::info!("batch-log-2");
414 });
415
416 wait_for_exported_count(&client, 2).await;
417 guard.shutdown().await;
418 }
419
420 #[tokio::test(flavor = "current_thread")]
421 async fn exports_without_shutdown_when_batch_not_full() {
422 let client = RecordingClient::default();
423 let (cw_layer, guard) = crate::layer()
424 .with_code_location(false)
425 .with_target(false)
426 .with_client(
427 client.clone(),
428 ExportConfig::default()
429 .with_batch_size(1024)
430 .with_interval(Duration::from_millis(200))
431 .with_log_group_name("group")
432 .with_log_stream_name("stream"),
433 );
434
435 let subscriber = tracing_subscriber::registry().with(cw_layer);
436 sleep(Duration::from_millis(20)).await;
438
439 tracing::subscriber::with_default(subscriber, || {
440 tracing::info!("interval-log-1");
441 });
442
443 wait_for_exported_count(&client, 1).await;
444 guard.shutdown().await;
445 }
446}