1use crate::PreparedStatement;
2use crate::buffer::BufferSet;
3use crate::buffer_pool::PooledBufferSet;
4use crate::constant::CapabilityFlags;
5use crate::error::{Error, Result};
6use crate::protocol::TextRowPayload;
7use crate::protocol::command::Action;
8use crate::protocol::command::ColumnDefinition;
9use crate::protocol::command::bulk_exec::{BulkExec, BulkFlags, BulkParamsSet, write_bulk_execute};
10use crate::protocol::command::prepared::Exec;
11use crate::protocol::command::prepared::write_execute;
12use crate::protocol::command::prepared::{read_prepare_ok, write_prepare};
13use crate::protocol::command::query::Query;
14use crate::protocol::command::query::write_query;
15use crate::protocol::command::utility::DropHandler;
16use crate::protocol::command::utility::FirstHandler;
17use crate::protocol::command::utility::write_ping;
18use crate::protocol::command::utility::write_reset_connection;
19use crate::protocol::connection::{Handshake, HandshakeAction, InitialHandshake};
20use crate::protocol::packet::PacketHeader;
21use crate::protocol::primitive::read_string_lenenc;
22use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
23use crate::protocol::r#trait::{BinaryResultSetHandler, TextResultSetHandler, param::Params};
24use core::hint::unlikely;
25use core::io::BorrowedBuf;
26use std::net::TcpStream;
27#[cfg(unix)]
28use std::os::unix::net::UnixStream;
29use zerocopy::FromZeros;
30use zerocopy::{FromBytes, IntoBytes};
31
32use super::stream::Stream;
33
34pub struct Conn {
35 stream: Stream,
36 buffer_set: PooledBufferSet,
37 initial_handshake: InitialHandshake,
38 capability_flags: CapabilityFlags,
39 mariadb_capabilities: crate::constant::MariadbCapabilityFlags,
40 in_transaction: bool,
41 is_broken: bool,
42}
43
44impl Conn {
45 pub(crate) fn set_in_transaction(&mut self, value: bool) {
46 self.in_transaction = value;
47 }
48
49 pub fn new<O: TryInto<crate::opts::Opts>>(opts: O) -> Result<Self>
51 where
52 Error: From<O::Error>,
53 {
54 let opts: crate::opts::Opts = opts.try_into()?;
55
56 #[cfg(unix)]
57 let stream = if let Some(socket_path) = &opts.socket {
58 let stream = UnixStream::connect(socket_path)?;
59 Stream::unix(stream)
60 } else {
61 if opts.host.is_empty() {
62 return Err(Error::BadUsageError(
63 "Missing host in connection options".to_string(),
64 ));
65 }
66 let addr = format!("{}:{}", opts.host, opts.port);
67 let stream = TcpStream::connect(&addr)?;
68 stream.set_nodelay(opts.tcp_nodelay)?;
69 Stream::tcp(stream)
70 };
71
72 #[cfg(not(unix))]
73 let stream = {
74 if opts.socket.is_some() {
75 return Err(Error::BadUsageError(
76 "Unix sockets are not supported on this platform".to_string(),
77 ));
78 }
79 if opts.host.is_empty() {
80 return Err(Error::BadUsageError(
81 "Missing host in connection options".to_string(),
82 ));
83 }
84 let addr = format!("{}:{}", opts.host, opts.port);
85 let stream = TcpStream::connect(&addr)?;
86 stream.set_nodelay(opts.tcp_nodelay)?;
87 Stream::tcp(stream)
88 };
89
90 Self::new_with_stream(stream, &opts)
91 }
92
93 pub fn new_with_stream(stream: Stream, opts: &crate::opts::Opts) -> Result<Self> {
95 let mut conn_stream = stream;
96 let mut buffer_set = opts.buffer_pool.get_buffer_set();
97
98 #[cfg(feature = "sync-tls")]
99 let host = opts.host.clone();
100
101 let mut handshake = Handshake::new(opts);
102
103 loop {
104 match handshake.step(&mut buffer_set)? {
105 HandshakeAction::ReadPacket(buffer) => {
106 buffer.clear();
107 read_payload(&mut conn_stream, buffer)?;
108 }
109 HandshakeAction::WritePacket { sequence_id } => {
110 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id)?;
111 buffer_set.read_buffer.clear();
112 read_payload(&mut conn_stream, &mut buffer_set.read_buffer)?;
113 }
114 #[cfg(feature = "sync-tls")]
115 HandshakeAction::UpgradeTls { sequence_id } => {
116 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id)?;
117 conn_stream = conn_stream.upgrade_to_tls(&host)?;
118 }
119 #[cfg(not(feature = "sync-tls"))]
120 HandshakeAction::UpgradeTls { .. } => {
121 return Err(Error::BadUsageError(
122 "TLS requested but sync-tls feature is not enabled".to_string(),
123 ));
124 }
125 HandshakeAction::Finished => break,
126 }
127 }
128
129 let (initial_handshake, capability_flags, mariadb_capabilities) = handshake.finish()?;
130
131 let conn = Self {
132 stream: conn_stream,
133 buffer_set,
134 initial_handshake,
135 capability_flags,
136 mariadb_capabilities,
137 in_transaction: false,
138 is_broken: false,
139 };
140
141 #[cfg(unix)]
143 let mut conn = if opts.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
144 conn.try_upgrade_to_unix_socket(opts)
145 } else {
146 conn
147 };
148 #[cfg(not(unix))]
149 let mut conn = conn;
150
151 if let Some(init_command) = &opts.init_command {
153 conn.query_drop(init_command)?;
154 }
155
156 Ok(conn)
157 }
158
159 pub fn server_version(&self) -> &[u8] {
161 &self.buffer_set.initial_handshake[self.initial_handshake.server_version.clone()]
162 }
163
164 pub fn capability_flags(&self) -> CapabilityFlags {
166 self.capability_flags
167 }
168
169 pub fn is_mysql(&self) -> bool {
171 self.capability_flags.is_mysql()
172 }
173
174 pub fn is_mariadb(&self) -> bool {
176 self.capability_flags.is_mariadb()
177 }
178
179 pub fn connection_id(&self) -> u64 {
181 self.initial_handshake.connection_id as u64
182 }
183
184 pub fn status_flags(&self) -> crate::constant::ServerStatusFlags {
186 self.initial_handshake.status_flags
187 }
188
189 pub fn is_broken(&self) -> bool {
193 self.is_broken
194 }
195
196 #[inline]
197 fn check_error<T>(&mut self, result: Result<T>) -> Result<T> {
198 if let Err(e) = &result
199 && e.is_conn_broken()
200 {
201 self.is_broken = true;
202 }
203 result
204 }
205
206 #[cfg(unix)]
209 fn try_upgrade_to_unix_socket(mut self, opts: &crate::opts::Opts) -> Self {
210 let mut handler = SocketPathHandler { path: None };
212 if self.query("SELECT @@socket", &mut handler).is_err() {
213 return self;
214 }
215
216 let socket_path = match handler.path {
217 Some(p) if !p.is_empty() => p,
218 _ => return self,
219 };
220
221 let unix_stream = match UnixStream::connect(&socket_path) {
223 Ok(s) => s,
224 Err(_) => return self,
225 };
226 let stream = Stream::unix(unix_stream);
227
228 let mut opts_unix = opts.clone();
231 opts_unix.upgrade_to_unix_socket = false;
232
233 match Self::new_with_stream(stream, &opts_unix) {
234 Ok(new_conn) => new_conn,
235 Err(_) => self,
236 }
237 }
238
239 fn write_payload(&mut self) -> Result<()> {
240 let mut sequence_id = 0_u8;
241 let mut buffer = self.buffer_set.write_buffer_mut().as_mut_slice();
242
243 loop {
244 let chunk_size = buffer[4..].len().min(0xFFFFFF);
245 PacketHeader::mut_from_bytes(&mut buffer[0..4])?
246 .encode_in_place(chunk_size, sequence_id);
247 self.stream.write_all(&buffer[..4 + chunk_size])?;
248
249 if chunk_size < 0xFFFFFF {
250 break;
251 }
252
253 sequence_id = sequence_id.wrapping_add(1);
254 buffer = &mut buffer[0xFFFFFF..];
255 }
256 self.stream.flush()?;
257 Ok(())
258 }
259
260 pub fn prepare(&mut self, sql: &str) -> Result<PreparedStatement> {
262 let result = self.prepare_inner(sql);
263 self.check_error(result)
264 }
265
266 fn prepare_inner(&mut self, sql: &str) -> Result<PreparedStatement> {
267 use crate::protocol::command::ColumnDefinitions;
268
269 self.buffer_set.read_buffer.clear();
270
271 write_prepare(self.buffer_set.new_write_buffer(), sql);
272
273 self.write_payload()?;
274 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
275
276 if unlikely(
277 !self.buffer_set.read_buffer.is_empty() && self.buffer_set.read_buffer[0] == 0xFF,
278 ) {
279 Err(ErrPayloadBytes(&self.buffer_set.read_buffer))?
280 }
281
282 let prepare_ok = read_prepare_ok(&self.buffer_set.read_buffer)?;
283 let statement_id = prepare_ok.statement_id();
284 let num_params = prepare_ok.num_params();
285 let num_columns = prepare_ok.num_columns();
286
287 if num_params > 0 {
289 for _ in 0..num_params {
290 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
291 }
292 }
293
294 let column_definitions = if num_columns > 0 {
296 read_column_definition_packets(
297 &mut self.stream,
298 &mut self.buffer_set.column_definition_buffer,
299 num_columns as usize,
300 )?;
301 Some(ColumnDefinitions::new(
302 num_columns as usize,
303 std::mem::take(&mut self.buffer_set.column_definition_buffer),
304 )?)
305 } else {
306 None
307 };
308
309 let mut stmt = PreparedStatement::new(statement_id);
310 if let Some(col_defs) = column_definitions {
311 stmt.set_column_definitions(col_defs);
312 }
313 Ok(stmt)
314 }
315
316 fn drive_exec<H: BinaryResultSetHandler>(
317 &mut self,
318 stmt: &mut PreparedStatement,
319 handler: &mut H,
320 ) -> Result<()> {
321 let cache_metadata = self
322 .mariadb_capabilities
323 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
324 let mut exec = Exec::new(handler, stmt, cache_metadata);
325
326 loop {
327 match exec.step(&mut self.buffer_set)? {
328 Action::NeedPacket(buffer) => {
329 buffer.clear();
330 let _ = read_payload(&mut self.stream, buffer)?;
331 }
332 Action::ReadColumnMetadata { num_columns } => {
333 read_column_definition_packets(
334 &mut self.stream,
335 &mut self.buffer_set.column_definition_buffer,
336 num_columns,
337 )?;
338 }
339 Action::Finished => return Ok(()),
340 }
341 }
342 }
343
344 pub fn exec<'conn, P, H>(
348 &'conn mut self,
349 stmt: &'conn mut PreparedStatement,
350 params: P,
351 handler: &mut H,
352 ) -> Result<()>
353 where
354 P: Params,
355 H: BinaryResultSetHandler,
356 {
357 let result = self.exec_inner(stmt, params, handler);
358 self.check_error(result)
359 }
360
361 fn exec_inner<'conn, P, H>(
362 &'conn mut self,
363 stmt: &'conn mut PreparedStatement,
364 params: P,
365 handler: &mut H,
366 ) -> Result<()>
367 where
368 P: Params,
369 H: BinaryResultSetHandler,
370 {
371 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
372 self.write_payload()?;
373 self.drive_exec(stmt, handler)
374 }
375
376 fn drive_bulk_exec<H: BinaryResultSetHandler>(
377 &mut self,
378 stmt: &mut PreparedStatement,
379 handler: &mut H,
380 ) -> Result<()> {
381 let cache_metadata = self
382 .mariadb_capabilities
383 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
384 let mut bulk_exec = BulkExec::new(handler, stmt, cache_metadata);
385
386 loop {
387 match bulk_exec.step(&mut self.buffer_set)? {
388 Action::NeedPacket(buffer) => {
389 buffer.clear();
390 let _ = read_payload(&mut self.stream, buffer)?;
391 }
392 Action::ReadColumnMetadata { num_columns } => {
393 read_column_definition_packets(
394 &mut self.stream,
395 &mut self.buffer_set.column_definition_buffer,
396 num_columns,
397 )?;
398 }
399 Action::Finished => return Ok(()),
400 }
401 }
402 }
403
404 pub fn exec_bulk_insert_or_update<P, I, H>(
409 &mut self,
410 stmt: &mut PreparedStatement,
411 params: P,
412 flags: BulkFlags,
413 handler: &mut H,
414 ) -> Result<()>
415 where
416 P: BulkParamsSet + IntoIterator<Item = I>,
417 I: Params,
418 H: BinaryResultSetHandler,
419 {
420 let result = self.exec_bulk_insert_or_update_inner(stmt, params, flags, handler);
421 self.check_error(result)
422 }
423
424 fn exec_bulk_insert_or_update_inner<P, I, H>(
425 &mut self,
426 stmt: &mut PreparedStatement,
427 params: P,
428 flags: BulkFlags,
429 handler: &mut H,
430 ) -> Result<()>
431 where
432 P: BulkParamsSet + IntoIterator<Item = I>,
433 I: Params,
434 H: BinaryResultSetHandler,
435 {
436 if !self.is_mariadb() {
437 for param in params {
439 self.exec_inner(stmt, param, &mut DropHandler::default())?;
440 }
441 Ok(())
442 } else {
443 write_bulk_execute(self.buffer_set.new_write_buffer(), stmt.id(), params, flags)?;
445 self.write_payload()?;
446 self.drive_bulk_exec(stmt, handler)
447 }
448 }
449
450 pub fn exec_first<Row, P>(
452 &mut self,
453 stmt: &mut PreparedStatement,
454 params: P,
455 ) -> Result<Option<Row>>
456 where
457 Row: for<'buf> crate::raw::FromRawRow<'buf>,
458 P: Params,
459 {
460 let result = self.exec_first_inner(stmt, params);
461 self.check_error(result)
462 }
463
464 fn exec_first_inner<Row, P>(
465 &mut self,
466 stmt: &mut PreparedStatement,
467 params: P,
468 ) -> Result<Option<Row>>
469 where
470 Row: for<'buf> crate::raw::FromRawRow<'buf>,
471 P: Params,
472 {
473 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
474 self.write_payload()?;
475 let mut handler = FirstHandler::<Row>::default();
476 self.drive_exec(stmt, &mut handler)?;
477 Ok(handler.take())
478 }
479
480 pub fn exec_drop<P>(&mut self, stmt: &mut PreparedStatement, params: P) -> Result<()>
482 where
483 P: Params,
484 {
485 self.exec(stmt, params, &mut DropHandler::default())
486 }
487
488 pub fn exec_collect<Row, P>(
490 &mut self,
491 stmt: &mut PreparedStatement,
492 params: P,
493 ) -> Result<Vec<Row>>
494 where
495 Row: for<'buf> crate::raw::FromRawRow<'buf>,
496 P: Params,
497 {
498 let mut handler = crate::handler::CollectHandler::<Row>::default();
499 self.exec(stmt, params, &mut handler)?;
500 Ok(handler.into_rows())
501 }
502
503 pub fn exec_foreach<Row, P, F>(
507 &mut self,
508 stmt: &mut PreparedStatement,
509 params: P,
510 f: F,
511 ) -> Result<()>
512 where
513 Row: for<'buf> crate::raw::FromRawRow<'buf>,
514 P: Params,
515 F: FnMut(Row) -> Result<()>,
516 {
517 let mut handler = crate::handler::ForEachHandler::<Row, F>::new(f);
518 self.exec(stmt, params, &mut handler)
519 }
520
521 fn drive_query<H: TextResultSetHandler>(&mut self, handler: &mut H) -> Result<()> {
522 let mut query = Query::new(handler);
523
524 loop {
525 match query.step(&mut self.buffer_set)? {
526 Action::NeedPacket(buffer) => {
527 buffer.clear();
528 let _ = read_payload(&mut self.stream, buffer)?;
529 }
530 Action::ReadColumnMetadata { num_columns } => {
531 read_column_definition_packets(
532 &mut self.stream,
533 &mut self.buffer_set.column_definition_buffer,
534 num_columns,
535 )?;
536 }
537 Action::Finished => return Ok(()),
538 }
539 }
540 }
541
542 pub fn query<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
544 where
545 H: TextResultSetHandler,
546 {
547 let result = self.query_inner(sql, handler);
548 self.check_error(result)
549 }
550
551 fn query_inner<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
552 where
553 H: TextResultSetHandler,
554 {
555 write_query(self.buffer_set.new_write_buffer(), sql);
556 self.write_payload()?;
557 self.drive_query(handler)
558 }
559
560 pub fn query_drop(&mut self, sql: &str) -> Result<()> {
562 let result = self.query_drop_inner(sql);
563 self.check_error(result)
564 }
565
566 fn query_drop_inner(&mut self, sql: &str) -> Result<()> {
567 write_query(self.buffer_set.new_write_buffer(), sql);
568 self.write_payload()?;
569 self.drive_query(&mut DropHandler::default())
570 }
571
572 pub fn ping(&mut self) -> Result<()> {
576 let result = self.ping_inner();
577 self.check_error(result)
578 }
579
580 fn ping_inner(&mut self) -> Result<()> {
581 write_ping(self.buffer_set.new_write_buffer());
582 self.write_payload()?;
583 self.buffer_set.read_buffer.clear();
584 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
585 Ok(())
586 }
587
588 pub fn reset(&mut self) -> Result<()> {
590 let result = self.reset_inner();
591 self.check_error(result)
592 }
593
594 fn reset_inner(&mut self) -> Result<()> {
595 write_reset_connection(self.buffer_set.new_write_buffer());
596 self.write_payload()?;
597 self.buffer_set.read_buffer.clear();
598 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
599 self.in_transaction = false;
600 Ok(())
601 }
602
603 pub fn run_transaction<F, R>(&mut self, f: F) -> Result<R>
608 where
609 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
610 {
611 if self.in_transaction {
612 return Err(Error::NestedTransaction);
613 }
614
615 self.in_transaction = true;
616
617 if let Err(e) = self.query_drop("BEGIN") {
618 self.in_transaction = false;
619 return Err(e);
620 }
621
622 let tx = super::transaction::Transaction::new(self.connection_id());
623 let result = f(self, tx);
624
625 if self.in_transaction {
627 let rollback_result = self.query_drop("ROLLBACK");
628 self.in_transaction = false;
629
630 if let Err(e) = result {
632 return Err(e);
633 }
634 rollback_result?;
635 }
636
637 result
638 }
639}
640
641fn read_payload(reader: &mut Stream, buffer: &mut Vec<u8>) -> Result<u8> {
644 buffer.clear();
645
646 let mut header = PacketHeader::new_zeroed();
647 reader.read_exact(header.as_mut_bytes())?;
648
649 let length = header.length();
650 let mut sequence_id = header.sequence_id;
651
652 buffer.reserve(length);
653
654 {
655 let spare = buffer.spare_capacity_mut();
656 let mut buf: BorrowedBuf<'_> = (&mut spare[..length]).into();
657 reader.read_buf_exact(buf.unfilled())?;
658 unsafe {
660 buffer.set_len(length);
661 }
662 }
663
664 let mut current_length = length;
665 while current_length == 0xFFFFFF {
666 reader.read_exact(header.as_mut_bytes())?;
667
668 current_length = header.length();
669 sequence_id = header.sequence_id;
670
671 buffer.reserve(current_length);
672 let spare = buffer.spare_capacity_mut();
673 let mut buf: BorrowedBuf<'_> = (&mut spare[..current_length]).into();
674 reader.read_buf_exact(buf.unfilled())?;
675 unsafe {
677 buffer.set_len(buffer.len() + current_length);
678 }
679 }
680
681 Ok(sequence_id)
682}
683
684fn read_column_definition_packets(
685 reader: &mut Stream,
686 out: &mut Vec<u8>,
687 num_columns: usize,
688) -> Result<u8> {
689 out.clear();
690 let mut header = PacketHeader::new_zeroed();
691
692 for _ in 0..num_columns {
694 reader.read_exact(header.as_mut_bytes())?;
695 let length = header.length();
696 out.extend((length as u32).to_ne_bytes());
697
698 out.reserve(length);
699 let spare = out.spare_capacity_mut();
700 let mut buf: BorrowedBuf<'_> = (&mut spare[..length]).into();
701 reader.read_buf_exact(buf.unfilled())?;
702 unsafe {
704 out.set_len(out.len() + length);
705 }
706 }
707
708 Ok(header.sequence_id)
709}
710
711fn write_handshake_payload(
712 stream: &mut Stream,
713 buffer_set: &mut BufferSet,
714 sequence_id: u8,
715) -> Result<()> {
716 let mut buffer = buffer_set.write_buffer_mut().as_mut_slice();
717 let mut seq_id = sequence_id;
718
719 loop {
720 let chunk_size = buffer[4..].len().min(0xFFFFFF);
721 PacketHeader::mut_from_bytes(&mut buffer[0..4])?.encode_in_place(chunk_size, seq_id);
722 stream.write_all(&buffer[..4 + chunk_size])?;
723
724 if chunk_size < 0xFFFFFF {
725 break;
726 }
727
728 seq_id = seq_id.wrapping_add(1);
729 buffer = &mut buffer[0xFFFFFF..];
730 }
731 stream.flush()?;
732 Ok(())
733}
734
735#[cfg(unix)]
737struct SocketPathHandler {
738 path: Option<String>,
739}
740
741#[cfg(unix)]
742impl TextResultSetHandler for SocketPathHandler {
743 fn no_result_set(&mut self, _: OkPayloadBytes) -> Result<()> {
744 Ok(())
745 }
746 fn resultset_start(&mut self, _: &[ColumnDefinition<'_>]) -> Result<()> {
747 Ok(())
748 }
749 fn resultset_end(&mut self, _: OkPayloadBytes) -> Result<()> {
750 Ok(())
751 }
752 fn row(&mut self, _: &[ColumnDefinition<'_>], row: TextRowPayload<'_>) -> Result<()> {
753 if row.0.first() == Some(&0xFB) {
755 return Ok(());
756 }
757 let (value, _) = read_string_lenenc(row.0)?;
759 if !value.is_empty() {
760 self.path = Some(String::from_utf8_lossy(value).into_owned());
761 }
762 Ok(())
763 }
764}