yb_tokio_postgres/
connect.rs

1use crate::client::{Addr, SocketConfig};
2use crate::config::{Host, LoadBalanceHosts, TargetSessionAttrs};
3use crate::connect_raw::connect_raw;
4use crate::connect_socket::connect_socket;
5use crate::tls::MakeTlsConnect;
6use crate::{Client, Config, Connection, Error, NoTls, SimpleQueryMessage, Socket};
7use futures_util::{future, pin_mut, Future, FutureExt, Stream};
8use lazy_static::lazy_static;
9use log::{debug, info};
10use rand::seq::SliceRandom;
11use rand::Rng;
12use std::collections::HashMap;
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::sync::Mutex;
15use std::task::Poll;
16use std::time::Instant;
17use std::{cmp, io};
18use tokio::net;
19use tokio::sync::Mutex as TokioMutex;
20
21lazy_static! {
22    static ref CONNECTION_COUNT_MAP: Mutex<HashMap<Host, i64>> = {
23        let mut m = HashMap::new();
24        let host_list_primary = HOST_INFO_PRIMAY.lock().unwrap().clone();
25        let host_list_rr = HOST_INFO_RR.lock().unwrap().clone();
26        let host_list = [host_list_primary, host_list_rr].concat();
27        let size = host_list.len();
28        for i in 0..size {
29            let host = host_list.get(i);
30            if host.is_some() {
31                m.insert(host.unwrap().clone(), 0);
32            }
33        }
34        Mutex::new(m)
35    };
36    static ref LAST_TIME_META_DATA_FETCHED: TokioMutex<Instant> = {
37        let m = Instant::now();
38        TokioMutex::new(m)
39    };
40    static ref HOST_INFO_PRIMAY: Mutex<Vec<Host>> = {
41        let m = Vec::new();
42        Mutex::new(m)
43    };
44    static ref HOST_INFO_RR: Mutex<Vec<Host>> = {
45        let m = Vec::new();
46        Mutex::new(m)
47    };
48    static ref FAILED_HOSTS: Mutex<HashMap<Host, Instant>> = {
49        let m = HashMap::new();
50        Mutex::new(m)
51    };
52    pub(crate) static ref PLACEMENT_INFO_MAP_PRIMARY: Mutex<HashMap<String, Vec<Host>>> = {
53        let m = HashMap::new();
54        Mutex::new(m)
55    };
56    pub(crate) static ref PLACEMENT_INFO_MAP_RR: Mutex<HashMap<String, Vec<Host>>> = {
57        let m = HashMap::new();
58        Mutex::new(m)
59    };
60    static ref PUBLIC_HOST_MAP: Mutex<HashMap<Host, Host>> = {
61        let m = HashMap::new();
62        Mutex::new(m)
63    };
64    static ref HOST_TO_PORT_MAP: Mutex<HashMap<Host, u16>> = {
65        let m = HashMap::new();
66        Mutex::new(m)
67    };
68}
69
70static USE_PUBLIC_IP: AtomicBool = AtomicBool::new(false);
71
72pub async fn connect<T>(
73    mut tls: T,
74    config: &Config,
75) -> Result<(Client, Connection<Socket, T::Stream>), Error>
76where
77    T: MakeTlsConnect<Socket>,
78{
79    if config.host.is_empty() && config.hostaddr.is_empty() {
80        return Err(Error::config("both host and hostaddr are missing".into()));
81    }
82
83    if !config.host.is_empty()
84        && !config.hostaddr.is_empty()
85        && config.host.len() != config.hostaddr.len()
86    {
87        let msg = format!(
88            "number of hosts ({}) is different from number of hostaddrs ({})",
89            config.host.len(),
90            config.hostaddr.len(),
91        );
92        return Err(Error::config(msg.into()));
93    }
94
95    // At this point, either one of the following two scenarios could happen:
96    // (1) either config.host or config.hostaddr must be empty;
97    // (2) if both config.host and config.hostaddr are NOT empty; their lengths must be equal.
98    let num_hosts = cmp::max(config.host.len(), config.hostaddr.len());
99
100    if config.port.len() > 1 && config.port.len() != num_hosts {
101        return Err(Error::config("invalid number of ports".into()));
102    }
103
104    let mut indices = (0..num_hosts).collect::<Vec<_>>();
105    if config.load_balance_hosts == LoadBalanceHosts::Random {
106        indices.shuffle(&mut rand::thread_rng());
107    }
108
109    let mut error = None;
110    for i in indices {
111        let host = config.host.get(i);
112        let hostaddr = config.hostaddr.get(i);
113        let port = config
114            .port
115            .get(i)
116            .or_else(|| config.port.first())
117            .copied()
118            .unwrap_or(5433);
119
120        // The value of host is used as the hostname for TLS validation,
121        let hostname = match host {
122            Some(Host::Tcp(host)) => Some(host.clone()),
123            // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
124            #[cfg(unix)]
125            Some(Host::Unix(_)) => None,
126            None => None,
127        };
128
129        // Try to use the value of hostaddr to establish the TCP connection,
130        // fallback to host if hostaddr is not present.
131        let addr = match hostaddr {
132            Some(ipaddr) => Host::Tcp(ipaddr.to_string()),
133            None => host.cloned().unwrap(),
134        };
135
136        match connect_host(addr, hostname, port, &mut tls, config).await {
137            Ok((client, connection)) => return Ok((client, connection)),
138            Err(e) => error = Some(e),
139        }
140    }
141
142    Err(error.unwrap())
143}
144
145pub async fn yb_connect<T>(
146    mut tls: T,
147    config: &Config,
148) -> Result<(Client, Connection<Socket, T::Stream>), Error>
149where
150    T: MakeTlsConnect<Socket>,
151{
152    if config.host.is_empty() && config.hostaddr.is_empty() {
153        return Err(Error::config("both host and hostaddr are missing".into()));
154    }
155
156    if !config.host.is_empty()
157        && !config.hostaddr.is_empty()
158        && config.host.len() != config.hostaddr.len()
159    {
160        let msg = format!(
161            "number of hosts ({}) is different from number of hostaddrs ({})",
162            config.host.len(),
163            config.hostaddr.len(),
164        );
165        return Err(Error::config(msg.into()));
166    }
167
168    // At this point, either one of the following two scenarios could happen:
169    // (1) either config.host or config.hostaddr must be empty;
170    // (2) if both config.host and config.hostaddr are NOT empty; their lengths must be equal.
171    let num_hosts = cmp::max(config.host.len(), config.hostaddr.len());
172
173    if config.port.len() > 1 && config.port.len() != num_hosts {
174        return Err(Error::config("invalid number of ports".into()));
175    }
176
177    if !check_and_refresh(config).await {
178        return Err(Error::connect(io::Error::new(
179            io::ErrorKind::ConnectionRefused,
180            "could not create control connection",
181        )));
182    }
183
184    let host_to_port_map = HOST_TO_PORT_MAP.lock().unwrap().clone();
185
186    loop {
187        let newhost = get_least_loaded_server(config);
188        let mut host = match newhost {
189            Ok(host) => host,
190            Err(e) => {
191                // Throw the error
192                return Err(e);
193            }
194        };
195
196        increase_connection_count(host.clone());
197
198        //check if we are to use public hosts
199        if USE_PUBLIC_IP.load(Ordering::SeqCst) {
200            let public_host_map = PUBLIC_HOST_MAP.lock().unwrap().clone();
201            let public_host = public_host_map.get(&host.clone());
202            if public_host.is_none() {
203                info!("Public host not available for private host {:?}, adding this to failed host list and trying another server", host.clone());
204                decrease_connection_count(host.clone());
205                add_to_failed_host_list(host.clone());
206                continue;
207            } else {
208                host = public_host.unwrap().clone();
209            }
210        }
211
212        let hostname = match host.clone() {
213            Host::Tcp(host) => Some(host),
214            // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
215            #[cfg(unix)]
216            Host::Unix(_) => None,
217        };
218
219        info!("Creating connection to {:?}", hostname.clone());
220        match connect_host(
221            host.clone(),
222            hostname.clone(),
223            host_to_port_map[&(host.clone())],
224            &mut tls,
225            config,
226        )
227        .await
228        {
229            Ok((client, connection)) => return Ok((client, connection)),
230            Err(_e) => {
231                info!("Not able to create connection to {:?}, adding it to failed host list and trying a different host.",  hostname.clone());
232                decrease_connection_count(host.clone());
233                add_to_failed_host_list(host);
234            }
235        }
236    }
237}
238
239fn increase_connection_count(host: Host) {
240    let mut conn_map = CONNECTION_COUNT_MAP.lock().unwrap();
241    let count = conn_map.get(&host);
242    if count.is_none() {
243        conn_map.insert(host.clone(), 1);
244        debug!("Increasing connection count for {:?} to 1", host.clone());
245    } else {
246        let mut conn_count: i64 = *count.unwrap();
247        conn_count += 1;
248        conn_map.insert(host.clone(), conn_count);
249        debug!(
250            "Increasing connection count for {:?} by one: {}",
251            host.clone(),
252            conn_count
253        );
254    }
255}
256
257pub(crate) fn decrease_connection_count(host: Host) {
258    let mut conn_map = CONNECTION_COUNT_MAP.lock().unwrap();
259    let count = conn_map.get(&host);
260    if count.is_some() {
261        let mut conn_count: i64 = *count.unwrap();
262        if conn_count != 0 {
263            conn_count -= 1;
264            conn_map.insert(host.clone(), conn_count);
265            debug!(
266                "Decremented connection count for {:?} by one: {}",
267                host.clone(),
268                conn_count
269            );
270        }
271    }
272}
273
274fn get_least_loaded_server(config: &Config) -> Result<Host, Error> {
275    let conn_map = CONNECTION_COUNT_MAP.lock().unwrap().clone();
276    let host_list_primary = HOST_INFO_PRIMAY.lock().unwrap().clone();
277    let host_list_rr = HOST_INFO_RR.lock().unwrap().clone();
278    let failed_host_list = FAILED_HOSTS.lock().unwrap().clone();
279    let placement_info_map_primary = PLACEMENT_INFO_MAP_PRIMARY.lock().unwrap().clone();
280    let placement_info_map_rr = PLACEMENT_INFO_MAP_RR.lock().unwrap().clone();
281    let mut least_host: Vec<Host> = Vec::new();
282
283    let mut host_list: Vec<Host>;
284    let mut placement_info_map: HashMap<String, Vec<Host>>;
285
286    if config.load_balance == "only-rr" || config.load_balance == "prefer-rr" {
287        host_list = host_list_rr.clone();
288        placement_info_map = placement_info_map_rr.clone();
289    } else if config.load_balance == "only-primary" || config.load_balance == "prefer-primary" {
290        host_list = host_list_primary.clone();
291        placement_info_map = placement_info_map_primary.clone();
292    } else {
293        host_list = host_list_rr.clone();
294        placement_info_map = placement_info_map_rr.clone();
295        host_list.extend(host_list_primary.clone());
296        for (key, value) in placement_info_map_primary {
297            if let Some(vec) = placement_info_map.get_mut(&key) {
298                vec.extend(value);
299            } else {
300                placement_info_map.insert(key, value);
301            }
302        }
303    }
304
305    if !config.topology_keys.is_empty() {
306        for i in 0..config.topology_keys.len() as i64 {
307            let mut server: Vec<Host> = Vec::new();
308            let prefered_zone = config.topology_keys.get(&(i + 1)).unwrap();
309            for placement_info in prefered_zone.iter() {
310                let to_check_star: Vec<&str> = placement_info.split(".").collect();
311                if to_check_star[2] == "*" {
312                    let star_placement_info: String =
313                        to_check_star[0].to_owned() + "." + to_check_star[1];
314                    let append_hosts = placement_info_map.get(&star_placement_info);
315                    if let Some(append_hosts_value) = append_hosts {
316                        server.extend(append_hosts_value.to_owned());
317                    }
318                } else {
319                    let append_hosts = placement_info_map.get(placement_info);
320                    if let Some(append_hosts_value) = append_hosts {
321                        server.extend(append_hosts_value.to_owned());
322                    }
323                }
324            }
325            least_host = get_least_loaded_hosts(server, conn_map.clone(), failed_host_list.clone());
326
327            if !least_host.is_empty() {
328                break;
329            }
330        }
331    }
332
333    if least_host.is_empty() {
334        if !(config.load_balance == "prefer-primary" || config.load_balance == "prefer-rr") {
335            if config.topology_keys.is_empty() || !config.fallback_to_topology_keys_only {
336                least_host = get_least_loaded_hosts(host_list, conn_map.clone(), failed_host_list.clone());
337            } else {
338                return Err(Error::connect(io::Error::new(
339                    io::ErrorKind::ConnectionRefused,
340                    "no preferred server available, fallback-to-topology-keys-only is set to true",
341                )));
342            }
343        } else {
344            least_host = get_least_loaded_hosts(host_list, conn_map.clone(), failed_host_list.clone());
345            if least_host.is_empty() {
346                if config.load_balance == "prefer-rr"{
347                    least_host = get_least_loaded_hosts(host_list_primary, conn_map.clone(), failed_host_list.clone());
348                } else {
349                    least_host = get_least_loaded_hosts(host_list_rr, conn_map.clone(), failed_host_list.clone());
350                }
351            }
352        }
353    }
354
355    if !least_host.is_empty() {
356        info!(
357            "Following hosts have the least number of connections: {:?}, chosing one randomly",
358            least_host
359        );
360        let num = rand::thread_rng().gen_range(0..least_host.len());
361        Ok(least_host.get(num).cloned().expect("least loaded host value is None"))
362    } else {
363        Err(Error::connect(io::Error::new(
364            io::ErrorKind::ConnectionRefused,
365            "could not find a server to connect to",
366        )))
367    }
368}
369
370fn get_least_loaded_hosts(hosts: Vec<Host>, conn_map: HashMap<Host, i64>, failed_hosts: HashMap<Host, Instant>) -> Vec<Host> {
371    let mut min_count = i64::MAX;
372    let mut least_host: Vec<Host> = Vec::new();
373    for host in hosts.iter() {
374        if !failed_hosts.contains_key(host) {
375            let count = conn_map.get(host);
376            let mut counter: i64 = 0;
377            if count.is_some() {
378                counter = *count.unwrap();
379            }
380            if min_count > counter {
381                min_count = counter;
382                least_host.clear();
383                least_host.push(host.clone());
384            } else if min_count == counter {
385                least_host.push(host.clone());
386            }
387        }
388    }
389    least_host
390}
391
392async fn check_and_refresh(config: &Config) -> bool {
393    let mut refresh_time = LAST_TIME_META_DATA_FETCHED.lock().await;
394    let host_list_primary = HOST_INFO_PRIMAY.lock().unwrap().clone();
395    let host_list_rr = HOST_INFO_RR.lock().unwrap().clone();
396    let host_list = [host_list_primary, host_list_rr].concat();
397    if host_list.is_empty() {
398        info!("Connecting to the server for the first time");
399        if let Ok((client, connection)) = connect(NoTls, config).await {
400            let handle = tokio::spawn(async move {
401                if let Err(e) = connection.await {
402                    eprintln!("connection error: {}", e);
403                }
404            });
405            info!("Control connection created to one of {:?}", config.host);
406            refresh(client, config).await;
407            let start = Instant::now();
408            *refresh_time = start;
409            info!("Resetting LAST_TIME_META_DATA_FETCHED");
410            handle.abort();
411            return true;
412        } else {
413            info!("Failed to establish control connection to available servers");
414            return false;
415        }
416    } else {
417        let duration = refresh_time.elapsed();
418        if duration > config.yb_servers_refresh_interval {
419            let host_to_port_map = HOST_TO_PORT_MAP.lock().unwrap().clone();
420            let mut index = 0;
421            while index < host_list.len() {
422                let host = host_list.get(index);
423                let mut conn_host = host.unwrap().to_owned();
424                //check if we are to use public hosts
425                if USE_PUBLIC_IP.load(Ordering::SeqCst) {
426                    let public_host_map = PUBLIC_HOST_MAP.lock().unwrap().clone();
427                    let public_host = public_host_map.get(&conn_host.clone());
428                    if public_host.is_none() {
429                        info!("Public host not available for private host {:?}, adding this to failed host list and trying another server", conn_host.clone());
430                        add_to_failed_host_list(host.cloned().unwrap());
431                        index += 1;
432                        continue;
433                    } else {
434                        conn_host = public_host.unwrap().clone();
435                    }
436                }
437
438                // The value of host is used as the hostname for TLS validation,
439                let hostname = match conn_host.clone() {
440                    Host::Tcp(host) => Some(host),
441                    // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
442                    #[cfg(unix)]
443                    Host::Unix(_) => None,
444                };
445
446                if let Ok((client, connection)) = connect_host(
447                    conn_host.clone(),
448                    hostname.clone(),
449                    host_to_port_map[&(conn_host.clone())],
450                    &mut NoTls,
451                    config,
452                )
453                .await
454                {
455                    let handle = tokio::spawn(async move {
456                        if let Err(e) = connection.await {
457                            eprintln!("connection error: {}", e);
458                        }
459                    });
460                    info!("Control connection created to {:?}", hostname.clone());
461                    refresh(client, config).await;
462                    let start = Instant::now();
463                    *refresh_time = start;
464                    info!("Resetting LAST_TIME_META_DATA_FETCHED");
465                    handle.abort();
466                    return true;
467                } else {
468                    info!("Failed to establish control connection to {:?}, adding this to failed host list and trying another server", hostname.clone());
469                    add_to_failed_host_list(host.cloned().unwrap());
470                    index += 1;
471                }
472            }
473            info!("Failed to establish control connection to available servers");
474            return false;
475        }
476    }
477    true
478}
479
480fn add_to_failed_host_list(host: Host) {
481    let mut failedhostlist = FAILED_HOSTS.lock().unwrap();
482    failedhostlist.insert(host.clone(), Instant::now());
483    info!("Added {:?} to failed host list", host.clone());
484}
485
486async fn refresh(client: Client, config: &Config) {
487    let socket_config = client.get_socket_config();
488    let mut control_conn_host: String = String::new();
489    if socket_config.is_some() {
490        control_conn_host = socket_config.unwrap().hostname.unwrap();
491    }
492
493    info!("Executing query: `select * from yb_servers()` to fetch list of servers");
494    let rows = client
495        .query("select * from yb_servers()", &[])
496        .await
497        .unwrap();
498
499    let mut host_list_primary = HOST_INFO_PRIMAY.lock().unwrap();
500    let mut host_list_rr = HOST_INFO_RR.lock().unwrap();
501    let mut failed_host_list = FAILED_HOSTS.lock().unwrap();
502    let mut placement_info_map_primary = PLACEMENT_INFO_MAP_PRIMARY.lock().unwrap();
503    let mut placement_info_map_rr = PLACEMENT_INFO_MAP_RR.lock().unwrap();
504    let mut public_host_map = PUBLIC_HOST_MAP.lock().unwrap();
505    let mut host_to_port_map = HOST_TO_PORT_MAP.lock().unwrap();
506    for row in rows {
507        let host_string: String = row.get("host");
508        let host = Host::Tcp(host_string.to_string());
509        info!("Received entry for host {:?}", host);
510        let nodetype: String = row.get("node_type");
511        let portvalue: i64 = row.get("port");
512        let port: u16 = portvalue as u16;
513        let cloud: String = row.get("cloud");
514        let region: String = row.get("region");
515        let zone: String = row.get("zone");
516        let public_ip_string: String = row.get("public_ip");
517        let public_ip = Host::Tcp(public_ip_string.to_string());
518        let placement_zone: String = cloud.clone() + "." + &region + "." + &zone;
519        let star_placement_zone: String = cloud.clone() + "." + &region;
520
521        host_to_port_map.insert(host.clone(), port);
522        host_to_port_map.insert(public_ip.clone(), port);
523
524        if control_conn_host.eq_ignore_ascii_case(&public_ip_string) {
525            USE_PUBLIC_IP.store(true, Ordering::SeqCst);
526        }
527
528        if !failed_host_list.contains_key(&host) {
529            if nodetype == "primary" {
530                if !host_list_primary.contains(&host) {
531                    host_list_primary.push(host.clone());
532                    public_host_map.insert(host.clone(), public_ip.clone());
533                    debug!("Added {:?} to host list primary", host.clone());
534                }
535            } else {
536                if !host_list_rr.contains(&host) {
537                    host_list_rr.push(host.clone());
538                    public_host_map.insert(host.clone(), public_ip.clone());
539                    debug!("Added {:?} to host list RR", host.clone());
540                }
541            }
542        } else {
543            if failed_host_list.get(&host).unwrap().elapsed()
544                > config.failed_host_reconnect_delay_secs
545            {
546                failed_host_list.remove(&host);
547                debug!(
548                    "Marking {:?} as UP since failed-host-reconnect-delay-secs has elapsed",
549                    host.clone()
550                );
551                if nodetype == "primary" {
552                    if !host_list_primary.contains(&host) {
553                        host_list_primary.push(host.clone());
554                        public_host_map.insert(host.clone(), public_ip.clone());
555                        debug!("Added {:?} to host list primary", host.clone());
556                    }
557                } else {
558                    if !host_list_rr.contains(&host) {
559                        host_list_rr.push(host.clone());
560                        public_host_map.insert(host.clone(), public_ip.clone());
561                        debug!("Added {:?} to host list RR", host.clone());
562                    }
563                }
564                make_connection_count_zero(host.clone());
565            } else if host_list_primary.contains(&host) || host_list_rr.contains(&host) {
566                debug!(
567                    "Treating {:?} as DOWN since failed-host-reconnect-delay-secs has not elapsed",
568                    host.clone()
569                );
570                if host_list_primary.contains(&host) {
571                    let index = host_list_primary.iter().position(|x| *x == host).unwrap();
572                    host_list_primary.remove(index);
573                } else {
574                    let index = host_list_rr.iter().position(|x| *x == host).unwrap();
575                    host_list_rr.remove(index);
576                }
577                public_host_map.remove(&host);
578            }
579        }
580
581        if nodetype == "primary" {
582            if placement_info_map_primary.contains_key(&placement_zone) {
583                let mut present_hosts = placement_info_map_primary.get(&placement_zone).unwrap().to_vec();
584                if !present_hosts.contains(&host) {
585                    present_hosts.push(host.clone());
586                    placement_info_map_primary.insert(placement_zone.clone(), present_hosts.to_vec());
587                }
588            } else {
589                let mut host_vec: Vec<Host> = Vec::new();
590                host_vec.push(host.clone());
591                placement_info_map_primary.insert(placement_zone.clone(), host_vec);
592            }
593
594            if placement_info_map_primary.contains_key(&star_placement_zone) {
595                let mut star_present_hosts = placement_info_map_primary
596                    .get(&star_placement_zone)
597                    .unwrap()
598                    .to_vec();
599                if !star_present_hosts.contains(&host) {
600                    star_present_hosts.push(host.clone());
601                    placement_info_map_primary.insert(star_placement_zone.clone(), star_present_hosts.to_vec());
602                }
603            } else {
604                let mut star_host_vec: Vec<Host> = Vec::new();
605                star_host_vec.push(host.clone());
606                placement_info_map_primary.insert(star_placement_zone.clone(), star_host_vec);
607            }
608        } else {
609            if placement_info_map_rr.contains_key(&placement_zone) {
610                let mut present_hosts = placement_info_map_rr.get(&placement_zone).unwrap().to_vec();
611                if !present_hosts.contains(&host) {
612                    present_hosts.push(host.clone());
613                    placement_info_map_rr.insert(placement_zone, present_hosts.to_vec());
614                }
615            } else {
616                let mut host_vec: Vec<Host> = Vec::new();
617                host_vec.push(host.clone());
618                placement_info_map_rr.insert(placement_zone, host_vec);
619            }
620
621            if placement_info_map_rr.contains_key(&star_placement_zone) {
622                let mut star_present_hosts = placement_info_map_rr
623                    .get(&star_placement_zone)
624                    .unwrap()
625                    .to_vec();
626                if !star_present_hosts.contains(&host) {
627                    star_present_hosts.push(host.clone());
628                    placement_info_map_rr.insert(star_placement_zone, star_present_hosts.to_vec());
629                }
630            } else {
631                let mut star_host_vec: Vec<Host> = Vec::new();
632                star_host_vec.push(host.clone());
633                placement_info_map_rr.insert(star_placement_zone, star_host_vec);
634            }
635        }
636    }
637}
638
639fn make_connection_count_zero(host: Host) {
640    let mut conn_map = CONNECTION_COUNT_MAP.lock().unwrap();
641    let count = conn_map.get(&host);
642    if count.is_none() {
643        return;
644    }
645    conn_map.insert(host.clone(), 0);
646    debug!("Resetting connection count for {:?} to zero", host.clone());
647}
648
649
650async fn connect_host<T>(
651    host: Host,
652    hostname: Option<String>,
653    port: u16,
654    tls: &mut T,
655    config: &Config,
656) -> Result<(Client, Connection<Socket, T::Stream>), Error>
657where
658    T: MakeTlsConnect<Socket>,
659{
660    match host {
661        Host::Tcp(host) => {
662            let mut addrs = net::lookup_host((&*host, port))
663                .await
664                .map_err(Error::connect)?
665                .collect::<Vec<_>>();
666
667            if config.load_balance_hosts == LoadBalanceHosts::Random {
668                addrs.shuffle(&mut rand::thread_rng());
669            }
670
671            let mut last_err = None;
672            for addr in addrs {
673                match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config)
674                    .await
675                {
676                    Ok(stream) => return Ok(stream),
677                    Err(e) => {
678                        last_err = Some(e);
679                        continue;
680                    }
681                };
682            }
683
684            Err(last_err.unwrap_or_else(|| {
685                Error::connect(io::Error::new(
686                    io::ErrorKind::InvalidInput,
687                    "could not resolve any addresses",
688                ))
689            }))
690        }
691        #[cfg(unix)]
692        Host::Unix(path) => {
693            connect_once(Addr::Unix(path), hostname.as_deref(), port, tls, config).await
694        }
695    }
696}
697
698async fn connect_once<T>(
699    addr: Addr,
700    hostname: Option<&str>,
701    port: u16,
702    tls: &mut T,
703    config: &Config,
704) -> Result<(Client, Connection<Socket, T::Stream>), Error>
705where
706    T: MakeTlsConnect<Socket>,
707{
708    let socket = connect_socket(
709        &addr,
710        port,
711        config.connect_timeout,
712        config.tcp_user_timeout,
713        if config.keepalives {
714            Some(&config.keepalive_config)
715        } else {
716            None
717        },
718    )
719    .await?;
720
721    let tls = tls
722        .make_tls_connect(hostname.unwrap_or(""))
723        .map_err(|e| Error::tls(e.into()))?;
724    let has_hostname = hostname.is_some();
725    let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?;
726
727    if let TargetSessionAttrs::ReadWrite = config.target_session_attrs {
728        let rows = client.simple_query_raw("SHOW transaction_read_only");
729        pin_mut!(rows);
730
731        let rows = future::poll_fn(|cx| {
732            if connection.poll_unpin(cx)?.is_ready() {
733                return Poll::Ready(Err(Error::closed()));
734            }
735
736            rows.as_mut().poll(cx)
737        })
738        .await?;
739        pin_mut!(rows);
740
741        loop {
742            let next = future::poll_fn(|cx| {
743                if connection.poll_unpin(cx)?.is_ready() {
744                    return Poll::Ready(Some(Err(Error::closed())));
745                }
746
747                rows.as_mut().poll_next(cx)
748            });
749
750            match next.await.transpose()? {
751                Some(SimpleQueryMessage::Row(row)) => {
752                    if row.try_get(0)? == Some("on") {
753                        return Err(Error::connect(io::Error::new(
754                            io::ErrorKind::PermissionDenied,
755                            "database does not allow writes",
756                        )));
757                    } else {
758                        break;
759                    }
760                }
761                Some(_) => {}
762                None => return Err(Error::unexpected_message()),
763            }
764        }
765    }
766
767    client.set_socket_config(SocketConfig {
768        addr,
769        hostname: hostname.map(|s| s.to_string()),
770        port,
771        connect_timeout: config.connect_timeout,
772        tcp_user_timeout: config.tcp_user_timeout,
773        keepalive: if config.keepalives {
774            Some(config.keepalive_config.clone())
775        } else {
776            None
777        },
778    });
779
780    Ok((client, connection))
781}