zero_mysql/protocol/command/
prepared.rs

1use crate::buffer::BufferSet;
2use crate::constant::CommandByte;
3use crate::error::{Error, Result, eyre};
4use crate::protocol::BinaryRowPayload;
5use crate::protocol::primitive::*;
6use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
7use crate::protocol::r#trait::param::Params;
8use zerocopy::byteorder::little_endian::{U16 as U16LE, U32 as U32LE};
9use zerocopy::{FromBytes, Immutable, KnownLayout};
10
11/// Prepared statement OK response (zero-copy)
12#[repr(C, packed)]
13#[derive(Debug, Clone, Copy, FromBytes, KnownLayout, Immutable)]
14pub struct PrepareOk {
15    statement_id: U32LE,
16    num_columns: U16LE,
17    num_params: U16LE,
18    _reserved: u8,
19    warning_count: U16LE, // MySQL >= 5.7 and MariaDB all expect at least 12 bytes: https://github.com/launchbadge/sqlx/issues/3335
20}
21
22impl PrepareOk {
23    /// Get the statement ID
24    pub fn statement_id(&self) -> u32 {
25        self.statement_id.get()
26    }
27
28    /// Get the number of columns in the result set
29    pub fn num_columns(&self) -> u16 {
30        self.num_columns.get()
31    }
32
33    /// Get the number of parameters in the prepared statement
34    pub fn num_params(&self) -> u16 {
35        self.num_params.get()
36    }
37
38    /// Get the warning count
39    pub fn warning_count(&self) -> u16 {
40        self.warning_count.get()
41    }
42}
43
44/// Write COM_STMT_PREPARE command
45pub fn write_prepare(out: &mut Vec<u8>, sql: &str) {
46    write_int_1(out, CommandByte::StmtPrepare as u8);
47    out.extend_from_slice(sql.as_bytes());
48}
49
50/// Read COM_STMT_PREPARE response
51pub fn read_prepare_ok(payload: &[u8]) -> Result<&PrepareOk> {
52    let (status, data) = read_int_1(payload)?;
53    debug_assert_eq!(status, 0x00);
54    Ok(PrepareOk::ref_from_bytes(&data[..11])?)
55}
56
57/// Write COM_STMT_EXECUTE command
58pub fn write_execute<P: Params>(out: &mut Vec<u8>, statement_id: u32, params: P) -> Result<()> {
59    write_int_1(out, CommandByte::StmtExecute as u8);
60    write_int_4(out, statement_id);
61
62    // flags (1 byte) - CURSOR_TYPE_NO_CURSOR
63    write_int_1(out, 0x00);
64
65    // iteration count (4 bytes) - always 1
66    write_int_4(out, 1);
67
68    let num_params = params.len();
69
70    if num_params > 0 {
71        // NULL bitmap: (num_params + 7) / 8 bytes
72        params.encode_null_bitmap(out);
73
74        // new-params-bound-flag (1 byte)
75        let send_types_to_server = true;
76        if send_types_to_server {
77            write_int_1(out, 0x01);
78            params.encode_types(out);
79        } else {
80            write_int_1(out, 0x00);
81        }
82
83        params.encode_values(out)?; // Ignore errors for now (non-priority)
84    }
85    Ok(())
86}
87
88/// Read COM_STMT_EXECUTE response
89/// This can be either an OK packet or a result set
90pub fn read_execute_response(payload: &[u8], cache_metadata: bool) -> Result<ExecuteResponse<'_>> {
91    if payload.is_empty() {
92        return Err(Error::LibraryBug(eyre!(
93            "read_execute_response: empty payload"
94        )));
95    }
96
97    match payload[0] {
98        0x00 => Ok(ExecuteResponse::Ok(OkPayloadBytes(payload))),
99        0xFF => Err(ErrPayloadBytes(payload).into()),
100        _ => {
101            let (column_count, rest) = read_int_lenenc(payload)?;
102
103            // If MARIADB_CLIENT_CACHE_METADATA is set, read the metadata_follows flag
104            let has_column_metadata = if cache_metadata {
105                if rest.is_empty() {
106                    return Err(Error::LibraryBug(eyre!(
107                        "read_execute_response: missing metadata_follows flag"
108                    )));
109                }
110                rest[0] != 0
111            } else {
112                // Without caching, metadata always follows
113                true
114            };
115
116            Ok(ExecuteResponse::ResultSet {
117                column_count,
118                has_column_metadata,
119            })
120        }
121    }
122}
123
124/// Execute response variants
125#[derive(Debug)]
126pub enum ExecuteResponse<'a> {
127    Ok(OkPayloadBytes<'a>),
128    ResultSet {
129        column_count: u64,
130        has_column_metadata: bool,
131    },
132}
133
134/// Read binary protocol row from execute response
135pub fn read_binary_row<'a>(payload: &'a [u8], num_columns: usize) -> Result<BinaryRowPayload<'a>> {
136    crate::protocol::command::resultset::read_binary_row(payload, num_columns)
137}
138
139/// Write COM_STMT_CLOSE command
140pub fn write_close_statement(out: &mut Vec<u8>, statement_id: u32) {
141    write_int_1(out, CommandByte::StmtClose as u8);
142    write_int_4(out, statement_id);
143}
144
145/// Write COM_STMT_RESET command
146pub fn write_reset_statement(out: &mut Vec<u8>, statement_id: u32) {
147    write_int_1(out, CommandByte::StmtReset as u8);
148    write_int_4(out, statement_id);
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn prepare_ok_has_alignment_of_1() {
157        assert_eq!(std::mem::align_of::<PrepareOk>(), 1);
158    }
159}
160
161// ============================================================================
162// State Machine API for exec_fold
163// ============================================================================
164
165use crate::PreparedStatement;
166use crate::protocol::command::ColumnDefinitions;
167use crate::protocol::r#trait::BinaryResultSetHandler;
168
169/// Internal state of the Exec state machine
170enum ExecState {
171    /// Initial state - need to read first packet
172    Start,
173    /// Reading the first response packet
174    ReadingFirstPacket,
175    /// Reading column definitions (processing the buffer after reading all packets)
176    ReadingColumns { num_columns: usize },
177    /// Reading rows
178    ReadingRows { num_columns: usize },
179    /// Finished
180    Finished,
181}
182
183/// State machine for executing prepared statements (binary protocol) with integrated handler
184///
185/// The handler is provided at construction and called directly by the state machine.
186/// The `drive()` method returns actions indicating what I/O operation is needed next.
187pub struct Exec<'h, 'stmt, H> {
188    state: ExecState,
189    handler: &'h mut H,
190    stmt: &'stmt mut PreparedStatement,
191    cache_metadata: bool,
192}
193
194impl<'h, 'stmt, H: BinaryResultSetHandler> Exec<'h, 'stmt, H> {
195    /// Create a new Exec state machine with the given handler and prepared statement
196    pub fn new(
197        handler: &'h mut H,
198        stmt: &'stmt mut PreparedStatement,
199        cache_metadata: bool,
200    ) -> Self {
201        Self {
202            state: ExecState::Start,
203            handler,
204            stmt,
205            cache_metadata,
206        }
207    }
208
209    /// Drive the state machine forward
210    ///
211    /// # Arguments
212    /// * `buffer_set` - The buffer set containing buffers to read from/write to
213    ///
214    /// # Returns
215    /// * `Action::NeedPacket(&mut Vec<u8>)` - Needs more data in the specified buffer
216    /// * `Action::Finished` - Processing complete
217    pub fn step<'buf>(
218        &mut self,
219        buffer_set: &'buf mut BufferSet,
220    ) -> Result<crate::protocol::command::Action<'buf>> {
221        use crate::protocol::command::Action;
222        match &mut self.state {
223            ExecState::Start => {
224                // Request the first packet
225                self.state = ExecState::ReadingFirstPacket;
226                Ok(Action::NeedPacket(&mut buffer_set.read_buffer))
227            }
228
229            ExecState::ReadingFirstPacket => {
230                let payload = &buffer_set.read_buffer[..];
231                let response = read_execute_response(payload, self.cache_metadata)?;
232
233                match response {
234                    ExecuteResponse::Ok(ok_bytes) => {
235                        // Parse OK packet to check status flags
236                        use crate::constant::ServerStatusFlags;
237                        use crate::protocol::response::OkPayload;
238
239                        let ok_payload = OkPayload::try_from(ok_bytes)?;
240                        self.handler.no_result_set(ok_bytes)?;
241
242                        // Check if there are more results to come
243                        if ok_payload
244                            .status_flags
245                            .contains(ServerStatusFlags::SERVER_MORE_RESULTS_EXISTS)
246                        {
247                            // More resultsets coming, go to ReadingFirstPacket to process next result
248                            self.state = ExecState::ReadingFirstPacket;
249                            Ok(Action::NeedPacket(&mut buffer_set.read_buffer))
250                        } else {
251                            // No more results, we're done
252                            self.state = ExecState::Finished;
253                            Ok(Action::Finished)
254                        }
255                    }
256                    ExecuteResponse::ResultSet {
257                        column_count,
258                        has_column_metadata,
259                    } => {
260                        let num_columns = column_count as usize;
261
262                        if has_column_metadata {
263                            // Server sent metadata, signal that we need to read N column packets
264                            self.state = ExecState::ReadingColumns { num_columns };
265                            Ok(Action::ReadColumnMetadata { num_columns })
266                        } else {
267                            // No metadata from server, use cached definitions
268                            if let Some(cols) = self.stmt.column_definitions() {
269                                self.handler.resultset_start(cols)?;
270                                self.state = ExecState::ReadingRows { num_columns };
271                                Ok(Action::NeedPacket(&mut buffer_set.read_buffer))
272                            } else {
273                                // No cache available but server didn't send metadata - error
274                                Err(Error::LibraryBug(eyre!(
275                                    "no cached column definitions available"
276                                )))
277                            }
278                        }
279                    }
280                }
281            }
282
283            ExecState::ReadingColumns { num_columns } => {
284                // Parse all column definitions from the buffer
285                // The buffer contains [len(u32)][payload][len(u32)][payload]...
286                let column_defs = ColumnDefinitions::new(
287                    *num_columns,
288                    std::mem::take(&mut buffer_set.column_definition_buffer),
289                )?;
290
291                // Cache the column definitions in the prepared statement
292                self.handler.resultset_start(column_defs.definitions())?;
293                self.stmt.set_column_definitions(column_defs);
294
295                // Move to reading rows
296                self.state = ExecState::ReadingRows {
297                    num_columns: *num_columns,
298                };
299                Ok(Action::NeedPacket(&mut buffer_set.read_buffer))
300            }
301
302            ExecState::ReadingRows { num_columns } => {
303                let payload = &buffer_set.read_buffer[..];
304                match payload[0] {
305                    0x00 => {
306                        let row = read_binary_row(payload, *num_columns)?;
307                        let cols = self.stmt.column_definitions().ok_or_else(|| {
308                            Error::LibraryBug(eyre!("no column definitions while reading rows"))
309                        })?;
310                        self.handler.row(cols, row)?;
311                        Ok(Action::NeedPacket(&mut buffer_set.read_buffer))
312                    }
313                    0xFE => {
314                        // Parse OK packet to check status flags
315                        use crate::constant::ServerStatusFlags;
316                        use crate::protocol::response::OkPayload;
317
318                        let eof_bytes = OkPayloadBytes(payload);
319                        eof_bytes.assert_eof()?;
320                        let ok_payload = OkPayload::try_from(eof_bytes)?;
321                        self.handler.resultset_end(eof_bytes)?;
322
323                        // Check if there are more results to come
324                        if ok_payload
325                            .status_flags
326                            .contains(ServerStatusFlags::SERVER_MORE_RESULTS_EXISTS)
327                        {
328                            // More resultsets coming, go to ReadingFirstPacket to process next result
329                            self.state = ExecState::ReadingFirstPacket;
330                            Ok(Action::NeedPacket(&mut buffer_set.read_buffer))
331                        } else {
332                            // No more results, we're done
333                            self.state = ExecState::Finished;
334                            Ok(Action::Finished)
335                        }
336                    }
337                    header => Err(Error::LibraryBug(eyre!(
338                        "unexpected row packet header: 0x{:02X}",
339                        header
340                    ))),
341                }
342            }
343
344            ExecState::Finished => {
345                Err(Error::LibraryBug(eyre!("Exec::step called after finished")))
346            }
347        }
348    }
349}