1use crate::PreparedStatement;
2use crate::buffer::BufferSet;
3use crate::buffer_pool::PooledBufferSet;
4use crate::constant::CapabilityFlags;
5use crate::error::{Error, Result};
6use crate::protocol::TextRowPayload;
7use crate::protocol::command::Action;
8use crate::protocol::command::ColumnDefinition;
9use crate::protocol::command::bulk_exec::{BulkExec, BulkFlags, BulkParamsSet, write_bulk_execute};
10use crate::protocol::command::prepared::Exec;
11use crate::protocol::command::prepared::write_execute;
12use crate::protocol::command::prepared::{read_prepare_ok, write_prepare};
13use crate::protocol::command::query::Query;
14use crate::protocol::command::query::write_query;
15use crate::protocol::command::utility::DropHandler;
16use crate::protocol::command::utility::FirstRowHandler;
17use crate::protocol::command::utility::write_ping;
18use crate::protocol::command::utility::write_reset_connection;
19use crate::protocol::connection::{Handshake, HandshakeAction, InitialHandshake};
20use crate::protocol::packet::PacketHeader;
21use crate::protocol::primitive::read_string_lenenc;
22use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
23use crate::protocol::r#trait::{BinaryResultSetHandler, TextResultSetHandler, param::Params};
24use core::hint::unlikely;
25use core::io::BorrowedBuf;
26use std::net::TcpStream;
27#[cfg(unix)]
28use std::os::unix::net::UnixStream;
29use zerocopy::FromZeros;
30use zerocopy::{FromBytes, IntoBytes};
31
32use super::stream::Stream;
33
34pub struct Conn {
35 stream: Stream,
36 buffer_set: PooledBufferSet,
37 initial_handshake: InitialHandshake,
38 capability_flags: CapabilityFlags,
39 mariadb_capabilities: crate::constant::MariadbCapabilityFlags,
40 in_transaction: bool,
41 is_broken: bool,
42}
43
44impl Conn {
45 pub(crate) fn set_in_transaction(&mut self, value: bool) {
46 self.in_transaction = value;
47 }
48
49 pub fn new<O: TryInto<crate::opts::Opts>>(opts: O) -> Result<Self>
51 where
52 Error: From<O::Error>,
53 {
54 let opts: crate::opts::Opts = opts.try_into()?;
55
56 #[cfg(unix)]
57 let stream = if let Some(socket_path) = &opts.socket {
58 let stream = UnixStream::connect(socket_path)?;
59 Stream::unix(stream)
60 } else {
61 if opts.host.is_empty() {
62 return Err(Error::BadUsageError(
63 "Missing host in connection options".to_string(),
64 ));
65 }
66 let addr = format!("{}:{}", opts.host, opts.port);
67 let stream = TcpStream::connect(&addr)?;
68 stream.set_nodelay(opts.tcp_nodelay)?;
69 Stream::tcp(stream)
70 };
71
72 #[cfg(not(unix))]
73 let stream = {
74 if opts.socket.is_some() {
75 return Err(Error::BadUsageError(
76 "Unix sockets are not supported on this platform".to_string(),
77 ));
78 }
79 if opts.host.is_empty() {
80 return Err(Error::BadUsageError(
81 "Missing host in connection options".to_string(),
82 ));
83 }
84 let addr = format!("{}:{}", opts.host, opts.port);
85 let stream = TcpStream::connect(&addr)?;
86 stream.set_nodelay(opts.tcp_nodelay)?;
87 Stream::tcp(stream)
88 };
89
90 Self::new_with_stream(stream, &opts)
91 }
92
93 pub fn new_with_stream(stream: Stream, opts: &crate::opts::Opts) -> Result<Self> {
95 let mut conn_stream = stream;
96 let mut buffer_set = opts.buffer_pool.get_buffer_set();
97
98 #[cfg(feature = "sync-tls")]
99 let host = opts.host.clone();
100
101 let mut handshake = Handshake::new(opts);
102
103 loop {
104 match handshake.step(&mut buffer_set)? {
105 HandshakeAction::ReadPacket(buffer) => {
106 buffer.clear();
107 read_payload(&mut conn_stream, buffer)?;
108 }
109 HandshakeAction::WritePacket { sequence_id } => {
110 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id)?;
111 buffer_set.read_buffer.clear();
112 read_payload(&mut conn_stream, &mut buffer_set.read_buffer)?;
113 }
114 #[cfg(feature = "sync-tls")]
115 HandshakeAction::UpgradeTls { sequence_id } => {
116 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id)?;
117 conn_stream = conn_stream.upgrade_to_tls(&host)?;
118 }
119 #[cfg(not(feature = "sync-tls"))]
120 HandshakeAction::UpgradeTls { .. } => {
121 return Err(Error::BadUsageError(
122 "TLS requested but sync-tls feature is not enabled".to_string(),
123 ));
124 }
125 HandshakeAction::Finished => break,
126 }
127 }
128
129 let (initial_handshake, capability_flags, mariadb_capabilities) = handshake.finish()?;
130
131 let conn = Self {
132 stream: conn_stream,
133 buffer_set,
134 initial_handshake,
135 capability_flags,
136 mariadb_capabilities,
137 in_transaction: false,
138 is_broken: false,
139 };
140
141 #[cfg(unix)]
143 let mut conn = if opts.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
144 conn.try_upgrade_to_unix_socket(opts)
145 } else {
146 conn
147 };
148 #[cfg(not(unix))]
149 let mut conn = conn;
150
151 if let Some(init_command) = &opts.init_command {
153 conn.query_drop(init_command)?;
154 }
155
156 Ok(conn)
157 }
158
159 pub fn server_version(&self) -> &[u8] {
161 &self.buffer_set.initial_handshake[self.initial_handshake.server_version.clone()]
162 }
163
164 pub fn capability_flags(&self) -> CapabilityFlags {
166 self.capability_flags
167 }
168
169 pub fn is_mysql(&self) -> bool {
171 self.capability_flags.is_mysql()
172 }
173
174 pub fn is_mariadb(&self) -> bool {
176 self.capability_flags.is_mariadb()
177 }
178
179 pub fn connection_id(&self) -> u64 {
181 self.initial_handshake.connection_id as u64
182 }
183
184 pub fn status_flags(&self) -> crate::constant::ServerStatusFlags {
186 self.initial_handshake.status_flags
187 }
188
189 pub fn is_broken(&self) -> bool {
193 self.is_broken
194 }
195
196 #[inline]
197 fn check_error<T>(&mut self, result: Result<T>) -> Result<T> {
198 if let Err(e) = &result
199 && e.is_conn_broken()
200 {
201 self.is_broken = true;
202 }
203 result
204 }
205
206 #[cfg(unix)]
209 fn try_upgrade_to_unix_socket(mut self, opts: &crate::opts::Opts) -> Self {
210 let mut handler = SocketPathHandler { path: None };
212 if self.query("SELECT @@socket", &mut handler).is_err() {
213 return self;
214 }
215
216 let socket_path = match handler.path {
217 Some(p) if !p.is_empty() => p,
218 _ => return self,
219 };
220
221 let unix_stream = match UnixStream::connect(&socket_path) {
223 Ok(s) => s,
224 Err(_) => return self,
225 };
226 let stream = Stream::unix(unix_stream);
227
228 let mut opts_unix = opts.clone();
231 opts_unix.upgrade_to_unix_socket = false;
232
233 match Self::new_with_stream(stream, &opts_unix) {
234 Ok(new_conn) => new_conn,
235 Err(_) => self,
236 }
237 }
238
239 fn write_payload(&mut self) -> Result<()> {
240 let mut sequence_id = 0_u8;
241 let mut buffer = self.buffer_set.write_buffer_mut().as_mut_slice();
242
243 loop {
244 let chunk_size = buffer[4..].len().min(0xFFFFFF);
245 PacketHeader::mut_from_bytes(&mut buffer[0..4])?
246 .encode_in_place(chunk_size, sequence_id);
247 self.stream.write_all(&buffer[..4 + chunk_size])?;
248
249 if chunk_size < 0xFFFFFF {
250 break;
251 }
252
253 sequence_id = sequence_id.wrapping_add(1);
254 buffer = &mut buffer[0xFFFFFF..];
255 }
256 self.stream.flush()?;
257 Ok(())
258 }
259
260 pub fn prepare(&mut self, sql: &str) -> Result<PreparedStatement> {
262 let result = self.prepare_inner(sql);
263 self.check_error(result)
264 }
265
266 fn prepare_inner(&mut self, sql: &str) -> Result<PreparedStatement> {
267 use crate::protocol::command::ColumnDefinitions;
268
269 self.buffer_set.read_buffer.clear();
270
271 write_prepare(self.buffer_set.new_write_buffer(), sql);
272
273 self.write_payload()?;
274 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
275
276 if unlikely(
277 !self.buffer_set.read_buffer.is_empty() && self.buffer_set.read_buffer[0] == 0xFF,
278 ) {
279 Err(ErrPayloadBytes(&self.buffer_set.read_buffer))?
280 }
281
282 let prepare_ok = read_prepare_ok(&self.buffer_set.read_buffer)?;
283 let statement_id = prepare_ok.statement_id();
284 let num_params = prepare_ok.num_params();
285 let num_columns = prepare_ok.num_columns();
286
287 if num_params > 0 {
289 for _ in 0..num_params {
290 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
291 }
292 }
293
294 let column_definitions = if num_columns > 0 {
296 read_column_definition_packets(
297 &mut self.stream,
298 &mut self.buffer_set.column_definition_buffer,
299 num_columns as usize,
300 )?;
301 Some(ColumnDefinitions::new(
302 num_columns as usize,
303 std::mem::take(&mut self.buffer_set.column_definition_buffer),
304 )?)
305 } else {
306 None
307 };
308
309 let mut stmt = PreparedStatement::new(statement_id);
310 if let Some(col_defs) = column_definitions {
311 stmt.set_column_definitions(col_defs);
312 }
313 Ok(stmt)
314 }
315
316 fn drive_exec<H: BinaryResultSetHandler>(
317 &mut self,
318 stmt: &mut PreparedStatement,
319 handler: &mut H,
320 ) -> Result<()> {
321 let cache_metadata = self
322 .mariadb_capabilities
323 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
324 let mut exec = Exec::new(handler, stmt, cache_metadata);
325
326 loop {
327 match exec.step(&mut self.buffer_set)? {
328 Action::NeedPacket(buffer) => {
329 buffer.clear();
330 let _ = read_payload(&mut self.stream, buffer)?;
331 }
332 Action::ReadColumnMetadata { num_columns } => {
333 read_column_definition_packets(
334 &mut self.stream,
335 &mut self.buffer_set.column_definition_buffer,
336 num_columns,
337 )?;
338 }
339 Action::Finished => return Ok(()),
340 }
341 }
342 }
343
344 pub fn exec<'conn, P, H>(
348 &'conn mut self,
349 stmt: &'conn mut PreparedStatement,
350 params: P,
351 handler: &mut H,
352 ) -> Result<()>
353 where
354 P: Params,
355 H: BinaryResultSetHandler,
356 {
357 let result = self.exec_inner(stmt, params, handler);
358 self.check_error(result)
359 }
360
361 fn exec_inner<'conn, P, H>(
362 &'conn mut self,
363 stmt: &'conn mut PreparedStatement,
364 params: P,
365 handler: &mut H,
366 ) -> Result<()>
367 where
368 P: Params,
369 H: BinaryResultSetHandler,
370 {
371 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
372 self.write_payload()?;
373 self.drive_exec(stmt, handler)
374 }
375
376 fn drive_bulk_exec<H: BinaryResultSetHandler>(
377 &mut self,
378 stmt: &mut PreparedStatement,
379 handler: &mut H,
380 ) -> Result<()> {
381 let cache_metadata = self
382 .mariadb_capabilities
383 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
384 let mut bulk_exec = BulkExec::new(handler, stmt, cache_metadata);
385
386 loop {
387 match bulk_exec.step(&mut self.buffer_set)? {
388 Action::NeedPacket(buffer) => {
389 buffer.clear();
390 let _ = read_payload(&mut self.stream, buffer)?;
391 }
392 Action::ReadColumnMetadata { num_columns } => {
393 read_column_definition_packets(
394 &mut self.stream,
395 &mut self.buffer_set.column_definition_buffer,
396 num_columns,
397 )?;
398 }
399 Action::Finished => return Ok(()),
400 }
401 }
402 }
403
404 pub fn exec_bulk_insert_or_update<P, I, H>(
406 &mut self,
407 stmt: &mut PreparedStatement,
408 params: P,
409 flags: BulkFlags,
410 handler: &mut H,
411 ) -> Result<()>
412 where
413 P: BulkParamsSet + IntoIterator<Item = I>,
414 I: Params,
415 H: BinaryResultSetHandler,
416 {
417 let result = self.exec_bulk_insert_or_update_inner(stmt, params, flags, handler);
418 self.check_error(result)
419 }
420
421 fn exec_bulk_insert_or_update_inner<P, I, H>(
422 &mut self,
423 stmt: &mut PreparedStatement,
424 params: P,
425 flags: BulkFlags,
426 handler: &mut H,
427 ) -> Result<()>
428 where
429 P: BulkParamsSet + IntoIterator<Item = I>,
430 I: Params,
431 H: BinaryResultSetHandler,
432 {
433 if !self.is_mariadb() {
434 for param in params {
436 self.exec_inner(stmt, param, &mut DropHandler::default())?;
437 }
438 Ok(())
439 } else {
440 write_bulk_execute(self.buffer_set.new_write_buffer(), stmt.id(), params, flags)?;
442 self.write_payload()?;
443 self.drive_bulk_exec(stmt, handler)
444 }
445 }
446
447 pub fn exec_first<'conn, P, H>(
454 &'conn mut self,
455 stmt: &'conn mut PreparedStatement,
456 params: P,
457 handler: &mut H,
458 ) -> Result<bool>
459 where
460 P: Params,
461 H: BinaryResultSetHandler,
462 {
463 let result = self.exec_first_inner(stmt, params, handler);
464 self.check_error(result)
465 }
466
467 fn exec_first_inner<'conn, P, H>(
468 &'conn mut self,
469 stmt: &'conn mut PreparedStatement,
470 params: P,
471 handler: &mut H,
472 ) -> Result<bool>
473 where
474 P: Params,
475 H: BinaryResultSetHandler,
476 {
477 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
478 self.write_payload()?;
479 let mut first_row_handler = FirstRowHandler::new(handler);
480 self.drive_exec(stmt, &mut first_row_handler)?;
481 Ok(first_row_handler.found_row)
482 }
483
484 pub fn exec_drop<P>(&mut self, stmt: &mut PreparedStatement, params: P) -> Result<()>
486 where
487 P: Params,
488 {
489 self.exec(stmt, params, &mut DropHandler::default())
490 }
491
492 pub fn exec_rows<Row, P>(&mut self, stmt: &mut PreparedStatement, params: P) -> Result<Vec<Row>>
494 where
495 Row: for<'buf> crate::raw::FromRawRow<'buf>,
496 P: Params,
497 {
498 let mut handler = crate::handler::CollectHandler::<Row>::default();
499 self.exec(stmt, params, &mut handler)?;
500 Ok(handler.into_rows())
501 }
502
503 fn drive_query<H: TextResultSetHandler>(&mut self, handler: &mut H) -> Result<()> {
504 let mut query = Query::new(handler);
505
506 loop {
507 match query.step(&mut self.buffer_set)? {
508 Action::NeedPacket(buffer) => {
509 buffer.clear();
510 let _ = read_payload(&mut self.stream, buffer)?;
511 }
512 Action::ReadColumnMetadata { num_columns } => {
513 read_column_definition_packets(
514 &mut self.stream,
515 &mut self.buffer_set.column_definition_buffer,
516 num_columns,
517 )?;
518 }
519 Action::Finished => return Ok(()),
520 }
521 }
522 }
523
524 pub fn query<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
526 where
527 H: TextResultSetHandler,
528 {
529 let result = self.query_inner(sql, handler);
530 self.check_error(result)
531 }
532
533 fn query_inner<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
534 where
535 H: TextResultSetHandler,
536 {
537 write_query(self.buffer_set.new_write_buffer(), sql);
538 self.write_payload()?;
539 self.drive_query(handler)
540 }
541
542 pub fn query_drop(&mut self, sql: &str) -> Result<()> {
544 let result = self.query_drop_inner(sql);
545 self.check_error(result)
546 }
547
548 fn query_drop_inner(&mut self, sql: &str) -> Result<()> {
549 write_query(self.buffer_set.new_write_buffer(), sql);
550 self.write_payload()?;
551 self.drive_query(&mut DropHandler::default())
552 }
553
554 pub fn ping(&mut self) -> Result<()> {
558 let result = self.ping_inner();
559 self.check_error(result)
560 }
561
562 fn ping_inner(&mut self) -> Result<()> {
563 write_ping(self.buffer_set.new_write_buffer());
564 self.write_payload()?;
565 self.buffer_set.read_buffer.clear();
566 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
567 Ok(())
568 }
569
570 pub fn reset(&mut self) -> Result<()> {
572 let result = self.reset_inner();
573 self.check_error(result)
574 }
575
576 fn reset_inner(&mut self) -> Result<()> {
577 write_reset_connection(self.buffer_set.new_write_buffer());
578 self.write_payload()?;
579 self.buffer_set.read_buffer.clear();
580 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
581 self.in_transaction = false;
582 Ok(())
583 }
584
585 pub fn run_transaction<F, R>(&mut self, f: F) -> Result<R>
590 where
591 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
592 {
593 if self.in_transaction {
594 return Err(Error::NestedTransaction);
595 }
596
597 self.in_transaction = true;
598
599 if let Err(e) = self.query_drop("BEGIN") {
600 self.in_transaction = false;
601 return Err(e);
602 }
603
604 let tx = super::transaction::Transaction::new(self.connection_id());
605 let result = f(self, tx);
606
607 if self.in_transaction {
609 let rollback_result = self.query_drop("ROLLBACK");
610 self.in_transaction = false;
611
612 if let Err(e) = result {
614 return Err(e);
615 }
616 rollback_result?;
617 }
618
619 result
620 }
621}
622
623fn read_payload(reader: &mut Stream, buffer: &mut Vec<u8>) -> Result<u8> {
626 buffer.clear();
627
628 let mut header = PacketHeader::new_zeroed();
629 reader.read_exact(header.as_mut_bytes())?;
630
631 let length = header.length();
632 let mut sequence_id = header.sequence_id;
633
634 buffer.reserve(length);
635
636 {
637 let spare = buffer.spare_capacity_mut();
638 let mut buf: BorrowedBuf<'_> = (&mut spare[..length]).into();
639 reader.read_buf_exact(buf.unfilled())?;
640 unsafe {
642 buffer.set_len(length);
643 }
644 }
645
646 let mut current_length = length;
647 while current_length == 0xFFFFFF {
648 reader.read_exact(header.as_mut_bytes())?;
649
650 current_length = header.length();
651 sequence_id = header.sequence_id;
652
653 buffer.reserve(current_length);
654 let spare = buffer.spare_capacity_mut();
655 let mut buf: BorrowedBuf<'_> = (&mut spare[..current_length]).into();
656 reader.read_buf_exact(buf.unfilled())?;
657 unsafe {
659 buffer.set_len(buffer.len() + current_length);
660 }
661 }
662
663 Ok(sequence_id)
664}
665
666fn read_column_definition_packets(
667 reader: &mut Stream,
668 out: &mut Vec<u8>,
669 num_columns: usize,
670) -> Result<u8> {
671 out.clear();
672 let mut header = PacketHeader::new_zeroed();
673
674 for _ in 0..num_columns {
676 reader.read_exact(header.as_mut_bytes())?;
677 let length = header.length();
678 out.extend((length as u32).to_ne_bytes());
679
680 out.reserve(length);
681 let spare = out.spare_capacity_mut();
682 let mut buf: BorrowedBuf<'_> = (&mut spare[..length]).into();
683 reader.read_buf_exact(buf.unfilled())?;
684 unsafe {
686 out.set_len(out.len() + length);
687 }
688 }
689
690 Ok(header.sequence_id)
691}
692
693fn write_handshake_payload(
694 stream: &mut Stream,
695 buffer_set: &mut BufferSet,
696 sequence_id: u8,
697) -> Result<()> {
698 let mut buffer = buffer_set.write_buffer_mut().as_mut_slice();
699 let mut seq_id = sequence_id;
700
701 loop {
702 let chunk_size = buffer[4..].len().min(0xFFFFFF);
703 PacketHeader::mut_from_bytes(&mut buffer[0..4])?.encode_in_place(chunk_size, seq_id);
704 stream.write_all(&buffer[..4 + chunk_size])?;
705
706 if chunk_size < 0xFFFFFF {
707 break;
708 }
709
710 seq_id = seq_id.wrapping_add(1);
711 buffer = &mut buffer[0xFFFFFF..];
712 }
713 stream.flush()?;
714 Ok(())
715}
716
717#[cfg(unix)]
719struct SocketPathHandler {
720 path: Option<String>,
721}
722
723#[cfg(unix)]
724impl TextResultSetHandler for SocketPathHandler {
725 fn no_result_set(&mut self, _: OkPayloadBytes) -> Result<()> {
726 Ok(())
727 }
728 fn resultset_start(&mut self, _: &[ColumnDefinition<'_>]) -> Result<()> {
729 Ok(())
730 }
731 fn resultset_end(&mut self, _: OkPayloadBytes) -> Result<()> {
732 Ok(())
733 }
734 fn row(&mut self, _: &[ColumnDefinition<'_>], row: TextRowPayload<'_>) -> Result<()> {
735 if row.0.first() == Some(&0xFB) {
737 return Ok(());
738 }
739 let (value, _) = read_string_lenenc(row.0)?;
741 if !value.is_empty() {
742 self.path = Some(String::from_utf8_lossy(value).into_owned());
743 }
744 Ok(())
745 }
746}