Skip to main content

sentinel_driver/connection/
client.rs

1use super::{
2    pipeline, startup, BackendMessage, BytesMut, CancelToken, Config, Connection, Duration, Error,
3    PgConnection, PipelineBatch, Result, StatementCache, ToSql, TransactionStatus,
4};
5use crate::config::{LoadBalanceHosts, TargetSessionAttrs};
6
7impl Connection {
8    /// Connect to PostgreSQL and perform the startup handshake.
9    ///
10    /// With multiple hosts configured, tries each host in order (or shuffled
11    /// if `load_balance_hosts=random`) until one succeeds and matches the
12    /// required `target_session_attrs`.
13    pub async fn connect(config: Config) -> Result<Self> {
14        let mut hosts: Vec<(String, u16)> = config.hosts().to_vec();
15
16        if hosts.is_empty() {
17            hosts.push(("localhost".to_string(), 5432));
18        }
19
20        if config.load_balance_hosts() == LoadBalanceHosts::Random {
21            use rand::seq::SliceRandom;
22            use rand::thread_rng;
23            hosts.shuffle(&mut thread_rng());
24        }
25
26        let mut last_error: Option<Error> = None;
27
28        for (host, port) in &hosts {
29            match Self::try_connect_host(&config, host, *port).await {
30                Ok(conn) => return Ok(conn),
31                Err(e) => {
32                    tracing::debug!(host = %host, port = %port, error = %e, "host failed");
33                    last_error = Some(e);
34                }
35            }
36        }
37
38        Err(last_error.unwrap_or_else(|| Error::AllHostsFailed("no hosts configured".to_string())))
39    }
40
41    /// Try connecting to a single host, performing startup and session attrs check.
42    async fn try_connect_host(config: &Config, host: &str, port: u16) -> Result<Self> {
43        let mut conn = PgConnection::connect_host(config, host, port).await?;
44        let result = startup::startup(&mut conn, config).await?;
45
46        // Check target_session_attrs after successful auth
47        if config.target_session_attrs() != TargetSessionAttrs::Any {
48            startup::check_session_attrs(&mut conn, config.target_session_attrs()).await?;
49        }
50
51        let query_timeout = config.statement_timeout();
52
53        Ok(Self {
54            conn,
55            config: config.clone(),
56            connected_host: host.to_string(),
57            connected_port: port,
58            process_id: result.process_id,
59            secret_key: result.secret_key,
60            transaction_status: result.transaction_status,
61            stmt_cache: StatementCache::new(),
62            query_timeout,
63            is_broken: false,
64        })
65    }
66
67    /// Close the connection gracefully.
68    pub async fn close(self) -> Result<()> {
69        self.conn.close().await
70    }
71
72    /// Get a cancel token for this connection.
73    ///
74    /// The token can be cloned and sent to another task to cancel a
75    /// running query. See [`CancelToken`] for details.
76    pub fn cancel_token(&self) -> CancelToken {
77        CancelToken::new(
78            &self.connected_host,
79            self.connected_port,
80            self.process_id,
81            self.secret_key,
82        )
83    }
84
85    /// Returns `true` if the connection is using TLS.
86    /// Returns the configuration used for this connection.
87    pub fn config(&self) -> &Config {
88        &self.config
89    }
90
91    /// Returns the host this connection is connected to.
92    pub fn connected_host(&self) -> &str {
93        &self.connected_host
94    }
95
96    /// Returns the port this connection is connected to.
97    pub fn connected_port(&self) -> u16 {
98        self.connected_port
99    }
100
101    pub fn is_tls(&self) -> bool {
102        self.conn.is_tls()
103    }
104
105    /// Returns `true` if connected via Unix domain socket.
106    #[cfg(unix)]
107    pub fn is_unix(&self) -> bool {
108        self.conn.is_unix()
109    }
110
111    /// The server process ID for this connection.
112    pub fn process_id(&self) -> i32 {
113        self.process_id
114    }
115
116    /// Returns the configured query timeout, if any.
117    pub fn query_timeout(&self) -> Option<Duration> {
118        self.query_timeout
119    }
120
121    /// Returns `true` if the connection has been marked broken by a timeout.
122    ///
123    /// A broken connection should be discarded — the server state is
124    /// indeterminate after a cancelled query.
125    pub fn is_broken(&self) -> bool {
126        self.is_broken
127    }
128
129    /// Current transaction status.
130    pub fn transaction_status(&self) -> TransactionStatus {
131        self.transaction_status
132    }
133
134    /// Access the underlying PgConnection mutably.
135    pub(crate) fn pg_connection_mut(&mut self) -> &mut PgConnection {
136        &mut self.conn
137    }
138
139    // ── Internal ─────────────────────────────────────
140
141    pub(crate) async fn query_internal(
142        &mut self,
143        sql: &str,
144        params: &[&(dyn ToSql + Sync)],
145    ) -> Result<pipeline::QueryResult> {
146        // Encode parameters
147        let param_types: Vec<u32> = params.iter().map(|p| p.oid().0).collect();
148        let mut encoded_params: Vec<Option<Vec<u8>>> = Vec::with_capacity(params.len());
149
150        for param in params {
151            if param.is_null() {
152                encoded_params.push(None);
153            } else {
154                let mut buf = BytesMut::new();
155                param.to_sql(&mut buf)?;
156                encoded_params.push(Some(buf.to_vec()));
157            }
158        }
159
160        // Use pipeline for single query (same protocol, consistent code path)
161        let mut batch = PipelineBatch::new();
162        batch.add(sql.to_string(), param_types, encoded_params);
163
164        let mut results = batch.execute(&mut self.conn).await?;
165
166        results
167            .pop()
168            .ok_or_else(|| Error::protocol("pipeline returned no results"))
169    }
170
171    pub(crate) async fn drain_until_ready(&mut self) -> Result<()> {
172        loop {
173            if let BackendMessage::ReadyForQuery { transaction_status } = self.conn.recv().await? {
174                self.transaction_status = transaction_status;
175                return Ok(());
176            }
177        }
178    }
179}