zero_mysql/protocol/command/
bulk_exec.rs

1use crate::PreparedStatement;
2use crate::buffer::BufferSet;
3use crate::constant::CommandByte;
4use crate::error::{Error, Result, eyre};
5use crate::protocol::command::ColumnDefinitions;
6use crate::protocol::command::prepared::read_binary_row;
7use crate::protocol::primitive::*;
8use crate::protocol::response::{ErrPayloadBytes, OkPayloadBytes};
9use crate::protocol::r#trait::BinaryResultSetHandler;
10use crate::protocol::r#trait::param::TypedParams;
11
12bitflags::bitflags! {
13    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
14    pub struct BulkFlags: u16 {
15        const SEND_UNIT_RESULTS = 64;
16        const SEND_TYPES_TO_SERVER = 128;
17    }
18}
19
20pub trait BulkParamsSet {
21    fn encode_types(&self, out: &mut Vec<u8>);
22    fn encode_rows(self, out: &mut Vec<u8>) -> Result<()>;
23}
24
25impl<P: TypedParams> BulkParamsSet for &[P] {
26    fn encode_types(&self, out: &mut Vec<u8>) {
27        P::encode_types(out);
28    }
29
30    fn encode_rows(self, out: &mut Vec<u8>) -> Result<()> {
31        for params in self {
32            params.encode_values_for_bulk(out)?;
33        }
34        Ok(())
35    }
36}
37
38pub fn write_bulk_execute<P: BulkParamsSet>(
39    out: &mut Vec<u8>,
40    statement_id: u32,
41    params: P,
42    flags: BulkFlags,
43) -> Result<()> {
44    write_int_1(out, CommandByte::StmtBulkExecute as u8);
45    write_int_4(out, statement_id);
46    write_int_2(out, flags.bits());
47
48    if flags.contains(BulkFlags::SEND_TYPES_TO_SERVER) {
49        params.encode_types(out);
50    }
51
52    params.encode_rows(out)?;
53    Ok(())
54}
55
56pub fn read_bulk_execute_response(
57    payload: &[u8],
58    cache_metadata: bool,
59) -> Result<BulkExecuteResponse<'_>> {
60    if payload.is_empty() {
61        return Err(Error::LibraryBug(eyre!(
62            "read_bulk_execute_response: empty payload"
63        )));
64    }
65
66    match payload[0] {
67        0x00 => Ok(BulkExecuteResponse::Ok(OkPayloadBytes(payload))),
68        0xFF => Err(ErrPayloadBytes(payload).into()),
69        _ => {
70            let (column_count, rest) = read_int_lenenc(payload)?;
71
72            // If MARIADB_CLIENT_CACHE_METADATA is set, read the metadata_follows flag
73            let has_column_metadata = if cache_metadata {
74                if rest.is_empty() {
75                    return Err(Error::LibraryBug(eyre!(
76                        "read_bulk_execute_response: missing metadata_follows flag"
77                    )));
78                }
79                rest[0] != 0
80            } else {
81                // Without caching, metadata always follows
82                true
83            };
84
85            Ok(BulkExecuteResponse::ResultSet {
86                column_count,
87                has_column_metadata,
88            })
89        }
90    }
91}
92
93#[derive(Debug)]
94pub enum BulkExecuteResponse<'a> {
95    Ok(OkPayloadBytes<'a>),
96    ResultSet {
97        column_count: u64,
98        has_column_metadata: bool,
99    },
100}
101
102enum BulkExecState {
103    Start,
104    ReadingFirstPacket,
105    ReadingColumns { num_columns: usize },
106    ReadingRows { num_columns: usize },
107    Finished,
108}
109
110pub struct BulkExec<'h, 'stmt, H> {
111    state: BulkExecState,
112    handler: &'h mut H,
113    stmt: &'stmt mut PreparedStatement,
114    cache_metadata: bool,
115}
116
117impl<'h, 'stmt, H: BinaryResultSetHandler> BulkExec<'h, 'stmt, H> {
118    pub fn new(
119        handler: &'h mut H,
120        stmt: &'stmt mut PreparedStatement,
121        cache_metadata: bool,
122    ) -> Self {
123        Self {
124            state: BulkExecState::Start,
125            handler,
126            stmt,
127            cache_metadata,
128        }
129    }
130
131    pub fn step<'buf>(
132        &mut self,
133        buffer_set: &'buf mut BufferSet,
134    ) -> Result<crate::protocol::command::Action<'buf>> {
135        use crate::protocol::command::Action;
136        match &mut self.state {
137            BulkExecState::Start => {
138                self.state = BulkExecState::ReadingFirstPacket;
139                Ok(Action::NeedPacket(&mut buffer_set.read_buffer))
140            }
141
142            BulkExecState::ReadingFirstPacket => {
143                let payload = &buffer_set.read_buffer[..];
144                let response = read_bulk_execute_response(payload, self.cache_metadata)?;
145
146                match response {
147                    BulkExecuteResponse::Ok(ok_bytes) => {
148                        self.handler.no_result_set(ok_bytes)?;
149                        self.state = BulkExecState::Finished;
150                        Ok(Action::Finished)
151                    }
152                    BulkExecuteResponse::ResultSet {
153                        column_count,
154                        has_column_metadata,
155                    } => {
156                        let num_columns = column_count as usize;
157
158                        if has_column_metadata {
159                            // Server sent metadata, signal that we need to read N column packets
160                            self.state = BulkExecState::ReadingColumns { num_columns };
161                            Ok(Action::ReadColumnMetadata { num_columns })
162                        } else {
163                            // No metadata from server, use cached definitions
164                            if let Some(cache) = self.stmt.column_definitions() {
165                                self.handler.resultset_start(cache)?;
166                                self.state = BulkExecState::ReadingRows { num_columns };
167                                Ok(Action::NeedPacket(&mut buffer_set.read_buffer))
168                            } else {
169                                // No cache available but server didn't send metadata - error
170                                Err(Error::LibraryBug(eyre!(
171                                    "no cached column definitions available"
172                                )))
173                            }
174                        }
175                    }
176                }
177            }
178
179            BulkExecState::ReadingColumns { num_columns } => {
180                // Parse all column definitions from the buffer
181                // The buffer contains [len(u32)][payload][len(u32)][payload]...
182                let column_defs = ColumnDefinitions::new(
183                    *num_columns,
184                    std::mem::take(&mut buffer_set.column_definition_buffer),
185                )?;
186
187                // Cache the column definitions in the prepared statement
188                self.handler.resultset_start(column_defs.definitions())?;
189                self.stmt.set_column_definitions(column_defs);
190
191                // Move to reading rows
192                self.state = BulkExecState::ReadingRows {
193                    num_columns: *num_columns,
194                };
195                Ok(Action::NeedPacket(&mut buffer_set.read_buffer))
196            }
197
198            BulkExecState::ReadingRows { num_columns } => {
199                let payload = &buffer_set.read_buffer[..];
200                match payload[0] {
201                    0x00 => {
202                        let row = read_binary_row(payload, *num_columns)?;
203                        let cols = self.stmt.column_definitions().ok_or_else(|| {
204                            Error::LibraryBug(eyre!("no column definitions while reading rows"))
205                        })?;
206                        self.handler.row(cols, row)?;
207                        Ok(Action::NeedPacket(&mut buffer_set.read_buffer))
208                    }
209                    0xFE => {
210                        let eof_bytes = OkPayloadBytes(payload);
211                        self.handler.resultset_end(eof_bytes)?;
212                        self.state = BulkExecState::Finished;
213                        Ok(Action::Finished)
214                    }
215                    header => Err(Error::LibraryBug(eyre!(
216                        "unexpected row packet header: 0x{:02X}",
217                        header
218                    ))),
219                }
220            }
221
222            BulkExecState::Finished => Err(Error::LibraryBug(eyre!(
223                "BulkExec::step called after finished"
224            ))),
225        }
226    }
227}