Skip to main content

sqlx_xugu/connection/
executor.rs

1use crate::connection::StatementId;
2use crate::error::Error;
3use crate::io::AsyncStreamExt;
4use crate::protocol::message::*;
5use crate::protocol::statement::{Execute as StatementExecute, Prepare, StmtClose};
6use crate::protocol::text::{ColumnFlags, OkPacket, Query};
7use crate::protocol::ServerContext;
8use crate::statement::{XuguStatement, XuguStatementMetadata};
9use crate::{
10    Xugu, XuguArguments, XuguConnection, XuguDatabaseError, XuguQueryResult, XuguRow, XuguTypeInfo,
11};
12use futures_core::future::BoxFuture;
13use futures_core::stream::BoxStream;
14use futures_core::Stream;
15use futures_util::TryStreamExt;
16use log::Level;
17use sqlx_core::describe::Describe;
18use sqlx_core::executor::{Execute, Executor};
19use sqlx_core::logger::QueryLogger;
20use sqlx_core::{try_stream, Either, HashMap};
21use std::{borrow::Cow, pin::pin, sync::Arc};
22
23impl XuguConnection {
24    async fn prepare_statement<'c>(
25        &mut self,
26        sql: &str,
27    ) -> Result<(StatementId, XuguStatementMetadata), Error> {
28        // flush and wait until we are re-ready
29        self.wait_until_ready().await?;
30
31        let id = self.inner.gen_st_id();
32        self.inner
33            .stream
34            .send_packet(Prepare {
35                query: sql,
36                st_id: id,
37            })
38            .await?;
39
40        let mut error = None;
41        let mut columns = Vec::new();
42        let mut column_names = HashMap::new();
43        let mut params = Vec::new();
44
45        loop {
46            let message: ReceivedMessage = self.inner.stream.recv().await?;
47            let cnt = ServerContext::new(self.inner.stream.server_version);
48            match message.format {
49                BackendMessageFormat::ErrorResponse => {
50                    let err: ErrorResponse = message.decode(&mut self.inner.stream, cnt).await?;
51                    error = Some(err.error);
52                }
53                BackendMessageFormat::MessageResponse => {
54                    // 读到服务器端返回消息用对话框抛出
55                    // 警告和信息
56                    let notice: MessageResponse =
57                        message.decode(&mut self.inner.stream, cnt).await?;
58                    let (log_level, tracing_level) = (Level::Info, tracing::Level::INFO);
59                    let log_is_enabled = log::log_enabled!(
60                        target: "sqlx::xugu::notice",
61                        log_level
62                    ) || sqlx_core::private_tracing_dynamic_enabled!(
63                        target: "sqlx::xugu::notice",
64                        tracing_level
65                    );
66                    if log_is_enabled {
67                        sqlx_core::private_tracing_dynamic_event!(
68                            target: "sqlx::xugu::notice",
69                            tracing_level,
70                            message = notice.msg
71                        );
72                    }
73                }
74                BackendMessageFormat::ReadyForQuery => {
75                    let _: ReadyForQuery = message.decode(&mut self.inner.stream, cnt).await?;
76                    break;
77                }
78                BackendMessageFormat::RowDescription => {
79                    let row_columns: RowDescription =
80                        message.decode(&mut self.inner.stream, cnt).await?;
81                    (columns, column_names) = row_columns.convert_columns()?;
82                }
83                BackendMessageFormat::ParameterDescription => {
84                    let param_def: ParameterDescription =
85                        message.decode(&mut self.inner.stream, cnt).await?;
86                    params = param_def.params;
87                }
88                _ => {
89                    break;
90                }
91            }
92        }
93
94        if let Some(err) = error {
95            return Err(Error::Database(Box::new(XuguDatabaseError::from_str(&err))));
96        }
97
98        let metadata = XuguStatementMetadata {
99            parameters: Arc::new(params),
100            columns: Arc::new(columns),
101            column_names: Arc::new(column_names),
102        };
103
104        Ok((id, metadata))
105    }
106
107    async fn get_or_prepare_statement<'c>(
108        &mut self,
109        sql: &str,
110    ) -> Result<(StatementId, XuguStatementMetadata), Error> {
111        if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
112            // <XuguStatementMetadata> is internally reference-counted
113            return Ok((*statement).clone());
114        }
115
116        let (id, metadata) = self.prepare_statement(sql).await?;
117
118        // in case of the cache being full, close the least recently used statement
119        if let Some((id, _)) = self
120            .inner
121            .cache_statement
122            .insert(sql, (id, metadata.clone()))
123        {
124            // flush and wait until we are re-ready
125            self.wait_until_ready().await?;
126            self.inner.stream.send_packet(StmtClose(id)).await?;
127
128            // for StmtClose
129            let _ok: OkPacket = self.inner.stream.recv().await?;
130        }
131
132        Ok((id, metadata))
133    }
134
135    ///
136    ///
137    /// # Arguments
138    ///
139    /// * `sql`:
140    /// * `arguments`:
141    /// * `persistent`: sql 语句是否需要被缓存
142    ///
143    #[allow(clippy::needless_lifetimes)]
144    pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
145        &'c mut self,
146        sql: &'q str,
147        arguments: Option<XuguArguments<'q>>,
148        persistent: bool,
149    ) -> Result<impl Stream<Item = Result<Either<XuguQueryResult, XuguRow>, Error>> + 'e, Error>
150    {
151        let mut logger = QueryLogger::new(sql, self.inner.log_settings.clone());
152
153        self.wait_until_ready().await?;
154
155        // make a slot for the shared column data
156        // as long as a reference to a row is not held past one iteration, this enables us
157        // to re-use this memory freely between result sets
158        let (mut column_names, mut columns, mut needs_metadata) = if let Some(arguments) = arguments
159        {
160            if persistent && self.inner.cache_statement.is_enabled() {
161                let (id, metadata) = self.get_or_prepare_statement(sql).await?;
162
163                self.inner
164                    .stream
165                    .send_packet(StatementExecute {
166                        st_id: id,
167                        arguments: &arguments,
168                        params: &metadata.parameters,
169                    })
170                    .await?;
171
172                let needs_metadata = metadata.column_names.is_empty();
173                (metadata.column_names, metadata.columns, needs_metadata)
174            } else {
175                let (id, metadata) = self.prepare_statement(sql).await?;
176
177                self.inner
178                    .stream
179                    .send_packet(StatementExecute {
180                        st_id: id,
181                        arguments: &arguments,
182                        params: &metadata.parameters,
183                    })
184                    .await?;
185
186                self.inner.stream.send_packet(StmtClose(id)).await?;
187                // for StmtClose
188                self.inner.pending_ready_for_query_count += 1;
189
190                let needs_metadata = metadata.column_names.is_empty();
191                (metadata.column_names, metadata.columns, needs_metadata)
192            }
193        } else {
194            self.inner.stream.send_packet(Query(sql)).await?;
195
196            (Arc::default(), Arc::default(), true)
197        };
198
199        self.inner.pending_ready_for_query_count += 1;
200
201        let mut error = None;
202
203        let mut num_columns = 0;
204
205        Ok(try_stream! {
206            loop {
207                let message: ReceivedMessage = self.inner.stream.recv().await?;
208                let cnt = ServerContext::new(self.inner.stream.server_version);
209                match message.format {
210                    BackendMessageFormat::ErrorResponse => {
211                        let err: ErrorResponse = message.decode(&mut self.inner.stream, cnt).await?;
212                        error = Some(err.error);
213                    },
214                    BackendMessageFormat::MessageResponse => {
215                        // 读到服务器端返回消息用对话框抛出
216                        // 警告和信息
217                        let notice: MessageResponse = message.decode(&mut self.inner.stream, cnt).await?;
218                        let (log_level, tracing_level) = (Level::Info, tracing::Level::INFO);
219                        let log_is_enabled = log::log_enabled!(
220                            target: "sqlx::xugu::notice",
221                            log_level
222                        ) || sqlx_core::private_tracing_dynamic_enabled!(
223                            target: "sqlx::xugu::notice",
224                            tracing_level
225                        );
226                        if log_is_enabled {
227                            sqlx_core::private_tracing_dynamic_event!(
228                                target: "sqlx::xugu::notice",
229                                tracing_level,
230                                message = notice.msg
231                            );
232                        }
233                    },
234                    BackendMessageFormat::ReadyForQuery => {
235                        //命令结束 / 错误结束
236                        let _: ReadyForQuery = message.decode(&mut self.inner.stream, cnt).await?;
237                        self.handle_ready_for_query().await?;
238                        break;
239                    },
240                    BackendMessageFormat::InsertResponse => {
241                        let res: InsertResponse = message.decode(&mut self.inner.stream, cnt).await?;
242                        let rows_affected = 1;
243                        logger.increase_rows_affected(rows_affected);
244                        let done = XuguQueryResult {
245                            rows_affected,
246                            last_insert_id: Some(res.rowid),
247                        };
248                        r#yield!(Either::Left(done));
249                    },
250                    BackendMessageFormat::DeleteResponse => {
251                        let res: DeleteResponse = message.decode(&mut self.inner.stream, cnt).await?;
252                        let rows_affected = res.rows_affected as u64;
253                        logger.increase_rows_affected(rows_affected);
254                        let done = XuguQueryResult {
255                            rows_affected,
256                            last_insert_id: None,
257                        };
258                        r#yield!(Either::Left(done));
259                    },
260                    BackendMessageFormat::UpdateResponse => {
261                        let res: UpdateResponse = message.decode(&mut self.inner.stream, cnt).await?;
262                        let rows_affected = res.rows_affected as u64;
263                        logger.increase_rows_affected(rows_affected);
264                        let done = XuguQueryResult {
265                            rows_affected,
266                            last_insert_id: None,
267                        };
268                        r#yield!(Either::Left(done));
269                    },
270                    BackendMessageFormat::RowDescription => {
271                        // 接收列数据
272                        let row_columns: RowDescription = message.decode(&mut self.inner.stream, cnt).await?;
273                        num_columns = row_columns.fields.len();
274                        self.inner.last_num_columns = num_columns;
275                        if needs_metadata {
276                            let (columns_c, column_names_c) = row_columns.convert_columns()?;
277                            columns = Arc::new(columns_c);
278                            column_names = Arc::new(column_names_c);
279                        } else {
280                            // next time we hit here, it'll be a new result set and we'll need the
281                            // full metadata
282                            needs_metadata = true;
283                        }
284                    },
285                    BackendMessageFormat::ParameterDescription => {
286                        let _: ParameterDescription = message.decode(&mut self.inner.stream, cnt).await?;
287                    },
288                    BackendMessageFormat::DataRow => {
289                        // 接收行数据
290                        let _: DataRow = message.decode(&mut self.inner.stream, cnt).await?;
291                        let mut row = Vec::with_capacity(num_columns);
292                        for _ in 0..num_columns {
293                            let len = self.inner.stream.read_i32().await?;
294                            let buf = self.inner.stream.read_bytes(len as usize).await?;
295                            row.push(buf);
296                        }
297                        let row = Arc::new(row);
298
299                        let v = Either::Right(XuguRow {
300                            row,
301                            columns: Arc::clone(&columns),
302                            column_names: Arc::clone(&column_names),
303                        });
304
305                        logger.increment_rows_returned();
306
307                        r#yield!(v);
308                    }
309                }
310            }
311
312            if let Some(err) = error {
313                return Err(Error::Database(Box::new(XuguDatabaseError::from_str(&err))));
314            }
315
316            return Ok(());
317        })
318    }
319}
320
321impl<'c> Executor<'c> for &'c mut XuguConnection {
322    type Database = Xugu;
323
324    /// 执行多个查询,并将生成的结果作为每个查询的流返回。
325    fn fetch_many<'e, 'q, E>(
326        self,
327        mut query: E,
328    ) -> BoxStream<'e, Result<Either<XuguQueryResult, XuguRow>, Error>>
329    where
330        'c: 'e,
331        E: Execute<'q, Self::Database>,
332        'q: 'e,
333        E: 'q,
334    {
335        let sql = query.sql();
336        let arguments = query.take_arguments().map_err(Error::Encode);
337        let persistent = query.persistent();
338
339        Box::pin(try_stream! {
340            let arguments = arguments?;
341            let mut s = pin!(self.run(sql, arguments, persistent).await?);
342
343            while let Some(v) = s.try_next().await? {
344                r#yield!(v);
345            }
346
347            Ok(())
348        })
349    }
350
351    /// 执行查询并最多返回一行。
352    fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result<Option<XuguRow>, Error>>
353    where
354        'c: 'e,
355        E: Execute<'q, Self::Database>,
356        'q: 'e,
357        E: 'q,
358    {
359        let mut s = self.fetch_many(query);
360
361        Box::pin(async move {
362            while let Some(v) = s.try_next().await? {
363                if let Either::Right(r) = v {
364                    return Ok(Some(r));
365                }
366            }
367
368            Ok(None)
369        })
370    }
371
372    /// 准备 SQL 查询,其中包含参数类型信息,以检查有关其参数和结果的类型信息。
373    ///
374    /// 只有某些数据库驱动程序(PostgreSQL、MSSQL)可以利用此额外信息来影响参数类型推断。
375    fn prepare_with<'e, 'q: 'e>(
376        self,
377        sql: &'q str,
378        _parameters: &'e [XuguTypeInfo],
379    ) -> BoxFuture<'e, Result<XuguStatement<'q>, Error>>
380    where
381        'c: 'e,
382    {
383        Box::pin(async move {
384            self.wait_until_ready().await?;
385
386            let metadata = if self.inner.cache_statement.is_enabled() {
387                self.get_or_prepare_statement(sql).await?.1
388            } else {
389                let (id, metadata) = self.prepare_statement(sql).await?;
390
391                self.inner.stream.send_packet(StmtClose(id)).await?;
392                // for StmtClose
393                let _ok: OkPacket = self.inner.stream.recv().await?;
394
395                metadata
396            };
397
398            Ok(XuguStatement {
399                sql: Cow::Borrowed(sql),
400                // metadata has internal Arcs for expensive data structures
401                metadata: metadata.clone(),
402            })
403        })
404    }
405
406    /// 描述有关其参数和结果的 SQL 查询和返回类型信息。
407    ///
408    /// 查询宏中的编译时验证使用它来支持其类型推断。
409    #[doc(hidden)]
410    fn describe<'e, 'q: 'e>(self, sql: &'q str) -> BoxFuture<'e, Result<Describe<Xugu>, Error>>
411    where
412        'c: 'e,
413    {
414        Box::pin(async move {
415            self.wait_until_ready().await?;
416
417            let (id, metadata) = self.prepare_statement(sql).await?;
418
419            self.inner.stream.send_packet(StmtClose(id)).await?;
420            // for StmtClose
421            let _ok: OkPacket = self.inner.stream.recv().await?;
422
423            let columns = (*metadata.columns).clone();
424
425            let nullable = columns
426                .iter()
427                .map(|col| {
428                    col.flags
429                        .map(|flags| !flags.contains(ColumnFlags::NOT_NULL))
430                })
431                .collect();
432
433            Ok(Describe {
434                parameters: Some(Either::Right(metadata.parameters.len())),
435                columns,
436                nullable,
437            })
438        })
439    }
440}