Skip to main content

scouter_sql/sql/traits/
trace.rs

1use crate::sql::error::SqlError;
2use crate::sql::query::Queries;
3
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use itertools::multiunzip;
7use scouter_types::sql::TraceSpan;
8use scouter_types::{TraceBaggageRecord, TraceId};
9use sqlx::{postgres::PgQueryResult, types::Json, Pool, Postgres};
10use std::collections::HashMap;
11use tracing::error;
12#[async_trait]
13pub trait TraceSqlLogic {
14    /// Attempts to insert multiple trace baggage records into the database in a batch.
15    ///
16    /// # Arguments
17    /// * `pool` - The database connection pool
18    /// * `baggage` - The trace baggage records to insert
19    async fn insert_trace_baggage_batch(
20        pool: &Pool<Postgres>,
21        baggage: &[TraceBaggageRecord],
22    ) -> Result<PgQueryResult, SqlError> {
23        let query = Queries::InsertTraceBaggage.get_query();
24
25        let (created_at, trace_id, scope, key, value): (
26            Vec<DateTime<Utc>>,
27            Vec<&[u8]>,
28            Vec<&str>,
29            Vec<&str>,
30            Vec<&str>,
31        ) = multiunzip(baggage.iter().map(|b| {
32            (
33                b.created_at,
34                b.trace_id.as_bytes() as &[u8],
35                b.scope.as_str(),
36                b.key.as_str(),
37                b.value.as_str(),
38            )
39        }));
40
41        let query_result = sqlx::query(query)
42            .bind(created_at)
43            .bind(trace_id)
44            .bind(scope)
45            .bind(key)
46            .bind(value)
47            .execute(pool)
48            .await
49            .inspect_err(|e| error!("Error inserting trace baggage: {:?}", e))?;
50
51        Ok(query_result)
52    }
53
54    /// Attempts to retrieve trace baggage records for a given trace ID.
55    /// # Arguments
56    /// * `pool` - The database connection pool
57    /// * `trace_id` - The trace ID to retrieve baggage for. This is always the hex encoded id
58    /// # Returns
59    /// * A vector of `TraceBaggageRecord` associated with the trace ID
60    async fn get_trace_baggage_records(
61        pool: &Pool<Postgres>,
62        trace_id: &str,
63    ) -> Result<Vec<TraceBaggageRecord>, SqlError> {
64        let bytes = TraceId::hex_to_bytes(trace_id)?;
65
66        let query = Queries::GetTraceBaggage.get_query();
67
68        let baggage_items: Result<Vec<TraceBaggageRecord>, SqlError> = sqlx::query_as(query)
69            .bind(bytes.as_slice())
70            .fetch_all(pool)
71            .await
72            .map_err(SqlError::SqlxError);
73
74        baggage_items
75    }
76
77    /// Attempts to retrieve trace spans based on tag filters.
78    /// # Arguments
79    /// * `pool` - The database connection pool
80    /// * `entity_type` - The entity type to filter spans
81    /// * `tag_filters` - The tag filters to apply
82    /// * `match_all` - Whether to match all tags or any
83    /// * `service_name` - Optional service name to filter spans
84    /// # Returns
85    /// * A vector of `TraceSpan` matching the tag filters
86    async fn get_spans_from_tags(
87        pool: &Pool<Postgres>,
88        entity_type: &str,
89        tag_filters: Vec<HashMap<String, String>>,
90        match_all: bool,
91        service_name: Option<&str>,
92    ) -> Result<Vec<TraceSpan>, SqlError> {
93        let query = Queries::GetSpansByTags.get_query();
94
95        sqlx::query_as(query)
96            .bind(entity_type)
97            .bind(Json(tag_filters))
98            .bind(match_all)
99            .bind(service_name)
100            .fetch_all(pool)
101            .await
102            .map_err(SqlError::SqlxError)
103    }
104
105    /// Resolve `entity_uid` (UUID string) to raw 16-byte trace IDs via `scouter.trace_entities`.
106    /// Returns empty `Vec` when `entity_uid` is invalid or no rows match.
107    async fn get_trace_ids_for_entity(
108        pool: &Pool<Postgres>,
109        entity_uid: &str,
110    ) -> Result<Vec<Vec<u8>>, SqlError> {
111        let uuid: uuid::Uuid = entity_uid.parse().map_err(SqlError::UuidError)?;
112        let uid_bytes = uuid.as_bytes().to_vec();
113        sqlx::query_scalar::<_, Vec<u8>>(
114            "SELECT trace_id FROM scouter.trace_entities WHERE entity_uid = $1",
115        )
116        .bind(uid_bytes)
117        .fetch_all(pool)
118        .await
119        .map_err(SqlError::SqlxError)
120    }
121}