1use crate::client::{Addr, SocketConfig};
2use crate::config::{Host, LoadBalanceHosts, TargetSessionAttrs};
3use crate::connect_raw::connect_raw;
4use crate::connect_socket::connect_socket;
5use crate::tls::MakeTlsConnect;
6use crate::{Client, Config, Connection, Error, NoTls, SimpleQueryMessage, Socket};
7use futures_util::{future, pin_mut, Future, FutureExt, Stream};
8use lazy_static::lazy_static;
9use log::{debug, info};
10use rand::seq::SliceRandom;
11use rand::Rng;
12use std::collections::HashMap;
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::sync::Mutex;
15use std::task::Poll;
16use std::time::Instant;
17use std::{cmp, io};
18use tokio::net;
19use tokio::sync::Mutex as TokioMutex;
20
21lazy_static! {
22 static ref CONNECTION_COUNT_MAP: Mutex<HashMap<Host, i64>> = {
23 let mut m = HashMap::new();
24 let host_list_primary = HOST_INFO_PRIMAY.lock().unwrap().clone();
25 let host_list_rr = HOST_INFO_RR.lock().unwrap().clone();
26 let host_list = [host_list_primary, host_list_rr].concat();
27 let size = host_list.len();
28 for i in 0..size {
29 let host = host_list.get(i);
30 if host.is_some() {
31 m.insert(host.unwrap().clone(), 0);
32 }
33 }
34 Mutex::new(m)
35 };
36 static ref LAST_TIME_META_DATA_FETCHED: TokioMutex<Instant> = {
37 let m = Instant::now();
38 TokioMutex::new(m)
39 };
40 static ref HOST_INFO_PRIMAY: Mutex<Vec<Host>> = {
41 let m = Vec::new();
42 Mutex::new(m)
43 };
44 static ref HOST_INFO_RR: Mutex<Vec<Host>> = {
45 let m = Vec::new();
46 Mutex::new(m)
47 };
48 static ref FAILED_HOSTS: Mutex<HashMap<Host, Instant>> = {
49 let m = HashMap::new();
50 Mutex::new(m)
51 };
52 pub(crate) static ref PLACEMENT_INFO_MAP_PRIMARY: Mutex<HashMap<String, Vec<Host>>> = {
53 let m = HashMap::new();
54 Mutex::new(m)
55 };
56 pub(crate) static ref PLACEMENT_INFO_MAP_RR: Mutex<HashMap<String, Vec<Host>>> = {
57 let m = HashMap::new();
58 Mutex::new(m)
59 };
60 static ref PUBLIC_HOST_MAP: Mutex<HashMap<Host, Host>> = {
61 let m = HashMap::new();
62 Mutex::new(m)
63 };
64 static ref HOST_TO_PORT_MAP: Mutex<HashMap<Host, u16>> = {
65 let m = HashMap::new();
66 Mutex::new(m)
67 };
68}
69
70static USE_PUBLIC_IP: AtomicBool = AtomicBool::new(false);
71
72pub async fn connect<T>(
73 mut tls: T,
74 config: &Config,
75) -> Result<(Client, Connection<Socket, T::Stream>), Error>
76where
77 T: MakeTlsConnect<Socket>,
78{
79 if config.host.is_empty() && config.hostaddr.is_empty() {
80 return Err(Error::config("both host and hostaddr are missing".into()));
81 }
82
83 if !config.host.is_empty()
84 && !config.hostaddr.is_empty()
85 && config.host.len() != config.hostaddr.len()
86 {
87 let msg = format!(
88 "number of hosts ({}) is different from number of hostaddrs ({})",
89 config.host.len(),
90 config.hostaddr.len(),
91 );
92 return Err(Error::config(msg.into()));
93 }
94
95 let num_hosts = cmp::max(config.host.len(), config.hostaddr.len());
99
100 if config.port.len() > 1 && config.port.len() != num_hosts {
101 return Err(Error::config("invalid number of ports".into()));
102 }
103
104 let mut indices = (0..num_hosts).collect::<Vec<_>>();
105 if config.load_balance_hosts == LoadBalanceHosts::Random {
106 indices.shuffle(&mut rand::thread_rng());
107 }
108
109 let mut error = None;
110 for i in indices {
111 let host = config.host.get(i);
112 let hostaddr = config.hostaddr.get(i);
113 let port = config
114 .port
115 .get(i)
116 .or_else(|| config.port.first())
117 .copied()
118 .unwrap_or(5433);
119
120 let hostname = match host {
122 Some(Host::Tcp(host)) => Some(host.clone()),
123 #[cfg(unix)]
125 Some(Host::Unix(_)) => None,
126 None => None,
127 };
128
129 let addr = match hostaddr {
132 Some(ipaddr) => Host::Tcp(ipaddr.to_string()),
133 None => host.cloned().unwrap(),
134 };
135
136 match connect_host(addr, hostname, port, &mut tls, config).await {
137 Ok((client, connection)) => return Ok((client, connection)),
138 Err(e) => error = Some(e),
139 }
140 }
141
142 Err(error.unwrap())
143}
144
145pub async fn yb_connect<T>(
146 mut tls: T,
147 config: &Config,
148) -> Result<(Client, Connection<Socket, T::Stream>), Error>
149where
150 T: MakeTlsConnect<Socket>,
151{
152 if config.host.is_empty() && config.hostaddr.is_empty() {
153 return Err(Error::config("both host and hostaddr are missing".into()));
154 }
155
156 if !config.host.is_empty()
157 && !config.hostaddr.is_empty()
158 && config.host.len() != config.hostaddr.len()
159 {
160 let msg = format!(
161 "number of hosts ({}) is different from number of hostaddrs ({})",
162 config.host.len(),
163 config.hostaddr.len(),
164 );
165 return Err(Error::config(msg.into()));
166 }
167
168 let num_hosts = cmp::max(config.host.len(), config.hostaddr.len());
172
173 if config.port.len() > 1 && config.port.len() != num_hosts {
174 return Err(Error::config("invalid number of ports".into()));
175 }
176
177 if !check_and_refresh(config).await {
178 return Err(Error::connect(io::Error::new(
179 io::ErrorKind::ConnectionRefused,
180 "could not create control connection",
181 )));
182 }
183
184 let host_to_port_map = HOST_TO_PORT_MAP.lock().unwrap().clone();
185
186 loop {
187 let newhost = get_least_loaded_server(config);
188 let mut host = match newhost {
189 Ok(host) => host,
190 Err(e) => {
191 return Err(e);
193 }
194 };
195
196 increase_connection_count(host.clone());
197
198 if USE_PUBLIC_IP.load(Ordering::SeqCst) {
200 let public_host_map = PUBLIC_HOST_MAP.lock().unwrap().clone();
201 let public_host = public_host_map.get(&host.clone());
202 if public_host.is_none() {
203 info!("Public host not available for private host {:?}, adding this to failed host list and trying another server", host.clone());
204 decrease_connection_count(host.clone());
205 add_to_failed_host_list(host.clone());
206 continue;
207 } else {
208 host = public_host.unwrap().clone();
209 }
210 }
211
212 let hostname = match host.clone() {
213 Host::Tcp(host) => Some(host),
214 #[cfg(unix)]
216 Host::Unix(_) => None,
217 };
218
219 info!("Creating connection to {:?}", hostname.clone());
220 match connect_host(
221 host.clone(),
222 hostname.clone(),
223 host_to_port_map[&(host.clone())],
224 &mut tls,
225 config,
226 )
227 .await
228 {
229 Ok((client, connection)) => return Ok((client, connection)),
230 Err(_e) => {
231 info!("Not able to create connection to {:?}, adding it to failed host list and trying a different host.", hostname.clone());
232 decrease_connection_count(host.clone());
233 add_to_failed_host_list(host);
234 }
235 }
236 }
237}
238
239fn increase_connection_count(host: Host) {
240 let mut conn_map = CONNECTION_COUNT_MAP.lock().unwrap();
241 let count = conn_map.get(&host);
242 if count.is_none() {
243 conn_map.insert(host.clone(), 1);
244 debug!("Increasing connection count for {:?} to 1", host.clone());
245 } else {
246 let mut conn_count: i64 = *count.unwrap();
247 conn_count += 1;
248 conn_map.insert(host.clone(), conn_count);
249 debug!(
250 "Increasing connection count for {:?} by one: {}",
251 host.clone(),
252 conn_count
253 );
254 }
255}
256
257pub(crate) fn decrease_connection_count(host: Host) {
258 let mut conn_map = CONNECTION_COUNT_MAP.lock().unwrap();
259 let count = conn_map.get(&host);
260 if count.is_some() {
261 let mut conn_count: i64 = *count.unwrap();
262 if conn_count != 0 {
263 conn_count -= 1;
264 conn_map.insert(host.clone(), conn_count);
265 debug!(
266 "Decremented connection count for {:?} by one: {}",
267 host.clone(),
268 conn_count
269 );
270 }
271 }
272}
273
274fn get_least_loaded_server(config: &Config) -> Result<Host, Error> {
275 let conn_map = CONNECTION_COUNT_MAP.lock().unwrap().clone();
276 let host_list_primary = HOST_INFO_PRIMAY.lock().unwrap().clone();
277 let host_list_rr = HOST_INFO_RR.lock().unwrap().clone();
278 let failed_host_list = FAILED_HOSTS.lock().unwrap().clone();
279 let placement_info_map_primary = PLACEMENT_INFO_MAP_PRIMARY.lock().unwrap().clone();
280 let placement_info_map_rr = PLACEMENT_INFO_MAP_RR.lock().unwrap().clone();
281 let mut least_host: Vec<Host> = Vec::new();
282
283 let mut host_list: Vec<Host>;
284 let mut placement_info_map: HashMap<String, Vec<Host>>;
285
286 if config.load_balance == "only-rr" || config.load_balance == "prefer-rr" {
287 host_list = host_list_rr.clone();
288 placement_info_map = placement_info_map_rr.clone();
289 } else if config.load_balance == "only-primary" || config.load_balance == "prefer-primary" {
290 host_list = host_list_primary.clone();
291 placement_info_map = placement_info_map_primary.clone();
292 } else {
293 host_list = host_list_rr.clone();
294 placement_info_map = placement_info_map_rr.clone();
295 host_list.extend(host_list_primary.clone());
296 for (key, value) in placement_info_map_primary {
297 if let Some(vec) = placement_info_map.get_mut(&key) {
298 vec.extend(value);
299 } else {
300 placement_info_map.insert(key, value);
301 }
302 }
303 }
304
305 if !config.topology_keys.is_empty() {
306 for i in 0..config.topology_keys.len() as i64 {
307 let mut server: Vec<Host> = Vec::new();
308 let prefered_zone = config.topology_keys.get(&(i + 1)).unwrap();
309 for placement_info in prefered_zone.iter() {
310 let to_check_star: Vec<&str> = placement_info.split(".").collect();
311 if to_check_star[2] == "*" {
312 let star_placement_info: String =
313 to_check_star[0].to_owned() + "." + to_check_star[1];
314 let append_hosts = placement_info_map.get(&star_placement_info);
315 if let Some(append_hosts_value) = append_hosts {
316 server.extend(append_hosts_value.to_owned());
317 }
318 } else {
319 let append_hosts = placement_info_map.get(placement_info);
320 if let Some(append_hosts_value) = append_hosts {
321 server.extend(append_hosts_value.to_owned());
322 }
323 }
324 }
325 least_host = get_least_loaded_hosts(server, conn_map.clone(), failed_host_list.clone());
326
327 if !least_host.is_empty() {
328 break;
329 }
330 }
331 }
332
333 if least_host.is_empty() {
334 if !(config.load_balance == "prefer-primary" || config.load_balance == "prefer-rr") {
335 if config.topology_keys.is_empty() || !config.fallback_to_topology_keys_only {
336 least_host = get_least_loaded_hosts(host_list, conn_map.clone(), failed_host_list.clone());
337 } else {
338 return Err(Error::connect(io::Error::new(
339 io::ErrorKind::ConnectionRefused,
340 "no preferred server available, fallback-to-topology-keys-only is set to true",
341 )));
342 }
343 } else {
344 least_host = get_least_loaded_hosts(host_list, conn_map.clone(), failed_host_list.clone());
345 if least_host.is_empty() {
346 if config.load_balance == "prefer-rr"{
347 least_host = get_least_loaded_hosts(host_list_primary, conn_map.clone(), failed_host_list.clone());
348 } else {
349 least_host = get_least_loaded_hosts(host_list_rr, conn_map.clone(), failed_host_list.clone());
350 }
351 }
352 }
353 }
354
355 if !least_host.is_empty() {
356 info!(
357 "Following hosts have the least number of connections: {:?}, chosing one randomly",
358 least_host
359 );
360 let num = rand::thread_rng().gen_range(0..least_host.len());
361 Ok(least_host.get(num).cloned().expect("least loaded host value is None"))
362 } else {
363 Err(Error::connect(io::Error::new(
364 io::ErrorKind::ConnectionRefused,
365 "could not find a server to connect to",
366 )))
367 }
368}
369
370fn get_least_loaded_hosts(hosts: Vec<Host>, conn_map: HashMap<Host, i64>, failed_hosts: HashMap<Host, Instant>) -> Vec<Host> {
371 let mut min_count = i64::MAX;
372 let mut least_host: Vec<Host> = Vec::new();
373 for host in hosts.iter() {
374 if !failed_hosts.contains_key(host) {
375 let count = conn_map.get(host);
376 let mut counter: i64 = 0;
377 if count.is_some() {
378 counter = *count.unwrap();
379 }
380 if min_count > counter {
381 min_count = counter;
382 least_host.clear();
383 least_host.push(host.clone());
384 } else if min_count == counter {
385 least_host.push(host.clone());
386 }
387 }
388 }
389 least_host
390}
391
392async fn check_and_refresh(config: &Config) -> bool {
393 let mut refresh_time = LAST_TIME_META_DATA_FETCHED.lock().await;
394 let host_list_primary = HOST_INFO_PRIMAY.lock().unwrap().clone();
395 let host_list_rr = HOST_INFO_RR.lock().unwrap().clone();
396 let host_list = [host_list_primary, host_list_rr].concat();
397 if host_list.is_empty() {
398 info!("Connecting to the server for the first time");
399 if let Ok((client, connection)) = connect(NoTls, config).await {
400 let handle = tokio::spawn(async move {
401 if let Err(e) = connection.await {
402 eprintln!("connection error: {}", e);
403 }
404 });
405 info!("Control connection created to one of {:?}", config.host);
406 refresh(client, config).await;
407 let start = Instant::now();
408 *refresh_time = start;
409 info!("Resetting LAST_TIME_META_DATA_FETCHED");
410 handle.abort();
411 return true;
412 } else {
413 info!("Failed to establish control connection to available servers");
414 return false;
415 }
416 } else {
417 let duration = refresh_time.elapsed();
418 if duration > config.yb_servers_refresh_interval {
419 let host_to_port_map = HOST_TO_PORT_MAP.lock().unwrap().clone();
420 let mut index = 0;
421 while index < host_list.len() {
422 let host = host_list.get(index);
423 let mut conn_host = host.unwrap().to_owned();
424 if USE_PUBLIC_IP.load(Ordering::SeqCst) {
426 let public_host_map = PUBLIC_HOST_MAP.lock().unwrap().clone();
427 let public_host = public_host_map.get(&conn_host.clone());
428 if public_host.is_none() {
429 info!("Public host not available for private host {:?}, adding this to failed host list and trying another server", conn_host.clone());
430 add_to_failed_host_list(host.cloned().unwrap());
431 index += 1;
432 continue;
433 } else {
434 conn_host = public_host.unwrap().clone();
435 }
436 }
437
438 let hostname = match conn_host.clone() {
440 Host::Tcp(host) => Some(host),
441 #[cfg(unix)]
443 Host::Unix(_) => None,
444 };
445
446 if let Ok((client, connection)) = connect_host(
447 conn_host.clone(),
448 hostname.clone(),
449 host_to_port_map[&(conn_host.clone())],
450 &mut NoTls,
451 config,
452 )
453 .await
454 {
455 let handle = tokio::spawn(async move {
456 if let Err(e) = connection.await {
457 eprintln!("connection error: {}", e);
458 }
459 });
460 info!("Control connection created to {:?}", hostname.clone());
461 refresh(client, config).await;
462 let start = Instant::now();
463 *refresh_time = start;
464 info!("Resetting LAST_TIME_META_DATA_FETCHED");
465 handle.abort();
466 return true;
467 } else {
468 info!("Failed to establish control connection to {:?}, adding this to failed host list and trying another server", hostname.clone());
469 add_to_failed_host_list(host.cloned().unwrap());
470 index += 1;
471 }
472 }
473 info!("Failed to establish control connection to available servers");
474 return false;
475 }
476 }
477 true
478}
479
480fn add_to_failed_host_list(host: Host) {
481 let mut failedhostlist = FAILED_HOSTS.lock().unwrap();
482 failedhostlist.insert(host.clone(), Instant::now());
483 info!("Added {:?} to failed host list", host.clone());
484}
485
486async fn refresh(client: Client, config: &Config) {
487 let socket_config = client.get_socket_config();
488 let mut control_conn_host: String = String::new();
489 if socket_config.is_some() {
490 control_conn_host = socket_config.unwrap().hostname.unwrap();
491 }
492
493 info!("Executing query: `select * from yb_servers()` to fetch list of servers");
494 let rows = client
495 .query("select * from yb_servers()", &[])
496 .await
497 .unwrap();
498
499 let mut host_list_primary = HOST_INFO_PRIMAY.lock().unwrap();
500 let mut host_list_rr = HOST_INFO_RR.lock().unwrap();
501 let mut failed_host_list = FAILED_HOSTS.lock().unwrap();
502 let mut placement_info_map_primary = PLACEMENT_INFO_MAP_PRIMARY.lock().unwrap();
503 let mut placement_info_map_rr = PLACEMENT_INFO_MAP_RR.lock().unwrap();
504 let mut public_host_map = PUBLIC_HOST_MAP.lock().unwrap();
505 let mut host_to_port_map = HOST_TO_PORT_MAP.lock().unwrap();
506 for row in rows {
507 let host_string: String = row.get("host");
508 let host = Host::Tcp(host_string.to_string());
509 info!("Received entry for host {:?}", host);
510 let nodetype: String = row.get("node_type");
511 let portvalue: i64 = row.get("port");
512 let port: u16 = portvalue as u16;
513 let cloud: String = row.get("cloud");
514 let region: String = row.get("region");
515 let zone: String = row.get("zone");
516 let public_ip_string: String = row.get("public_ip");
517 let public_ip = Host::Tcp(public_ip_string.to_string());
518 let placement_zone: String = cloud.clone() + "." + ®ion + "." + &zone;
519 let star_placement_zone: String = cloud.clone() + "." + ®ion;
520
521 host_to_port_map.insert(host.clone(), port);
522 host_to_port_map.insert(public_ip.clone(), port);
523
524 if control_conn_host.eq_ignore_ascii_case(&public_ip_string) {
525 USE_PUBLIC_IP.store(true, Ordering::SeqCst);
526 }
527
528 if !failed_host_list.contains_key(&host) {
529 if nodetype == "primary" {
530 if !host_list_primary.contains(&host) {
531 host_list_primary.push(host.clone());
532 public_host_map.insert(host.clone(), public_ip.clone());
533 debug!("Added {:?} to host list primary", host.clone());
534 }
535 } else {
536 if !host_list_rr.contains(&host) {
537 host_list_rr.push(host.clone());
538 public_host_map.insert(host.clone(), public_ip.clone());
539 debug!("Added {:?} to host list RR", host.clone());
540 }
541 }
542 } else {
543 if failed_host_list.get(&host).unwrap().elapsed()
544 > config.failed_host_reconnect_delay_secs
545 {
546 failed_host_list.remove(&host);
547 debug!(
548 "Marking {:?} as UP since failed-host-reconnect-delay-secs has elapsed",
549 host.clone()
550 );
551 if nodetype == "primary" {
552 if !host_list_primary.contains(&host) {
553 host_list_primary.push(host.clone());
554 public_host_map.insert(host.clone(), public_ip.clone());
555 debug!("Added {:?} to host list primary", host.clone());
556 }
557 } else {
558 if !host_list_rr.contains(&host) {
559 host_list_rr.push(host.clone());
560 public_host_map.insert(host.clone(), public_ip.clone());
561 debug!("Added {:?} to host list RR", host.clone());
562 }
563 }
564 make_connection_count_zero(host.clone());
565 } else if host_list_primary.contains(&host) || host_list_rr.contains(&host) {
566 debug!(
567 "Treating {:?} as DOWN since failed-host-reconnect-delay-secs has not elapsed",
568 host.clone()
569 );
570 if host_list_primary.contains(&host) {
571 let index = host_list_primary.iter().position(|x| *x == host).unwrap();
572 host_list_primary.remove(index);
573 } else {
574 let index = host_list_rr.iter().position(|x| *x == host).unwrap();
575 host_list_rr.remove(index);
576 }
577 public_host_map.remove(&host);
578 }
579 }
580
581 if nodetype == "primary" {
582 if placement_info_map_primary.contains_key(&placement_zone) {
583 let mut present_hosts = placement_info_map_primary.get(&placement_zone).unwrap().to_vec();
584 if !present_hosts.contains(&host) {
585 present_hosts.push(host.clone());
586 placement_info_map_primary.insert(placement_zone.clone(), present_hosts.to_vec());
587 }
588 } else {
589 let mut host_vec: Vec<Host> = Vec::new();
590 host_vec.push(host.clone());
591 placement_info_map_primary.insert(placement_zone.clone(), host_vec);
592 }
593
594 if placement_info_map_primary.contains_key(&star_placement_zone) {
595 let mut star_present_hosts = placement_info_map_primary
596 .get(&star_placement_zone)
597 .unwrap()
598 .to_vec();
599 if !star_present_hosts.contains(&host) {
600 star_present_hosts.push(host.clone());
601 placement_info_map_primary.insert(star_placement_zone.clone(), star_present_hosts.to_vec());
602 }
603 } else {
604 let mut star_host_vec: Vec<Host> = Vec::new();
605 star_host_vec.push(host.clone());
606 placement_info_map_primary.insert(star_placement_zone.clone(), star_host_vec);
607 }
608 } else {
609 if placement_info_map_rr.contains_key(&placement_zone) {
610 let mut present_hosts = placement_info_map_rr.get(&placement_zone).unwrap().to_vec();
611 if !present_hosts.contains(&host) {
612 present_hosts.push(host.clone());
613 placement_info_map_rr.insert(placement_zone, present_hosts.to_vec());
614 }
615 } else {
616 let mut host_vec: Vec<Host> = Vec::new();
617 host_vec.push(host.clone());
618 placement_info_map_rr.insert(placement_zone, host_vec);
619 }
620
621 if placement_info_map_rr.contains_key(&star_placement_zone) {
622 let mut star_present_hosts = placement_info_map_rr
623 .get(&star_placement_zone)
624 .unwrap()
625 .to_vec();
626 if !star_present_hosts.contains(&host) {
627 star_present_hosts.push(host.clone());
628 placement_info_map_rr.insert(star_placement_zone, star_present_hosts.to_vec());
629 }
630 } else {
631 let mut star_host_vec: Vec<Host> = Vec::new();
632 star_host_vec.push(host.clone());
633 placement_info_map_rr.insert(star_placement_zone, star_host_vec);
634 }
635 }
636 }
637}
638
639fn make_connection_count_zero(host: Host) {
640 let mut conn_map = CONNECTION_COUNT_MAP.lock().unwrap();
641 let count = conn_map.get(&host);
642 if count.is_none() {
643 return;
644 }
645 conn_map.insert(host.clone(), 0);
646 debug!("Resetting connection count for {:?} to zero", host.clone());
647}
648
649
650async fn connect_host<T>(
651 host: Host,
652 hostname: Option<String>,
653 port: u16,
654 tls: &mut T,
655 config: &Config,
656) -> Result<(Client, Connection<Socket, T::Stream>), Error>
657where
658 T: MakeTlsConnect<Socket>,
659{
660 match host {
661 Host::Tcp(host) => {
662 let mut addrs = net::lookup_host((&*host, port))
663 .await
664 .map_err(Error::connect)?
665 .collect::<Vec<_>>();
666
667 if config.load_balance_hosts == LoadBalanceHosts::Random {
668 addrs.shuffle(&mut rand::thread_rng());
669 }
670
671 let mut last_err = None;
672 for addr in addrs {
673 match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config)
674 .await
675 {
676 Ok(stream) => return Ok(stream),
677 Err(e) => {
678 last_err = Some(e);
679 continue;
680 }
681 };
682 }
683
684 Err(last_err.unwrap_or_else(|| {
685 Error::connect(io::Error::new(
686 io::ErrorKind::InvalidInput,
687 "could not resolve any addresses",
688 ))
689 }))
690 }
691 #[cfg(unix)]
692 Host::Unix(path) => {
693 connect_once(Addr::Unix(path), hostname.as_deref(), port, tls, config).await
694 }
695 }
696}
697
698async fn connect_once<T>(
699 addr: Addr,
700 hostname: Option<&str>,
701 port: u16,
702 tls: &mut T,
703 config: &Config,
704) -> Result<(Client, Connection<Socket, T::Stream>), Error>
705where
706 T: MakeTlsConnect<Socket>,
707{
708 let socket = connect_socket(
709 &addr,
710 port,
711 config.connect_timeout,
712 config.tcp_user_timeout,
713 if config.keepalives {
714 Some(&config.keepalive_config)
715 } else {
716 None
717 },
718 )
719 .await?;
720
721 let tls = tls
722 .make_tls_connect(hostname.unwrap_or(""))
723 .map_err(|e| Error::tls(e.into()))?;
724 let has_hostname = hostname.is_some();
725 let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?;
726
727 if let TargetSessionAttrs::ReadWrite = config.target_session_attrs {
728 let rows = client.simple_query_raw("SHOW transaction_read_only");
729 pin_mut!(rows);
730
731 let rows = future::poll_fn(|cx| {
732 if connection.poll_unpin(cx)?.is_ready() {
733 return Poll::Ready(Err(Error::closed()));
734 }
735
736 rows.as_mut().poll(cx)
737 })
738 .await?;
739 pin_mut!(rows);
740
741 loop {
742 let next = future::poll_fn(|cx| {
743 if connection.poll_unpin(cx)?.is_ready() {
744 return Poll::Ready(Some(Err(Error::closed())));
745 }
746
747 rows.as_mut().poll_next(cx)
748 });
749
750 match next.await.transpose()? {
751 Some(SimpleQueryMessage::Row(row)) => {
752 if row.try_get(0)? == Some("on") {
753 return Err(Error::connect(io::Error::new(
754 io::ErrorKind::PermissionDenied,
755 "database does not allow writes",
756 )));
757 } else {
758 break;
759 }
760 }
761 Some(_) => {}
762 None => return Err(Error::unexpected_message()),
763 }
764 }
765 }
766
767 client.set_socket_config(SocketConfig {
768 addr,
769 hostname: hostname.map(|s| s.to_string()),
770 port,
771 connect_timeout: config.connect_timeout,
772 tcp_user_timeout: config.tcp_user_timeout,
773 keepalive: if config.keepalives {
774 Some(config.keepalive_config.clone())
775 } else {
776 None
777 },
778 });
779
780 Ok((client, connection))
781}