1use crate::connection::StatementId;
2use crate::error::Error;
3use crate::io::AsyncStreamExt;
4use crate::protocol::message::*;
5use crate::protocol::statement::{Execute as StatementExecute, Prepare, StmtClose};
6use crate::protocol::text::{ColumnFlags, OkPacket, Query};
7use crate::protocol::ServerContext;
8use crate::statement::{XuguStatement, XuguStatementMetadata};
9use crate::{
10 Xugu, XuguArguments, XuguConnection, XuguDatabaseError, XuguQueryResult, XuguRow, XuguTypeInfo,
11};
12use futures_core::future::BoxFuture;
13use futures_core::stream::BoxStream;
14use futures_core::Stream;
15use futures_util::TryStreamExt;
16use log::Level;
17use sqlx_core::describe::Describe;
18use sqlx_core::executor::{Execute, Executor};
19use sqlx_core::logger::QueryLogger;
20use sqlx_core::{try_stream, Either, HashMap};
21use std::{borrow::Cow, pin::pin, sync::Arc};
22
23impl XuguConnection {
24 async fn prepare_statement<'c>(
25 &mut self,
26 sql: &str,
27 ) -> Result<(StatementId, XuguStatementMetadata), Error> {
28 self.wait_until_ready().await?;
30
31 let id = self.inner.gen_st_id();
32 self.inner
33 .stream
34 .send_packet(Prepare {
35 query: sql,
36 st_id: id,
37 })
38 .await?;
39
40 let mut error = None;
41 let mut columns = Vec::new();
42 let mut column_names = HashMap::new();
43 let mut params = Vec::new();
44
45 loop {
46 let message: ReceivedMessage = self.inner.stream.recv().await?;
47 let cnt = ServerContext::new(self.inner.stream.server_version);
48 match message.format {
49 BackendMessageFormat::ErrorResponse => {
50 let err: ErrorResponse = message.decode(&mut self.inner.stream, cnt).await?;
51 error = Some(err.error);
52 }
53 BackendMessageFormat::MessageResponse => {
54 let notice: MessageResponse =
57 message.decode(&mut self.inner.stream, cnt).await?;
58 let (log_level, tracing_level) = (Level::Info, tracing::Level::INFO);
59 let log_is_enabled = log::log_enabled!(
60 target: "sqlx::xugu::notice",
61 log_level
62 ) || sqlx_core::private_tracing_dynamic_enabled!(
63 target: "sqlx::xugu::notice",
64 tracing_level
65 );
66 if log_is_enabled {
67 sqlx_core::private_tracing_dynamic_event!(
68 target: "sqlx::xugu::notice",
69 tracing_level,
70 message = notice.msg
71 );
72 }
73 }
74 BackendMessageFormat::ReadyForQuery => {
75 let _: ReadyForQuery = message.decode(&mut self.inner.stream, cnt).await?;
76 break;
77 }
78 BackendMessageFormat::RowDescription => {
79 let row_columns: RowDescription =
80 message.decode(&mut self.inner.stream, cnt).await?;
81 (columns, column_names) = row_columns.convert_columns()?;
82 }
83 BackendMessageFormat::ParameterDescription => {
84 let param_def: ParameterDescription =
85 message.decode(&mut self.inner.stream, cnt).await?;
86 params = param_def.params;
87 }
88 _ => {
89 break;
90 }
91 }
92 }
93
94 if let Some(err) = error {
95 return Err(Error::Database(Box::new(XuguDatabaseError::from_str(&err))));
96 }
97
98 let metadata = XuguStatementMetadata {
99 parameters: Arc::new(params),
100 columns: Arc::new(columns),
101 column_names: Arc::new(column_names),
102 };
103
104 Ok((id, metadata))
105 }
106
107 async fn get_or_prepare_statement<'c>(
108 &mut self,
109 sql: &str,
110 ) -> Result<(StatementId, XuguStatementMetadata), Error> {
111 if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
112 return Ok((*statement).clone());
114 }
115
116 let (id, metadata) = self.prepare_statement(sql).await?;
117
118 if let Some((id, _)) = self
120 .inner
121 .cache_statement
122 .insert(sql, (id, metadata.clone()))
123 {
124 self.wait_until_ready().await?;
126 self.inner.stream.send_packet(StmtClose(id)).await?;
127
128 let _ok: OkPacket = self.inner.stream.recv().await?;
130 }
131
132 Ok((id, metadata))
133 }
134
135 #[allow(clippy::needless_lifetimes)]
144 pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
145 &'c mut self,
146 sql: &'q str,
147 arguments: Option<XuguArguments<'q>>,
148 persistent: bool,
149 ) -> Result<impl Stream<Item = Result<Either<XuguQueryResult, XuguRow>, Error>> + 'e, Error>
150 {
151 let mut logger = QueryLogger::new(sql, self.inner.log_settings.clone());
152
153 self.wait_until_ready().await?;
154
155 let (mut column_names, mut columns, mut needs_metadata) = if let Some(arguments) = arguments
159 {
160 if persistent && self.inner.cache_statement.is_enabled() {
161 let (id, metadata) = self.get_or_prepare_statement(sql).await?;
162
163 self.inner
164 .stream
165 .send_packet(StatementExecute {
166 st_id: id,
167 arguments: &arguments,
168 params: &metadata.parameters,
169 })
170 .await?;
171
172 let needs_metadata = metadata.column_names.is_empty();
173 (metadata.column_names, metadata.columns, needs_metadata)
174 } else {
175 let (id, metadata) = self.prepare_statement(sql).await?;
176
177 self.inner
178 .stream
179 .send_packet(StatementExecute {
180 st_id: id,
181 arguments: &arguments,
182 params: &metadata.parameters,
183 })
184 .await?;
185
186 self.inner.stream.send_packet(StmtClose(id)).await?;
187 self.inner.pending_ready_for_query_count += 1;
189
190 let needs_metadata = metadata.column_names.is_empty();
191 (metadata.column_names, metadata.columns, needs_metadata)
192 }
193 } else {
194 self.inner.stream.send_packet(Query(sql)).await?;
195
196 (Arc::default(), Arc::default(), true)
197 };
198
199 self.inner.pending_ready_for_query_count += 1;
200
201 let mut error = None;
202
203 let mut num_columns = 0;
204
205 Ok(try_stream! {
206 loop {
207 let message: ReceivedMessage = self.inner.stream.recv().await?;
208 let cnt = ServerContext::new(self.inner.stream.server_version);
209 match message.format {
210 BackendMessageFormat::ErrorResponse => {
211 let err: ErrorResponse = message.decode(&mut self.inner.stream, cnt).await?;
212 error = Some(err.error);
213 },
214 BackendMessageFormat::MessageResponse => {
215 let notice: MessageResponse = message.decode(&mut self.inner.stream, cnt).await?;
218 let (log_level, tracing_level) = (Level::Info, tracing::Level::INFO);
219 let log_is_enabled = log::log_enabled!(
220 target: "sqlx::xugu::notice",
221 log_level
222 ) || sqlx_core::private_tracing_dynamic_enabled!(
223 target: "sqlx::xugu::notice",
224 tracing_level
225 );
226 if log_is_enabled {
227 sqlx_core::private_tracing_dynamic_event!(
228 target: "sqlx::xugu::notice",
229 tracing_level,
230 message = notice.msg
231 );
232 }
233 },
234 BackendMessageFormat::ReadyForQuery => {
235 let _: ReadyForQuery = message.decode(&mut self.inner.stream, cnt).await?;
237 self.handle_ready_for_query().await?;
238 break;
239 },
240 BackendMessageFormat::InsertResponse => {
241 let res: InsertResponse = message.decode(&mut self.inner.stream, cnt).await?;
242 let rows_affected = 1;
243 logger.increase_rows_affected(rows_affected);
244 let done = XuguQueryResult {
245 rows_affected,
246 last_insert_id: Some(res.rowid),
247 };
248 r#yield!(Either::Left(done));
249 },
250 BackendMessageFormat::DeleteResponse => {
251 let res: DeleteResponse = message.decode(&mut self.inner.stream, cnt).await?;
252 let rows_affected = res.rows_affected as u64;
253 logger.increase_rows_affected(rows_affected);
254 let done = XuguQueryResult {
255 rows_affected,
256 last_insert_id: None,
257 };
258 r#yield!(Either::Left(done));
259 },
260 BackendMessageFormat::UpdateResponse => {
261 let res: UpdateResponse = message.decode(&mut self.inner.stream, cnt).await?;
262 let rows_affected = res.rows_affected as u64;
263 logger.increase_rows_affected(rows_affected);
264 let done = XuguQueryResult {
265 rows_affected,
266 last_insert_id: None,
267 };
268 r#yield!(Either::Left(done));
269 },
270 BackendMessageFormat::RowDescription => {
271 let row_columns: RowDescription = message.decode(&mut self.inner.stream, cnt).await?;
273 num_columns = row_columns.fields.len();
274 self.inner.last_num_columns = num_columns;
275 if needs_metadata {
276 let (columns_c, column_names_c) = row_columns.convert_columns()?;
277 columns = Arc::new(columns_c);
278 column_names = Arc::new(column_names_c);
279 } else {
280 needs_metadata = true;
283 }
284 },
285 BackendMessageFormat::ParameterDescription => {
286 let _: ParameterDescription = message.decode(&mut self.inner.stream, cnt).await?;
287 },
288 BackendMessageFormat::DataRow => {
289 let _: DataRow = message.decode(&mut self.inner.stream, cnt).await?;
291 let mut row = Vec::with_capacity(num_columns);
292 for _ in 0..num_columns {
293 let len = self.inner.stream.read_i32().await?;
294 let buf = self.inner.stream.read_bytes(len as usize).await?;
295 row.push(buf);
296 }
297 let row = Arc::new(row);
298
299 let v = Either::Right(XuguRow {
300 row,
301 columns: Arc::clone(&columns),
302 column_names: Arc::clone(&column_names),
303 });
304
305 logger.increment_rows_returned();
306
307 r#yield!(v);
308 }
309 }
310 }
311
312 if let Some(err) = error {
313 return Err(Error::Database(Box::new(XuguDatabaseError::from_str(&err))));
314 }
315
316 return Ok(());
317 })
318 }
319}
320
321impl<'c> Executor<'c> for &'c mut XuguConnection {
322 type Database = Xugu;
323
324 fn fetch_many<'e, 'q, E>(
326 self,
327 mut query: E,
328 ) -> BoxStream<'e, Result<Either<XuguQueryResult, XuguRow>, Error>>
329 where
330 'c: 'e,
331 E: Execute<'q, Self::Database>,
332 'q: 'e,
333 E: 'q,
334 {
335 let sql = query.sql();
336 let arguments = query.take_arguments().map_err(Error::Encode);
337 let persistent = query.persistent();
338
339 Box::pin(try_stream! {
340 let arguments = arguments?;
341 let mut s = pin!(self.run(sql, arguments, persistent).await?);
342
343 while let Some(v) = s.try_next().await? {
344 r#yield!(v);
345 }
346
347 Ok(())
348 })
349 }
350
351 fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result<Option<XuguRow>, Error>>
353 where
354 'c: 'e,
355 E: Execute<'q, Self::Database>,
356 'q: 'e,
357 E: 'q,
358 {
359 let mut s = self.fetch_many(query);
360
361 Box::pin(async move {
362 while let Some(v) = s.try_next().await? {
363 if let Either::Right(r) = v {
364 return Ok(Some(r));
365 }
366 }
367
368 Ok(None)
369 })
370 }
371
372 fn prepare_with<'e, 'q: 'e>(
376 self,
377 sql: &'q str,
378 _parameters: &'e [XuguTypeInfo],
379 ) -> BoxFuture<'e, Result<XuguStatement<'q>, Error>>
380 where
381 'c: 'e,
382 {
383 Box::pin(async move {
384 self.wait_until_ready().await?;
385
386 let metadata = if self.inner.cache_statement.is_enabled() {
387 self.get_or_prepare_statement(sql).await?.1
388 } else {
389 let (id, metadata) = self.prepare_statement(sql).await?;
390
391 self.inner.stream.send_packet(StmtClose(id)).await?;
392 let _ok: OkPacket = self.inner.stream.recv().await?;
394
395 metadata
396 };
397
398 Ok(XuguStatement {
399 sql: Cow::Borrowed(sql),
400 metadata: metadata.clone(),
402 })
403 })
404 }
405
406 #[doc(hidden)]
410 fn describe<'e, 'q: 'e>(self, sql: &'q str) -> BoxFuture<'e, Result<Describe<Xugu>, Error>>
411 where
412 'c: 'e,
413 {
414 Box::pin(async move {
415 self.wait_until_ready().await?;
416
417 let (id, metadata) = self.prepare_statement(sql).await?;
418
419 self.inner.stream.send_packet(StmtClose(id)).await?;
420 let _ok: OkPacket = self.inner.stream.recv().await?;
422
423 let columns = (*metadata.columns).clone();
424
425 let nullable = columns
426 .iter()
427 .map(|col| {
428 col.flags
429 .map(|flags| !flags.contains(ColumnFlags::NOT_NULL))
430 })
431 .collect();
432
433 Ok(Describe {
434 parameters: Some(Either::Right(metadata.parameters.len())),
435 columns,
436 nullable,
437 })
438 })
439 }
440}