1use tokio::net::TcpStream;
2#[cfg(unix)]
3use tokio::net::UnixStream;
4use tracing::instrument;
5use zerocopy::{FromBytes, FromZeros, IntoBytes};
6
7use crate::PreparedStatement;
8use crate::buffer::BufferSet;
9use crate::buffer_pool::PooledBufferSet;
10use crate::constant::CapabilityFlags;
11use crate::error::{Error, Result};
12use crate::protocol::TextRowPayload;
13use crate::protocol::command::Action;
14use crate::protocol::command::ColumnDefinition;
15use crate::protocol::command::bulk_exec::{BulkExec, BulkFlags, BulkParamsSet, write_bulk_execute};
16use crate::protocol::command::prepared::{Exec, read_prepare_ok, write_execute, write_prepare};
17use crate::protocol::command::query::{Query, write_query};
18use crate::protocol::command::utility::{
19 DropHandler, FirstRowHandler, write_ping, write_reset_connection,
20};
21use crate::protocol::connection::{Handshake, HandshakeAction, InitialHandshake};
22use crate::protocol::packet::PacketHeader;
23use crate::protocol::primitive::read_string_lenenc;
24use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
25use crate::protocol::r#trait::{BinaryResultSetHandler, TextResultSetHandler, param::Params};
26
27use super::stream::Stream;
28
29pub struct Conn {
30 stream: Stream,
31 buffer_set: PooledBufferSet,
32 initial_handshake: InitialHandshake,
33 capability_flags: CapabilityFlags,
34 mariadb_capabilities: crate::constant::MariadbCapabilityFlags,
35 in_transaction: bool,
36 is_broken: bool,
37}
38
39impl Conn {
40 pub async fn new<O: TryInto<crate::opts::Opts>>(opts: O) -> Result<Self>
42 where
43 Error: From<O::Error>,
44 {
45 let opts: crate::opts::Opts = opts.try_into()?;
46
47 #[cfg(unix)]
48 let stream = if let Some(socket_path) = &opts.socket {
49 let stream = UnixStream::connect(socket_path).await?;
50 Stream::unix(stream)
51 } else {
52 if opts.host.is_empty() {
53 return Err(Error::BadUsageError(
54 "Missing host in connection options".to_string(),
55 ));
56 }
57 let addr = format!("{}:{}", opts.host, opts.port);
58 let stream = TcpStream::connect(&addr).await?;
59 stream.set_nodelay(opts.tcp_nodelay)?;
60 Stream::tcp(stream)
61 };
62
63 #[cfg(not(unix))]
64 let stream = {
65 if opts.socket.is_some() {
66 return Err(Error::BadUsageError(
67 "Unix sockets are not supported on this platform".to_string(),
68 ));
69 }
70 if opts.host.is_empty() {
71 return Err(Error::BadUsageError(
72 "Missing host in connection options".to_string(),
73 ));
74 }
75 let addr = format!("{}:{}", opts.host, opts.port);
76 let stream = TcpStream::connect(&addr).await?;
77 stream.set_nodelay(opts.tcp_nodelay)?;
78 Stream::tcp(stream)
79 };
80
81 Self::new_with_stream(stream, &opts).await
82 }
83
84 pub async fn new_with_stream(stream: Stream, opts: &crate::opts::Opts) -> Result<Self> {
86 let mut conn_stream = stream;
87 let mut buffer_set = opts.buffer_pool.get_buffer_set();
88
89 #[cfg(feature = "tokio-tls")]
90 let host = opts.host.clone();
91
92 let mut handshake = Handshake::new(opts);
93
94 loop {
95 match handshake.step(&mut buffer_set)? {
96 HandshakeAction::ReadPacket(buffer) => {
97 buffer.clear();
98 read_payload(&mut conn_stream, buffer).await?;
99 }
100 HandshakeAction::WritePacket { sequence_id } => {
101 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id).await?;
102 buffer_set.read_buffer.clear();
103 read_payload(&mut conn_stream, &mut buffer_set.read_buffer).await?;
104 }
105 #[cfg(feature = "tokio-tls")]
106 HandshakeAction::UpgradeTls { sequence_id } => {
107 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id).await?;
108 conn_stream = conn_stream.upgrade_to_tls(&host).await?;
109 }
110 #[cfg(not(feature = "tokio-tls"))]
111 HandshakeAction::UpgradeTls { .. } => {
112 return Err(Error::BadUsageError(
113 "TLS requested but tokio-tls feature is not enabled".to_string(),
114 ));
115 }
116 HandshakeAction::Finished => break,
117 }
118 }
119
120 let (initial_handshake, capability_flags, mariadb_capabilities) = handshake.finish()?;
121
122 let conn = Self {
123 stream: conn_stream,
124 buffer_set,
125 initial_handshake,
126 capability_flags,
127 mariadb_capabilities,
128 in_transaction: false,
129 is_broken: false,
130 };
131
132 #[cfg(unix)]
134 let mut conn = if opts.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
135 conn.try_upgrade_to_unix_socket(opts).await
136 } else {
137 conn
138 };
139 #[cfg(not(unix))]
140 let mut conn = conn;
141
142 if let Some(init_command) = &opts.init_command {
144 conn.query_drop(init_command).await?;
145 }
146
147 Ok(conn)
148 }
149
150 pub fn server_version(&self) -> &[u8] {
151 &self.buffer_set.initial_handshake[self.initial_handshake.server_version.clone()]
152 }
153
154 pub fn capability_flags(&self) -> CapabilityFlags {
156 self.capability_flags
157 }
158
159 pub fn is_mysql(&self) -> bool {
161 self.capability_flags.is_mysql()
162 }
163
164 pub fn is_mariadb(&self) -> bool {
166 self.capability_flags.is_mariadb()
167 }
168
169 pub fn connection_id(&self) -> u64 {
171 self.initial_handshake.connection_id as u64
172 }
173
174 pub fn status_flags(&self) -> crate::constant::ServerStatusFlags {
176 self.initial_handshake.status_flags
177 }
178
179 pub fn is_broken(&self) -> bool {
181 self.is_broken
182 }
183
184 #[inline]
185 fn check_error<T>(&mut self, result: Result<T>) -> Result<T> {
186 if let Err(e) = &result
187 && e.is_conn_broken()
188 {
189 self.is_broken = true;
190 }
191 result
192 }
193
194 pub(crate) fn set_in_transaction(&mut self, value: bool) {
195 self.in_transaction = value;
196 }
197
198 #[cfg(unix)]
201 async fn try_upgrade_to_unix_socket(mut self, opts: &crate::opts::Opts) -> Self {
202 let mut handler = SocketPathHandler { path: None };
204 if self.query("SELECT @@socket", &mut handler).await.is_err() {
205 return self;
206 }
207
208 let socket_path = match handler.path {
209 Some(p) if !p.is_empty() => p,
210 _ => return self,
211 };
212
213 let unix_stream = match UnixStream::connect(&socket_path).await {
215 Ok(s) => s,
216 Err(_) => return self,
217 };
218 let stream = Stream::unix(unix_stream);
219
220 let mut opts_unix = opts.clone();
223 opts_unix.upgrade_to_unix_socket = false;
224
225 match Box::pin(Self::new_with_stream(stream, &opts_unix)).await {
226 Ok(new_conn) => new_conn,
227 Err(_) => self,
228 }
229 }
230
231 #[instrument(skip_all)]
233 async fn write_payload(&mut self) -> Result<()> {
234 let mut sequence_id = 0_u8;
235 let mut buffer = self.buffer_set.write_buffer_mut().as_mut_slice();
236
237 loop {
238 let chunk_size = buffer[4..].len().min(0xFFFFFF);
239 PacketHeader::mut_from_bytes(&mut buffer[0..4])?
240 .encode_in_place(chunk_size, sequence_id);
241 self.stream.write_all(&buffer[..4 + chunk_size]).await?;
242
243 if chunk_size < 0xFFFFFF {
244 break;
245 }
246
247 sequence_id = sequence_id.wrapping_add(1);
248 buffer = &mut buffer[0xFFFFFF..];
249 }
250 self.stream.flush().await?;
251 Ok(())
252 }
253
254 pub async fn prepare(&mut self, sql: &str) -> Result<PreparedStatement> {
258 let result = self.prepare_inner(sql).await;
259 self.check_error(result)
260 }
261
262 async fn prepare_inner(&mut self, sql: &str) -> Result<PreparedStatement> {
263 use crate::protocol::command::ColumnDefinitions;
264
265 self.buffer_set.read_buffer.clear();
266
267 write_prepare(self.buffer_set.new_write_buffer(), sql);
268
269 self.write_payload().await?;
270
271 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
272
273 if !self.buffer_set.read_buffer.is_empty() && self.buffer_set.read_buffer[0] == 0xFF {
274 Err(ErrPayloadBytes(&self.buffer_set.read_buffer))?
275 }
276
277 let prepare_ok = read_prepare_ok(&self.buffer_set.read_buffer)?;
278 let statement_id = prepare_ok.statement_id();
279 let num_params = prepare_ok.num_params();
280 let num_columns = prepare_ok.num_columns();
281
282 for _ in 0..num_params {
284 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
285 }
286
287 let column_definitions = if num_columns > 0 {
289 self.read_column_definition_packets(num_columns as usize)
290 .await?;
291 Some(ColumnDefinitions::new(
292 num_columns as usize,
293 std::mem::take(&mut self.buffer_set.column_definition_buffer),
294 )?)
295 } else {
296 None
297 };
298
299 let mut stmt = PreparedStatement::new(statement_id);
300 if let Some(col_defs) = column_definitions {
301 stmt.set_column_definitions(col_defs);
302 }
303 Ok(stmt)
304 }
305
306 #[tracing::instrument(skip_all)]
307 async fn read_column_definition_packets(&mut self, num_columns: usize) -> Result<u8> {
308 let mut header = PacketHeader::new_zeroed();
309 let out = &mut self.buffer_set.column_definition_buffer;
310 out.clear();
311
312 for _ in 0..num_columns {
314 self.stream.read_exact(header.as_mut_bytes()).await?;
315 let length = header.length();
316 out.extend((length as u32).to_ne_bytes());
317
318 out.reserve(length);
319 let spare = out.spare_capacity_mut();
320 self.stream.read_buf_exact(&mut spare[..length]).await?;
321 unsafe {
323 out.set_len(out.len() + length);
324 }
325 }
326
327 Ok(header.sequence_id)
328 }
329
330 async fn drive_exec<H: BinaryResultSetHandler>(
331 &mut self,
332 stmt: &mut crate::PreparedStatement,
333 handler: &mut H,
334 ) -> Result<()> {
335 let cache_metadata = self
336 .mariadb_capabilities
337 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
338 let mut exec = Exec::new(handler, stmt, cache_metadata);
339
340 loop {
341 match exec.step(&mut self.buffer_set)? {
342 Action::NeedPacket(buffer) => {
343 buffer.clear();
344 let _ = read_payload(&mut self.stream, buffer).await?;
345 }
346 Action::ReadColumnMetadata { num_columns } => {
347 self.read_column_definition_packets(num_columns).await?;
348 }
349 Action::Finished => return Ok(()),
350 }
351 }
352 }
353
354 async fn drive_query<H: TextResultSetHandler>(&mut self, handler: &mut H) -> Result<()> {
355 let mut query = Query::new(handler);
356
357 loop {
358 match query.step(&mut self.buffer_set)? {
359 Action::NeedPacket(buffer) => {
360 buffer.clear();
361 let _ = read_payload(&mut self.stream, buffer).await?;
362 }
363 Action::ReadColumnMetadata { num_columns } => {
364 self.read_column_definition_packets(num_columns).await?;
365 }
366 Action::Finished => return Ok(()),
367 }
368 }
369 }
370
371 pub async fn exec<P, H>(
373 &mut self,
374 stmt: &mut PreparedStatement,
375 params: P,
376 handler: &mut H,
377 ) -> Result<()>
378 where
379 P: Params,
380 H: BinaryResultSetHandler,
381 {
382 let result = self.exec_inner(stmt, params, handler).await;
383 self.check_error(result)
384 }
385
386 async fn exec_inner<P, H>(
387 &mut self,
388 stmt: &mut PreparedStatement,
389 params: P,
390 handler: &mut H,
391 ) -> Result<()>
392 where
393 P: Params,
394 H: BinaryResultSetHandler,
395 {
396 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
397 self.write_payload().await?;
398 self.drive_exec(stmt, handler).await
399 }
400
401 async fn drive_bulk_exec<H: BinaryResultSetHandler>(
402 &mut self,
403 stmt: &mut crate::PreparedStatement,
404 handler: &mut H,
405 ) -> Result<()> {
406 let cache_metadata = self
407 .mariadb_capabilities
408 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
409 let mut bulk_exec = BulkExec::new(handler, stmt, cache_metadata);
410
411 loop {
412 match bulk_exec.step(&mut self.buffer_set)? {
413 Action::NeedPacket(buffer) => {
414 buffer.clear();
415 let _ = read_payload(&mut self.stream, buffer).await?;
416 }
417 Action::ReadColumnMetadata { num_columns } => {
418 self.read_column_definition_packets(num_columns).await?;
419 }
420 Action::Finished => return Ok(()),
421 }
422 }
423 }
424
425 pub async fn exec_bulk_insert_or_update<P, I, H>(
427 &mut self,
428 stmt: &mut PreparedStatement,
429 params: P,
430 flags: BulkFlags,
431 handler: &mut H,
432 ) -> Result<()>
433 where
434 P: BulkParamsSet + IntoIterator<Item = I>,
435 I: Params,
436 H: BinaryResultSetHandler,
437 {
438 let result = self
439 .exec_bulk_insert_or_update_inner(stmt, params, flags, handler)
440 .await;
441 self.check_error(result)
442 }
443
444 async fn exec_bulk_insert_or_update_inner<P, I, H>(
445 &mut self,
446 stmt: &mut PreparedStatement,
447 params: P,
448 flags: BulkFlags,
449 handler: &mut H,
450 ) -> Result<()>
451 where
452 P: BulkParamsSet + IntoIterator<Item = I>,
453 I: Params,
454 H: BinaryResultSetHandler,
455 {
456 if !self.is_mariadb() {
457 for param in params {
459 self.exec_inner(stmt, param, &mut DropHandler::default())
460 .await?;
461 }
462 Ok(())
463 } else {
464 write_bulk_execute(self.buffer_set.new_write_buffer(), stmt.id(), params, flags)?;
466 self.write_payload().await?;
467 self.drive_bulk_exec(stmt, handler).await
468 }
469 }
470
471 pub async fn exec_first<P, H>(
478 &mut self,
479 stmt: &mut PreparedStatement,
480 params: P,
481 handler: &mut H,
482 ) -> Result<bool>
483 where
484 P: Params,
485 H: BinaryResultSetHandler,
486 {
487 let result = self.exec_first_inner(stmt, params, handler).await;
488 self.check_error(result)
489 }
490
491 async fn exec_first_inner<P, H>(
492 &mut self,
493 stmt: &mut PreparedStatement,
494 params: P,
495 handler: &mut H,
496 ) -> Result<bool>
497 where
498 P: Params,
499 H: BinaryResultSetHandler,
500 {
501 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
502 self.write_payload().await?;
503 let mut first_row_handler = FirstRowHandler::new(handler);
504 self.drive_exec(stmt, &mut first_row_handler).await?;
505 Ok(first_row_handler.found_row)
506 }
507
508 #[instrument(skip_all)]
510 pub async fn exec_drop<P>(&mut self, stmt: &mut PreparedStatement, params: P) -> Result<()>
511 where
512 P: Params,
513 {
514 self.exec(stmt, params, &mut DropHandler::default()).await
515 }
516
517 pub async fn query<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
519 where
520 H: TextResultSetHandler,
521 {
522 let result = self.query_inner(sql, handler).await;
523 self.check_error(result)
524 }
525
526 async fn query_inner<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
527 where
528 H: TextResultSetHandler,
529 {
530 write_query(self.buffer_set.new_write_buffer(), sql);
531 self.write_payload().await?;
532 self.drive_query(handler).await
533 }
534
535 #[instrument(skip_all)]
537 pub async fn query_drop(&mut self, sql: &str) -> Result<()> {
538 let result = self.query_drop_inner(sql).await;
539 self.check_error(result)
540 }
541
542 async fn query_drop_inner(&mut self, sql: &str) -> Result<()> {
543 write_query(self.buffer_set.new_write_buffer(), sql);
544 self.write_payload().await?;
545 self.drive_query(&mut DropHandler::default()).await
546 }
547
548 pub async fn ping(&mut self) -> Result<()> {
552 let result = self.ping_inner().await;
553 self.check_error(result)
554 }
555
556 async fn ping_inner(&mut self) -> Result<()> {
557 write_ping(self.buffer_set.new_write_buffer());
558 self.write_payload().await?;
559 self.buffer_set.read_buffer.clear();
560 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
561 Ok(())
562 }
563
564 pub async fn reset(&mut self) -> Result<()> {
566 let result = self.reset_inner().await;
567 self.check_error(result)
568 }
569
570 async fn reset_inner(&mut self) -> Result<()> {
571 write_reset_connection(self.buffer_set.new_write_buffer());
572 self.write_payload().await?;
573 self.buffer_set.read_buffer.clear();
574 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
575 self.in_transaction = false;
576 Ok(())
577 }
578
579 pub async fn run_transaction<F, Fut, R>(&mut self, f: F) -> Result<R>
584 where
585 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Fut,
586 Fut: core::future::Future<Output = Result<R>>,
587 {
588 if self.in_transaction {
589 return Err(Error::NestedTransaction);
590 }
591
592 self.in_transaction = true;
593
594 if let Err(err) = self.query_drop("BEGIN").await {
595 self.in_transaction = false;
596 return Err(err);
597 }
598
599 let tx = super::transaction::Transaction::new(self.connection_id());
600 let result = f(self, tx).await;
601
602 if self.in_transaction {
604 let rollback_result = self.query_drop("ROLLBACK").await;
605 self.in_transaction = false;
606
607 if let Err(e) = result {
609 return Err(e);
610 }
611 rollback_result?;
612 }
613
614 result
615 }
616}
617
618#[instrument(skip_all)]
621async fn read_payload(reader: &mut Stream, buffer: &mut Vec<u8>) -> Result<u8> {
622 let mut packet_header = PacketHeader::new_zeroed();
623
624 buffer.clear();
625 reader.read_exact(packet_header.as_mut_bytes()).await?;
626
627 let length = packet_header.length();
628 let mut sequence_id = packet_header.sequence_id;
629
630 buffer.reserve(length);
631
632 {
634 let spare = buffer.spare_capacity_mut();
635 reader.read_buf_exact(&mut spare[..length]).await?;
636 unsafe {
638 buffer.set_len(length);
639 }
640 }
641
642 let mut current_length = length;
643 while current_length == 0xFFFFFF {
644 reader.read_exact(packet_header.as_mut_bytes()).await?;
645
646 current_length = packet_header.length();
647 sequence_id = packet_header.sequence_id;
648
649 buffer.reserve(current_length);
650 let spare = buffer.spare_capacity_mut();
651 reader.read_buf_exact(&mut spare[..current_length]).await?;
652 unsafe {
654 buffer.set_len(buffer.len() + current_length);
655 }
656 }
657
658 Ok(sequence_id)
659}
660
661async fn write_handshake_payload(
662 stream: &mut Stream,
663 buffer_set: &mut BufferSet,
664 sequence_id: u8,
665) -> Result<()> {
666 let mut buffer = buffer_set.write_buffer_mut().as_mut_slice();
667 let mut seq_id = sequence_id;
668
669 loop {
670 let chunk_size = buffer[4..].len().min(0xFFFFFF);
671 PacketHeader::mut_from_bytes(&mut buffer[0..4])?.encode_in_place(chunk_size, seq_id);
672 stream.write_all(&buffer[..4 + chunk_size]).await?;
673
674 if chunk_size < 0xFFFFFF {
675 break;
676 }
677
678 seq_id = seq_id.wrapping_add(1);
679 buffer = &mut buffer[0xFFFFFF..];
680 }
681 stream.flush().await?;
682 Ok(())
683}
684
685#[cfg(unix)]
687struct SocketPathHandler {
688 path: Option<String>,
689}
690
691#[cfg(unix)]
692impl TextResultSetHandler for SocketPathHandler {
693 fn no_result_set(&mut self, _: OkPayloadBytes) -> Result<()> {
694 Ok(())
695 }
696 fn resultset_start(&mut self, _: &[ColumnDefinition<'_>]) -> Result<()> {
697 Ok(())
698 }
699 fn resultset_end(&mut self, _: OkPayloadBytes) -> Result<()> {
700 Ok(())
701 }
702 fn row(&mut self, _: &[ColumnDefinition<'_>], row: TextRowPayload<'_>) -> Result<()> {
703 if row.0.first() == Some(&0xFB) {
705 return Ok(());
706 }
707 let (value, _) = read_string_lenenc(row.0)?;
709 if !value.is_empty() {
710 self.path = Some(String::from_utf8_lossy(value).into_owned());
711 }
712 Ok(())
713 }
714}