zero_mysql/protocol/command/
prepared.rs

1use crate::buffer::BufferSet;
2use crate::constant::CommandByte;
3use crate::error::{Error, Result, eyre};
4use crate::protocol::BinaryRowPayload;
5use crate::protocol::primitive::*;
6use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
7use crate::protocol::r#trait::param::Params;
8use zerocopy::byteorder::little_endian::{U16 as U16LE, U32 as U32LE};
9use zerocopy::{FromBytes, Immutable, KnownLayout};
10
11/// Prepared statement OK response (zero-copy)
12#[repr(C, packed)]
13#[derive(Debug, Clone, Copy, FromBytes, KnownLayout, Immutable)]
14pub struct PrepareOk {
15    statement_id: U32LE,
16    num_columns: U16LE,
17    num_params: U16LE,
18    _reserved: u8,
19    warning_count: U16LE, // MySQL >= 5.7 and MariaDB all expect at least 12 bytes: https://github.com/launchbadge/sqlx/issues/3335
20}
21
22impl PrepareOk {
23    /// Get the statement ID
24    pub fn statement_id(&self) -> u32 {
25        self.statement_id.get()
26    }
27
28    /// Get the number of columns in the result set
29    pub fn num_columns(&self) -> u16 {
30        self.num_columns.get()
31    }
32
33    /// Get the number of parameters in the prepared statement
34    pub fn num_params(&self) -> u16 {
35        self.num_params.get()
36    }
37
38    /// Get the warning count
39    pub fn warning_count(&self) -> u16 {
40        self.warning_count.get()
41    }
42}
43
44/// Write COM_STMT_PREPARE command
45pub fn write_prepare(out: &mut Vec<u8>, sql: &str) {
46    write_int_1(out, CommandByte::StmtPrepare as u8);
47    out.extend_from_slice(sql.as_bytes());
48}
49
50/// Read COM_STMT_PREPARE response
51pub fn read_prepare_ok(payload: &[u8]) -> Result<&PrepareOk> {
52    let (status, data) = read_int_1(payload)?;
53    debug_assert_eq!(status, 0x00);
54    Ok(PrepareOk::ref_from_bytes(&data[..11])?)
55}
56
57/// Write COM_STMT_EXECUTE command
58pub fn write_execute<P: Params>(out: &mut Vec<u8>, statement_id: u32, params: P) -> Result<()> {
59    write_int_1(out, CommandByte::StmtExecute as u8);
60    write_int_4(out, statement_id);
61
62    // flags (1 byte) - CURSOR_TYPE_NO_CURSOR
63    write_int_1(out, 0x00);
64
65    // iteration count (4 bytes) - always 1
66    write_int_4(out, 1);
67
68    let num_params = params.len();
69
70    if num_params > 0 {
71        // NULL bitmap: (num_params + 7) / 8 bytes
72        params.encode_null_bitmap(out);
73
74        // new-params-bound-flag (1 byte)
75        let send_types_to_server = true;
76        if send_types_to_server {
77            write_int_1(out, 0x01);
78            params.encode_types(out);
79        } else {
80            write_int_1(out, 0x00);
81        }
82
83        params.encode_values(out)?; // Ignore errors for now (non-priority)
84    }
85    Ok(())
86}
87
88/// Read COM_STMT_EXECUTE response
89/// This can be either an OK packet or a result set
90pub fn read_execute_response(payload: &[u8], cache_metadata: bool) -> Result<ExecuteResponse<'_>> {
91    if payload.is_empty() {
92        return Err(Error::LibraryBug(eyre!(
93            "read_execute_response: empty payload"
94        )));
95    }
96
97    match payload[0] {
98        0x00 => Ok(ExecuteResponse::Ok(OkPayloadBytes(payload))),
99        0xFF => Err(ErrPayloadBytes(payload).into()),
100        _ => {
101            let (column_count, rest) = read_int_lenenc(payload)?;
102
103            // If MARIADB_CLIENT_CACHE_METADATA is set, read the metadata_follows flag
104            let has_column_metadata = if cache_metadata {
105                if rest.is_empty() {
106                    return Err(Error::LibraryBug(eyre!(
107                        "read_execute_response: missing metadata_follows flag"
108                    )));
109                }
110                rest[0] != 0
111            } else {
112                // Without caching, metadata always follows
113                true
114            };
115
116            Ok(ExecuteResponse::ResultSet {
117                column_count,
118                has_column_metadata,
119            })
120        }
121    }
122}
123
124/// Execute response variants
125#[derive(Debug)]
126pub enum ExecuteResponse<'a> {
127    Ok(OkPayloadBytes<'a>),
128    ResultSet {
129        column_count: u64,
130        has_column_metadata: bool,
131    },
132}
133
134/// Read binary protocol row from execute response
135pub fn read_binary_row<'a>(payload: &'a [u8], num_columns: usize) -> Result<BinaryRowPayload<'a>> {
136    crate::protocol::command::resultset::read_binary_row(payload, num_columns)
137}
138
139/// Write COM_STMT_CLOSE command
140pub fn write_close_statement(out: &mut Vec<u8>, statement_id: u32) {
141    write_int_1(out, CommandByte::StmtClose as u8);
142    write_int_4(out, statement_id);
143}
144
145/// Write COM_STMT_RESET command
146pub fn write_reset_statement(out: &mut Vec<u8>, statement_id: u32) {
147    write_int_1(out, CommandByte::StmtReset as u8);
148    write_int_4(out, statement_id);
149}
150
151// ============================================================================
152// State Machine API for exec_fold
153// ============================================================================
154
155use crate::PreparedStatement;
156use crate::protocol::command::ColumnDefinitions;
157use crate::protocol::r#trait::BinaryResultSetHandler;
158
159/// Internal state of the Exec state machine
160enum ExecState {
161    /// Initial state - need to read first packet
162    Start,
163    /// Reading the first response packet
164    ReadingFirstPacket,
165    /// Reading column definitions (processing the buffer after reading all packets)
166    ReadingColumns { num_columns: usize },
167    /// Reading rows
168    ReadingRows { num_columns: usize },
169    /// Finished
170    Finished,
171}
172
173/// State machine for executing prepared statements (binary protocol) with integrated handler
174///
175/// The handler is provided at construction and called directly by the state machine.
176/// The `drive()` method returns actions indicating what I/O operation is needed next.
177pub struct Exec<'h, 'stmt, H> {
178    state: ExecState,
179    handler: &'h mut H,
180    stmt: &'stmt mut PreparedStatement,
181    cache_metadata: bool,
182}
183
184impl<'h, 'stmt, H: BinaryResultSetHandler> Exec<'h, 'stmt, H> {
185    /// Create a new Exec state machine with the given handler and prepared statement
186    pub fn new(
187        handler: &'h mut H,
188        stmt: &'stmt mut PreparedStatement,
189        cache_metadata: bool,
190    ) -> Self {
191        Self {
192            state: ExecState::Start,
193            handler,
194            stmt,
195            cache_metadata,
196        }
197    }
198
199    /// Drive the state machine forward
200    ///
201    /// # Arguments
202    /// * `buffer_set` - The buffer set containing buffers to read from/write to
203    ///
204    /// # Returns
205    /// * `Action::NeedPacket(&mut Vec<u8>)` - Needs more data in the specified buffer
206    /// * `Action::Finished` - Processing complete
207    pub fn step<'buf>(
208        &mut self,
209        buffer_set: &'buf mut BufferSet,
210    ) -> Result<crate::protocol::command::Action<'buf>> {
211        use crate::protocol::command::Action;
212        match &mut self.state {
213            ExecState::Start => {
214                // Request the first packet
215                self.state = ExecState::ReadingFirstPacket;
216                Ok(Action::NeedPacket(&mut buffer_set.read_buffer))
217            }
218
219            ExecState::ReadingFirstPacket => {
220                let payload = &buffer_set.read_buffer[..];
221                let response = read_execute_response(payload, self.cache_metadata)?;
222
223                match response {
224                    ExecuteResponse::Ok(ok_bytes) => {
225                        // Parse OK packet to check status flags
226                        use crate::constant::ServerStatusFlags;
227                        use crate::protocol::response::OkPayload;
228
229                        let ok_payload = OkPayload::try_from(ok_bytes)?;
230                        self.handler.no_result_set(ok_bytes)?;
231
232                        // Check if there are more results to come
233                        if ok_payload
234                            .status_flags
235                            .contains(ServerStatusFlags::SERVER_MORE_RESULTS_EXISTS)
236                        {
237                            // More resultsets coming, go to ReadingFirstPacket to process next result
238                            self.state = ExecState::ReadingFirstPacket;
239                            Ok(Action::NeedPacket(&mut buffer_set.read_buffer))
240                        } else {
241                            // No more results, we're done
242                            self.state = ExecState::Finished;
243                            Ok(Action::Finished)
244                        }
245                    }
246                    ExecuteResponse::ResultSet {
247                        column_count,
248                        has_column_metadata,
249                    } => {
250                        let num_columns = column_count as usize;
251
252                        if has_column_metadata {
253                            // Server sent metadata, signal that we need to read N column packets
254                            self.state = ExecState::ReadingColumns { num_columns };
255                            Ok(Action::ReadColumnMetadata { num_columns })
256                        } else {
257                            // No metadata from server, use cached definitions
258                            if let Some(cols) = self.stmt.column_definitions() {
259                                self.handler.resultset_start(cols)?;
260                                self.state = ExecState::ReadingRows { num_columns };
261                                Ok(Action::NeedPacket(&mut buffer_set.read_buffer))
262                            } else {
263                                // No cache available but server didn't send metadata - error
264                                Err(Error::LibraryBug(eyre!(
265                                    "no cached column definitions available"
266                                )))
267                            }
268                        }
269                    }
270                }
271            }
272
273            ExecState::ReadingColumns { num_columns } => {
274                // Parse all column definitions from the buffer
275                // The buffer contains [len(u32)][payload][len(u32)][payload]...
276                let column_defs = ColumnDefinitions::new(
277                    *num_columns,
278                    std::mem::take(&mut buffer_set.column_definition_buffer),
279                )?;
280
281                // Cache the column definitions in the prepared statement
282                self.handler.resultset_start(column_defs.definitions())?;
283                self.stmt.set_column_definitions(column_defs);
284
285                // Move to reading rows
286                self.state = ExecState::ReadingRows {
287                    num_columns: *num_columns,
288                };
289                Ok(Action::NeedPacket(&mut buffer_set.read_buffer))
290            }
291
292            ExecState::ReadingRows { num_columns } => {
293                let payload = &buffer_set.read_buffer[..];
294                match payload[0] {
295                    0x00 => {
296                        let row = read_binary_row(payload, *num_columns)?;
297                        let cols = self.stmt.column_definitions().ok_or_else(|| {
298                            Error::LibraryBug(eyre!("no column definitions while reading rows"))
299                        })?;
300                        self.handler.row(cols, row)?;
301                        Ok(Action::NeedPacket(&mut buffer_set.read_buffer))
302                    }
303                    0xFE => {
304                        // Parse OK packet to check status flags
305                        use crate::constant::ServerStatusFlags;
306                        use crate::protocol::response::OkPayload;
307
308                        let eof_bytes = OkPayloadBytes(payload);
309                        eof_bytes.assert_eof()?;
310                        let ok_payload = OkPayload::try_from(eof_bytes)?;
311                        self.handler.resultset_end(eof_bytes)?;
312
313                        // Check if there are more results to come
314                        if ok_payload
315                            .status_flags
316                            .contains(ServerStatusFlags::SERVER_MORE_RESULTS_EXISTS)
317                        {
318                            // More resultsets coming, go to ReadingFirstPacket to process next result
319                            self.state = ExecState::ReadingFirstPacket;
320                            Ok(Action::NeedPacket(&mut buffer_set.read_buffer))
321                        } else {
322                            // No more results, we're done
323                            self.state = ExecState::Finished;
324                            Ok(Action::Finished)
325                        }
326                    }
327                    header => Err(Error::LibraryBug(eyre!(
328                        "unexpected row packet header: 0x{:02X}",
329                        header
330                    ))),
331                }
332            }
333
334            ExecState::Finished => {
335                Err(Error::LibraryBug(eyre!("Exec::step called after finished")))
336            }
337        }
338    }
339}