1use zino_core::error::Error;
2
3pub trait Executor {
5 type Row;
7
8 type QueryResult;
10
11 async fn execute(self, sql: &str) -> Result<Self::QueryResult, Error>;
13
14 async fn execute_with<T: ToString>(
16 self,
17 sql: &str,
18 arguments: &[T],
19 ) -> Result<Self::QueryResult, Error>;
20
21 async fn fetch(self, sql: &str) -> Result<Vec<Self::Row>, Error>;
23
24 async fn fetch_with<T: ToString>(
26 self,
27 sql: &str,
28 arguments: &[T],
29 ) -> Result<Vec<Self::Row>, Error>;
30
31 async fn fetch_one(self, sql: &str) -> Result<Self::Row, Error>;
33
34 async fn fetch_optional(self, sql: &str) -> Result<Option<Self::Row>, Error>;
36
37 async fn fetch_optional_with<T: ToString>(
39 self,
40 sql: &str,
41 arguments: &[T],
42 ) -> Result<Option<Self::Row>, Error>;
43}
44
45#[cfg(feature = "orm-sqlx")]
46macro_rules! impl_sqlx_executor {
47 () => {
48 type Row = super::DatabaseRow;
49 type QueryResult = <super::DatabaseDriver as sqlx::Database>::QueryResult;
50
51 async fn execute(self, sql: &str) -> Result<Self::QueryResult, Error> {
52 match sqlx::query(sql).execute(self).await {
53 Ok(result) => Ok(result),
54 Err(err) => {
55 if matches!(err, sqlx::error::Error::PoolTimedOut) {
56 super::GlobalPool::connect_all().await;
57 }
58 Err(err.into())
59 }
60 }
61 }
62
63 async fn execute_with<T: ToString>(
64 self,
65 sql: &str,
66 arguments: &[T],
67 ) -> Result<Self::QueryResult, Error> {
68 let mut query = sqlx::query(sql);
69 for arg in arguments {
70 query = query.bind(arg.to_string());
71 }
72 match query.execute(self).await {
73 Ok(result) => Ok(result),
74 Err(err) => {
75 if matches!(err, sqlx::error::Error::PoolTimedOut) {
76 super::GlobalPool::connect_all().await;
77 }
78 Err(err.into())
79 }
80 }
81 }
82
83 async fn fetch(self, sql: &str) -> Result<Vec<Self::Row>, Error> {
84 use futures::StreamExt;
85 use std::sync::atomic::Ordering::Relaxed;
86
87 let mut stream = sqlx::query(sql).fetch(self);
88 let mut max_rows = super::MAX_ROWS.load(Relaxed);
89 let mut rows = Vec::with_capacity(stream.size_hint().0.min(max_rows));
90 while let Some(result) = stream.next().await {
91 match result {
92 Ok(row) if max_rows > 0 => {
93 rows.push(row);
94 max_rows -= 1;
95 }
96 Err(err) => {
97 if matches!(err, sqlx::error::Error::PoolTimedOut) {
98 super::GlobalPool::connect_all().await;
99 }
100 return Err(err.into());
101 }
102 _ => break,
103 }
104 }
105 Ok(rows)
106 }
107
108 async fn fetch_with<T: ToString>(
109 self,
110 sql: &str,
111 arguments: &[T],
112 ) -> Result<Vec<Self::Row>, Error> {
113 use futures::StreamExt;
114 use std::sync::atomic::Ordering::Relaxed;
115
116 let mut query = sqlx::query(sql);
117 for arg in arguments {
118 query = query.bind(arg.to_string());
119 }
120
121 let mut stream = query.fetch(self);
122 let mut max_rows = super::MAX_ROWS.load(Relaxed);
123 let mut rows = Vec::with_capacity(stream.size_hint().0.min(max_rows));
124 while let Some(result) = stream.next().await {
125 match result {
126 Ok(row) if max_rows > 0 => {
127 rows.push(row);
128 max_rows -= 1;
129 }
130 Err(err) => {
131 if matches!(err, sqlx::error::Error::PoolTimedOut) {
132 super::GlobalPool::connect_all().await;
133 }
134 return Err(err.into());
135 }
136 _ => break,
137 }
138 }
139 Ok(rows)
140 }
141
142 async fn fetch_one(self, sql: &str) -> Result<Self::Row, Error> {
143 match sqlx::query(sql).fetch_one(self).await {
144 Ok(row) => Ok(row),
145 Err(err) => {
146 if matches!(err, sqlx::error::Error::PoolTimedOut) {
147 super::GlobalPool::connect_all().await;
148 }
149 Err(err.into())
150 }
151 }
152 }
153
154 async fn fetch_optional(self, sql: &str) -> Result<Option<Self::Row>, Error> {
155 match sqlx::query(sql).fetch_optional(self).await {
156 Ok(row) => Ok(row),
157 Err(err) => {
158 if matches!(err, sqlx::error::Error::PoolTimedOut) {
159 super::GlobalPool::connect_all().await;
160 }
161 Err(err.into())
162 }
163 }
164 }
165
166 async fn fetch_optional_with<T: ToString>(
167 self,
168 sql: &str,
169 arguments: &[T],
170 ) -> Result<Option<Self::Row>, Error> {
171 let mut query = sqlx::query(sql);
172 for arg in arguments {
173 query = query.bind(arg.to_string());
174 }
175 match query.fetch_optional(self).await {
176 Ok(row) => Ok(row),
177 Err(err) => {
178 if matches!(err, sqlx::error::Error::PoolTimedOut) {
179 super::GlobalPool::connect_all().await;
180 }
181 Err(err.into())
182 }
183 }
184 }
185 };
186}
187
188#[cfg(feature = "orm-sqlx")]
189impl Executor for &sqlx::Pool<super::DatabaseDriver> {
190 impl_sqlx_executor!();
191}
192
193#[cfg(feature = "orm-sqlx")]
194impl Executor for &mut super::DatabaseConnection {
195 impl_sqlx_executor!();
196}