1use super::{DatabasePool, pool::ConnectionPool};
2use std::time::Duration;
3use toml::value::Table;
4use zino_core::extension::TomlTableExt;
5
6pub trait PoolManager {
8 fn with_config(config: &'static Table) -> Self;
10
11 async fn check_availability(&self) -> bool;
13
14 async fn close(&self);
16}
17
18#[cfg(feature = "orm-sqlx")]
19impl PoolManager for ConnectionPool<DatabasePool> {
20 fn with_config(config: &'static Table) -> Self {
21 use sqlx::{Connection, Executor, pool::PoolOptions};
22
23 let name = config.get_str("name").unwrap_or("main");
24
25 let database = config
27 .get_str("database")
28 .expect("field `database` should be a str");
29 let mut connect_options = new_connect_options(database, config);
30 if let Some(statement_cache_capacity) = config.get_usize("statement-cache-capacity") {
31 connect_options = connect_options.statement_cache_capacity(statement_cache_capacity);
32 }
33
34 let max_connections = config.get_u32("max-connections").unwrap_or(16);
36 let min_connections = config.get_u32("min-connections").unwrap_or(1);
37 let max_lifetime = config
38 .get_duration("max-lifetime")
39 .unwrap_or_else(|| Duration::from_secs(24 * 60 * 60));
40 let idle_timeout = config
41 .get_duration("idle-timeout")
42 .unwrap_or_else(|| Duration::from_secs(60 * 60));
43 let acquire_timeout = config
44 .get_duration("acquire-timeout")
45 .unwrap_or_else(|| Duration::from_secs(60));
46 let health_check_interval = config.get_u64("health-check-interval").unwrap_or(60);
47 let pool = PoolOptions::<super::DatabaseDriver>::new()
48 .max_connections(max_connections)
49 .min_connections(min_connections)
50 .max_lifetime(max_lifetime)
51 .idle_timeout(idle_timeout)
52 .acquire_timeout(acquire_timeout)
53 .test_before_acquire(false)
54 .before_acquire(move |conn, meta| {
55 Box::pin(async move {
56 if meta.idle_for.as_secs() > health_check_interval {
57 if let Some(cp) = super::GlobalPool::get(name) {
58 if let Err(err) = conn.ping().await {
59 let name = cp.name();
60 cp.store_availability(false);
61 tracing::error!(
62 "fail to ping the database for the `{name}` service: {err}"
63 );
64 return Err(err);
65 } else {
66 cp.store_availability(true);
67 }
68 }
69 }
70 Ok(true)
71 })
72 })
73 .after_connect(|conn, _meta| {
74 Box::pin(async move {
75 if let Some(time_zone) = super::TIME_ZONE.get() {
76 if cfg!(any(
77 feature = "orm-mariadb",
78 feature = "orm-mysql",
79 feature = "orm-tidb"
80 )) {
81 let sql = format!("SET time_zone = '{time_zone}';");
82 conn.execute(sql.as_str()).await?;
83 } else if cfg!(feature = "orm-postgres") {
84 let sql = format!("SET TIME ZONE '{time_zone}';");
85 conn.execute(sql.as_str()).await?;
86 }
87 }
88 Ok(())
89 })
90 })
91 .connect_lazy_with(connect_options);
92 Self::new(name, database, pool)
93 }
94
95 async fn check_availability(&self) -> bool {
96 let name = self.name();
97 if let Err(err) = self.pool().acquire().await {
98 tracing::error!("fail to acquire a connection for the `{name}` service: {err}");
99 self.store_availability(false);
100 false
101 } else {
102 tracing::info!("acquire a connection for the `{name}` service sucessfully");
103 self.store_availability(true);
104 true
105 }
106 }
107
108 async fn close(&self) {
109 let name = self.name();
110 tracing::warn!("closing the connection pool for the `{name}` service");
111 self.pool().close().await;
112 }
113}
114
115cfg_if::cfg_if! {
116 if #[cfg(any(feature = "orm-mariadb", feature = "orm-mysql", feature = "orm-tidb"))] {
117 use sqlx::mysql::{MySqlConnectOptions, MySqlSslMode};
118 use zino_core::state::State;
119
120 fn new_connect_options(database: &'static str, config: &'static Table) -> MySqlConnectOptions {
122 let username = config
123 .get_str("username")
124 .expect("field `username` should be a str");
125 let password =
126 State::decrypt_password(config).expect("field `password` should be a str");
127
128 let mut connect_options = MySqlConnectOptions::new()
129 .database(database)
130 .username(username)
131 .password(password.as_ref());
132 if let Some(host) = config.get_str("host") {
133 connect_options = connect_options.host(host);
134 }
135 if let Some(port) = config.get_u16("port") {
136 connect_options = connect_options.port(port);
137 }
138 if let Some(ssl_mode) = config.get_str("ssl-mode").and_then(|s| s.parse().ok()) {
139 connect_options = connect_options.ssl_mode(ssl_mode);
140 } else {
141 connect_options = connect_options.ssl_mode(MySqlSslMode::Disabled);
142 }
143 connect_options
144 }
145 } else if #[cfg(feature = "orm-postgres")] {
146 use sqlx::postgres::{PgConnectOptions, PgSslMode};
147 use zino_core::state::State;
148
149 fn new_connect_options(database: &'static str, config: &'static Table) -> PgConnectOptions {
151 let username = config
152 .get_str("username")
153 .expect("field `username` should be a str");
154 let password =
155 State::decrypt_password(config).expect("field `password` should be a str");
156
157 let mut connect_options = PgConnectOptions::new()
158 .database(database)
159 .username(username)
160 .password(password.as_ref());
161 if let Some(host) = config.get_str("host") {
162 connect_options = connect_options.host(host);
163 }
164 if let Some(port) = config.get_u16("port") {
165 connect_options = connect_options.port(port);
166 }
167 if let Some(ssl_mode) = config.get_str("ssl-mode").and_then(|s| s.parse().ok()) {
168 connect_options = connect_options.ssl_mode(ssl_mode);
169 } else {
170 connect_options = connect_options.ssl_mode(PgSslMode::Disable);
171 }
172 connect_options
173 }
174 } else {
175 use sqlx::sqlite::SqliteConnectOptions;
176 use zino_core::application::{Agent, Application};
177
178 fn new_connect_options(database: &'static str, config: &'static Table) -> SqliteConnectOptions {
180 let mut connect_options = SqliteConnectOptions::new().create_if_missing(true);
181 if let Some(read_only) = config.get_bool("read-only") {
182 connect_options = connect_options.read_only(read_only);
183 }
184
185 let database_path = Agent::parse_path(database);
186 connect_options.filename(database_path)
187 }
188 }
189}