sqlx_core/mssql/connection/
executor.rs

1use crate::describe::Describe;
2use crate::error::Error;
3use crate::executor::{Execute, Executor};
4use crate::logger::QueryLogger;
5use crate::mssql::connection::prepare::prepare;
6use crate::mssql::protocol::col_meta_data::Flags;
7use crate::mssql::protocol::done::Status;
8use crate::mssql::protocol::message::Message;
9use crate::mssql::protocol::packet::PacketType;
10use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest};
11use crate::mssql::protocol::sql_batch::SqlBatch;
12use crate::mssql::{
13    Mssql, MssqlArguments, MssqlConnection, MssqlQueryResult, MssqlRow, MssqlStatement,
14    MssqlTypeInfo,
15};
16use either::Either;
17use futures_core::future::BoxFuture;
18use futures_core::stream::BoxStream;
19use futures_util::TryStreamExt;
20use std::borrow::Cow;
21use std::sync::Arc;
22
23impl MssqlConnection {
24    async fn run(&mut self, query: &str, arguments: Option<MssqlArguments>) -> Result<(), Error> {
25        self.stream.wait_until_ready().await?;
26        self.stream.pending_done_count += 1;
27
28        if let Some(mut arguments) = arguments {
29            let proc = Either::Right(Procedure::ExecuteSql);
30            let mut proc_args = MssqlArguments::default();
31
32            // SQL
33            proc_args.add_unnamed(query);
34
35            if !arguments.data.is_empty() {
36                // Declarations
37                //  NAME TYPE, NAME TYPE, ...
38                proc_args.add_unnamed(&*arguments.declarations);
39
40                // Add the list of SQL parameters _after_ our RPC parameters
41                proc_args.append(&mut arguments);
42            }
43
44            self.stream.write_packet(
45                PacketType::Rpc,
46                RpcRequest {
47                    transaction_descriptor: self.stream.transaction_descriptor,
48                    arguments: &proc_args,
49                    procedure: proc,
50                    options: OptionFlags::empty(),
51                },
52            );
53        } else {
54            self.stream.write_packet(
55                PacketType::SqlBatch,
56                SqlBatch {
57                    transaction_descriptor: self.stream.transaction_descriptor,
58                    sql: query,
59                },
60            );
61        }
62
63        self.stream.flush().await?;
64
65        Ok(())
66    }
67}
68
69impl<'c> Executor<'c> for &'c mut MssqlConnection {
70    type Database = Mssql;
71
72    fn fetch_many<'e, 'q: 'e, E: 'q>(
73        self,
74        mut query: E,
75    ) -> BoxStream<'e, Result<Either<MssqlQueryResult, MssqlRow>, Error>>
76    where
77        'c: 'e,
78        E: Execute<'q, Self::Database>,
79    {
80        let sql = query.sql();
81        let arguments = query.take_arguments();
82        let mut logger = QueryLogger::new(sql, self.log_settings.clone());
83
84        Box::pin(try_stream! {
85            self.run(sql, arguments).await?;
86
87            loop {
88                let message = self.stream.recv_message().await?;
89
90                match message {
91                    Message::Row(row) => {
92                        let columns = Arc::clone(&self.stream.columns);
93                        let column_names = Arc::clone(&self.stream.column_names);
94
95                        logger.increment_rows_returned();
96
97                        r#yield!(Either::Right(MssqlRow { row, column_names, columns }));
98                    }
99
100                    Message::Done(done) | Message::DoneProc(done) => {
101                        if !done.status.contains(Status::DONE_MORE) {
102                            self.stream.handle_done(&done);
103                        }
104
105                        if done.status.contains(Status::DONE_COUNT) {
106                            let rows_affected = done.affected_rows;
107                            logger.increase_rows_affected(rows_affected);
108                            r#yield!(Either::Left(MssqlQueryResult {
109                                rows_affected,
110                            }));
111                        }
112
113                        if !done.status.contains(Status::DONE_MORE) {
114                            break;
115                        }
116                    }
117
118                    Message::DoneInProc(done) => {
119                        if done.status.contains(Status::DONE_COUNT) {
120                            let rows_affected = done.affected_rows;
121                            logger.increase_rows_affected(rows_affected);
122                            r#yield!(Either::Left(MssqlQueryResult {
123                                rows_affected,
124                            }));
125                        }
126                    }
127
128                    _ => {}
129                }
130            }
131
132            Ok(())
133        })
134    }
135
136    fn fetch_optional<'e, 'q: 'e, E: 'q>(
137        self,
138        query: E,
139    ) -> BoxFuture<'e, Result<Option<MssqlRow>, Error>>
140    where
141        'c: 'e,
142        E: Execute<'q, Self::Database>,
143    {
144        let mut s = self.fetch_many(query);
145
146        Box::pin(async move {
147            while let Some(v) = s.try_next().await? {
148                if let Either::Right(r) = v {
149                    return Ok(Some(r));
150                }
151            }
152
153            Ok(None)
154        })
155    }
156
157    fn prepare_with<'e, 'q: 'e>(
158        self,
159        sql: &'q str,
160        _parameters: &[MssqlTypeInfo],
161    ) -> BoxFuture<'e, Result<MssqlStatement<'q>, Error>>
162    where
163        'c: 'e,
164    {
165        Box::pin(async move {
166            let metadata = prepare(self, sql).await?;
167
168            Ok(MssqlStatement {
169                sql: Cow::Borrowed(sql),
170                metadata,
171            })
172        })
173    }
174
175    fn describe<'e, 'q: 'e>(
176        self,
177        sql: &'q str,
178    ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
179    where
180        'c: 'e,
181    {
182        Box::pin(async move {
183            let metadata = prepare(self, sql).await?;
184
185            let mut nullable = Vec::with_capacity(metadata.columns.len());
186
187            for col in metadata.columns.iter() {
188                nullable.push(Some(col.flags.contains(Flags::NULLABLE)));
189            }
190
191            Ok(Describe {
192                nullable,
193                columns: (metadata.columns).clone(),
194                parameters: None,
195            })
196        })
197    }
198}