Skip to main content

sqlx_otel/
pool.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use opentelemetry_semantic_conventions::metric as semconv_metric;
5
6use crate::annotations::{Annotated, QueryAnnotations};
7use crate::attributes::{ConnectionAttributes, QueryTextMode};
8use crate::connection::PoolConnection;
9use crate::database::Database;
10use crate::metrics::Metrics;
11use crate::transaction::Transaction;
12
13/// Shared state propagated to every wrapper type derived from a pool.
14#[derive(Debug, Clone)]
15pub(crate) struct SharedState {
16    pub attrs: Arc<ConnectionAttributes>,
17    pub metrics: Arc<Metrics>,
18}
19
20/// Builder for constructing an instrumented [`Pool`] from a raw `sqlx::Pool`.
21///
22/// The builder auto-extracts connection attributes (host, port, database namespace) from
23/// the underlying connect options via the [`Database`] trait, then lets you override any of
24/// them before calling [`build`](Self::build). Settings on the wrapped `sqlx::Pool` itself
25/// (max connections, idle timeout, etc.) should be applied to the `sqlx::Pool` *before*
26/// passing it to the builder – `sqlx-otel` does not duplicate `SQLx`'s configuration
27/// surface.
28///
29/// # Example
30///
31/// ```no_run
32/// # #[cfg(feature = "sqlite")]
33/// # async fn _doc() -> Result<(), sqlx::Error> {
34/// use sqlx_otel::{PoolBuilder, QueryTextMode};
35/// use std::time::Duration;
36///
37/// let raw = sqlx::SqlitePool::connect(":memory:").await?;
38/// let pool = PoolBuilder::from(raw)
39///     .with_database("my_db")
40///     .with_query_text_mode(QueryTextMode::Obfuscated)
41///     .with_pool_name("my-service-db")
42///     .with_pool_metrics_interval(Duration::from_secs(5))
43///     .build();
44/// # let _ = pool;
45/// # Ok(())
46/// # }
47/// ```
48#[derive(Debug)]
49pub struct PoolBuilder<DB: sqlx::Database> {
50    pool: sqlx::Pool<DB>,
51    host: Option<String>,
52    port: Option<u16>,
53    namespace: Option<String>,
54    network_peer_address: Option<String>,
55    network_peer_port: Option<u16>,
56    network_protocol_name: Option<String>,
57    network_transport: Option<String>,
58    query_text_mode: QueryTextMode,
59    pool_name: Option<String>,
60    pool_metrics_interval: Duration,
61}
62
63impl<DB: Database> From<sqlx::Pool<DB>> for PoolBuilder<DB> {
64    /// Create a builder from an existing `sqlx::Pool`, auto-extracting connection
65    /// attributes from the backend's connect options. `network.protocol.name` is
66    /// pre-populated from [`Database::DEFAULT_NETWORK_PROTOCOL_NAME`] (the wire protocol
67    /// for Postgres / `MySQL`; absent for `SQLite`); override via
68    /// [`with_network_protocol_name`](Self::with_network_protocol_name).
69    fn from(pool: sqlx::Pool<DB>) -> Self {
70        let (host, port, namespace) = DB::connection_attributes(&pool);
71        Self {
72            pool,
73            host,
74            port,
75            namespace,
76            network_peer_address: None,
77            network_peer_port: None,
78            network_protocol_name: DB::DEFAULT_NETWORK_PROTOCOL_NAME.map(String::from),
79            network_transport: None,
80            query_text_mode: QueryTextMode::default(),
81            pool_name: None,
82            pool_metrics_interval: Duration::from_secs(10),
83        }
84    }
85}
86
87impl<DB: Database> PoolBuilder<DB> {
88    /// Override the `db.namespace` attribute (the database name).
89    #[must_use]
90    pub fn with_database(mut self, database: impl Into<String>) -> Self {
91        self.namespace = Some(database.into());
92        self
93    }
94
95    /// Override the `server.address` attribute (the logical hostname).
96    #[must_use]
97    pub fn with_host(mut self, host: impl Into<String>) -> Self {
98        self.host = Some(host.into());
99        self
100    }
101
102    /// Override the `server.port` attribute.
103    #[must_use]
104    pub fn with_port(mut self, port: u16) -> Self {
105        self.port = Some(port);
106        self
107    }
108
109    /// Set the `network.peer.address` attribute (the resolved IP address).
110    #[must_use]
111    pub fn with_network_peer_address(mut self, address: impl Into<String>) -> Self {
112        self.network_peer_address = Some(address.into());
113        self
114    }
115
116    /// Set the `network.peer.port` attribute (the resolved port).
117    #[must_use]
118    pub fn with_network_peer_port(mut self, port: u16) -> Self {
119        self.network_peer_port = Some(port);
120        self
121    }
122
123    /// Override the `network.protocol.name` attribute. Defaults to the backend's wire
124    /// protocol via [`Database::DEFAULT_NETWORK_PROTOCOL_NAME`] (`"postgresql"` /
125    /// `"mysql"`; absent for `SQLite`). Override when the connection is tunnelled through
126    /// a different application-layer protocol or when reporting to a system that expects a
127    /// specific name.
128    #[must_use]
129    pub fn with_network_protocol_name(mut self, name: impl Into<String>) -> Self {
130        self.network_protocol_name = Some(name.into());
131        self
132    }
133
134    /// Set the `network.transport` attribute (the OSI L4 transport: `"tcp"`, `"udp"`,
135    /// `"pipe"`, `"unix"`, `"inproc"`). The wrapper does not infer transport from the
136    /// connect string – callers who want this attribute on spans / metrics must set it
137    /// explicitly so the value reflects the deployment configuration rather than a guess.
138    #[must_use]
139    pub fn with_network_transport(mut self, transport: impl Into<String>) -> Self {
140        self.network_transport = Some(transport.into());
141        self
142    }
143
144    /// Configure how `db.query.text` is captured on spans. Defaults to
145    /// [`QueryTextMode::Full`].
146    #[must_use]
147    pub fn with_query_text_mode(mut self, mode: QueryTextMode) -> Self {
148        self.query_text_mode = mode;
149        self
150    }
151
152    /// Set the `db.client.connection.pool.name` attribute and enable the
153    /// `db.client.connection.count` polling task.
154    ///
155    /// When a runtime feature (`runtime-tokio` or `runtime-async-std`) is also enabled, a
156    /// background task is spawned that periodically records `db.client.connection.count`
157    /// (idle / used). See [`with_pool_metrics_interval`](Self::with_pool_metrics_interval)
158    /// to configure the polling frequency. The task is cancelled when the [`Pool`] (and
159    /// every clone of it) is dropped.
160    ///
161    /// **Without a runtime feature, the name is recorded but no `connection.count` task is
162    /// spawned and the gauge is never reported.** All other operation- and pool-level
163    /// metrics still work in that configuration.
164    #[must_use]
165    pub fn with_pool_name(mut self, name: impl Into<String>) -> Self {
166        self.pool_name = Some(name.into());
167        self
168    }
169
170    /// Set the polling interval for `db.client.connection.count`. Defaults to 10 seconds.
171    ///
172    /// Has no effect unless [`with_pool_name`](Self::with_pool_name) is also called and a
173    /// runtime feature is enabled.
174    #[must_use]
175    pub fn with_pool_metrics_interval(mut self, interval: Duration) -> Self {
176        self.pool_metrics_interval = interval;
177        self
178    }
179
180    /// Consume the builder and produce an instrumented [`Pool`].
181    ///
182    /// At this point the static pool gauges (`db.client.connection.max`,
183    /// `db.client.connection.idle.max`, `db.client.connection.idle.min`) are recorded
184    /// once with the connection-level attributes – they do not change over the pool's
185    /// lifetime. The wait-time / use-time / timeout / pending-request instruments are
186    /// created here and updated inline on every `acquire()` and connection drop.
187    #[must_use]
188    pub fn build(self) -> Pool<DB> {
189        let metrics_shutdown = self.spawn_pool_metrics_task();
190
191        let attrs = Arc::new(ConnectionAttributes {
192            system: DB::SYSTEM,
193            host: self.host,
194            port: self.port,
195            namespace: self.namespace,
196            network_peer_address: self.network_peer_address,
197            network_peer_port: self.network_peer_port,
198            network_protocol_name: self.network_protocol_name,
199            network_transport: self.network_transport,
200            pool_name: self.pool_name,
201            query_text_mode: self.query_text_mode,
202        });
203        let metrics = Arc::new(Metrics::new());
204        let meter = opentelemetry::global::meter("sqlx-otel");
205
206        // Record static pool configuration gauges once – these never change.
207        let max_conns = i64::from(self.pool.options().get_max_connections());
208        let min_conns = i64::from(self.pool.options().get_min_connections());
209        let base_attrs = attrs.base_key_values();
210
211        meter
212            .i64_gauge(semconv_metric::DB_CLIENT_CONNECTION_MAX)
213            .with_description("The maximum number of open connections allowed.")
214            .build()
215            .record(max_conns, &base_attrs);
216        meter
217            .i64_gauge(semconv_metric::DB_CLIENT_CONNECTION_IDLE_MAX)
218            .with_description("The maximum number of idle open connections allowed.")
219            .build()
220            .record(max_conns, &base_attrs);
221        meter
222            .i64_gauge(semconv_metric::DB_CLIENT_CONNECTION_IDLE_MIN)
223            .with_description("The minimum number of idle open connections allowed.")
224            .build()
225            .record(min_conns, &base_attrs);
226
227        Pool {
228            inner: self.pool,
229            state: SharedState { attrs, metrics },
230            metrics_shutdown,
231            wait_time: Arc::new(
232                meter
233                    .f64_histogram(semconv_metric::DB_CLIENT_CONNECTION_WAIT_TIME)
234                    .with_unit("s")
235                    .with_description(
236                        "The time it took to obtain an open connection from the pool.",
237                    )
238                    .build(),
239            ),
240            use_time: Arc::new(
241                meter
242                    .f64_histogram(semconv_metric::DB_CLIENT_CONNECTION_USE_TIME)
243                    .with_unit("s")
244                    .with_description(
245                        "The time between borrowing a connection and returning it to the pool.",
246                    )
247                    .build(),
248            ),
249            timeouts: Arc::new(
250                meter
251                    .u64_counter(semconv_metric::DB_CLIENT_CONNECTION_TIMEOUTS)
252                    .with_description(
253                        "The number of connection pool acquire attempts that timed out.",
254                    )
255                    .build(),
256            ),
257            pending_requests: Arc::new(
258                meter
259                    .i64_up_down_counter(semconv_metric::DB_CLIENT_CONNECTION_PENDING_REQUESTS)
260                    .with_description("The number of pending requests for an open connection.")
261                    .build(),
262            ),
263        }
264    }
265
266    /// Spawn the pool metrics background task if a pool name is set and a runtime is
267    /// available. Returns the shutdown handle (or `None`).
268    fn spawn_pool_metrics_task(&self) -> Option<crate::pool_metrics::ShutdownHandle> {
269        let name = self.pool_name.as_ref()?;
270
271        // Prefer tokio if both runtimes are enabled.
272        #[cfg(feature = "runtime-tokio")]
273        {
274            Some(
275                crate::pool_metrics::spawn::<crate::runtime::TokioRuntime, DB>(
276                    self.pool.clone(),
277                    name.clone(),
278                    self.pool_metrics_interval,
279                ),
280            )
281        }
282
283        #[cfg(all(feature = "runtime-async-std", not(feature = "runtime-tokio")))]
284        {
285            Some(crate::pool_metrics::spawn::<
286                crate::runtime::AsyncStdRuntime,
287                DB,
288            >(
289                self.pool.clone(),
290                name.clone(),
291                self.pool_metrics_interval,
292            ))
293        }
294
295        #[cfg(not(any(feature = "runtime-tokio", feature = "runtime-async-std")))]
296        {
297            let _ = name;
298            None
299        }
300    }
301}
302
303/// An instrumented wrapper around `sqlx::Pool` that emits OpenTelemetry spans and metrics
304/// for every database operation.
305///
306/// Create one via [`PoolBuilder`]. The wrapper is a drop-in replacement for `sqlx::Pool`:
307/// `&Pool<DB>` implements [`sqlx::Executor`], so you can pass it straight into
308/// `sqlx::query(...)`, `sqlx::query_as(...)`, and friends. Connections acquired via
309/// [`acquire`](Self::acquire) and transactions started via [`begin`](Self::begin) inherit
310/// the same instrumentation and produce spans / metrics with identical connection-level
311/// attributes.
312///
313/// `Clone` is cheap – the inner `sqlx::Pool`, the connection-level attribute set, and the
314/// metric instruments are all `Arc`-shared. Cloning never copies state; cloned pools share
315/// the same underlying connection pool and metric stream.
316///
317/// # Example
318///
319/// ```no_run
320/// # #[cfg(feature = "sqlite")]
321/// # async fn _doc() -> Result<(), sqlx::Error> {
322/// use sqlx_otel::PoolBuilder;
323///
324/// let raw = sqlx::SqlitePool::connect(":memory:").await?;
325/// let pool = PoolBuilder::from(raw).build();
326///
327/// // Pass `&pool` anywhere a `sqlx::Executor` is expected.
328/// let row: (i64,) = sqlx::query_as("SELECT 1").fetch_one(&pool).await?;
329/// assert_eq!(row.0, 1);
330/// # Ok(())
331/// # }
332/// ```
333///
334/// See also [`with_annotations`](Self::with_annotations) for per-query semantic-convention
335/// attributes, and [`crate::QueryAnnotateExt`] for attaching annotations on the query side
336/// instead of the executor side.
337#[derive(Debug)]
338pub struct Pool<DB: sqlx::Database> {
339    pub(crate) inner: sqlx::Pool<DB>,
340    pub(crate) state: SharedState,
341    /// Dropping this handle signals the background polling task to stop.
342    metrics_shutdown: Option<crate::pool_metrics::ShutdownHandle>,
343    /// Histogram for `db.client.connection.wait_time`, recorded on each `acquire()`.
344    wait_time: Arc<opentelemetry::metrics::Histogram<f64>>,
345    /// Histogram for `db.client.connection.use_time`, recorded when a connection is dropped.
346    pub(crate) use_time: Arc<opentelemetry::metrics::Histogram<f64>>,
347    /// Counter for `db.client.connection.timeouts`, incremented on `PoolTimedOut`.
348    timeouts: Arc<opentelemetry::metrics::Counter<u64>>,
349    /// Up/down counter for `db.client.connection.pending_requests`, tracks callers
350    /// currently waiting in `acquire()`.
351    pending_requests: Arc<opentelemetry::metrics::UpDownCounter<i64>>,
352}
353
354impl<DB: sqlx::Database> Clone for Pool<DB> {
355    fn clone(&self) -> Self {
356        Self {
357            inner: self.inner.clone(),
358            state: self.state.clone(),
359            metrics_shutdown: self.metrics_shutdown.clone(),
360            wait_time: self.wait_time.clone(),
361            use_time: self.use_time.clone(),
362            timeouts: self.timeouts.clone(),
363            pending_requests: self.pending_requests.clone(),
364        }
365    }
366}
367
368impl<DB: Database> Pool<DB> {
369    /// Acquire a pooled connection instrumented for OpenTelemetry.
370    ///
371    /// Records `db.client.connection.wait_time` (time spent waiting for a connection),
372    /// tracks `db.client.connection.pending_requests` while the call is in flight, and
373    /// increments `db.client.connection.timeouts` on `sqlx::Error::PoolTimedOut`. The
374    /// returned [`PoolConnection`] records `db.client.connection.use_time` when dropped
375    /// and is itself an [`sqlx::Executor`] via `&mut conn`.
376    ///
377    /// # Errors
378    ///
379    /// Returns `sqlx::Error` if a connection cannot be obtained from the pool – typically
380    /// `PoolTimedOut` when the configured acquire timeout elapses, or `PoolClosed` after
381    /// [`close`](Self::close).
382    pub async fn acquire(&self) -> Result<PoolConnection<DB>, sqlx::Error> {
383        let attrs = self.state.attrs.base_key_values();
384        self.pending_requests.add(1, &attrs);
385        let start = std::time::Instant::now();
386        let result = self.inner.acquire().await;
387        self.pending_requests.add(-1, &attrs);
388        self.wait_time.record(start.elapsed().as_secs_f64(), &attrs);
389
390        if let Err(sqlx::Error::PoolTimedOut) = &result {
391            self.timeouts.add(1, &attrs);
392        }
393
394        result.map(|inner| PoolConnection {
395            inner,
396            state: self.state.clone(),
397            use_time: self.use_time.clone(),
398            acquired_at: std::time::Instant::now(),
399            base_attrs: attrs,
400        })
401    }
402
403    /// Begin a new transaction instrumented for OpenTelemetry.
404    ///
405    /// The returned [`Transaction`] implements `sqlx::Executor` via `&mut tx` and emits
406    /// the same per-operation spans and metrics as the pool itself. Call
407    /// [`commit`](Transaction::commit) or [`rollback`](Transaction::rollback) to terminate
408    /// it; dropping the value without doing either rolls back implicitly (per `SQLx`'s
409    /// usual behaviour).
410    ///
411    /// # Errors
412    ///
413    /// Returns `sqlx::Error` if `BEGIN` fails – typically due to a connection problem or
414    /// because the underlying connection cannot start a new transaction.
415    pub async fn begin(&self) -> Result<Transaction<'_, DB>, sqlx::Error> {
416        self.inner.begin().await.map(|inner| Transaction {
417            inner,
418            state: self.state.clone(),
419        })
420    }
421
422    /// Shut down the pool, waiting for all connections to be released.
423    pub async fn close(&self) {
424        self.inner.close().await;
425    }
426
427    /// Returns `true` if the pool has been closed.
428    #[must_use]
429    pub fn is_closed(&self) -> bool {
430        self.inner.is_closed()
431    }
432
433    /// Return an annotated executor that attaches per-query semantic-convention attributes
434    /// (`db.operation.name`, `db.collection.name`, `db.query.summary`,
435    /// `db.stored_procedure.name`) to every span created by the next operation.
436    ///
437    /// The returned wrapper borrows the pool and implements `sqlx::Executor`. Use the
438    /// query-side equivalent ([`crate::QueryAnnotateExt`]) when the annotation belongs
439    /// next to the query text rather than next to the executor.
440    ///
441    /// # Example
442    ///
443    /// ```no_run
444    /// # #[cfg(feature = "sqlite")]
445    /// # async fn _doc() -> Result<(), sqlx::Error> {
446    /// # use sqlx_otel::PoolBuilder;
447    /// use sqlx::Executor as _;
448    /// use sqlx_otel::QueryAnnotations;
449    /// # let pool = PoolBuilder::from(sqlx::SqlitePool::connect(":memory:").await?).build();
450    ///
451    /// pool.with_annotations(
452    ///     QueryAnnotations::new()
453    ///         .operation("SELECT")
454    ///         .collection("users"),
455    /// )
456    /// .fetch_all("SELECT * FROM users")
457    /// .await?;
458    /// # Ok(())
459    /// # }
460    /// ```
461    #[must_use]
462    pub fn with_annotations(&self, annotations: QueryAnnotations) -> Annotated<'_, Self> {
463        Annotated {
464            inner: self,
465            annotations,
466            state: self.state.clone(),
467        }
468    }
469
470    /// Shorthand for annotating the next operation with `db.operation.name` and
471    /// `db.collection.name`.
472    ///
473    /// Equivalent to `self.with_annotations(QueryAnnotations::new().operation(op).collection(coll))`.
474    #[must_use]
475    pub fn with_operation(
476        &self,
477        operation: impl Into<String>,
478        collection: impl Into<String>,
479    ) -> Annotated<'_, Self> {
480        self.with_annotations(
481            QueryAnnotations::new()
482                .operation(operation)
483                .collection(collection),
484        )
485    }
486}