1use super::{DatabaseDriver, column::ColumnExt, query::QueryExt, schema::Schema};
2use futures::TryStreamExt;
3use sqlx::{Decode, Row, Type};
4use std::{fmt::Display, sync::atomic::Ordering::Relaxed};
5use zino_core::{Map, error::Error, extension::JsonValueExt, model::Query};
6
7pub trait ScalarQuery<K>: Schema<PrimaryKey = K>
9where
10 K: Default + Display + PartialEq,
11{
12 async fn find_scalar<T>(query: &Query) -> Result<T, Error>
15 where
16 T: Send + Unpin + Type<DatabaseDriver> + for<'r> Decode<'r, DatabaseDriver>,
17 {
18 Self::before_query(query).await?;
19
20 let table_name = query.format_table_name::<Self>();
21 let projection = query.format_projection();
22 let filters = query.format_filters::<Self>();
23 let sort = query.format_sort();
24 let sql = format!("SELECT {projection} FROM {table_name} {filters} {sort} LIMIT 1;");
25 let mut ctx = Self::before_scan(&sql).await?;
26 ctx.set_query(sql);
27
28 let pool = Self::acquire_reader().await?.pool();
29 let scalar = sqlx::query_scalar(ctx.query()).fetch_one(pool).await?;
30 ctx.set_query_result(1, true);
31 Self::after_scan(&ctx).await?;
32 Self::after_query(&ctx).await?;
33 Ok(scalar)
34 }
35
36 async fn find_scalars<T>(query: &Query) -> Result<Vec<T>, Error>
39 where
40 T: Send + Unpin + Type<DatabaseDriver> + for<'r> Decode<'r, DatabaseDriver>,
41 {
42 Self::before_query(query).await?;
43
44 let table_name = query.format_table_name::<Self>();
45 let projection = query.format_projection();
46 let filters = query.format_filters::<Self>();
47 let sort = query.format_sort();
48 let pagination = query.format_pagination();
49 let sql = format!("SELECT {projection} FROM {table_name} {filters} {sort} {pagination};");
50 let mut ctx = Self::before_scan(&sql).await?;
51 ctx.set_query(&sql);
52
53 let pool = Self::acquire_reader().await?.pool();
54 let mut stream = sqlx::query(&sql).fetch(pool);
55 let mut max_rows = super::MAX_ROWS.load(Relaxed);
56 let estimated_rows = stream.size_hint().0;
57 if cfg!(debug_assertions) && estimated_rows > max_rows {
58 tracing::warn!(
59 "estimated number of rows {} exceeds the maximum row limit {}",
60 estimated_rows,
61 max_rows,
62 );
63 }
64
65 let mut data = Vec::with_capacity(estimated_rows.min(max_rows));
66 while let Some(row) = stream.try_next().await? {
67 if max_rows > 0 {
68 data.push(row.try_get_unchecked(0)?);
69 max_rows -= 1;
70 } else {
71 break;
72 }
73 }
74 ctx.set_query_result(u64::try_from(data.len())?, true);
75 Self::after_scan(&ctx).await?;
76 Self::after_query(&ctx).await?;
77 Ok(data)
78 }
79
80 async fn find_distinct_scalars<T>(query: &Query) -> Result<Vec<T>, Error>
83 where
84 T: Send + Unpin + Type<DatabaseDriver> + for<'r> Decode<'r, DatabaseDriver>,
85 {
86 Self::before_query(query).await?;
87
88 let table_name = query.format_table_name::<Self>();
89 let projection = query.format_projection();
90 let filters = query.format_filters::<Self>();
91 let sort = query.format_sort();
92 let pagination = query.format_pagination();
93 let sql = format!(
94 "SELECT DISTINCT {projection} FROM {table_name} \
95 {filters} {sort} {pagination};"
96 );
97 let mut ctx = Self::before_scan(&sql).await?;
98 ctx.set_query(&sql);
99
100 let pool = Self::acquire_reader().await?.pool();
101 let mut stream = sqlx::query(&sql).fetch(pool);
102 let mut max_rows = super::MAX_ROWS.load(Relaxed);
103 let estimated_rows = stream.size_hint().0;
104 if cfg!(debug_assertions) && estimated_rows > max_rows {
105 tracing::warn!(
106 "estimated number of rows {} exceeds the maximum row limit {}",
107 estimated_rows,
108 max_rows,
109 );
110 }
111
112 let mut data = Vec::with_capacity(estimated_rows.min(max_rows));
113 while let Some(row) = stream.try_next().await? {
114 if max_rows > 0 {
115 data.push(row.try_get_unchecked(0)?);
116 max_rows -= 1;
117 } else {
118 break;
119 }
120 }
121 ctx.set_query_result(u64::try_from(data.len())?, true);
122 Self::after_scan(&ctx).await?;
123 Self::after_query(&ctx).await?;
124 Ok(data)
125 }
126
127 async fn query_scalar<T>(query: &str, params: Option<&Map>) -> Result<T, Error>
129 where
130 T: Send + Unpin + Type<DatabaseDriver> + for<'r> Decode<'r, DatabaseDriver>,
131 {
132 let (sql, values) = Query::prepare_query(query, params);
133 let mut ctx = Self::before_scan(&sql).await?;
134 ctx.set_query(sql);
135
136 let mut query = sqlx::query_scalar(ctx.query());
137 let mut arguments = Vec::with_capacity(values.len());
138 for value in values {
139 query = query.bind(value.to_string_unquoted());
140 arguments.push(value.to_string_unquoted());
141 }
142
143 let pool = Self::acquire_reader().await?.pool();
144 let scalar = query.fetch_one(pool).await?;
145 ctx.append_arguments(&mut arguments);
146 ctx.set_query_result(1, true);
147 Self::after_scan(&ctx).await?;
148 Ok(scalar)
149 }
150
151 async fn query_scalars<T>(query: &str, params: Option<&Map>) -> Result<Vec<T>, Error>
153 where
154 T: Send + Unpin + Type<DatabaseDriver> + for<'r> Decode<'r, DatabaseDriver>,
155 {
156 let (sql, values) = Query::prepare_query(query, params);
157 let mut ctx = Self::before_scan(&sql).await?;
158 ctx.set_query(sql.as_ref());
159
160 let mut query = sqlx::query(&sql);
161 let mut arguments = Vec::with_capacity(values.len());
162 for value in values {
163 query = query.bind(value.to_string_unquoted());
164 arguments.push(value.to_string_unquoted());
165 }
166
167 let pool = Self::acquire_reader().await?.pool();
168 let mut stream = query.fetch(pool);
169 let mut max_rows = super::MAX_ROWS.load(Relaxed);
170 let estimated_rows = stream.size_hint().0;
171 if cfg!(debug_assertions) && estimated_rows > max_rows {
172 tracing::warn!(
173 "estimated number of rows {} exceeds the maximum row limit {}",
174 estimated_rows,
175 max_rows,
176 );
177 }
178
179 let mut data = Vec::with_capacity(estimated_rows.min(max_rows));
180 while let Some(row) = stream.try_next().await? {
181 if max_rows > 0 {
182 data.push(row.try_get_unchecked(0)?);
183 max_rows -= 1;
184 } else {
185 break;
186 }
187 }
188 ctx.append_arguments(&mut arguments);
189 ctx.set_query_result(u64::try_from(data.len())?, true);
190 Self::after_scan(&ctx).await?;
191 Ok(data)
192 }
193
194 async fn find_scalar_by_id<C, T>(primary_key: &Self::PrimaryKey, column: C) -> Result<T, Error>
197 where
198 C: AsRef<str>,
199 T: Send + Unpin + Type<DatabaseDriver> + for<'r> Decode<'r, DatabaseDriver>,
200 {
201 let primary_key_name = Self::primary_key_name();
202 let table_name = Query::escape_table_name(Self::table_name());
203 let projection = Query::format_field(column.as_ref());
204 let placeholder = Query::placeholder(1);
205 let sql = if cfg!(feature = "orm-postgres") {
206 let type_annotation = Self::primary_key_column().type_annotation();
207 format!(
208 "SELECT {projection} FROM {table_name} \
209 WHERE {primary_key_name} = ({placeholder}){type_annotation};"
210 )
211 } else {
212 format!(
213 "SELECT {projection} FROM {table_name} WHERE {primary_key_name} = {placeholder};"
214 )
215 };
216 let mut ctx = Self::before_scan(&sql).await?;
217 ctx.set_query(sql);
218
219 let pool = Self::acquire_reader().await?.pool();
220 let query = sqlx::query_scalar(ctx.query()).bind(primary_key.to_string());
221 let scalar = query.fetch_one(pool).await?;
222 ctx.set_query_result(1, true);
223 Self::after_scan(&ctx).await?;
224 Self::after_query(&ctx).await?;
225 Ok(scalar)
226 }
227
228 async fn find_primary_key(query: &Query) -> Result<K, Error>
230 where
231 K: Send + Unpin + Type<DatabaseDriver> + for<'r> Decode<'r, DatabaseDriver>,
232 {
233 Self::before_query(query).await?;
234
235 let projection = Self::PRIMARY_KEY_NAME;
236 let table_name = query.format_table_name::<Self>();
237 let filters = query.format_filters::<Self>();
238 let sort = query.format_sort();
239 let sql = format!("SELECT {projection} FROM {table_name} {filters} {sort} LIMIT 1;");
240 let mut ctx = Self::before_scan(&sql).await?;
241 ctx.set_query(sql);
242
243 let pool = Self::acquire_reader().await?.pool();
244 let scalar = sqlx::query_scalar(ctx.query()).fetch_one(pool).await?;
245 ctx.set_query_result(1, true);
246 Self::after_scan(&ctx).await?;
247 Self::after_query(&ctx).await?;
248 Ok(scalar)
249 }
250
251 async fn find_primary_keys(query: &Query) -> Result<Vec<K>, Error>
253 where
254 K: Send + Unpin + Type<DatabaseDriver> + for<'r> Decode<'r, DatabaseDriver>,
255 {
256 Self::before_query(query).await?;
257
258 let projection = Self::PRIMARY_KEY_NAME;
259 let table_name = query.format_table_name::<Self>();
260 let filters = query.format_filters::<Self>();
261 let sort = query.format_sort();
262 let pagination = query.format_pagination();
263 let sql = format!("SELECT {projection} FROM {table_name} {filters} {sort} {pagination};");
264 let mut ctx = Self::before_scan(&sql).await?;
265 ctx.set_query(&sql);
266
267 let pool = Self::acquire_reader().await?.pool();
268 let mut stream = sqlx::query(&sql).fetch(pool);
269 let mut max_rows = super::MAX_ROWS.load(Relaxed);
270 let estimated_rows = stream.size_hint().0;
271 if cfg!(debug_assertions) && estimated_rows > max_rows {
272 tracing::warn!(
273 "estimated number of rows {} exceeds the maximum row limit {}",
274 estimated_rows,
275 max_rows,
276 );
277 }
278
279 let mut data = Vec::with_capacity(estimated_rows.min(max_rows));
280 while let Some(row) = stream.try_next().await? {
281 if max_rows > 0 {
282 data.push(row.try_get_unchecked(0)?);
283 max_rows -= 1;
284 } else {
285 break;
286 }
287 }
288 ctx.set_query_result(u64::try_from(data.len())?, true);
289 Self::after_scan(&ctx).await?;
290 Self::after_query(&ctx).await?;
291 Ok(data)
292 }
293}
294
295impl<M, K> ScalarQuery<K> for M
296where
297 M: Schema<PrimaryKey = K>,
298 K: Default + Display + PartialEq,
299{
300}