1use tokio::net::{TcpStream, UnixStream};
2use tracing::instrument;
3use zerocopy::{FromBytes, FromZeros, IntoBytes};
4
5use crate::PreparedStatement;
6use crate::buffer::BufferSet;
7use crate::buffer_pool::PooledBufferSet;
8use crate::constant::CapabilityFlags;
9use crate::error::{Error, Result};
10use crate::protocol::TextRowPayload;
11use crate::protocol::command::Action;
12use crate::protocol::command::ColumnDefinition;
13use crate::protocol::command::bulk_exec::{BulkExec, BulkFlags, BulkParamsSet, write_bulk_execute};
14use crate::protocol::command::prepared::{Exec, read_prepare_ok, write_execute, write_prepare};
15use crate::protocol::command::query::{Query, write_query};
16use crate::protocol::command::utility::{
17 DropHandler, FirstRowHandler, write_ping, write_reset_connection,
18};
19use crate::protocol::connection::{Handshake, HandshakeAction, InitialHandshake};
20use crate::protocol::packet::PacketHeader;
21use crate::protocol::primitive::read_string_lenenc;
22use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
23use crate::protocol::r#trait::{BinaryResultSetHandler, TextResultSetHandler, param::Params};
24
25use super::stream::Stream;
26
27pub struct Conn {
28 stream: Stream,
29 buffer_set: PooledBufferSet,
30 initial_handshake: InitialHandshake,
31 capability_flags: CapabilityFlags,
32 mariadb_capabilities: crate::constant::MariadbCapabilityFlags,
33 in_transaction: bool,
34 is_broken: bool,
35}
36
37impl Conn {
38 pub async fn new<O: TryInto<crate::opts::Opts>>(opts: O) -> Result<Self>
40 where
41 Error: From<O::Error>,
42 {
43 let opts: crate::opts::Opts = opts.try_into()?;
44
45 let stream = if let Some(socket_path) = &opts.socket {
46 let stream = UnixStream::connect(socket_path).await?;
47 Stream::unix(stream)
48 } else {
49 if opts.host.is_empty() {
50 return Err(Error::BadUsageError(
51 "Missing host in connection options".to_string(),
52 ));
53 }
54 let addr = format!("{}:{}", opts.host, opts.port);
55 let stream = TcpStream::connect(&addr).await?;
56 stream.set_nodelay(opts.tcp_nodelay)?;
57 Stream::tcp(stream)
58 };
59
60 Self::new_with_stream(stream, &opts).await
61 }
62
63 pub async fn new_with_stream(stream: Stream, opts: &crate::opts::Opts) -> Result<Self> {
65 let mut conn_stream = stream;
66 let mut buffer_set = opts.buffer_pool.get_buffer_set();
67
68 #[cfg(feature = "tokio-tls")]
69 let host = opts.host.clone();
70
71 let mut handshake = Handshake::new(opts);
72
73 loop {
74 match handshake.step(&mut buffer_set)? {
75 HandshakeAction::ReadPacket(buffer) => {
76 buffer.clear();
77 read_payload(&mut conn_stream, buffer).await?;
78 }
79 HandshakeAction::WritePacket { sequence_id } => {
80 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id).await?;
81 buffer_set.read_buffer.clear();
82 read_payload(&mut conn_stream, &mut buffer_set.read_buffer).await?;
83 }
84 #[cfg(feature = "tokio-tls")]
85 HandshakeAction::UpgradeTls { sequence_id } => {
86 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id).await?;
87 conn_stream = conn_stream.upgrade_to_tls(&host).await?;
88 }
89 #[cfg(not(feature = "tokio-tls"))]
90 HandshakeAction::UpgradeTls { .. } => {
91 return Err(Error::BadUsageError(
92 "TLS requested but tokio-tls feature is not enabled".to_string(),
93 ));
94 }
95 HandshakeAction::Finished => break,
96 }
97 }
98
99 let (initial_handshake, capability_flags, mariadb_capabilities) = handshake.finish()?;
100
101 let conn = Self {
102 stream: conn_stream,
103 buffer_set,
104 initial_handshake,
105 capability_flags,
106 mariadb_capabilities,
107 in_transaction: false,
108 is_broken: false,
109 };
110
111 let mut conn = if opts.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
113 conn.try_upgrade_to_unix_socket(opts).await
114 } else {
115 conn
116 };
117
118 if let Some(init_command) = &opts.init_command {
120 conn.query_drop(init_command).await?;
121 }
122
123 Ok(conn)
124 }
125
126 pub fn server_version(&self) -> &[u8] {
127 &self.buffer_set.initial_handshake[self.initial_handshake.server_version.clone()]
128 }
129
130 pub fn capability_flags(&self) -> CapabilityFlags {
132 self.capability_flags
133 }
134
135 pub fn is_mysql(&self) -> bool {
137 self.capability_flags.is_mysql()
138 }
139
140 pub fn is_mariadb(&self) -> bool {
142 self.capability_flags.is_mariadb()
143 }
144
145 pub fn connection_id(&self) -> u64 {
147 self.initial_handshake.connection_id as u64
148 }
149
150 pub fn status_flags(&self) -> crate::constant::ServerStatusFlags {
152 self.initial_handshake.status_flags
153 }
154
155 pub fn is_broken(&self) -> bool {
157 self.is_broken
158 }
159
160 #[inline]
161 fn check_error<T>(&mut self, result: Result<T>) -> Result<T> {
162 if let Err(e) = &result
163 && e.is_conn_broken()
164 {
165 self.is_broken = true;
166 }
167 result
168 }
169
170 pub(crate) fn set_in_transaction(&mut self, value: bool) {
171 self.in_transaction = value;
172 }
173
174 async fn try_upgrade_to_unix_socket(mut self, opts: &crate::opts::Opts) -> Self {
177 let mut handler = SocketPathHandler { path: None };
179 if self.query("SELECT @@socket", &mut handler).await.is_err() {
180 return self;
181 }
182
183 let socket_path = match handler.path {
184 Some(p) if !p.is_empty() => p,
185 _ => return self,
186 };
187
188 let unix_stream = match UnixStream::connect(&socket_path).await {
190 Ok(s) => s,
191 Err(_) => return self,
192 };
193 let stream = Stream::unix(unix_stream);
194
195 let mut opts_unix = opts.clone();
198 opts_unix.upgrade_to_unix_socket = false;
199
200 match Box::pin(Self::new_with_stream(stream, &opts_unix)).await {
201 Ok(new_conn) => new_conn,
202 Err(_) => self,
203 }
204 }
205
206 #[instrument(skip_all)]
208 async fn write_payload(&mut self) -> Result<()> {
209 let mut sequence_id = 0_u8;
210 let mut buffer = self.buffer_set.write_buffer_mut().as_mut_slice();
211
212 loop {
213 let chunk_size = buffer[4..].len().min(0xFFFFFF);
214 PacketHeader::mut_from_bytes(&mut buffer[0..4])?
215 .encode_in_place(chunk_size, sequence_id);
216 self.stream.write_all(&buffer[..4 + chunk_size]).await?;
217
218 if chunk_size < 0xFFFFFF {
219 break;
220 }
221
222 sequence_id = sequence_id.wrapping_add(1);
223 buffer = &mut buffer[0xFFFFFF..];
224 }
225 self.stream.flush().await?;
226 Ok(())
227 }
228
229 pub async fn prepare(&mut self, sql: &str) -> Result<PreparedStatement> {
233 let result = self.prepare_inner(sql).await;
234 self.check_error(result)
235 }
236
237 async fn prepare_inner(&mut self, sql: &str) -> Result<PreparedStatement> {
238 use crate::protocol::command::ColumnDefinitions;
239
240 self.buffer_set.read_buffer.clear();
241
242 write_prepare(self.buffer_set.new_write_buffer(), sql);
243
244 self.write_payload().await?;
245
246 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
247
248 if !self.buffer_set.read_buffer.is_empty() && self.buffer_set.read_buffer[0] == 0xFF {
249 Err(ErrPayloadBytes(&self.buffer_set.read_buffer))?
250 }
251
252 let prepare_ok = read_prepare_ok(&self.buffer_set.read_buffer)?;
253 let statement_id = prepare_ok.statement_id();
254 let num_params = prepare_ok.num_params();
255 let num_columns = prepare_ok.num_columns();
256
257 for _ in 0..num_params {
259 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
260 }
261
262 let column_definitions = if num_columns > 0 {
264 self.read_column_definition_packets(num_columns as usize)
265 .await?;
266 Some(ColumnDefinitions::new(
267 num_columns as usize,
268 std::mem::take(&mut self.buffer_set.column_definition_buffer),
269 )?)
270 } else {
271 None
272 };
273
274 let mut stmt = PreparedStatement::new(statement_id);
275 if let Some(col_defs) = column_definitions {
276 stmt.set_column_definitions(col_defs);
277 }
278 Ok(stmt)
279 }
280
281 #[tracing::instrument(skip_all)]
282 async fn read_column_definition_packets(&mut self, num_columns: usize) -> Result<u8> {
283 let mut header = PacketHeader::new_zeroed();
284 let out = &mut self.buffer_set.column_definition_buffer;
285 out.clear();
286
287 for _ in 0..num_columns {
289 self.stream.read_exact(header.as_mut_bytes()).await?;
290 let length = header.length();
291 out.extend((length as u32).to_ne_bytes());
292
293 out.reserve(length);
294 let spare = out.spare_capacity_mut();
295 self.stream.read_buf_exact(&mut spare[..length]).await?;
296 unsafe {
298 out.set_len(out.len() + length);
299 }
300 }
301
302 Ok(header.sequence_id)
303 }
304
305 async fn drive_exec<H: BinaryResultSetHandler>(
306 &mut self,
307 stmt: &mut crate::PreparedStatement,
308 handler: &mut H,
309 ) -> Result<()> {
310 let cache_metadata = self
311 .mariadb_capabilities
312 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
313 let mut exec = Exec::new(handler, stmt, cache_metadata);
314
315 loop {
316 match exec.step(&mut self.buffer_set)? {
317 Action::NeedPacket(buffer) => {
318 buffer.clear();
319 let _ = read_payload(&mut self.stream, buffer).await?;
320 }
321 Action::ReadColumnMetadata { num_columns } => {
322 self.read_column_definition_packets(num_columns).await?;
323 }
324 Action::Finished => return Ok(()),
325 }
326 }
327 }
328
329 async fn drive_query<H: TextResultSetHandler>(&mut self, handler: &mut H) -> Result<()> {
330 let mut query = Query::new(handler);
331
332 loop {
333 match query.step(&mut self.buffer_set)? {
334 Action::NeedPacket(buffer) => {
335 buffer.clear();
336 let _ = read_payload(&mut self.stream, buffer).await?;
337 }
338 Action::ReadColumnMetadata { num_columns } => {
339 self.read_column_definition_packets(num_columns).await?;
340 }
341 Action::Finished => return Ok(()),
342 }
343 }
344 }
345
346 pub async fn exec<P, H>(
348 &mut self,
349 stmt: &mut PreparedStatement,
350 params: P,
351 handler: &mut H,
352 ) -> Result<()>
353 where
354 P: Params,
355 H: BinaryResultSetHandler,
356 {
357 let result = self.exec_inner(stmt, params, handler).await;
358 self.check_error(result)
359 }
360
361 async fn exec_inner<P, H>(
362 &mut self,
363 stmt: &mut PreparedStatement,
364 params: P,
365 handler: &mut H,
366 ) -> Result<()>
367 where
368 P: Params,
369 H: BinaryResultSetHandler,
370 {
371 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
372 self.write_payload().await?;
373 self.drive_exec(stmt, handler).await
374 }
375
376 async fn drive_bulk_exec<H: BinaryResultSetHandler>(
377 &mut self,
378 stmt: &mut crate::PreparedStatement,
379 handler: &mut H,
380 ) -> Result<()> {
381 let cache_metadata = self
382 .mariadb_capabilities
383 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
384 let mut bulk_exec = BulkExec::new(handler, stmt, cache_metadata);
385
386 loop {
387 match bulk_exec.step(&mut self.buffer_set)? {
388 Action::NeedPacket(buffer) => {
389 buffer.clear();
390 let _ = read_payload(&mut self.stream, buffer).await?;
391 }
392 Action::ReadColumnMetadata { num_columns } => {
393 self.read_column_definition_packets(num_columns).await?;
394 }
395 Action::Finished => return Ok(()),
396 }
397 }
398 }
399
400 pub async fn exec_bulk_insert_or_update<P, I, H>(
402 &mut self,
403 stmt: &mut PreparedStatement,
404 params: P,
405 flags: BulkFlags,
406 handler: &mut H,
407 ) -> Result<()>
408 where
409 P: BulkParamsSet + IntoIterator<Item = I>,
410 I: Params,
411 H: BinaryResultSetHandler,
412 {
413 let result = self
414 .exec_bulk_insert_or_update_inner(stmt, params, flags, handler)
415 .await;
416 self.check_error(result)
417 }
418
419 async fn exec_bulk_insert_or_update_inner<P, I, H>(
420 &mut self,
421 stmt: &mut PreparedStatement,
422 params: P,
423 flags: BulkFlags,
424 handler: &mut H,
425 ) -> Result<()>
426 where
427 P: BulkParamsSet + IntoIterator<Item = I>,
428 I: Params,
429 H: BinaryResultSetHandler,
430 {
431 if !self.is_mariadb() {
432 for param in params {
434 self.exec_inner(stmt, param, &mut DropHandler::default())
435 .await?;
436 }
437 Ok(())
438 } else {
439 write_bulk_execute(self.buffer_set.new_write_buffer(), stmt.id(), params, flags)?;
441 self.write_payload().await?;
442 self.drive_bulk_exec(stmt, handler).await
443 }
444 }
445
446 pub async fn exec_first<P, H>(
453 &mut self,
454 stmt: &mut PreparedStatement,
455 params: P,
456 handler: &mut H,
457 ) -> Result<bool>
458 where
459 P: Params,
460 H: BinaryResultSetHandler,
461 {
462 let result = self.exec_first_inner(stmt, params, handler).await;
463 self.check_error(result)
464 }
465
466 async fn exec_first_inner<P, H>(
467 &mut self,
468 stmt: &mut PreparedStatement,
469 params: P,
470 handler: &mut H,
471 ) -> Result<bool>
472 where
473 P: Params,
474 H: BinaryResultSetHandler,
475 {
476 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
477 self.write_payload().await?;
478 let mut first_row_handler = FirstRowHandler::new(handler);
479 self.drive_exec(stmt, &mut first_row_handler).await?;
480 Ok(first_row_handler.found_row)
481 }
482
483 #[instrument(skip_all)]
485 pub async fn exec_drop<P>(&mut self, stmt: &mut PreparedStatement, params: P) -> Result<()>
486 where
487 P: Params,
488 {
489 self.exec(stmt, params, &mut DropHandler::default()).await
490 }
491
492 pub async fn query<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
494 where
495 H: TextResultSetHandler,
496 {
497 let result = self.query_inner(sql, handler).await;
498 self.check_error(result)
499 }
500
501 async fn query_inner<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
502 where
503 H: TextResultSetHandler,
504 {
505 write_query(self.buffer_set.new_write_buffer(), sql);
506 self.write_payload().await?;
507 self.drive_query(handler).await
508 }
509
510 #[instrument(skip_all)]
512 pub async fn query_drop(&mut self, sql: &str) -> Result<()> {
513 let result = self.query_drop_inner(sql).await;
514 self.check_error(result)
515 }
516
517 async fn query_drop_inner(&mut self, sql: &str) -> Result<()> {
518 write_query(self.buffer_set.new_write_buffer(), sql);
519 self.write_payload().await?;
520 self.drive_query(&mut DropHandler::default()).await
521 }
522
523 pub async fn ping(&mut self) -> Result<()> {
527 let result = self.ping_inner().await;
528 self.check_error(result)
529 }
530
531 async fn ping_inner(&mut self) -> Result<()> {
532 write_ping(self.buffer_set.new_write_buffer());
533 self.write_payload().await?;
534 self.buffer_set.read_buffer.clear();
535 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
536 Ok(())
537 }
538
539 pub async fn reset(&mut self) -> Result<()> {
541 let result = self.reset_inner().await;
542 self.check_error(result)
543 }
544
545 async fn reset_inner(&mut self) -> Result<()> {
546 write_reset_connection(self.buffer_set.new_write_buffer());
547 self.write_payload().await?;
548 self.buffer_set.read_buffer.clear();
549 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
550 self.in_transaction = false;
551 Ok(())
552 }
553
554 pub async fn run_transaction<F, Fut, R>(&mut self, f: F) -> Result<R>
559 where
560 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Fut,
561 Fut: core::future::Future<Output = Result<R>>,
562 {
563 if self.in_transaction {
564 return Err(Error::NestedTransaction);
565 }
566
567 self.in_transaction = true;
568
569 if let Err(err) = self.query_drop("BEGIN").await {
570 self.in_transaction = false;
571 return Err(err);
572 }
573
574 let tx = super::transaction::Transaction::new(self.connection_id());
575 let result = f(self, tx).await;
576
577 if self.in_transaction {
579 let rollback_result = self.query_drop("ROLLBACK").await;
580 self.in_transaction = false;
581
582 if let Err(e) = result {
584 return Err(e);
585 }
586 rollback_result?;
587 }
588
589 result
590 }
591}
592
593#[instrument(skip_all)]
596async fn read_payload(reader: &mut Stream, buffer: &mut Vec<u8>) -> Result<u8> {
597 let mut packet_header = PacketHeader::new_zeroed();
598
599 buffer.clear();
600 reader.read_exact(packet_header.as_mut_bytes()).await?;
601
602 let length = packet_header.length();
603 let mut sequence_id = packet_header.sequence_id;
604
605 buffer.reserve(length);
606
607 {
609 let spare = buffer.spare_capacity_mut();
610 reader.read_buf_exact(&mut spare[..length]).await?;
611 unsafe {
613 buffer.set_len(length);
614 }
615 }
616
617 let mut current_length = length;
618 while current_length == 0xFFFFFF {
619 reader.read_exact(packet_header.as_mut_bytes()).await?;
620
621 current_length = packet_header.length();
622 sequence_id = packet_header.sequence_id;
623
624 buffer.reserve(current_length);
625 let spare = buffer.spare_capacity_mut();
626 reader.read_buf_exact(&mut spare[..current_length]).await?;
627 unsafe {
629 buffer.set_len(buffer.len() + current_length);
630 }
631 }
632
633 Ok(sequence_id)
634}
635
636async fn write_handshake_payload(
637 stream: &mut Stream,
638 buffer_set: &mut BufferSet,
639 sequence_id: u8,
640) -> Result<()> {
641 let mut buffer = buffer_set.write_buffer_mut().as_mut_slice();
642 let mut seq_id = sequence_id;
643
644 loop {
645 let chunk_size = buffer[4..].len().min(0xFFFFFF);
646 PacketHeader::mut_from_bytes(&mut buffer[0..4])?.encode_in_place(chunk_size, seq_id);
647 stream.write_all(&buffer[..4 + chunk_size]).await?;
648
649 if chunk_size < 0xFFFFFF {
650 break;
651 }
652
653 seq_id = seq_id.wrapping_add(1);
654 buffer = &mut buffer[0xFFFFFF..];
655 }
656 stream.flush().await?;
657 Ok(())
658}
659
660struct SocketPathHandler {
662 path: Option<String>,
663}
664
665impl TextResultSetHandler for SocketPathHandler {
666 fn no_result_set(&mut self, _: OkPayloadBytes) -> Result<()> {
667 Ok(())
668 }
669 fn resultset_start(&mut self, _: &[ColumnDefinition<'_>]) -> Result<()> {
670 Ok(())
671 }
672 fn resultset_end(&mut self, _: OkPayloadBytes) -> Result<()> {
673 Ok(())
674 }
675 fn row(&mut self, _: &[ColumnDefinition<'_>], row: TextRowPayload<'_>) -> Result<()> {
676 if row.0.first() == Some(&0xFB) {
678 return Ok(());
679 }
680 let (value, _) = read_string_lenenc(row.0)?;
682 if !value.is_empty() {
683 self.path = Some(String::from_utf8_lossy(value).into_owned());
684 }
685 Ok(())
686 }
687}