1use crate::sql::error::SqlError;
2use crate::sql::query::Queries;
3use crate::sql::utils::EntityBytea;
4use chrono::{DateTime, Duration, Timelike, Utc};
5use dashmap::DashMap;
6use scouter_types::{Attribute, TraceId, TraceSpanRecord, SCOUTER_ENTITY};
7use sqlx::PgPool;
8use std::collections::HashSet;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use tokio::time::{interval, Duration as StdDuration};
13use tracing::{error, info, warn};
14const TRACE_BATCH_SIZE: usize = 1000;
15
16const MAX_TOTAL_SPANS: u64 = 1_000_000;
17
18struct TraceCacheHandle {
20 cache: Arc<TraceCache>,
21 shutdown_flag: Arc<AtomicBool>,
22}
23
24static TRACE_CACHE: RwLock<Option<TraceCacheHandle>> = RwLock::const_new(None);
25
26#[derive(Debug, Clone)]
27pub struct TraceAggregator {
28 pub trace_id: TraceId,
29 pub service_name: String,
30 pub scope_name: String,
31 pub scope_version: String,
32 pub root_operation: String,
33 pub start_time: DateTime<Utc>,
34 pub end_time: Option<DateTime<Utc>>,
35 pub status_code: i32,
36 pub status_message: String,
37 pub span_count: i64,
38 pub error_count: i64,
39 pub resource_attributes: Vec<Attribute>,
40 pub first_seen: DateTime<Utc>,
41 pub last_updated: DateTime<Utc>,
42 pub entity_tags: HashSet<EntityBytea>,
43}
44
45impl TraceAggregator {
46 pub fn bucket_time(&self) -> DateTime<Utc> {
47 self.start_time
48 .with_minute(0)
49 .unwrap()
50 .with_second(0)
51 .unwrap()
52 .with_nanosecond(0)
53 .unwrap()
54 }
55
56 pub fn add_entities(&mut self, span: &TraceSpanRecord) {
60 for attr in &span.attributes {
61 if attr.key.starts_with(SCOUTER_ENTITY) {
62 let entity = match &attr.value {
64 serde_json::Value::String(s) => s.clone(),
65 _ => continue, };
67 match EntityBytea::from_uuid(&entity) {
68 Ok(uid) => {
69 self.entity_tags.insert(uid);
70 }
71 Err(e) => {
72 warn!(%entity, "Failed to parse entity UID from attribute value in span {}: {}", span.span_id, e);
73 }
74 }
75 }
76 }
77 }
78
79 pub fn new_from_span(span: &TraceSpanRecord) -> Self {
80 let now = Utc::now();
81 let mut aggregator = Self {
82 trace_id: span.trace_id.clone(),
83 service_name: span.service_name.clone(),
84 scope_name: span.scope_name.clone(),
85 scope_version: span.scope_version.clone().unwrap_or_default(),
86 root_operation: if span.parent_span_id.is_none() {
87 span.span_name.clone()
88 } else {
89 String::new()
90 },
91 start_time: span.start_time,
92 end_time: Some(span.end_time),
93 status_code: span.status_code,
94 status_message: span.status_message.clone(),
95 span_count: 1,
96 error_count: if span.status_code == 2 { 1 } else { 0 },
97 resource_attributes: span.resource_attributes.clone(),
98 first_seen: now,
99 last_updated: now,
100 entity_tags: HashSet::new(),
101 };
102 aggregator.add_entities(span);
103 aggregator
104 }
105
106 pub fn update_from_span(&mut self, span: &TraceSpanRecord) {
107 if span.start_time < self.start_time {
108 self.start_time = span.start_time;
109 }
110 if let Some(current_end) = self.end_time {
111 if span.end_time > current_end {
112 self.end_time = Some(span.end_time);
113 }
114 } else {
115 self.end_time = Some(span.end_time);
116 }
117
118 if span.parent_span_id.is_none() {
119 self.root_operation = span.span_name.clone();
120 self.service_name = span.service_name.clone();
121 self.scope_name = span.scope_name.clone();
122 if let Some(version) = &span.scope_version {
123 self.scope_version = version.clone();
124 }
125 self.resource_attributes = span.resource_attributes.clone();
126 }
127
128 if span.status_code == 2 {
129 self.error_count += 1;
130 self.status_code = 2;
131 self.status_message = span.status_message.clone();
132 }
133
134 self.span_count += 1;
135 self.last_updated = Utc::now();
136 self.add_entities(span);
137 }
138
139 pub fn duration_ms(&self) -> Option<i64> {
140 self.end_time
141 .map(|end| (end - self.start_time).num_milliseconds())
142 }
143
144 pub fn is_stale(&self, stale_duration: Duration) -> bool {
145 (Utc::now() - self.last_updated) >= stale_duration
146 }
147}
148
149pub struct TraceCache {
150 traces: DashMap<TraceId, TraceAggregator>,
151 max_traces: usize,
152 total_span_count: AtomicU64,
153}
154
155impl TraceCache {
156 fn new(max_traces: usize) -> Self {
157 Self {
158 traces: DashMap::new(),
159 max_traces,
160 total_span_count: AtomicU64::new(0),
161 }
162 }
163
164 pub async fn update_trace(self: &Arc<Self>, span: &TraceSpanRecord) {
166 let current_traces = self.traces.len();
167 let current_spans = self.total_span_count.load(Ordering::Relaxed);
168
169 let trace_pressure = (current_traces * 100) / self.max_traces;
171 let span_pressure = (current_spans * 100) / MAX_TOTAL_SPANS;
172 let max_pressure = trace_pressure.max(span_pressure as usize);
173
174 if max_pressure >= 90 {
176 warn!(
177 current_traces,
178 current_spans,
179 max_pressure,
180 "TraceCache high memory pressure, will flush on next interval"
181 );
182 }
183 self.traces
184 .entry(span.trace_id.clone())
185 .and_modify(|agg| {
186 agg.update_from_span(span);
187 self.total_span_count.fetch_add(1, Ordering::Relaxed);
188 })
189 .or_insert_with(|| {
190 self.total_span_count.fetch_add(1, Ordering::Relaxed);
191 TraceAggregator::new_from_span(span)
192 });
193 }
194
195 pub async fn flush_traces(
196 &self,
197 pool: &PgPool,
198 stale_threshold: Duration,
199 ) -> Result<usize, SqlError> {
200 let stale_keys: Vec<TraceId> = self
201 .traces
202 .iter()
203 .filter(|entry| entry.value().is_stale(stale_threshold))
204 .map(|entry| entry.key().clone())
205 .collect();
206
207 if stale_keys.is_empty() {
208 return Ok(0);
209 }
210
211 let mut to_flush = Vec::with_capacity(stale_keys.len());
212 let mut spans_freed = 0u64;
213
214 for id in stale_keys {
215 if let Some((_, agg)) = self.traces.remove(&id) {
216 spans_freed += agg.span_count as u64;
217 to_flush.push((id, agg));
218 }
219 }
220
221 self.total_span_count
223 .fetch_sub(spans_freed, Ordering::Relaxed);
224
225 let count = to_flush.len();
226 info!(
227 flushed_traces = count,
228 freed_spans = spans_freed,
229 remaining_traces = self.traces.len(),
230 remaining_spans = self.total_span_count.load(Ordering::Relaxed),
231 "Flushed stale traces"
232 );
233
234 for chunk in to_flush.chunks(TRACE_BATCH_SIZE) {
235 self.upsert_batch(pool, chunk).await?;
236 }
237 Ok(count)
238 }
239
240 async fn upsert_batch(
241 &self,
242 pool: &PgPool,
243 traces: &[(TraceId, TraceAggregator)],
244 ) -> Result<(), SqlError> {
245 let mut created_ats = Vec::with_capacity(traces.len());
246 let mut bucket_times = Vec::with_capacity(traces.len());
247 let mut trace_ids = Vec::with_capacity(traces.len());
248 let mut service_names = Vec::with_capacity(traces.len());
249 let mut scope_names = Vec::with_capacity(traces.len());
250 let mut scope_versions = Vec::with_capacity(traces.len());
251 let mut root_operations = Vec::with_capacity(traces.len());
252 let mut start_times = Vec::with_capacity(traces.len());
253 let mut end_times = Vec::with_capacity(traces.len());
254 let mut durations_ms = Vec::with_capacity(traces.len());
255 let mut status_codes = Vec::with_capacity(traces.len());
256 let mut status_messages = Vec::with_capacity(traces.len());
257 let mut span_counts = Vec::with_capacity(traces.len());
258 let mut error_counts = Vec::with_capacity(traces.len());
259 let mut resource_attrs = Vec::with_capacity(traces.len());
260
261 let mut entity_trace_ids = Vec::new();
263 let mut entity_uids = Vec::new();
264 let mut entity_tagged_ats = Vec::new();
265 let now = Utc::now();
266
267 for (trace_id, agg) in traces {
268 created_ats.push(agg.first_seen);
269 bucket_times.push(agg.bucket_time());
270 trace_ids.push(trace_id.as_bytes());
271 service_names.push(&agg.service_name);
272 scope_names.push(&agg.scope_name);
273 scope_versions.push(&agg.scope_version);
274 root_operations.push(&agg.root_operation);
275 start_times.push(agg.start_time);
276 end_times.push(agg.end_time.unwrap_or(agg.start_time));
277 durations_ms.push(agg.duration_ms().unwrap_or(0));
278 status_codes.push(agg.status_code);
279 status_messages.push(&agg.status_message);
280 span_counts.push(agg.span_count);
281 error_counts.push(agg.error_count);
282 resource_attrs.push(serde_json::to_value(&agg.resource_attributes)?);
283
284 for entity_uid in &agg.entity_tags {
286 entity_trace_ids.push(trace_id.as_bytes());
287 entity_uids.push(entity_uid.as_bytes());
288 entity_tagged_ats.push(now);
289 }
290 }
291
292 sqlx::query(Queries::UpsertTrace.get_query())
294 .bind(&bucket_times)
295 .bind(&created_ats)
296 .bind(&trace_ids)
297 .bind(&service_names as &[&String])
298 .bind(&scope_names as &[&String])
299 .bind(&scope_versions as &[&String])
300 .bind(&root_operations as &[&String])
301 .bind(&start_times)
302 .bind(&end_times)
303 .bind(&durations_ms)
304 .bind(&status_codes)
305 .bind(&status_messages as &[&String])
306 .bind(&span_counts)
307 .bind(&error_counts)
308 .bind(&resource_attrs)
309 .execute(pool)
310 .await?;
311
312 if !entity_trace_ids.is_empty() {
314 sqlx::query(Queries::InsertTraceEntityTags.get_query())
315 .bind(&entity_trace_ids)
316 .bind(&entity_uids)
317 .bind(&entity_tagged_ats)
318 .execute(pool)
319 .await?;
320 }
321
322 Ok(())
323 }
324}
325
326pub async fn init_trace_cache(
330 pool: PgPool,
331 flush_interval: Duration,
332 stale_threshold: Duration,
333 max_traces: usize,
334) -> Result<(), SqlError> {
335 let old_cache = {
337 let guard = TRACE_CACHE.read().await;
338 guard.as_ref().map(|handle| {
339 handle.shutdown_flag.store(true, Ordering::SeqCst);
340 handle.cache.clone()
341 })
342 };
343
344 if let Some(cache) = old_cache {
346 info!("Flushing previous TraceCache before re-initialization...");
347 if let Err(e) = cache.flush_traces(&pool, Duration::seconds(-1)).await {
348 error!(error = %e, "Failed to flush previous TraceCache");
349 }
350 }
351
352 let cache = Arc::new(TraceCache::new(max_traces));
353 let shutdown_flag = Arc::new(AtomicBool::new(false));
354
355 {
356 let mut guard = TRACE_CACHE.write().await;
357 *guard = Some(TraceCacheHandle {
358 cache: cache.clone(),
359 shutdown_flag: shutdown_flag.clone(),
360 });
361 }
362
363 let flush_std_duration = StdDuration::from_secs(flush_interval.num_seconds() as u64);
364 let task_shutdown = shutdown_flag.clone();
365
366 tokio::spawn(async move {
367 let mut ticker = interval(flush_std_duration);
368 loop {
369 ticker.tick().await;
370
371 if task_shutdown.load(Ordering::SeqCst) {
372 info!("TraceCache background flush task shutting down");
373 break;
374 }
375
376 let current_traces = cache.traces.len();
377 let current_spans = cache.total_span_count.load(Ordering::Relaxed);
378
379 let threshold = if current_traces > max_traces || current_spans > MAX_TOTAL_SPANS {
380 warn!(
381 current_traces,
382 current_spans, "Emergency flush triggered due to memory pressure"
383 );
384 Duration::seconds(0)
385 } else {
386 stale_threshold
387 };
388
389 if let Err(e) = cache.flush_traces(&pool, threshold).await {
390 error!(error = %e, "Flush task failed");
391 }
392 }
393 });
394
395 info!("TraceCache initialized");
396 Ok(())
397}
398
399pub async fn get_trace_cache() -> Arc<TraceCache> {
401 TRACE_CACHE
402 .read()
403 .await
404 .as_ref()
405 .expect("TraceCache not initialized")
406 .cache
407 .clone()
408}
409
410pub async fn shutdown_trace_cache(pool: &PgPool) -> Result<usize, SqlError> {
412 let cache_to_flush = {
413 let guard = TRACE_CACHE.read().await;
414 guard.as_ref().map(|handle| {
415 handle.shutdown_flag.store(true, Ordering::SeqCst);
416 handle.cache.clone()
417 })
418 };
419
420 if let Some(cache) = cache_to_flush {
421 info!("Flushing TraceCache for shutdown...");
422 cache.flush_traces(pool, Duration::seconds(-1)).await
423 } else {
424 Ok(0)
425 }
426}