zino_orm/
lib.rs

1#![cfg_attr(docsrs, feature(doc_auto_cfg))]
2#![doc = include_str!("../README.md")]
3#![doc(html_favicon_url = "https://zino.cc/assets/zino-logo.png")]
4#![doc(html_logo_url = "https://zino.cc/assets/zino-logo.svg")]
5#![allow(async_fn_in_trait)]
6
7use smallvec::SmallVec;
8use std::sync::{
9    atomic::{AtomicBool, AtomicUsize, Ordering::Relaxed},
10    OnceLock,
11};
12use zino_core::{extension::TomlTableExt, state::State, LazyLock};
13
14mod accessor;
15mod aggregate;
16mod column;
17mod entity;
18mod executor;
19mod helper;
20mod join;
21mod manager;
22mod mutation;
23mod pool;
24mod query;
25mod row;
26mod schema;
27mod transaction;
28mod value;
29mod window;
30
31pub use accessor::ModelAccessor;
32pub use aggregate::Aggregation;
33pub use column::EncodeColumn;
34pub use entity::Entity;
35pub use executor::Executor;
36pub use helper::ModelHelper;
37pub use join::JoinOn;
38pub use manager::PoolManager;
39pub use mutation::MutationBuilder;
40pub use pool::ConnectionPool;
41pub use query::QueryBuilder;
42pub use row::DecodeRow;
43pub use schema::Schema;
44pub use transaction::Transaction;
45pub use value::IntoSqlValue;
46pub use window::Window;
47
48#[cfg(feature = "orm-sqlx")]
49mod decode;
50#[cfg(feature = "orm-sqlx")]
51mod scalar;
52
53#[cfg(feature = "orm-sqlx")]
54pub use decode::{decode, decode_array, decode_decimal, decode_optional, decode_uuid};
55#[cfg(feature = "orm-sqlx")]
56pub use scalar::ScalarQuery;
57
58cfg_if::cfg_if! {
59    if #[cfg(any(feature = "orm-mariadb", feature = "orm-mysql", feature = "orm-tidb"))] {
60        mod mysql;
61
62        /// Driver name.
63        static DRIVER_NAME: &str = if cfg!(feature = "orm-mariadb") {
64            "mariadb"
65        } else if cfg!(feature = "orm-tidb") {
66            "tidb"
67        } else {
68            "mysql"
69        };
70
71        /// MySQL database driver.
72        pub type DatabaseDriver = sqlx::MySql;
73
74        /// MySQL database pool.
75        pub type DatabasePool = sqlx::MySqlPool;
76
77        /// MySQL database connection.
78        pub type DatabaseConnection = sqlx::MySqlConnection;
79
80        /// A single row from the MySQL database.
81        pub type DatabaseRow = sqlx::mysql::MySqlRow;
82    } else if #[cfg(feature = "orm-postgres")] {
83        mod postgres;
84
85        /// Driver name.
86        static DRIVER_NAME: &str = "postgres";
87
88        /// PostgreSQL database driver.
89        pub type DatabaseDriver = sqlx::Postgres;
90
91        /// PostgreSQL database pool.
92        pub type DatabasePool = sqlx::PgPool;
93
94        /// PostgreSQL database connection.
95        pub type DatabaseConnection = sqlx::PgConnection;
96
97        /// A single row from the PostgreSQL database.
98        pub type DatabaseRow = sqlx::postgres::PgRow;
99    } else {
100        mod sqlite;
101
102        /// Driver name.
103        static DRIVER_NAME: &str = "sqlite";
104
105        /// SQLite database driver.
106        pub type DatabaseDriver = sqlx::Sqlite;
107
108        /// SQLite database pool.
109        pub type DatabasePool = sqlx::SqlitePool;
110
111        /// SQLite database connection.
112        pub type DatabaseConnection = sqlx::SqliteConnection;
113
114        /// A single row from the SQLite database.
115        pub type DatabaseRow = sqlx::sqlite::SqliteRow;
116    }
117}
118
119/// A list of database connection pools.
120#[derive(Debug)]
121struct ConnectionPools(SmallVec<[ConnectionPool; 4]>);
122
123impl ConnectionPools {
124    /// Returns a connection pool with the specific name.
125    pub(crate) fn get_pool(&self, name: &str) -> Option<&ConnectionPool> {
126        let mut pool = None;
127        for cp in self.0.iter().filter(|cp| cp.name() == name) {
128            if cp.is_available() {
129                return Some(cp);
130            } else {
131                pool = Some(cp);
132            }
133        }
134        pool
135    }
136}
137
138/// Global access to the shared connection pools.
139#[derive(Debug, Clone, Copy, Default)]
140pub struct GlobalPool;
141
142impl GlobalPool {
143    /// Gets the connection pool for the specific service.
144    #[inline]
145    pub fn get(name: &str) -> Option<&'static ConnectionPool> {
146        SHARED_CONNECTION_POOLS.get_pool(name)
147    }
148
149    /// Iterates over the shared connection pools and
150    /// attempts to establish a database connection for each of them.
151    #[inline]
152    pub async fn connect_all() {
153        for cp in SHARED_CONNECTION_POOLS.0.iter() {
154            cp.check_availability().await;
155        }
156    }
157
158    /// Shuts down the shared connection pools to ensure all connections are gracefully closed.
159    #[inline]
160    pub async fn close_all() {
161        for cp in SHARED_CONNECTION_POOLS.0.iter() {
162            cp.close().await;
163        }
164    }
165}
166
167/// Shared connection pools.
168static SHARED_CONNECTION_POOLS: LazyLock<ConnectionPools> = LazyLock::new(|| {
169    let config = State::shared().config();
170    let mut database_type = DRIVER_NAME;
171    if let Some(database) = config.get_table("database") {
172        if let Some(driver) = database.get_str("type") {
173            database_type = driver;
174        }
175        if let Some(time_zone) = database.get_str("time-zone") {
176            TIME_ZONE
177                .set(time_zone)
178                .expect("fail to set time zone for the database session");
179        }
180        if let Some(max_rows) = database.get_usize("max-rows") {
181            MAX_ROWS.store(max_rows, Relaxed);
182        }
183        if let Some(auto_migration) = database.get_bool("auto-migration") {
184            AUTO_MIGRATION.store(auto_migration, Relaxed);
185        }
186        if let Some(debug_only) = database.get_bool("debug-only") {
187            DEBUG_ONLY.store(debug_only, Relaxed);
188        }
189    }
190
191    // Database connection pools.
192    let databases = config.get_array(database_type).unwrap_or_else(|| {
193        panic!(
194            "field `{database_type}` should be an array of tables; \
195                please use `[[{database_type}]]` to configure a list of database services"
196        )
197    });
198    let pools = databases
199        .iter()
200        .filter_map(|v| v.as_table())
201        .map(ConnectionPool::with_config)
202        .collect();
203    let driver = DRIVER_NAME;
204    if database_type == driver {
205        tracing::warn!(driver, "connect to database services lazily");
206    } else {
207        tracing::error!(
208            driver,
209            "invalid database type `{database_type}` for the driver `{driver}`"
210        );
211    }
212    ConnectionPools(pools)
213});
214
215/// Database namespace prefix.
216static NAMESPACE_PREFIX: LazyLock<&'static str> = LazyLock::new(|| {
217    State::shared()
218        .get_config("database")
219        .and_then(|config| {
220            config
221                .get_str("namespace")
222                .filter(|s| !s.is_empty())
223                .map(|s| [s, ":"].concat().leak())
224        })
225        .unwrap_or_default()
226});
227
228/// Database table prefix.
229static TABLE_PREFIX: LazyLock<&'static str> = LazyLock::new(|| {
230    State::shared()
231        .get_config("database")
232        .and_then(|config| {
233            config
234                .get_str("namespace")
235                .filter(|s| !s.is_empty())
236                .map(|s| [s, "_"].concat().leak())
237        })
238        .unwrap_or_default()
239});
240
241/// Optional time zone.
242static TIME_ZONE: OnceLock<&'static str> = OnceLock::new();
243
244/// Max number of returning rows.
245static MAX_ROWS: AtomicUsize = AtomicUsize::new(10000);
246
247/// Auto migration.
248static AUTO_MIGRATION: AtomicBool = AtomicBool::new(true);
249
250/// Debug-only mode.
251static DEBUG_ONLY: AtomicBool = AtomicBool::new(false);