sqlx_xugu/connection/
executor.rs

1use crate::error::Error;
2use crate::io::AsyncStreamExt;
3use crate::protocol::message::*;
4use crate::protocol::statement::{Execute as StatementExecute, Prepare, StmtClose};
5use crate::protocol::text::{ColumnFlags, OkPacket, Query};
6use crate::protocol::ServerContext;
7use crate::statement::{XuguStatement, XuguStatementMetadata};
8use crate::{
9    Xugu, XuguArguments, XuguConnection, XuguDatabaseError, XuguQueryResult, XuguRow, XuguTypeInfo,
10};
11use futures_core::future::BoxFuture;
12use futures_core::stream::BoxStream;
13use futures_core::Stream;
14use futures_util::TryStreamExt;
15use log::Level;
16use sqlx_core::describe::Describe;
17use sqlx_core::executor::{Execute, Executor};
18use sqlx_core::logger::QueryLogger;
19use sqlx_core::{try_stream, Either, HashMap};
20use std::{borrow::Cow, pin::pin, sync::Arc};
21
22impl XuguConnection {
23    async fn prepare_statement<'c>(
24        &mut self,
25        sql: &str,
26    ) -> Result<(u32, XuguStatementMetadata), Error> {
27        // flush and wait until we are re-ready
28        self.wait_until_ready().await?;
29
30        let id = self.inner.gen_st_id();
31        self.inner
32            .stream
33            .send_packet(Prepare {
34                query: sql,
35                con_obj_name: &self.inner.con_obj_name,
36                st_id: id,
37            })
38            .await?;
39
40        let mut error = None;
41        let mut columns = Vec::new();
42        let mut column_names = HashMap::new();
43        let mut params = Vec::new();
44
45        loop {
46            let message: ReceivedMessage = self.inner.stream.recv().await?;
47            let cnt = ServerContext::new(self.inner.stream.server_version);
48            match message.format {
49                BackendMessageFormat::ErrorResponse => {
50                    let err: ErrorResponse = message.decode(&mut self.inner.stream, cnt).await?;
51                    error = Some(err.error);
52                }
53                BackendMessageFormat::MessageResponse => {
54                    // 读到服务器端返回消息用对话框抛出
55                    // 警告和信息
56                    let notice: MessageResponse =
57                        message.decode(&mut self.inner.stream, cnt).await?;
58                    let (log_level, tracing_level) = (Level::Info, tracing::Level::INFO);
59                    let log_is_enabled = log::log_enabled!(
60                        target: "sqlx::xugu::notice",
61                        log_level
62                    ) || sqlx_core::private_tracing_dynamic_enabled!(
63                        target: "sqlx::xugu::notice",
64                        tracing_level
65                    );
66                    if log_is_enabled {
67                        sqlx_core::private_tracing_dynamic_event!(
68                            target: "sqlx::xugu::notice",
69                            tracing_level,
70                            message = notice.msg
71                        );
72                    }
73                }
74                BackendMessageFormat::ReadyForQuery => {
75                    let _: ReadyForQuery = message.decode(&mut self.inner.stream, cnt).await?;
76                    break;
77                }
78                BackendMessageFormat::RowDescription => {
79                    let row_columns: RowDescription =
80                        message.decode(&mut self.inner.stream, cnt).await?;
81                    (columns, column_names) = row_columns.convert_columns()?;
82                }
83                BackendMessageFormat::ParameterDescription => {
84                    let param_def: ParameterDescription =
85                        message.decode(&mut self.inner.stream, cnt).await?;
86                    params = param_def.params;
87                }
88                _ => {
89                    break;
90                }
91            }
92        }
93
94        if let Some(err) = error {
95            return Err(Error::Database(Box::new(XuguDatabaseError::from_str(&err))));
96        }
97
98        let metadata = XuguStatementMetadata {
99            parameters: Arc::new(params),
100            columns: Arc::new(columns),
101            column_names: Arc::new(column_names),
102        };
103
104        Ok((id, metadata))
105    }
106
107    async fn get_or_prepare_statement<'c>(
108        &mut self,
109        sql: &str,
110    ) -> Result<(u32, XuguStatementMetadata), Error> {
111        if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
112            // <XuguStatementMetadata> is internally reference-counted
113            return Ok((*statement).clone());
114        }
115
116        let (id, metadata) = self.prepare_statement(sql).await?;
117
118        // in case of the cache being full, close the least recently used statement
119        if let Some((id, _)) = self
120            .inner
121            .cache_statement
122            .insert(sql, (id, metadata.clone()))
123        {
124            // flush and wait until we are re-ready
125            self.wait_until_ready().await?;
126            self.inner
127                .stream
128                .send_packet(StmtClose {
129                    con_obj_name: &self.inner.con_obj_name,
130                    st_id: id,
131                })
132                .await?;
133
134            // for StmtClose
135            let _ok: OkPacket = self.inner.stream.recv().await?;
136        }
137
138        Ok((id, metadata))
139    }
140
141    ///
142    ///
143    /// # Arguments
144    ///
145    /// * `sql`:
146    /// * `arguments`:
147    /// * `persistent`: sql 语句是否需要被缓存
148    ///
149    #[allow(clippy::needless_lifetimes)]
150    pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
151        &'c mut self,
152        sql: &'q str,
153        arguments: Option<XuguArguments<'q>>,
154        persistent: bool,
155    ) -> Result<impl Stream<Item = Result<Either<XuguQueryResult, XuguRow>, Error>> + 'e, Error>
156    {
157        let mut logger = QueryLogger::new(sql, self.inner.log_settings.clone());
158
159        self.wait_until_ready().await?;
160
161        // make a slot for the shared column data
162        // as long as a reference to a row is not held past one iteration, this enables us
163        // to re-use this memory freely between result sets
164        let (mut column_names, mut columns, mut needs_metadata) = if let Some(arguments) = arguments
165        {
166            if persistent && self.inner.cache_statement.is_enabled() {
167                let (id, metadata) = self.get_or_prepare_statement(sql).await?;
168
169                self.inner
170                    .stream
171                    .send_packet(StatementExecute {
172                        con_obj_name: &self.inner.con_obj_name,
173                        st_id: id,
174                        arguments: &arguments,
175                        params: &metadata.parameters,
176                    })
177                    .await?;
178
179                let needs_metadata = metadata.column_names.is_empty();
180                (metadata.column_names, metadata.columns, needs_metadata)
181            } else {
182                let (id, metadata) = self.prepare_statement(sql).await?;
183
184                self.inner
185                    .stream
186                    .send_packet(StatementExecute {
187                        con_obj_name: &self.inner.con_obj_name,
188                        st_id: id,
189                        arguments: &arguments,
190                        params: &metadata.parameters,
191                    })
192                    .await?;
193
194                self.inner
195                    .stream
196                    .send_packet(StmtClose {
197                        con_obj_name: &self.inner.con_obj_name,
198                        st_id: id,
199                    })
200                    .await?;
201                // for StmtClose
202                self.inner.pending_ready_for_query_count += 1;
203
204                let needs_metadata = metadata.column_names.is_empty();
205                (metadata.column_names, metadata.columns, needs_metadata)
206            }
207        } else {
208            self.inner.stream.send_packet(Query(sql)).await?;
209
210            (Arc::default(), Arc::default(), true)
211        };
212
213        self.inner.pending_ready_for_query_count += 1;
214
215        let mut error = None;
216
217        let mut num_columns = 0;
218
219        Ok(try_stream! {
220            loop {
221                let message: ReceivedMessage = self.inner.stream.recv().await?;
222                let cnt = ServerContext::new(self.inner.stream.server_version);
223                match message.format {
224                    BackendMessageFormat::ErrorResponse => {
225                        let err: ErrorResponse = message.decode(&mut self.inner.stream, cnt).await?;
226                        error = Some(err.error);
227                    },
228                    BackendMessageFormat::MessageResponse => {
229                        // 读到服务器端返回消息用对话框抛出
230                        // 警告和信息
231                        let notice: MessageResponse = message.decode(&mut self.inner.stream, cnt).await?;
232                        let (log_level, tracing_level) = (Level::Info, tracing::Level::INFO);
233                        let log_is_enabled = log::log_enabled!(
234                            target: "sqlx::xugu::notice",
235                            log_level
236                        ) || sqlx_core::private_tracing_dynamic_enabled!(
237                            target: "sqlx::xugu::notice",
238                            tracing_level
239                        );
240                        if log_is_enabled {
241                            sqlx_core::private_tracing_dynamic_event!(
242                                target: "sqlx::xugu::notice",
243                                tracing_level,
244                                message = notice.msg
245                            );
246                        }
247                    },
248                    BackendMessageFormat::ReadyForQuery => {
249                        //命令结束 / 错误结束
250                        let _: ReadyForQuery = message.decode(&mut self.inner.stream, cnt).await?;
251                        self.handle_ready_for_query().await?;
252                        break;
253                    },
254                    BackendMessageFormat::InsertResponse => {
255                        let res: InsertResponse = message.decode(&mut self.inner.stream, cnt).await?;
256                        let rows_affected = 1;
257                        logger.increase_rows_affected(rows_affected);
258                        let done = XuguQueryResult {
259                            rows_affected,
260                            last_insert_id: Some(res.rowid),
261                        };
262                        r#yield!(Either::Left(done));
263                    },
264                    BackendMessageFormat::DeleteResponse => {
265                        let res: DeleteResponse = message.decode(&mut self.inner.stream, cnt).await?;
266                        let rows_affected = res.rows_affected as u64;
267                        logger.increase_rows_affected(rows_affected);
268                        let done = XuguQueryResult {
269                            rows_affected,
270                            last_insert_id: None,
271                        };
272                        r#yield!(Either::Left(done));
273                    },
274                    BackendMessageFormat::UpdateResponse => {
275                        let res: UpdateResponse = message.decode(&mut self.inner.stream, cnt).await?;
276                        let rows_affected = res.rows_affected as u64;
277                        logger.increase_rows_affected(rows_affected);
278                        let done = XuguQueryResult {
279                            rows_affected,
280                            last_insert_id: None,
281                        };
282                        r#yield!(Either::Left(done));
283                    },
284                    BackendMessageFormat::RowDescription => {
285                        // 接收列数据
286                        let row_columns: RowDescription = message.decode(&mut self.inner.stream, cnt).await?;
287                        num_columns = row_columns.fields.len();
288                        self.inner.last_num_columns = num_columns;
289                        if needs_metadata {
290                            let (columns_c, column_names_c) = row_columns.convert_columns()?;
291                            columns = Arc::new(columns_c);
292                            column_names = Arc::new(column_names_c);
293                        } else {
294                            // next time we hit here, it'll be a new result set and we'll need the
295                            // full metadata
296                            needs_metadata = true;
297                        }
298                    },
299                    BackendMessageFormat::ParameterDescription => {
300                        let _: ParameterDescription = message.decode(&mut self.inner.stream, cnt).await?;
301                    },
302                    BackendMessageFormat::DataRow => {
303                        // 接收行数据
304                        let _: DataRow = message.decode(&mut self.inner.stream, cnt).await?;
305                        let mut row = Vec::with_capacity(num_columns);
306                        for _ in 0..num_columns {
307                            let len = self.inner.stream.read_i32().await?;
308                            let buf = self.inner.stream.read_bytes(len as usize).await?;
309                            row.push(buf);
310                        }
311                        let row = Arc::new(row);
312
313                        let v = Either::Right(XuguRow {
314                            row,
315                            columns: Arc::clone(&columns),
316                            column_names: Arc::clone(&column_names),
317                        });
318
319                        logger.increment_rows_returned();
320
321                        r#yield!(v);
322                    }
323                }
324            }
325
326            if let Some(err) = error {
327                return Err(Error::Database(Box::new(XuguDatabaseError::from_str(&err))));
328            }
329
330            return Ok(());
331        })
332    }
333}
334
335impl<'c> Executor<'c> for &'c mut XuguConnection {
336    type Database = Xugu;
337
338    /// 执行多个查询,并将生成的结果作为每个查询的流返回。
339    fn fetch_many<'e, 'q, E>(
340        self,
341        mut query: E,
342    ) -> BoxStream<'e, Result<Either<XuguQueryResult, XuguRow>, Error>>
343    where
344        'c: 'e,
345        E: Execute<'q, Self::Database>,
346        'q: 'e,
347        E: 'q,
348    {
349        let sql = query.sql();
350        let arguments = query.take_arguments().map_err(Error::Encode);
351        let persistent = query.persistent();
352
353        Box::pin(try_stream! {
354            let arguments = arguments?;
355            let mut s = pin!(self.run(sql, arguments, persistent).await?);
356
357            while let Some(v) = s.try_next().await? {
358                r#yield!(v);
359            }
360
361            Ok(())
362        })
363    }
364
365    /// 执行查询并最多返回一行。
366    fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result<Option<XuguRow>, Error>>
367    where
368        'c: 'e,
369        E: Execute<'q, Self::Database>,
370        'q: 'e,
371        E: 'q,
372    {
373        let mut s = self.fetch_many(query);
374
375        Box::pin(async move {
376            while let Some(v) = s.try_next().await? {
377                if let Either::Right(r) = v {
378                    return Ok(Some(r));
379                }
380            }
381
382            Ok(None)
383        })
384    }
385
386    /// 准备 SQL 查询,其中包含参数类型信息,以检查有关其参数和结果的类型信息。
387    ///
388    /// 只有某些数据库驱动程序(PostgreSQL、MSSQL)可以利用此额外信息来影响参数类型推断。
389    fn prepare_with<'e, 'q: 'e>(
390        self,
391        sql: &'q str,
392        _parameters: &'e [XuguTypeInfo],
393    ) -> BoxFuture<'e, Result<XuguStatement<'q>, Error>>
394    where
395        'c: 'e,
396    {
397        Box::pin(async move {
398            self.wait_until_ready().await?;
399
400            let metadata = if self.inner.cache_statement.is_enabled() {
401                self.get_or_prepare_statement(sql).await?.1
402            } else {
403                let (id, metadata) = self.prepare_statement(sql).await?;
404
405                self.inner
406                    .stream
407                    .send_packet(StmtClose {
408                        con_obj_name: &self.inner.con_obj_name,
409                        st_id: id,
410                    })
411                    .await?;
412                // for StmtClose
413                let _ok: OkPacket = self.inner.stream.recv().await?;
414
415                metadata
416            };
417
418            Ok(XuguStatement {
419                sql: Cow::Borrowed(sql),
420                // metadata has internal Arcs for expensive data structures
421                metadata: metadata.clone(),
422            })
423        })
424    }
425
426    /// 描述有关其参数和结果的 SQL 查询和返回类型信息。
427    ///
428    /// 查询宏中的编译时验证使用它来支持其类型推断。
429    #[doc(hidden)]
430    fn describe<'e, 'q: 'e>(self, sql: &'q str) -> BoxFuture<'e, Result<Describe<Xugu>, Error>>
431    where
432        'c: 'e,
433    {
434        Box::pin(async move {
435            self.wait_until_ready().await?;
436
437            let (id, metadata) = self.prepare_statement(sql).await?;
438
439            self.inner
440                .stream
441                .send_packet(StmtClose {
442                    con_obj_name: &self.inner.con_obj_name,
443                    st_id: id,
444                })
445                .await?;
446            // for StmtClose
447            let _ok: OkPacket = self.inner.stream.recv().await?;
448
449            let columns = (*metadata.columns).clone();
450
451            let nullable = columns
452                .iter()
453                .map(|col| {
454                    col.flags
455                        .map(|flags| !flags.contains(ColumnFlags::NOT_NULL))
456                })
457                .collect();
458
459            Ok(Describe {
460                parameters: Some(Either::Right(metadata.parameters.len())),
461                columns,
462                nullable,
463            })
464        })
465    }
466}