zino_orm/
manager.rs

1use super::{DatabasePool, pool::ConnectionPool};
2use std::time::Duration;
3use toml::value::Table;
4use zino_core::extension::TomlTableExt;
5
6/// A manager of the connection pool.
7pub trait PoolManager {
8    /// Connects lazily to the database according to the config.
9    fn with_config(config: &'static Table) -> Self;
10
11    /// Checks the availability of the connection pool.
12    async fn check_availability(&self) -> bool;
13
14    /// Shuts down the connection pool.
15    async fn close(&self);
16}
17
18#[cfg(feature = "orm-sqlx")]
19impl PoolManager for ConnectionPool<DatabasePool> {
20    fn with_config(config: &'static Table) -> Self {
21        use sqlx::{Connection, Executor, pool::PoolOptions};
22
23        let name = config.get_str("name").unwrap_or("main");
24
25        // Connect options.
26        let database = config
27            .get_str("database")
28            .expect("field `database` should be a str");
29        let mut connect_options = new_connect_options(database, config);
30        if let Some(statement_cache_capacity) = config.get_usize("statement-cache-capacity") {
31            connect_options = connect_options.statement_cache_capacity(statement_cache_capacity);
32        }
33
34        // Pool options.
35        let max_connections = config.get_u32("max-connections").unwrap_or(16);
36        let min_connections = config.get_u32("min-connections").unwrap_or(1);
37        let max_lifetime = config
38            .get_duration("max-lifetime")
39            .unwrap_or_else(|| Duration::from_secs(24 * 60 * 60));
40        let idle_timeout = config
41            .get_duration("idle-timeout")
42            .unwrap_or_else(|| Duration::from_secs(60 * 60));
43        let acquire_timeout = config
44            .get_duration("acquire-timeout")
45            .unwrap_or_else(|| Duration::from_secs(60));
46        let health_check_interval = config.get_u64("health-check-interval").unwrap_or(60);
47        let pool = PoolOptions::<super::DatabaseDriver>::new()
48            .max_connections(max_connections)
49            .min_connections(min_connections)
50            .max_lifetime(max_lifetime)
51            .idle_timeout(idle_timeout)
52            .acquire_timeout(acquire_timeout)
53            .test_before_acquire(false)
54            .before_acquire(move |conn, meta| {
55                Box::pin(async move {
56                    if meta.idle_for.as_secs() > health_check_interval
57                        && let Some(cp) = super::GlobalPool::get(name)
58                    {
59                        if let Err(err) = conn.ping().await {
60                            let name = cp.name();
61                            cp.store_availability(false);
62                            tracing::error!(
63                                "fail to ping the database for the `{name}` service: {err}"
64                            );
65                            return Err(err);
66                        } else {
67                            cp.store_availability(true);
68                        }
69                    }
70                    Ok(true)
71                })
72            })
73            .after_connect(|conn, _meta| {
74                Box::pin(async move {
75                    if let Some(time_zone) = super::TIME_ZONE.get() {
76                        if cfg!(any(
77                            feature = "orm-mariadb",
78                            feature = "orm-mysql",
79                            feature = "orm-tidb"
80                        )) {
81                            let sql = format!("SET time_zone = '{time_zone}';");
82                            conn.execute(sql.as_str()).await?;
83                        } else if cfg!(feature = "orm-postgres") {
84                            let sql = format!("SET TIME ZONE '{time_zone}';");
85                            conn.execute(sql.as_str()).await?;
86                        }
87                    }
88                    Ok(())
89                })
90            })
91            .connect_lazy_with(connect_options);
92        let connection_pool = Self::new(name, database, pool);
93        if config.get_bool("auto-migration").is_some_and(|b| !b) {
94            connection_pool.disable_auto_migration();
95        }
96        connection_pool
97    }
98
99    async fn check_availability(&self) -> bool {
100        let name = self.name();
101        if let Err(err) = self.pool().acquire().await {
102            tracing::error!("fail to acquire a connection for the `{name}` service: {err}");
103            self.store_availability(false);
104            false
105        } else {
106            tracing::info!("acquire a connection for the `{name}` service sucessfully");
107            self.store_availability(true);
108            true
109        }
110    }
111
112    async fn close(&self) {
113        let name = self.name();
114        tracing::warn!("closing the connection pool for the `{name}` service");
115        self.pool().close().await;
116    }
117}
118
119cfg_if::cfg_if! {
120    if #[cfg(any(feature = "orm-mariadb", feature = "orm-mysql", feature = "orm-tidb"))] {
121        use sqlx::mysql::{MySqlConnectOptions, MySqlSslMode};
122        use zino_core::state::State;
123
124        /// Options and flags which can be used to configure a MySQL connection.
125        fn new_connect_options(database: &'static str, config: &'static Table) -> MySqlConnectOptions {
126            let username = config
127                .get_str("username")
128                .expect("field `username` should be a str");
129            let password =
130                State::decrypt_password(config).expect("field `password` should be a str");
131
132            let mut connect_options = MySqlConnectOptions::new()
133                .database(database)
134                .username(username)
135                .password(password.as_ref());
136            if let Some(host) = config.get_str("host") {
137                connect_options = connect_options.host(host);
138            }
139            if let Some(port) = config.get_u16("port") {
140                connect_options = connect_options.port(port);
141            }
142            if let Some(ssl_mode) = config.get_str("ssl-mode").and_then(|s| s.parse().ok()) {
143                connect_options = connect_options.ssl_mode(ssl_mode);
144            } else {
145                connect_options = connect_options.ssl_mode(MySqlSslMode::Disabled);
146            }
147            connect_options
148        }
149    } else if #[cfg(feature = "orm-postgres")] {
150        use sqlx::postgres::{PgConnectOptions, PgSslMode};
151        use zino_core::state::State;
152
153        /// Options and flags which can be used to configure a PostgreSQL connection.
154        fn new_connect_options(database: &'static str, config: &'static Table) -> PgConnectOptions {
155            let username = config
156                .get_str("username")
157                .expect("field `username` should be a str");
158            let password =
159                State::decrypt_password(config).expect("field `password` should be a str");
160
161            let mut connect_options = PgConnectOptions::new()
162                .database(database)
163                .username(username)
164                .password(password.as_ref());
165            if let Some(host) = config.get_str("host") {
166                connect_options = connect_options.host(host);
167            }
168            if let Some(port) = config.get_u16("port") {
169                connect_options = connect_options.port(port);
170            }
171            if let Some(ssl_mode) = config.get_str("ssl-mode").and_then(|s| s.parse().ok()) {
172                connect_options = connect_options.ssl_mode(ssl_mode);
173            } else {
174                connect_options = connect_options.ssl_mode(PgSslMode::Disable);
175            }
176            connect_options
177        }
178    } else {
179        use sqlx::sqlite::SqliteConnectOptions;
180        use zino_core::application::{Agent, Application};
181
182        /// Options and flags which can be used to configure a SQLite connection.
183        fn new_connect_options(database: &'static str, config: &'static Table) -> SqliteConnectOptions {
184            let mut connect_options = SqliteConnectOptions::new().create_if_missing(true);
185            if let Some(read_only) = config.get_bool("read-only") {
186                connect_options = connect_options.read_only(read_only);
187            }
188
189            let database_path = Agent::parse_path(database);
190            connect_options.filename(database_path)
191        }
192    }
193}