scouter_sql/sql/traits/
trace.rs

1use crate::sql::error::SqlError;
2use crate::sql::query::Queries;
3
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use itertools::multiunzip;
7use scouter_types::sql::{TraceFilters, TraceListItem, TraceMetricBucket, TraceSpan};
8use scouter_types::{
9    TraceBaggageRecord, TraceCursor, TracePaginationResponse, TraceRecord, TraceSpanRecord,
10};
11use sqlx::{postgres::PgQueryResult, types::Json, Pool, Postgres};
12
13#[async_trait]
14pub trait TraceSqlLogic {
15    /// Attempts to upsert multiple trace records into the database in a batch.
16    ///
17    /// # Arguments
18    ///
19    /// * `pool` - The database connection pool
20    /// * `traces` - The trace records to insert
21    async fn upsert_trace_batch(
22        pool: &Pool<Postgres>,
23        traces: &[TraceRecord],
24    ) -> Result<PgQueryResult, SqlError> {
25        let query = Queries::UpsertTrace.get_query();
26        let capacity = traces.len();
27
28        // Pre-allocate vectors for each field for batch efficiency
29        let mut created_at = Vec::with_capacity(capacity);
30        let mut trace_id = Vec::with_capacity(capacity);
31        let mut space = Vec::with_capacity(capacity);
32        let mut name = Vec::with_capacity(capacity);
33        let mut version = Vec::with_capacity(capacity);
34        let mut scope = Vec::with_capacity(capacity);
35        let mut trace_state = Vec::with_capacity(capacity);
36        let mut start_time = Vec::with_capacity(capacity);
37        let mut end_time = Vec::with_capacity(capacity);
38        let mut duration_ms = Vec::with_capacity(capacity);
39        let mut status_code = Vec::with_capacity(capacity);
40        let mut status_message = Vec::with_capacity(capacity);
41        let mut root_span_id = Vec::with_capacity(capacity);
42        let mut span_count = Vec::with_capacity(capacity);
43
44        // Single-pass extraction for performance
45        for r in traces {
46            created_at.push(r.created_at);
47            trace_id.push(r.trace_id.as_str());
48            space.push(r.space.as_str());
49            name.push(r.name.as_str());
50            version.push(r.version.as_str());
51            scope.push(r.scope.as_str());
52            trace_state.push(r.trace_state.as_str());
53            start_time.push(r.start_time);
54            end_time.push(r.end_time);
55            duration_ms.push(r.duration_ms);
56            status_code.push(r.status_code);
57            status_message.push(r.status_message.clone());
58            root_span_id.push(r.root_span_id.as_str());
59            span_count.push(r.span_count);
60        }
61
62        let query_result = sqlx::query(&query.sql)
63            .bind(created_at)
64            .bind(trace_id)
65            .bind(space)
66            .bind(name)
67            .bind(version)
68            .bind(scope)
69            .bind(trace_state)
70            .bind(start_time)
71            .bind(end_time)
72            .bind(duration_ms)
73            .bind(status_code)
74            .bind(status_message)
75            .bind(root_span_id)
76            .bind(span_count)
77            .execute(pool)
78            .await?;
79
80        Ok(query_result)
81    }
82
83    /// Attempts to insert multiple trace span records into the database in a batch.
84    ///
85    /// # Arguments
86    /// * `pool` - The database connection pool
87    /// * `spans` - The trace span records to insert
88    async fn insert_span_batch(
89        pool: &Pool<Postgres>,
90        spans: &[TraceSpanRecord],
91    ) -> Result<PgQueryResult, SqlError> {
92        let query = Queries::InsertTraceSpan.get_query();
93        let capacity = spans.len();
94
95        // we are pre-allocating here instead of using multiunzip because multiunzip has
96        // a max limit of 12 tuples (we have 18 fields)
97        let mut created_at = Vec::with_capacity(capacity);
98        let mut span_id = Vec::with_capacity(capacity);
99        let mut trace_id = Vec::with_capacity(capacity);
100        let mut parent_span_id = Vec::with_capacity(capacity);
101        let mut space = Vec::with_capacity(capacity);
102        let mut name = Vec::with_capacity(capacity);
103        let mut version = Vec::with_capacity(capacity);
104        let mut scope = Vec::with_capacity(capacity);
105        let mut span_name = Vec::with_capacity(capacity);
106        let mut span_kind = Vec::with_capacity(capacity);
107        let mut start_time = Vec::with_capacity(capacity);
108        let mut end_time = Vec::with_capacity(capacity);
109        let mut duration_ms = Vec::with_capacity(capacity);
110        let mut status_code = Vec::with_capacity(capacity);
111        let mut status_message = Vec::with_capacity(capacity);
112        let mut attributes = Vec::with_capacity(capacity);
113        let mut events = Vec::with_capacity(capacity);
114        let mut links = Vec::with_capacity(capacity);
115        let mut labels = Vec::with_capacity(capacity);
116        let mut input = Vec::with_capacity(capacity);
117        let mut output = Vec::with_capacity(capacity);
118
119        // Single iteration for maximum efficiency
120        for span in spans {
121            created_at.push(span.created_at);
122            span_id.push(span.span_id.as_str());
123            trace_id.push(span.trace_id.as_str());
124            parent_span_id.push(span.parent_span_id.as_deref());
125            space.push(span.space.as_str());
126            name.push(span.name.as_str());
127            version.push(span.version.as_str());
128            scope.push(span.scope.as_str());
129            span_name.push(span.span_name.as_str());
130            span_kind.push(span.span_kind.as_str());
131            start_time.push(span.start_time);
132            end_time.push(span.end_time);
133            duration_ms.push(span.duration_ms);
134            status_code.push(span.status_code);
135            status_message.push(span.status_message.as_str());
136            attributes.push(Json(span.attributes.clone()));
137            events.push(Json(span.events.clone()));
138            links.push(Json(span.links.clone()));
139            labels.push(span.label.as_deref());
140            input.push(Json(span.input.clone()));
141            output.push(Json(span.output.clone()));
142        }
143
144        let query_result = sqlx::query(&query.sql)
145            .bind(created_at)
146            .bind(span_id)
147            .bind(trace_id)
148            .bind(parent_span_id)
149            .bind(space)
150            .bind(name)
151            .bind(version)
152            .bind(scope)
153            .bind(span_name)
154            .bind(span_kind)
155            .bind(start_time)
156            .bind(end_time)
157            .bind(duration_ms)
158            .bind(status_code)
159            .bind(status_message)
160            .bind(attributes)
161            .bind(events)
162            .bind(links)
163            .bind(labels)
164            .bind(input)
165            .bind(output)
166            .execute(pool)
167            .await?;
168
169        Ok(query_result)
170    }
171
172    /// Attempts to insert multiple trace baggage records into the database in a batch.
173    ///
174    /// # Arguments
175    /// * `pool` - The database connection pool
176    /// * `baggage` - The trace baggage records to insert
177    async fn insert_trace_baggage_batch(
178        pool: &Pool<Postgres>,
179        baggage: &[TraceBaggageRecord],
180    ) -> Result<PgQueryResult, SqlError> {
181        let query = Queries::InsertTraceBaggage.get_query();
182
183        let (created_at, trace_id, scope, key, value): (
184            Vec<DateTime<Utc>>,
185            Vec<&str>,
186            Vec<&str>,
187            Vec<&str>,
188            Vec<&str>,
189        ) = multiunzip(baggage.iter().map(|b| {
190            (
191                b.created_at,
192                b.trace_id.as_str(),
193                b.scope.as_str(),
194                b.key.as_str(),
195                b.value.as_str(),
196            )
197        }));
198
199        let query_result = sqlx::query(&query.sql)
200            .bind(created_at)
201            .bind(trace_id)
202            .bind(scope)
203            .bind(key)
204            .bind(value)
205            .execute(pool)
206            .await?;
207
208        Ok(query_result)
209    }
210
211    async fn get_trace_baggage_records(
212        pool: &Pool<Postgres>,
213        trace_id: &str,
214    ) -> Result<Vec<TraceBaggageRecord>, SqlError> {
215        let query = Queries::GetTraceBaggage.get_query();
216
217        let baggage_items: Result<Vec<TraceBaggageRecord>, SqlError> = sqlx::query_as(&query.sql)
218            .bind(trace_id)
219            .fetch_all(pool)
220            .await
221            .map_err(SqlError::SqlxError);
222
223        baggage_items
224    }
225
226    /// Attempts to retrieve paginated trace records from the database based on provided filters.
227    /// # Arguments
228    /// * `pool` - The database connection pool
229    /// * `filters` - The filters to apply for retrieving traces
230    /// # Returns
231    /// * A vector of `TraceListItem` matching the filters
232    async fn get_traces_paginated(
233        pool: &Pool<Postgres>,
234        filters: TraceFilters,
235    ) -> Result<TracePaginationResponse, SqlError> {
236        let default_start = Utc::now() - chrono::Duration::hours(24);
237        let default_end = Utc::now();
238        let limit = filters.limit.unwrap_or(50);
239        let direction = filters.direction.as_deref().unwrap_or("next");
240
241        let query = Queries::GetPaginatedTraces.get_query();
242
243        let mut items: Vec<TraceListItem> = sqlx::query_as(&query.sql)
244            .bind(filters.space)
245            .bind(filters.name)
246            .bind(filters.version)
247            .bind(filters.service_name)
248            .bind(filters.has_errors)
249            .bind(filters.status_code)
250            .bind(filters.start_time.unwrap_or(default_start))
251            .bind(filters.end_time.unwrap_or(default_end))
252            .bind(limit)
253            .bind(filters.cursor_created_at)
254            .bind(filters.cursor_trace_id)
255            .bind(direction)
256            .fetch_all(pool)
257            .await
258            .map_err(SqlError::SqlxError)?;
259
260        let has_more = items.len() > limit as usize;
261
262        // Remove the extra item
263        if has_more {
264            items.pop();
265        }
266
267        // Determine next/previous based on direction
268        let (has_next, next_cursor, has_previous, previous_cursor) = match direction {
269            "next" => {
270                // Forward pagination
271                let next_cursor = if has_more {
272                    items.last().map(|last| TraceCursor {
273                        created_at: last.created_at,
274                        trace_id: last.trace_id.clone(),
275                    })
276                } else {
277                    None
278                };
279
280                let previous_cursor = items.first().map(|first| TraceCursor {
281                    created_at: first.created_at,
282                    trace_id: first.trace_id.clone(),
283                });
284
285                (
286                    has_more,
287                    next_cursor,
288                    filters.cursor_created_at.is_some(),
289                    previous_cursor,
290                )
291            }
292            "previous" => {
293                // Backward pagination
294                let previous_cursor = if has_more {
295                    items.first().map(|first| TraceCursor {
296                        created_at: first.created_at,
297                        trace_id: first.trace_id.clone(),
298                    })
299                } else {
300                    None
301                };
302
303                let next_cursor = items.last().map(|last| TraceCursor {
304                    created_at: last.created_at,
305                    trace_id: last.trace_id.clone(),
306                });
307
308                (
309                    filters.cursor_created_at.is_some(),
310                    next_cursor,
311                    has_more,
312                    previous_cursor,
313                )
314            }
315            _ => (false, None, false, None),
316        };
317
318        Ok(TracePaginationResponse {
319            items,
320            has_next,
321            next_cursor,
322            has_previous,
323            previous_cursor,
324        })
325    }
326
327    /// Attempts to retrieve trace spans for a given trace ID.
328    /// # Arguments
329    /// * `pool` - The database connection pool
330    /// * `trace_id` - The trace ID to retrieve spans for
331    /// # Returns
332    /// * A vector of `TraceSpan` associated with the trace ID
333    async fn get_trace_spans(
334        pool: &Pool<Postgres>,
335        trace_id: &str,
336    ) -> Result<Vec<TraceSpan>, SqlError> {
337        let query = Queries::GetTraceSpans.get_query();
338        let trace_items: Result<Vec<TraceSpan>, SqlError> = sqlx::query_as(&query.sql)
339            .bind(trace_id)
340            .fetch_all(pool)
341            .await
342            .map_err(SqlError::SqlxError);
343
344        trace_items
345    }
346
347    /// Attempts to retrieve trace spans for a given trace ID.
348    /// # Arguments
349    /// * `pool` - The database connection pool
350    /// * `trace_id` - The trace ID to retrieve spans for
351    /// # Returns
352    /// * A vector of `TraceSpan` associated with the trace ID
353    async fn get_trace_metrics(
354        pool: &Pool<Postgres>,
355        space: Option<&str>,
356        name: Option<&str>,
357        version: Option<&str>,
358        start_time: DateTime<Utc>,
359        end_time: DateTime<Utc>,
360        bucket_interval_str: &str,
361    ) -> Result<Vec<TraceMetricBucket>, SqlError> {
362        let query = Queries::GetTraceMetrics.get_query();
363        let trace_items: Result<Vec<TraceMetricBucket>, SqlError> = sqlx::query_as(&query.sql)
364            .bind(space)
365            .bind(name)
366            .bind(version)
367            .bind(start_time)
368            .bind(end_time)
369            .bind(bucket_interval_str)
370            .fetch_all(pool)
371            .await
372            .map_err(SqlError::SqlxError);
373
374        trace_items
375    }
376
377    async fn refresh_trace_summary(pool: &Pool<Postgres>) -> Result<PgQueryResult, SqlError> {
378        let query_result = sqlx::query("REFRESH MATERIALIZED VIEW scouter.trace_summary;")
379            .execute(pool)
380            .await?;
381
382        Ok(query_result)
383    }
384}