Skip to main content

ssh_commander_core/postgres/
mod.rs

1//! PostgreSQL pool for the database explorer.
2//!
3//! Sprint 6 model: each `PgPool` represents one configured Postgres
4//! profile and manages up to `max_size` underlying connections. Callers
5//! identify their work by `session_id` (typically a UUID per query
6//! tab). The pool maps `session_id → connection` so a tab's cursor
7//! always lives on the same wire across `execute → fetch_page →
8//! close_query`. Idle connections are reused for fresh sessions.
9//!
10//! ## Why per-session leasing
11//!
12//! `tokio_postgres::Client` is `Sync`, but a Postgres *session* is
13//! single-threaded by protocol: one transaction at a time, one cursor
14//! at a time. To let two query tabs run independent paginated SELECTs
15//! in parallel, each must hold its own connection for the cursor's
16//! lifetime. Without that, opening cursor A then cursor B on the same
17//! wire kills A — what Sprint 5's single-cursor invariant produced.
18//!
19//! ## Tunnel sharing
20//!
21//! When the profile uses an SSH tunnel, the listener is opened *once*
22//! at pool construction and shared by every pooled connection — they
23//! each open their own SSH `direct-tcpip` channel via the same local
24//! port. Single tunnel per profile keeps SSH session usage minimal.
25//!
26//! ## Thread safety
27//!
28//! `PgPool` is `Send + Sync`. All public methods take `&self` and
29//! acquire internal locks for short windows (the pool's own metadata
30//! Mutex during lease/release, and a per-connection Mutex during
31//! cursor bookkeeping). FFI callers wrap one `Arc<PgPool>` per
32//! managed connection id.
33
34pub mod config;
35pub mod edit;
36pub mod exec;
37pub mod introspect;
38pub mod tunnel;
39
40pub use config::{PgAuthMethod, PgConfig, PgTlsMode, SshTunnelRef};
41pub use edit::{InsertColumnInput, InsertedRow, UpdateOutcome};
42pub use exec::{ActiveCursor, ColumnMeta, ExecutionOutcome, PageResult};
43pub use introspect::{
44    ColumnDetail, DbSummary, ObjectType, ObjectTypeKind, Relation, RelationKind, Routine,
45    RoutineKind, SchemaContents, SchemaSummary, Sequence,
46};
47pub use tunnel::SshTunnel;
48
49use std::collections::HashMap;
50use std::sync::{Arc, Weak};
51use std::time::{Duration, Instant};
52
53use rustls::ClientConfig as RustlsClientConfig;
54use tokio::sync::{Mutex, RwLock};
55use tokio::task::JoinHandle;
56use tokio_postgres::config::SslMode as PgSslMode;
57use tokio_postgres::{CancelToken, Client, Config as PgDriverConfig};
58use tokio_postgres_rustls::MakeRustlsConnect;
59use tokio_util::sync::CancellationToken;
60
61use crate::ssh::SshClient;
62
63/// Default upper bound on connections per pool. Five is plenty for an
64/// interactive explorer (you'd need six query tabs running at once to
65/// hit it) and well below the default `max_connections=100` Postgres
66/// quota that managed providers tend to set.
67const DEFAULT_MAX_POOL_SIZE: usize = 5;
68
69/// Default idle-connection lifetime. Five minutes balances "alt-tabbing
70/// out and back doesn't pay reconnect cost" against "polite to managed
71/// providers that bill on connection-hours". Per-profile override on
72/// `PgConfig.idle_timeout_secs`.
73const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(300);
74
75/// How often the eviction loop wakes up. Independent of the idle
76/// timeout — connections live at most `idle_timeout + EVICTION_INTERVAL`.
77/// Not exposed on `PgConfig`; the cadence is global and the value is
78/// already much smaller than typical idle_timeout settings.
79const EVICTION_INTERVAL: Duration = Duration::from_secs(30);
80
81/// Default minimum idle connections to keep alive past `idle_timeout`.
82/// One warm connection means the next query doesn't pay connect
83/// latency. Per-profile override on `PgConfig.min_idle_connections`.
84const DEFAULT_MIN_IDLE_CONNECTIONS: usize = 1;
85
86/// Stable session id used for the schema browser's introspection
87/// calls (`list_databases`, `list_schemas`, `list_relations`). Uses
88/// a name that can't collide with a UUID-generated tab session id.
89pub const BROWSER_SESSION_ID: &str = "_browser";
90
91/// Errors surfaced from the Postgres explorer layer.
92#[derive(Debug, thiserror::Error)]
93pub enum PgError {
94    #[error("postgres connect failed: {0}")]
95    Connect(String),
96    #[error("postgres auth failed: {0}")]
97    Auth(String),
98    #[error("postgres tls setup failed: {0}")]
99    Tls(String),
100    #[error("ssh tunnel error: {0}")]
101    Tunnel(String),
102    /// Tunnel was requested but the referenced SSH connection isn't
103    /// registered (or has been closed).
104    #[error("ssh tunnel source not found: {0}")]
105    TunnelSourceMissing(String),
106    /// The cursor referenced by a fetch_page / close_query call no
107    /// longer exists. Sprint 6: now scoped to the session — only
108    /// fires if the same session opened a new cursor in between, or
109    /// the session was released.
110    #[error("cursor no longer available: {0}")]
111    CursorExpired(String),
112    /// The pool is at `max_size` and all connections are currently
113    /// leased to other sessions. Caller should wait and retry, or
114    /// release another session.
115    #[error("pool exhausted: {0} of {1} connections leased")]
116    PoolExhausted(usize, usize),
117    #[error("postgres driver error: {0}")]
118    Driver(#[from] tokio_postgres::Error),
119}
120
121// ============================================================================
122// Pool
123// ============================================================================
124
125pub struct PgPool {
126    config: PgConfig,
127    /// Optional SSH tunnel shared across all pooled connections. The
128    /// listener stays bound for the pool's lifetime; per-connection
129    /// `direct-tcpip` channels open lazily as each connection
130    /// dials in.
131    tunnel: Option<Arc<SshTunnel>>,
132    /// TLS connector built once at pool init and reused for every
133    /// new connection plus every server-side cancel. Sharing
134    /// matters for TLS-only Postgres deployments (RDS, Supabase,
135    /// Neon) where the server rejects a plaintext cancel handshake;
136    /// the cancel must use the same TLS posture as the data wire.
137    tls_connector: TlsConnectorKind,
138    /// Pool state guarded by a single Mutex. Acquire/release windows
139    /// are short — the actual SQL round trips happen against the
140    /// per-connection Mutex, not this one.
141    inner: Mutex<PoolInner>,
142    max_size: usize,
143    /// Per-pool idle-connection lifetime. Read once at construction
144    /// from `PgConfig.idle_timeout_secs` (or `DEFAULT_IDLE_TIMEOUT`).
145    idle_timeout: Duration,
146    /// Per-pool minimum-idle floor. Eviction won't drop below this
147    /// even when entries are aged.
148    min_idle: usize,
149    /// Cached side-connections for browsing databases other than
150    /// `config.database`. Postgres connections are bound to one
151    /// database at connect time; the schema browser tree shows
152    /// every database on the server, so expanding a non-default
153    /// one needs its own connection. Keyed by database name.
154    /// Lazily populated on first cross-database introspection;
155    /// torn down by `shutdown`.
156    secondary_browsers: Mutex<HashMap<String, Arc<Mutex<PooledConnection>>>>,
157    /// Signal to the background eviction task that it should stop.
158    /// `shutdown` cancels it explicitly; the task also self-exits
159    /// when the pool's `Weak<Self>` upgrade fails (i.e. all `Arc`s
160    /// have been dropped).
161    eviction_cancel: CancellationToken,
162}
163
164/// One idle connection plus the moment it returned to idle. The
165/// eviction loop reads `since` to decide what to drop. Newly opened
166/// connections also enter idle with `since = now`, so a cold pool
167/// doesn't immediately evict.
168struct IdleEntry {
169    since: Instant,
170    conn: Arc<Mutex<PooledConnection>>,
171}
172
173/// Erased TLS strategy for both data connections and cancel
174/// requests. `MakeRustlsConnect` clones cheaply (the inner
175/// `rustls::ClientConfig` is `Arc`-shared), so reusing the same
176/// instance for many connections is fine.
177#[derive(Clone)]
178enum TlsConnectorKind {
179    NoTls,
180    Rustls(MakeRustlsConnect),
181}
182
183struct PoolInner {
184    /// Connections free to lease, with the timestamp they returned to
185    /// idle. The eviction loop scans this list every
186    /// `EVICTION_INTERVAL`.
187    idle: Vec<IdleEntry>,
188    /// Active leases, keyed by caller-supplied session id.
189    leased: HashMap<String, Arc<Mutex<PooledConnection>>>,
190    /// Total connections in existence (idle + leased + currently
191    /// being opened). Bounds growth against `max_size`.
192    total: usize,
193}
194
195struct PooledConnection {
196    client: Client,
197    cancel_token: CancelToken,
198    /// At-most-one active cursor on this wire. Protected by the
199    /// per-connection mutex (callers hold it for the duration of
200    /// any cursor op).
201    active_cursor: Option<ActiveCursor>,
202    /// Background task that drives this connection's wire protocol.
203    /// Aborted when the connection is dropped from the pool.
204    connection_task: Option<JoinHandle<()>>,
205}
206
207impl PgPool {
208    /// Open a pool with `min_size = 1` (one connection eagerly
209    /// established) and `max_size = DEFAULT_MAX_POOL_SIZE`. Eager
210    /// initial connect surfaces auth/network errors immediately
211    /// rather than deferring them to the first query.
212    pub async fn connect(
213        cfg: PgConfig,
214        ssh_client: Option<Arc<RwLock<SshClient>>>,
215    ) -> Result<Arc<Self>, PgError> {
216        // Open the tunnel once if requested. Subsequent connections
217        // dial 127.0.0.1:<local_port> independently.
218        let tunnel: Option<Arc<SshTunnel>> = if let Some(tunnel_ref) = cfg.ssh_tunnel.as_ref() {
219            let Some(ssh) = ssh_client else {
220                return Err(PgError::Tunnel(
221                    "ssh tunnel requested but no ssh client supplied".into(),
222                ));
223            };
224            let t = SshTunnel::open(ssh, tunnel_ref.remote_host.clone(), tunnel_ref.remote_port)
225                .await
226                .map_err(|e| PgError::Tunnel(format!("failed to open ssh tunnel: {e}")))?;
227            Some(Arc::new(t))
228        } else {
229            None
230        };
231
232        // Build the TLS connector once. Subsequent opens (and the
233        // initial open below) reuse it; so does `cancel`.
234        let tls_connector = build_tls_connector(&cfg)?;
235
236        // Eager-connect the first connection so authentication errors
237        // surface up front.
238        let first = open_one(&cfg, tunnel.as_deref(), &tls_connector).await?;
239
240        let now = Instant::now();
241        // Apply per-profile overrides on top of the built-in
242        // defaults. `0` is a valid `min_idle` (full evacuate) so we
243        // pass it through as-is.
244        let max_size = cfg
245            .max_pool_size
246            .map(|n| n as usize)
247            .filter(|&n| n > 0)
248            .unwrap_or(DEFAULT_MAX_POOL_SIZE);
249        let idle_timeout = cfg
250            .idle_timeout_secs
251            .map(Duration::from_secs)
252            .unwrap_or(DEFAULT_IDLE_TIMEOUT);
253        let min_idle = cfg
254            .min_idle_connections
255            .map(|n| n as usize)
256            .unwrap_or(DEFAULT_MIN_IDLE_CONNECTIONS);
257
258        let pool = Arc::new(Self {
259            config: cfg,
260            tunnel,
261            tls_connector,
262            inner: Mutex::new(PoolInner {
263                idle: vec![IdleEntry {
264                    since: now,
265                    conn: Arc::new(Mutex::new(first)),
266                }],
267                leased: HashMap::new(),
268                total: 1,
269            }),
270            max_size,
271            idle_timeout,
272            min_idle,
273            secondary_browsers: Mutex::new(HashMap::new()),
274            eviction_cancel: CancellationToken::new(),
275        });
276
277        // Spawn the eviction loop. We hand it a `Weak<Self>` so the
278        // pool can drop without the task keeping it alive — and a
279        // clone of the cancel token so an explicit `shutdown` can
280        // wake it immediately rather than waiting for the next tick.
281        let weak = Arc::downgrade(&pool);
282        let cancel = pool.eviction_cancel.clone();
283        tokio::spawn(run_eviction(weak, cancel));
284
285        Ok(pool)
286    }
287
288    // ------------------------------------------------------------------
289    // High-level operations
290    // ------------------------------------------------------------------
291
292    /// Schema introspection runs on the connection bound to the
293    /// well-known browser session. Re-using the same session means
294    /// the browser doesn't churn through pool slots on every tree
295    /// refresh.
296    pub async fn list_databases(&self) -> Result<Vec<DbSummary>, PgError> {
297        let conn = self.lease_for_session(BROWSER_SESSION_ID).await?;
298        let guard = conn.lock().await;
299        Ok(introspect::list_databases(&guard.client).await?)
300    }
301
302    pub async fn list_schemas(&self) -> Result<Vec<SchemaSummary>, PgError> {
303        self.list_schemas_in(None).await
304    }
305
306    /// List schemas in `database`, opening (and caching) a side
307    /// connection when it differs from the connection's default DB.
308    /// Postgres binds a connection to one database at startup, so
309    /// browsing other DBs in the tree requires this routing —
310    /// otherwise every database expansion would show the connected
311    /// DB's schemas.
312    pub async fn list_schemas_in(
313        &self,
314        database: Option<&str>,
315    ) -> Result<Vec<SchemaSummary>, PgError> {
316        let conn = self.browser_connection_for(database).await?;
317        let guard = conn.lock().await;
318        Ok(introspect::list_schemas(&guard.client).await?)
319    }
320
321    pub async fn list_relations(&self, schema: &str) -> Result<Vec<Relation>, PgError> {
322        self.list_relations_in(schema, None).await
323    }
324
325    pub async fn list_relations_in(
326        &self,
327        schema: &str,
328        database: Option<&str>,
329    ) -> Result<Vec<Relation>, PgError> {
330        let conn = self.browser_connection_for(database).await?;
331        let guard = conn.lock().await;
332        Ok(introspect::list_relations(&guard.client, schema).await?)
333    }
334
335    /// Unified schema-contents fetch — tables, views, mat-views,
336    /// sequences, routines, and object types in one call. Replaces
337    /// the per-category round-trips the older tree did, and is what
338    /// the DataGrip-style 6-category tree expects.
339    pub async fn list_schema_contents_in(
340        &self,
341        schema: &str,
342        database: Option<&str>,
343    ) -> Result<SchemaContents, PgError> {
344        let conn = self.browser_connection_for(database).await?;
345        let guard = conn.lock().await;
346        Ok(introspect::list_schema_contents(&guard.client, schema).await?)
347    }
348
349    /// Resolve a browser connection for `database`. When `None` or
350    /// matching the pool's default DB, returns the regular browser
351    /// session's lease. Otherwise opens (or returns cached) a
352    /// secondary connection bound to that database.
353    async fn browser_connection_for(
354        &self,
355        database: Option<&str>,
356    ) -> Result<Arc<Mutex<PooledConnection>>, PgError> {
357        let target = database.unwrap_or(self.config.database.as_str());
358        if target == self.config.database {
359            return self.lease_for_session(BROWSER_SESSION_ID).await;
360        }
361        // Cached secondary?
362        {
363            let map = self.secondary_browsers.lock().await;
364            if let Some(c) = map.get(target) {
365                return Ok(c.clone());
366            }
367        }
368        // Open a fresh connection bound to the target database.
369        // Reuse credentials, TLS posture, and tunnel — only the
370        // database name differs.
371        let mut cfg = self.config.clone();
372        cfg.database = target.to_string();
373        let conn = open_one(&cfg, self.tunnel.as_deref(), &self.tls_connector).await?;
374        let arc = Arc::new(Mutex::new(conn));
375        let mut map = self.secondary_browsers.lock().await;
376        // Race: another task may have inserted while we were
377        // opening. Prefer the existing entry to avoid leaking the
378        // freshly-opened one — drop ours, return theirs.
379        if let Some(existing) = map.get(target) {
380            return Ok(existing.clone());
381        }
382        map.insert(target.to_string(), arc.clone());
383        Ok(arc)
384    }
385
386    /// Describe a relation's columns for the INSERT form. Runs on
387    /// the browser session to avoid churning a query tab's lease.
388    pub async fn describe_columns(
389        &self,
390        schema: &str,
391        table: &str,
392    ) -> Result<Vec<ColumnDetail>, PgError> {
393        let conn = self.lease_for_session(BROWSER_SESSION_ID).await?;
394        let guard = conn.lock().await;
395        Ok(introspect::describe_columns(&guard.client, schema, table).await?)
396    }
397
398    /// Run a SQL statement on the connection assigned to `session_id`,
399    /// leasing one if the session is new. Closes any previously-active
400    /// cursor on that connection (per-session behavior — other sessions
401    /// are unaffected).
402    pub async fn execute(
403        &self,
404        session_id: &str,
405        sql: &str,
406        page_size: usize,
407    ) -> Result<ExecutionOutcome, PgError> {
408        let conn = self.lease_for_session(session_id).await?;
409        let mut guard = conn.lock().await;
410        let previous = guard.active_cursor.take();
411        let (outcome, new_cursor) =
412            exec::open_query(&guard.client, sql, page_size, previous).await?;
413        guard.active_cursor = new_cursor;
414        // If the cursor closed (no more rows), the connection is
415        // logically idle — but we keep the lease so the same session
416        // continues to land on the same wire for follow-up commands
417        // (helpful for SET / temporary tables). Explicit
418        // `release_session` returns it to idle.
419        Ok(outcome)
420    }
421
422    pub async fn fetch_page(
423        &self,
424        session_id: &str,
425        cursor_id: &str,
426        count: usize,
427    ) -> Result<PageResult, PgError> {
428        let conn = self
429            .leased_only(session_id)
430            .await
431            .ok_or_else(|| PgError::CursorExpired(format!("no active session {session_id}")))?;
432        let guard = conn.lock().await;
433        let Some(cursor) = guard.active_cursor.as_ref() else {
434            return Err(PgError::CursorExpired(format!(
435                "session {session_id} has no active cursor"
436            )));
437        };
438        if cursor.cursor_id != cursor_id {
439            return Err(PgError::CursorExpired(format!(
440                "session {session_id} active cursor is {} (looking for {cursor_id})",
441                cursor.cursor_id
442            )));
443        }
444        let cursor_clone = cursor.clone();
445        let client = &guard.client;
446        exec::fetch_page(client, &cursor_clone, count).await
447    }
448
449    /// Update a single cell on `(schema, table)` identified by ctid.
450    /// Runs on the session's connection so users editing in one tab
451    /// don't block on another tab's pagination cursor. Returns
452    /// `UpdateOutcome { rows_affected }` — the UI treats `0` as
453    /// "row no longer there, please refresh".
454    #[allow(clippy::too_many_arguments)]
455    pub async fn update_cell(
456        &self,
457        session_id: &str,
458        schema: &str,
459        table: &str,
460        column: &str,
461        column_type: &str,
462        new_value: Option<&str>,
463        ctid: &str,
464    ) -> Result<UpdateOutcome, PgError> {
465        let conn = self.lease_for_session(session_id).await?;
466        let guard = conn.lock().await;
467        edit::update_cell(
468            &guard.client,
469            schema,
470            table,
471            column,
472            column_type,
473            new_value,
474            ctid,
475        )
476        .await
477    }
478
479    /// Insert one row, returning the requested columns. Runs on the
480    /// session's connection so any session-local state (SET,
481    /// transactions) applies. See [`edit::insert_row`] for the SQL
482    /// shape and parameter rules.
483    pub async fn insert_row(
484        &self,
485        session_id: &str,
486        schema: &str,
487        table: &str,
488        inputs: &[InsertColumnInput],
489        return_columns: &[String],
490    ) -> Result<InsertedRow, PgError> {
491        let conn = self.lease_for_session(session_id).await?;
492        let guard = conn.lock().await;
493        edit::insert_row(&guard.client, schema, table, inputs, return_columns).await
494    }
495
496    /// Delete one or more rows by ctid on the session's connection.
497    /// Returns the actual rows-deleted count — callers compare
498    /// against the requested count to spot "some rows were already
499    /// gone" (concurrent edit / delete from another session).
500    pub async fn delete_rows(
501        &self,
502        session_id: &str,
503        schema: &str,
504        table: &str,
505        ctids: &[String],
506    ) -> Result<UpdateOutcome, PgError> {
507        let conn = self.lease_for_session(session_id).await?;
508        let guard = conn.lock().await;
509        edit::delete_rows(&guard.client, schema, table, ctids).await
510    }
511
512    pub async fn close_query(&self, session_id: &str, cursor_id: &str) -> Result<(), PgError> {
513        let Some(conn) = self.leased_only(session_id).await else {
514            return Ok(()); // Nothing to close — idempotent.
515        };
516        let mut guard = conn.lock().await;
517        if let Some(c) = guard.active_cursor.as_ref()
518            && c.cursor_id == cursor_id
519        {
520            let cursor = guard.active_cursor.take().expect("just checked");
521            let client = &guard.client;
522            exec::close_query(client, &cursor).await;
523        }
524        Ok(())
525    }
526
527    /// Server-side cancel for whatever query is in flight on the
528    /// session's connection. Uses the same TLS posture as the
529    /// data wire — TLS-only Postgres deployments (RDS, Supabase,
530    /// Neon) reject plaintext cancels, which would silently leave
531    /// the in-flight query running until it timed out. No-op if
532    /// the session has no lease.
533    pub async fn cancel(&self, session_id: &str) -> Result<(), PgError> {
534        let Some(conn) = self.leased_only(session_id).await else {
535            return Ok(());
536        };
537        let token = {
538            let guard = conn.lock().await;
539            guard.cancel_token.clone()
540        };
541        match &self.tls_connector {
542            TlsConnectorKind::NoTls => token.cancel_query(tokio_postgres::NoTls).await,
543            TlsConnectorKind::Rustls(connector) => token.cancel_query(connector.clone()).await,
544        }
545        .map_err(PgError::Driver)
546    }
547
548    /// Release a session's lease. Closes any active cursor first so
549    /// the underlying connection returns to idle in a clean state.
550    pub async fn release_session(&self, session_id: &str) {
551        let Some(conn) = self.take_lease(session_id).await else {
552            return;
553        };
554        // Close any open cursor + transaction so the connection is
555        // safe to hand to a different session.
556        {
557            let mut guard = conn.lock().await;
558            if let Some(cursor) = guard.active_cursor.take() {
559                exec::close_query(&guard.client, &cursor).await;
560            }
561        }
562        // Return to idle, stamping the moment of release so the
563        // eviction loop can age it out.
564        let mut inner = self.inner.lock().await;
565        inner.idle.push(IdleEntry {
566            since: Instant::now(),
567            conn,
568        });
569    }
570
571    /// Tear down all connections. Used on `disconnect`.
572    pub async fn shutdown(&self) {
573        // Wake the eviction loop so it exits promptly rather than
574        // sleeping on the next tick.
575        self.eviction_cancel.cancel();
576
577        let mut inner = self.inner.lock().await;
578        let mut conns: Vec<Arc<Mutex<PooledConnection>>> =
579            inner.idle.drain(..).map(|e| e.conn).collect();
580        conns.extend(inner.leased.drain().map(|(_, c)| c));
581        inner.total = 0;
582        drop(inner);
583        // Also close any secondary cross-database browser
584        // connections we opened.
585        let secondaries: Vec<Arc<Mutex<PooledConnection>>> = {
586            let mut map = self.secondary_browsers.lock().await;
587            map.drain().map(|(_, c)| c).collect()
588        };
589        let conns_with_secondaries = conns.into_iter().chain(secondaries);
590        let conns: Vec<Arc<Mutex<PooledConnection>>> = conns_with_secondaries.collect();
591        for conn in conns {
592            // Best-effort task abort; the wire closes when `Client`
593            // is dropped (which happens when the last Arc to this
594            // PooledConnection drops — usually right here).
595            let mut guard = conn.lock().await;
596            if let Some(task) = guard.connection_task.take() {
597                task.abort();
598            }
599        }
600    }
601
602    // ------------------------------------------------------------------
603    // Internal: leasing
604    // ------------------------------------------------------------------
605
606    /// Get the connection currently leased to `session_id`, opening
607    /// a new one (and leasing it) when the session is new.
608    async fn lease_for_session(
609        &self,
610        session_id: &str,
611    ) -> Result<Arc<Mutex<PooledConnection>>, PgError> {
612        // Fast path: existing lease.
613        {
614            let inner = self.inner.lock().await;
615            if let Some(c) = inner.leased.get(session_id) {
616                return Ok(c.clone());
617            }
618        }
619
620        // Try to grab an idle connection first. LIFO (`pop`) keeps
621        // the most-recently-released connection warm — which is the
622        // youngest, most-likely-to-survive eviction next round.
623        let from_idle = {
624            let mut inner = self.inner.lock().await;
625            inner.idle.pop().map(|e| e.conn)
626        };
627        if let Some(conn) = from_idle {
628            self.assign_lease(session_id, conn.clone()).await;
629            return Ok(conn);
630        }
631
632        // No idle connections. Open a new one if we have room.
633        let need_new = {
634            let inner = self.inner.lock().await;
635            if inner.total >= self.max_size {
636                return Err(PgError::PoolExhausted(inner.total, self.max_size));
637            }
638            true
639        };
640        if need_new {
641            // Reserve the slot before the network round trip so two
642            // simultaneous lease requests don't both think there's
643            // room.
644            {
645                let mut inner = self.inner.lock().await;
646                if inner.total >= self.max_size {
647                    return Err(PgError::PoolExhausted(inner.total, self.max_size));
648                }
649                inner.total += 1;
650            }
651            let new_conn =
652                match open_one(&self.config, self.tunnel.as_deref(), &self.tls_connector).await {
653                    Ok(c) => c,
654                    Err(e) => {
655                        // Roll back the slot reservation on failure so
656                        // the pool can try again later.
657                        let mut inner = self.inner.lock().await;
658                        inner.total = inner.total.saturating_sub(1);
659                        return Err(e);
660                    }
661                };
662            let conn = Arc::new(Mutex::new(new_conn));
663            self.assign_lease(session_id, conn.clone()).await;
664            return Ok(conn);
665        }
666        unreachable!()
667    }
668
669    async fn assign_lease(&self, session_id: &str, conn: Arc<Mutex<PooledConnection>>) {
670        let mut inner = self.inner.lock().await;
671        inner.leased.insert(session_id.to_string(), conn);
672    }
673
674    /// Evict idle connections older than the pool's configured
675    /// `idle_timeout`, keeping at least `min_idle` alive. Runs from
676    /// the background eviction task; safe to call manually too.
677    async fn evict_idle(&self) {
678        let now = Instant::now();
679        let to_drop: Vec<Arc<Mutex<PooledConnection>>>;
680        {
681            let mut inner = self.inner.lock().await;
682            // Take the idle list out so we can decide which entries
683            // to keep without holding the lock during async work.
684            // Single-pass keep/drop: iterate in storage order (most
685            // recently released last, since `lease` pops from the
686            // end); the first `min_idle` we encounter are pinned
687            // regardless of age.
688            let snapshot = std::mem::take(&mut inner.idle);
689            let mut keep: Vec<IdleEntry> = Vec::with_capacity(snapshot.len());
690            let mut drop_list: Vec<Arc<Mutex<PooledConnection>>> = Vec::new();
691            for entry in snapshot.into_iter() {
692                let aged = now.duration_since(entry.since) >= self.idle_timeout;
693                if !aged || keep.len() < self.min_idle {
694                    keep.push(entry);
695                } else {
696                    drop_list.push(entry.conn);
697                    inner.total = inner.total.saturating_sub(1);
698                }
699            }
700            inner.idle = keep;
701            to_drop = drop_list;
702        }
703
704        if !to_drop.is_empty() {
705            tracing::debug!(
706                target: "postgres::pool",
707                count = to_drop.len(),
708                "evicted idle postgres connections"
709            );
710        }
711
712        // Abort the connection tasks. Dropping the `Arc` (via
713        // out-of-scope at end of this function) closes the wire when
714        // the last reference goes — which is here, since `idle`
715        // held the only one.
716        for conn in to_drop {
717            let mut guard = conn.lock().await;
718            if let Some(task) = guard.connection_task.take() {
719                task.abort();
720            }
721        }
722    }
723
724    /// Look up the lease for `session_id` without opening a new one.
725    async fn leased_only(&self, session_id: &str) -> Option<Arc<Mutex<PooledConnection>>> {
726        let inner = self.inner.lock().await;
727        inner.leased.get(session_id).cloned()
728    }
729
730    /// Remove and return the lease for `session_id`.
731    async fn take_lease(&self, session_id: &str) -> Option<Arc<Mutex<PooledConnection>>> {
732        let mut inner = self.inner.lock().await;
733        inner.leased.remove(session_id)
734    }
735}
736
737impl Drop for PgPool {
738    fn drop(&mut self) {
739        // Cancel the eviction loop first so it doesn't try to upgrade
740        // the soon-dead `Weak<Self>` and log spurious work.
741        self.eviction_cancel.cancel();
742
743        // Best-effort: abort connection tasks. Async shutdown is the
744        // documented path for clean teardown; this guards against
745        // forgotten shutdown calls.
746        if let Ok(mut inner) = self.inner.try_lock() {
747            let mut conns: Vec<Arc<Mutex<PooledConnection>>> =
748                inner.idle.drain(..).map(|e| e.conn).collect();
749            conns.extend(inner.leased.drain().map(|(_, c)| c));
750            for conn in conns {
751                if let Ok(mut guard) = conn.try_lock()
752                    && let Some(task) = guard.connection_task.take()
753                {
754                    task.abort();
755                }
756            }
757        }
758        // Best-effort close of secondary cross-database browsers.
759        if let Ok(mut map) = self.secondary_browsers.try_lock() {
760            for (_, conn) in map.drain() {
761                if let Ok(mut guard) = conn.try_lock()
762                    && let Some(task) = guard.connection_task.take()
763                {
764                    task.abort();
765                }
766            }
767        }
768    }
769}
770
771// ============================================================================
772// Background eviction (private)
773// ============================================================================
774
775/// Background loop that prunes idle connections aged past
776/// `IDLE_TIMEOUT`. Holds a `Weak<PgPool>` so it can't extend the
777/// pool's lifetime — when all `Arc<PgPool>`s are dropped, the next
778/// upgrade fails and the loop exits.
779async fn run_eviction(pool: Weak<PgPool>, cancel: CancellationToken) {
780    let mut ticker = tokio::time::interval(EVICTION_INTERVAL);
781    // Skip the immediate first tick — `connect` just opened a
782    // connection so there's nothing to evict yet.
783    ticker.tick().await;
784    loop {
785        tokio::select! {
786            _ = cancel.cancelled() => return,
787            _ = ticker.tick() => {
788                let Some(pool) = pool.upgrade() else { return };
789                pool.evict_idle().await;
790                // Drop the strong ref before the next sleep so the
791                // pool can be reclaimed promptly when the manager
792                // discards it.
793                drop(pool);
794            }
795        }
796    }
797}
798
799// ============================================================================
800// Connection construction (private)
801// ============================================================================
802
803async fn open_one(
804    cfg: &PgConfig,
805    tunnel: Option<&SshTunnel>,
806    tls: &TlsConnectorKind,
807) -> Result<PooledConnection, PgError> {
808    let driver_cfg = build_driver_config(cfg, tunnel)?;
809    match tls {
810        TlsConnectorKind::NoTls => {
811            let (client, connection) = driver_cfg
812                .connect(tokio_postgres::NoTls)
813                .await
814                .map_err(classify_connect_error)?;
815            Ok(spawn_connection(client, connection))
816        }
817        TlsConnectorKind::Rustls(connector) => {
818            // `MakeRustlsConnect` clones cheaply (Arc-shared
819            // ClientConfig); each connect consumes its own clone.
820            let (client, connection) = driver_cfg
821                .connect(connector.clone())
822                .await
823                .map_err(classify_connect_error)?;
824            Ok(spawn_connection(client, connection))
825        }
826    }
827}
828
829/// Build the TLS connector used for both the data wire and cancel
830/// handshakes for this pool. Installing the rustls crypto provider
831/// once per pool is sufficient — the call is idempotent across
832/// multiple pools (the second install returns Err and we ignore it).
833fn build_tls_connector(cfg: &PgConfig) -> Result<TlsConnectorKind, PgError> {
834    match cfg.tls {
835        PgTlsMode::Disable => Ok(TlsConnectorKind::NoTls),
836        PgTlsMode::Prefer | PgTlsMode::Require | PgTlsMode::VerifyFull => {
837            let _ = rustls::crypto::ring::default_provider().install_default();
838            let tls_config = build_rustls_config(cfg.tls)?;
839            Ok(TlsConnectorKind::Rustls(MakeRustlsConnect::new(tls_config)))
840        }
841    }
842}
843
844fn build_driver_config(
845    cfg: &PgConfig,
846    tunnel: Option<&SshTunnel>,
847) -> Result<PgDriverConfig, PgError> {
848    let mut driver = PgDriverConfig::new();
849    if let Some(t) = tunnel {
850        driver.host("127.0.0.1").port(t.local_port());
851    } else {
852        driver.host(&cfg.host).port(cfg.port);
853    }
854    driver.dbname(&cfg.database).user(&cfg.user);
855
856    let password = match &cfg.auth {
857        PgAuthMethod::Password { password } => password.clone(),
858        PgAuthMethod::Keychain { account } => crate::keychain::load_password(
859            crate::keychain::CredentialKind::PostgresPassword,
860            account,
861        )
862        .map_err(|e| PgError::Auth(format!("keychain load failed for {account}: {e}")))?
863        .ok_or_else(|| {
864            PgError::Auth(format!("no keychain entry for postgres account {account}"))
865        })?,
866    };
867    if !password.is_empty() {
868        driver.password(password);
869    }
870
871    if let Some(name) = &cfg.application_name {
872        driver.application_name(name);
873    }
874    if let Some(secs) = cfg.connect_timeout_secs {
875        driver.connect_timeout(Duration::from_secs(secs));
876    }
877
878    driver.ssl_mode(match cfg.tls {
879        PgTlsMode::Disable => PgSslMode::Disable,
880        PgTlsMode::Prefer => PgSslMode::Prefer,
881        PgTlsMode::Require | PgTlsMode::VerifyFull => PgSslMode::Require,
882    });
883    Ok(driver)
884}
885
886fn build_rustls_config(mode: PgTlsMode) -> Result<RustlsClientConfig, PgError> {
887    let mut roots = rustls::RootCertStore::empty();
888    let native = rustls_native_certs::load_native_certs();
889    for cert in native.certs {
890        let _ = roots.add(cert);
891    }
892
893    let cfg = match mode {
894        PgTlsMode::VerifyFull => RustlsClientConfig::builder()
895            .with_root_certificates(roots)
896            .with_no_client_auth(),
897        _ => RustlsClientConfig::builder()
898            .dangerous()
899            .with_custom_certificate_verifier(std::sync::Arc::new(NoCertVerifier))
900            .with_no_client_auth(),
901    };
902    Ok(cfg)
903}
904
905fn classify_connect_error(e: tokio_postgres::Error) -> PgError {
906    if let Some(db_err) = e.as_db_error() {
907        let code = db_err.code().code();
908        if code == "28P01" || code == "28000" {
909            return PgError::Auth(db_err.message().to_string());
910        }
911    }
912    PgError::Connect(e.to_string())
913}
914
915fn spawn_connection<S, T>(
916    client: Client,
917    connection: tokio_postgres::Connection<S, T>,
918) -> PooledConnection
919where
920    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
921    T: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
922{
923    let cancel_token = client.cancel_token();
924    let task = tokio::spawn(async move {
925        if let Err(e) = connection.await {
926            tracing::warn!("postgres connection task ended with error: {e}");
927        }
928    });
929    PooledConnection {
930        client,
931        cancel_token,
932        active_cursor: None,
933        connection_task: Some(task),
934    }
935}
936
937#[derive(Debug)]
938struct NoCertVerifier;
939
940impl rustls::client::danger::ServerCertVerifier for NoCertVerifier {
941    fn verify_server_cert(
942        &self,
943        _end_entity: &rustls::pki_types::CertificateDer<'_>,
944        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
945        _server_name: &rustls::pki_types::ServerName<'_>,
946        _ocsp_response: &[u8],
947        _now: rustls::pki_types::UnixTime,
948    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
949        Ok(rustls::client::danger::ServerCertVerified::assertion())
950    }
951
952    fn verify_tls12_signature(
953        &self,
954        _message: &[u8],
955        _cert: &rustls::pki_types::CertificateDer<'_>,
956        _dss: &rustls::DigitallySignedStruct,
957    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
958        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
959    }
960
961    fn verify_tls13_signature(
962        &self,
963        _message: &[u8],
964        _cert: &rustls::pki_types::CertificateDer<'_>,
965        _dss: &rustls::DigitallySignedStruct,
966    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
967        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
968    }
969
970    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
971        vec![
972            rustls::SignatureScheme::RSA_PKCS1_SHA256,
973            rustls::SignatureScheme::RSA_PKCS1_SHA384,
974            rustls::SignatureScheme::RSA_PKCS1_SHA512,
975            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
976            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
977            rustls::SignatureScheme::ED25519,
978            rustls::SignatureScheme::RSA_PSS_SHA256,
979            rustls::SignatureScheme::RSA_PSS_SHA384,
980            rustls::SignatureScheme::RSA_PSS_SHA512,
981        ]
982    }
983}
984
985#[cfg(test)]
986mod tests {
987    use super::*;
988
989    /// `PgPool::connect` rejects a tunneled config when no SSH client
990    /// is supplied. The macOS bridge always supplies one when the
991    /// config requests a tunnel; this guards against accidental
992    /// bypass in tests / library usage.
993    #[test]
994    fn pool_connect_with_tunnel_requires_ssh_client() {
995        let cfg = PgConfig {
996            ssh_tunnel: Some(SshTunnelRef {
997                ssh_connection_id: "ssh-1".to_string(),
998                remote_host: "db".to_string(),
999                remote_port: 5432,
1000            }),
1001            ..PgConfig::local("db", "u")
1002        };
1003        let rt = tokio::runtime::Runtime::new().unwrap();
1004        match rt.block_on(PgPool::connect(cfg, None)) {
1005            Err(PgError::Tunnel(detail)) => {
1006                assert!(detail.contains("ssh client"));
1007            }
1008            Err(other) => panic!("expected Tunnel error, got {other:?}"),
1009            Ok(_) => panic!("expected error, got Ok"),
1010        }
1011    }
1012
1013    #[test]
1014    fn driver_config_uses_correct_ssl_mode() {
1015        let mut cfg = PgConfig::local("db", "u");
1016        cfg.tls = PgTlsMode::Require;
1017        let driver = build_driver_config(&cfg, None).expect("driver cfg");
1018        assert!(matches!(driver.get_ssl_mode(), PgSslMode::Require));
1019
1020        cfg.tls = PgTlsMode::Disable;
1021        let driver = build_driver_config(&cfg, None).expect("driver cfg");
1022        assert!(matches!(driver.get_ssl_mode(), PgSslMode::Disable));
1023    }
1024
1025    #[test]
1026    fn driver_config_omits_password_when_empty() {
1027        let cfg = PgConfig::local("db", "u");
1028        let driver = build_driver_config(&cfg, None).expect("driver cfg");
1029        assert!(driver.get_password().is_none());
1030    }
1031
1032    /// Pure logic test of the eviction policy: given idle entries
1033    /// with various ages, the keep/drop split honors `IDLE_TIMEOUT`
1034    /// and `MIN_IDLE_CONNECTIONS`. Doesn't open real connections —
1035    /// the policy is just arithmetic over `(since, idx)` pairs.
1036    #[test]
1037    fn eviction_policy_keeps_min_idle_and_drops_aged() {
1038        // The keep/drop math against (idle_timeout, min_idle) inputs.
1039        // Mirrors `evict_idle` so a future change to the loop forces
1040        // a test update.
1041        fn run_policy(
1042            entries: Vec<(usize, Instant)>,
1043            now: Instant,
1044            idle_timeout: Duration,
1045            min_idle: usize,
1046        ) -> (Vec<usize>, Vec<usize>) {
1047            let mut keep: Vec<usize> = Vec::new();
1048            let mut drop_idx: Vec<usize> = Vec::new();
1049            for (idx, since) in entries {
1050                let aged = now.duration_since(since) >= idle_timeout;
1051                if !aged || keep.len() < min_idle {
1052                    keep.push(idx);
1053                } else {
1054                    drop_idx.push(idx);
1055                }
1056            }
1057            (keep, drop_idx)
1058        }
1059
1060        let now = Instant::now();
1061        let timeout = Duration::from_secs(300);
1062        let aged = now - timeout - Duration::from_secs(1);
1063        let fresh = now - Duration::from_secs(10);
1064
1065        // min_idle = 1 (default): one aged entry survives, fresh
1066        // always survives, rest evicted.
1067        let (keep, drop_idx) = run_policy(
1068            vec![(0, aged), (1, aged), (2, fresh), (3, aged)],
1069            now,
1070            timeout,
1071            1,
1072        );
1073        assert_eq!(keep, vec![0, 2]);
1074        assert_eq!(drop_idx, vec![1, 3]);
1075
1076        // min_idle = 0: all aged entries evicted, only fresh survives.
1077        let (keep, drop_idx) = run_policy(vec![(0, aged), (1, fresh), (2, aged)], now, timeout, 0);
1078        assert_eq!(keep, vec![1]);
1079        assert_eq!(drop_idx, vec![0, 2]);
1080    }
1081}