Skip to main content

scouter_sql/sql/
aggregator.rs

1use crate::sql::error::SqlError;
2use crate::sql::query::Queries;
3use crate::sql::utils::UuidBytea;
4use chrono::{DateTime, Duration, Utc};
5use dashmap::DashMap;
6use scouter_dataframe::parquet::tracing::summary::TraceSummaryService;
7use scouter_types::{
8    Attribute, TraceId, TraceSpanRecord, TraceSummaryRecord, SCOUTER_ENTITY, SCOUTER_QUEUE_RECORD,
9};
10use sqlx::PgPool;
11use std::collections::HashSet;
12use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13use std::sync::Arc;
14use tokio::sync::RwLock;
15use tokio::time::{interval, Duration as StdDuration};
16use tracing::{error, info, warn};
17const TRACE_BATCH_SIZE: usize = 1000;
18
19// ── Global TraceSummaryService singleton ─────────────────────────────────────
20/// Uses `RwLock<Option<...>>` so tests can re-initialize with a fresh service.
21static TRACE_SUMMARY_SERVICE: std::sync::RwLock<Option<Arc<TraceSummaryService>>> =
22    std::sync::RwLock::new(None);
23
24/// Register the global TraceSummaryService. Replaces any previously registered instance.
25pub fn init_trace_summary_service(service: Arc<TraceSummaryService>) -> Result<(), SqlError> {
26    let mut guard = TRACE_SUMMARY_SERVICE
27        .write()
28        .map_err(|e| SqlError::TraceCacheError(format!("Failed to acquire write lock: {}", e)))?;
29    *guard = Some(service);
30    info!("TraceSummaryService global singleton registered in aggregator");
31    Ok(())
32}
33
34/// Retrieve the global TraceSummaryService (if initialized).
35pub fn get_trace_summary_service() -> Option<Arc<TraceSummaryService>> {
36    TRACE_SUMMARY_SERVICE.read().ok()?.clone()
37}
38
39const MAX_TOTAL_SPANS: u64 = 1_000_000;
40
41/// Cache handle to manage trace aggregations
42struct TraceCacheHandle {
43    cache: Arc<TraceCache>,
44    shutdown_flag: Arc<AtomicBool>,
45}
46
47static TRACE_CACHE: RwLock<Option<TraceCacheHandle>> = RwLock::const_new(None);
48
49#[derive(Debug, Clone)]
50pub struct TraceAggregator {
51    pub trace_id: TraceId,
52    pub service_name: String,
53    pub scope_name: String,
54    pub scope_version: String,
55    pub root_operation: String,
56    pub start_time: DateTime<Utc>,
57    pub end_time: Option<DateTime<Utc>>,
58    pub status_code: i32,
59    pub status_message: String,
60    pub span_count: i64,
61    pub error_count: i64,
62    pub resource_attributes: Vec<Attribute>,
63    pub first_seen: DateTime<Utc>,
64    pub last_updated: DateTime<Utc>,
65    pub entity_tags: HashSet<UuidBytea>,
66    pub queue_tags: HashSet<UuidBytea>,
67}
68
69fn extract_value_to_set(attr: &Attribute, set: &mut HashSet<UuidBytea>) -> Option<UuidBytea> {
70    if let serde_json::Value::String(s) = &attr.value {
71        match UuidBytea::from_uuid(s) {
72            Ok(uid) => {
73                set.insert(uid.clone());
74                return Some(uid);
75            }
76            Err(e) => {
77                warn!(%s, "Failed to parse value as UUID for attribute key '{}': {}", attr.key, e)
78            }
79        }
80    }
81    None
82}
83
84impl TraceAggregator {
85    /// Extracts specific attributes from span events to populate entity and queue tag sets
86    pub fn add_ids(&mut self, span: &TraceSpanRecord) {
87        for event in &span.events {
88            for attr in &event.attributes {
89                if attr.key == SCOUTER_QUEUE_RECORD {
90                    extract_value_to_set(attr, &mut self.queue_tags);
91                }
92                if attr.key.starts_with(SCOUTER_ENTITY) {
93                    extract_value_to_set(attr, &mut self.entity_tags);
94                }
95            }
96        }
97        for attr in &span.attributes {
98            if attr.key == SCOUTER_QUEUE_RECORD {
99                extract_value_to_set(attr, &mut self.queue_tags);
100            }
101            if attr.key.starts_with(SCOUTER_ENTITY) {
102                extract_value_to_set(attr, &mut self.entity_tags);
103            }
104        }
105    }
106
107    pub fn new_from_span(span: &TraceSpanRecord) -> Self {
108        let now = Utc::now();
109        let mut aggregator = Self {
110            trace_id: span.trace_id.clone(),
111            service_name: span.service_name.clone(),
112            scope_name: span.scope_name.clone(),
113            scope_version: span.scope_version.clone().unwrap_or_default(),
114            root_operation: if span.parent_span_id.is_none() {
115                span.span_name.clone()
116            } else {
117                String::new()
118            },
119            start_time: span.start_time,
120            end_time: Some(span.end_time),
121            status_code: span.status_code,
122            status_message: span.status_message.clone(),
123            span_count: 1,
124            error_count: if span.status_code == 2 { 1 } else { 0 },
125            resource_attributes: span.resource_attributes.clone(),
126            first_seen: now,
127            last_updated: now,
128            entity_tags: HashSet::new(),
129            queue_tags: HashSet::new(),
130        };
131        aggregator.add_ids(span);
132        aggregator
133    }
134
135    pub fn update_from_span(&mut self, span: &TraceSpanRecord) {
136        if span.start_time < self.start_time {
137            self.start_time = span.start_time;
138        }
139        if let Some(current_end) = self.end_time {
140            if span.end_time > current_end {
141                self.end_time = Some(span.end_time);
142            }
143        } else {
144            self.end_time = Some(span.end_time);
145        }
146
147        if span.parent_span_id.is_none() {
148            self.root_operation = span.span_name.clone();
149            self.service_name = span.service_name.clone();
150            self.scope_name = span.scope_name.clone();
151            if let Some(version) = &span.scope_version {
152                self.scope_version = version.clone();
153            }
154            self.resource_attributes = span.resource_attributes.clone();
155        }
156
157        if span.status_code == 2 {
158            self.error_count += 1;
159            self.status_code = 2;
160            self.status_message = span.status_message.clone();
161        }
162
163        self.span_count += 1;
164        self.last_updated = Utc::now();
165        self.add_ids(span);
166    }
167
168    pub fn duration_ms(&self) -> Option<i64> {
169        self.end_time
170            .map(|end| (end - self.start_time).num_milliseconds())
171    }
172
173    pub fn is_stale(&self, stale_duration: Duration) -> bool {
174        (Utc::now() - self.last_updated) >= stale_duration
175    }
176
177    /// Convert to the lightweight `TraceSummaryRecord` for Delta Lake writes.
178    pub fn to_summary_record(&self) -> TraceSummaryRecord {
179        let entity_ids: Vec<String> = self
180            .entity_tags
181            .iter()
182            .map(|e| uuid::Uuid::from_bytes(e.0).to_string())
183            .collect();
184        let queue_ids: Vec<String> = self
185            .queue_tags
186            .iter()
187            .map(|q| uuid::Uuid::from_bytes(q.0).to_string())
188            .collect();
189        TraceSummaryRecord {
190            trace_id: self.trace_id.clone(),
191            service_name: self.service_name.clone(),
192            scope_name: self.scope_name.clone(),
193            scope_version: self.scope_version.clone(),
194            root_operation: self.root_operation.clone(),
195            start_time: self.start_time,
196            end_time: self.end_time,
197            status_code: self.status_code,
198            status_message: self.status_message.clone(),
199            span_count: self.span_count,
200            error_count: self.error_count,
201            resource_attributes: self.resource_attributes.clone(),
202            entity_ids,
203            queue_ids,
204        }
205    }
206}
207
208pub struct TraceCache {
209    traces: DashMap<TraceId, TraceAggregator>,
210    max_traces: usize,
211    total_span_count: AtomicU64,
212}
213
214impl TraceCache {
215    fn new(max_traces: usize) -> Self {
216        Self {
217            traces: DashMap::new(),
218            max_traces,
219            total_span_count: AtomicU64::new(0),
220        }
221    }
222
223    /// Update trace aggregation from a span. Uses Arc<Self> to enable background flushing.
224    pub async fn update_trace(self: &Arc<Self>, span: &TraceSpanRecord) {
225        let current_traces = self.traces.len();
226        let current_spans = self.total_span_count.load(Ordering::Relaxed);
227
228        // Check trace and span pressure
229        let trace_pressure = (current_traces * 100) / self.max_traces;
230        let span_pressure = (current_spans * 100) / MAX_TOTAL_SPANS;
231        let max_pressure = trace_pressure.max(span_pressure as usize);
232
233        // If near capacity, log warning (background flush task will handle it)
234        if max_pressure >= 90 {
235            warn!(
236                current_traces,
237                current_spans,
238                max_pressure,
239                "TraceCache high memory pressure, will flush on next interval"
240            );
241        }
242        self.traces
243            .entry(span.trace_id.clone())
244            .and_modify(|agg| {
245                agg.update_from_span(span);
246                self.total_span_count.fetch_add(1, Ordering::Relaxed);
247            })
248            .or_insert_with(|| {
249                self.total_span_count.fetch_add(1, Ordering::Relaxed);
250                TraceAggregator::new_from_span(span)
251            });
252    }
253
254    pub async fn flush_traces(
255        &self,
256        pool: &PgPool,
257        stale_threshold: Duration,
258    ) -> Result<usize, SqlError> {
259        let stale_keys: Vec<TraceId> = self
260            .traces
261            .iter()
262            .filter(|entry| entry.value().is_stale(stale_threshold))
263            .map(|entry| entry.key().clone())
264            .collect();
265
266        if stale_keys.is_empty() {
267            return Ok(0);
268        }
269
270        let mut to_flush = Vec::with_capacity(stale_keys.len());
271        let mut spans_freed = 0u64;
272
273        for id in stale_keys {
274            if let Some((_, agg)) = self.traces.remove(&id) {
275                spans_freed += agg.span_count as u64;
276                to_flush.push((id, agg));
277            }
278        }
279
280        self.total_span_count
281            .fetch_sub(spans_freed, Ordering::Relaxed);
282
283        let count = to_flush.len();
284        info!(
285            flushed_traces = count,
286            freed_spans = spans_freed,
287            remaining_traces = self.traces.len(),
288            remaining_spans = self.total_span_count.load(Ordering::Relaxed),
289            "Flushed stale traces"
290        );
291
292        for chunk in to_flush.chunks(TRACE_BATCH_SIZE) {
293            self.upsert_batch(pool, chunk).await?;
294        }
295        Ok(count)
296    }
297
298    /// Write a batch of trace aggregations.
299    ///
300    /// Primary: Delta Lake via `TraceSummaryService` (span counts, timing, error rates).
301    /// Secondary: Postgres for entity tag associations only (unchanged).
302    async fn upsert_batch(
303        &self,
304        pool: &PgPool,
305        traces: &[(TraceId, TraceAggregator)],
306    ) -> Result<(), SqlError> {
307        // ── Delta Lake: write summary records ────────────────────────────────
308        if let Some(summary_service) = get_trace_summary_service() {
309            let records: Vec<TraceSummaryRecord> = traces
310                .iter()
311                .map(|(_, agg)| agg.to_summary_record())
312                .collect();
313            if let Err(e) = summary_service.write_summaries(records).await {
314                error!("Failed to write trace summaries to Delta Lake: {:?}", e);
315            }
316        }
317
318        // ── Postgres: entity tag associations only ────────────────────────────
319        let mut entity_trace_ids = Vec::new();
320        let mut entity_uids = Vec::new();
321        let mut entity_tagged_ats = Vec::new();
322        let now = Utc::now();
323
324        for (trace_id, agg) in traces {
325            for entity_uid in &agg.entity_tags {
326                entity_trace_ids.push(trace_id.as_bytes());
327                entity_uids.push(entity_uid.as_bytes());
328                entity_tagged_ats.push(now);
329            }
330        }
331
332        if !entity_trace_ids.is_empty() {
333            sqlx::query(Queries::InsertTraceEntityTags.get_query())
334                .bind(&entity_trace_ids)
335                .bind(&entity_uids)
336                .bind(&entity_tagged_ats)
337                .execute(pool)
338                .await?;
339        }
340
341        Ok(())
342    }
343}
344
345/// Initialize the TraceCache, replacing any previous instance.
346/// The old background flush task is signaled to stop and any remaining
347/// traces are flushed with the NEW pool before the cache is swapped.
348pub async fn init_trace_cache(
349    pool: PgPool,
350    flush_interval: Duration,
351    stale_threshold: Duration,
352    max_traces: usize,
353) -> Result<(), SqlError> {
354    // Shut down any existing cache first
355    let old_cache = {
356        let guard = TRACE_CACHE.read().await;
357        guard.as_ref().map(|handle| {
358            handle.shutdown_flag.store(true, Ordering::SeqCst);
359            handle.cache.clone()
360        })
361    };
362
363    // Flush outside so we dont hold the lock
364    if let Some(cache) = old_cache {
365        info!("Flushing previous TraceCache before re-initialization...");
366        if let Err(e) = cache.flush_traces(&pool, Duration::seconds(-1)).await {
367            error!(error = %e, "Failed to flush previous TraceCache");
368        }
369    }
370
371    let cache = Arc::new(TraceCache::new(max_traces));
372    let shutdown_flag = Arc::new(AtomicBool::new(false));
373
374    {
375        let mut guard = TRACE_CACHE.write().await;
376        *guard = Some(TraceCacheHandle {
377            cache: cache.clone(),
378            shutdown_flag: shutdown_flag.clone(),
379        });
380    }
381
382    let flush_std_duration = StdDuration::from_secs(flush_interval.num_seconds() as u64);
383    let task_shutdown = shutdown_flag.clone();
384
385    tokio::spawn(async move {
386        let mut ticker = interval(flush_std_duration);
387        loop {
388            ticker.tick().await;
389
390            if task_shutdown.load(Ordering::SeqCst) {
391                info!("TraceCache background flush task shutting down");
392                break;
393            }
394
395            let current_traces = cache.traces.len();
396            let current_spans = cache.total_span_count.load(Ordering::Relaxed);
397
398            let threshold = if current_traces > max_traces || current_spans > MAX_TOTAL_SPANS {
399                warn!(
400                    current_traces,
401                    current_spans, "Emergency flush triggered due to memory pressure"
402                );
403                Duration::seconds(0)
404            } else {
405                stale_threshold
406            };
407
408            if let Err(e) = cache.flush_traces(&pool, threshold).await {
409                error!(error = %e, "Flush task failed");
410            }
411        }
412    });
413
414    info!("TraceCache initialized");
415    Ok(())
416}
417
418/// Get access to the current TraceCache
419pub async fn get_trace_cache() -> Arc<TraceCache> {
420    TRACE_CACHE
421        .read()
422        .await
423        .as_ref()
424        .expect("TraceCache not initialized")
425        .cache
426        .clone()
427}
428
429/// Flush all remaining traces during shutdown
430pub async fn shutdown_trace_cache(pool: &PgPool) -> Result<usize, SqlError> {
431    let cache_to_flush = {
432        let guard = TRACE_CACHE.read().await;
433        guard.as_ref().map(|handle| {
434            handle.shutdown_flag.store(true, Ordering::SeqCst);
435            handle.cache.clone()
436        })
437    };
438
439    if let Some(cache) = cache_to_flush {
440        info!("Flushing TraceCache for shutdown...");
441        cache.flush_traces(pool, Duration::seconds(-1)).await
442    } else {
443        Ok(0)
444    }
445}