Skip to main content

scouter_sql/sql/traits/
trace.rs

1use crate::sql::error::SqlError;
2use crate::sql::query::Queries;
3
4use crate::sql::aggregator::get_trace_cache;
5use crate::sql::utils::EntityBytea;
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use itertools::multiunzip;
9use scouter_types::sql::{TraceFilters, TraceListItem, TraceMetricBucket, TraceSpan};
10use scouter_types::{
11    TraceBaggageRecord, TraceCursor, TraceId, TracePaginationResponse, TraceSpanRecord,
12};
13use sqlx::{postgres::PgQueryResult, types::Json, Pool, Postgres};
14use std::collections::HashMap;
15use tracing::{error, instrument};
16#[async_trait]
17pub trait TraceSqlLogic {
18    /// Attempts to insert multiple trace span records into the database in a batch.
19    ///
20    /// # Arguments
21    /// * `pool` - The database connection pool
22    /// * `spans` - The trace span records to insert
23    async fn insert_span_batch(
24        pool: &Pool<Postgres>,
25        spans: &[TraceSpanRecord],
26    ) -> Result<PgQueryResult, SqlError> {
27        let query = Queries::InsertTraceSpan.get_query();
28        let capacity = spans.len();
29
30        // we are pre-allocating here instead of using multiunzip because multiunzip has
31        // a max limit of 12 tuples (we have 18 fields)
32        let mut created_at = Vec::with_capacity(capacity);
33        let mut span_id = Vec::with_capacity(capacity);
34        let mut trace_id = Vec::with_capacity(capacity);
35        let mut parent_span_id = Vec::with_capacity(capacity);
36        let mut flags = Vec::with_capacity(capacity);
37        let mut trace_state = Vec::with_capacity(capacity);
38        let mut scope_name = Vec::with_capacity(capacity);
39        let mut scope_version = Vec::with_capacity(capacity);
40        let mut span_name = Vec::with_capacity(capacity);
41        let mut span_kind = Vec::with_capacity(capacity);
42        let mut start_time = Vec::with_capacity(capacity);
43        let mut end_time = Vec::with_capacity(capacity);
44        let mut duration_ms = Vec::with_capacity(capacity);
45        let mut status_code = Vec::with_capacity(capacity);
46        let mut status_message = Vec::with_capacity(capacity);
47        let mut attributes = Vec::with_capacity(capacity);
48        let mut events = Vec::with_capacity(capacity);
49        let mut links = Vec::with_capacity(capacity);
50        let mut labels = Vec::with_capacity(capacity);
51        let mut input = Vec::with_capacity(capacity);
52        let mut output = Vec::with_capacity(capacity);
53        let mut service_name = Vec::with_capacity(capacity);
54        let mut resource_attributes = Vec::with_capacity(capacity);
55
56        // Single iteration for maximum efficiency
57        for span in spans {
58            // add to trace cache
59            get_trace_cache().await.update_trace(span).await;
60
61            created_at.push(span.created_at);
62            span_id.push(span.span_id.as_bytes());
63            trace_id.push(span.trace_id.as_bytes());
64            parent_span_id.push(span.parent_span_id.as_ref().map(|id| id.as_bytes()));
65            flags.push(span.flags);
66            trace_state.push(span.trace_state.as_str());
67            scope_name.push(span.scope_name.as_str());
68            scope_version.push(span.scope_version.as_deref());
69            span_name.push(span.span_name.as_str());
70            span_kind.push(span.span_kind.as_str());
71            start_time.push(span.start_time);
72            end_time.push(span.end_time);
73            duration_ms.push(span.duration_ms);
74            status_code.push(span.status_code);
75            status_message.push(span.status_message.as_str());
76            attributes.push(Json(span.attributes.clone()));
77            events.push(Json(span.events.clone()));
78            links.push(Json(span.links.clone()));
79            labels.push(span.label.as_deref());
80            input.push(Json(span.input.clone()));
81            output.push(Json(span.output.clone()));
82            service_name.push(span.service_name.as_str());
83            resource_attributes.push(Json(span.resource_attributes.clone()));
84        }
85
86        let query_result = sqlx::query(query)
87            .bind(created_at)
88            .bind(span_id)
89            .bind(trace_id)
90            .bind(parent_span_id)
91            .bind(flags)
92            .bind(trace_state)
93            .bind(scope_name)
94            .bind(scope_version)
95            .bind(span_name)
96            .bind(span_kind)
97            .bind(start_time)
98            .bind(end_time)
99            .bind(duration_ms)
100            .bind(status_code)
101            .bind(status_message)
102            .bind(attributes)
103            .bind(events)
104            .bind(links)
105            .bind(labels)
106            .bind(input)
107            .bind(output)
108            .bind(service_name)
109            .bind(resource_attributes)
110            .execute(pool)
111            .await
112            .inspect_err(|e| error!("Error inserting trace spans: {:?}", e))?;
113
114        Ok(query_result)
115    }
116
117    /// Attempts to insert multiple trace baggage records into the database in a batch.
118    ///
119    /// # Arguments
120    /// * `pool` - The database connection pool
121    /// * `baggage` - The trace baggage records to insert
122    async fn insert_trace_baggage_batch(
123        pool: &Pool<Postgres>,
124        baggage: &[TraceBaggageRecord],
125    ) -> Result<PgQueryResult, SqlError> {
126        let query = Queries::InsertTraceBaggage.get_query();
127
128        let (created_at, trace_id, scope, key, value): (
129            Vec<DateTime<Utc>>,
130            Vec<&[u8]>,
131            Vec<&str>,
132            Vec<&str>,
133            Vec<&str>,
134        ) = multiunzip(baggage.iter().map(|b| {
135            (
136                b.created_at,
137                b.trace_id.as_bytes() as &[u8],
138                b.scope.as_str(),
139                b.key.as_str(),
140                b.value.as_str(),
141            )
142        }));
143
144        let query_result = sqlx::query(query)
145            .bind(created_at)
146            .bind(trace_id)
147            .bind(scope)
148            .bind(key)
149            .bind(value)
150            .execute(pool)
151            .await
152            .inspect_err(|e| error!("Error inserting trace baggage: {:?}", e))?;
153
154        Ok(query_result)
155    }
156
157    /// Attempts to retrieve trace baggage records for a given trace ID.
158    /// # Arguments
159    /// * `pool` - The database connection pool
160    /// * `trace_id` - The trace ID to retrieve baggage for. This is always the hex encoded id
161    /// # Returns
162    /// * A vector of `TraceBaggageRecord` associated with the trace ID
163    async fn get_trace_baggage_records(
164        pool: &Pool<Postgres>,
165        trace_id: &str,
166    ) -> Result<Vec<TraceBaggageRecord>, SqlError> {
167        let bytes = TraceId::hex_to_bytes(trace_id)?;
168
169        let query = Queries::GetTraceBaggage.get_query();
170
171        let baggage_items: Result<Vec<TraceBaggageRecord>, SqlError> = sqlx::query_as(query)
172            .bind(bytes.as_slice())
173            .fetch_all(pool)
174            .await
175            .map_err(SqlError::SqlxError);
176
177        baggage_items
178    }
179
180    /// Attempts to retrieve paginated trace records from the database based on provided filters.
181    /// # Arguments
182    /// * `pool` - The database connection pool
183    /// * `filters` - The filters to apply for retrieving traces
184    /// # Returns
185    /// * A vector of `TraceListItem` matching the filters
186    #[instrument(skip_all)]
187    async fn get_paginated_traces(
188        pool: &Pool<Postgres>,
189        filters: TraceFilters,
190    ) -> Result<TracePaginationResponse, SqlError> {
191        let default_start = Utc::now() - chrono::Duration::hours(24);
192        let default_end = Utc::now();
193        let limit = filters.limit.unwrap_or(50);
194        let direction = filters.direction.as_deref().unwrap_or("next");
195        let trace_id_bytes = filters.parsed_trace_ids()?;
196        let cursor_trace_id_bytes = filters.parsed_cursor_trace_id()?;
197
198        let query = Queries::GetPaginatedTraces.get_query();
199
200        let tag_filters_json = filters.attribute_filters.as_ref().and_then(|tags| {
201            if tags.is_empty() {
202                None
203            } else {
204                // Parse "key:value" or "key=value" format into structured JSON
205                let tag_filters: Vec<HashMap<String, String>> = tags
206                    .iter()
207                    .filter_map(|tag| {
208                        let parts: Vec<&str> = tag.splitn(2, [':', '=']).collect();
209                        if parts.len() == 2 {
210                            Some(HashMap::from([
211                                ("key".to_string(), parts[0].trim().to_string()),
212                                ("value".to_string(), parts[1].trim().to_string()),
213                            ]))
214                        } else {
215                            None
216                        }
217                    })
218                    .collect();
219
220                if tag_filters.is_empty() {
221                    None
222                } else {
223                    Some(Json(tag_filters))
224                }
225            }
226        });
227
228        let entity_bytea = match filters.entity_uid {
229            Some(ref uid) => Some(EntityBytea::from_uuid(uid)?.0), // consume the uid
230            None => None,
231        };
232
233        let mut items: Vec<TraceListItem> = sqlx::query_as(query)
234            .bind(filters.service_name)
235            .bind(filters.has_errors)
236            .bind(filters.start_time.unwrap_or(default_start))
237            .bind(filters.end_time.unwrap_or(default_end))
238            .bind(limit)
239            .bind(filters.cursor_start_time)
240            .bind(cursor_trace_id_bytes)
241            .bind(direction)
242            .bind(tag_filters_json)
243            .bind(trace_id_bytes)
244            .bind(false)
245            .bind(entity_bytea)
246            .fetch_all(pool)
247            .await
248            .map_err(SqlError::SqlxError)?;
249
250        let has_more = items.len() > limit as usize;
251
252        // Remove the extra item
253        if has_more {
254            items.pop();
255        }
256
257        // Determine next/previous based on direction
258        let (has_next, next_cursor, has_previous, previous_cursor) = match direction {
259            "next" => {
260                // Forward pagination
261                let next_cursor = if has_more {
262                    items.last().map(|last| TraceCursor {
263                        start_time: last.start_time,
264                        trace_id: last.trace_id.clone(),
265                    })
266                } else {
267                    None
268                };
269
270                let previous_cursor = items.first().map(|first| TraceCursor {
271                    start_time: first.start_time,
272                    trace_id: first.trace_id.clone(),
273                });
274
275                (
276                    has_more,
277                    next_cursor,
278                    filters.cursor_start_time.is_some(),
279                    previous_cursor,
280                )
281            }
282            "previous" => {
283                // Backward pagination
284                let previous_cursor = if has_more {
285                    items.first().map(|first| TraceCursor {
286                        start_time: first.start_time,
287                        trace_id: first.trace_id.clone(),
288                    })
289                } else {
290                    None
291                };
292
293                let next_cursor = items.last().map(|last| TraceCursor {
294                    start_time: last.start_time,
295                    trace_id: last.trace_id.clone(),
296                });
297
298                (
299                    filters.cursor_start_time.is_some(),
300                    next_cursor,
301                    has_more,
302                    previous_cursor,
303                )
304            }
305            _ => (false, None, false, None),
306        };
307
308        Ok(TracePaginationResponse {
309            items,
310            has_next,
311            next_cursor,
312            has_previous,
313            previous_cursor,
314        })
315    }
316
317    /// Attempts to retrieve trace spans for a given trace ID.
318    /// # Arguments
319    /// * `pool` - The database connection pool
320    /// * `trace_id` - The trace ID to retrieve spans for
321    /// # Returns
322    /// * A vector of `TraceSpan` associated with the trace ID
323    async fn get_trace_spans(
324        pool: &Pool<Postgres>,
325        trace_id: &str,
326        service_name: Option<&str>,
327    ) -> Result<Vec<TraceSpan>, SqlError> {
328        let query = Queries::GetTraceSpans.get_query();
329        // check if service name is None or empty string, if so we want to bind None to the query, otherwise bind the service name
330        let service_name_param = service_name.filter(|&name| !name.trim().is_empty());
331        let trace_id_bytes = TraceId::hex_to_bytes(trace_id)?;
332        let trace_items: Result<Vec<TraceSpan>, SqlError> = sqlx::query_as(query)
333            .bind(trace_id_bytes)
334            .bind(service_name_param)
335            .fetch_all(pool)
336            .await
337            .map_err(SqlError::SqlxError);
338
339        trace_items
340    }
341
342    /// Attempts to retrieve trace spans for a given trace ID.
343    /// # Arguments
344    /// * `pool` - The database connection pool
345    /// * `trace_id` - The trace ID to retrieve spans for
346    /// # Returns
347    /// * A vector of `TraceSpan` associated with the trace ID
348    async fn get_trace_metrics(
349        pool: &Pool<Postgres>,
350        service_name: Option<&str>,
351        start_time: DateTime<Utc>,
352        end_time: DateTime<Utc>,
353        bucket_interval_str: &str,
354        attribute_filters: Option<Vec<String>>,
355        entity_uid: Option<String>,
356    ) -> Result<Vec<TraceMetricBucket>, SqlError> {
357        let tag_filters_json = attribute_filters.as_ref().and_then(|tags| {
358            if tags.is_empty() {
359                None
360            } else {
361                // Parse "key:value" or "key=value" format into structured JSON
362                let tag_filters: Vec<HashMap<String, String>> = tags
363                    .iter()
364                    .filter_map(|tag| {
365                        let parts: Vec<&str> = tag.splitn(2, [':', '=']).collect();
366                        if parts.len() == 2 {
367                            Some(HashMap::from([
368                                ("key".to_string(), parts[0].trim().to_string()),
369                                ("value".to_string(), parts[1].trim().to_string()),
370                            ]))
371                        } else {
372                            None
373                        }
374                    })
375                    .collect();
376
377                if tag_filters.is_empty() {
378                    None
379                } else {
380                    Some(Json(tag_filters))
381                }
382            }
383        });
384
385        let entity_bytea = match entity_uid {
386            Some(uid) => Some(EntityBytea::from_uuid(&uid)?.0), // consume the uid
387            None => None,
388        };
389
390        let query = Queries::GetTraceMetrics.get_query();
391        let trace_items: Result<Vec<TraceMetricBucket>, SqlError> = sqlx::query_as(query)
392            .bind(service_name)
393            .bind(start_time)
394            .bind(end_time)
395            .bind(bucket_interval_str)
396            .bind(tag_filters_json)
397            .bind(false)
398            .bind(entity_bytea)
399            .fetch_all(pool)
400            .await
401            .map_err(SqlError::SqlxError);
402
403        trace_items
404    }
405
406    async fn refresh_trace_summary(pool: &Pool<Postgres>) -> Result<PgQueryResult, SqlError> {
407        let query_result = sqlx::query("REFRESH MATERIALIZED VIEW scouter.trace_summary;")
408            .execute(pool)
409            .await?;
410
411        Ok(query_result)
412    }
413
414    /// Attempts to retrieve trace spans based on tag filters.
415    /// # Arguments
416    /// * `pool` - The database connection pool
417    /// * `entity_type` - The entity type to filter spans
418    /// * `tag_filters` - The tag filters to apply
419    /// * `match_all` - Whether to match all tags or any
420    /// * `service_name` - Optional service name to filter spans
421    /// # Returns
422    /// * A vector of `TraceSpan` matching the tag filters
423    async fn get_spans_from_tags(
424        pool: &Pool<Postgres>,
425        entity_type: &str,
426        tag_filters: Vec<HashMap<String, String>>,
427        match_all: bool,
428        service_name: Option<&str>,
429    ) -> Result<Vec<TraceSpan>, SqlError> {
430        let query = Queries::GetSpansByTags.get_query();
431
432        sqlx::query_as(query)
433            .bind(entity_type)
434            .bind(Json(tag_filters))
435            .bind(match_all)
436            .bind(service_name)
437            .fetch_all(pool)
438            .await
439            .map_err(SqlError::SqlxError)
440    }
441}