Skip to main content

scouter_sql/sql/
aggregator.rs

1use crate::sql::error::SqlError;
2use crate::sql::query::Queries;
3use crate::sql::utils::EntityBytea;
4use chrono::{DateTime, Duration, Timelike, Utc};
5use dashmap::DashMap;
6use scouter_types::{Attribute, TraceId, TraceSpanRecord, SCOUTER_ENTITY};
7use sqlx::PgPool;
8use std::collections::HashSet;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use tokio::time::{interval, Duration as StdDuration};
13use tracing::{error, info, warn};
14const TRACE_BATCH_SIZE: usize = 1000;
15
16const MAX_TOTAL_SPANS: u64 = 1_000_000;
17
18/// Cache handle to manage trace aggregations
19struct TraceCacheHandle {
20    cache: Arc<TraceCache>,
21    shutdown_flag: Arc<AtomicBool>,
22}
23
24static TRACE_CACHE: RwLock<Option<TraceCacheHandle>> = RwLock::const_new(None);
25
26#[derive(Debug, Clone)]
27pub struct TraceAggregator {
28    pub trace_id: TraceId,
29    pub service_name: String,
30    pub scope_name: String,
31    pub scope_version: String,
32    pub root_operation: String,
33    pub start_time: DateTime<Utc>,
34    pub end_time: Option<DateTime<Utc>>,
35    pub status_code: i32,
36    pub status_message: String,
37    pub span_count: i64,
38    pub error_count: i64,
39    pub resource_attributes: Vec<Attribute>,
40    pub first_seen: DateTime<Utc>,
41    pub last_updated: DateTime<Utc>,
42    pub entity_tags: HashSet<EntityBytea>,
43}
44
45impl TraceAggregator {
46    pub fn bucket_time(&self) -> DateTime<Utc> {
47        self.start_time
48            .with_minute(0)
49            .unwrap()
50            .with_second(0)
51            .unwrap()
52            .with_nanosecond(0)
53            .unwrap()
54    }
55
56    /// Extracts specific entity attributes from span attributes and adds them to the aggregator's entity_tags set
57    /// Arguments:
58    /// - `span`: The TraceSpanRecord from which to extract entity attributes
59    pub fn add_entities(&mut self, span: &TraceSpanRecord) {
60        for attr in &span.attributes {
61            if attr.key.starts_with(SCOUTER_ENTITY) {
62                // Value should be string in the format "{uid}"
63                let entity = match &attr.value {
64                    serde_json::Value::String(s) => s.clone(),
65                    _ => continue, // Skip if not a string
66                };
67                match EntityBytea::from_uuid(&entity) {
68                    Ok(uid) => {
69                        self.entity_tags.insert(uid);
70                    }
71                    Err(e) => {
72                        warn!(%entity, "Failed to parse entity UID from attribute value in span {}: {}", span.span_id, e);
73                    }
74                }
75            }
76        }
77    }
78
79    pub fn new_from_span(span: &TraceSpanRecord) -> Self {
80        let now = Utc::now();
81        let mut aggregator = Self {
82            trace_id: span.trace_id.clone(),
83            service_name: span.service_name.clone(),
84            scope_name: span.scope_name.clone(),
85            scope_version: span.scope_version.clone().unwrap_or_default(),
86            root_operation: if span.parent_span_id.is_none() {
87                span.span_name.clone()
88            } else {
89                String::new()
90            },
91            start_time: span.start_time,
92            end_time: Some(span.end_time),
93            status_code: span.status_code,
94            status_message: span.status_message.clone(),
95            span_count: 1,
96            error_count: if span.status_code == 2 { 1 } else { 0 },
97            resource_attributes: span.resource_attributes.clone(),
98            first_seen: now,
99            last_updated: now,
100            entity_tags: HashSet::new(),
101        };
102        aggregator.add_entities(span);
103        aggregator
104    }
105
106    pub fn update_from_span(&mut self, span: &TraceSpanRecord) {
107        if span.start_time < self.start_time {
108            self.start_time = span.start_time;
109        }
110        if let Some(current_end) = self.end_time {
111            if span.end_time > current_end {
112                self.end_time = Some(span.end_time);
113            }
114        } else {
115            self.end_time = Some(span.end_time);
116        }
117
118        if span.parent_span_id.is_none() {
119            self.root_operation = span.span_name.clone();
120            self.service_name = span.service_name.clone();
121            self.scope_name = span.scope_name.clone();
122            if let Some(version) = &span.scope_version {
123                self.scope_version = version.clone();
124            }
125            self.resource_attributes = span.resource_attributes.clone();
126        }
127
128        if span.status_code == 2 {
129            self.error_count += 1;
130            self.status_code = 2;
131            self.status_message = span.status_message.clone();
132        }
133
134        self.span_count += 1;
135        self.last_updated = Utc::now();
136        self.add_entities(span);
137    }
138
139    pub fn duration_ms(&self) -> Option<i64> {
140        self.end_time
141            .map(|end| (end - self.start_time).num_milliseconds())
142    }
143
144    pub fn is_stale(&self, stale_duration: Duration) -> bool {
145        (Utc::now() - self.last_updated) >= stale_duration
146    }
147}
148
149pub struct TraceCache {
150    traces: DashMap<TraceId, TraceAggregator>,
151    max_traces: usize,
152    total_span_count: AtomicU64,
153}
154
155impl TraceCache {
156    fn new(max_traces: usize) -> Self {
157        Self {
158            traces: DashMap::new(),
159            max_traces,
160            total_span_count: AtomicU64::new(0),
161        }
162    }
163
164    /// Update trace aggregation from a span. Uses Arc<Self> to enable background flushing.
165    pub async fn update_trace(self: &Arc<Self>, span: &TraceSpanRecord) {
166        let current_traces = self.traces.len();
167        let current_spans = self.total_span_count.load(Ordering::Relaxed);
168
169        // Check trace and span pressure
170        let trace_pressure = (current_traces * 100) / self.max_traces;
171        let span_pressure = (current_spans * 100) / MAX_TOTAL_SPANS;
172        let max_pressure = trace_pressure.max(span_pressure as usize);
173
174        // If near capacity, log warning (background flush task will handle it)
175        if max_pressure >= 90 {
176            warn!(
177                current_traces,
178                current_spans,
179                max_pressure,
180                "TraceCache high memory pressure, will flush on next interval"
181            );
182        }
183        self.traces
184            .entry(span.trace_id.clone())
185            .and_modify(|agg| {
186                agg.update_from_span(span);
187                self.total_span_count.fetch_add(1, Ordering::Relaxed);
188            })
189            .or_insert_with(|| {
190                self.total_span_count.fetch_add(1, Ordering::Relaxed);
191                TraceAggregator::new_from_span(span)
192            });
193    }
194
195    pub async fn flush_traces(
196        &self,
197        pool: &PgPool,
198        stale_threshold: Duration,
199    ) -> Result<usize, SqlError> {
200        let stale_keys: Vec<TraceId> = self
201            .traces
202            .iter()
203            .filter(|entry| entry.value().is_stale(stale_threshold))
204            .map(|entry| entry.key().clone())
205            .collect();
206
207        if stale_keys.is_empty() {
208            return Ok(0);
209        }
210
211        let mut to_flush = Vec::with_capacity(stale_keys.len());
212        let mut spans_freed = 0u64;
213
214        for id in stale_keys {
215            if let Some((_, agg)) = self.traces.remove(&id) {
216                spans_freed += agg.span_count as u64;
217                to_flush.push((id, agg));
218            }
219        }
220
221        // Decrement by actual span count, not trace count
222        self.total_span_count
223            .fetch_sub(spans_freed, Ordering::Relaxed);
224
225        let count = to_flush.len();
226        info!(
227            flushed_traces = count,
228            freed_spans = spans_freed,
229            remaining_traces = self.traces.len(),
230            remaining_spans = self.total_span_count.load(Ordering::Relaxed),
231            "Flushed stale traces"
232        );
233
234        for chunk in to_flush.chunks(TRACE_BATCH_SIZE) {
235            self.upsert_batch(pool, chunk).await?;
236        }
237        Ok(count)
238    }
239
240    async fn upsert_batch(
241        &self,
242        pool: &PgPool,
243        traces: &[(TraceId, TraceAggregator)],
244    ) -> Result<(), SqlError> {
245        let mut created_ats = Vec::with_capacity(traces.len());
246        let mut bucket_times = Vec::with_capacity(traces.len());
247        let mut trace_ids = Vec::with_capacity(traces.len());
248        let mut service_names = Vec::with_capacity(traces.len());
249        let mut scope_names = Vec::with_capacity(traces.len());
250        let mut scope_versions = Vec::with_capacity(traces.len());
251        let mut root_operations = Vec::with_capacity(traces.len());
252        let mut start_times = Vec::with_capacity(traces.len());
253        let mut end_times = Vec::with_capacity(traces.len());
254        let mut durations_ms = Vec::with_capacity(traces.len());
255        let mut status_codes = Vec::with_capacity(traces.len());
256        let mut status_messages = Vec::with_capacity(traces.len());
257        let mut span_counts = Vec::with_capacity(traces.len());
258        let mut error_counts = Vec::with_capacity(traces.len());
259        let mut resource_attrs = Vec::with_capacity(traces.len());
260
261        // Collect entity tags for batch insert
262        let mut entity_trace_ids = Vec::new();
263        let mut entity_uids = Vec::new();
264        let mut entity_tagged_ats = Vec::new();
265        let now = Utc::now();
266
267        for (trace_id, agg) in traces {
268            created_ats.push(agg.first_seen);
269            bucket_times.push(agg.bucket_time());
270            trace_ids.push(trace_id.as_bytes());
271            service_names.push(&agg.service_name);
272            scope_names.push(&agg.scope_name);
273            scope_versions.push(&agg.scope_version);
274            root_operations.push(&agg.root_operation);
275            start_times.push(agg.start_time);
276            end_times.push(agg.end_time.unwrap_or(agg.start_time));
277            durations_ms.push(agg.duration_ms().unwrap_or(0));
278            status_codes.push(agg.status_code);
279            status_messages.push(&agg.status_message);
280            span_counts.push(agg.span_count);
281            error_counts.push(agg.error_count);
282            resource_attrs.push(serde_json::to_value(&agg.resource_attributes)?);
283
284            // Collect entity tags for this trace
285            for entity_uid in &agg.entity_tags {
286                entity_trace_ids.push(trace_id.as_bytes());
287                entity_uids.push(entity_uid.as_bytes());
288                entity_tagged_ats.push(now);
289            }
290        }
291
292        // 1. Upsert trace aggregations
293        sqlx::query(Queries::UpsertTrace.get_query())
294            .bind(&bucket_times)
295            .bind(&created_ats)
296            .bind(&trace_ids)
297            .bind(&service_names as &[&String])
298            .bind(&scope_names as &[&String])
299            .bind(&scope_versions as &[&String])
300            .bind(&root_operations as &[&String])
301            .bind(&start_times)
302            .bind(&end_times)
303            .bind(&durations_ms)
304            .bind(&status_codes)
305            .bind(&status_messages as &[&String])
306            .bind(&span_counts)
307            .bind(&error_counts)
308            .bind(&resource_attrs)
309            .execute(pool)
310            .await?;
311
312        // 2. Batch insert all entity tags in one query
313        if !entity_trace_ids.is_empty() {
314            sqlx::query(Queries::InsertTraceEntityTags.get_query())
315                .bind(&entity_trace_ids)
316                .bind(&entity_uids)
317                .bind(&entity_tagged_ats)
318                .execute(pool)
319                .await?;
320        }
321
322        Ok(())
323    }
324}
325
326/// Initialize the TraceCache, replacing any previous instance.
327/// The old background flush task is signaled to stop and any remaining
328/// traces are flushed with the NEW pool before the cache is swapped.
329pub async fn init_trace_cache(
330    pool: PgPool,
331    flush_interval: Duration,
332    stale_threshold: Duration,
333    max_traces: usize,
334) -> Result<(), SqlError> {
335    // Shut down any existing cache first
336    let old_cache = {
337        let guard = TRACE_CACHE.read().await;
338        guard.as_ref().map(|handle| {
339            handle.shutdown_flag.store(true, Ordering::SeqCst);
340            handle.cache.clone()
341        })
342    };
343
344    // Flush outside so we dont hold the lock
345    if let Some(cache) = old_cache {
346        info!("Flushing previous TraceCache before re-initialization...");
347        if let Err(e) = cache.flush_traces(&pool, Duration::seconds(-1)).await {
348            error!(error = %e, "Failed to flush previous TraceCache");
349        }
350    }
351
352    let cache = Arc::new(TraceCache::new(max_traces));
353    let shutdown_flag = Arc::new(AtomicBool::new(false));
354
355    {
356        let mut guard = TRACE_CACHE.write().await;
357        *guard = Some(TraceCacheHandle {
358            cache: cache.clone(),
359            shutdown_flag: shutdown_flag.clone(),
360        });
361    }
362
363    let flush_std_duration = StdDuration::from_secs(flush_interval.num_seconds() as u64);
364    let task_shutdown = shutdown_flag.clone();
365
366    tokio::spawn(async move {
367        let mut ticker = interval(flush_std_duration);
368        loop {
369            ticker.tick().await;
370
371            if task_shutdown.load(Ordering::SeqCst) {
372                info!("TraceCache background flush task shutting down");
373                break;
374            }
375
376            let current_traces = cache.traces.len();
377            let current_spans = cache.total_span_count.load(Ordering::Relaxed);
378
379            let threshold = if current_traces > max_traces || current_spans > MAX_TOTAL_SPANS {
380                warn!(
381                    current_traces,
382                    current_spans, "Emergency flush triggered due to memory pressure"
383                );
384                Duration::seconds(0)
385            } else {
386                stale_threshold
387            };
388
389            if let Err(e) = cache.flush_traces(&pool, threshold).await {
390                error!(error = %e, "Flush task failed");
391            }
392        }
393    });
394
395    info!("TraceCache initialized");
396    Ok(())
397}
398
399/// Get access to the current TraceCache
400pub async fn get_trace_cache() -> Arc<TraceCache> {
401    TRACE_CACHE
402        .read()
403        .await
404        .as_ref()
405        .expect("TraceCache not initialized")
406        .cache
407        .clone()
408}
409
410/// Flush all remaining traces during shutdown
411pub async fn shutdown_trace_cache(pool: &PgPool) -> Result<usize, SqlError> {
412    let cache_to_flush = {
413        let guard = TRACE_CACHE.read().await;
414        guard.as_ref().map(|handle| {
415            handle.shutdown_flag.store(true, Ordering::SeqCst);
416            handle.cache.clone()
417        })
418    };
419
420    if let Some(cache) = cache_to_flush {
421        info!("Flushing TraceCache for shutdown...");
422        cache.flush_traces(pool, Duration::seconds(-1)).await
423    } else {
424        Ok(0)
425    }
426}