1use crate::PreparedStatement;
2use crate::buffer::BufferSet;
3use crate::buffer_pool::PooledBufferSet;
4use crate::constant::CapabilityFlags;
5use crate::error::{Error, Result};
6use crate::nightly::unlikely;
7use crate::protocol::TextRowPayload;
8use crate::protocol::command::Action;
9use crate::protocol::command::ColumnDefinition;
10use crate::protocol::command::bulk_exec::{BulkExec, BulkFlags, BulkParamsSet, write_bulk_execute};
11use crate::protocol::command::prepared::Exec;
12use crate::protocol::command::prepared::write_execute;
13use crate::protocol::command::prepared::{read_prepare_ok, write_prepare};
14use crate::protocol::command::query::Query;
15use crate::protocol::command::query::write_query;
16use crate::protocol::command::utility::DropHandler;
17use crate::protocol::command::utility::FirstHandler;
18use crate::protocol::command::utility::write_ping;
19use crate::protocol::command::utility::write_reset_connection;
20use crate::protocol::connection::{Handshake, HandshakeAction, InitialHandshake};
21use crate::protocol::packet::PacketHeader;
22use crate::protocol::primitive::read_string_lenenc;
23use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
24use crate::protocol::r#trait::{BinaryResultSetHandler, TextResultSetHandler, param::Params};
25use std::net::TcpStream;
26#[cfg(unix)]
27use std::os::unix::net::UnixStream;
28use zerocopy::FromZeros;
29use zerocopy::{FromBytes, IntoBytes};
30
31use super::stream::Stream;
32
33pub struct Conn {
34 stream: Stream,
35 buffer_set: PooledBufferSet,
36 initial_handshake: InitialHandshake,
37 capability_flags: CapabilityFlags,
38 mariadb_capabilities: crate::constant::MariadbCapabilityFlags,
39 in_transaction: bool,
40 is_broken: bool,
41}
42
43impl Conn {
44 pub(crate) fn set_in_transaction(&mut self, value: bool) {
45 self.in_transaction = value;
46 }
47
48 pub fn in_transaction(&self) -> bool {
50 self.in_transaction
51 }
52
53 pub fn new<O: TryInto<crate::opts::Opts>>(opts: O) -> Result<Self>
55 where
56 Error: From<O::Error>,
57 {
58 let opts: crate::opts::Opts = opts.try_into()?;
59
60 #[cfg(unix)]
61 let stream = if let Some(socket_path) = &opts.socket {
62 let stream = UnixStream::connect(socket_path)?;
63 Stream::unix(stream)
64 } else {
65 if opts.host.is_empty() {
66 return Err(Error::BadUsageError(
67 "Missing host in connection options".to_string(),
68 ));
69 }
70 let addr = format!("{}:{}", opts.host, opts.port);
71 let stream = TcpStream::connect(&addr)?;
72 stream.set_nodelay(opts.tcp_nodelay)?;
73 Stream::tcp(stream)
74 };
75
76 #[cfg(not(unix))]
77 let stream = {
78 if opts.socket.is_some() {
79 return Err(Error::BadUsageError(
80 "Unix sockets are not supported on this platform".to_string(),
81 ));
82 }
83 if opts.host.is_empty() {
84 return Err(Error::BadUsageError(
85 "Missing host in connection options".to_string(),
86 ));
87 }
88 let addr = format!("{}:{}", opts.host, opts.port);
89 let stream = TcpStream::connect(&addr)?;
90 stream.set_nodelay(opts.tcp_nodelay)?;
91 Stream::tcp(stream)
92 };
93
94 Self::new_with_stream(stream, &opts)
95 }
96
97 pub fn new_with_stream(stream: Stream, opts: &crate::opts::Opts) -> Result<Self> {
99 let mut conn_stream = stream;
100 let mut buffer_set = opts.buffer_pool.get_buffer_set();
101
102 #[cfg(feature = "sync-tls")]
103 let host = opts.host.clone();
104
105 let mut handshake = Handshake::new(opts);
106
107 loop {
108 match handshake.step(&mut buffer_set)? {
109 HandshakeAction::ReadPacket(buffer) => {
110 buffer.clear();
111 read_payload(&mut conn_stream, buffer)?;
112 }
113 HandshakeAction::WritePacket { sequence_id } => {
114 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id)?;
115 buffer_set.read_buffer.clear();
116 read_payload(&mut conn_stream, &mut buffer_set.read_buffer)?;
117 }
118 #[cfg(feature = "sync-tls")]
119 HandshakeAction::UpgradeTls { sequence_id } => {
120 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id)?;
121 conn_stream = conn_stream.upgrade_to_tls(&host)?;
122 }
123 #[cfg(not(feature = "sync-tls"))]
124 HandshakeAction::UpgradeTls { .. } => {
125 return Err(Error::BadUsageError(
126 "TLS requested but sync-tls feature is not enabled".to_string(),
127 ));
128 }
129 HandshakeAction::Finished => break,
130 }
131 }
132
133 let (initial_handshake, capability_flags, mariadb_capabilities) = handshake.finish()?;
134
135 let conn = Self {
136 stream: conn_stream,
137 buffer_set,
138 initial_handshake,
139 capability_flags,
140 mariadb_capabilities,
141 in_transaction: false,
142 is_broken: false,
143 };
144
145 #[cfg(unix)]
147 let mut conn = if opts.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
148 conn.try_upgrade_to_unix_socket(opts)
149 } else {
150 conn
151 };
152 #[cfg(not(unix))]
153 let mut conn = conn;
154
155 if let Some(init_command) = &opts.init_command {
157 conn.query_drop(init_command)?;
158 }
159
160 Ok(conn)
161 }
162
163 pub fn server_version(&self) -> &[u8] {
165 &self.buffer_set.initial_handshake[self.initial_handshake.server_version.clone()]
166 }
167
168 pub fn capability_flags(&self) -> CapabilityFlags {
170 self.capability_flags
171 }
172
173 pub fn is_mysql(&self) -> bool {
175 self.capability_flags.is_mysql()
176 }
177
178 pub fn is_mariadb(&self) -> bool {
180 self.capability_flags.is_mariadb()
181 }
182
183 pub fn connection_id(&self) -> u64 {
185 self.initial_handshake.connection_id as u64
186 }
187
188 pub fn status_flags(&self) -> crate::constant::ServerStatusFlags {
190 self.initial_handshake.status_flags
191 }
192
193 pub fn is_broken(&self) -> bool {
197 self.is_broken
198 }
199
200 #[inline]
201 fn check_error<T>(&mut self, result: Result<T>) -> Result<T> {
202 if let Err(e) = &result
203 && e.is_conn_broken()
204 {
205 self.is_broken = true;
206 }
207 result
208 }
209
210 #[cfg(unix)]
213 fn try_upgrade_to_unix_socket(mut self, opts: &crate::opts::Opts) -> Self {
214 let mut handler = SocketPathHandler { path: None };
216 if self.query("SELECT @@socket", &mut handler).is_err() {
217 return self;
218 }
219
220 let socket_path = match handler.path {
221 Some(p) if !p.is_empty() => p,
222 _ => return self,
223 };
224
225 let unix_stream = match UnixStream::connect(&socket_path) {
227 Ok(s) => s,
228 Err(_) => return self,
229 };
230 let stream = Stream::unix(unix_stream);
231
232 let mut opts_unix = opts.clone();
235 opts_unix.upgrade_to_unix_socket = false;
236
237 match Self::new_with_stream(stream, &opts_unix) {
238 Ok(new_conn) => new_conn,
239 Err(_) => self,
240 }
241 }
242
243 fn write_payload(&mut self) -> Result<()> {
244 let mut sequence_id = 0_u8;
245 let mut buffer = self.buffer_set.write_buffer_mut().as_mut_slice();
246
247 loop {
248 let chunk_size = buffer[4..].len().min(0xFFFFFF);
249 PacketHeader::mut_from_bytes(&mut buffer[0..4])?
250 .encode_in_place(chunk_size, sequence_id);
251 self.stream.write_all(&buffer[..4 + chunk_size])?;
252
253 if chunk_size < 0xFFFFFF {
254 break;
255 }
256
257 sequence_id = sequence_id.wrapping_add(1);
258 buffer = &mut buffer[0xFFFFFF..];
259 }
260 self.stream.flush()?;
261 Ok(())
262 }
263
264 pub fn prepare(&mut self, sql: &str) -> Result<PreparedStatement> {
266 let result = self.prepare_inner(sql);
267 self.check_error(result)
268 }
269
270 fn prepare_inner(&mut self, sql: &str) -> Result<PreparedStatement> {
271 use crate::protocol::command::ColumnDefinitions;
272
273 self.buffer_set.read_buffer.clear();
274
275 write_prepare(self.buffer_set.new_write_buffer(), sql);
276
277 self.write_payload()?;
278 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
279
280 if unlikely(
281 !self.buffer_set.read_buffer.is_empty() && self.buffer_set.read_buffer[0] == 0xFF,
282 ) {
283 Err(ErrPayloadBytes(&self.buffer_set.read_buffer))?
284 }
285
286 let prepare_ok = read_prepare_ok(&self.buffer_set.read_buffer)?;
287 let statement_id = prepare_ok.statement_id();
288 let num_params = prepare_ok.num_params();
289 let num_columns = prepare_ok.num_columns();
290
291 if num_params > 0 {
293 for _ in 0..num_params {
294 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
295 }
296 }
297
298 let column_definitions = if num_columns > 0 {
300 read_column_definition_packets(
301 &mut self.stream,
302 &mut self.buffer_set.column_definition_buffer,
303 num_columns as usize,
304 )?;
305 Some(ColumnDefinitions::new(
306 num_columns as usize,
307 std::mem::take(&mut self.buffer_set.column_definition_buffer),
308 )?)
309 } else {
310 None
311 };
312
313 let mut stmt = PreparedStatement::new(statement_id);
314 if let Some(col_defs) = column_definitions {
315 stmt.set_column_definitions(col_defs);
316 }
317 Ok(stmt)
318 }
319
320 fn drive_exec<H: BinaryResultSetHandler>(
321 &mut self,
322 stmt: &mut PreparedStatement,
323 handler: &mut H,
324 ) -> Result<()> {
325 let cache_metadata = self
326 .mariadb_capabilities
327 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
328 let mut exec = Exec::new(handler, stmt, cache_metadata);
329
330 loop {
331 match exec.step(&mut self.buffer_set)? {
332 Action::NeedPacket(buffer) => {
333 buffer.clear();
334 let _ = read_payload(&mut self.stream, buffer)?;
335 }
336 Action::ReadColumnMetadata { num_columns } => {
337 read_column_definition_packets(
338 &mut self.stream,
339 &mut self.buffer_set.column_definition_buffer,
340 num_columns,
341 )?;
342 }
343 Action::Finished => return Ok(()),
344 }
345 }
346 }
347
348 pub fn exec<'conn, P, H>(
352 &'conn mut self,
353 stmt: &'conn mut PreparedStatement,
354 params: P,
355 handler: &mut H,
356 ) -> Result<()>
357 where
358 P: Params,
359 H: BinaryResultSetHandler,
360 {
361 let result = self.exec_inner(stmt, params, handler);
362 self.check_error(result)
363 }
364
365 fn exec_inner<'conn, P, H>(
366 &'conn mut self,
367 stmt: &'conn mut PreparedStatement,
368 params: P,
369 handler: &mut H,
370 ) -> Result<()>
371 where
372 P: Params,
373 H: BinaryResultSetHandler,
374 {
375 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
376 self.write_payload()?;
377 self.drive_exec(stmt, handler)
378 }
379
380 fn drive_bulk_exec<H: BinaryResultSetHandler>(
381 &mut self,
382 stmt: &mut PreparedStatement,
383 handler: &mut H,
384 ) -> Result<()> {
385 let cache_metadata = self
386 .mariadb_capabilities
387 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
388 let mut bulk_exec = BulkExec::new(handler, stmt, cache_metadata);
389
390 loop {
391 match bulk_exec.step(&mut self.buffer_set)? {
392 Action::NeedPacket(buffer) => {
393 buffer.clear();
394 let _ = read_payload(&mut self.stream, buffer)?;
395 }
396 Action::ReadColumnMetadata { num_columns } => {
397 read_column_definition_packets(
398 &mut self.stream,
399 &mut self.buffer_set.column_definition_buffer,
400 num_columns,
401 )?;
402 }
403 Action::Finished => return Ok(()),
404 }
405 }
406 }
407
408 pub fn exec_bulk_insert_or_update<P, I, H>(
413 &mut self,
414 stmt: &mut PreparedStatement,
415 params: P,
416 flags: BulkFlags,
417 handler: &mut H,
418 ) -> Result<()>
419 where
420 P: BulkParamsSet + IntoIterator<Item = I>,
421 I: Params,
422 H: BinaryResultSetHandler,
423 {
424 let result = self.exec_bulk_insert_or_update_inner(stmt, params, flags, handler);
425 self.check_error(result)
426 }
427
428 fn exec_bulk_insert_or_update_inner<P, I, H>(
429 &mut self,
430 stmt: &mut PreparedStatement,
431 params: P,
432 flags: BulkFlags,
433 handler: &mut H,
434 ) -> Result<()>
435 where
436 P: BulkParamsSet + IntoIterator<Item = I>,
437 I: Params,
438 H: BinaryResultSetHandler,
439 {
440 if !self.is_mariadb() {
441 for param in params {
443 self.exec_inner(stmt, param, &mut DropHandler::default())?;
444 }
445 Ok(())
446 } else {
447 write_bulk_execute(self.buffer_set.new_write_buffer(), stmt.id(), params, flags)?;
449 self.write_payload()?;
450 self.drive_bulk_exec(stmt, handler)
451 }
452 }
453
454 pub fn exec_first<Row, P>(
456 &mut self,
457 stmt: &mut PreparedStatement,
458 params: P,
459 ) -> Result<Option<Row>>
460 where
461 Row: for<'buf> crate::raw::FromRow<'buf>,
462 P: Params,
463 {
464 let result = self.exec_first_inner(stmt, params);
465 self.check_error(result)
466 }
467
468 fn exec_first_inner<Row, P>(
469 &mut self,
470 stmt: &mut PreparedStatement,
471 params: P,
472 ) -> Result<Option<Row>>
473 where
474 Row: for<'buf> crate::raw::FromRow<'buf>,
475 P: Params,
476 {
477 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
478 self.write_payload()?;
479 let mut handler = FirstHandler::<Row>::default();
480 self.drive_exec(stmt, &mut handler)?;
481 Ok(handler.take())
482 }
483
484 pub fn exec_drop<P>(&mut self, stmt: &mut PreparedStatement, params: P) -> Result<()>
486 where
487 P: Params,
488 {
489 self.exec(stmt, params, &mut DropHandler::default())
490 }
491
492 pub fn exec_collect<Row, P>(
494 &mut self,
495 stmt: &mut PreparedStatement,
496 params: P,
497 ) -> Result<Vec<Row>>
498 where
499 Row: for<'buf> crate::raw::FromRow<'buf>,
500 P: Params,
501 {
502 let mut handler = crate::handler::CollectHandler::<Row>::default();
503 self.exec(stmt, params, &mut handler)?;
504 Ok(handler.into_rows())
505 }
506
507 pub fn exec_foreach<Row, P, F>(
511 &mut self,
512 stmt: &mut PreparedStatement,
513 params: P,
514 f: F,
515 ) -> Result<()>
516 where
517 Row: for<'buf> crate::raw::FromRow<'buf>,
518 P: Params,
519 F: FnMut(Row) -> Result<()>,
520 {
521 let mut handler = crate::handler::ForEachHandler::<Row, F>::new(f);
522 self.exec(stmt, params, &mut handler)
523 }
524
525 pub fn exec_foreach_ref<Row, P, F>(
556 &mut self,
557 stmt: &mut PreparedStatement,
558 params: P,
559 f: F,
560 ) -> Result<()>
561 where
562 Row: for<'buf> crate::ref_row::RefFromRow<'buf>,
563 P: Params,
564 F: for<'buf> FnMut(&'buf Row) -> Result<()>,
565 {
566 let mut handler = crate::handler::ForEachRefHandler::<Row, F>::new(f);
567 self.exec(stmt, params, &mut handler)
568 }
569
570 fn drive_query<H: TextResultSetHandler>(&mut self, handler: &mut H) -> Result<()> {
571 let mut query = Query::new(handler);
572
573 loop {
574 match query.step(&mut self.buffer_set)? {
575 Action::NeedPacket(buffer) => {
576 buffer.clear();
577 let _ = read_payload(&mut self.stream, buffer)?;
578 }
579 Action::ReadColumnMetadata { num_columns } => {
580 read_column_definition_packets(
581 &mut self.stream,
582 &mut self.buffer_set.column_definition_buffer,
583 num_columns,
584 )?;
585 }
586 Action::Finished => return Ok(()),
587 }
588 }
589 }
590
591 pub fn query<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
593 where
594 H: TextResultSetHandler,
595 {
596 let result = self.query_inner(sql, handler);
597 self.check_error(result)
598 }
599
600 fn query_inner<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
601 where
602 H: TextResultSetHandler,
603 {
604 write_query(self.buffer_set.new_write_buffer(), sql);
605 self.write_payload()?;
606 self.drive_query(handler)
607 }
608
609 pub fn query_drop(&mut self, sql: &str) -> Result<()> {
611 let result = self.query_drop_inner(sql);
612 self.check_error(result)
613 }
614
615 fn query_drop_inner(&mut self, sql: &str) -> Result<()> {
616 write_query(self.buffer_set.new_write_buffer(), sql);
617 self.write_payload()?;
618 self.drive_query(&mut DropHandler::default())
619 }
620
621 pub fn ping(&mut self) -> Result<()> {
625 let result = self.ping_inner();
626 self.check_error(result)
627 }
628
629 fn ping_inner(&mut self) -> Result<()> {
630 write_ping(self.buffer_set.new_write_buffer());
631 self.write_payload()?;
632 self.buffer_set.read_buffer.clear();
633 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
634 Ok(())
635 }
636
637 pub fn reset(&mut self) -> Result<()> {
639 let result = self.reset_inner();
640 self.check_error(result)
641 }
642
643 fn reset_inner(&mut self) -> Result<()> {
644 write_reset_connection(self.buffer_set.new_write_buffer());
645 self.write_payload()?;
646 self.buffer_set.read_buffer.clear();
647 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
648 self.in_transaction = false;
649 Ok(())
650 }
651
652 pub fn transaction<F, R>(&mut self, f: F) -> Result<R>
657 where
658 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
659 {
660 if self.in_transaction {
661 return Err(Error::NestedTransaction);
662 }
663
664 self.in_transaction = true;
665
666 if let Err(e) = self.query_drop("BEGIN") {
667 self.in_transaction = false;
668 return Err(e);
669 }
670
671 let tx = super::transaction::Transaction::new(self.connection_id());
672 let result = f(self, tx);
673
674 if self.in_transaction {
676 self.in_transaction = false;
677 match &result {
678 Ok(_) => self.query_drop("COMMIT")?,
679 Err(_) => {
680 let _ = self.query_drop("ROLLBACK");
681 }
682 }
683 }
684
685 result
686 }
687}
688
689fn read_payload(reader: &mut Stream, buffer: &mut Vec<u8>) -> Result<u8> {
692 buffer.clear();
693
694 let mut header = PacketHeader::new_zeroed();
695 reader.read_exact(header.as_mut_bytes())?;
696
697 let length = header.length();
698 let mut sequence_id = header.sequence_id;
699
700 buffer.reserve(length);
701
702 {
703 let spare = buffer.spare_capacity_mut();
704 reader.read_buf_exact(&mut spare[..length])?;
705 unsafe {
707 buffer.set_len(length);
708 }
709 }
710
711 let mut current_length = length;
712 while current_length == 0xFFFFFF {
713 reader.read_exact(header.as_mut_bytes())?;
714
715 current_length = header.length();
716 sequence_id = header.sequence_id;
717
718 buffer.reserve(current_length);
719 let spare = buffer.spare_capacity_mut();
720 reader.read_buf_exact(&mut spare[..current_length])?;
721 unsafe {
723 buffer.set_len(buffer.len() + current_length);
724 }
725 }
726
727 Ok(sequence_id)
728}
729
730fn read_column_definition_packets(
731 reader: &mut Stream,
732 out: &mut Vec<u8>,
733 num_columns: usize,
734) -> Result<u8> {
735 out.clear();
736 let mut header = PacketHeader::new_zeroed();
737
738 for _ in 0..num_columns {
740 reader.read_exact(header.as_mut_bytes())?;
741 let length = header.length();
742 out.extend((length as u32).to_ne_bytes());
743
744 out.reserve(length);
745 let spare = out.spare_capacity_mut();
746 reader.read_buf_exact(&mut spare[..length])?;
747 unsafe {
749 out.set_len(out.len() + length);
750 }
751 }
752
753 Ok(header.sequence_id)
754}
755
756fn write_handshake_payload(
757 stream: &mut Stream,
758 buffer_set: &mut BufferSet,
759 sequence_id: u8,
760) -> Result<()> {
761 let mut buffer = buffer_set.write_buffer_mut().as_mut_slice();
762 let mut seq_id = sequence_id;
763
764 loop {
765 let chunk_size = buffer[4..].len().min(0xFFFFFF);
766 PacketHeader::mut_from_bytes(&mut buffer[0..4])?.encode_in_place(chunk_size, seq_id);
767 stream.write_all(&buffer[..4 + chunk_size])?;
768
769 if chunk_size < 0xFFFFFF {
770 break;
771 }
772
773 seq_id = seq_id.wrapping_add(1);
774 buffer = &mut buffer[0xFFFFFF..];
775 }
776 stream.flush()?;
777 Ok(())
778}
779
780#[cfg(unix)]
782struct SocketPathHandler {
783 path: Option<String>,
784}
785
786#[cfg(unix)]
787impl TextResultSetHandler for SocketPathHandler {
788 fn no_result_set(&mut self, _: OkPayloadBytes) -> Result<()> {
789 Ok(())
790 }
791 fn resultset_start(&mut self, _: &[ColumnDefinition<'_>]) -> Result<()> {
792 Ok(())
793 }
794 fn resultset_end(&mut self, _: OkPayloadBytes) -> Result<()> {
795 Ok(())
796 }
797 fn row(&mut self, _: &[ColumnDefinition<'_>], row: TextRowPayload<'_>) -> Result<()> {
798 if row.0.first() == Some(&0xFB) {
800 return Ok(());
801 }
802 let (value, _) = read_string_lenenc(row.0)?;
804 if !value.is_empty() {
805 self.path = Some(String::from_utf8_lossy(value).into_owned());
806 }
807 Ok(())
808 }
809}