zero_mysql/tokio/
conn.rs

1use std::ops::AsyncFnOnce;
2
3use tokio::net::TcpStream;
4#[cfg(unix)]
5use tokio::net::UnixStream;
6use tracing::instrument;
7use zerocopy::{FromBytes, FromZeros, IntoBytes};
8
9use crate::PreparedStatement;
10use crate::buffer::BufferSet;
11use crate::buffer_pool::PooledBufferSet;
12use crate::constant::CapabilityFlags;
13use crate::error::{Error, Result};
14use crate::protocol::TextRowPayload;
15use crate::protocol::command::Action;
16use crate::protocol::command::ColumnDefinition;
17use crate::protocol::command::bulk_exec::{BulkExec, BulkFlags, BulkParamsSet, write_bulk_execute};
18use crate::protocol::command::prepared::{Exec, read_prepare_ok, write_execute, write_prepare};
19use crate::protocol::command::query::{Query, write_query};
20use crate::protocol::command::utility::{
21    DropHandler, FirstHandler, write_ping, write_reset_connection,
22};
23use crate::protocol::connection::{Handshake, HandshakeAction, InitialHandshake};
24use crate::protocol::packet::PacketHeader;
25use crate::protocol::primitive::read_string_lenenc;
26use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
27use crate::protocol::r#trait::{BinaryResultSetHandler, TextResultSetHandler, param::Params};
28
29use super::stream::Stream;
30
31pub struct Conn {
32    stream: Stream,
33    buffer_set: PooledBufferSet,
34    initial_handshake: InitialHandshake,
35    capability_flags: CapabilityFlags,
36    mariadb_capabilities: crate::constant::MariadbCapabilityFlags,
37    in_transaction: bool,
38    is_broken: bool,
39}
40
41impl Conn {
42    /// Create a new MySQL connection from connection options (async)
43    pub async fn new<O: TryInto<crate::opts::Opts>>(opts: O) -> Result<Self>
44    where
45        Error: From<O::Error>,
46    {
47        let opts: crate::opts::Opts = opts.try_into()?;
48
49        #[cfg(unix)]
50        let stream = if let Some(socket_path) = &opts.socket {
51            let stream = UnixStream::connect(socket_path).await?;
52            Stream::unix(stream)
53        } else {
54            if opts.host.is_empty() {
55                return Err(Error::BadUsageError(
56                    "Missing host in connection options".to_string(),
57                ));
58            }
59            let addr = format!("{}:{}", opts.host, opts.port);
60            let stream = TcpStream::connect(&addr).await?;
61            stream.set_nodelay(opts.tcp_nodelay)?;
62            Stream::tcp(stream)
63        };
64
65        #[cfg(not(unix))]
66        let stream = {
67            if opts.socket.is_some() {
68                return Err(Error::BadUsageError(
69                    "Unix sockets are not supported on this platform".to_string(),
70                ));
71            }
72            if opts.host.is_empty() {
73                return Err(Error::BadUsageError(
74                    "Missing host in connection options".to_string(),
75                ));
76            }
77            let addr = format!("{}:{}", opts.host, opts.port);
78            let stream = TcpStream::connect(&addr).await?;
79            stream.set_nodelay(opts.tcp_nodelay)?;
80            Stream::tcp(stream)
81        };
82
83        Self::new_with_stream(stream, &opts).await
84    }
85
86    /// Create a new MySQL connection with an existing stream (async)
87    pub async fn new_with_stream(stream: Stream, opts: &crate::opts::Opts) -> Result<Self> {
88        let mut conn_stream = stream;
89        let mut buffer_set = opts.buffer_pool.get_buffer_set();
90
91        #[cfg(feature = "tokio-tls")]
92        let host = opts.host.clone();
93
94        let mut handshake = Handshake::new(opts);
95
96        loop {
97            match handshake.step(&mut buffer_set)? {
98                HandshakeAction::ReadPacket(buffer) => {
99                    buffer.clear();
100                    read_payload(&mut conn_stream, buffer).await?;
101                }
102                HandshakeAction::WritePacket { sequence_id } => {
103                    write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id).await?;
104                    buffer_set.read_buffer.clear();
105                    read_payload(&mut conn_stream, &mut buffer_set.read_buffer).await?;
106                }
107                #[cfg(feature = "tokio-tls")]
108                HandshakeAction::UpgradeTls { sequence_id } => {
109                    write_handshake_payload(&mut conn_stream, &mut buffer_set, sequence_id).await?;
110                    conn_stream = conn_stream.upgrade_to_tls(&host).await?;
111                }
112                #[cfg(not(feature = "tokio-tls"))]
113                HandshakeAction::UpgradeTls { .. } => {
114                    return Err(Error::BadUsageError(
115                        "TLS requested but tokio-tls feature is not enabled".to_string(),
116                    ));
117                }
118                HandshakeAction::Finished => break,
119            }
120        }
121
122        let (initial_handshake, capability_flags, mariadb_capabilities) = handshake.finish()?;
123
124        let conn = Self {
125            stream: conn_stream,
126            buffer_set,
127            initial_handshake,
128            capability_flags,
129            mariadb_capabilities,
130            in_transaction: false,
131            is_broken: false,
132        };
133
134        // Upgrade to Unix socket if connected via TCP to loopback
135        #[cfg(unix)]
136        let mut conn = if opts.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
137            conn.try_upgrade_to_unix_socket(opts).await
138        } else {
139            conn
140        };
141        #[cfg(not(unix))]
142        let mut conn = conn;
143
144        // Execute init command if specified
145        if let Some(init_command) = &opts.init_command {
146            conn.query_drop(init_command).await?;
147        }
148
149        Ok(conn)
150    }
151
152    pub fn server_version(&self) -> &[u8] {
153        &self.buffer_set.initial_handshake[self.initial_handshake.server_version.clone()]
154    }
155
156    /// Get the negotiated capability flags
157    pub fn capability_flags(&self) -> CapabilityFlags {
158        self.capability_flags
159    }
160
161    /// Check if the server is MySQL (as opposed to MariaDB)
162    pub fn is_mysql(&self) -> bool {
163        self.capability_flags.is_mysql()
164    }
165
166    /// Check if the server is MariaDB (as opposed to MySQL)
167    pub fn is_mariadb(&self) -> bool {
168        self.capability_flags.is_mariadb()
169    }
170
171    /// Get the connection ID assigned by the server
172    pub fn connection_id(&self) -> u64 {
173        self.initial_handshake.connection_id as u64
174    }
175
176    /// Get the server status flags from the initial handshake
177    pub fn status_flags(&self) -> crate::constant::ServerStatusFlags {
178        self.initial_handshake.status_flags
179    }
180
181    /// Check if the connection is broken due to a previous I/O error
182    pub fn is_broken(&self) -> bool {
183        self.is_broken
184    }
185
186    #[inline]
187    fn check_error<T>(&mut self, result: Result<T>) -> Result<T> {
188        if let Err(e) = &result
189            && e.is_conn_broken()
190        {
191            self.is_broken = true;
192        }
193        result
194    }
195
196    pub(crate) fn set_in_transaction(&mut self, value: bool) {
197        self.in_transaction = value;
198    }
199
200    /// Returns true if the connection is currently in a transaction
201    pub fn in_transaction(&self) -> bool {
202        self.in_transaction
203    }
204
205    /// Try to upgrade to Unix socket connection.
206    /// Returns upgraded conn on success, original conn on failure.
207    #[cfg(unix)]
208    async fn try_upgrade_to_unix_socket(mut self, opts: &crate::opts::Opts) -> Self {
209        // Query the server for its Unix socket path
210        let mut handler = SocketPathHandler { path: None };
211        if self.query("SELECT @@socket", &mut handler).await.is_err() {
212            return self;
213        }
214
215        let socket_path = match handler.path {
216            Some(p) if !p.is_empty() => p,
217            _ => return self,
218        };
219
220        // Connect via Unix socket
221        let unix_stream = match UnixStream::connect(&socket_path).await {
222            Ok(s) => s,
223            Err(_) => return self,
224        };
225        let stream = Stream::unix(unix_stream);
226
227        // Create new connection over Unix socket (re-handshakes)
228        // Disable upgrade_to_unix_socket to prevent infinite recursion
229        let mut opts_unix = opts.clone();
230        opts_unix.upgrade_to_unix_socket = false;
231
232        match Box::pin(Self::new_with_stream(stream, &opts_unix)).await {
233            Ok(new_conn) => new_conn,
234            Err(_) => self,
235        }
236    }
237
238    /// Write a MySQL packet from write_buffer asynchronously, splitting it into 16MB chunks if necessary
239    #[instrument(skip_all)]
240    async fn write_payload(&mut self) -> Result<()> {
241        let mut sequence_id = 0_u8;
242        let mut buffer = self.buffer_set.write_buffer_mut().as_mut_slice();
243
244        loop {
245            let chunk_size = buffer[4..].len().min(0xFFFFFF);
246            PacketHeader::mut_from_bytes(&mut buffer[0..4])?
247                .encode_in_place(chunk_size, sequence_id);
248            self.stream.write_all(&buffer[..4 + chunk_size]).await?;
249
250            if chunk_size < 0xFFFFFF {
251                break;
252            }
253
254            sequence_id = sequence_id.wrapping_add(1);
255            buffer = &mut buffer[0xFFFFFF..];
256        }
257        self.stream.flush().await?;
258        Ok(())
259    }
260
261    /// Prepare a statement and return the PreparedStatement (async)
262    ///
263    /// Returns `Ok(PreparedStatement)` on success.
264    pub async fn prepare(&mut self, sql: &str) -> Result<PreparedStatement> {
265        let result = self.prepare_inner(sql).await;
266        self.check_error(result)
267    }
268
269    async fn prepare_inner(&mut self, sql: &str) -> Result<PreparedStatement> {
270        use crate::protocol::command::ColumnDefinitions;
271
272        self.buffer_set.read_buffer.clear();
273
274        write_prepare(self.buffer_set.new_write_buffer(), sql);
275
276        self.write_payload().await?;
277
278        let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
279
280        if !self.buffer_set.read_buffer.is_empty() && self.buffer_set.read_buffer[0] == 0xFF {
281            Err(ErrPayloadBytes(&self.buffer_set.read_buffer))?
282        }
283
284        let prepare_ok = read_prepare_ok(&self.buffer_set.read_buffer)?;
285        let statement_id = prepare_ok.statement_id();
286        let num_params = prepare_ok.num_params();
287        let num_columns = prepare_ok.num_columns();
288
289        // Skip param definitions (we don't cache them)
290        for _ in 0..num_params {
291            let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
292        }
293
294        // Read and cache column definitions for MARIADB_CLIENT_CACHE_METADATA support
295        let column_definitions = if num_columns > 0 {
296            self.read_column_definition_packets(num_columns as usize)
297                .await?;
298            Some(ColumnDefinitions::new(
299                num_columns as usize,
300                std::mem::take(&mut self.buffer_set.column_definition_buffer),
301            )?)
302        } else {
303            None
304        };
305
306        let mut stmt = PreparedStatement::new(statement_id);
307        if let Some(col_defs) = column_definitions {
308            stmt.set_column_definitions(col_defs);
309        }
310        Ok(stmt)
311    }
312
313    #[tracing::instrument(skip_all)]
314    async fn read_column_definition_packets(&mut self, num_columns: usize) -> Result<u8> {
315        let mut header = PacketHeader::new_zeroed();
316        let out = &mut self.buffer_set.column_definition_buffer;
317        out.clear();
318
319        // For each column, write [4 bytes len][payload]
320        for _ in 0..num_columns {
321            self.stream.read_exact(header.as_mut_bytes()).await?;
322            let length = header.length();
323            out.extend((length as u32).to_ne_bytes());
324
325            out.reserve(length);
326            let spare = out.spare_capacity_mut();
327            self.stream.read_buf_exact(&mut spare[..length]).await?;
328            // SAFETY: read_buf_exact filled exactly `length` bytes
329            unsafe {
330                out.set_len(out.len() + length);
331            }
332        }
333
334        Ok(header.sequence_id)
335    }
336
337    async fn drive_exec<H: BinaryResultSetHandler>(
338        &mut self,
339        stmt: &mut crate::PreparedStatement,
340        handler: &mut H,
341    ) -> Result<()> {
342        let cache_metadata = self
343            .mariadb_capabilities
344            .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
345        let mut exec = Exec::new(handler, stmt, cache_metadata);
346
347        loop {
348            match exec.step(&mut self.buffer_set)? {
349                Action::NeedPacket(buffer) => {
350                    buffer.clear();
351                    let _ = read_payload(&mut self.stream, buffer).await?;
352                }
353                Action::ReadColumnMetadata { num_columns } => {
354                    self.read_column_definition_packets(num_columns).await?;
355                }
356                Action::Finished => return Ok(()),
357            }
358        }
359    }
360
361    async fn drive_query<H: TextResultSetHandler>(&mut self, handler: &mut H) -> Result<()> {
362        let mut query = Query::new(handler);
363
364        loop {
365            match query.step(&mut self.buffer_set)? {
366                Action::NeedPacket(buffer) => {
367                    buffer.clear();
368                    let _ = read_payload(&mut self.stream, buffer).await?;
369                }
370                Action::ReadColumnMetadata { num_columns } => {
371                    self.read_column_definition_packets(num_columns).await?;
372                }
373                Action::Finished => return Ok(()),
374            }
375        }
376    }
377
378    /// Execute a prepared statement with a result set handler (async)
379    pub async fn exec<P, H>(
380        &mut self,
381        stmt: &mut PreparedStatement,
382        params: P,
383        handler: &mut H,
384    ) -> Result<()>
385    where
386        P: Params,
387        H: BinaryResultSetHandler,
388    {
389        let result = self.exec_inner(stmt, params, handler).await;
390        self.check_error(result)
391    }
392
393    async fn exec_inner<P, H>(
394        &mut self,
395        stmt: &mut PreparedStatement,
396        params: P,
397        handler: &mut H,
398    ) -> Result<()>
399    where
400        P: Params,
401        H: BinaryResultSetHandler,
402    {
403        write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
404        self.write_payload().await?;
405        self.drive_exec(stmt, handler).await
406    }
407
408    async fn drive_bulk_exec<H: BinaryResultSetHandler>(
409        &mut self,
410        stmt: &mut crate::PreparedStatement,
411        handler: &mut H,
412    ) -> Result<()> {
413        let cache_metadata = self
414            .mariadb_capabilities
415            .contains(crate::constant::MariadbCapabilityFlags::MARIADB_CLIENT_CACHE_METADATA);
416        let mut bulk_exec = BulkExec::new(handler, stmt, cache_metadata);
417
418        loop {
419            match bulk_exec.step(&mut self.buffer_set)? {
420                Action::NeedPacket(buffer) => {
421                    buffer.clear();
422                    let _ = read_payload(&mut self.stream, buffer).await?;
423                }
424                Action::ReadColumnMetadata { num_columns } => {
425                    self.read_column_definition_packets(num_columns).await?;
426                }
427                Action::Finished => return Ok(()),
428            }
429        }
430    }
431
432    /// Execute a bulk prepared statement with a result set handler (async)
433    pub async fn exec_bulk_insert_or_update<P, I, H>(
434        &mut self,
435        stmt: &mut PreparedStatement,
436        params: P,
437        flags: BulkFlags,
438        handler: &mut H,
439    ) -> Result<()>
440    where
441        P: BulkParamsSet + IntoIterator<Item = I>,
442        I: Params,
443        H: BinaryResultSetHandler,
444    {
445        let result = self
446            .exec_bulk_insert_or_update_inner(stmt, params, flags, handler)
447            .await;
448        self.check_error(result)
449    }
450
451    async fn exec_bulk_insert_or_update_inner<P, I, H>(
452        &mut self,
453        stmt: &mut PreparedStatement,
454        params: P,
455        flags: BulkFlags,
456        handler: &mut H,
457    ) -> Result<()>
458    where
459        P: BulkParamsSet + IntoIterator<Item = I>,
460        I: Params,
461        H: BinaryResultSetHandler,
462    {
463        if !self.is_mariadb() {
464            // Fallback to multiple exec_drop for non-MariaDB servers
465            for param in params {
466                self.exec_inner(stmt, param, &mut DropHandler::default())
467                    .await?;
468            }
469            Ok(())
470        } else {
471            // Use MariaDB bulk execute protocol
472            write_bulk_execute(self.buffer_set.new_write_buffer(), stmt.id(), params, flags)?;
473            self.write_payload().await?;
474            self.drive_bulk_exec(stmt, handler).await
475        }
476    }
477
478    /// Execute a prepared statement and return only the first row, dropping the rest (async).
479    pub async fn exec_first<Row, P>(
480        &mut self,
481        stmt: &mut PreparedStatement,
482        params: P,
483    ) -> Result<Option<Row>>
484    where
485        Row: for<'buf> crate::raw::FromRawRow<'buf>,
486        P: Params,
487    {
488        let result = self.exec_first_inner(stmt, params).await;
489        self.check_error(result)
490    }
491
492    async fn exec_first_inner<Row, P>(
493        &mut self,
494        stmt: &mut PreparedStatement,
495        params: P,
496    ) -> Result<Option<Row>>
497    where
498        Row: for<'buf> crate::raw::FromRawRow<'buf>,
499        P: Params,
500    {
501        write_execute(self.buffer_set.new_write_buffer(), stmt.id(), params)?;
502        self.write_payload().await?;
503        let mut handler = FirstHandler::<Row>::default();
504        self.drive_exec(stmt, &mut handler).await?;
505        Ok(handler.take())
506    }
507
508    /// Execute a prepared statement and discard all results (async)
509    #[instrument(skip_all)]
510    pub async fn exec_drop<P>(&mut self, stmt: &mut PreparedStatement, params: P) -> Result<()>
511    where
512        P: Params,
513    {
514        self.exec(stmt, params, &mut DropHandler::default()).await
515    }
516
517    /// Execute a prepared statement and collect all rows into a Vec (async).
518    pub async fn exec_collect<Row, P>(
519        &mut self,
520        stmt: &mut PreparedStatement,
521        params: P,
522    ) -> Result<Vec<Row>>
523    where
524        Row: for<'buf> crate::raw::FromRawRow<'buf>,
525        P: Params,
526    {
527        let mut handler = crate::handler::CollectHandler::<Row>::default();
528        self.exec(stmt, params, &mut handler).await?;
529        Ok(handler.into_rows())
530    }
531
532    /// Execute a prepared statement and call a closure for each row (async).
533    ///
534    /// The closure can return an error to stop iteration early.
535    pub async fn exec_foreach<Row, P, F>(
536        &mut self,
537        stmt: &mut PreparedStatement,
538        params: P,
539        f: F,
540    ) -> Result<()>
541    where
542        Row: for<'buf> crate::raw::FromRawRow<'buf>,
543        P: Params,
544        F: FnMut(Row) -> Result<()>,
545    {
546        let mut handler = crate::handler::ForEachHandler::<Row, F>::new(f);
547        self.exec(stmt, params, &mut handler).await
548    }
549
550    /// Execute a text protocol SQL query (async)
551    pub async fn query<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
552    where
553        H: TextResultSetHandler,
554    {
555        let result = self.query_inner(sql, handler).await;
556        self.check_error(result)
557    }
558
559    async fn query_inner<H>(&mut self, sql: &str, handler: &mut H) -> Result<()>
560    where
561        H: TextResultSetHandler,
562    {
563        write_query(self.buffer_set.new_write_buffer(), sql);
564        self.write_payload().await?;
565        self.drive_query(handler).await
566    }
567
568    /// Execute a text protocol SQL query and discard all results (async)
569    #[instrument(skip_all)]
570    pub async fn query_drop(&mut self, sql: &str) -> Result<()> {
571        let result = self.query_drop_inner(sql).await;
572        self.check_error(result)
573    }
574
575    async fn query_drop_inner(&mut self, sql: &str) -> Result<()> {
576        write_query(self.buffer_set.new_write_buffer(), sql);
577        self.write_payload().await?;
578        self.drive_query(&mut DropHandler::default()).await
579    }
580
581    /// Send a ping to the server to check if the connection is alive (async)
582    ///
583    /// This sends a COM_PING command to the MySQL server and waits for an OK response.
584    pub async fn ping(&mut self) -> Result<()> {
585        let result = self.ping_inner().await;
586        self.check_error(result)
587    }
588
589    async fn ping_inner(&mut self) -> Result<()> {
590        write_ping(self.buffer_set.new_write_buffer());
591        self.write_payload().await?;
592        self.buffer_set.read_buffer.clear();
593        let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
594        Ok(())
595    }
596
597    /// Reset the connection to its initial state (async)
598    pub async fn reset(&mut self) -> Result<()> {
599        let result = self.reset_inner().await;
600        self.check_error(result)
601    }
602
603    async fn reset_inner(&mut self) -> Result<()> {
604        write_reset_connection(self.buffer_set.new_write_buffer());
605        self.write_payload().await?;
606        self.buffer_set.read_buffer.clear();
607        let _ = read_payload(&mut self.stream, &mut self.buffer_set.read_buffer).await?;
608        self.in_transaction = false;
609        Ok(())
610    }
611
612    /// Execute a closure within a transaction (async)
613    ///
614    /// # Errors
615    /// Returns `Error::NestedTransaction` if called while already in a transaction
616    pub async fn transaction<F, R>(&mut self, f: F) -> Result<R>
617    where
618        F: AsyncFnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
619    {
620        if self.in_transaction {
621            return Err(Error::NestedTransaction);
622        }
623
624        self.in_transaction = true;
625
626        if let Err(err) = self.query_drop("BEGIN").await {
627            self.in_transaction = false;
628            return Err(err);
629        }
630
631        let tx = super::transaction::Transaction::new(self.connection_id());
632        let result = f(self, tx).await;
633
634        // If no explicit commit/rollback was called, commit on Ok, rollback on Err
635        if self.in_transaction {
636            self.in_transaction = false;
637            match &result {
638                Ok(_) => self.query_drop("COMMIT").await?,
639                Err(_) => {
640                    let _ = self.query_drop("ROLLBACK").await;
641                }
642            }
643        }
644
645        result
646    }
647}
648
649/// Read a complete MySQL payload asynchronously, concatenating packets if they span multiple 16MB chunks
650/// Returns the sequence_id of the last packet read.
651#[instrument(skip_all)]
652async fn read_payload(reader: &mut Stream, buffer: &mut Vec<u8>) -> Result<u8> {
653    let mut packet_header = PacketHeader::new_zeroed();
654
655    buffer.clear();
656    reader.read_exact(packet_header.as_mut_bytes()).await?;
657
658    let length = packet_header.length();
659    let mut sequence_id = packet_header.sequence_id;
660
661    buffer.reserve(length);
662
663    // read the first payload
664    {
665        let spare = buffer.spare_capacity_mut();
666        reader.read_buf_exact(&mut spare[..length]).await?;
667        // SAFETY: read_buf_exact filled exactly `length` bytes
668        unsafe {
669            buffer.set_len(length);
670        }
671    }
672
673    let mut current_length = length;
674    while current_length == 0xFFFFFF {
675        reader.read_exact(packet_header.as_mut_bytes()).await?;
676
677        current_length = packet_header.length();
678        sequence_id = packet_header.sequence_id;
679
680        buffer.reserve(current_length);
681        let spare = buffer.spare_capacity_mut();
682        reader.read_buf_exact(&mut spare[..current_length]).await?;
683        // SAFETY: read_buf_exact filled exactly `current_length` bytes
684        unsafe {
685            buffer.set_len(buffer.len() + current_length);
686        }
687    }
688
689    Ok(sequence_id)
690}
691
692async fn write_handshake_payload(
693    stream: &mut Stream,
694    buffer_set: &mut BufferSet,
695    sequence_id: u8,
696) -> Result<()> {
697    let mut buffer = buffer_set.write_buffer_mut().as_mut_slice();
698    let mut seq_id = sequence_id;
699
700    loop {
701        let chunk_size = buffer[4..].len().min(0xFFFFFF);
702        PacketHeader::mut_from_bytes(&mut buffer[0..4])?.encode_in_place(chunk_size, seq_id);
703        stream.write_all(&buffer[..4 + chunk_size]).await?;
704
705        if chunk_size < 0xFFFFFF {
706            break;
707        }
708
709        seq_id = seq_id.wrapping_add(1);
710        buffer = &mut buffer[0xFFFFFF..];
711    }
712    stream.flush().await?;
713    Ok(())
714}
715
716/// Handler to capture socket path from SELECT @@socket query
717#[cfg(unix)]
718struct SocketPathHandler {
719    path: Option<String>,
720}
721
722#[cfg(unix)]
723impl TextResultSetHandler for SocketPathHandler {
724    fn no_result_set(&mut self, _: OkPayloadBytes) -> Result<()> {
725        Ok(())
726    }
727    fn resultset_start(&mut self, _: &[ColumnDefinition<'_>]) -> Result<()> {
728        Ok(())
729    }
730    fn resultset_end(&mut self, _: OkPayloadBytes) -> Result<()> {
731        Ok(())
732    }
733    fn row(&mut self, _: &[ColumnDefinition<'_>], row: TextRowPayload<'_>) -> Result<()> {
734        // 0xFB indicates NULL value
735        if row.0.first() == Some(&0xFB) {
736            return Ok(());
737        }
738        // Parse the first length-encoded string
739        let (value, _) = read_string_lenenc(row.0)?;
740        if !value.is_empty() {
741            self.path = Some(String::from_utf8_lossy(value).into_owned());
742        }
743        Ok(())
744    }
745}