1use zino_core::error::Error;
2
3pub trait Executor {
5 type Row;
7
8 type QueryResult;
10
11 async fn execute(self, sql: &str) -> Result<Self::QueryResult, Error>;
13
14 async fn execute_with<T: ToString>(
16 self,
17 sql: &str,
18 arguments: &[T],
19 ) -> Result<Self::QueryResult, Error>;
20
21 async fn fetch(self, sql: &str) -> Result<Vec<Self::Row>, Error>;
23
24 async fn fetch_with<T: ToString>(
26 self,
27 sql: &str,
28 arguments: &[T],
29 ) -> Result<Vec<Self::Row>, Error>;
30
31 async fn fetch_one(self, sql: &str) -> Result<Self::Row, Error>;
33
34 async fn fetch_optional(self, sql: &str) -> Result<Option<Self::Row>, Error>;
36
37 async fn fetch_optional_with<T: ToString>(
39 self,
40 sql: &str,
41 arguments: &[T],
42 ) -> Result<Option<Self::Row>, Error>;
43}
44
45#[cfg(feature = "orm-sqlx")]
46macro_rules! impl_sqlx_executor {
47 () => {
48 type Row = super::DatabaseRow;
49 type QueryResult = <super::DatabaseDriver as sqlx::Database>::QueryResult;
50
51 async fn execute(self, sql: &str) -> Result<Self::QueryResult, Error> {
52 match sqlx::query(sql).execute(self).await {
53 Ok(result) => Ok(result),
54 Err(err) => {
55 if matches!(err, sqlx::error::Error::PoolTimedOut) {
56 super::GlobalPool::connect_all().await;
57 }
58 Err(err.into())
59 }
60 }
61 }
62
63 async fn execute_with<T: ToString>(
64 self,
65 sql: &str,
66 arguments: &[T],
67 ) -> Result<Self::QueryResult, Error> {
68 let mut query = sqlx::query(sql);
69 for arg in arguments {
70 query = query.bind(arg.to_string());
71 }
72 match query.execute(self).await {
73 Ok(result) => Ok(result),
74 Err(err) => {
75 if matches!(err, sqlx::error::Error::PoolTimedOut) {
76 super::GlobalPool::connect_all().await;
77 }
78 Err(err.into())
79 }
80 }
81 }
82
83 async fn fetch(self, sql: &str) -> Result<Vec<Self::Row>, Error> {
84 use futures::StreamExt;
85 use std::sync::atomic::Ordering::Relaxed;
86
87 let mut stream = sqlx::query(sql).fetch(self);
88 let mut max_rows = super::MAX_ROWS.load(Relaxed);
89 let estimated_rows = stream.size_hint().0;
90 if cfg!(debug_assertions) && estimated_rows > max_rows {
91 tracing::warn!(
92 "estimated number of rows {} exceeds the maximum row limit {}",
93 estimated_rows,
94 max_rows,
95 );
96 }
97
98 let mut rows = Vec::with_capacity(estimated_rows.min(max_rows));
99 while let Some(result) = stream.next().await {
100 match result {
101 Ok(row) if max_rows > 0 => {
102 rows.push(row);
103 max_rows -= 1;
104 }
105 Err(err) => {
106 if matches!(err, sqlx::error::Error::PoolTimedOut) {
107 super::GlobalPool::connect_all().await;
108 }
109 return Err(err.into());
110 }
111 _ => break,
112 }
113 }
114 Ok(rows)
115 }
116
117 async fn fetch_with<T: ToString>(
118 self,
119 sql: &str,
120 arguments: &[T],
121 ) -> Result<Vec<Self::Row>, Error> {
122 use futures::StreamExt;
123 use std::sync::atomic::Ordering::Relaxed;
124
125 let mut query = sqlx::query(sql);
126 for arg in arguments {
127 query = query.bind(arg.to_string());
128 }
129
130 let mut stream = query.fetch(self);
131 let mut max_rows = super::MAX_ROWS.load(Relaxed);
132 let estimated_rows = stream.size_hint().0;
133 if cfg!(debug_assertions) && estimated_rows > max_rows {
134 tracing::warn!(
135 "estimated number of rows {} exceeds the maximum row limit {}",
136 estimated_rows,
137 max_rows,
138 );
139 }
140
141 let mut rows = Vec::with_capacity(estimated_rows.min(max_rows));
142 while let Some(result) = stream.next().await {
143 match result {
144 Ok(row) if max_rows > 0 => {
145 rows.push(row);
146 max_rows -= 1;
147 }
148 Err(err) => {
149 if matches!(err, sqlx::error::Error::PoolTimedOut) {
150 super::GlobalPool::connect_all().await;
151 }
152 return Err(err.into());
153 }
154 _ => break,
155 }
156 }
157 Ok(rows)
158 }
159
160 async fn fetch_one(self, sql: &str) -> Result<Self::Row, Error> {
161 match sqlx::query(sql).fetch_one(self).await {
162 Ok(row) => Ok(row),
163 Err(err) => {
164 if matches!(err, sqlx::error::Error::PoolTimedOut) {
165 super::GlobalPool::connect_all().await;
166 }
167 Err(err.into())
168 }
169 }
170 }
171
172 async fn fetch_optional(self, sql: &str) -> Result<Option<Self::Row>, Error> {
173 match sqlx::query(sql).fetch_optional(self).await {
174 Ok(row) => Ok(row),
175 Err(err) => {
176 if matches!(err, sqlx::error::Error::PoolTimedOut) {
177 super::GlobalPool::connect_all().await;
178 }
179 Err(err.into())
180 }
181 }
182 }
183
184 async fn fetch_optional_with<T: ToString>(
185 self,
186 sql: &str,
187 arguments: &[T],
188 ) -> Result<Option<Self::Row>, Error> {
189 let mut query = sqlx::query(sql);
190 for arg in arguments {
191 query = query.bind(arg.to_string());
192 }
193 match query.fetch_optional(self).await {
194 Ok(row) => Ok(row),
195 Err(err) => {
196 if matches!(err, sqlx::error::Error::PoolTimedOut) {
197 super::GlobalPool::connect_all().await;
198 }
199 Err(err.into())
200 }
201 }
202 }
203 };
204}
205
206#[cfg(feature = "orm-sqlx")]
207impl Executor for &sqlx::Pool<super::DatabaseDriver> {
208 impl_sqlx_executor!();
209}
210
211#[cfg(feature = "orm-sqlx")]
212impl Executor for &mut super::DatabaseConnection {
213 impl_sqlx_executor!();
214}