zino_orm/
executor.rs

1use zino_core::error::Error;
2
3/// Executing queries against the database.
4pub trait Executor {
5    /// A type for the database row.
6    type Row;
7
8    /// A type for the query result.
9    type QueryResult;
10
11    /// Executes the query and return the total number of rows affected.
12    async fn execute(self, sql: &str) -> Result<Self::QueryResult, Error>;
13
14    /// Executes the query with arguments and return the total number of rows affected.
15    async fn execute_with<T: ToString>(
16        self,
17        sql: &str,
18        arguments: &[T],
19    ) -> Result<Self::QueryResult, Error>;
20
21    /// Executes the query and return all the generated results.
22    async fn fetch(self, sql: &str) -> Result<Vec<Self::Row>, Error>;
23
24    /// Executes the query with arguments and return all the generated results.
25    async fn fetch_with<T: ToString>(
26        self,
27        sql: &str,
28        arguments: &[T],
29    ) -> Result<Vec<Self::Row>, Error>;
30
31    /// Executes the query and returns exactly one row.
32    async fn fetch_one(self, sql: &str) -> Result<Self::Row, Error>;
33
34    /// Executes the query and returns at most one row.
35    async fn fetch_optional(self, sql: &str) -> Result<Option<Self::Row>, Error>;
36
37    /// Executes the query with arguments and returns at most one row.
38    async fn fetch_optional_with<T: ToString>(
39        self,
40        sql: &str,
41        arguments: &[T],
42    ) -> Result<Option<Self::Row>, Error>;
43}
44
45#[cfg(feature = "orm-sqlx")]
46macro_rules! impl_sqlx_executor {
47    () => {
48        type Row = super::DatabaseRow;
49        type QueryResult = <super::DatabaseDriver as sqlx::Database>::QueryResult;
50
51        async fn execute(self, sql: &str) -> Result<Self::QueryResult, Error> {
52            match sqlx::query(sql).execute(self).await {
53                Ok(result) => Ok(result),
54                Err(err) => {
55                    if matches!(err, sqlx::error::Error::PoolTimedOut) {
56                        super::GlobalPool::connect_all().await;
57                    }
58                    Err(err.into())
59                }
60            }
61        }
62
63        async fn execute_with<T: ToString>(
64            self,
65            sql: &str,
66            arguments: &[T],
67        ) -> Result<Self::QueryResult, Error> {
68            let mut query = sqlx::query(sql);
69            for arg in arguments {
70                query = query.bind(arg.to_string());
71            }
72            match query.execute(self).await {
73                Ok(result) => Ok(result),
74                Err(err) => {
75                    if matches!(err, sqlx::error::Error::PoolTimedOut) {
76                        super::GlobalPool::connect_all().await;
77                    }
78                    Err(err.into())
79                }
80            }
81        }
82
83        async fn fetch(self, sql: &str) -> Result<Vec<Self::Row>, Error> {
84            use futures::StreamExt;
85            use std::sync::atomic::Ordering::Relaxed;
86
87            let mut stream = sqlx::query(sql).fetch(self);
88            let mut max_rows = super::MAX_ROWS.load(Relaxed);
89            let estimated_rows = stream.size_hint().0;
90            if cfg!(debug_assertions) && estimated_rows > max_rows {
91                tracing::warn!(
92                    "estimated number of rows {} exceeds the maximum row limit {}",
93                    estimated_rows,
94                    max_rows,
95                );
96            }
97
98            let mut rows = Vec::with_capacity(estimated_rows.min(max_rows));
99            while let Some(result) = stream.next().await {
100                match result {
101                    Ok(row) if max_rows > 0 => {
102                        rows.push(row);
103                        max_rows -= 1;
104                    }
105                    Err(err) => {
106                        if matches!(err, sqlx::error::Error::PoolTimedOut) {
107                            super::GlobalPool::connect_all().await;
108                        }
109                        return Err(err.into());
110                    }
111                    _ => break,
112                }
113            }
114            Ok(rows)
115        }
116
117        async fn fetch_with<T: ToString>(
118            self,
119            sql: &str,
120            arguments: &[T],
121        ) -> Result<Vec<Self::Row>, Error> {
122            use futures::StreamExt;
123            use std::sync::atomic::Ordering::Relaxed;
124
125            let mut query = sqlx::query(sql);
126            for arg in arguments {
127                query = query.bind(arg.to_string());
128            }
129
130            let mut stream = query.fetch(self);
131            let mut max_rows = super::MAX_ROWS.load(Relaxed);
132            let estimated_rows = stream.size_hint().0;
133            if cfg!(debug_assertions) && estimated_rows > max_rows {
134                tracing::warn!(
135                    "estimated number of rows {} exceeds the maximum row limit {}",
136                    estimated_rows,
137                    max_rows,
138                );
139            }
140
141            let mut rows = Vec::with_capacity(estimated_rows.min(max_rows));
142            while let Some(result) = stream.next().await {
143                match result {
144                    Ok(row) if max_rows > 0 => {
145                        rows.push(row);
146                        max_rows -= 1;
147                    }
148                    Err(err) => {
149                        if matches!(err, sqlx::error::Error::PoolTimedOut) {
150                            super::GlobalPool::connect_all().await;
151                        }
152                        return Err(err.into());
153                    }
154                    _ => break,
155                }
156            }
157            Ok(rows)
158        }
159
160        async fn fetch_one(self, sql: &str) -> Result<Self::Row, Error> {
161            match sqlx::query(sql).fetch_one(self).await {
162                Ok(row) => Ok(row),
163                Err(err) => {
164                    if matches!(err, sqlx::error::Error::PoolTimedOut) {
165                        super::GlobalPool::connect_all().await;
166                    }
167                    Err(err.into())
168                }
169            }
170        }
171
172        async fn fetch_optional(self, sql: &str) -> Result<Option<Self::Row>, Error> {
173            match sqlx::query(sql).fetch_optional(self).await {
174                Ok(row) => Ok(row),
175                Err(err) => {
176                    if matches!(err, sqlx::error::Error::PoolTimedOut) {
177                        super::GlobalPool::connect_all().await;
178                    }
179                    Err(err.into())
180                }
181            }
182        }
183
184        async fn fetch_optional_with<T: ToString>(
185            self,
186            sql: &str,
187            arguments: &[T],
188        ) -> Result<Option<Self::Row>, Error> {
189            let mut query = sqlx::query(sql);
190            for arg in arguments {
191                query = query.bind(arg.to_string());
192            }
193            match query.fetch_optional(self).await {
194                Ok(row) => Ok(row),
195                Err(err) => {
196                    if matches!(err, sqlx::error::Error::PoolTimedOut) {
197                        super::GlobalPool::connect_all().await;
198                    }
199                    Err(err.into())
200                }
201            }
202        }
203    };
204}
205
206#[cfg(feature = "orm-sqlx")]
207impl Executor for &sqlx::Pool<super::DatabaseDriver> {
208    impl_sqlx_executor!();
209}
210
211#[cfg(feature = "orm-sqlx")]
212impl Executor for &mut super::DatabaseConnection {
213    impl_sqlx_executor!();
214}