sentinel_driver/connection/
client.rs1use super::{
2 pipeline, startup, BackendMessage, BytesMut, CancelToken, Config, Connection, Duration, Error,
3 PgConnection, PipelineBatch, Result, StatementCache, ToSql, TransactionStatus,
4};
5use crate::config::{LoadBalanceHosts, TargetSessionAttrs};
6
7impl Connection {
8 pub async fn connect(config: Config) -> Result<Self> {
14 let mut hosts: Vec<(String, u16)> = config.hosts().to_vec();
15
16 if hosts.is_empty() {
17 hosts.push(("localhost".to_string(), 5432));
18 }
19
20 if config.load_balance_hosts() == LoadBalanceHosts::Random {
21 use rand::seq::SliceRandom;
22 use rand::thread_rng;
23 hosts.shuffle(&mut thread_rng());
24 }
25
26 let mut last_error: Option<Error> = None;
27
28 for (host, port) in &hosts {
29 match Self::try_connect_host(&config, host, *port).await {
30 Ok(conn) => return Ok(conn),
31 Err(e) => {
32 tracing::debug!(host = %host, port = %port, error = %e, "host failed");
33 last_error = Some(e);
34 }
35 }
36 }
37
38 Err(last_error.unwrap_or_else(|| Error::AllHostsFailed("no hosts configured".to_string())))
39 }
40
41 async fn try_connect_host(config: &Config, host: &str, port: u16) -> Result<Self> {
43 let mut conn = PgConnection::connect_host(config, host, port).await?;
44 let result = startup::startup(&mut conn, config).await?;
45
46 if config.target_session_attrs() != TargetSessionAttrs::Any {
48 startup::check_session_attrs(&mut conn, config.target_session_attrs()).await?;
49 }
50
51 let query_timeout = config.statement_timeout();
52
53 Ok(Self {
54 conn,
55 config: config.clone(),
56 connected_host: host.to_string(),
57 connected_port: port,
58 process_id: result.process_id,
59 secret_key: result.secret_key,
60 transaction_status: result.transaction_status,
61 stmt_cache: StatementCache::new(),
62 query_timeout,
63 is_broken: false,
64 instrumentation: config
65 .instrumentation
66 .clone()
67 .unwrap_or_else(crate::instrumentation::noop),
68 })
69 }
70
71 pub async fn close(self) -> Result<()> {
73 self.conn.close().await
74 }
75
76 pub fn cancel_token(&self) -> CancelToken {
81 CancelToken::new(
82 &self.connected_host,
83 self.connected_port,
84 self.process_id,
85 self.secret_key,
86 )
87 }
88
89 pub fn config(&self) -> &Config {
92 &self.config
93 }
94
95 pub fn connected_host(&self) -> &str {
97 &self.connected_host
98 }
99
100 pub fn connected_port(&self) -> u16 {
102 self.connected_port
103 }
104
105 pub fn is_tls(&self) -> bool {
106 self.conn.is_tls()
107 }
108
109 #[cfg(unix)]
111 pub fn is_unix(&self) -> bool {
112 self.conn.is_unix()
113 }
114
115 pub fn process_id(&self) -> i32 {
117 self.process_id
118 }
119
120 pub fn query_timeout(&self) -> Option<Duration> {
122 self.query_timeout
123 }
124
125 pub fn is_broken(&self) -> bool {
130 self.is_broken
131 }
132
133 pub fn transaction_status(&self) -> TransactionStatus {
135 self.transaction_status
136 }
137
138 pub(crate) fn pg_connection_mut(&mut self) -> &mut PgConnection {
140 &mut self.conn
141 }
142
143 pub(crate) async fn query_internal(
146 &mut self,
147 sql: &str,
148 params: &[&(dyn ToSql + Sync)],
149 ) -> Result<pipeline::QueryResult> {
150 self.instr().on_event(&crate::Event::ExecuteStart {
151 stmt: crate::StmtRef::Inline { sql },
152 param_count: params.len(),
153 });
154 let started = std::time::Instant::now();
155 let res = self.query_internal_inner(sql, params).await;
156 let duration = started.elapsed();
157 let (rows, outcome) = match &res {
158 Ok(pipeline::QueryResult::Rows(v)) => (v.len() as u64, crate::Outcome::Ok),
159 Ok(pipeline::QueryResult::Command(r)) => (r.rows_affected, crate::Outcome::Ok),
160 Err(e) => (0, crate::Outcome::Err(e)),
161 };
162 self.instr().on_event(&crate::Event::ExecuteFinish {
163 stmt: crate::StmtRef::Inline { sql },
164 rows,
165 duration,
166 outcome,
167 });
168 res
169 }
170
171 async fn query_internal_inner(
172 &mut self,
173 sql: &str,
174 params: &[&(dyn ToSql + Sync)],
175 ) -> Result<pipeline::QueryResult> {
176 let param_types: Vec<u32> = params.iter().map(|p| p.oid().0).collect();
178 let mut encoded_params: Vec<Option<Vec<u8>>> = Vec::with_capacity(params.len());
179
180 for param in params {
181 if param.is_null() {
182 encoded_params.push(None);
183 } else {
184 let mut buf = BytesMut::new();
185 param.to_sql(&mut buf)?;
186 encoded_params.push(Some(buf.to_vec()));
187 }
188 }
189
190 let mut batch = PipelineBatch::new();
192 batch.add(sql.to_string(), param_types, encoded_params);
193
194 let mut results = batch.execute(&mut self.conn).await?;
195
196 results
197 .pop()
198 .ok_or_else(|| Error::protocol("pipeline returned no results"))
199 }
200
201 pub(crate) async fn drain_until_ready(&mut self) -> Result<()> {
202 loop {
203 if let BackendMessage::ReadyForQuery { transaction_status } = self.conn.recv().await? {
204 self.transaction_status = transaction_status;
205 return Ok(());
206 }
207 }
208 }
209
210 pub fn set_instrumentation(&mut self, instr: std::sync::Arc<dyn crate::Instrumentation>) {
213 self.instrumentation = instr;
214 }
215
216 pub fn instrumentation(&self) -> &std::sync::Arc<dyn crate::Instrumentation> {
220 &self.instrumentation
221 }
222
223 pub(crate) fn instr(&self) -> &dyn crate::Instrumentation {
225 &*self.instrumentation
226 }
227}