zero_mysql/sync/
conn.rs

1use crate::buffer::BufferSet;
2use crate::buffer_pool::PooledBufferSet;
3use crate::constant::CapabilityFlags;
4use crate::error::{Error, Result};
5use crate::protocol::command::bulk_exec::{write_bulk_execute, BulkExec, BulkFlags, BulkParamsSet};
6use crate::protocol::command::prepared::write_execute;
7use crate::protocol::command::prepared::Exec;
8use crate::protocol::command::prepared::{read_prepare_ok, write_prepare};
9use crate::protocol::command::query::write_query;
10use crate::protocol::command::query::Query;
11use crate::protocol::command::utility::write_ping;
12use crate::protocol::command::utility::write_reset_connection;
13use crate::protocol::command::utility::DropHandler;
14use crate::protocol::command::utility::FirstRowHandler;
15use crate::protocol::command::Action;
16use crate::protocol::command::ColumnDefinition;
17use crate::protocol::connection::{Handshake, HandshakeAction, InitialHandshake};
18use crate::protocol::packet::PacketHeader;
19use crate::protocol::primitive::read_string_lenenc;
20use crate::protocol::r#trait::{param::Params, BinaryResultSetHandler, TextResultSetHandler};
21use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
22use crate::protocol::TextRowPayload;
23use crate::PreparedStatement;
24use core::hint::unlikely;
25use core::io::BorrowedBuf;
26use std::net::TcpStream;
27use std::os::unix::net::UnixStream;
28use zerocopy::FromZeros;
29use zerocopy::{FromBytes, IntoBytes};
30
31use super::stream::Stream;
32
33pub struct Conn {
34    stream: Stream,
35    buffer_set: PooledBufferSet,
36    initial_handshake: InitialHandshake,
37    capability_flags: CapabilityFlags,
38    mariadb_capabilities: crate::constant::MariadbCapabilityFlags,
39    in_transaction: bool,
40}
41
42impl Conn {
43    pub(crate) fn set_in_transaction(&mut self, value: bool) {
44        self.in_transaction = value;
45    }
46
47    /// Create a new MySQL connection from connection options
48    pub fn new<O: TryInto<crate::opts::Opts>>(opts: O) -> Result<Self>
49    where
50        Error: From<O::Error>,
51    {
52        let opts: crate::opts::Opts = opts.try_into()?;
53
54        let stream = if let Some(socket_path) = &opts.socket {
55            let stream = UnixStream::connect(socket_path)?;
56            Stream::unix(stream)
57        } else {
58            let host = opts.host.as_ref().ok_or_else(|| {
59                Error::BadConfigError("Missing host in connection options".to_string())
60            })?;
61
62            let addr = format!("{}:{}", host, opts.port);
63            let stream = TcpStream::connect(&addr)?;
64            stream.set_nodelay(opts.tcp_nodelay)?;
65            Stream::tcp(stream)
66        };
67
68        Self::new_with_stream(stream, &opts)
69    }
70
71    /// Create a new MySQL connection with an existing stream
72    pub fn new_with_stream(stream: Stream, opts: &crate::opts::Opts) -> Result<Self> {
73        let mut conn_stream = stream;
74        let mut buffer_set = opts.buffer_pool.get_buffer_set();
75
76        #[cfg(feature = "sync-tls")]
77        let host = opts.host.clone().unwrap_or_default();
78
79        let mut handshake = Handshake::new(opts);
80
81        loop {
82            match handshake.step(&mut buffer_set)? {
83                HandshakeAction::ReadPacket(buffer) => {
84                    buffer.clear();
85                    read_payload(&mut conn_stream, buffer)?;
86                }
87                HandshakeAction::WritePacket { sequence_id } => {
88                    write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id)?;
89                    buffer_set.read_buffer.clear();
90                    read_payload(&mut conn_stream, &mut buffer_set.read_buffer)?;
91                }
92                #[cfg(feature = "sync-tls")]
93                HandshakeAction::UpgradeTls { sequence_id } => {
94                    write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id)?;
95                    conn_stream = conn_stream.upgrade_to_tls(&host)?;
96                }
97                #[cfg(not(feature = "sync-tls"))]
98                HandshakeAction::UpgradeTls { .. } => {
99                    return Err(Error::BadConfigError(
100                        "TLS requested but sync-tls feature is not enabled".to_string(),
101                    ));
102                }
103                HandshakeAction::Finished => break,
104            }
105        }
106
107        let (initial_handshake, capability_flags, mariadb_capabilities) = handshake.finish()?;
108
109        let conn = Self {
110            stream: conn_stream,
111            buffer_set,
112            initial_handshake,
113            capability_flags,
114            mariadb_capabilities,
115            in_transaction: false,
116        };
117
118        // Upgrade to Unix socket if connected via TCP to loopback
119        let mut conn = if opts.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
120            conn.try_upgrade_to_unix_socket(opts)
121        } else {
122            conn
123        };
124
125        // Execute init command if specified
126        if let Some(init_command) = &opts.init_command {
127            conn.query_drop(init_command)?;
128        }
129
130        Ok(conn)
131    }
132
133    pub fn server_version(&self) -> &[u8] {
134        &self.buffer_set.initial_handshake[self.initial_handshake.server_version.clone()]
135    }
136
137    /// Get the negotiated capability flags
138    pub fn capability_flags(&self) -> CapabilityFlags {
139        self.capability_flags
140    }
141
142    /// Check if the server is MySQL (as opposed to MariaDB)
143    pub fn is_mysql(&self) -> bool {
144        self.capability_flags.is_mysql()
145    }
146
147    /// Check if the server is MariaDB (as opposed to MySQL)
148    pub fn is_mariadb(&self) -> bool {
149        self.capability_flags.is_mariadb()
150    }
151
152    /// Get the connection ID assigned by the server
153    pub fn connection_id(&self) -> u64 {
154        self.initial_handshake.connection_id as u64
155    }
156
157    /// Get the server status flags from the initial handshake
158    pub fn status_flags(&self) -> crate::constant::ServerStatusFlags {
159        self.initial_handshake.status_flags
160    }
161
162    /// Try to upgrade to Unix socket connection.
163    /// Returns upgraded conn on success, original conn on failure.
164    fn try_upgrade_to_unix_socket(mut self, opts: &crate::opts::Opts) -> Self {
165        // Query the server for its Unix socket path
166        let mut handler = SocketPathHandler { path: None };
167        if self.query("SELECT @@socket", &mut handler).is_err() {
168            return self;
169        }
170
171        let socket_path = match handler.path {
172            Some(p) if !p.is_empty() => p,
173            _ => return self,
174        };
175
176        // Connect via Unix socket
177        let unix_stream = match UnixStream::connect(&socket_path) {
178            Ok(s) => s,
179            Err(_) => return self,
180        };
181        let stream = Stream::unix(unix_stream);
182
183        // Create new connection over Unix socket (re-handshakes)
184        // Disable upgrade_to_unix_socket to prevent infinite recursion
185        let mut opts_unix = opts.clone();
186        opts_unix.upgrade_to_unix_socket = false;
187
188        match Self::new_with_stream(stream, &opts_unix) {
189            Ok(new_conn) => new_conn,
190            Err(_) => self,
191        }
192    }
193
194    fn write_payload(&mut self) -> Result<()> {
195        let mut sequence_id = 0_u8;
196        let mut buffer = self.buffer_set.write_buffer_mut().as_mut_slice();
197
198        loop {
199            let chunk_size = buffer[4..].len().min(0xFFFFFF);
200            PacketHeader::mut_from_bytes(&mut buffer[0..4])?
201                .encode_in_place(chunk_size, sequence_id);
202            self.stream.write_all(&buffer[..4 + chunk_size])?;
203
204            if chunk_size < 0xFFFFFF {
205                break;
206            }
207
208            sequence_id = sequence_id.wrapping_add(1);
209            buffer = &mut buffer[0xFFFFFF..];
210        }
211        self.stream.flush()?;
212        Ok(())
213    }
214
215    /// Returns `Ok(statement_id) on success
216    pub fn prepare(&mut self, sql: &str) -> Result<PreparedStatement> {
217        use crate::protocol::command::ColumnDefinitions;
218
219        self.buffer_set.read_buffer.clear();
220
221        write_prepare(self.buffer_set.new_write_buffer(), sql);
222
223        self.write_payload()?;
224        let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
225
226        if unlikely(
227            !self.buffer_set.read_buffer.is_empty() && self.buffer_set.read_buffer[0] == 0xFF,
228        ) {
229            Err(ErrPayloadBytes(&self.buffer_set.read_buffer))?
230        }
231
232        let prepare_ok = read_prepare_ok(&self.buffer_set.read_buffer)?;
233        let statement_id = prepare_ok.statement_id();
234        let num_params = prepare_ok.num_params();
235        let num_columns = prepare_ok.num_columns();
236
237        // Skip param definitions (we don't cache them)
238        if num_params > 0 {
239            for _ in 0..num_params {
240                let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
241            }
242        }
243
244        // Read and cache column definitions for MARIADB_CLIENT_CACHE_METADATA support
245        let column_definitions = if num_columns > 0 {
246            read_column_definition_packets(
247                &mut self.stream,
248                &mut self.buffer_set.column_definition_buffer,
249                num_columns as usize,
250            )?;
251            Some(ColumnDefinitions::new(
252                num_columns as usize,
253                std::mem::take(&mut self.buffer_set.column_definition_buffer),
254            )?)
255        } else {
256            None
257        };
258
259        let mut stmt = PreparedStatement::new(statement_id);
260        if let Some(col_defs) = column_definitions {
261            stmt.set_column_definitions(col_defs);
262        }
263        Ok(stmt)
264    }
265
266    fn drive_exec<H: BinaryResultSetHandler>(
267        &mut self,
268        stmt: &mut PreparedStatement,
269        handler: &mut H,
270    ) -> Result<()> {
271        let cache_metadata = self
272            .mariadb_capabilities
273            .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
274        let mut exec = Exec::new(handler, stmt, cache_metadata);
275
276        loop {
277            match exec.step(&mut self.buffer_set)? {
278                Action::NeedPacket(buffer) => {
279                    buffer.clear();
280                    let _ = read_payload(&mut self.stream, buffer)?;
281                }
282                Action::ReadColumnMetadata { num_columns } => {
283                    read_column_definition_packets(
284                        &mut self.stream,
285                        &mut self.buffer_set.column_definition_buffer,
286                        num_columns,
287                    )?;
288                }
289                Action::Finished => return Ok(()),
290            }
291        }
292    }
293
294    pub fn exec<'conn, P, H>(
295        &'conn mut self,
296        stmt: &'conn mut PreparedStatement,
297        params: P,
298        handler: &mut H,
299    ) -> Result<()>
300    where
301        P: Params,
302        H: BinaryResultSetHandler,
303    {
304        write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
305        self.write_payload()?;
306        self.drive_exec(stmt, handler)
307    }
308
309    fn drive_bulk_exec<H: BinaryResultSetHandler>(
310        &mut self,
311        stmt: &mut PreparedStatement,
312        handler: &mut H,
313    ) -> Result<()> {
314        let cache_metadata = self
315            .mariadb_capabilities
316            .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
317        let mut bulk_exec = BulkExec::new(handler, stmt, cache_metadata);
318
319        loop {
320            match bulk_exec.step(&mut self.buffer_set)? {
321                Action::NeedPacket(buffer) => {
322                    buffer.clear();
323                    let _ = read_payload(&mut self.stream, buffer)?;
324                }
325                Action::ReadColumnMetadata { num_columns } => {
326                    read_column_definition_packets(
327                        &mut self.stream,
328                        &mut self.buffer_set.column_definition_buffer,
329                        num_columns,
330                    )?;
331                }
332                Action::Finished => return Ok(()),
333            }
334        }
335    }
336
337    /// Execute a bulk prepared statement with a result set handler
338    pub fn exec_bulk<P, I, H>(
339        &mut self,
340        stmt: &mut PreparedStatement,
341        params: P,
342        flags: BulkFlags,
343        handler: &mut H,
344    ) -> Result<()>
345    where
346        P: BulkParamsSet + IntoIterator<Item = I>,
347        I: Params,
348        H: BinaryResultSetHandler,
349    {
350        if !self.is_mariadb() {
351            // Fallback to multiple exec_drop 'conn, for non-MariaDB servers'conn
352            for param in params {
353                self.exec_drop(stmt, param)?;
354            }
355            Ok(())
356        } else {
357            // Use MariaDB bulk execute protocol
358            write_bulk_execute(self.buffer_set.new_write_buffer(), stmt.id(), params, flags)?;
359            self.write_payload()?;
360            self.drive_bulk_exec(stmt, handler)
361        }
362    }
363
364    /// Execute a prepared statement and return only the first row, dropping the rest
365    ///
366    /// # Returns
367    /// * `Ok(true)` - First row was found and processed
368    /// * `Ok(false)` - No rows in result set
369    /// * `Err(Error)` - Query execution or handler callback failed
370    pub fn exec_first<'conn, P, H>(
371        &'conn mut self,
372        stmt: &'conn mut PreparedStatement,
373        params: P,
374        handler: &mut H,
375    ) -> Result<bool>
376    where
377        P: Params,
378        H: BinaryResultSetHandler,
379    {
380        write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
381        self.write_payload()?;
382        let mut first_row_handler = FirstRowHandler::new(handler);
383        self.drive_exec(stmt, &mut first_row_handler)?;
384        Ok(first_row_handler.found_row)
385    }
386
387    /// Execute a prepared statement and discard all results
388    pub fn exec_drop<P>(&mut self, stmt: &mut PreparedStatement, params: P) -> Result<()>
389    where
390        P: Params,
391    {
392        write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
393        self.write_payload()?;
394        self.drive_exec(stmt, &mut DropHandler::default())
395    }
396
397    fn drive_query<H: TextResultSetHandler>(&mut self, handler: &mut H) -> Result<()> {
398        let mut query = Query::new(handler);
399
400        loop {
401            match query.step(&mut self.buffer_set)? {
402                Action::NeedPacket(buffer) => {
403                    buffer.clear();
404                    let _ = read_payload(&mut self.stream, buffer)?;
405                }
406                Action::ReadColumnMetadata { num_columns } => {
407                    read_column_definition_packets(
408                        &mut self.stream,
409                        &mut self.buffer_set.column_definition_buffer,
410                        num_columns,
411                    )?;
412                }
413                Action::Finished => return Ok(()),
414            }
415        }
416    }
417
418    /// Execute a text protocol SQL query
419    pub fn query<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
420    where
421        H: TextResultSetHandler,
422    {
423        write_query(self.buffer_set.new_write_buffer(), sql);
424        self.write_payload()?;
425        self.drive_query(handler)
426    }
427
428    /// Execute a text protocol SQL query and discard the result
429    pub fn query_drop(&mut self, sql: &str) -> Result<()> {
430        write_query(self.buffer_set.new_write_buffer(), sql);
431        self.write_payload()?;
432        self.drive_query(&mut DropHandler::default())
433    }
434
435    /// Send a ping to the server to check if the connection is alive
436    ///
437    /// This sends a COM_PING command to the MySQL server and waits for an OK response.
438    pub fn ping(&mut self) -> Result<()> {
439        write_ping(self.buffer_set.new_write_buffer());
440        self.write_payload()?;
441        self.buffer_set.read_buffer.clear();
442        let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
443        Ok(())
444    }
445
446    /// Reset the connection to its initial state
447    pub fn reset(&mut self) -> Result<()> {
448        write_reset_connection(self.buffer_set.new_write_buffer());
449        self.write_payload()?;
450        self.buffer_set.read_buffer.clear();
451        let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer)?;
452        self.in_transaction = false;
453        Ok(())
454    }
455
456    /// Execute a closure within a transaction
457    ///
458    /// # Errors
459    /// Returns `Error::NestedTransaction` if called while already in a transaction
460    pub fn run_transaction<F, R>(&mut self, f: F) -> Result<R>
461    where
462        F: FnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
463    {
464        if self.in_transaction {
465            return Err(Error::NestedTransaction);
466        }
467
468        self.in_transaction = true;
469
470        if let Err(e) = self.query_drop("BEGIN") {
471            self.in_transaction = false;
472            return Err(e);
473        }
474
475        let tx = super::transaction::Transaction::new(self.connection_id());
476        let result = f(self, tx);
477
478        // If the transaction was not explicitly committed or rolled back, roll it back
479        if self.in_transaction {
480            let rollback_result = self.query_drop("ROLLBACK");
481            self.in_transaction = false;
482
483            // Return the first error (either from closure or rollback)
484            if let Err(e) = result {
485                return Err(e);
486            }
487            rollback_result?;
488        }
489
490        result
491    }
492}
493
494/// Read a complete MySQL payload, concatenating payloads if they span multiple 16MB chunks
495/// Returns the sequence_id of the last packet read.
496fn read_payload(reader: &mut Stream, buffer: &mut Vec<u8>) -> Result<u8> {
497    buffer.clear();
498
499    let mut header = PacketHeader::new_zeroed();
500    reader.read_exact(header.as_mut_bytes())?;
501
502    let length = header.length();
503    let mut sequence_id = header.sequence_id;
504
505    buffer.reserve(length);
506
507    {
508        let spare = buffer.spare_capacity_mut();
509        let mut buf: BorrowedBuf<'_> = (&mut spare[..length]).into();
510        reader.read_buf_exact(buf.unfilled())?;
511        // SAFETY: read_buf_exact filled exactly `length` bytes
512        unsafe {
513            buffer.set_len(length);
514        }
515    }
516
517    let mut current_length = length;
518    while current_length == 0xFFFFFF {
519        reader.read_exact(header.as_mut_bytes())?;
520
521        current_length = header.length();
522        sequence_id = header.sequence_id;
523
524        buffer.reserve(current_length);
525        let spare = buffer.spare_capacity_mut();
526        let mut buf: BorrowedBuf<'_> = (&mut spare[..current_length]).into();
527        reader.read_buf_exact(buf.unfilled())?;
528        // SAFETY: read_buf_exact filled exactly `current_length` bytes
529        unsafe {
530            buffer.set_len(buffer.len() + current_length);
531        }
532    }
533
534    Ok(sequence_id)
535}
536
537fn read_column_definition_packets(
538    reader: &mut Stream,
539    out: &mut Vec<u8>,
540    num_columns: usize,
541) -> Result<u8> {
542    out.clear();
543    let mut header = PacketHeader::new_zeroed();
544
545    // For each column, write [4 bytes len][payload]
546    for _ in 0..num_columns {
547        reader.read_exact(header.as_mut_bytes())?;
548        let length = header.length();
549        out.extend((length as u32).to_ne_bytes());
550
551        out.reserve(length);
552        let spare = out.spare_capacity_mut();
553        let mut buf: BorrowedBuf<'_> = (&mut spare[..length]).into();
554        reader.read_buf_exact(buf.unfilled())?;
555        // SAFETY: read_buf_exact filled exactly `length` bytes
556        unsafe {
557            out.set_len(out.len() + length);
558        }
559    }
560
561    Ok(header.sequence_id)
562}
563
564fn write_handshake_payload(
565    stream: &mut Stream,
566    buffer_set: &mut BufferSet,
567    sequence_id: u8,
568) -> Result<()> {
569    let mut buffer = buffer_set.write_buffer_mut().as_mut_slice();
570    let mut seq_id = sequence_id;
571
572    loop {
573        let chunk_size = buffer[4..].len().min(0xFFFFFF);
574        PacketHeader::mut_from_bytes(&mut buffer[0..4])?.encode_in_place(chunk_size, seq_id);
575        stream.write_all(&buffer[..4 + chunk_size])?;
576
577        if chunk_size < 0xFFFFFF {
578            break;
579        }
580
581        seq_id = seq_id.wrapping_add(1);
582        buffer = &mut buffer[0xFFFFFF..];
583    }
584    stream.flush()?;
585    Ok(())
586}
587
588/// Handler to capture socket path from SELECT @@socket query
589struct SocketPathHandler {
590    path: Option<String>,
591}
592
593impl TextResultSetHandler for SocketPathHandler {
594    fn no_result_set(&mut self, _: OkPayloadBytes) -> Result<()> {
595        Ok(())
596    }
597    fn resultset_start(&mut self, _: &[ColumnDefinition<'_>]) -> Result<()> {
598        Ok(())
599    }
600    fn resultset_end(&mut self, _: OkPayloadBytes) -> Result<()> {
601        Ok(())
602    }
603    fn row(&mut self, _: &[ColumnDefinition<'_>], row: TextRowPayload<'_>) -> Result<()> {
604        // 0xFB indicates NULL value
605        if row.0.first() == Some(&0xFB) {
606            return Ok(());
607        }
608        // Parse the first length-encoded string
609        let (value, _) = read_string_lenenc(row.0)?;
610        if !value.is_empty() {
611            self.path = Some(String::from_utf8_lossy(value).into_owned());
612        }
613        Ok(())
614    }
615}