1use super::{DatabasePool, pool::ConnectionPool};
2use std::time::Duration;
3use toml::value::Table;
4use zino_core::extension::TomlTableExt;
5
6pub trait PoolManager {
8 fn with_config(config: &'static Table) -> Self;
10
11 async fn check_availability(&self) -> bool;
13
14 async fn close(&self);
16}
17
18#[cfg(feature = "orm-sqlx")]
19impl PoolManager for ConnectionPool<DatabasePool> {
20 fn with_config(config: &'static Table) -> Self {
21 use sqlx::{Connection, Executor, pool::PoolOptions};
22
23 let name = config.get_str("name").unwrap_or("main");
24
25 let database = config
27 .get_str("database")
28 .expect("field `database` should be a str");
29 let mut connect_options = new_connect_options(database, config);
30 if let Some(statement_cache_capacity) = config.get_usize("statement-cache-capacity") {
31 connect_options = connect_options.statement_cache_capacity(statement_cache_capacity);
32 }
33
34 let max_connections = config.get_u32("max-connections").unwrap_or(16);
36 let min_connections = config.get_u32("min-connections").unwrap_or(1);
37 let max_lifetime = config
38 .get_duration("max-lifetime")
39 .unwrap_or_else(|| Duration::from_secs(24 * 60 * 60));
40 let idle_timeout = config
41 .get_duration("idle-timeout")
42 .unwrap_or_else(|| Duration::from_secs(60 * 60));
43 let acquire_timeout = config
44 .get_duration("acquire-timeout")
45 .unwrap_or_else(|| Duration::from_secs(60));
46 let health_check_interval = config.get_u64("health-check-interval").unwrap_or(60);
47 let pool = PoolOptions::<super::DatabaseDriver>::new()
48 .max_connections(max_connections)
49 .min_connections(min_connections)
50 .max_lifetime(max_lifetime)
51 .idle_timeout(idle_timeout)
52 .acquire_timeout(acquire_timeout)
53 .test_before_acquire(false)
54 .before_acquire(move |conn, meta| {
55 Box::pin(async move {
56 if meta.idle_for.as_secs() > health_check_interval
57 && let Some(cp) = super::GlobalPool::get(name)
58 {
59 if let Err(err) = conn.ping().await {
60 let name = cp.name();
61 cp.store_availability(false);
62 tracing::error!(
63 "fail to ping the database for the `{name}` service: {err}"
64 );
65 return Err(err);
66 } else {
67 cp.store_availability(true);
68 }
69 }
70 Ok(true)
71 })
72 })
73 .after_connect(|conn, _meta| {
74 Box::pin(async move {
75 if let Some(time_zone) = super::TIME_ZONE.get() {
76 if cfg!(any(
77 feature = "orm-mariadb",
78 feature = "orm-mysql",
79 feature = "orm-tidb"
80 )) {
81 let sql = format!("SET time_zone = '{time_zone}';");
82 conn.execute(sql.as_str()).await?;
83 } else if cfg!(feature = "orm-postgres") {
84 let sql = format!("SET TIME ZONE '{time_zone}';");
85 conn.execute(sql.as_str()).await?;
86 }
87 }
88 Ok(())
89 })
90 })
91 .connect_lazy_with(connect_options);
92 let connection_pool = Self::new(name, database, pool);
93 if config.get_bool("auto-migration").is_some_and(|b| !b) {
94 connection_pool.disable_auto_migration();
95 }
96 connection_pool
97 }
98
99 async fn check_availability(&self) -> bool {
100 let name = self.name();
101 if let Err(err) = self.pool().acquire().await {
102 tracing::error!("fail to acquire a connection for the `{name}` service: {err}");
103 self.store_availability(false);
104 false
105 } else {
106 tracing::info!("acquire a connection for the `{name}` service sucessfully");
107 self.store_availability(true);
108 true
109 }
110 }
111
112 async fn close(&self) {
113 let name = self.name();
114 tracing::warn!("closing the connection pool for the `{name}` service");
115 self.pool().close().await;
116 }
117}
118
119cfg_if::cfg_if! {
120 if #[cfg(any(feature = "orm-mariadb", feature = "orm-mysql", feature = "orm-tidb"))] {
121 use sqlx::mysql::{MySqlConnectOptions, MySqlSslMode};
122 use zino_core::state::State;
123
124 fn new_connect_options(database: &'static str, config: &'static Table) -> MySqlConnectOptions {
126 let username = config
127 .get_str("username")
128 .expect("field `username` should be a str");
129 let password =
130 State::decrypt_password(config).expect("field `password` should be a str");
131
132 let mut connect_options = MySqlConnectOptions::new()
133 .database(database)
134 .username(username)
135 .password(password.as_ref());
136 if let Some(host) = config.get_str("host") {
137 connect_options = connect_options.host(host);
138 }
139 if let Some(port) = config.get_u16("port") {
140 connect_options = connect_options.port(port);
141 }
142 if let Some(ssl_mode) = config.get_str("ssl-mode").and_then(|s| s.parse().ok()) {
143 connect_options = connect_options.ssl_mode(ssl_mode);
144 } else {
145 connect_options = connect_options.ssl_mode(MySqlSslMode::Disabled);
146 }
147 connect_options
148 }
149 } else if #[cfg(feature = "orm-postgres")] {
150 use sqlx::postgres::{PgConnectOptions, PgSslMode};
151 use zino_core::state::State;
152
153 fn new_connect_options(database: &'static str, config: &'static Table) -> PgConnectOptions {
155 let username = config
156 .get_str("username")
157 .expect("field `username` should be a str");
158 let password =
159 State::decrypt_password(config).expect("field `password` should be a str");
160
161 let mut connect_options = PgConnectOptions::new()
162 .database(database)
163 .username(username)
164 .password(password.as_ref());
165 if let Some(host) = config.get_str("host") {
166 connect_options = connect_options.host(host);
167 }
168 if let Some(port) = config.get_u16("port") {
169 connect_options = connect_options.port(port);
170 }
171 if let Some(ssl_mode) = config.get_str("ssl-mode").and_then(|s| s.parse().ok()) {
172 connect_options = connect_options.ssl_mode(ssl_mode);
173 } else {
174 connect_options = connect_options.ssl_mode(PgSslMode::Disable);
175 }
176 connect_options
177 }
178 } else {
179 use sqlx::sqlite::SqliteConnectOptions;
180 use zino_core::application::{Agent, Application};
181
182 fn new_connect_options(database: &'static str, config: &'static Table) -> SqliteConnectOptions {
184 let mut connect_options = SqliteConnectOptions::new().create_if_missing(true);
185 if let Some(read_only) = config.get_bool("read-only") {
186 connect_options = connect_options.read_only(read_only);
187 }
188
189 let database_path = Agent::parse_path(database);
190 connect_options.filename(database_path)
191 }
192 }
193}