1use crate::sql::error::SqlError;
2use crate::sql::query::Queries;
3use crate::sql::utils::UuidBytea;
4use chrono::{DateTime, Duration, Utc};
5use dashmap::DashMap;
6use scouter_dataframe::parquet::tracing::summary::TraceSummaryService;
7use scouter_types::{
8 Attribute, TraceId, TraceSpanRecord, TraceSummaryRecord, SCOUTER_ENTITY, SCOUTER_QUEUE_RECORD,
9};
10use sqlx::PgPool;
11use std::collections::HashSet;
12use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13use std::sync::Arc;
14use tokio::sync::RwLock;
15use tokio::time::{interval, Duration as StdDuration};
16use tracing::{error, info, warn};
17const TRACE_BATCH_SIZE: usize = 1000;
18
19static TRACE_SUMMARY_SERVICE: std::sync::RwLock<Option<Arc<TraceSummaryService>>> =
22 std::sync::RwLock::new(None);
23
24pub fn init_trace_summary_service(service: Arc<TraceSummaryService>) -> Result<(), SqlError> {
26 let mut guard = TRACE_SUMMARY_SERVICE
27 .write()
28 .map_err(|e| SqlError::TraceCacheError(format!("Failed to acquire write lock: {}", e)))?;
29 *guard = Some(service);
30 info!("TraceSummaryService global singleton registered in aggregator");
31 Ok(())
32}
33
34pub fn get_trace_summary_service() -> Option<Arc<TraceSummaryService>> {
36 TRACE_SUMMARY_SERVICE.read().ok()?.clone()
37}
38
39const MAX_TOTAL_SPANS: u64 = 1_000_000;
40
41struct TraceCacheHandle {
43 cache: Arc<TraceCache>,
44 shutdown_flag: Arc<AtomicBool>,
45}
46
47static TRACE_CACHE: RwLock<Option<TraceCacheHandle>> = RwLock::const_new(None);
48
49#[derive(Debug, Clone)]
50pub struct TraceAggregator {
51 pub trace_id: TraceId,
52 pub service_name: String,
53 pub scope_name: String,
54 pub scope_version: String,
55 pub root_operation: String,
56 pub start_time: DateTime<Utc>,
57 pub end_time: Option<DateTime<Utc>>,
58 pub status_code: i32,
59 pub status_message: String,
60 pub span_count: i64,
61 pub error_count: i64,
62 pub resource_attributes: Vec<Attribute>,
63 pub first_seen: DateTime<Utc>,
64 pub last_updated: DateTime<Utc>,
65 pub entity_tags: HashSet<UuidBytea>,
66 pub queue_tags: HashSet<UuidBytea>,
67}
68
69fn extract_value_to_set(attr: &Attribute, set: &mut HashSet<UuidBytea>) -> Option<UuidBytea> {
70 if let serde_json::Value::String(s) = &attr.value {
71 match UuidBytea::from_uuid(s) {
72 Ok(uid) => {
73 set.insert(uid.clone());
74 return Some(uid);
75 }
76 Err(e) => {
77 warn!(%s, "Failed to parse value as UUID for attribute key '{}': {}", attr.key, e)
78 }
79 }
80 }
81 None
82}
83
84impl TraceAggregator {
85 pub fn add_ids(&mut self, span: &TraceSpanRecord) {
87 for event in &span.events {
88 for attr in &event.attributes {
89 if attr.key == SCOUTER_QUEUE_RECORD {
90 extract_value_to_set(attr, &mut self.queue_tags);
91 }
92 if attr.key.starts_with(SCOUTER_ENTITY) {
93 extract_value_to_set(attr, &mut self.entity_tags);
94 }
95 }
96 }
97 for attr in &span.attributes {
98 if attr.key == SCOUTER_QUEUE_RECORD {
99 extract_value_to_set(attr, &mut self.queue_tags);
100 }
101 if attr.key.starts_with(SCOUTER_ENTITY) {
102 extract_value_to_set(attr, &mut self.entity_tags);
103 }
104 }
105 }
106
107 pub fn new_from_span(span: &TraceSpanRecord) -> Self {
108 let now = Utc::now();
109 let mut aggregator = Self {
110 trace_id: span.trace_id.clone(),
111 service_name: span.service_name.clone(),
112 scope_name: span.scope_name.clone(),
113 scope_version: span.scope_version.clone().unwrap_or_default(),
114 root_operation: if span.parent_span_id.is_none() {
115 span.span_name.clone()
116 } else {
117 String::new()
118 },
119 start_time: span.start_time,
120 end_time: Some(span.end_time),
121 status_code: span.status_code,
122 status_message: span.status_message.clone(),
123 span_count: 1,
124 error_count: if span.status_code == 2 { 1 } else { 0 },
125 resource_attributes: span.resource_attributes.clone(),
126 first_seen: now,
127 last_updated: now,
128 entity_tags: HashSet::new(),
129 queue_tags: HashSet::new(),
130 };
131 aggregator.add_ids(span);
132 aggregator
133 }
134
135 pub fn update_from_span(&mut self, span: &TraceSpanRecord) {
136 if span.start_time < self.start_time {
137 self.start_time = span.start_time;
138 }
139 if let Some(current_end) = self.end_time {
140 if span.end_time > current_end {
141 self.end_time = Some(span.end_time);
142 }
143 } else {
144 self.end_time = Some(span.end_time);
145 }
146
147 if span.parent_span_id.is_none() {
148 self.root_operation = span.span_name.clone();
149 self.service_name = span.service_name.clone();
150 self.scope_name = span.scope_name.clone();
151 if let Some(version) = &span.scope_version {
152 self.scope_version = version.clone();
153 }
154 self.resource_attributes = span.resource_attributes.clone();
155 }
156
157 if span.status_code == 2 {
158 self.error_count += 1;
159 self.status_code = 2;
160 self.status_message = span.status_message.clone();
161 }
162
163 self.span_count += 1;
164 self.last_updated = Utc::now();
165 self.add_ids(span);
166 }
167
168 pub fn duration_ms(&self) -> Option<i64> {
169 self.end_time
170 .map(|end| (end - self.start_time).num_milliseconds())
171 }
172
173 pub fn is_stale(&self, stale_duration: Duration) -> bool {
174 (Utc::now() - self.last_updated) >= stale_duration
175 }
176
177 pub fn to_summary_record(&self) -> TraceSummaryRecord {
179 let entity_ids: Vec<String> = self
180 .entity_tags
181 .iter()
182 .map(|e| uuid::Uuid::from_bytes(e.0).to_string())
183 .collect();
184 let queue_ids: Vec<String> = self
185 .queue_tags
186 .iter()
187 .map(|q| uuid::Uuid::from_bytes(q.0).to_string())
188 .collect();
189 TraceSummaryRecord {
190 trace_id: self.trace_id.clone(),
191 service_name: self.service_name.clone(),
192 scope_name: self.scope_name.clone(),
193 scope_version: self.scope_version.clone(),
194 root_operation: self.root_operation.clone(),
195 start_time: self.start_time,
196 end_time: self.end_time,
197 status_code: self.status_code,
198 status_message: self.status_message.clone(),
199 span_count: self.span_count,
200 error_count: self.error_count,
201 resource_attributes: self.resource_attributes.clone(),
202 entity_ids,
203 queue_ids,
204 }
205 }
206}
207
208pub struct TraceCache {
209 traces: DashMap<TraceId, TraceAggregator>,
210 max_traces: usize,
211 total_span_count: AtomicU64,
212}
213
214impl TraceCache {
215 fn new(max_traces: usize) -> Self {
216 Self {
217 traces: DashMap::new(),
218 max_traces,
219 total_span_count: AtomicU64::new(0),
220 }
221 }
222
223 pub async fn update_trace(self: &Arc<Self>, span: &TraceSpanRecord) {
225 let current_traces = self.traces.len();
226 let current_spans = self.total_span_count.load(Ordering::Relaxed);
227
228 let trace_pressure = (current_traces * 100) / self.max_traces;
230 let span_pressure = (current_spans * 100) / MAX_TOTAL_SPANS;
231 let max_pressure = trace_pressure.max(span_pressure as usize);
232
233 if max_pressure >= 90 {
235 warn!(
236 current_traces,
237 current_spans,
238 max_pressure,
239 "TraceCache high memory pressure, will flush on next interval"
240 );
241 }
242 self.traces
243 .entry(span.trace_id.clone())
244 .and_modify(|agg| {
245 agg.update_from_span(span);
246 self.total_span_count.fetch_add(1, Ordering::Relaxed);
247 })
248 .or_insert_with(|| {
249 self.total_span_count.fetch_add(1, Ordering::Relaxed);
250 TraceAggregator::new_from_span(span)
251 });
252 }
253
254 pub async fn flush_traces(
255 &self,
256 pool: &PgPool,
257 stale_threshold: Duration,
258 ) -> Result<usize, SqlError> {
259 let stale_keys: Vec<TraceId> = self
260 .traces
261 .iter()
262 .filter(|entry| entry.value().is_stale(stale_threshold))
263 .map(|entry| entry.key().clone())
264 .collect();
265
266 if stale_keys.is_empty() {
267 return Ok(0);
268 }
269
270 let mut to_flush = Vec::with_capacity(stale_keys.len());
271 let mut spans_freed = 0u64;
272
273 for id in stale_keys {
274 if let Some((_, agg)) = self.traces.remove(&id) {
275 spans_freed += agg.span_count as u64;
276 to_flush.push((id, agg));
277 }
278 }
279
280 self.total_span_count
281 .fetch_sub(spans_freed, Ordering::Relaxed);
282
283 let count = to_flush.len();
284 info!(
285 flushed_traces = count,
286 freed_spans = spans_freed,
287 remaining_traces = self.traces.len(),
288 remaining_spans = self.total_span_count.load(Ordering::Relaxed),
289 "Flushed stale traces"
290 );
291
292 for chunk in to_flush.chunks(TRACE_BATCH_SIZE) {
293 self.upsert_batch(pool, chunk).await?;
294 }
295 Ok(count)
296 }
297
298 async fn upsert_batch(
303 &self,
304 pool: &PgPool,
305 traces: &[(TraceId, TraceAggregator)],
306 ) -> Result<(), SqlError> {
307 if let Some(summary_service) = get_trace_summary_service() {
309 let records: Vec<TraceSummaryRecord> = traces
310 .iter()
311 .map(|(_, agg)| agg.to_summary_record())
312 .collect();
313 if let Err(e) = summary_service.write_summaries(records).await {
314 error!("Failed to write trace summaries to Delta Lake: {:?}", e);
315 }
316 }
317
318 let mut entity_trace_ids = Vec::new();
320 let mut entity_uids = Vec::new();
321 let mut entity_tagged_ats = Vec::new();
322 let now = Utc::now();
323
324 for (trace_id, agg) in traces {
325 for entity_uid in &agg.entity_tags {
326 entity_trace_ids.push(trace_id.as_bytes());
327 entity_uids.push(entity_uid.as_bytes());
328 entity_tagged_ats.push(now);
329 }
330 }
331
332 if !entity_trace_ids.is_empty() {
333 sqlx::query(Queries::InsertTraceEntityTags.get_query())
334 .bind(&entity_trace_ids)
335 .bind(&entity_uids)
336 .bind(&entity_tagged_ats)
337 .execute(pool)
338 .await?;
339 }
340
341 Ok(())
342 }
343}
344
345pub async fn init_trace_cache(
349 pool: PgPool,
350 flush_interval: Duration,
351 stale_threshold: Duration,
352 max_traces: usize,
353) -> Result<(), SqlError> {
354 let old_cache = {
356 let guard = TRACE_CACHE.read().await;
357 guard.as_ref().map(|handle| {
358 handle.shutdown_flag.store(true, Ordering::SeqCst);
359 handle.cache.clone()
360 })
361 };
362
363 if let Some(cache) = old_cache {
365 info!("Flushing previous TraceCache before re-initialization...");
366 if let Err(e) = cache.flush_traces(&pool, Duration::seconds(-1)).await {
367 error!(error = %e, "Failed to flush previous TraceCache");
368 }
369 }
370
371 let cache = Arc::new(TraceCache::new(max_traces));
372 let shutdown_flag = Arc::new(AtomicBool::new(false));
373
374 {
375 let mut guard = TRACE_CACHE.write().await;
376 *guard = Some(TraceCacheHandle {
377 cache: cache.clone(),
378 shutdown_flag: shutdown_flag.clone(),
379 });
380 }
381
382 let flush_std_duration = StdDuration::from_secs(flush_interval.num_seconds() as u64);
383 let task_shutdown = shutdown_flag.clone();
384
385 tokio::spawn(async move {
386 let mut ticker = interval(flush_std_duration);
387 loop {
388 ticker.tick().await;
389
390 if task_shutdown.load(Ordering::SeqCst) {
391 info!("TraceCache background flush task shutting down");
392 break;
393 }
394
395 let current_traces = cache.traces.len();
396 let current_spans = cache.total_span_count.load(Ordering::Relaxed);
397
398 let threshold = if current_traces > max_traces || current_spans > MAX_TOTAL_SPANS {
399 warn!(
400 current_traces,
401 current_spans, "Emergency flush triggered due to memory pressure"
402 );
403 Duration::seconds(0)
404 } else {
405 stale_threshold
406 };
407
408 if let Err(e) = cache.flush_traces(&pool, threshold).await {
409 error!(error = %e, "Flush task failed");
410 }
411 }
412 });
413
414 info!("TraceCache initialized");
415 Ok(())
416}
417
418pub async fn get_trace_cache() -> Arc<TraceCache> {
420 TRACE_CACHE
421 .read()
422 .await
423 .as_ref()
424 .expect("TraceCache not initialized")
425 .cache
426 .clone()
427}
428
429pub async fn shutdown_trace_cache(pool: &PgPool) -> Result<usize, SqlError> {
431 let cache_to_flush = {
432 let guard = TRACE_CACHE.read().await;
433 guard.as_ref().map(|handle| {
434 handle.shutdown_flag.store(true, Ordering::SeqCst);
435 handle.cache.clone()
436 })
437 };
438
439 if let Some(cache) = cache_to_flush {
440 info!("Flushing TraceCache for shutdown...");
441 cache.flush_traces(pool, Duration::seconds(-1)).await
442 } else {
443 Ok(0)
444 }
445}