Skip to main content

scouter_sql/sql/traits/
trace.rs

1use crate::sql::error::SqlError;
2use crate::sql::query::Queries;
3
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use itertools::multiunzip;
7use scouter_types::sql::{TraceFilters, TraceListItem, TraceMetricBucket, TraceSpan};
8use scouter_types::{TraceBaggageRecord, TraceCursor, TracePaginationResponse, TraceSpanRecord};
9use sqlx::{postgres::PgQueryResult, types::Json, Pool, Postgres};
10use std::collections::HashMap;
11use tracing::{error, instrument};
12#[async_trait]
13pub trait TraceSqlLogic {
14    /// Attempts to insert multiple trace span records into the database in a batch.
15    ///
16    /// # Arguments
17    /// * `pool` - The database connection pool
18    /// * `spans` - The trace span records to insert
19    async fn insert_span_batch(
20        pool: &Pool<Postgres>,
21        spans: &[TraceSpanRecord],
22    ) -> Result<PgQueryResult, SqlError> {
23        let query = Queries::InsertTraceSpan.get_query();
24        let capacity = spans.len();
25
26        // we are pre-allocating here instead of using multiunzip because multiunzip has
27        // a max limit of 12 tuples (we have 18 fields)
28        let mut created_at = Vec::with_capacity(capacity);
29        let mut span_id = Vec::with_capacity(capacity);
30        let mut trace_id = Vec::with_capacity(capacity);
31        let mut parent_span_id = Vec::with_capacity(capacity);
32        let mut scope = Vec::with_capacity(capacity);
33        let mut span_name = Vec::with_capacity(capacity);
34        let mut span_kind = Vec::with_capacity(capacity);
35        let mut start_time = Vec::with_capacity(capacity);
36        let mut end_time = Vec::with_capacity(capacity);
37        let mut duration_ms = Vec::with_capacity(capacity);
38        let mut status_code = Vec::with_capacity(capacity);
39        let mut status_message = Vec::with_capacity(capacity);
40        let mut attributes = Vec::with_capacity(capacity);
41        let mut events = Vec::with_capacity(capacity);
42        let mut links = Vec::with_capacity(capacity);
43        let mut labels = Vec::with_capacity(capacity);
44        let mut input = Vec::with_capacity(capacity);
45        let mut output = Vec::with_capacity(capacity);
46        let mut service_name = Vec::with_capacity(capacity);
47        let mut resource_attributes = Vec::with_capacity(capacity);
48
49        // Single iteration for maximum efficiency
50        for span in spans {
51            created_at.push(span.created_at);
52            span_id.push(span.span_id.as_str());
53            trace_id.push(span.trace_id.as_str());
54            parent_span_id.push(span.parent_span_id.as_deref());
55            scope.push(span.scope.as_str());
56            span_name.push(span.span_name.as_str());
57            span_kind.push(span.span_kind.as_str());
58            start_time.push(span.start_time);
59            end_time.push(span.end_time);
60            duration_ms.push(span.duration_ms);
61            status_code.push(span.status_code);
62            status_message.push(span.status_message.as_str());
63            attributes.push(Json(span.attributes.clone()));
64            events.push(Json(span.events.clone()));
65            links.push(Json(span.links.clone()));
66            labels.push(span.label.as_deref());
67            input.push(Json(span.input.clone()));
68            output.push(Json(span.output.clone()));
69            service_name.push(span.service_name.as_str());
70            resource_attributes.push(Json(span.resource_attributes.clone()));
71        }
72
73        let query_result = sqlx::query(query)
74            .bind(created_at)
75            .bind(span_id)
76            .bind(trace_id)
77            .bind(parent_span_id)
78            .bind(scope)
79            .bind(span_name)
80            .bind(span_kind)
81            .bind(start_time)
82            .bind(end_time)
83            .bind(duration_ms)
84            .bind(status_code)
85            .bind(status_message)
86            .bind(attributes)
87            .bind(events)
88            .bind(links)
89            .bind(labels)
90            .bind(input)
91            .bind(output)
92            .bind(service_name)
93            .bind(resource_attributes)
94            .execute(pool)
95            .await
96            .inspect_err(|e| error!("Error inserting trace spans: {:?}", e))?;
97
98        Ok(query_result)
99    }
100
101    /// Attempts to insert multiple trace baggage records into the database in a batch.
102    ///
103    /// # Arguments
104    /// * `pool` - The database connection pool
105    /// * `baggage` - The trace baggage records to insert
106    async fn insert_trace_baggage_batch(
107        pool: &Pool<Postgres>,
108        baggage: &[TraceBaggageRecord],
109    ) -> Result<PgQueryResult, SqlError> {
110        let query = Queries::InsertTraceBaggage.get_query();
111
112        let (created_at, trace_id, scope, key, value): (
113            Vec<DateTime<Utc>>,
114            Vec<&str>,
115            Vec<&str>,
116            Vec<&str>,
117            Vec<&str>,
118        ) = multiunzip(baggage.iter().map(|b| {
119            (
120                b.created_at,
121                b.trace_id.as_str(),
122                b.scope.as_str(),
123                b.key.as_str(),
124                b.value.as_str(),
125            )
126        }));
127
128        let query_result = sqlx::query(query)
129            .bind(created_at)
130            .bind(trace_id)
131            .bind(scope)
132            .bind(key)
133            .bind(value)
134            .execute(pool)
135            .await
136            .inspect_err(|e| error!("Error inserting trace baggage: {:?}", e))?;
137
138        Ok(query_result)
139    }
140
141    async fn get_trace_baggage_records(
142        pool: &Pool<Postgres>,
143        trace_id: &str,
144    ) -> Result<Vec<TraceBaggageRecord>, SqlError> {
145        let query = Queries::GetTraceBaggage.get_query();
146
147        let baggage_items: Result<Vec<TraceBaggageRecord>, SqlError> = sqlx::query_as(query)
148            .bind(trace_id)
149            .fetch_all(pool)
150            .await
151            .map_err(SqlError::SqlxError);
152
153        baggage_items
154    }
155
156    /// Attempts to retrieve paginated trace records from the database based on provided filters.
157    /// # Arguments
158    /// * `pool` - The database connection pool
159    /// * `filters` - The filters to apply for retrieving traces
160    /// # Returns
161    /// * A vector of `TraceListItem` matching the filters
162    #[instrument(skip_all)]
163    async fn get_paginated_traces(
164        pool: &Pool<Postgres>,
165        filters: TraceFilters,
166    ) -> Result<TracePaginationResponse, SqlError> {
167        let default_start = Utc::now() - chrono::Duration::hours(24);
168        let default_end = Utc::now();
169        let limit = filters.limit.unwrap_or(50);
170        let direction = filters.direction.as_deref().unwrap_or("next");
171
172        let query = Queries::GetPaginatedTraces.get_query();
173
174        let tag_filters_json = filters.attribute_filters.as_ref().and_then(|tags| {
175            if tags.is_empty() {
176                None
177            } else {
178                // Parse "key:value" or "key=value" format into structured JSON
179                let tag_filters: Vec<HashMap<String, String>> = tags
180                    .iter()
181                    .filter_map(|tag| {
182                        let parts: Vec<&str> = tag.splitn(2, [':', '=']).collect();
183                        if parts.len() == 2 {
184                            Some(HashMap::from([
185                                ("key".to_string(), parts[0].trim().to_string()),
186                                ("value".to_string(), parts[1].trim().to_string()),
187                            ]))
188                        } else {
189                            None
190                        }
191                    })
192                    .collect();
193
194                if tag_filters.is_empty() {
195                    None
196                } else {
197                    Some(Json(tag_filters))
198                }
199            }
200        });
201
202        let mut items: Vec<TraceListItem> = sqlx::query_as(query)
203            .bind(filters.service_name)
204            .bind(filters.has_errors)
205            .bind(filters.status_code)
206            .bind(filters.start_time.unwrap_or(default_start))
207            .bind(filters.end_time.unwrap_or(default_end))
208            .bind(limit)
209            .bind(filters.cursor_start_time)
210            .bind(filters.cursor_trace_id)
211            .bind(direction)
212            .bind(filters.trace_ids)
213            .bind(tag_filters_json)
214            .bind(false) // we dont want to match all attributes for pagination
215            .fetch_all(pool)
216            .await
217            .map_err(SqlError::SqlxError)?;
218
219        let has_more = items.len() > limit as usize;
220
221        // Remove the extra item
222        if has_more {
223            items.pop();
224        }
225
226        // Determine next/previous based on direction
227        let (has_next, next_cursor, has_previous, previous_cursor) = match direction {
228            "next" => {
229                // Forward pagination
230                let next_cursor = if has_more {
231                    items.last().map(|last| TraceCursor {
232                        start_time: last.start_time,
233                        trace_id: last.trace_id.clone(),
234                    })
235                } else {
236                    None
237                };
238
239                let previous_cursor = items.first().map(|first| TraceCursor {
240                    start_time: first.start_time,
241                    trace_id: first.trace_id.clone(),
242                });
243
244                (
245                    has_more,
246                    next_cursor,
247                    filters.cursor_start_time.is_some(),
248                    previous_cursor,
249                )
250            }
251            "previous" => {
252                // Backward pagination
253                let previous_cursor = if has_more {
254                    items.first().map(|first| TraceCursor {
255                        start_time: first.start_time,
256                        trace_id: first.trace_id.clone(),
257                    })
258                } else {
259                    None
260                };
261
262                let next_cursor = items.last().map(|last| TraceCursor {
263                    start_time: last.start_time,
264                    trace_id: last.trace_id.clone(),
265                });
266
267                (
268                    filters.cursor_start_time.is_some(),
269                    next_cursor,
270                    has_more,
271                    previous_cursor,
272                )
273            }
274            _ => (false, None, false, None),
275        };
276
277        Ok(TracePaginationResponse {
278            items,
279            has_next,
280            next_cursor,
281            has_previous,
282            previous_cursor,
283        })
284    }
285
286    /// Attempts to retrieve trace spans for a given trace ID.
287    /// # Arguments
288    /// * `pool` - The database connection pool
289    /// * `trace_id` - The trace ID to retrieve spans for
290    /// # Returns
291    /// * A vector of `TraceSpan` associated with the trace ID
292    async fn get_trace_spans(
293        pool: &Pool<Postgres>,
294        trace_id: &str,
295        service_name: Option<&str>,
296    ) -> Result<Vec<TraceSpan>, SqlError> {
297        let query = Queries::GetTraceSpans.get_query();
298        let trace_items: Result<Vec<TraceSpan>, SqlError> = sqlx::query_as(query)
299            .bind(trace_id)
300            .bind(service_name)
301            .fetch_all(pool)
302            .await
303            .map_err(SqlError::SqlxError);
304
305        trace_items
306    }
307
308    /// Attempts to retrieve trace spans for a given trace ID.
309    /// # Arguments
310    /// * `pool` - The database connection pool
311    /// * `trace_id` - The trace ID to retrieve spans for
312    /// # Returns
313    /// * A vector of `TraceSpan` associated with the trace ID
314    async fn get_trace_metrics(
315        pool: &Pool<Postgres>,
316        service_name: Option<&str>,
317        start_time: DateTime<Utc>,
318        end_time: DateTime<Utc>,
319        bucket_interval_str: &str,
320        attribute_filters: Option<Vec<String>>,
321    ) -> Result<Vec<TraceMetricBucket>, SqlError> {
322        let tag_filters_json = attribute_filters.as_ref().and_then(|tags| {
323            if tags.is_empty() {
324                None
325            } else {
326                // Parse "key:value" or "key=value" format into structured JSON
327                let tag_filters: Vec<HashMap<String, String>> = tags
328                    .iter()
329                    .filter_map(|tag| {
330                        let parts: Vec<&str> = tag.splitn(2, [':', '=']).collect();
331                        if parts.len() == 2 {
332                            Some(HashMap::from([
333                                ("key".to_string(), parts[0].trim().to_string()),
334                                ("value".to_string(), parts[1].trim().to_string()),
335                            ]))
336                        } else {
337                            None
338                        }
339                    })
340                    .collect();
341
342                if tag_filters.is_empty() {
343                    None
344                } else {
345                    Some(Json(tag_filters))
346                }
347            }
348        });
349
350        let query = Queries::GetTraceMetrics.get_query();
351        let trace_items: Result<Vec<TraceMetricBucket>, SqlError> = sqlx::query_as(query)
352            .bind(service_name)
353            .bind(start_time)
354            .bind(end_time)
355            .bind(bucket_interval_str)
356            .bind(tag_filters_json)
357            .bind(false)
358            .fetch_all(pool)
359            .await
360            .map_err(SqlError::SqlxError);
361
362        trace_items
363    }
364
365    async fn refresh_trace_summary(pool: &Pool<Postgres>) -> Result<PgQueryResult, SqlError> {
366        let query_result = sqlx::query("REFRESH MATERIALIZED VIEW scouter.trace_summary;")
367            .execute(pool)
368            .await?;
369
370        Ok(query_result)
371    }
372
373    /// Attempts to retrieve trace spans based on tag filters.
374    /// # Arguments
375    /// * `pool` - The database connection pool
376    /// * `entity_type` - The entity type to filter spans
377    /// * `tag_filters` - The tag filters to apply
378    /// * `match_all` - Whether to match all tags or any
379    /// * `service_name` - Optional service name to filter spans
380    /// # Returns
381    /// * A vector of `TraceSpan` matching the tag filters
382    async fn get_spans_from_tags(
383        pool: &Pool<Postgres>,
384        entity_type: &str,
385        tag_filters: Vec<HashMap<String, String>>,
386        match_all: bool,
387        service_name: Option<&str>,
388    ) -> Result<Vec<TraceSpan>, SqlError> {
389        let query = Queries::GetSpansByTags.get_query();
390
391        sqlx::query_as(query)
392            .bind(entity_type)
393            .bind(Json(tag_filters))
394            .bind(match_all)
395            .bind(service_name)
396            .fetch_all(pool)
397            .await
398            .map_err(SqlError::SqlxError)
399    }
400}