1use std::ops::AsyncFnOnce;
2
3use tokio::net::TcpStream;
4#[cfg(unix)]
5use tokio::net::UnixStream;
6use tracing::instrument;
7use zerocopy::{FromBytes, FromZeros, IntoBytes};
8
9use crate::PreparedStatement;
10use crate::buffer::BufferSet;
11use crate::buffer_pool::PooledBufferSet;
12use crate::constant::CapabilityFlags;
13use crate::error::{Error, Result};
14use crate::protocol::TextRowPayload;
15use crate::protocol::command::Action;
16use crate::protocol::command::ColumnDefinition;
17use crate::protocol::command::bulk_exec::{BulkExec, BulkFlags, BulkParamsSet, write_bulk_execute};
18use crate::protocol::command::prepared::{Exec, read_prepare_ok, write_execute, write_prepare};
19use crate::protocol::command::query::{Query, write_query};
20use crate::protocol::command::utility::{
21 DropHandler, FirstHandler, write_ping, write_reset_connection,
22};
23use crate::protocol::connection::{Handshake, HandshakeAction, InitialHandshake};
24use crate::protocol::packet::PacketHeader;
25use crate::protocol::primitive::read_string_lenenc;
26use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
27use crate::protocol::r#trait::{BinaryResultSetHandler, TextResultSetHandler, param::Params};
28
29use super::stream::Stream;
30
31pub struct Conn {
32 stream: Stream,
33 buffer_set: PooledBufferSet,
34 initial_handshake: InitialHandshake,
35 capability_flags: CapabilityFlags,
36 mariadb_capabilities: crate::constant::MariadbCapabilityFlags,
37 in_transaction: bool,
38 is_broken: bool,
39}
40
41impl Conn {
42 pub async fn new<O: TryInto<crate::opts::Opts>>(opts: O) -> Result<Self>
44 where
45 Error: From<O::Error>,
46 {
47 let opts: crate::opts::Opts = opts.try_into()?;
48
49 #[cfg(unix)]
50 let stream = if let Some(socket_path) = &opts.socket {
51 let stream = UnixStream::connect(socket_path).await?;
52 Stream::unix(stream)
53 } else {
54 if opts.host.is_empty() {
55 return Err(Error::BadUsageError(
56 "Missing host in connection options".to_string(),
57 ));
58 }
59 let addr = format!("{}:{}", opts.host, opts.port);
60 let stream = TcpStream::connect(&addr).await?;
61 stream.set_nodelay(opts.tcp_nodelay)?;
62 Stream::tcp(stream)
63 };
64
65 #[cfg(not(unix))]
66 let stream = {
67 if opts.socket.is_some() {
68 return Err(Error::BadUsageError(
69 "Unix sockets are not supported on this platform".to_string(),
70 ));
71 }
72 if opts.host.is_empty() {
73 return Err(Error::BadUsageError(
74 "Missing host in connection options".to_string(),
75 ));
76 }
77 let addr = format!("{}:{}", opts.host, opts.port);
78 let stream = TcpStream::connect(&addr).await?;
79 stream.set_nodelay(opts.tcp_nodelay)?;
80 Stream::tcp(stream)
81 };
82
83 Self::new_with_stream(stream, &opts).await
84 }
85
86 pub async fn new_with_stream(stream: Stream, opts: &crate::opts::Opts) -> Result<Self> {
88 let mut conn_stream = stream;
89 let mut buffer_set = opts.buffer_pool.get_buffer_set();
90
91 #[cfg(feature = "tokio-tls")]
92 let host = opts.host.clone();
93
94 let mut handshake = Handshake::new(opts);
95
96 loop {
97 match handshake.step(&mut buffer_set)? {
98 HandshakeAction::ReadPacket(buffer) => {
99 buffer.clear();
100 read_payload(&mut conn_stream, buffer).await?;
101 }
102 HandshakeAction::WritePacket { sequence_id } => {
103 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id).await?;
104 buffer_set.read_buffer.clear();
105 read_payload(&mut conn_stream, &mut buffer_set.read_buffer).await?;
106 }
107 #[cfg(feature = "tokio-tls")]
108 HandshakeAction::UpgradeTls { sequence_id } => {
109 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id).await?;
110 conn_stream = conn_stream.upgrade_to_tls(&host).await?;
111 }
112 #[cfg(not(feature = "tokio-tls"))]
113 HandshakeAction::UpgradeTls { .. } => {
114 return Err(Error::BadUsageError(
115 "TLS requested but tokio-tls feature is not enabled".to_string(),
116 ));
117 }
118 HandshakeAction::Finished => break,
119 }
120 }
121
122 let (initial_handshake, capability_flags, mariadb_capabilities) = handshake.finish()?;
123
124 let conn = Self {
125 stream: conn_stream,
126 buffer_set,
127 initial_handshake,
128 capability_flags,
129 mariadb_capabilities,
130 in_transaction: false,
131 is_broken: false,
132 };
133
134 #[cfg(unix)]
136 let mut conn = if opts.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
137 conn.try_upgrade_to_unix_socket(opts).await
138 } else {
139 conn
140 };
141 #[cfg(not(unix))]
142 let mut conn = conn;
143
144 if let Some(init_command) = &opts.init_command {
146 conn.query_drop(init_command).await?;
147 }
148
149 Ok(conn)
150 }
151
152 pub fn server_version(&self) -> &[u8] {
153 &self.buffer_set.initial_handshake[self.initial_handshake.server_version.clone()]
154 }
155
156 pub fn capability_flags(&self) -> CapabilityFlags {
158 self.capability_flags
159 }
160
161 pub fn is_mysql(&self) -> bool {
163 self.capability_flags.is_mysql()
164 }
165
166 pub fn is_mariadb(&self) -> bool {
168 self.capability_flags.is_mariadb()
169 }
170
171 pub fn connection_id(&self) -> u64 {
173 self.initial_handshake.connection_id as u64
174 }
175
176 pub fn status_flags(&self) -> crate::constant::ServerStatusFlags {
178 self.initial_handshake.status_flags
179 }
180
181 pub fn is_broken(&self) -> bool {
183 self.is_broken
184 }
185
186 #[inline]
187 fn check_error<T>(&mut self, result: Result<T>) -> Result<T> {
188 if let Err(e) = &result
189 && e.is_conn_broken()
190 {
191 self.is_broken = true;
192 }
193 result
194 }
195
196 pub(crate) fn set_in_transaction(&mut self, value: bool) {
197 self.in_transaction = value;
198 }
199
200 pub fn in_transaction(&self) -> bool {
202 self.in_transaction
203 }
204
205 #[cfg(unix)]
208 async fn try_upgrade_to_unix_socket(mut self, opts: &crate::opts::Opts) -> Self {
209 let mut handler = SocketPathHandler { path: None };
211 if self.query("SELECT @@socket", &mut handler).await.is_err() {
212 return self;
213 }
214
215 let socket_path = match handler.path {
216 Some(p) if !p.is_empty() => p,
217 _ => return self,
218 };
219
220 let unix_stream = match UnixStream::connect(&socket_path).await {
222 Ok(s) => s,
223 Err(_) => return self,
224 };
225 let stream = Stream::unix(unix_stream);
226
227 let mut opts_unix = opts.clone();
230 opts_unix.upgrade_to_unix_socket = false;
231
232 match Box::pin(Self::new_with_stream(stream, &opts_unix)).await {
233 Ok(new_conn) => new_conn,
234 Err(_) => self,
235 }
236 }
237
238 #[instrument(skip_all)]
240 async fn write_payload(&mut self) -> Result<()> {
241 let mut sequence_id = 0_u8;
242 let mut buffer = self.buffer_set.write_buffer_mut().as_mut_slice();
243
244 loop {
245 let chunk_size = buffer[4..].len().min(0xFFFFFF);
246 PacketHeader::mut_from_bytes(&mut buffer[0..4])?
247 .encode_in_place(chunk_size, sequence_id);
248 self.stream.write_all(&buffer[..4 + chunk_size]).await?;
249
250 if chunk_size < 0xFFFFFF {
251 break;
252 }
253
254 sequence_id = sequence_id.wrapping_add(1);
255 buffer = &mut buffer[0xFFFFFF..];
256 }
257 self.stream.flush().await?;
258 Ok(())
259 }
260
261 pub async fn prepare(&mut self, sql: &str) -> Result<PreparedStatement> {
265 let result = self.prepare_inner(sql).await;
266 self.check_error(result)
267 }
268
269 async fn prepare_inner(&mut self, sql: &str) -> Result<PreparedStatement> {
270 use crate::protocol::command::ColumnDefinitions;
271
272 self.buffer_set.read_buffer.clear();
273
274 write_prepare(self.buffer_set.new_write_buffer(), sql);
275
276 self.write_payload().await?;
277
278 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
279
280 if !self.buffer_set.read_buffer.is_empty() && self.buffer_set.read_buffer[0] == 0xFF {
281 Err(ErrPayloadBytes(&self.buffer_set.read_buffer))?
282 }
283
284 let prepare_ok = read_prepare_ok(&self.buffer_set.read_buffer)?;
285 let statement_id = prepare_ok.statement_id();
286 let num_params = prepare_ok.num_params();
287 let num_columns = prepare_ok.num_columns();
288
289 for _ in 0..num_params {
291 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
292 }
293
294 let column_definitions = if num_columns > 0 {
296 self.read_column_definition_packets(num_columns as usize)
297 .await?;
298 Some(ColumnDefinitions::new(
299 num_columns as usize,
300 std::mem::take(&mut self.buffer_set.column_definition_buffer),
301 )?)
302 } else {
303 None
304 };
305
306 let mut stmt = PreparedStatement::new(statement_id);
307 if let Some(col_defs) = column_definitions {
308 stmt.set_column_definitions(col_defs);
309 }
310 Ok(stmt)
311 }
312
313 #[tracing::instrument(skip_all)]
314 async fn read_column_definition_packets(&mut self, num_columns: usize) -> Result<u8> {
315 let mut header = PacketHeader::new_zeroed();
316 let out = &mut self.buffer_set.column_definition_buffer;
317 out.clear();
318
319 for _ in 0..num_columns {
321 self.stream.read_exact(header.as_mut_bytes()).await?;
322 let length = header.length();
323 out.extend((length as u32).to_ne_bytes());
324
325 out.reserve(length);
326 let spare = out.spare_capacity_mut();
327 self.stream.read_buf_exact(&mut spare[..length]).await?;
328 unsafe {
330 out.set_len(out.len() + length);
331 }
332 }
333
334 Ok(header.sequence_id)
335 }
336
337 async fn drive_exec<H: BinaryResultSetHandler>(
338 &mut self,
339 stmt: &mut crate::PreparedStatement,
340 handler: &mut H,
341 ) -> Result<()> {
342 let cache_metadata = self
343 .mariadb_capabilities
344 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
345 let mut exec = Exec::new(handler, stmt, cache_metadata);
346
347 loop {
348 match exec.step(&mut self.buffer_set)? {
349 Action::NeedPacket(buffer) => {
350 buffer.clear();
351 let _ = read_payload(&mut self.stream, buffer).await?;
352 }
353 Action::ReadColumnMetadata { num_columns } => {
354 self.read_column_definition_packets(num_columns).await?;
355 }
356 Action::Finished => return Ok(()),
357 }
358 }
359 }
360
361 async fn drive_query<H: TextResultSetHandler>(&mut self, handler: &mut H) -> Result<()> {
362 let mut query = Query::new(handler);
363
364 loop {
365 match query.step(&mut self.buffer_set)? {
366 Action::NeedPacket(buffer) => {
367 buffer.clear();
368 let _ = read_payload(&mut self.stream, buffer).await?;
369 }
370 Action::ReadColumnMetadata { num_columns } => {
371 self.read_column_definition_packets(num_columns).await?;
372 }
373 Action::Finished => return Ok(()),
374 }
375 }
376 }
377
378 pub async fn exec<P, H>(
380 &mut self,
381 stmt: &mut PreparedStatement,
382 params: P,
383 handler: &mut H,
384 ) -> Result<()>
385 where
386 P: Params,
387 H: BinaryResultSetHandler,
388 {
389 let result = self.exec_inner(stmt, params, handler).await;
390 self.check_error(result)
391 }
392
393 async fn exec_inner<P, H>(
394 &mut self,
395 stmt: &mut PreparedStatement,
396 params: P,
397 handler: &mut H,
398 ) -> Result<()>
399 where
400 P: Params,
401 H: BinaryResultSetHandler,
402 {
403 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
404 self.write_payload().await?;
405 self.drive_exec(stmt, handler).await
406 }
407
408 async fn drive_bulk_exec<H: BinaryResultSetHandler>(
409 &mut self,
410 stmt: &mut crate::PreparedStatement,
411 handler: &mut H,
412 ) -> Result<()> {
413 let cache_metadata = self
414 .mariadb_capabilities
415 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
416 let mut bulk_exec = BulkExec::new(handler, stmt, cache_metadata);
417
418 loop {
419 match bulk_exec.step(&mut self.buffer_set)? {
420 Action::NeedPacket(buffer) => {
421 buffer.clear();
422 let _ = read_payload(&mut self.stream, buffer).await?;
423 }
424 Action::ReadColumnMetadata { num_columns } => {
425 self.read_column_definition_packets(num_columns).await?;
426 }
427 Action::Finished => return Ok(()),
428 }
429 }
430 }
431
432 pub async fn exec_bulk_insert_or_update<P, I, H>(
434 &mut self,
435 stmt: &mut PreparedStatement,
436 params: P,
437 flags: BulkFlags,
438 handler: &mut H,
439 ) -> Result<()>
440 where
441 P: BulkParamsSet + IntoIterator<Item = I>,
442 I: Params,
443 H: BinaryResultSetHandler,
444 {
445 let result = self
446 .exec_bulk_insert_or_update_inner(stmt, params, flags, handler)
447 .await;
448 self.check_error(result)
449 }
450
451 async fn exec_bulk_insert_or_update_inner<P, I, H>(
452 &mut self,
453 stmt: &mut PreparedStatement,
454 params: P,
455 flags: BulkFlags,
456 handler: &mut H,
457 ) -> Result<()>
458 where
459 P: BulkParamsSet + IntoIterator<Item = I>,
460 I: Params,
461 H: BinaryResultSetHandler,
462 {
463 if !self.is_mariadb() {
464 for param in params {
466 self.exec_inner(stmt, param, &mut DropHandler::default())
467 .await?;
468 }
469 Ok(())
470 } else {
471 write_bulk_execute(self.buffer_set.new_write_buffer(), stmt.id(), params, flags)?;
473 self.write_payload().await?;
474 self.drive_bulk_exec(stmt, handler).await
475 }
476 }
477
478 pub async fn exec_first<Row, P>(
480 &mut self,
481 stmt: &mut PreparedStatement,
482 params: P,
483 ) -> Result<Option<Row>>
484 where
485 Row: for<'buf> crate::raw::FromRawRow<'buf>,
486 P: Params,
487 {
488 let result = self.exec_first_inner(stmt, params).await;
489 self.check_error(result)
490 }
491
492 async fn exec_first_inner<Row, P>(
493 &mut self,
494 stmt: &mut PreparedStatement,
495 params: P,
496 ) -> Result<Option<Row>>
497 where
498 Row: for<'buf> crate::raw::FromRawRow<'buf>,
499 P: Params,
500 {
501 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
502 self.write_payload().await?;
503 let mut handler = FirstHandler::<Row>::default();
504 self.drive_exec(stmt, &mut handler).await?;
505 Ok(handler.take())
506 }
507
508 #[instrument(skip_all)]
510 pub async fn exec_drop<P>(&mut self, stmt: &mut PreparedStatement, params: P) -> Result<()>
511 where
512 P: Params,
513 {
514 self.exec(stmt, params, &mut DropHandler::default()).await
515 }
516
517 pub async fn exec_collect<Row, P>(
519 &mut self,
520 stmt: &mut PreparedStatement,
521 params: P,
522 ) -> Result<Vec<Row>>
523 where
524 Row: for<'buf> crate::raw::FromRawRow<'buf>,
525 P: Params,
526 {
527 let mut handler = crate::handler::CollectHandler::<Row>::default();
528 self.exec(stmt, params, &mut handler).await?;
529 Ok(handler.into_rows())
530 }
531
532 pub async fn exec_foreach<Row, P, F>(
536 &mut self,
537 stmt: &mut PreparedStatement,
538 params: P,
539 f: F,
540 ) -> Result<()>
541 where
542 Row: for<'buf> crate::raw::FromRawRow<'buf>,
543 P: Params,
544 F: FnMut(Row) -> Result<()>,
545 {
546 let mut handler = crate::handler::ForEachHandler::<Row, F>::new(f);
547 self.exec(stmt, params, &mut handler).await
548 }
549
550 pub async fn query<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
552 where
553 H: TextResultSetHandler,
554 {
555 let result = self.query_inner(sql, handler).await;
556 self.check_error(result)
557 }
558
559 async fn query_inner<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
560 where
561 H: TextResultSetHandler,
562 {
563 write_query(self.buffer_set.new_write_buffer(), sql);
564 self.write_payload().await?;
565 self.drive_query(handler).await
566 }
567
568 #[instrument(skip_all)]
570 pub async fn query_drop(&mut self, sql: &str) -> Result<()> {
571 let result = self.query_drop_inner(sql).await;
572 self.check_error(result)
573 }
574
575 async fn query_drop_inner(&mut self, sql: &str) -> Result<()> {
576 write_query(self.buffer_set.new_write_buffer(), sql);
577 self.write_payload().await?;
578 self.drive_query(&mut DropHandler::default()).await
579 }
580
581 pub async fn ping(&mut self) -> Result<()> {
585 let result = self.ping_inner().await;
586 self.check_error(result)
587 }
588
589 async fn ping_inner(&mut self) -> Result<()> {
590 write_ping(self.buffer_set.new_write_buffer());
591 self.write_payload().await?;
592 self.buffer_set.read_buffer.clear();
593 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
594 Ok(())
595 }
596
597 pub async fn reset(&mut self) -> Result<()> {
599 let result = self.reset_inner().await;
600 self.check_error(result)
601 }
602
603 async fn reset_inner(&mut self) -> Result<()> {
604 write_reset_connection(self.buffer_set.new_write_buffer());
605 self.write_payload().await?;
606 self.buffer_set.read_buffer.clear();
607 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
608 self.in_transaction = false;
609 Ok(())
610 }
611
612 pub async fn transaction<F, R>(&mut self, f: F) -> Result<R>
617 where
618 F: AsyncFnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
619 {
620 if self.in_transaction {
621 return Err(Error::NestedTransaction);
622 }
623
624 self.in_transaction = true;
625
626 if let Err(err) = self.query_drop("BEGIN").await {
627 self.in_transaction = false;
628 return Err(err);
629 }
630
631 let tx = super::transaction::Transaction::new(self.connection_id());
632 let result = f(self, tx).await;
633
634 if self.in_transaction {
636 self.in_transaction = false;
637 match &result {
638 Ok(_) => self.query_drop("COMMIT").await?,
639 Err(_) => {
640 let _ = self.query_drop("ROLLBACK").await;
641 }
642 }
643 }
644
645 result
646 }
647}
648
649#[instrument(skip_all)]
652async fn read_payload(reader: &mut Stream, buffer: &mut Vec<u8>) -> Result<u8> {
653 let mut packet_header = PacketHeader::new_zeroed();
654
655 buffer.clear();
656 reader.read_exact(packet_header.as_mut_bytes()).await?;
657
658 let length = packet_header.length();
659 let mut sequence_id = packet_header.sequence_id;
660
661 buffer.reserve(length);
662
663 {
665 let spare = buffer.spare_capacity_mut();
666 reader.read_buf_exact(&mut spare[..length]).await?;
667 unsafe {
669 buffer.set_len(length);
670 }
671 }
672
673 let mut current_length = length;
674 while current_length == 0xFFFFFF {
675 reader.read_exact(packet_header.as_mut_bytes()).await?;
676
677 current_length = packet_header.length();
678 sequence_id = packet_header.sequence_id;
679
680 buffer.reserve(current_length);
681 let spare = buffer.spare_capacity_mut();
682 reader.read_buf_exact(&mut spare[..current_length]).await?;
683 unsafe {
685 buffer.set_len(buffer.len() + current_length);
686 }
687 }
688
689 Ok(sequence_id)
690}
691
692async fn write_handshake_payload(
693 stream: &mut Stream,
694 buffer_set: &mut BufferSet,
695 sequence_id: u8,
696) -> Result<()> {
697 let mut buffer = buffer_set.write_buffer_mut().as_mut_slice();
698 let mut seq_id = sequence_id;
699
700 loop {
701 let chunk_size = buffer[4..].len().min(0xFFFFFF);
702 PacketHeader::mut_from_bytes(&mut buffer[0..4])?.encode_in_place(chunk_size, seq_id);
703 stream.write_all(&buffer[..4 + chunk_size]).await?;
704
705 if chunk_size < 0xFFFFFF {
706 break;
707 }
708
709 seq_id = seq_id.wrapping_add(1);
710 buffer = &mut buffer[0xFFFFFF..];
711 }
712 stream.flush().await?;
713 Ok(())
714}
715
716#[cfg(unix)]
718struct SocketPathHandler {
719 path: Option<String>,
720}
721
722#[cfg(unix)]
723impl TextResultSetHandler for SocketPathHandler {
724 fn no_result_set(&mut self, _: OkPayloadBytes) -> Result<()> {
725 Ok(())
726 }
727 fn resultset_start(&mut self, _: &[ColumnDefinition<'_>]) -> Result<()> {
728 Ok(())
729 }
730 fn resultset_end(&mut self, _: OkPayloadBytes) -> Result<()> {
731 Ok(())
732 }
733 fn row(&mut self, _: &[ColumnDefinition<'_>], row: TextRowPayload<'_>) -> Result<()> {
734 if row.0.first() == Some(&0xFB) {
736 return Ok(());
737 }
738 let (value, _) = read_string_lenenc(row.0)?;
740 if !value.is_empty() {
741 self.path = Some(String::from_utf8_lossy(value).into_owned());
742 }
743 Ok(())
744 }
745}