1use super::MySqlStream;
2use crate::connection::stream::Waiting;
3use crate::describe::Describe;
4use crate::error::Error;
5use crate::executor::{Execute, Executor};
6use crate::ext::ustr::UStr;
7use crate::io::MySqlBufExt;
8use crate::logger::QueryLogger;
9use crate::protocol::response::Status;
10use crate::protocol::statement::{
11 BinaryRow, Execute as StatementExecute, Prepare, PrepareOk, StmtClose,
12};
13use crate::protocol::text::{ColumnDefinition, ColumnFlags, Query, TextRow};
14use crate::statement::{MySqlStatement, MySqlStatementMetadata};
15use crate::HashMap;
16use crate::{
17 MySql, MySqlArguments, MySqlColumn, MySqlConnection, MySqlQueryResult, MySqlRow, MySqlTypeInfo,
18 MySqlValueFormat,
19};
20use either::Either;
21use futures_core::future::BoxFuture;
22use futures_core::stream::BoxStream;
23use futures_core::Stream;
24use futures_util::TryStreamExt;
25use sqlx_core::column::{ColumnOrigin, TableColumn};
26use sqlx_core::sql_str::SqlStr;
27use std::{pin::pin, sync::Arc};
28
29impl MySqlConnection {
30 async fn prepare_statement(
31 &mut self,
32 sql: &str,
33 ) -> Result<(u32, MySqlStatementMetadata), Error> {
34 self.inner
38 .stream
39 .send_packet(Prepare { query: sql })
40 .await?;
41
42 let ok: PrepareOk = self.inner.stream.recv().await?;
43
44 if ok.params > 0 {
48 for _ in 0..ok.params {
49 let _def: ColumnDefinition = self.inner.stream.recv().await?;
50 }
51
52 self.inner.stream.maybe_recv_eof().await?;
53 }
54
55 let mut columns = Vec::new();
60
61 let column_names = if ok.columns > 0 {
62 recv_result_metadata(&mut self.inner.stream, ok.columns as usize, &mut columns).await?
63 } else {
64 Default::default()
65 };
66
67 let id = ok.statement_id;
68 let metadata = MySqlStatementMetadata {
69 parameters: ok.params as usize,
70 columns: Arc::new(columns),
71 column_names: Arc::new(column_names),
72 };
73
74 Ok((id, metadata))
75 }
76
77 async fn get_or_prepare_statement(
78 &mut self,
79 sql: &str,
80 ) -> Result<(u32, MySqlStatementMetadata), Error> {
81 if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
82 return Ok((*statement).clone());
84 }
85
86 let (id, metadata) = self.prepare_statement(sql).await?;
87
88 if let Some((id, _)) = self
90 .inner
91 .cache_statement
92 .insert(sql, (id, metadata.clone()))
93 {
94 self.inner
95 .stream
96 .send_packet(StmtClose { statement: id })
97 .await?;
98 }
99
100 Ok((id, metadata))
101 }
102
103 #[allow(clippy::needless_lifetimes)]
104 pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
105 &'c mut self,
106 sql: SqlStr,
107 arguments: Option<MySqlArguments>,
108 persistent: bool,
109 ) -> Result<impl Stream<Item = Result<Either<MySqlQueryResult, MySqlRow>, Error>> + 'e, Error>
110 {
111 let mut logger = QueryLogger::new(sql, self.inner.log_settings.clone());
112
113 self.inner.stream.wait_until_ready().await?;
114 self.inner.stream.waiting.push_back(Waiting::Result);
115
116 Ok(try_stream! {
117 let sql = logger.sql().as_str();
118
119 let mut columns = Arc::new(Vec::new());
123
124 let (mut column_names, format, mut needs_metadata) = if let Some(arguments) = arguments {
125 if persistent && self.inner.cache_statement.is_enabled() {
126 let (id, metadata) = self
127 .get_or_prepare_statement(sql)
128 .await?;
129
130 if arguments.types.len() != metadata.parameters {
131 return Err(
132 err_protocol!(
133 "prepared statement expected {} parameters but {} parameters were provided",
134 metadata.parameters,
135 arguments.types.len()
136 )
137 );
138 }
139
140 self.inner.stream
142 .send_packet(StatementExecute {
143 statement: id,
144 arguments: &arguments,
145 })
146 .await?;
147
148 (metadata.column_names, MySqlValueFormat::Binary, false)
149 } else {
150 let (id, metadata) = self
151 .prepare_statement(sql)
152 .await?;
153
154 if arguments.types.len() != metadata.parameters {
155 return Err(
156 err_protocol!(
157 "prepared statement expected {} parameters but {} parameters were provided",
158 metadata.parameters,
159 arguments.types.len()
160 )
161 );
162 }
163
164 self.inner.stream
166 .send_packet(StatementExecute {
167 statement: id,
168 arguments: &arguments,
169 })
170 .await?;
171
172 self.inner.stream.send_packet(StmtClose { statement: id }).await?;
173
174 (metadata.column_names, MySqlValueFormat::Binary, false)
175 }
176 } else {
177 self.inner.stream.send_packet(Query(sql)).await?;
179
180 (Arc::default(), MySqlValueFormat::Text, true)
181 };
182
183 loop {
184 let mut packet = self.inner.stream.recv_packet().await?;
187
188 if packet[0] == 0x00 || packet[0] == 0xff {
189 let ok = packet.ok()?;
192
193 self.inner.status_flags = ok.status;
194
195 let rows_affected = ok.affected_rows;
196 logger.increase_rows_affected(rows_affected);
197 let done = MySqlQueryResult {
198 rows_affected,
199 last_insert_id: ok.last_insert_id,
200 };
201
202 r#yield!(Either::Left(done));
203
204 if ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
205 continue;
207 }
208
209 self.inner.stream.waiting.pop_front();
210 return Ok(());
211 }
212
213 *self.inner.stream.waiting.front_mut().unwrap() = Waiting::Row;
215
216 let num_columns = packet.get_uint_lenenc(); let num_columns = usize::try_from(num_columns)
218 .map_err(|_| err_protocol!("column count overflows usize: {num_columns}"))?;
219
220 if needs_metadata {
221 column_names = Arc::new(recv_result_metadata(&mut self.inner.stream, num_columns, Arc::make_mut(&mut columns)).await?);
222 } else {
223 needs_metadata = true;
226
227 recv_result_columns(&mut self.inner.stream, num_columns, Arc::make_mut(&mut columns)).await?;
228 }
229
230 loop {
232 let packet = self.inner.stream.recv_packet().await?;
233
234 if packet[0] == 0xfe {
235 let (rows_affected, last_insert_id, status) = if packet.len() < 9 {
236 let eof = packet.eof(self.inner.stream.capabilities)?;
238 (0, 0, eof.status)
239 } else {
240 let ok = packet.ok()?;
242 (ok.affected_rows, ok.last_insert_id, ok.status)
243 };
244
245 self.inner.status_flags = status;
246 r#yield!(Either::Left(MySqlQueryResult {
247 rows_affected,
248 last_insert_id,
249 }));
250
251 if status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
252 *self.inner.stream.waiting.front_mut().unwrap() = Waiting::Result;
253 break;
254 }
255 self.inner.stream.waiting.pop_front();
256 return Ok(());
257 }
258
259 let row = match format {
260 MySqlValueFormat::Binary => packet.decode_with::<BinaryRow, _>(&columns)?.0,
261 MySqlValueFormat::Text => packet.decode_with::<TextRow, _>(&columns)?.0,
262 };
263
264 let v = Either::Right(MySqlRow {
265 row,
266 format,
267 columns: Arc::clone(&columns),
268 column_names: Arc::clone(&column_names),
269 });
270
271 logger.increment_rows_returned();
272
273 r#yield!(v);
274 }
275 }
276 })
277 }
278}
279
280impl<'c> Executor<'c> for &'c mut MySqlConnection {
281 type Database = MySql;
282
283 fn fetch_many<'e, 'q, E>(
284 self,
285 mut query: E,
286 ) -> BoxStream<'e, Result<Either<MySqlQueryResult, MySqlRow>, Error>>
287 where
288 'c: 'e,
289 E: Execute<'q, Self::Database>,
290 'q: 'e,
291 E: 'q,
292 {
293 let arguments = query.take_arguments().map_err(Error::Encode);
294 let persistent = query.persistent();
295
296 Box::pin(try_stream! {
297 let sql = query.sql();
298 let arguments = arguments?;
299 let mut s = pin!(self.run(sql, arguments, persistent).await?);
300
301 while let Some(v) = s.try_next().await? {
302 r#yield!(v);
303 }
304
305 Ok(())
306 })
307 }
308
309 fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result<Option<MySqlRow>, Error>>
310 where
311 'c: 'e,
312 E: Execute<'q, Self::Database>,
313 'q: 'e,
314 E: 'q,
315 {
316 let mut s = self.fetch_many(query);
317
318 Box::pin(async move {
319 while let Some(v) = s.try_next().await? {
320 if let Either::Right(r) = v {
321 return Ok(Some(r));
322 }
323 }
324
325 Ok(None)
326 })
327 }
328
329 fn prepare_with<'e>(
330 self,
331 sql: SqlStr,
332 _parameters: &'e [MySqlTypeInfo],
333 ) -> BoxFuture<'e, Result<MySqlStatement, Error>>
334 where
335 'c: 'e,
336 {
337 Box::pin(async move {
338 self.inner.stream.wait_until_ready().await?;
339
340 let metadata = if self.inner.cache_statement.is_enabled() {
341 self.get_or_prepare_statement(sql.as_str()).await?.1
342 } else {
343 let (id, metadata) = self.prepare_statement(sql.as_str()).await?;
344
345 self.inner
346 .stream
347 .send_packet(StmtClose { statement: id })
348 .await?;
349
350 metadata
351 };
352
353 Ok(MySqlStatement {
354 sql,
355 metadata: metadata.clone(),
357 })
358 })
359 }
360
361 #[doc(hidden)]
362 fn describe<'e>(self, sql: SqlStr) -> BoxFuture<'e, Result<Describe<MySql>, Error>>
363 where
364 'c: 'e,
365 {
366 Box::pin(async move {
367 self.inner.stream.wait_until_ready().await?;
368
369 let (id, metadata) = self.prepare_statement(sql.as_str()).await?;
370
371 self.inner
372 .stream
373 .send_packet(StmtClose { statement: id })
374 .await?;
375
376 let columns = (*metadata.columns).clone();
377
378 let nullable = columns
379 .iter()
380 .map(|col| {
381 col.flags
382 .map(|flags| !flags.contains(ColumnFlags::NOT_NULL))
383 })
384 .collect();
385
386 Ok(Describe {
387 parameters: Some(Either::Right(metadata.parameters)),
388 columns,
389 nullable,
390 })
391 })
392 }
393}
394
395async fn recv_result_columns(
396 stream: &mut MySqlStream,
397 num_columns: usize,
398 columns: &mut Vec<MySqlColumn>,
399) -> Result<(), Error> {
400 columns.clear();
401 columns.reserve(num_columns);
402
403 for ordinal in 0..num_columns {
404 columns.push(recv_next_result_column(&stream.recv().await?, ordinal)?);
405 }
406
407 if num_columns > 0 {
408 stream.maybe_recv_eof().await?;
409 }
410
411 Ok(())
412}
413
414fn recv_next_result_column(def: &ColumnDefinition, ordinal: usize) -> Result<MySqlColumn, Error> {
415 let column_name = def.name()?;
418
419 let name = match (def.name()?, def.alias()?) {
420 (_, alias) if !alias.is_empty() => UStr::new(alias),
421 (name, _) => UStr::new(name),
422 };
423
424 let table = def.table()?;
425
426 let origin = if table.is_empty() {
427 ColumnOrigin::Expression
428 } else {
429 let schema = def.schema()?;
430
431 ColumnOrigin::Table(TableColumn {
432 table: if !schema.is_empty() {
433 format!("{schema}.{table}").into()
434 } else {
435 table.into()
436 },
437 name: column_name.into(),
438 })
439 };
440
441 let type_info = MySqlTypeInfo::from_column(def);
442
443 Ok(MySqlColumn {
444 name,
445 type_info,
446 ordinal,
447 flags: Some(def.flags),
448 origin,
449 })
450}
451
452async fn recv_result_metadata(
453 stream: &mut MySqlStream,
454 num_columns: usize,
455 columns: &mut Vec<MySqlColumn>,
456) -> Result<HashMap<UStr, usize>, Error> {
457 let mut column_names = HashMap::with_capacity(num_columns);
461
462 columns.clear();
463 columns.reserve(num_columns);
464
465 for ordinal in 0..num_columns {
466 let def: ColumnDefinition = stream.recv().await?;
467
468 let column = recv_next_result_column(&def, ordinal)?;
469
470 column_names.insert(column.name.clone(), ordinal);
471 columns.push(column);
472 }
473
474 stream.maybe_recv_eof().await?;
475
476 Ok(column_names)
477}