Skip to main content

sqlx_otel/
pool.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use opentelemetry_semantic_conventions::metric as semconv_metric;
5
6use crate::annotations::{Annotated, QueryAnnotations};
7use crate::attributes::{ConnectionAttributes, QueryTextMode};
8use crate::connection::PoolConnection;
9use crate::database::Database;
10use crate::metrics::Metrics;
11use crate::transaction::Transaction;
12
13/// Shared state propagated to every wrapper type derived from a pool.
14#[derive(Debug, Clone)]
15pub(crate) struct SharedState {
16    pub attrs: Arc<ConnectionAttributes>,
17    pub metrics: Arc<Metrics>,
18}
19
20/// Builder for constructing an instrumented [`Pool`] from a raw `sqlx::Pool`.
21///
22/// The builder auto-extracts connection attributes (host, port, namespace) from the
23/// pool's connect options via the [`Database`] trait, then allows overriding any of them
24/// before calling [`build()`](Self::build).
25///
26/// # Example
27///
28/// ```ignore
29/// let pool = PoolBuilder::from(sqlx_pool)
30///     .with_database("my_db")
31///     .build();
32/// ```
33#[derive(Debug)]
34pub struct PoolBuilder<DB: sqlx::Database> {
35    pool: sqlx::Pool<DB>,
36    host: Option<String>,
37    port: Option<u16>,
38    namespace: Option<String>,
39    network_peer_address: Option<String>,
40    network_peer_port: Option<u16>,
41    query_text_mode: QueryTextMode,
42    pool_name: Option<String>,
43    pool_metrics_interval: Duration,
44}
45
46impl<DB: Database> From<sqlx::Pool<DB>> for PoolBuilder<DB> {
47    /// Create a builder from an existing `sqlx::Pool`, auto-extracting connection
48    /// attributes from the backend's connect options.
49    fn from(pool: sqlx::Pool<DB>) -> Self {
50        let (host, port, namespace) = DB::connection_attributes(&pool);
51        Self {
52            pool,
53            host,
54            port,
55            namespace,
56            network_peer_address: None,
57            network_peer_port: None,
58            query_text_mode: QueryTextMode::default(),
59            pool_name: None,
60            pool_metrics_interval: Duration::from_secs(10),
61        }
62    }
63}
64
65impl<DB: Database> PoolBuilder<DB> {
66    /// Override the `db.namespace` attribute (the database name).
67    #[must_use]
68    pub fn with_database(mut self, database: impl Into<String>) -> Self {
69        self.namespace = Some(database.into());
70        self
71    }
72
73    /// Override the `server.address` attribute (the logical hostname).
74    #[must_use]
75    pub fn with_host(mut self, host: impl Into<String>) -> Self {
76        self.host = Some(host.into());
77        self
78    }
79
80    /// Override the `server.port` attribute.
81    #[must_use]
82    pub fn with_port(mut self, port: u16) -> Self {
83        self.port = Some(port);
84        self
85    }
86
87    /// Set the `network.peer.address` attribute (the resolved IP address).
88    #[must_use]
89    pub fn with_network_peer_address(mut self, address: impl Into<String>) -> Self {
90        self.network_peer_address = Some(address.into());
91        self
92    }
93
94    /// Set the `network.peer.port` attribute (the resolved port).
95    #[must_use]
96    pub fn with_network_peer_port(mut self, port: u16) -> Self {
97        self.network_peer_port = Some(port);
98        self
99    }
100
101    /// Configure how `db.query.text` is captured on spans. Defaults to
102    /// [`QueryTextMode::Full`].
103    #[must_use]
104    pub fn with_query_text_mode(mut self, mode: QueryTextMode) -> Self {
105        self.query_text_mode = mode;
106        self
107    }
108
109    /// Set the `db.client.connection.pool.name` attribute.
110    ///
111    /// When a runtime feature (e.g. `runtime-tokio`) is also enabled, a background task is
112    /// spawned that periodically records `db.client.connection.count` (idle/used). See
113    /// [`with_pool_metrics_interval`](Self::with_pool_metrics_interval) to configure the
114    /// polling frequency.
115    #[must_use]
116    pub fn with_pool_name(mut self, name: impl Into<String>) -> Self {
117        self.pool_name = Some(name.into());
118        self
119    }
120
121    /// Set the polling interval for `db.client.connection.count`. Defaults to 10 seconds.
122    ///
123    /// Has no effect unless [`with_pool_name`](Self::with_pool_name) is also called and a
124    /// runtime feature is enabled.
125    #[must_use]
126    pub fn with_pool_metrics_interval(mut self, interval: Duration) -> Self {
127        self.pool_metrics_interval = interval;
128        self
129    }
130
131    /// Consume the builder and produce an instrumented [`Pool`].
132    #[must_use]
133    pub fn build(self) -> Pool<DB> {
134        let metrics_shutdown = self.spawn_pool_metrics_task();
135
136        let attrs = Arc::new(ConnectionAttributes {
137            system: DB::SYSTEM,
138            host: self.host,
139            port: self.port,
140            namespace: self.namespace,
141            network_peer_address: self.network_peer_address,
142            network_peer_port: self.network_peer_port,
143            query_text_mode: self.query_text_mode,
144        });
145        let metrics = Arc::new(Metrics::new());
146        let meter = opentelemetry::global::meter("sqlx-otel");
147
148        // Record static pool configuration gauges once – these never change.
149        let max_conns = i64::from(self.pool.options().get_max_connections());
150        let min_conns = i64::from(self.pool.options().get_min_connections());
151        let base_attrs = attrs.base_key_values();
152
153        meter
154            .i64_gauge(semconv_metric::DB_CLIENT_CONNECTION_MAX)
155            .with_description("The maximum number of open connections allowed.")
156            .build()
157            .record(max_conns, &base_attrs);
158        meter
159            .i64_gauge(semconv_metric::DB_CLIENT_CONNECTION_IDLE_MAX)
160            .with_description("The maximum number of idle open connections allowed.")
161            .build()
162            .record(max_conns, &base_attrs);
163        meter
164            .i64_gauge(semconv_metric::DB_CLIENT_CONNECTION_IDLE_MIN)
165            .with_description("The minimum number of idle open connections allowed.")
166            .build()
167            .record(min_conns, &base_attrs);
168
169        Pool {
170            inner: self.pool,
171            state: SharedState { attrs, metrics },
172            metrics_shutdown,
173            wait_time: Arc::new(
174                meter
175                    .f64_histogram(semconv_metric::DB_CLIENT_CONNECTION_WAIT_TIME)
176                    .with_unit("s")
177                    .with_description(
178                        "The time it took to obtain an open connection from the pool.",
179                    )
180                    .build(),
181            ),
182            use_time: Arc::new(
183                meter
184                    .f64_histogram(semconv_metric::DB_CLIENT_CONNECTION_USE_TIME)
185                    .with_unit("s")
186                    .with_description(
187                        "The time between borrowing a connection and returning it to the pool.",
188                    )
189                    .build(),
190            ),
191            timeouts: Arc::new(
192                meter
193                    .u64_counter(semconv_metric::DB_CLIENT_CONNECTION_TIMEOUTS)
194                    .with_description(
195                        "The number of connection pool acquire attempts that timed out.",
196                    )
197                    .build(),
198            ),
199            pending_requests: Arc::new(
200                meter
201                    .i64_up_down_counter(semconv_metric::DB_CLIENT_CONNECTION_PENDING_REQUESTS)
202                    .with_description("The number of pending requests for an open connection.")
203                    .build(),
204            ),
205        }
206    }
207
208    /// Spawn the pool metrics background task if a pool name is set and a runtime is
209    /// available. Returns the shutdown handle (or `None`).
210    fn spawn_pool_metrics_task(&self) -> Option<crate::pool_metrics::ShutdownHandle> {
211        let name = self.pool_name.as_ref()?;
212
213        // Prefer tokio if both runtimes are enabled.
214        #[cfg(feature = "runtime-tokio")]
215        {
216            Some(
217                crate::pool_metrics::spawn::<crate::runtime::TokioRuntime, DB>(
218                    self.pool.clone(),
219                    name.clone(),
220                    self.pool_metrics_interval,
221                ),
222            )
223        }
224
225        #[cfg(all(feature = "runtime-async-std", not(feature = "runtime-tokio")))]
226        {
227            Some(crate::pool_metrics::spawn::<
228                crate::runtime::AsyncStdRuntime,
229                DB,
230            >(
231                self.pool.clone(),
232                name.clone(),
233                self.pool_metrics_interval,
234            ))
235        }
236
237        #[cfg(not(any(feature = "runtime-tokio", feature = "runtime-async-std")))]
238        {
239            let _ = name;
240            None
241        }
242    }
243}
244
245/// An instrumented wrapper around `sqlx::Pool` that emits OpenTelemetry spans and metrics
246/// for every database operation.
247///
248/// Create one via [`PoolBuilder`]:
249///
250/// ```ignore
251/// let pool: Pool<Postgres> = PoolBuilder::from(sqlx_pool).build();
252/// ```
253///
254/// All connections acquired from this pool inherit its shared attributes and metric
255/// instruments.
256#[derive(Debug)]
257pub struct Pool<DB: sqlx::Database> {
258    pub(crate) inner: sqlx::Pool<DB>,
259    pub(crate) state: SharedState,
260    /// Dropping this handle signals the background polling task to stop.
261    metrics_shutdown: Option<crate::pool_metrics::ShutdownHandle>,
262    /// Histogram for `db.client.connection.wait_time`, recorded on each `acquire()`.
263    wait_time: Arc<opentelemetry::metrics::Histogram<f64>>,
264    /// Histogram for `db.client.connection.use_time`, recorded when a connection is dropped.
265    pub(crate) use_time: Arc<opentelemetry::metrics::Histogram<f64>>,
266    /// Counter for `db.client.connection.timeouts`, incremented on `PoolTimedOut`.
267    timeouts: Arc<opentelemetry::metrics::Counter<u64>>,
268    /// Up/down counter for `db.client.connection.pending_requests`, tracks callers
269    /// currently waiting in `acquire()`.
270    pending_requests: Arc<opentelemetry::metrics::UpDownCounter<i64>>,
271}
272
273impl<DB: sqlx::Database> Clone for Pool<DB> {
274    fn clone(&self) -> Self {
275        Self {
276            inner: self.inner.clone(),
277            state: self.state.clone(),
278            metrics_shutdown: self.metrics_shutdown.clone(),
279            wait_time: self.wait_time.clone(),
280            use_time: self.use_time.clone(),
281            timeouts: self.timeouts.clone(),
282            pending_requests: self.pending_requests.clone(),
283        }
284    }
285}
286
287impl<DB: Database> Pool<DB> {
288    /// Acquire a pooled connection instrumented for OpenTelemetry.
289    ///
290    /// Records `db.client.connection.wait_time` (time spent waiting for a connection),
291    /// tracks `db.client.connection.pending_requests`, and increments
292    /// `db.client.connection.timeouts` on `PoolTimedOut`.
293    ///
294    /// # Errors
295    ///
296    /// Returns `sqlx::Error` if a connection cannot be obtained from the pool (e.g.
297    /// timeout, pool closed).
298    pub async fn acquire(&self) -> Result<PoolConnection<DB>, sqlx::Error> {
299        let attrs = self.state.attrs.base_key_values();
300        self.pending_requests.add(1, &attrs);
301        let start = std::time::Instant::now();
302        let result = self.inner.acquire().await;
303        self.pending_requests.add(-1, &attrs);
304        self.wait_time.record(start.elapsed().as_secs_f64(), &attrs);
305
306        if let Err(sqlx::Error::PoolTimedOut) = &result {
307            self.timeouts.add(1, &attrs);
308        }
309
310        result.map(|inner| PoolConnection {
311            inner,
312            state: self.state.clone(),
313            use_time: self.use_time.clone(),
314            acquired_at: std::time::Instant::now(),
315            base_attrs: attrs,
316        })
317    }
318
319    /// Begin a new transaction instrumented for OpenTelemetry.
320    ///
321    /// # Errors
322    ///
323    /// Returns `sqlx::Error` if beginning the transaction fails.
324    pub async fn begin(&self) -> Result<Transaction<'_, DB>, sqlx::Error> {
325        self.inner.begin().await.map(|inner| Transaction {
326            inner,
327            state: self.state.clone(),
328        })
329    }
330
331    /// Shut down the pool, waiting for all connections to be released.
332    pub async fn close(&self) {
333        self.inner.close().await;
334    }
335
336    /// Returns `true` if the pool has been closed.
337    #[must_use]
338    pub fn is_closed(&self) -> bool {
339        self.inner.is_closed()
340    }
341
342    /// Return an annotated executor that attaches per-query semantic convention attributes
343    /// to every span created by the next operation.
344    ///
345    /// The returned wrapper borrows the pool and implements `sqlx::Executor` with the
346    /// same instrumentation, but with annotation values threaded through to span creation.
347    ///
348    /// # Example
349    ///
350    /// ```ignore
351    /// pool.with_annotations(QueryAnnotations::new()
352    ///         .operation("SELECT")
353    ///         .collection("users"))
354    ///     .fetch_all("SELECT * FROM users")
355    ///     .await?;
356    /// ```
357    #[must_use]
358    pub fn with_annotations(&self, annotations: QueryAnnotations) -> Annotated<'_, Self> {
359        Annotated {
360            inner: self,
361            annotations,
362            state: self.state.clone(),
363        }
364    }
365
366    /// Shorthand for annotating the next operation with `db.operation.name` and
367    /// `db.collection.name`.
368    ///
369    /// Equivalent to `self.with_annotations(QueryAnnotations::new().operation(op).collection(coll))`.
370    #[must_use]
371    pub fn with_operation(
372        &self,
373        operation: impl Into<String>,
374        collection: impl Into<String>,
375    ) -> Annotated<'_, Self> {
376        self.with_annotations(
377            QueryAnnotations::new()
378                .operation(operation)
379                .collection(collection),
380        )
381    }
382}