1use crate::PreparedStatement;
2use crate::buffer::BufferSet;
3use crate::buffer_pool::PooledBufferSet;
4use crate::constant::CapabilityFlags;
5use crate::error::{Error, Result};
6use crate::protocol::TextRowPayload;
7use crate::protocol::command::Action;
8use crate::protocol::command::ColumnDefinition;
9use crate::protocol::command::bulk_exec::{BulkExec, BulkFlags, BulkParamsSet, write_bulk_execute};
10use crate::protocol::command::prepared::Exec;
11use crate::protocol::command::prepared::write_execute;
12use crate::protocol::command::prepared::{read_prepare_ok, write_prepare};
13use crate::protocol::command::query::Query;
14use crate::protocol::command::query::write_query;
15use crate::protocol::command::utility::DropHandler;
16use crate::protocol::command::utility::FirstRowHandler;
17use crate::protocol::command::utility::write_ping;
18use crate::protocol::command::utility::write_reset_connection;
19use crate::protocol::connection::{Handshake, HandshakeAction, InitialHandshake};
20use crate::protocol::packet::PacketHeader;
21use crate::protocol::primitive::read_string_lenenc;
22use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
23use crate::protocol::r#trait::{BinaryResultSetHandler, TextResultSetHandler, param::Params};
24use core::hint::unlikely;
25use core::io::BorrowedBuf;
26use std::net::TcpStream;
27use std::os::unix::net::UnixStream;
28use zerocopy::FromZeros;
29use zerocopy::{FromBytes, IntoBytes};
30
31use super::stream::Stream;
32
33pub struct Conn {
34 stream: Stream,
35 buffer_set: PooledBufferSet,
36 initial_handshake: InitialHandshake,
37 capability_flags: CapabilityFlags,
38 mariadb_capabilities: crate::constant::MariadbCapabilityFlags,
39 in_transaction: bool,
40 is_broken: bool,
41}
42
43impl Conn {
44 pub(crate) fn set_in_transaction(&mut self, value: bool) {
45 self.in_transaction = value;
46 }
47
48 pub fn new<O: TryInto<crate::opts::Opts>>(opts: O) -> Result<Self>
50 where
51 Error: From<O::Error>,
52 {
53 let opts: crate::opts::Opts = opts.try_into()?;
54
55 let stream = if let Some(socket_path) = &opts.socket {
56 let stream = UnixStream::connect(socket_path)?;
57 Stream::unix(stream)
58 } else {
59 if opts.host.is_empty() {
60 return Err(Error::BadUsageError(
61 "Missing host in connection options".to_string(),
62 ));
63 }
64 let addr = format!("{}:{}", opts.host, opts.port);
65 let stream = TcpStream::connect(&addr)?;
66 stream.set_nodelay(opts.tcp_nodelay)?;
67 Stream::tcp(stream)
68 };
69
70 Self::new_with_stream(stream, &opts)
71 }
72
73 pub fn new_with_stream(stream: Stream, opts: &crate::opts::Opts) -> Result<Self> {
75 let mut conn_stream = stream;
76 let mut buffer_set = opts.buffer_pool.get_buffer_set();
77
78 #[cfg(feature = "sync-tls")]
79 let host = opts.host.clone();
80
81 let mut handshake = Handshake::new(opts);
82
83 loop {
84 match handshake.step(&mut buffer_set)? {
85 HandshakeAction::ReadPacket(buffer) => {
86 buffer.clear();
87 read_payload(&mut conn_stream, buffer)?;
88 }
89 HandshakeAction::WritePacket { sequence_id } => {
90 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id)?;
91 buffer_set.read_buffer.clear();
92 read_payload(&mut conn_stream, &mut buffer_set.read_buffer)?;
93 }
94 #[cfg(feature = "sync-tls")]
95 HandshakeAction::UpgradeTls { sequence_id } => {
96 write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id)?;
97 conn_stream = conn_stream.upgrade_to_tls(&host)?;
98 }
99 #[cfg(not(feature = "sync-tls"))]
100 HandshakeAction::UpgradeTls { .. } => {
101 return Err(Error::BadUsageError(
102 "TLS requested but sync-tls feature is not enabled".to_string(),
103 ));
104 }
105 HandshakeAction::Finished => break,
106 }
107 }
108
109 let (initial_handshake, capability_flags, mariadb_capabilities) = handshake.finish()?;
110
111 let conn = Self {
112 stream: conn_stream,
113 buffer_set,
114 initial_handshake,
115 capability_flags,
116 mariadb_capabilities,
117 in_transaction: false,
118 is_broken: false,
119 };
120
121 let mut conn = if opts.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
123 conn.try_upgrade_to_unix_socket(opts)
124 } else {
125 conn
126 };
127
128 if let Some(init_command) = &opts.init_command {
130 conn.query_drop(init_command)?;
131 }
132
133 Ok(conn)
134 }
135
136 pub fn server_version(&self) -> &[u8] {
138 &self.buffer_set.initial_handshake[self.initial_handshake.server_version.clone()]
139 }
140
141 pub fn capability_flags(&self) -> CapabilityFlags {
143 self.capability_flags
144 }
145
146 pub fn is_mysql(&self) -> bool {
148 self.capability_flags.is_mysql()
149 }
150
151 pub fn is_mariadb(&self) -> bool {
153 self.capability_flags.is_mariadb()
154 }
155
156 pub fn connection_id(&self) -> u64 {
158 self.initial_handshake.connection_id as u64
159 }
160
161 pub fn status_flags(&self) -> crate::constant::ServerStatusFlags {
163 self.initial_handshake.status_flags
164 }
165
166 pub fn is_broken(&self) -> bool {
170 self.is_broken
171 }
172
173 #[inline]
174 fn check_error<T>(&mut self, result: Result<T>) -> Result<T> {
175 if let Err(e) = &result
176 && e.is_conn_broken()
177 {
178 self.is_broken = true;
179 }
180 result
181 }
182
183 fn try_upgrade_to_unix_socket(mut self, opts: &crate::opts::Opts) -> Self {
186 let mut handler = SocketPathHandler { path: None };
188 if self.query("SELECT @@socket", &mut handler).is_err() {
189 return self;
190 }
191
192 let socket_path = match handler.path {
193 Some(p) if !p.is_empty() => p,
194 _ => return self,
195 };
196
197 let unix_stream = match UnixStream::connect(&socket_path) {
199 Ok(s) => s,
200 Err(_) => return self,
201 };
202 let stream = Stream::unix(unix_stream);
203
204 let mut opts_unix = opts.clone();
207 opts_unix.upgrade_to_unix_socket = false;
208
209 match Self::new_with_stream(stream, &opts_unix) {
210 Ok(new_conn) => new_conn,
211 Err(_) => self,
212 }
213 }
214
215 fn write_payload(&mut self) -> Result<()> {
216 let mut sequence_id = 0_u8;
217 let mut buffer = self.buffer_set.write_buffer_mut().as_mut_slice();
218
219 loop {
220 let chunk_size = buffer[4..].len().min(0xFFFFFF);
221 PacketHeader::mut_from_bytes(&mut buffer[0..4])?
222 .encode_in_place(chunk_size, sequence_id);
223 self.stream.write_all(&buffer[..4 + chunk_size])?;
224
225 if chunk_size < 0xFFFFFF {
226 break;
227 }
228
229 sequence_id = sequence_id.wrapping_add(1);
230 buffer = &mut buffer[0xFFFFFF..];
231 }
232 self.stream.flush()?;
233 Ok(())
234 }
235
236 pub fn prepare(&mut self, sql: &str) -> Result<PreparedStatement> {
238 let result = self.prepare_inner(sql);
239 self.check_error(result)
240 }
241
242 fn prepare_inner(&mut self, sql: &str) -> Result<PreparedStatement> {
243 use crate::protocol::command::ColumnDefinitions;
244
245 self.buffer_set.read_buffer.clear();
246
247 write_prepare(self.buffer_set.new_write_buffer(), sql);
248
249 self.write_payload()?;
250 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
251
252 if unlikely(
253 !self.buffer_set.read_buffer.is_empty() && self.buffer_set.read_buffer[0] == 0xFF,
254 ) {
255 Err(ErrPayloadBytes(&self.buffer_set.read_buffer))?
256 }
257
258 let prepare_ok = read_prepare_ok(&self.buffer_set.read_buffer)?;
259 let statement_id = prepare_ok.statement_id();
260 let num_params = prepare_ok.num_params();
261 let num_columns = prepare_ok.num_columns();
262
263 if num_params > 0 {
265 for _ in 0..num_params {
266 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
267 }
268 }
269
270 let column_definitions = if num_columns > 0 {
272 read_column_definition_packets(
273 &mut self.stream,
274 &mut self.buffer_set.column_definition_buffer,
275 num_columns as usize,
276 )?;
277 Some(ColumnDefinitions::new(
278 num_columns as usize,
279 std::mem::take(&mut self.buffer_set.column_definition_buffer),
280 )?)
281 } else {
282 None
283 };
284
285 let mut stmt = PreparedStatement::new(statement_id);
286 if let Some(col_defs) = column_definitions {
287 stmt.set_column_definitions(col_defs);
288 }
289 Ok(stmt)
290 }
291
292 fn drive_exec<H: BinaryResultSetHandler>(
293 &mut self,
294 stmt: &mut PreparedStatement,
295 handler: &mut H,
296 ) -> Result<()> {
297 let cache_metadata = self
298 .mariadb_capabilities
299 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
300 let mut exec = Exec::new(handler, stmt, cache_metadata);
301
302 loop {
303 match exec.step(&mut self.buffer_set)? {
304 Action::NeedPacket(buffer) => {
305 buffer.clear();
306 let _ = read_payload(&mut self.stream, buffer)?;
307 }
308 Action::ReadColumnMetadata { num_columns } => {
309 read_column_definition_packets(
310 &mut self.stream,
311 &mut self.buffer_set.column_definition_buffer,
312 num_columns,
313 )?;
314 }
315 Action::Finished => return Ok(()),
316 }
317 }
318 }
319
320 pub fn exec<'conn, P, H>(
324 &'conn mut self,
325 stmt: &'conn mut PreparedStatement,
326 params: P,
327 handler: &mut H,
328 ) -> Result<()>
329 where
330 P: Params,
331 H: BinaryResultSetHandler,
332 {
333 let result = self.exec_inner(stmt, params, handler);
334 self.check_error(result)
335 }
336
337 fn exec_inner<'conn, P, H>(
338 &'conn mut self,
339 stmt: &'conn mut PreparedStatement,
340 params: P,
341 handler: &mut H,
342 ) -> Result<()>
343 where
344 P: Params,
345 H: BinaryResultSetHandler,
346 {
347 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
348 self.write_payload()?;
349 self.drive_exec(stmt, handler)
350 }
351
352 fn drive_bulk_exec<H: BinaryResultSetHandler>(
353 &mut self,
354 stmt: &mut PreparedStatement,
355 handler: &mut H,
356 ) -> Result<()> {
357 let cache_metadata = self
358 .mariadb_capabilities
359 .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
360 let mut bulk_exec = BulkExec::new(handler, stmt, cache_metadata);
361
362 loop {
363 match bulk_exec.step(&mut self.buffer_set)? {
364 Action::NeedPacket(buffer) => {
365 buffer.clear();
366 let _ = read_payload(&mut self.stream, buffer)?;
367 }
368 Action::ReadColumnMetadata { num_columns } => {
369 read_column_definition_packets(
370 &mut self.stream,
371 &mut self.buffer_set.column_definition_buffer,
372 num_columns,
373 )?;
374 }
375 Action::Finished => return Ok(()),
376 }
377 }
378 }
379
380 pub fn exec_bulk<P, I, H>(
382 &mut self,
383 stmt: &mut PreparedStatement,
384 params: P,
385 flags: BulkFlags,
386 handler: &mut H,
387 ) -> Result<()>
388 where
389 P: BulkParamsSet + IntoIterator<Item = I>,
390 I: Params,
391 H: BinaryResultSetHandler,
392 {
393 let result = self.exec_bulk_inner(stmt, params, flags, handler);
394 self.check_error(result)
395 }
396
397 fn exec_bulk_inner<P, I, H>(
398 &mut self,
399 stmt: &mut PreparedStatement,
400 params: P,
401 flags: BulkFlags,
402 handler: &mut H,
403 ) -> Result<()>
404 where
405 P: BulkParamsSet + IntoIterator<Item = I>,
406 I: Params,
407 H: BinaryResultSetHandler,
408 {
409 if !self.is_mariadb() {
410 for param in params {
412 self.exec_inner(stmt, param, &mut DropHandler::default())?;
413 }
414 Ok(())
415 } else {
416 write_bulk_execute(self.buffer_set.new_write_buffer(), stmt.id(), params, flags)?;
418 self.write_payload()?;
419 self.drive_bulk_exec(stmt, handler)
420 }
421 }
422
423 pub fn exec_first<'conn, P, H>(
430 &'conn mut self,
431 stmt: &'conn mut PreparedStatement,
432 params: P,
433 handler: &mut H,
434 ) -> Result<bool>
435 where
436 P: Params,
437 H: BinaryResultSetHandler,
438 {
439 let result = self.exec_first_inner(stmt, params, handler);
440 self.check_error(result)
441 }
442
443 fn exec_first_inner<'conn, P, H>(
444 &'conn mut self,
445 stmt: &'conn mut PreparedStatement,
446 params: P,
447 handler: &mut H,
448 ) -> Result<bool>
449 where
450 P: Params,
451 H: BinaryResultSetHandler,
452 {
453 write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
454 self.write_payload()?;
455 let mut first_row_handler = FirstRowHandler::new(handler);
456 self.drive_exec(stmt, &mut first_row_handler)?;
457 Ok(first_row_handler.found_row)
458 }
459
460 pub fn exec_drop<P>(&mut self, stmt: &mut PreparedStatement, params: P) -> Result<()>
462 where
463 P: Params,
464 {
465 self.exec(stmt, params, &mut DropHandler::default())
466 }
467
468 pub fn exec_rows<Row, P>(&mut self, stmt: &mut PreparedStatement, params: P) -> Result<Vec<Row>>
470 where
471 Row: for<'buf> crate::raw::FromRawRow<'buf>,
472 P: Params,
473 {
474 let mut handler = crate::handler::CollectHandler::<Row>::default();
475 self.exec(stmt, params, &mut handler)?;
476 Ok(handler.into_rows())
477 }
478
479 fn drive_query<H: TextResultSetHandler>(&mut self, handler: &mut H) -> Result<()> {
480 let mut query = Query::new(handler);
481
482 loop {
483 match query.step(&mut self.buffer_set)? {
484 Action::NeedPacket(buffer) => {
485 buffer.clear();
486 let _ = read_payload(&mut self.stream, buffer)?;
487 }
488 Action::ReadColumnMetadata { num_columns } => {
489 read_column_definition_packets(
490 &mut self.stream,
491 &mut self.buffer_set.column_definition_buffer,
492 num_columns,
493 )?;
494 }
495 Action::Finished => return Ok(()),
496 }
497 }
498 }
499
500 pub fn query<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
502 where
503 H: TextResultSetHandler,
504 {
505 let result = self.query_inner(sql, handler);
506 self.check_error(result)
507 }
508
509 fn query_inner<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
510 where
511 H: TextResultSetHandler,
512 {
513 write_query(self.buffer_set.new_write_buffer(), sql);
514 self.write_payload()?;
515 self.drive_query(handler)
516 }
517
518 pub fn query_drop(&mut self, sql: &str) -> Result<()> {
520 let result = self.query_drop_inner(sql);
521 self.check_error(result)
522 }
523
524 fn query_drop_inner(&mut self, sql: &str) -> Result<()> {
525 write_query(self.buffer_set.new_write_buffer(), sql);
526 self.write_payload()?;
527 self.drive_query(&mut DropHandler::default())
528 }
529
530 pub fn ping(&mut self) -> Result<()> {
534 let result = self.ping_inner();
535 self.check_error(result)
536 }
537
538 fn ping_inner(&mut self) -> Result<()> {
539 write_ping(self.buffer_set.new_write_buffer());
540 self.write_payload()?;
541 self.buffer_set.read_buffer.clear();
542 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
543 Ok(())
544 }
545
546 pub fn reset(&mut self) -> Result<()> {
548 let result = self.reset_inner();
549 self.check_error(result)
550 }
551
552 fn reset_inner(&mut self) -> Result<()> {
553 write_reset_connection(self.buffer_set.new_write_buffer());
554 self.write_payload()?;
555 self.buffer_set.read_buffer.clear();
556 let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
557 self.in_transaction = false;
558 Ok(())
559 }
560
561 pub fn run_transaction<F, R>(&mut self, f: F) -> Result<R>
566 where
567 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
568 {
569 if self.in_transaction {
570 return Err(Error::NestedTransaction);
571 }
572
573 self.in_transaction = true;
574
575 if let Err(e) = self.query_drop("BEGIN") {
576 self.in_transaction = false;
577 return Err(e);
578 }
579
580 let tx = super::transaction::Transaction::new(self.connection_id());
581 let result = f(self, tx);
582
583 if self.in_transaction {
585 let rollback_result = self.query_drop("ROLLBACK");
586 self.in_transaction = false;
587
588 if let Err(e) = result {
590 return Err(e);
591 }
592 rollback_result?;
593 }
594
595 result
596 }
597}
598
599fn read_payload(reader: &mut Stream, buffer: &mut Vec<u8>) -> Result<u8> {
602 buffer.clear();
603
604 let mut header = PacketHeader::new_zeroed();
605 reader.read_exact(header.as_mut_bytes())?;
606
607 let length = header.length();
608 let mut sequence_id = header.sequence_id;
609
610 buffer.reserve(length);
611
612 {
613 let spare = buffer.spare_capacity_mut();
614 let mut buf: BorrowedBuf<'_> = (&mut spare[..length]).into();
615 reader.read_buf_exact(buf.unfilled())?;
616 unsafe {
618 buffer.set_len(length);
619 }
620 }
621
622 let mut current_length = length;
623 while current_length == 0xFFFFFF {
624 reader.read_exact(header.as_mut_bytes())?;
625
626 current_length = header.length();
627 sequence_id = header.sequence_id;
628
629 buffer.reserve(current_length);
630 let spare = buffer.spare_capacity_mut();
631 let mut buf: BorrowedBuf<'_> = (&mut spare[..current_length]).into();
632 reader.read_buf_exact(buf.unfilled())?;
633 unsafe {
635 buffer.set_len(buffer.len() + current_length);
636 }
637 }
638
639 Ok(sequence_id)
640}
641
642fn read_column_definition_packets(
643 reader: &mut Stream,
644 out: &mut Vec<u8>,
645 num_columns: usize,
646) -> Result<u8> {
647 out.clear();
648 let mut header = PacketHeader::new_zeroed();
649
650 for _ in 0..num_columns {
652 reader.read_exact(header.as_mut_bytes())?;
653 let length = header.length();
654 out.extend((length as u32).to_ne_bytes());
655
656 out.reserve(length);
657 let spare = out.spare_capacity_mut();
658 let mut buf: BorrowedBuf<'_> = (&mut spare[..length]).into();
659 reader.read_buf_exact(buf.unfilled())?;
660 unsafe {
662 out.set_len(out.len() + length);
663 }
664 }
665
666 Ok(header.sequence_id)
667}
668
669fn write_handshake_payload(
670 stream: &mut Stream,
671 buffer_set: &mut BufferSet,
672 sequence_id: u8,
673) -> Result<()> {
674 let mut buffer = buffer_set.write_buffer_mut().as_mut_slice();
675 let mut seq_id = sequence_id;
676
677 loop {
678 let chunk_size = buffer[4..].len().min(0xFFFFFF);
679 PacketHeader::mut_from_bytes(&mut buffer[0..4])?.encode_in_place(chunk_size, seq_id);
680 stream.write_all(&buffer[..4 + chunk_size])?;
681
682 if chunk_size < 0xFFFFFF {
683 break;
684 }
685
686 seq_id = seq_id.wrapping_add(1);
687 buffer = &mut buffer[0xFFFFFF..];
688 }
689 stream.flush()?;
690 Ok(())
691}
692
693struct SocketPathHandler {
695 path: Option<String>,
696}
697
698impl TextResultSetHandler for SocketPathHandler {
699 fn no_result_set(&mut self, _: OkPayloadBytes) -> Result<()> {
700 Ok(())
701 }
702 fn resultset_start(&mut self, _: &[ColumnDefinition<'_>]) -> Result<()> {
703 Ok(())
704 }
705 fn resultset_end(&mut self, _: OkPayloadBytes) -> Result<()> {
706 Ok(())
707 }
708 fn row(&mut self, _: &[ColumnDefinition<'_>], row: TextRowPayload<'_>) -> Result<()> {
709 if row.0.first() == Some(&0xFB) {
711 return Ok(());
712 }
713 let (value, _) = read_string_lenenc(row.0)?;
715 if !value.is_empty() {
716 self.path = Some(String::from_utf8_lossy(value).into_owned());
717 }
718 Ok(())
719 }
720}