1use crate::error::Error;
2use crate::io::AsyncStreamExt;
3use crate::protocol::message::*;
4use crate::protocol::statement::{Execute as StatementExecute, Prepare, StmtClose};
5use crate::protocol::text::{ColumnFlags, OkPacket, Query};
6use crate::protocol::ServerContext;
7use crate::statement::{XuguStatement, XuguStatementMetadata};
8use crate::{
9 Xugu, XuguArguments, XuguConnection, XuguDatabaseError, XuguQueryResult, XuguRow, XuguTypeInfo,
10};
11use futures_core::future::BoxFuture;
12use futures_core::stream::BoxStream;
13use futures_core::Stream;
14use futures_util::TryStreamExt;
15use log::Level;
16use sqlx_core::describe::Describe;
17use sqlx_core::executor::{Execute, Executor};
18use sqlx_core::logger::QueryLogger;
19use sqlx_core::{try_stream, Either, HashMap};
20use std::{borrow::Cow, pin::pin, sync::Arc};
21
22impl XuguConnection {
23 async fn prepare_statement<'c>(
24 &mut self,
25 sql: &str,
26 ) -> Result<(u32, XuguStatementMetadata), Error> {
27 self.wait_until_ready().await?;
29
30 let id = self.inner.gen_st_id();
31 self.inner
32 .stream
33 .send_packet(Prepare {
34 query: sql,
35 con_obj_name: &self.inner.con_obj_name,
36 st_id: id,
37 })
38 .await?;
39
40 let mut error = None;
41 let mut columns = Vec::new();
42 let mut column_names = HashMap::new();
43 let mut params = Vec::new();
44
45 loop {
46 let message: ReceivedMessage = self.inner.stream.recv().await?;
47 let cnt = ServerContext::new(self.inner.stream.server_version);
48 match message.format {
49 BackendMessageFormat::ErrorResponse => {
50 let err: ErrorResponse = message.decode(&mut self.inner.stream, cnt).await?;
51 error = Some(err.error);
52 }
53 BackendMessageFormat::MessageResponse => {
54 let notice: MessageResponse =
57 message.decode(&mut self.inner.stream, cnt).await?;
58 let (log_level, tracing_level) = (Level::Info, tracing::Level::INFO);
59 let log_is_enabled = log::log_enabled!(
60 target: "sqlx::xugu::notice",
61 log_level
62 ) || sqlx_core::private_tracing_dynamic_enabled!(
63 target: "sqlx::xugu::notice",
64 tracing_level
65 );
66 if log_is_enabled {
67 sqlx_core::private_tracing_dynamic_event!(
68 target: "sqlx::xugu::notice",
69 tracing_level,
70 message = notice.msg
71 );
72 }
73 }
74 BackendMessageFormat::ReadyForQuery => {
75 let _: ReadyForQuery = message.decode(&mut self.inner.stream, cnt).await?;
76 break;
77 }
78 BackendMessageFormat::RowDescription => {
79 let row_columns: RowDescription =
80 message.decode(&mut self.inner.stream, cnt).await?;
81 (columns, column_names) = row_columns.convert_columns()?;
82 }
83 BackendMessageFormat::ParameterDescription => {
84 let param_def: ParameterDescription =
85 message.decode(&mut self.inner.stream, cnt).await?;
86 params = param_def.params;
87 }
88 _ => {
89 break;
90 }
91 }
92 }
93
94 if let Some(err) = error {
95 return Err(Error::Database(Box::new(XuguDatabaseError::from_str(&err))));
96 }
97
98 let metadata = XuguStatementMetadata {
99 parameters: Arc::new(params),
100 columns: Arc::new(columns),
101 column_names: Arc::new(column_names),
102 };
103
104 Ok((id, metadata))
105 }
106
107 async fn get_or_prepare_statement<'c>(
108 &mut self,
109 sql: &str,
110 ) -> Result<(u32, XuguStatementMetadata), Error> {
111 if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
112 return Ok((*statement).clone());
114 }
115
116 let (id, metadata) = self.prepare_statement(sql).await?;
117
118 if let Some((id, _)) = self
120 .inner
121 .cache_statement
122 .insert(sql, (id, metadata.clone()))
123 {
124 self.wait_until_ready().await?;
126 self.inner
127 .stream
128 .send_packet(StmtClose {
129 con_obj_name: &self.inner.con_obj_name,
130 st_id: id,
131 })
132 .await?;
133
134 let _ok: OkPacket = self.inner.stream.recv().await?;
136 }
137
138 Ok((id, metadata))
139 }
140
141 #[allow(clippy::needless_lifetimes)]
150 pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
151 &'c mut self,
152 sql: &'q str,
153 arguments: Option<XuguArguments<'q>>,
154 persistent: bool,
155 ) -> Result<impl Stream<Item = Result<Either<XuguQueryResult, XuguRow>, Error>> + 'e, Error>
156 {
157 let mut logger = QueryLogger::new(sql, self.inner.log_settings.clone());
158
159 self.wait_until_ready().await?;
160
161 let (mut column_names, mut columns, mut needs_metadata) = if let Some(arguments) = arguments
165 {
166 if persistent && self.inner.cache_statement.is_enabled() {
167 let (id, metadata) = self.get_or_prepare_statement(sql).await?;
168
169 self.inner
170 .stream
171 .send_packet(StatementExecute {
172 con_obj_name: &self.inner.con_obj_name,
173 st_id: id,
174 arguments: &arguments,
175 params: &metadata.parameters,
176 })
177 .await?;
178
179 let needs_metadata = metadata.column_names.is_empty();
180 (metadata.column_names, metadata.columns, needs_metadata)
181 } else {
182 let (id, metadata) = self.prepare_statement(sql).await?;
183
184 self.inner
185 .stream
186 .send_packet(StatementExecute {
187 con_obj_name: &self.inner.con_obj_name,
188 st_id: id,
189 arguments: &arguments,
190 params: &metadata.parameters,
191 })
192 .await?;
193
194 self.inner
195 .stream
196 .send_packet(StmtClose {
197 con_obj_name: &self.inner.con_obj_name,
198 st_id: id,
199 })
200 .await?;
201 self.inner.pending_ready_for_query_count += 1;
203
204 let needs_metadata = metadata.column_names.is_empty();
205 (metadata.column_names, metadata.columns, needs_metadata)
206 }
207 } else {
208 self.inner.stream.send_packet(Query(sql)).await?;
209
210 (Arc::default(), Arc::default(), true)
211 };
212
213 self.inner.pending_ready_for_query_count += 1;
214
215 let mut error = None;
216
217 let mut num_columns = 0;
218
219 Ok(try_stream! {
220 loop {
221 let message: ReceivedMessage = self.inner.stream.recv().await?;
222 let cnt = ServerContext::new(self.inner.stream.server_version);
223 match message.format {
224 BackendMessageFormat::ErrorResponse => {
225 let err: ErrorResponse = message.decode(&mut self.inner.stream, cnt).await?;
226 error = Some(err.error);
227 },
228 BackendMessageFormat::MessageResponse => {
229 let notice: MessageResponse = message.decode(&mut self.inner.stream, cnt).await?;
232 let (log_level, tracing_level) = (Level::Info, tracing::Level::INFO);
233 let log_is_enabled = log::log_enabled!(
234 target: "sqlx::xugu::notice",
235 log_level
236 ) || sqlx_core::private_tracing_dynamic_enabled!(
237 target: "sqlx::xugu::notice",
238 tracing_level
239 );
240 if log_is_enabled {
241 sqlx_core::private_tracing_dynamic_event!(
242 target: "sqlx::xugu::notice",
243 tracing_level,
244 message = notice.msg
245 );
246 }
247 },
248 BackendMessageFormat::ReadyForQuery => {
249 let _: ReadyForQuery = message.decode(&mut self.inner.stream, cnt).await?;
251 self.handle_ready_for_query().await?;
252 break;
253 },
254 BackendMessageFormat::InsertResponse => {
255 let res: InsertResponse = message.decode(&mut self.inner.stream, cnt).await?;
256 let rows_affected = 1;
257 logger.increase_rows_affected(rows_affected);
258 let done = XuguQueryResult {
259 rows_affected,
260 last_insert_id: Some(res.rowid),
261 };
262 r#yield!(Either::Left(done));
263 },
264 BackendMessageFormat::DeleteResponse => {
265 let res: DeleteResponse = message.decode(&mut self.inner.stream, cnt).await?;
266 let rows_affected = res.rows_affected as u64;
267 logger.increase_rows_affected(rows_affected);
268 let done = XuguQueryResult {
269 rows_affected,
270 last_insert_id: None,
271 };
272 r#yield!(Either::Left(done));
273 },
274 BackendMessageFormat::UpdateResponse => {
275 let res: UpdateResponse = message.decode(&mut self.inner.stream, cnt).await?;
276 let rows_affected = res.rows_affected as u64;
277 logger.increase_rows_affected(rows_affected);
278 let done = XuguQueryResult {
279 rows_affected,
280 last_insert_id: None,
281 };
282 r#yield!(Either::Left(done));
283 },
284 BackendMessageFormat::RowDescription => {
285 let row_columns: RowDescription = message.decode(&mut self.inner.stream, cnt).await?;
287 num_columns = row_columns.fields.len();
288 self.inner.last_num_columns = num_columns;
289 if needs_metadata {
290 let (columns_c, column_names_c) = row_columns.convert_columns()?;
291 columns = Arc::new(columns_c);
292 column_names = Arc::new(column_names_c);
293 } else {
294 needs_metadata = true;
297 }
298 },
299 BackendMessageFormat::ParameterDescription => {
300 let _: ParameterDescription = message.decode(&mut self.inner.stream, cnt).await?;
301 },
302 BackendMessageFormat::DataRow => {
303 let _: DataRow = message.decode(&mut self.inner.stream, cnt).await?;
305 let mut row = Vec::with_capacity(num_columns);
306 for _ in 0..num_columns {
307 let len = self.inner.stream.read_i32().await?;
308 let buf = self.inner.stream.read_bytes(len as usize).await?;
309 row.push(buf);
310 }
311 let row = Arc::new(row);
312
313 let v = Either::Right(XuguRow {
314 row,
315 columns: Arc::clone(&columns),
316 column_names: Arc::clone(&column_names),
317 });
318
319 logger.increment_rows_returned();
320
321 r#yield!(v);
322 }
323 }
324 }
325
326 if let Some(err) = error {
327 return Err(Error::Database(Box::new(XuguDatabaseError::from_str(&err))));
328 }
329
330 return Ok(());
331 })
332 }
333}
334
335impl<'c> Executor<'c> for &'c mut XuguConnection {
336 type Database = Xugu;
337
338 fn fetch_many<'e, 'q, E>(
340 self,
341 mut query: E,
342 ) -> BoxStream<'e, Result<Either<XuguQueryResult, XuguRow>, Error>>
343 where
344 'c: 'e,
345 E: Execute<'q, Self::Database>,
346 'q: 'e,
347 E: 'q,
348 {
349 let sql = query.sql();
350 let arguments = query.take_arguments().map_err(Error::Encode);
351 let persistent = query.persistent();
352
353 Box::pin(try_stream! {
354 let arguments = arguments?;
355 let mut s = pin!(self.run(sql, arguments, persistent).await?);
356
357 while let Some(v) = s.try_next().await? {
358 r#yield!(v);
359 }
360
361 Ok(())
362 })
363 }
364
365 fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result<Option<XuguRow>, Error>>
367 where
368 'c: 'e,
369 E: Execute<'q, Self::Database>,
370 'q: 'e,
371 E: 'q,
372 {
373 let mut s = self.fetch_many(query);
374
375 Box::pin(async move {
376 while let Some(v) = s.try_next().await? {
377 if let Either::Right(r) = v {
378 return Ok(Some(r));
379 }
380 }
381
382 Ok(None)
383 })
384 }
385
386 fn prepare_with<'e, 'q: 'e>(
390 self,
391 sql: &'q str,
392 _parameters: &'e [XuguTypeInfo],
393 ) -> BoxFuture<'e, Result<XuguStatement<'q>, Error>>
394 where
395 'c: 'e,
396 {
397 Box::pin(async move {
398 self.wait_until_ready().await?;
399
400 let metadata = if self.inner.cache_statement.is_enabled() {
401 self.get_or_prepare_statement(sql).await?.1
402 } else {
403 let (id, metadata) = self.prepare_statement(sql).await?;
404
405 self.inner
406 .stream
407 .send_packet(StmtClose {
408 con_obj_name: &self.inner.con_obj_name,
409 st_id: id,
410 })
411 .await?;
412 let _ok: OkPacket = self.inner.stream.recv().await?;
414
415 metadata
416 };
417
418 Ok(XuguStatement {
419 sql: Cow::Borrowed(sql),
420 metadata: metadata.clone(),
422 })
423 })
424 }
425
426 #[doc(hidden)]
430 fn describe<'e, 'q: 'e>(self, sql: &'q str) -> BoxFuture<'e, Result<Describe<Xugu>, Error>>
431 where
432 'c: 'e,
433 {
434 Box::pin(async move {
435 self.wait_until_ready().await?;
436
437 let (id, metadata) = self.prepare_statement(sql).await?;
438
439 self.inner
440 .stream
441 .send_packet(StmtClose {
442 con_obj_name: &self.inner.con_obj_name,
443 st_id: id,
444 })
445 .await?;
446 let _ok: OkPacket = self.inner.stream.recv().await?;
448
449 let columns = (*metadata.columns).clone();
450
451 let nullable = columns
452 .iter()
453 .map(|col| {
454 col.flags
455 .map(|flags| !flags.contains(ColumnFlags::NOT_NULL))
456 })
457 .collect();
458
459 Ok(Describe {
460 parameters: Some(Either::Right(metadata.parameters.len())),
461 columns,
462 nullable,
463 })
464 })
465 }
466}