Skip to main content

sqlx_otel/
executor.rs

1use std::borrow::Cow;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4use std::time::Instant;
5
6use futures::Stream;
7use futures::stream::BoxStream;
8use opentelemetry::trace::{SpanKind, Status, TraceContextExt, Tracer};
9use opentelemetry::{Context as OtelContext, KeyValue};
10use opentelemetry_semantic_conventions::attribute;
11
12use crate::annotations::QueryAnnotations;
13use crate::attributes::{self, ConnectionAttributes, QueryTextMode};
14use crate::database::Database;
15use crate::metrics::Metrics;
16
17// ---------------------------------------------------------------------------
18// Span helpers
19// ---------------------------------------------------------------------------
20
21/// Build span attributes for a query, combining connection-level and per-query values.
22///
23/// When `annotations` is provided, the four per-query semantic convention attributes
24/// (`db.operation.name`, `db.collection.name`, `db.query.summary`,
25/// `db.stored_procedure.name`) are included for any field that is set.
26fn build_attributes(
27    attrs: &ConnectionAttributes,
28    sql: Option<&str>,
29    annotations: Option<&QueryAnnotations>,
30) -> Vec<KeyValue> {
31    let mut kv = attrs.base_key_values();
32    if let Some(ann) = annotations {
33        if let Some(ref op) = ann.operation {
34            kv.push(KeyValue::new(attribute::DB_OPERATION_NAME, op.clone()));
35        }
36        if let Some(ref coll) = ann.collection {
37            kv.push(KeyValue::new(attribute::DB_COLLECTION_NAME, coll.clone()));
38        }
39        if let Some(ref summary) = ann.query_summary {
40            kv.push(KeyValue::new(attribute::DB_QUERY_SUMMARY, summary.clone()));
41        }
42        if let Some(ref sp) = ann.stored_procedure {
43            kv.push(KeyValue::new(
44                attribute::DB_STORED_PROCEDURE_NAME,
45                sp.clone(),
46            ));
47        }
48    }
49    if let Some(sql) = sql {
50        match attrs.query_text_mode {
51            QueryTextMode::Full => {
52                kv.push(KeyValue::new(attribute::DB_QUERY_TEXT, sql.to_owned()));
53            }
54            QueryTextMode::Obfuscated => {
55                kv.push(KeyValue::new(
56                    attribute::DB_QUERY_TEXT,
57                    crate::obfuscate::obfuscate(sql),
58                ));
59            }
60            QueryTextMode::Off => {}
61        }
62    }
63    kv
64}
65
66/// Create an OpenTelemetry span for a database operation and return a context containing it.
67fn start_span(name: &str, span_attrs: Vec<KeyValue>) -> (OtelContext, Instant) {
68    let tracer = opentelemetry::global::tracer("sqlx-otel");
69    let span = tracer
70        .span_builder(name.to_owned())
71        .with_kind(SpanKind::Client)
72        .with_attributes(span_attrs)
73        .start(&tracer);
74    let cx = OtelContext::current_with_span(span);
75    (cx, Instant::now())
76}
77
78/// Start an instrumented query: derive the span name from the connection attributes and
79/// per-query annotations, build the span and metric attribute lists, and open the span.
80///
81/// Returns the span's context, the timing reference for `finish()`, and the metric
82/// attribute list. This consolidates the boilerplate that every `Executor` method shares
83/// before delegating to the inner `SQLx` call.
84fn begin_query_span(
85    attrs: &ConnectionAttributes,
86    sql: Option<&str>,
87    annotations: Option<&QueryAnnotations>,
88) -> (OtelContext, Instant, Vec<KeyValue>) {
89    let (op, coll, summary) = annotations.map_or((None, None, None), |a| {
90        (
91            a.operation.as_deref(),
92            a.collection.as_deref(),
93            a.query_summary.as_deref(),
94        )
95    });
96    let name = attributes::span_name(attrs.system, op, coll, summary);
97    let span_attrs = build_attributes(attrs, sql, annotations);
98    let metric_attrs = attrs.base_key_values();
99    let (cx, start) = start_span(&name, span_attrs);
100    (cx, start, metric_attrs)
101}
102
103/// Classify a `sqlx::Error` variant into a string suitable for `error.type`.
104fn error_type(err: &sqlx::Error) -> &'static str {
105    match err {
106        sqlx::Error::Configuration(_) => "Configuration",
107        sqlx::Error::Database(_) => "Database",
108        sqlx::Error::Io(_) => "Io",
109        sqlx::Error::Tls(_) => "Tls",
110        sqlx::Error::Protocol(_) => "Protocol",
111        sqlx::Error::RowNotFound => "RowNotFound",
112        sqlx::Error::TypeNotFound { .. } => "TypeNotFound",
113        sqlx::Error::ColumnIndexOutOfBounds { .. } => "ColumnIndexOutOfBounds",
114        sqlx::Error::ColumnNotFound(_) => "ColumnNotFound",
115        sqlx::Error::ColumnDecode { .. } => "ColumnDecode",
116        sqlx::Error::Decode(_) => "Decode",
117        sqlx::Error::AnyDriverError(_) => "AnyDriverError",
118        sqlx::Error::PoolTimedOut => "PoolTimedOut",
119        sqlx::Error::PoolClosed => "PoolClosed",
120        sqlx::Error::WorkerCrashed => "WorkerCrashed",
121        sqlx::Error::Migrate(_) => "Migrate",
122        _ => "Unknown",
123    }
124}
125
126/// Record an error on the span within the given context: set status, `error.type`, and
127/// add an exception event.
128fn record_error(cx: &OtelContext, err: &sqlx::Error) {
129    let span = cx.span();
130    span.set_status(Status::Error {
131        description: Cow::Owned(err.to_string()),
132    });
133    span.set_attribute(KeyValue::new(attribute::ERROR_TYPE, error_type(err)));
134    // Extract SQLSTATE or database-specific error code when available.
135    if let sqlx::Error::Database(db_err) = err {
136        if let Some(code) = db_err.code() {
137            span.set_attribute(KeyValue::new(
138                attribute::DB_RESPONSE_STATUS_CODE,
139                code.into_owned(),
140            ));
141        }
142    }
143    span.add_event(
144        "exception",
145        vec![
146            KeyValue::new("exception.type", error_type(err)),
147            KeyValue::new("exception.message", err.to_string()),
148        ],
149    );
150}
151
152/// Record success attributes (returned rows) on the span.
153fn record_rows(cx: &OtelContext, rows: u64) {
154    cx.span().set_attribute(KeyValue::new(
155        attribute::DB_RESPONSE_RETURNED_ROWS,
156        i64::try_from(rows).unwrap_or(i64::MAX),
157    ));
158}
159
160/// Record affected rows on the span (for `execute` operations).
161fn record_affected_rows(cx: &OtelContext, rows: u64) {
162    cx.span().set_attribute(KeyValue::new(
163        "db.response.affected_rows",
164        i64::try_from(rows).unwrap_or(i64::MAX),
165    ));
166}
167
168/// End the span and record metrics.
169fn finish(
170    cx: &OtelContext,
171    start: Instant,
172    rows: Option<u64>,
173    metrics: &Metrics,
174    attrs: &[KeyValue],
175) {
176    cx.span().end();
177    metrics.record(start.elapsed(), rows, attrs);
178}
179
180/// Await a future, record any error on the span, then finish. Used by `execute`, `prepare`,
181/// `prepare_with`, and `describe` which share the same instrumentation pattern.
182async fn execute_instrumented<T>(
183    fut: futures::future::BoxFuture<'_, Result<T, sqlx::Error>>,
184    cx: OtelContext,
185    start: Instant,
186    metrics: std::sync::Arc<Metrics>,
187    metric_attrs: Vec<KeyValue>,
188) -> Result<T, sqlx::Error> {
189    let result = fut.await;
190    if let Err(err) = &result {
191        record_error(&cx, err);
192    }
193    finish(&cx, start, None, &metrics, &metric_attrs);
194    result
195}
196
197// ---------------------------------------------------------------------------
198// InstrumentedStream – keeps the span alive for streaming operations
199// ---------------------------------------------------------------------------
200
201/// Trait that determines how many rows a stream item represents.
202trait RowCounter<T> {
203    /// Return the number of rows this item contributes.
204    fn count(item: &T) -> u64;
205}
206
207/// Counts every item as one row. Used for `fetch` (which yields `Row`).
208struct CountAll;
209
210impl<T> RowCounter<T> for CountAll {
211    fn count(_item: &T) -> u64 {
212        1
213    }
214}
215
216/// Counts only `Either::Right` items as rows. Used for `fetch_many` (which yields
217/// `Either<QueryResult, Row>`).
218struct CountRight;
219
220impl<L, R> RowCounter<sqlx::Either<L, R>> for CountRight {
221    fn count(item: &sqlx::Either<L, R>) -> u64 {
222        u64::from(item.is_right())
223    }
224}
225
226/// Counts nothing. Used for `execute_many` (which yields `QueryResult`, not rows).
227struct CountNone;
228
229impl<T> RowCounter<T> for CountNone {
230    fn count(_item: &T) -> u64 {
231        0
232    }
233}
234
235/// A stream wrapper that holds an OpenTelemetry context (keeping the span alive), counts rows,
236/// and records metrics when the stream completes or is dropped.
237struct InstrumentedStream<S, C> {
238    inner: S,
239    cx: OtelContext,
240    start: Instant,
241    rows: u64,
242    metrics: std::sync::Arc<Metrics>,
243    metric_attrs: Vec<KeyValue>,
244    finished: bool,
245    _counter: std::marker::PhantomData<C>,
246}
247
248impl<S, C> InstrumentedStream<S, C> {
249    fn new(
250        inner: S,
251        cx: OtelContext,
252        start: Instant,
253        metrics: std::sync::Arc<Metrics>,
254        metric_attrs: Vec<KeyValue>,
255    ) -> Self {
256        Self {
257            inner,
258            cx,
259            start,
260            rows: 0,
261            metrics,
262            metric_attrs,
263            finished: false,
264            _counter: std::marker::PhantomData,
265        }
266    }
267
268    fn complete(&mut self) {
269        if !self.finished {
270            self.finished = true;
271            record_rows(&self.cx, self.rows);
272            finish(
273                &self.cx,
274                self.start,
275                Some(self.rows),
276                &self.metrics,
277                &self.metric_attrs,
278            );
279        }
280    }
281}
282
283// Safety: all fields are Unpin (inner S is bounded Unpin, the rest are owned values).
284// PhantomData<C> prevents auto-Unpin, so we impl it explicitly.
285impl<S: Unpin, C> Unpin for InstrumentedStream<S, C> {}
286
287impl<S, T, C> Stream for InstrumentedStream<S, C>
288where
289    S: Stream<Item = Result<T, sqlx::Error>> + Unpin,
290    C: RowCounter<T>,
291{
292    type Item = Result<T, sqlx::Error>;
293
294    fn poll_next(mut self: Pin<&mut Self>, task_cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
295        match Pin::new(&mut self.inner).poll_next(task_cx) {
296            Poll::Ready(Some(Ok(item))) => {
297                self.rows += C::count(&item);
298                Poll::Ready(Some(Ok(item)))
299            }
300            Poll::Ready(Some(Err(err))) => {
301                record_error(&self.cx, &err);
302                Poll::Ready(Some(Err(err)))
303            }
304            Poll::Ready(None) => {
305                self.complete();
306                Poll::Ready(None)
307            }
308            Poll::Pending => Poll::Pending,
309        }
310    }
311}
312
313impl<S, C> Drop for InstrumentedStream<S, C> {
314    fn drop(&mut self) {
315        self.complete();
316    }
317}
318
319// ---------------------------------------------------------------------------
320// Macro to reduce Executor impl boilerplate
321// ---------------------------------------------------------------------------
322
323/// Generate the full `sqlx::Executor` implementation for one of our wrapper types.
324///
325/// Each method extracts the SQL string, builds an OpenTelemetry span with connection attributes,
326/// delegates to the inner executor, and records metrics and errors on completion.
327///
328/// Two forms are supported:
329/// - `impl_executor!(Type, self => inner)` – no annotations (passes `None`).
330/// - `impl_executor!(Type, self => inner, annotations: expr)` – per-query annotations.
331macro_rules! impl_executor {
332    ($ty:ty, $self_:ident => $inner:expr) => {
333        impl_executor!(@impl $ty, $self_ => $inner, None);
334    };
335    ($ty:ty, $self_:ident => $inner:expr, annotations: $ann:expr) => {
336        impl_executor!(@impl $ty, $self_ => $inner, $ann);
337    };
338    (@impl $ty:ty, $self_:ident => $inner:expr, $ann:expr) => {
339        impl<'c, DB> sqlx::Executor<'c> for $ty
340        where
341            DB: Database,
342            for<'a> &'a mut DB::Connection: sqlx::Executor<'a, Database = DB>,
343        {
344            type Database = DB;
345
346            /// Execute the query and return the total number of rows affected.
347            fn execute<'e, 'q: 'e, E>(
348                $self_,
349                query: E,
350            ) -> futures::future::BoxFuture<
351                'e,
352                Result<<DB as sqlx::Database>::QueryResult, sqlx::Error>,
353            >
354            where
355                E: 'q + sqlx::Execute<'q, DB>,
356                'c: 'e,
357            {
358                let sql = query.sql().to_owned();
359                let state = $self_.state.clone();
360                let (cx, start, metric_attrs) =
361                    begin_query_span(&state.attrs, Some(&sql), $ann);
362                let fut = ($inner).execute(query);
363                Box::pin(async move {
364                    let result = fut.await;
365                    match &result {
366                        Ok(qr) => {
367                            record_affected_rows(&cx, DB::rows_affected(qr));
368                        }
369                        Err(err) => {
370                            record_error(&cx, err);
371                        }
372                    }
373                    finish(&cx, start, None, &state.metrics, &metric_attrs);
374                    result
375                })
376            }
377
378            /// Execute multiple queries and return the rows affected from each query,
379            /// in a stream.
380            fn execute_many<'e, 'q: 'e, E>(
381                $self_,
382                query: E,
383            ) -> BoxStream<'e, Result<<DB as sqlx::Database>::QueryResult, sqlx::Error>>
384            where
385                E: 'q + sqlx::Execute<'q, DB>,
386                'c: 'e,
387            {
388                let sql = query.sql().to_owned();
389                let state = $self_.state.clone();
390                let (cx, start, metric_attrs) =
391                    begin_query_span(&state.attrs, Some(&sql), $ann);
392                let stream = ($inner).execute_many(query);
393                Box::pin(InstrumentedStream::<_, CountNone>::new(
394                    stream,
395                    cx,
396                    start,
397                    state.metrics,
398                    metric_attrs,
399                ))
400            }
401
402            /// Execute the query and return the generated results as a stream.
403            fn fetch<'e, 'q: 'e, E>(
404                $self_,
405                query: E,
406            ) -> BoxStream<'e, Result<<DB as sqlx::Database>::Row, sqlx::Error>>
407            where
408                E: 'q + sqlx::Execute<'q, DB>,
409                'c: 'e,
410            {
411                let sql = query.sql().to_owned();
412                let state = $self_.state.clone();
413                let (cx, start, metric_attrs) =
414                    begin_query_span(&state.attrs, Some(&sql), $ann);
415                let stream = ($inner).fetch(query);
416                Box::pin(InstrumentedStream::<_, CountAll>::new(
417                    stream,
418                    cx,
419                    start,
420                    state.metrics,
421                    metric_attrs,
422                ))
423            }
424
425            /// Execute multiple queries and return the generated results as a stream
426            /// from each query, in a stream.
427            fn fetch_many<'e, 'q: 'e, E>(
428                $self_,
429                query: E,
430            ) -> BoxStream<
431                'e,
432                Result<
433                    sqlx::Either<
434                        <DB as sqlx::Database>::QueryResult,
435                        <DB as sqlx::Database>::Row,
436                    >,
437                    sqlx::Error,
438                >,
439            >
440            where
441                E: 'q + sqlx::Execute<'q, DB>,
442                'c: 'e,
443            {
444                let sql = query.sql().to_owned();
445                let state = $self_.state.clone();
446                let (cx, start, metric_attrs) =
447                    begin_query_span(&state.attrs, Some(&sql), $ann);
448                let stream = ($inner).fetch_many(query);
449                Box::pin(InstrumentedStream::<_, CountRight>::new(
450                    stream,
451                    cx,
452                    start,
453                    state.metrics,
454                    metric_attrs,
455                ))
456            }
457
458            /// Execute the query and return all the generated results, collected into
459            /// a [`Vec`].
460            fn fetch_all<'e, 'q: 'e, E>(
461                $self_,
462                query: E,
463            ) -> futures::future::BoxFuture<
464                'e,
465                Result<Vec<<DB as sqlx::Database>::Row>, sqlx::Error>,
466            >
467            where
468                E: 'q + sqlx::Execute<'q, DB>,
469                'c: 'e,
470            {
471                let sql = query.sql().to_owned();
472                let state = $self_.state.clone();
473                let (cx, start, metric_attrs) =
474                    begin_query_span(&state.attrs, Some(&sql), $ann);
475                let fut = ($inner).fetch_all(query);
476                Box::pin(async move {
477                    let result = fut.await;
478                    match &result {
479                        Ok(rows) => {
480                            let count = rows.len() as u64;
481                            record_rows(&cx, count);
482                            finish(&cx, start, Some(count), &state.metrics, &metric_attrs);
483                        }
484                        Err(err) => {
485                            record_error(&cx, err);
486                            finish(&cx, start, None, &state.metrics, &metric_attrs);
487                        }
488                    }
489                    result
490                })
491            }
492
493            /// Execute the query and returns exactly one row.
494            fn fetch_one<'e, 'q: 'e, E>(
495                $self_,
496                query: E,
497            ) -> futures::future::BoxFuture<
498                'e,
499                Result<<DB as sqlx::Database>::Row, sqlx::Error>,
500            >
501            where
502                E: 'q + sqlx::Execute<'q, DB>,
503                'c: 'e,
504            {
505                let sql = query.sql().to_owned();
506                let state = $self_.state.clone();
507                let (cx, start, metric_attrs) =
508                    begin_query_span(&state.attrs, Some(&sql), $ann);
509                let fut = ($inner).fetch_one(query);
510                Box::pin(async move {
511                    let result = fut.await;
512                    match &result {
513                        Ok(_) => {
514                            record_rows(&cx, 1);
515                            finish(&cx, start, Some(1), &state.metrics, &metric_attrs);
516                        }
517                        Err(err) => {
518                            record_error(&cx, err);
519                            finish(&cx, start, None, &state.metrics, &metric_attrs);
520                        }
521                    }
522                    result
523                })
524            }
525
526            /// Execute the query and returns at most one row.
527            fn fetch_optional<'e, 'q: 'e, E>(
528                $self_,
529                query: E,
530            ) -> futures::future::BoxFuture<
531                'e,
532                Result<Option<<DB as sqlx::Database>::Row>, sqlx::Error>,
533            >
534            where
535                E: 'q + sqlx::Execute<'q, DB>,
536                'c: 'e,
537            {
538                let sql = query.sql().to_owned();
539                let state = $self_.state.clone();
540                let (cx, start, metric_attrs) =
541                    begin_query_span(&state.attrs, Some(&sql), $ann);
542                let fut = ($inner).fetch_optional(query);
543                Box::pin(async move {
544                    let result = fut.await;
545                    match &result {
546                        Ok(maybe_row) => {
547                            let count = u64::from(maybe_row.is_some());
548                            record_rows(&cx, count);
549                            finish(&cx, start, Some(count), &state.metrics, &metric_attrs);
550                        }
551                        Err(err) => {
552                            record_error(&cx, err);
553                            finish(&cx, start, None, &state.metrics, &metric_attrs);
554                        }
555                    }
556                    result
557                })
558            }
559
560            /// Prepare the SQL query to inspect the type information of its parameters
561            /// and results.
562            ///
563            /// Be advised that when using the `query`, `query_as`, or `query_scalar`
564            /// functions, the query is transparently prepared and executed.
565            ///
566            /// This explicit API is provided to allow access to the statement metadata
567            /// available after it prepared but before the first row is returned.
568            fn prepare<'e, 'q: 'e>(
569                $self_,
570                query: &'q str,
571            ) -> futures::future::BoxFuture<
572                'e,
573                Result<<DB as sqlx::Database>::Statement<'q>, sqlx::Error>,
574            >
575            where
576                'c: 'e,
577            {
578                let state = $self_.state.clone();
579                let (cx, start, metric_attrs) = begin_query_span(&state.attrs, Some(query), $ann);
580                let fut = ($inner).prepare(query);
581                Box::pin(execute_instrumented(
582                    fut, cx, start, state.metrics, metric_attrs,
583                ))
584            }
585
586            /// Prepare the SQL query, with parameter type information, to inspect the
587            /// type information about its parameters and results.
588            ///
589            /// Only some database drivers (Postgres, MSSQL) can take advantage of
590            /// this extra information to influence parameter type inference.
591            fn prepare_with<'e, 'q: 'e>(
592                $self_,
593                sql: &'q str,
594                parameters: &'e [<DB as sqlx::Database>::TypeInfo],
595            ) -> futures::future::BoxFuture<
596                'e,
597                Result<<DB as sqlx::Database>::Statement<'q>, sqlx::Error>,
598            >
599            where
600                'c: 'e,
601            {
602                let state = $self_.state.clone();
603                let (cx, start, metric_attrs) = begin_query_span(&state.attrs, Some(sql), $ann);
604                let fut = ($inner).prepare_with(sql, parameters);
605                Box::pin(execute_instrumented(
606                    fut, cx, start, state.metrics, metric_attrs,
607                ))
608            }
609
610            /// Describe the SQL query and return type information about its parameters
611            /// and results.
612            ///
613            /// This is used by compile-time verification in the query macros to
614            /// power their type inference.
615            #[doc(hidden)]
616            fn describe<'e, 'q: 'e>(
617                $self_,
618                sql: &'q str,
619            ) -> futures::future::BoxFuture<
620                'e,
621                Result<sqlx::Describe<DB>, sqlx::Error>,
622            >
623            where
624                'c: 'e,
625            {
626                let state = $self_.state.clone();
627                let (cx, start, metric_attrs) = begin_query_span(&state.attrs, Some(sql), $ann);
628                let fut = ($inner).describe(sql);
629                Box::pin(execute_instrumented(
630                    fut, cx, start, state.metrics, metric_attrs,
631                ))
632            }
633        }
634    };
635}
636
637// ---------------------------------------------------------------------------
638// Executor impls for each wrapper type
639// ---------------------------------------------------------------------------
640
641impl_executor!(&'_ crate::Pool<DB>, self => &self.inner);
642impl_executor!(&'c mut crate::PoolConnection<DB>, self => self.inner.as_mut());
643impl_executor!(&'c mut crate::Transaction<'_, DB>, self => &mut *self.inner);
644
645// Annotated wrappers – same instrumentation with per-query annotations threaded through.
646impl_executor!(
647    crate::annotations::Annotated<'c, crate::Pool<DB>>,
648    self => &self.inner.inner,
649    annotations: Some(&self.annotations)
650);
651impl_executor!(
652    crate::annotations::AnnotatedMut<'c, crate::PoolConnection<DB>>,
653    self => self.inner.inner.as_mut(),
654    annotations: Some(&self.annotations)
655);
656impl_executor!(
657    crate::annotations::AnnotatedMut<'c, crate::Transaction<'_, DB>>,
658    self => &mut *self.inner.inner,
659    annotations: Some(&self.annotations)
660);
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665    use crate::attributes::ConnectionAttributes;
666
667    #[test]
668    fn error_type_classification() {
669        // Unit variants.
670        assert_eq!(error_type(&sqlx::Error::RowNotFound), "RowNotFound");
671        assert_eq!(error_type(&sqlx::Error::PoolTimedOut), "PoolTimedOut");
672        assert_eq!(error_type(&sqlx::Error::PoolClosed), "PoolClosed");
673        assert_eq!(error_type(&sqlx::Error::WorkerCrashed), "WorkerCrashed");
674
675        // String / boxed-error variants.
676        assert_eq!(
677            error_type(&sqlx::Error::Configuration("bad".into())),
678            "Configuration"
679        );
680        assert_eq!(
681            error_type(&sqlx::Error::Io(std::io::Error::other("test"))),
682            "Io"
683        );
684        assert_eq!(error_type(&sqlx::Error::Tls("tls".into())), "Tls");
685        assert_eq!(
686            error_type(&sqlx::Error::Protocol("proto".into())),
687            "Protocol"
688        );
689        assert_eq!(error_type(&sqlx::Error::Decode("dec".into())), "Decode");
690        assert_eq!(
691            error_type(&sqlx::Error::AnyDriverError("any".into())),
692            "AnyDriverError"
693        );
694
695        // Struct variants.
696        assert_eq!(
697            error_type(&sqlx::Error::ColumnNotFound("x".into())),
698            "ColumnNotFound"
699        );
700        assert_eq!(
701            error_type(&sqlx::Error::ColumnIndexOutOfBounds { index: 5, len: 3 }),
702            "ColumnIndexOutOfBounds"
703        );
704        assert_eq!(
705            error_type(&sqlx::Error::ColumnDecode {
706                index: "0".into(),
707                source: "bad".into(),
708            }),
709            "ColumnDecode"
710        );
711        assert_eq!(
712            error_type(&sqlx::Error::TypeNotFound {
713                type_name: "Foo".into(),
714            }),
715            "TypeNotFound"
716        );
717
718        // Migrate variant (behind sqlx's "migrate" default feature).
719        assert_eq!(
720            error_type(&sqlx::Error::Migrate(Box::new(
721                sqlx::migrate::MigrateError::Execute(sqlx::Error::Protocol("test".into()))
722            ))),
723            "Migrate"
724        );
725
726        // The `_ => "Unknown"` branch covers future sqlx::Error variants that may be
727        // added in newer sqlx releases. It cannot be tested directly since we cannot
728        // construct an unknown variant, but it ensures forward compatibility.
729    }
730
731    fn test_attrs() -> ConnectionAttributes {
732        ConnectionAttributes {
733            system: "postgresql",
734            host: Some("localhost".into()),
735            port: Some(5432),
736            namespace: Some("mydb".into()),
737            network_peer_address: None,
738            network_peer_port: None,
739            query_text_mode: QueryTextMode::Full,
740        }
741    }
742
743    // ===========================================================================
744    // query text
745    // ===========================================================================
746
747    #[test]
748    fn build_attributes_with_full_query_text() {
749        let attrs = test_attrs();
750        let kv = build_attributes(&attrs, Some("SELECT 1"), None);
751        let keys: Vec<&str> = kv.iter().map(|k| k.key.as_str()).collect();
752        assert!(keys.contains(&"db.query.text"));
753    }
754
755    #[test]
756    fn build_attributes_with_off_query_text() {
757        let mut attrs = test_attrs();
758        attrs.query_text_mode = QueryTextMode::Off;
759        let kv = build_attributes(&attrs, Some("SELECT 1"), None);
760        let keys: Vec<&str> = kv.iter().map(|k| k.key.as_str()).collect();
761        assert!(!keys.contains(&"db.query.text"));
762    }
763
764    #[test]
765    fn build_attributes_obfuscated_replaces_literals() {
766        let mut attrs = test_attrs();
767        attrs.query_text_mode = QueryTextMode::Obfuscated;
768        let kv = build_attributes(
769            &attrs,
770            Some("INSERT INTO t (id, name) VALUES (1, 'alice')"),
771            None,
772        );
773        let text = kv
774            .iter()
775            .find(|k| k.key.as_str() == "db.query.text")
776            .map(|k| k.value.clone());
777        assert_eq!(
778            text,
779            Some(opentelemetry::Value::String(
780                "INSERT INTO t (id, name) VALUES (?, ?)".into()
781            ))
782        );
783    }
784
785    // ===========================================================================
786    // annotations
787    // ===========================================================================
788
789    #[test]
790    fn build_attributes_no_sql_no_annotations() {
791        let attrs = test_attrs();
792        let kv = build_attributes(&attrs, None, None);
793        let keys: Vec<&str> = kv.iter().map(|k| k.key.as_str()).collect();
794        assert!(!keys.contains(&"db.query.text"));
795        assert!(!keys.contains(&"db.operation.name"));
796        assert!(!keys.contains(&"db.collection.name"));
797        assert!(!keys.contains(&"db.query.summary"));
798        assert!(!keys.contains(&"db.stored_procedure.name"));
799        assert!(keys.contains(&"db.system.name"));
800    }
801
802    #[test]
803    fn build_attributes_with_all_annotation_fields() {
804        let attrs = test_attrs();
805        let ann = QueryAnnotations::new()
806            .operation("SELECT")
807            .collection("users")
808            .query_summary("SELECT users")
809            .stored_procedure("sp_get");
810        let kv = build_attributes(&attrs, Some("SELECT * FROM users"), Some(&ann));
811        let find = |key: &str| {
812            kv.iter()
813                .find(|k| k.key.as_str() == key)
814                .map(|k| k.value.clone())
815        };
816        assert_eq!(
817            find("db.operation.name"),
818            Some(opentelemetry::Value::String("SELECT".into()))
819        );
820        assert_eq!(
821            find("db.collection.name"),
822            Some(opentelemetry::Value::String("users".into()))
823        );
824        assert_eq!(
825            find("db.query.summary"),
826            Some(opentelemetry::Value::String("SELECT users".into()))
827        );
828        assert_eq!(
829            find("db.stored_procedure.name"),
830            Some(opentelemetry::Value::String("sp_get".into()))
831        );
832        assert_eq!(
833            find("db.query.text"),
834            Some(opentelemetry::Value::String("SELECT * FROM users".into()))
835        );
836    }
837
838    #[test]
839    fn build_attributes_annotation_field_permutations() {
840        type Setter = fn(QueryAnnotations) -> QueryAnnotations;
841
842        let attrs = test_attrs();
843        let fields: &[(&str, Setter)] = &[
844            ("db.operation.name", |a| a.operation("SELECT")),
845            ("db.collection.name", |a| a.collection("users")),
846            ("db.query.summary", |a| a.query_summary("SELECT users")),
847            ("db.stored_procedure.name", |a| a.stored_procedure("sp")),
848        ];
849
850        // Verify every permutation (2^4 = 16) of the four annotation fields: each field that is
851        // `Some` must appear in the output, and each field that is `None` must be absent.
852        for mask in 0u8..16 {
853            let mut ann = QueryAnnotations::new();
854            for (i, &(_, setter)) in fields.iter().enumerate() {
855                if mask & (1 << i) != 0 {
856                    ann = setter(ann);
857                }
858            }
859            let kv = build_attributes(&attrs, None, Some(&ann));
860            let keys: Vec<&str> = kv.iter().map(|k| k.key.as_str()).collect();
861            for (i, &(key, _)) in fields.iter().enumerate() {
862                println!(
863                    "mask: {:08b}, field: {}, key: {}; contains: {}",
864                    mask,
865                    i,
866                    key,
867                    keys.contains(&key)
868                );
869                if mask & (1 << i) != 0 {
870                    assert!(
871                        keys.contains(&key),
872                        "{key} should be present for mask {mask:#06b}"
873                    );
874                } else {
875                    assert!(
876                        !keys.contains(&key),
877                        "{key} should be absent for mask {mask:#06b}"
878                    );
879                }
880            }
881        }
882    }
883
884    use proptest::prelude::*;
885
886    /// Build a `ConnectionAttributes` from explicit option fields. Used by the proptest
887    /// strategies below so that each generated case exercises an arbitrary subset of the
888    /// optional connection-level fields.
889    fn make_connection_attributes(
890        host: Option<String>,
891        port: Option<u16>,
892        namespace: Option<String>,
893        network_peer_address: Option<String>,
894        network_peer_port: Option<u16>,
895        query_text_mode: QueryTextMode,
896    ) -> ConnectionAttributes {
897        ConnectionAttributes {
898            system: "postgresql",
899            host,
900            port,
901            namespace,
902            network_peer_address,
903            network_peer_port,
904            query_text_mode,
905        }
906    }
907
908    /// Strategy for the three `QueryTextMode` variants.
909    fn any_query_text_mode() -> impl Strategy<Value = QueryTextMode> {
910        prop_oneof![
911            Just(QueryTextMode::Full),
912            Just(QueryTextMode::Obfuscated),
913            Just(QueryTextMode::Off),
914        ]
915    }
916
917    /// Strategy for an arbitrary `QueryAnnotations` whose four fields are independently
918    /// `None` or `Some(s)` for a bounded-length string `s`.
919    fn any_annotations() -> impl Strategy<Value = QueryAnnotations> {
920        (
921            proptest::option::of(".{0,32}"),
922            proptest::option::of(".{0,32}"),
923            proptest::option::of(".{0,32}"),
924            proptest::option::of(".{0,32}"),
925        )
926            .prop_map(|(op, coll, summary, sp)| {
927                let mut ann = QueryAnnotations::new();
928                if let Some(s) = op {
929                    ann = ann.operation(s);
930                }
931                if let Some(s) = coll {
932                    ann = ann.collection(s);
933                }
934                if let Some(s) = summary {
935                    ann = ann.query_summary(s);
936                }
937                if let Some(s) = sp {
938                    ann = ann.stored_procedure(s);
939                }
940                ann
941            })
942    }
943
944    proptest! {
945        #![proptest_config(ProptestConfig::with_cases(128))]
946
947        /// Membership invariant: the keys emitted by `build_attributes` are exactly the
948        /// union of the base connection keys, the four annotation keys (each iff its
949        /// field is `Some`), and `db.query.text` (iff `sql.is_some()` and the mode is
950        /// not `Off`).
951        #[test]
952        fn build_attributes_membership_invariant(
953            host in proptest::option::of("[a-z]{1,16}"),
954            port in proptest::option::of(any::<u16>()),
955            namespace in proptest::option::of("[a-z]{1,16}"),
956            network_peer_address in proptest::option::of("[0-9.:]{1,32}"),
957            network_peer_port in proptest::option::of(any::<u16>()),
958            mode in any_query_text_mode(),
959            sql in proptest::option::of(".{0,64}"),
960            ann in any_annotations(),
961        ) {
962            let attrs = make_connection_attributes(
963                host.clone(), port, namespace.clone(),
964                network_peer_address.clone(), network_peer_port, mode,
965            );
966            let kv = build_attributes(&attrs, sql.as_deref(), Some(&ann));
967            let keys: Vec<&str> = kv.iter().map(|k| k.key.as_str()).collect();
968
969            // `db.system.name` is always present.
970            prop_assert!(keys.contains(&"db.system.name"));
971
972            // Optional connection keys appear iff their field is `Some`.
973            prop_assert_eq!(keys.contains(&"server.address"), host.is_some());
974            prop_assert_eq!(keys.contains(&"server.port"), port.is_some());
975            prop_assert_eq!(keys.contains(&"db.namespace"), namespace.is_some());
976            prop_assert_eq!(keys.contains(&"network.peer.address"), network_peer_address.is_some());
977            prop_assert_eq!(keys.contains(&"network.peer.port"), network_peer_port.is_some());
978
979            // Annotation keys appear iff their field is `Some`.
980            prop_assert_eq!(keys.contains(&"db.operation.name"), ann.operation.is_some());
981            prop_assert_eq!(keys.contains(&"db.collection.name"), ann.collection.is_some());
982            prop_assert_eq!(keys.contains(&"db.query.summary"), ann.query_summary.is_some());
983            prop_assert_eq!(keys.contains(&"db.stored_procedure.name"), ann.stored_procedure.is_some());
984
985            // `db.query.text` is emitted iff sql is provided and mode is not Off.
986            let expect_query_text = sql.is_some() && mode != QueryTextMode::Off;
987            prop_assert_eq!(keys.contains(&"db.query.text"), expect_query_text);
988        }
989
990        /// No key appears more than once in the emitted attribute list. Duplicate keys
991        /// would cause downstream OTel exporters to emit conflicting tag values.
992        #[test]
993        fn build_attributes_has_no_duplicate_keys(
994            host in proptest::option::of("[a-z]{1,16}"),
995            port in proptest::option::of(any::<u16>()),
996            namespace in proptest::option::of("[a-z]{1,16}"),
997            mode in any_query_text_mode(),
998            sql in proptest::option::of(".{0,64}"),
999            ann in any_annotations(),
1000        ) {
1001            let attrs = make_connection_attributes(host, port, namespace, None, None, mode);
1002            let kv = build_attributes(&attrs, sql.as_deref(), Some(&ann));
1003            let mut seen = std::collections::HashSet::new();
1004            for k in &kv {
1005                prop_assert!(
1006                    seen.insert(k.key.as_str().to_owned()),
1007                    "duplicate key in build_attributes output: {}",
1008                    k.key.as_str(),
1009                );
1010            }
1011        }
1012
1013        /// `build_attributes` does not panic on arbitrary unicode SQL across all three
1014        /// query-text modes, including the obfuscated path that delegates into
1015        /// `obfuscate::obfuscate`.
1016        #[test]
1017        fn build_attributes_no_panic_arbitrary_sql(
1018            sql in proptest::option::of(any::<String>()),
1019            mode in any_query_text_mode(),
1020            ann in any_annotations(),
1021        ) {
1022            let attrs = make_connection_attributes(None, None, None, None, None, mode);
1023            let _ = build_attributes(&attrs, sql.as_deref(), Some(&ann));
1024        }
1025
1026        /// When `annotations` is `None`, no annotation keys appear in the output
1027        /// regardless of any other input – the `if let Some(ann)` guard short-circuits
1028        /// the entire annotation-emission block.
1029        #[test]
1030        fn build_attributes_no_annotations_emits_no_annotation_keys(
1031            mode in any_query_text_mode(),
1032            sql in proptest::option::of(".{0,64}"),
1033        ) {
1034            let attrs = make_connection_attributes(None, None, None, None, None, mode);
1035            let kv = build_attributes(&attrs, sql.as_deref(), None);
1036            let keys: Vec<&str> = kv.iter().map(|k| k.key.as_str()).collect();
1037            prop_assert!(!keys.contains(&"db.operation.name"));
1038            prop_assert!(!keys.contains(&"db.collection.name"));
1039            prop_assert!(!keys.contains(&"db.query.summary"));
1040            prop_assert!(!keys.contains(&"db.stored_procedure.name"));
1041        }
1042    }
1043}