zero_mysql/tokio/
conn.rs

1use tokio::net::{TcpStream, UnixStream};
2use tracing::instrument;
3use zerocopy::{FromBytes, FromZeros, IntoBytes};
4
5use crate::PreparedStatement;
6use crate::buffer::BufferSet;
7use crate::buffer_pool::PooledBufferSet;
8use crate::constant::CapabilityFlags;
9use crate::error::{Error, Result};
10use crate::protocol::command::Action;
11use crate::protocol::command::bulk_exec::{BulkExec, BulkFlags, BulkParamsSet, write_bulk_execute};
12use crate::protocol::command::prepared::{Exec, read_prepare_ok, write_execute, write_prepare};
13use crate::protocol::command::query::{Query, write_query};
14use crate::protocol::command::utility::{
15    DropHandler, FirstRowHandler, write_ping, write_reset_connection,
16};
17use crate::protocol::command::ColumnDefinition;
18use crate::protocol::connection::{Handshake, HandshakeAction, InitialHandshake};
19use crate::protocol::packet::PacketHeader;
20use crate::protocol::primitive::read_string_lenenc;
21use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
22use crate::protocol::r#trait::{BinaryResultSetHandler, TextResultSetHandler, param::Params};
23use crate::protocol::TextRowPayload;
24
25use super::stream::Stream;
26
27pub struct Conn {
28    stream: Stream,
29    buffer_set: PooledBufferSet,
30    initial_handshake: InitialHandshake,
31    capability_flags: CapabilityFlags,
32    mariadb_capabilities: crate::constant::MariadbCapabilityFlags,
33    in_transaction: bool,
34}
35
36impl Conn {
37    /// Create a new MySQL connection from connection options (async)
38    pub async fn new<O: TryInto<crate::opts::Opts>>(opts: O) -> Result<Self>
39    where
40        Error: From<O::Error>,
41    {
42        let opts: crate::opts::Opts = opts.try_into()?;
43
44        let stream = if let Some(socket_path) = &opts.socket {
45            let stream = UnixStream::connect(socket_path).await?;
46            Stream::unix(stream)
47        } else {
48            let host = opts.host.as_ref().ok_or_else(|| {
49                Error::BadConfigError("Missing host in connection options".to_string())
50            })?;
51
52            let addr = format!("{}:{}", host, opts.port);
53            let stream = TcpStream::connect(&addr).await?;
54            stream.set_nodelay(opts.tcp_nodelay)?;
55            Stream::tcp(stream)
56        };
57
58        Self::new_with_stream(stream, &opts).await
59    }
60
61    /// Create a new MySQL connection with an existing stream (async)
62    pub async fn new_with_stream(stream: Stream, opts: &crate::opts::Opts) -> Result<Self> {
63        let mut conn_stream = stream;
64        let mut buffer_set = opts.buffer_pool.get_buffer_set();
65
66        #[cfg(feature = "tokio-tls")]
67        let host = opts.host.clone().unwrap_or_default();
68
69        let mut handshake = Handshake::new(opts);
70
71        loop {
72            match handshake.step(&mut buffer_set)? {
73                HandshakeAction::ReadPacket(buffer) => {
74                    buffer.clear();
75                    read_payload(&mut conn_stream, buffer).await?;
76                }
77                HandshakeAction::WritePacket { sequence_id } => {
78                    write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id).await?;
79                    buffer_set.read_buffer.clear();
80                    read_payload(&mut conn_stream, &mut buffer_set.read_buffer).await?;
81                }
82                #[cfg(feature = "tokio-tls")]
83                HandshakeAction::UpgradeTls { sequence_id } => {
84                    write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id).await?;
85                    conn_stream = conn_stream.upgrade_to_tls(&host).await?;
86                }
87                #[cfg(not(feature = "tokio-tls"))]
88                HandshakeAction::UpgradeTls { .. } => {
89                    return Err(Error::BadConfigError(
90                        "TLS requested but tokio-tls feature is not enabled".to_string(),
91                    ));
92                }
93                HandshakeAction::Finished => break,
94            }
95        }
96
97        let (initial_handshake, capability_flags, mariadb_capabilities) = handshake.finish()?;
98
99        let conn = Self {
100            stream: conn_stream,
101            buffer_set,
102            initial_handshake,
103            capability_flags,
104            mariadb_capabilities,
105            in_transaction: false,
106        };
107
108        // Upgrade to Unix socket if connected via TCP to loopback
109        let mut conn = if opts.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
110            conn.try_upgrade_to_unix_socket(opts).await
111        } else {
112            conn
113        };
114
115        // Execute init command if specified
116        if let Some(init_command) = &opts.init_command {
117            conn.query_drop(init_command).await?;
118        }
119
120        Ok(conn)
121    }
122
123    pub fn server_version(&self) -> &[u8] {
124        &self.buffer_set.initial_handshake[self.initial_handshake.server_version.clone()]
125    }
126
127    /// Get the negotiated capability flags
128    pub fn capability_flags(&self) -> CapabilityFlags {
129        self.capability_flags
130    }
131
132    /// Check if the server is MySQL (as opposed to MariaDB)
133    pub fn is_mysql(&self) -> bool {
134        self.capability_flags.is_mysql()
135    }
136
137    /// Check if the server is MariaDB (as opposed to MySQL)
138    pub fn is_mariadb(&self) -> bool {
139        self.capability_flags.is_mariadb()
140    }
141
142    /// Get the connection ID assigned by the server
143    pub fn connection_id(&self) -> u64 {
144        self.initial_handshake.connection_id as u64
145    }
146
147    /// Get the server status flags from the initial handshake
148    pub fn status_flags(&self) -> crate::constant::ServerStatusFlags {
149        self.initial_handshake.status_flags
150    }
151
152    pub(crate) fn set_in_transaction(&mut self, value: bool) {
153        self.in_transaction = value;
154    }
155
156    /// Try to upgrade to Unix socket connection.
157    /// Returns upgraded conn on success, original conn on failure.
158    async fn try_upgrade_to_unix_socket(mut self, opts: &crate::opts::Opts) -> Self {
159        // Query the server for its Unix socket path
160        let mut handler = SocketPathHandler { path: None };
161        if self.query("SELECT @@socket", &mut handler).await.is_err() {
162            return self;
163        }
164
165        let socket_path = match handler.path {
166            Some(p) if !p.is_empty() => p,
167            _ => return self,
168        };
169
170        // Connect via Unix socket
171        let unix_stream = match UnixStream::connect(&socket_path).await {
172            Ok(s) => s,
173            Err(_) => return self,
174        };
175        let stream = Stream::unix(unix_stream);
176
177        // Create new connection over Unix socket (re-handshakes)
178        // Disable upgrade_to_unix_socket to prevent infinite recursion
179        let mut opts_unix = opts.clone();
180        opts_unix.upgrade_to_unix_socket = false;
181
182        match Box::pin(Self::new_with_stream(stream, &opts_unix)).await {
183            Ok(new_conn) => new_conn,
184            Err(_) => self,
185        }
186    }
187
188    /// Write a MySQL packet from write_buffer asynchronously, splitting it into 16MB chunks if necessary
189    #[instrument(skip_all)]
190    async fn write_payload(&mut self) -> Result<()> {
191        let mut sequence_id = 0_u8;
192        let mut buffer = self.buffer_set.write_buffer_mut().as_mut_slice();
193
194        loop {
195            let chunk_size = buffer[4..].len().min(0xFFFFFF);
196            PacketHeader::mut_from_bytes(&mut buffer[0..4])?
197                .encode_in_place(chunk_size, sequence_id);
198            self.stream.write_all(&buffer[..4 + chunk_size]).await?;
199
200            if chunk_size < 0xFFFFFF {
201                break;
202            }
203
204            sequence_id = sequence_id.wrapping_add(1);
205            buffer = &mut buffer[0xFFFFFF..];
206        }
207        self.stream.flush().await?;
208        Ok(())
209    }
210
211    /// Prepare a statement and return the PreparedStatement (async)
212    ///
213    /// Returns `Ok(PreparedStatement)` on success.
214    pub async fn prepare(&mut self, sql: &str) -> Result<PreparedStatement> {
215        use crate::protocol::command::ColumnDefinitions;
216
217        self.buffer_set.read_buffer.clear();
218
219        write_prepare(self.buffer_set.new_write_buffer(), sql);
220
221        self.write_payload().await?;
222
223        let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
224
225        if !self.buffer_set.read_buffer.is_empty() && self.buffer_set.read_buffer[0] == 0xFF {
226            Err(ErrPayloadBytes(&self.buffer_set.read_buffer))?
227        }
228
229        let prepare_ok = read_prepare_ok(&self.buffer_set.read_buffer)?;
230        let statement_id = prepare_ok.statement_id();
231        let num_params = prepare_ok.num_params();
232        let num_columns = prepare_ok.num_columns();
233
234        // Skip param definitions (we don't cache them)
235        for _ in 0..num_params {
236            let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
237        }
238
239        // Read and cache column definitions for MARIADB_CLIENT_CACHE_METADATA support
240        let column_definitions = if num_columns > 0 {
241            self.read_column_definition_packets(num_columns as usize)
242                .await?;
243            Some(ColumnDefinitions::new(
244                num_columns as usize,
245                std::mem::take(&mut self.buffer_set.column_definition_buffer),
246            )?)
247        } else {
248            None
249        };
250
251        let mut stmt = PreparedStatement::new(statement_id);
252        if let Some(col_defs) = column_definitions {
253            stmt.set_column_definitions(col_defs);
254        }
255        Ok(stmt)
256    }
257
258    #[tracing::instrument(skip_all)]
259    async fn read_column_definition_packets(&mut self, num_columns: usize) -> Result<u8> {
260        let mut header = PacketHeader::new_zeroed();
261        let out = &mut self.buffer_set.column_definition_buffer;
262        out.clear();
263
264        // For each column, write [4 bytes len][payload]
265        for _ in 0..num_columns {
266            self.stream.read_exact(header.as_mut_bytes()).await?;
267            let length = header.length();
268            out.extend((length as u32).to_ne_bytes());
269
270            out.reserve(length);
271            let spare = out.spare_capacity_mut();
272            self.stream.read_buf_exact(&mut spare[..length]).await?;
273            // SAFETY: read_buf_exact filled exactly `length` bytes
274            unsafe {
275                out.set_len(out.len() + length);
276            }
277        }
278
279        Ok(header.sequence_id)
280    }
281
282    async fn drive_exec<H: BinaryResultSetHandler>(
283        &mut self,
284        stmt: &mut crate::PreparedStatement,
285        handler: &mut H,
286    ) -> Result<()> {
287        let cache_metadata = self
288            .mariadb_capabilities
289            .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
290        let mut exec = Exec::new(handler, stmt, cache_metadata);
291
292        loop {
293            match exec.step(&mut self.buffer_set)? {
294                Action::NeedPacket(buffer) => {
295                    buffer.clear();
296                    let _ = read_payload(&mut self.stream, buffer).await?;
297                }
298                Action::ReadColumnMetadata { num_columns } => {
299                    self.read_column_definition_packets(num_columns).await?;
300                }
301                Action::Finished => return Ok(()),
302            }
303        }
304    }
305
306    async fn drive_query<H: TextResultSetHandler>(&mut self, handler: &mut H) -> Result<()> {
307        let mut query = Query::new(handler);
308
309        loop {
310            match query.step(&mut self.buffer_set)? {
311                Action::NeedPacket(buffer) => {
312                    buffer.clear();
313                    let _ = read_payload(&mut self.stream, buffer).await?;
314                }
315                Action::ReadColumnMetadata { num_columns } => {
316                    self.read_column_definition_packets(num_columns).await?;
317                }
318                Action::Finished => return Ok(()),
319            }
320        }
321    }
322
323    /// Execute a prepared statement with a result set handler (async)
324    pub async fn exec<P, H>(
325        &mut self,
326        stmt: &mut PreparedStatement,
327        params: P,
328        handler: &mut H,
329    ) -> Result<()>
330    where
331        P: Params,
332        H: BinaryResultSetHandler,
333    {
334        write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
335        self.write_payload().await?;
336        self.drive_exec(stmt, handler).await
337    }
338
339    async fn drive_bulk_exec<H: BinaryResultSetHandler>(
340        &mut self,
341        stmt: &mut crate::PreparedStatement,
342        handler: &mut H,
343    ) -> Result<()> {
344        let cache_metadata = self
345            .mariadb_capabilities
346            .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
347        let mut bulk_exec = BulkExec::new(handler, stmt, cache_metadata);
348
349        loop {
350            match bulk_exec.step(&mut self.buffer_set)? {
351                Action::NeedPacket(buffer) => {
352                    buffer.clear();
353                    let _ = read_payload(&mut self.stream, buffer).await?;
354                }
355                Action::ReadColumnMetadata { num_columns } => {
356                    self.read_column_definition_packets(num_columns).await?;
357                }
358                Action::Finished => return Ok(()),
359            }
360        }
361    }
362
363    /// Execute a bulk prepared statement with a result set handler (async)
364    pub async fn exec_bulk<P, I, H>(
365        &mut self,
366        stmt: &mut PreparedStatement,
367        params: P,
368        flags: BulkFlags,
369        handler: &mut H,
370    ) -> Result<()>
371    where
372        P: BulkParamsSet + IntoIterator<Item = I>,
373        I: Params,
374        H: BinaryResultSetHandler,
375    {
376        if !self.is_mariadb() {
377            // Fallback to multiple exec_drop for non-MariaDB servers
378            for param in params {
379                self.exec_drop(stmt, param).await?;
380            }
381            Ok(())
382        } else {
383            // Use MariaDB bulk execute protocol
384            write_bulk_execute(self.buffer_set.new_write_buffer(), stmt.id(), params, flags)?;
385            self.write_payload().await?;
386            self.drive_bulk_exec(stmt, handler).await
387        }
388    }
389
390    /// Execute a prepared statement and return only the first row, dropping the rest (async)
391    ///
392    /// # Returns
393    /// * `Ok(true)` - First row was found and processed
394    /// * `Ok(false)` - No rows in result set
395    /// * `Err(Error)` - Query execution or handler callback failed
396    pub async fn exec_first<P, H>(
397        &mut self,
398        stmt: &mut PreparedStatement,
399        params: P,
400        handler: &mut H,
401    ) -> Result<bool>
402    where
403        P: Params,
404        H: BinaryResultSetHandler,
405    {
406        write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
407        self.write_payload().await?;
408        let mut first_row_handler = FirstRowHandler::new(handler);
409        self.drive_exec(stmt, &mut first_row_handler).await?;
410        Ok(first_row_handler.found_row)
411    }
412
413    /// Execute a prepared statement and discard all results (async)
414    #[instrument(skip_all)]
415    pub async fn exec_drop<P>(&mut self, stmt: &mut PreparedStatement, params: P) -> Result<()>
416    where
417        P: Params,
418    {
419        write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
420        self.write_payload().await?;
421        self.drive_exec(stmt, &mut DropHandler::default()).await
422    }
423
424    /// Execute a text protocol SQL query (async)
425    pub async fn query<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
426    where
427        H: TextResultSetHandler,
428    {
429        write_query(self.buffer_set.new_write_buffer(), sql);
430        self.write_payload().await?;
431        self.drive_query(handler).await
432    }
433
434    /// Execute a text protocol SQL query and discard all results (async)
435    #[instrument(skip_all)]
436    pub async fn query_drop(&mut self, sql: &str) -> Result<()> {
437        write_query(self.buffer_set.new_write_buffer(), sql);
438        self.write_payload().await?;
439        self.drive_query(&mut DropHandler::default()).await
440    }
441
442    /// Send a ping to the server to check if the connection is alive (async)
443    ///
444    /// This sends a COM_PING command to the MySQL server and waits for an OK response.
445    pub async fn ping(&mut self) -> Result<()> {
446        write_ping(self.buffer_set.new_write_buffer());
447        self.write_payload().await?;
448        self.buffer_set.read_buffer.clear();
449        let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
450        Ok(())
451    }
452
453    /// Reset the connection to its initial state (async)
454    pub async fn reset(&mut self) -> Result<()> {
455        write_reset_connection(self.buffer_set.new_write_buffer());
456        self.write_payload().await?;
457        self.buffer_set.read_buffer.clear();
458        let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
459        self.in_transaction = false;
460        Ok(())
461    }
462
463    /// Execute a closure within a transaction (async)
464    ///
465    /// # Errors
466    /// Returns `Error::NestedTransaction` if called while already in a transaction
467    pub async fn run_transaction<F, Fut, R>(&mut self, f: F) -> Result<R>
468    where
469        F: FnOnce(&mut Conn, super::transaction::Transaction) -> Fut,
470        Fut: core::future::Future<Output = Result<R>>,
471    {
472        if self.in_transaction {
473            return Err(Error::NestedTransaction);
474        }
475
476        self.in_transaction = true;
477
478        if let Err(err) = self.query_drop("BEGIN").await {
479            self.in_transaction = false;
480            return Err(err);
481        }
482
483        let tx = super::transaction::Transaction::new(self.connection_id());
484        let result = f(self, tx).await;
485
486        // If the transaction was not explicitly committed or rolled back, roll it back
487        if self.in_transaction {
488            let rollback_result = self.query_drop("ROLLBACK").await;
489            self.in_transaction = false;
490
491            // Return the first error (either from closure or rollback)
492            if let Err(e) = result {
493                return Err(e);
494            }
495            rollback_result?;
496        }
497
498        result
499    }
500}
501
502/// Read a complete MySQL payload asynchronously, concatenating packets if they span multiple 16MB chunks
503/// Returns the sequence_id of the last packet read.
504#[instrument(skip_all)]
505async fn read_payload(reader: &mut Stream, buffer: &mut Vec<u8>) -> Result<u8> {
506    let mut packet_header = PacketHeader::new_zeroed();
507
508    buffer.clear();
509    reader.read_exact(packet_header.as_mut_bytes()).await?;
510
511    let length = packet_header.length();
512    let mut sequence_id = packet_header.sequence_id;
513
514    buffer.reserve(length);
515
516    // read the first payload
517    {
518        let spare = buffer.spare_capacity_mut();
519        reader.read_buf_exact(&mut spare[..length]).await?;
520        // SAFETY: read_buf_exact filled exactly `length` bytes
521        unsafe {
522            buffer.set_len(length);
523        }
524    }
525
526    let mut current_length = length;
527    while current_length == 0xFFFFFF {
528        reader.read_exact(packet_header.as_mut_bytes()).await?;
529
530        current_length = packet_header.length();
531        sequence_id = packet_header.sequence_id;
532
533        buffer.reserve(current_length);
534        let spare = buffer.spare_capacity_mut();
535        reader.read_buf_exact(&mut spare[..current_length]).await?;
536        // SAFETY: read_buf_exact filled exactly `current_length` bytes
537        unsafe {
538            buffer.set_len(buffer.len() + current_length);
539        }
540    }
541
542    Ok(sequence_id)
543}
544
545async fn write_handshake_payload(
546    stream: &mut Stream,
547    buffer_set: &mut BufferSet,
548    sequence_id: u8,
549) -> Result<()> {
550    let mut buffer = buffer_set.write_buffer_mut().as_mut_slice();
551    let mut seq_id = sequence_id;
552
553    loop {
554        let chunk_size = buffer[4..].len().min(0xFFFFFF);
555        PacketHeader::mut_from_bytes(&mut buffer[0..4])?.encode_in_place(chunk_size, seq_id);
556        stream.write_all(&buffer[..4 + chunk_size]).await?;
557
558        if chunk_size < 0xFFFFFF {
559            break;
560        }
561
562        seq_id = seq_id.wrapping_add(1);
563        buffer = &mut buffer[0xFFFFFF..];
564    }
565    stream.flush().await?;
566    Ok(())
567}
568
569/// Handler to capture socket path from SELECT @@socket query
570struct SocketPathHandler {
571    path: Option<String>,
572}
573
574impl TextResultSetHandler for SocketPathHandler {
575    fn no_result_set(&mut self, _: OkPayloadBytes) -> Result<()> {
576        Ok(())
577    }
578    fn resultset_start(&mut self, _: &[ColumnDefinition<'_>]) -> Result<()> {
579        Ok(())
580    }
581    fn resultset_end(&mut self, _: OkPayloadBytes) -> Result<()> {
582        Ok(())
583    }
584    fn row(&mut self, _: &[ColumnDefinition<'_>], row: TextRowPayload<'_>) -> Result<()> {
585        // 0xFB indicates NULL value
586        if row.0.first() == Some(&0xFB) {
587            return Ok(());
588        }
589        // Parse the first length-encoded string
590        let (value, _) = read_string_lenenc(row.0)?;
591        if !value.is_empty() {
592            self.path = Some(String::from_utf8_lossy(value).into_owned());
593        }
594        Ok(())
595    }
596}