Skip to main content

sqlx_sqlserver/
any.rs

1//! Runtime `Any` driver registration for SQL Server.
2//!
3//! The driver can be installed with SQLx `Any`. The current SQL Server port supports SQL batch
4//! execution and the same stable scalar RPC argument types as the native connection.
5
6use crate::{
7    Mssql, MssqlArguments, MssqlColumn, MssqlConnectOptions, MssqlConnection, MssqlQueryResult,
8    MssqlTransactionManager, MssqlType, MssqlTypeInfo,
9};
10use futures_core::future::BoxFuture;
11use futures_core::stream::BoxStream;
12use futures_util::{future, stream, FutureExt, StreamExt};
13use sqlx_core::any::driver::AnyDriver;
14use sqlx_core::any::{
15    AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow,
16    AnyStatement, AnyTypeInfo, AnyTypeInfoKind, AnyValueKind,
17};
18use sqlx_core::arguments::Arguments;
19use sqlx_core::column::Column;
20use sqlx_core::connection::{ConnectOptions, Connection};
21use sqlx_core::database::Database;
22use sqlx_core::ext::ustr::UStr;
23use sqlx_core::row::Row;
24use sqlx_core::sql_str::SqlStr;
25use sqlx_core::statement::Statement;
26use sqlx_core::transaction::TransactionManager;
27use sqlx_core::{Either, Error, HashMap};
28use std::sync::Arc;
29
30/// Installable SQL Server driver for SQLx `Any` connections.
31pub const DRIVER: AnyDriver = AnyDriver::with_migrate::<Mssql>();
32
33impl AnyConnectionBackend for MssqlConnection {
34    fn name(&self) -> &str {
35        <Mssql as Database>::NAME
36    }
37
38    fn close(self: Box<Self>) -> BoxFuture<'static, sqlx_core::Result<()>> {
39        Connection::close(*self).boxed()
40    }
41
42    fn close_hard(self: Box<Self>) -> BoxFuture<'static, sqlx_core::Result<()>> {
43        Connection::close_hard(*self).boxed()
44    }
45
46    fn ping(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
47        Connection::ping(self).boxed()
48    }
49
50    fn begin(&mut self, statement: Option<SqlStr>) -> BoxFuture<'_, sqlx_core::Result<()>> {
51        MssqlTransactionManager::begin(self, statement).boxed()
52    }
53
54    fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
55        MssqlTransactionManager::commit(self).boxed()
56    }
57
58    fn rollback(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
59        MssqlTransactionManager::rollback(self).boxed()
60    }
61
62    fn start_rollback(&mut self) {
63        MssqlTransactionManager::start_rollback(self);
64    }
65
66    fn get_transaction_depth(&self) -> usize {
67        MssqlTransactionManager::get_transaction_depth(self)
68    }
69
70    fn shrink_buffers(&mut self) {
71        Connection::shrink_buffers(self);
72    }
73
74    fn flush(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> {
75        Connection::flush(self).boxed()
76    }
77
78    fn should_flush(&self) -> bool {
79        Connection::should_flush(self)
80    }
81
82    #[cfg(feature = "migrate")]
83    fn as_migrate(
84        &mut self,
85    ) -> sqlx_core::Result<&mut (dyn sqlx_core::migrate::Migrate + Send + 'static)> {
86        Ok(self)
87    }
88
89    fn fetch_many(
90        &mut self,
91        query: SqlStr,
92        _persistent: bool,
93        arguments: Option<AnyArguments>,
94    ) -> BoxStream<'_, sqlx_core::Result<Either<AnyQueryResult, AnyRow>>> {
95        stream::once(async move {
96            let native_arguments = arguments
97                .map(convert_any_arguments)
98                .transpose()
99                .map_err(Error::Encode)?;
100            self.run_execute_sql(query.as_str(), native_arguments.as_ref())
101                .await
102        })
103        .map(|result| match result {
104            Ok(output) => {
105                let column_names = column_names(&output.columns);
106                let rows = output.rows.into_iter().map(move |row| {
107                    AnyRow::map_from(&row, Arc::clone(&column_names)).map(Either::Right)
108                });
109                let done = std::iter::once(Ok(Either::Left(map_result(output.result))));
110                stream::iter(rows.chain(done)).boxed()
111            }
112            Err(error) => stream::once(future::ready(Err(error))).boxed(),
113        })
114        .flatten()
115        .boxed()
116    }
117
118    fn fetch_optional(
119        &mut self,
120        query: SqlStr,
121        _persistent: bool,
122        arguments: Option<AnyArguments>,
123    ) -> BoxFuture<'_, sqlx_core::Result<Option<AnyRow>>> {
124        Box::pin(async move {
125            let native_arguments = arguments
126                .map(convert_any_arguments)
127                .transpose()
128                .map_err(Error::Encode)?;
129            self.run_execute_sql(query.as_str(), native_arguments.as_ref())
130                .await?
131                .rows
132                .into_iter()
133                .next()
134                .map(|row| {
135                    let column_names = column_names(row.columns());
136                    AnyRow::map_from(&row, column_names)
137                })
138                .transpose()
139        })
140    }
141
142    fn prepare_with<'c, 'q: 'c>(
143        &'c mut self,
144        sql: SqlStr,
145        parameters: &[AnyTypeInfo],
146    ) -> BoxFuture<'c, sqlx_core::Result<AnyStatement>> {
147        let parameters = parameters
148            .iter()
149            .map(mssql_type_from_any)
150            .collect::<Result<Vec<_>, _>>();
151
152        Box::pin(async move {
153            let parameters = parameters?;
154            let statement = self.run_prepare(sql.as_str(), &parameters).await?;
155            let statement = crate::MssqlStatement::with_parameters(
156                sql,
157                statement.columns,
158                if parameters.is_empty() {
159                    None
160                } else {
161                    Some(Either::Left(parameters))
162                },
163            );
164            let column_names = column_names(statement.columns());
165            AnyStatement::try_from_statement(statement, column_names)
166        })
167    }
168}
169
170fn mssql_type_from_any(type_info: &AnyTypeInfo) -> Result<MssqlTypeInfo, Error> {
171    Ok(match type_info.kind() {
172        AnyTypeInfoKind::Bool => MssqlTypeInfo::BIT,
173        AnyTypeInfoKind::SmallInt => MssqlTypeInfo::SMALLINT,
174        AnyTypeInfoKind::Integer | AnyTypeInfoKind::Null => MssqlTypeInfo::INT,
175        AnyTypeInfoKind::BigInt => MssqlTypeInfo::BIGINT,
176        AnyTypeInfoKind::Real => MssqlTypeInfo::REAL,
177        AnyTypeInfoKind::Double => MssqlTypeInfo::FLOAT,
178        AnyTypeInfoKind::Text => MssqlTypeInfo::NVARCHAR,
179        AnyTypeInfoKind::Blob => MssqlTypeInfo::VARBINARY,
180    })
181}
182
183fn convert_any_arguments(
184    arguments: AnyArguments,
185) -> Result<MssqlArguments, sqlx_core::error::BoxDynError> {
186    let mut out = MssqlArguments::default();
187
188    for argument in arguments.values.0 {
189        match argument {
190            AnyValueKind::Null(AnyTypeInfoKind::Null) => out.add(Option::<i32>::None),
191            AnyValueKind::Null(AnyTypeInfoKind::Bool) => out.add(Option::<bool>::None),
192            AnyValueKind::Null(AnyTypeInfoKind::SmallInt) => out.add(Option::<i16>::None),
193            AnyValueKind::Null(AnyTypeInfoKind::Integer) => out.add(Option::<i32>::None),
194            AnyValueKind::Null(AnyTypeInfoKind::BigInt) => out.add(Option::<i64>::None),
195            AnyValueKind::Null(AnyTypeInfoKind::Real) => out.add(Option::<f32>::None),
196            AnyValueKind::Null(AnyTypeInfoKind::Double) => out.add(Option::<f64>::None),
197            AnyValueKind::Null(AnyTypeInfoKind::Text) => out.add(Option::<String>::None),
198            AnyValueKind::Null(AnyTypeInfoKind::Blob) => out.add(Option::<Vec<u8>>::None),
199            AnyValueKind::Bool(value) => out.add(value),
200            AnyValueKind::SmallInt(value) => out.add(value),
201            AnyValueKind::Integer(value) => out.add(value),
202            AnyValueKind::BigInt(value) => out.add(value),
203            AnyValueKind::Real(value) => out.add(value),
204            AnyValueKind::Double(value) => out.add(value),
205            AnyValueKind::Text(value) => out.add(value.as_str()),
206            AnyValueKind::TextSlice(value) => out.add(value.as_ref()),
207            AnyValueKind::Blob(value) => out.add(value.as_slice()),
208            other => {
209                return Err(format!(
210                    "SQL Server Any driver does not support argument value {other:?}"
211                )
212                .into());
213            }
214        }?;
215    }
216
217    Ok(out)
218}
219
220fn map_result(result: MssqlQueryResult) -> AnyQueryResult {
221    AnyQueryResult {
222        rows_affected: result.rows_affected(),
223        last_insert_id: None,
224    }
225}
226
227fn column_names(columns: &[MssqlColumn]) -> Arc<HashMap<UStr, usize>> {
228    Arc::new(
229        columns
230            .iter()
231            .map(|column| (UStr::new(column.name()), column.ordinal()))
232            .collect(),
233    )
234}
235
236impl<'a> TryFrom<&'a AnyConnectOptions> for MssqlConnectOptions {
237    type Error = Error;
238
239    fn try_from(options: &'a AnyConnectOptions) -> Result<Self, Self::Error> {
240        MssqlConnectOptions::from_url(&options.database_url)
241    }
242}
243
244impl<'a> TryFrom<&'a MssqlTypeInfo> for AnyTypeInfo {
245    type Error = Error;
246
247    fn try_from(type_info: &'a MssqlTypeInfo) -> Result<Self, Self::Error> {
248        let kind = match type_info.kind() {
249            MssqlType::Null => AnyTypeInfoKind::Null,
250            MssqlType::Bit => AnyTypeInfoKind::Bool,
251            MssqlType::SmallInt => AnyTypeInfoKind::SmallInt,
252            MssqlType::Int => AnyTypeInfoKind::Integer,
253            MssqlType::BigInt => AnyTypeInfoKind::BigInt,
254            MssqlType::Real => AnyTypeInfoKind::Real,
255            MssqlType::Float => AnyTypeInfoKind::Double,
256            MssqlType::NVarChar | MssqlType::VarChar => AnyTypeInfoKind::Text,
257            MssqlType::VarBinary => AnyTypeInfoKind::Blob,
258            MssqlType::TinyInt | MssqlType::Other(_) => {
259                return Err(Error::AnyDriverError(
260                    format!("Any driver does not support the SQL Server type {type_info:?}").into(),
261                ));
262            }
263        };
264
265        Ok(AnyTypeInfo { kind })
266    }
267}
268
269impl<'a> TryFrom<&'a MssqlColumn> for AnyColumn {
270    type Error = Error;
271
272    fn try_from(column: &'a MssqlColumn) -> Result<Self, Self::Error> {
273        let type_info =
274            AnyTypeInfo::try_from(column.type_info()).map_err(|error| Error::ColumnDecode {
275                index: column.name().to_owned(),
276                source: error.into(),
277            })?;
278
279        Ok(Self {
280            ordinal: column.ordinal(),
281            name: UStr::new(column.name()),
282            type_info,
283        })
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn maps_stable_sql_server_types_to_any_types() {
293        assert_eq!(
294            AnyTypeInfo::try_from(&MssqlTypeInfo::BIT).unwrap().kind(),
295            AnyTypeInfoKind::Bool
296        );
297        assert_eq!(
298            AnyTypeInfo::try_from(&MssqlTypeInfo::INT).unwrap().kind(),
299            AnyTypeInfoKind::Integer
300        );
301        assert_eq!(
302            AnyTypeInfo::try_from(&MssqlTypeInfo::NVARCHAR)
303                .unwrap()
304                .kind(),
305            AnyTypeInfoKind::Text
306        );
307        assert_eq!(
308            AnyTypeInfo::try_from(&MssqlTypeInfo::VARCHAR)
309                .unwrap()
310                .kind(),
311            AnyTypeInfoKind::Text
312        );
313        assert_eq!(
314            AnyTypeInfo::try_from(&MssqlTypeInfo::VARBINARY)
315                .unwrap()
316                .kind(),
317            AnyTypeInfoKind::Blob
318        );
319    }
320
321    #[test]
322    fn rejects_unstable_sql_server_types_for_any_mapping() {
323        assert!(matches!(
324            AnyTypeInfo::try_from(&MssqlTypeInfo::TINYINT),
325            Err(Error::AnyDriverError(_))
326        ));
327    }
328}