zino_orm/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3#![doc(html_favicon_url = "https://zino.cc/assets/zino-logo.png")]
4#![doc(html_logo_url = "https://zino.cc/assets/zino-logo.svg")]
5#![allow(async_fn_in_trait)]
6
7use smallvec::SmallVec;
8use std::sync::{
9    OnceLock,
10    atomic::{AtomicBool, AtomicUsize, Ordering::Relaxed},
11};
12use zino_core::{LazyLock, extension::TomlTableExt, state::State};
13
14mod accessor;
15mod aggregate;
16mod column;
17mod entity;
18mod executor;
19mod helper;
20mod join;
21mod manager;
22mod mutation;
23mod pool;
24mod primary_key;
25mod query;
26mod row;
27mod schema;
28mod transaction;
29mod value;
30mod window;
31
32pub use accessor::ModelAccessor;
33pub use aggregate::Aggregation;
34pub use column::EncodeColumn;
35pub use entity::{DerivedColumn, Entity, ModelColumn};
36pub use executor::Executor;
37pub use helper::ModelHelper;
38pub use join::JoinOn;
39pub use manager::PoolManager;
40pub use mutation::MutationBuilder;
41pub use pool::ConnectionPool;
42pub use primary_key::PrimaryKey;
43pub use query::QueryBuilder;
44pub use row::DecodeRow;
45pub use schema::Schema;
46pub use transaction::Transaction;
47pub use value::IntoSqlValue;
48pub use window::Window;
49
50#[cfg(feature = "orm-sqlx")]
51mod decode;
52#[cfg(feature = "orm-sqlx")]
53mod scalar;
54
55#[cfg(feature = "orm-sqlx")]
56pub use decode::{decode, decode_array, decode_decimal, decode_optional, decode_uuid};
57#[cfg(feature = "orm-sqlx")]
58pub use scalar::ScalarQuery;
59
60cfg_if::cfg_if! {
61    if #[cfg(any(feature = "orm-mariadb", feature = "orm-mysql", feature = "orm-tidb"))] {
62        mod mysql;
63
64        /// Driver name.
65        static DRIVER_NAME: &str = if cfg!(feature = "orm-mariadb") {
66            "mariadb"
67        } else if cfg!(feature = "orm-tidb") {
68            "tidb"
69        } else {
70            "mysql"
71        };
72
73        /// MySQL database driver.
74        pub type DatabaseDriver = sqlx::MySql;
75
76        /// MySQL database pool.
77        pub type DatabasePool = sqlx::MySqlPool;
78
79        /// MySQL database connection.
80        pub type DatabaseConnection = sqlx::MySqlConnection;
81
82        /// A single row from the MySQL database.
83        pub type DatabaseRow = sqlx::mysql::MySqlRow;
84    } else if #[cfg(feature = "orm-postgres")] {
85        mod postgres;
86
87        /// Driver name.
88        static DRIVER_NAME: &str = "postgres";
89
90        /// PostgreSQL database driver.
91        pub type DatabaseDriver = sqlx::Postgres;
92
93        /// PostgreSQL database pool.
94        pub type DatabasePool = sqlx::PgPool;
95
96        /// PostgreSQL database connection.
97        pub type DatabaseConnection = sqlx::PgConnection;
98
99        /// A single row from the PostgreSQL database.
100        pub type DatabaseRow = sqlx::postgres::PgRow;
101    } else {
102        mod sqlite;
103
104        /// Driver name.
105        static DRIVER_NAME: &str = "sqlite";
106
107        /// SQLite database driver.
108        pub type DatabaseDriver = sqlx::Sqlite;
109
110        /// SQLite database pool.
111        pub type DatabasePool = sqlx::SqlitePool;
112
113        /// SQLite database connection.
114        pub type DatabaseConnection = sqlx::SqliteConnection;
115
116        /// A single row from the SQLite database.
117        pub type DatabaseRow = sqlx::sqlite::SqliteRow;
118    }
119}
120
121/// A list of database connection pools.
122#[derive(Debug)]
123struct ConnectionPools(SmallVec<[ConnectionPool; 4]>);
124
125impl ConnectionPools {
126    /// Returns a connection pool with the specific name.
127    pub(crate) fn get_pool(&self, name: &str) -> Option<&ConnectionPool> {
128        let mut pool = None;
129        for cp in self.0.iter().filter(|cp| cp.name() == name) {
130            if cp.is_available() {
131                return Some(cp);
132            } else {
133                pool = Some(cp);
134            }
135        }
136        pool
137    }
138}
139
140/// Global access to the shared connection pools.
141#[derive(Debug, Clone, Copy, Default)]
142pub struct GlobalPool;
143
144impl GlobalPool {
145    /// Gets the connection pool for the specific service.
146    #[inline]
147    pub fn get(name: &str) -> Option<&'static ConnectionPool> {
148        SHARED_CONNECTION_POOLS.get_pool(name)
149    }
150
151    /// Iterates over the shared connection pools and
152    /// attempts to establish a database connection for each of them.
153    #[inline]
154    pub async fn connect_all() {
155        for cp in SHARED_CONNECTION_POOLS.0.iter() {
156            cp.check_availability().await;
157        }
158    }
159
160    /// Shuts down the shared connection pools to ensure all connections are gracefully closed.
161    #[inline]
162    pub async fn close_all() {
163        for cp in SHARED_CONNECTION_POOLS.0.iter() {
164            cp.close().await;
165        }
166    }
167}
168
169/// Shared connection pools.
170static SHARED_CONNECTION_POOLS: LazyLock<ConnectionPools> = LazyLock::new(|| {
171    let config = State::shared().config();
172    let mut database_type = DRIVER_NAME;
173    let mut disable_auto_migration = false;
174    if let Some(database) = config.get_table("database") {
175        if let Some(driver) = database.get_str("type") {
176            database_type = driver;
177        }
178        if let Some(time_zone) = database.get_str("time-zone") {
179            TIME_ZONE
180                .set(time_zone)
181                .expect("fail to set time zone for the database session");
182        }
183        if let Some(max_rows) = database.get_usize("max-rows") {
184            MAX_ROWS.store(max_rows, Relaxed);
185        }
186        if let Some(auto_migration) = database.get_bool("auto-migration") {
187            disable_auto_migration = !auto_migration;
188        }
189        if let Some(debug_only) = database.get_bool("debug-only") {
190            DEBUG_ONLY.store(debug_only, Relaxed);
191        }
192    }
193
194    // Database connection pools.
195    let databases = config.get_array(database_type).unwrap_or_else(|| {
196        panic!(
197            "field `{database_type}` should be an array of tables; \
198                please use `[[{database_type}]]` to configure a list of database services"
199        )
200    });
201    let pools = databases
202        .iter()
203        .filter_map(|v| v.as_table())
204        .map(|config| {
205            let connection_pool = ConnectionPool::with_config(config);
206            if disable_auto_migration {
207                connection_pool.disable_auto_migration();
208            }
209            connection_pool
210        })
211        .collect();
212    let driver = DRIVER_NAME;
213    if database_type == driver {
214        tracing::warn!(driver, "connect to database services lazily");
215    } else {
216        tracing::error!(
217            driver,
218            "invalid database type `{database_type}` for the driver `{driver}`"
219        );
220    }
221    ConnectionPools(pools)
222});
223
224/// Database namespace prefix.
225static NAMESPACE_PREFIX: LazyLock<&'static str> = LazyLock::new(|| {
226    State::shared()
227        .get_config("database")
228        .and_then(|config| {
229            config
230                .get_str("namespace")
231                .filter(|s| !s.is_empty())
232                .map(|s| [s, ":"].concat().leak())
233        })
234        .unwrap_or_default()
235});
236
237/// Database table prefix.
238static TABLE_PREFIX: LazyLock<&'static str> = LazyLock::new(|| {
239    State::shared()
240        .get_config("database")
241        .and_then(|config| {
242            config
243                .get_str("namespace")
244                .filter(|s| !s.is_empty())
245                .map(|s| [s, "_"].concat().leak())
246        })
247        .unwrap_or_default()
248});
249
250/// Optional time zone.
251static TIME_ZONE: OnceLock<&'static str> = OnceLock::new();
252
253/// Max number of returning rows.
254static MAX_ROWS: AtomicUsize = AtomicUsize::new(10000);
255
256/// Debug-only mode.
257static DEBUG_ONLY: AtomicBool = AtomicBool::new(false);