Skip to main content

sqlx_sqlserver/
connection.rs

1use futures_core::future::BoxFuture;
2use futures_core::stream::BoxStream;
3use futures_util::{future, stream, StreamExt};
4use native_tls::Certificate;
5use sqlx_core::connection::Connection;
6use sqlx_core::decode::Decode;
7use sqlx_core::error::Error;
8use sqlx_core::executor::{Execute, Executor};
9use sqlx_core::transaction::Transaction;
10use sqlx_core::value::Value;
11use sqlx_core::Either;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::TcpStream;
14use tokio_native_tls::TlsConnector;
15
16use crate::protocol::login::{build_login7_packet, Login7Error};
17use crate::protocol::packet::{PacketHeader, PacketStatus, PacketType, PACKET_HEADER_LEN};
18use crate::protocol::pre_login::{build_pre_login_packet, parse_server_encrypt, PreLoginError};
19use crate::protocol::query::{build_sql_batch_packet, parse_query_response, QueryOutput};
20use crate::protocol::rpc::{
21    build_execute_sql_packet, build_prepare_packet, build_unprepare_packet,
22};
23use crate::protocol::token::{
24    parse_login_response, EnvChange, LoginResponse, ServerError, TokenParseError,
25};
26use crate::tls::TlsPreloginStream;
27use crate::{
28    ssrp, Encrypt, Mssql, MssqlArguments, MssqlConnectOptions, MssqlQueryResult, MssqlRow,
29    MssqlStatement, MssqlTypeInfo,
30};
31
32/// SQL Server connection.
33#[derive(Debug)]
34pub struct MssqlConnection {
35    stream: Option<MssqlWireStream>,
36    transaction_depth: usize,
37    transaction_descriptor: u64,
38}
39
40impl MssqlConnection {
41    /// Establishes a SQL Server TCP connection and completes PRELOGIN and LOGIN7.
42    pub async fn establish(options: &MssqlConnectOptions) -> Result<Self, Error> {
43        let mut stream = MssqlWireStream::connect(options).await?;
44
45        let pre_login = build_pre_login_packet(options).map_err(pre_login_error)?;
46        stream.write_all(&pre_login).await?;
47
48        let pre_login_response = stream.read_message().await?;
49        if pre_login_response.packet_type != PacketType::TABULAR_RESULT {
50            return Err(Error::Protocol(format!(
51                "expected SQL Server PRELOGIN response as tabular result, got packet type 0x{:02x}",
52                pre_login_response.packet_type.code()
53            )));
54        }
55
56        let server_encrypt =
57            parse_server_encrypt(&pre_login_response.payload).map_err(pre_login_error)?;
58        let encrypted = negotiate_encryption(options.encrypt(), server_encrypt)?;
59
60        if encrypted {
61            stream.enable_tls(options).await?;
62        }
63
64        let login = build_login7_packet(options).map_err(login_error)?;
65        stream.write_all(&login).await?;
66
67        let login_response = stream.read_message().await?;
68        if login_response.packet_type != PacketType::TABULAR_RESULT {
69            return Err(Error::Protocol(format!(
70                "expected SQL Server LOGIN7 response as tabular result, got packet type 0x{:02x}",
71                login_response.packet_type.code()
72            )));
73        }
74
75        match parse_login_response(&login_response.payload).map_err(token_error)? {
76            LoginResponse::Success { env_changes, .. } => {
77                let mut conn = Self {
78                    stream: Some(stream),
79                    transaction_depth: 0,
80                    transaction_descriptor: 0,
81                };
82                conn.apply_env_changes(&env_changes);
83                Ok(conn)
84            }
85            LoginResponse::ServerError(error) => Err(server_error(error)),
86        }
87    }
88
89    fn apply_env_changes(&mut self, env_changes: &[EnvChange]) {
90        for change in env_changes {
91            match change {
92                EnvChange::PacketSize(size) => {
93                    if let Some(stream) = self.stream.as_mut() {
94                        stream.packet_size = (*size).clamp(512, 32767) as usize;
95                    }
96                }
97                EnvChange::BeginTransaction(descriptor) => {
98                    self.transaction_descriptor = *descriptor;
99                }
100                EnvChange::CommitTransaction(_) | EnvChange::RollbackTransaction(_) => {
101                    self.transaction_descriptor = 0;
102                }
103                _ => {}
104            }
105        }
106    }
107
108    /// Returns the current transaction depth tracked by the connection.
109    pub const fn transaction_depth(&self) -> usize {
110        self.transaction_depth
111    }
112
113    pub(crate) fn increment_transaction_depth(&mut self) {
114        self.transaction_depth += 1;
115    }
116
117    pub(crate) fn decrement_transaction_depth(&mut self) {
118        self.transaction_depth = self.transaction_depth.saturating_sub(1);
119    }
120
121    pub(crate) fn clear_transaction_depth(&mut self) {
122        self.transaction_depth = 0;
123    }
124
125    pub(crate) async fn run_sql_batch(&mut self, sql: &str) -> Result<QueryOutput, Error> {
126        let transaction_descriptor = self.transaction_descriptor;
127        let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
128        let packet = build_sql_batch_packet(sql, stream.packet_size, transaction_descriptor)
129            .map_err(frame_error)?;
130        stream.write_all(&packet).await?;
131
132        self.read_query_response().await
133    }
134
135    pub(crate) async fn run_execute_sql(
136        &mut self,
137        sql: &str,
138        arguments: Option<&MssqlArguments>,
139    ) -> Result<QueryOutput, Error> {
140        match arguments {
141            Some(arguments) if !arguments.is_empty() => {
142                let transaction_descriptor = self.transaction_descriptor;
143                let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
144                let packet = build_execute_sql_packet(
145                    sql,
146                    arguments,
147                    stream.packet_size,
148                    transaction_descriptor,
149                )
150                .map_err(|error| {
151                    Error::Protocol(format!("failed to encode SQL Server RPC: {error}"))
152                })?;
153                stream.write_all(&packet).await?;
154                self.read_query_response().await
155            }
156            _ => self.run_sql_batch(sql).await,
157        }
158    }
159
160    pub(crate) async fn run_prepare(
161        &mut self,
162        sql: &str,
163        parameters: &[MssqlTypeInfo],
164    ) -> Result<QueryOutput, Error> {
165        let transaction_descriptor = self.transaction_descriptor;
166        let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
167        let packet =
168            build_prepare_packet(sql, parameters, stream.packet_size, transaction_descriptor)
169                .map_err(|error| {
170                    Error::Protocol(format!("failed to encode SQL Server prepare RPC: {error}"))
171                })?;
172        stream.write_all(&packet).await?;
173
174        let output = self.read_query_response().await?;
175
176        if let Some(statement_id) = first_i32_return_value(&output)? {
177            let transaction_descriptor = self.transaction_descriptor;
178            let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
179            let packet =
180                build_unprepare_packet(statement_id, stream.packet_size, transaction_descriptor)
181                    .map_err(|error| {
182                        Error::Protocol(format!(
183                            "failed to encode SQL Server unprepare RPC: {error}"
184                        ))
185                    })?;
186            stream.write_all(&packet).await?;
187            let _ = self.read_query_response().await?;
188        }
189
190        Ok(output)
191    }
192
193    async fn read_query_response(&mut self) -> Result<QueryOutput, Error> {
194        let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
195        let response = stream.read_message().await?;
196        if response.packet_type != PacketType::TABULAR_RESULT {
197            return Err(Error::Protocol(format!(
198                "expected SQL Server query response as tabular result, got packet type 0x{:02x}",
199                response.packet_type.code()
200            )));
201        }
202
203        let output = parse_query_response(&response.payload)?;
204        self.apply_env_changes(&output.env_changes);
205        Ok(output)
206    }
207}
208
209impl Connection for MssqlConnection {
210    type Database = Mssql;
211    type Options = MssqlConnectOptions;
212
213    async fn close(mut self) -> Result<(), Error> {
214        if let Some(mut stream) = self.stream.take() {
215            stream.shutdown().await?;
216        }
217
218        Ok(())
219    }
220
221    async fn close_hard(mut self) -> Result<(), Error> {
222        if let Some(mut stream) = self.stream.take() {
223            stream.shutdown().await?;
224        }
225
226        Ok(())
227    }
228
229    async fn ping(&mut self) -> Result<(), Error> {
230        if self.stream.is_some() {
231            Ok(())
232        } else {
233            Err(wire_not_implemented())
234        }
235    }
236
237    fn begin(
238        &mut self,
239    ) -> impl std::future::Future<Output = Result<Transaction<'_, Self::Database>, Error>> + Send + '_
240    {
241        Transaction::begin(self, None)
242    }
243
244    fn shrink_buffers(&mut self) {}
245
246    async fn flush(&mut self) -> Result<(), Error> {
247        Ok(())
248    }
249
250    fn should_flush(&self) -> bool {
251        false
252    }
253}
254
255impl<'c> Executor<'c> for &'c mut MssqlConnection {
256    type Database = Mssql;
257
258    fn fetch_many<'e, 'q, E>(
259        self,
260        mut query: E,
261    ) -> BoxStream<'e, Result<Either<MssqlQueryResult, MssqlRow>, Error>>
262    where
263        'c: 'e,
264        E: Execute<'q, Self::Database>,
265        'q: 'e,
266        E: 'q,
267    {
268        let arguments = query.take_arguments().map_err(Error::Encode);
269        let sql = query.sql();
270
271        stream::once(async move {
272            let arguments = arguments?;
273            self.run_execute_sql(sql.as_str(), arguments.as_ref()).await
274        })
275        .map(|result| match result {
276            Ok(output) => stream_query_output(output),
277            Err(error) => stream::once(future::ready(Err(error))).boxed(),
278        })
279        .flatten()
280        .boxed()
281    }
282
283    fn fetch_optional<'e, 'q, E>(
284        self,
285        mut query: E,
286    ) -> BoxFuture<'e, Result<Option<MssqlRow>, Error>>
287    where
288        'c: 'e,
289        E: Execute<'q, Self::Database>,
290        'q: 'e,
291        E: 'q,
292    {
293        let arguments = query.take_arguments().map_err(Error::Encode);
294        let sql = query.sql();
295
296        Box::pin(async move {
297            let arguments = arguments?;
298            Ok(self
299                .run_execute_sql(sql.as_str(), arguments.as_ref())
300                .await?
301                .rows
302                .into_iter()
303                .next())
304        })
305    }
306
307    fn prepare_with<'e>(
308        self,
309        sql: sqlx_core::sql_str::SqlStr,
310        parameters: &'e [crate::MssqlTypeInfo],
311    ) -> BoxFuture<'e, Result<MssqlStatement, Error>>
312    where
313        'c: 'e,
314    {
315        Box::pin(async move {
316            let output = self.run_prepare(sql.as_str(), parameters).await?;
317            let parameters = if parameters.is_empty() {
318                None
319            } else {
320                Some(Either::Left(parameters.to_vec()))
321            };
322
323            Ok(MssqlStatement::with_parameters(
324                sql,
325                output.columns,
326                parameters,
327            ))
328        })
329    }
330}
331
332fn first_i32_return_value(output: &QueryOutput) -> Result<Option<i32>, Error> {
333    output
334        .return_values
335        .first()
336        .map(|value| {
337            <i32 as Decode<Mssql>>::decode(value.as_ref()).map_err(|error| Error::ColumnDecode {
338                index: "return value".to_owned(),
339                source: error,
340            })
341        })
342        .transpose()
343}
344
345pub(crate) fn wire_not_implemented() -> Error {
346    Error::Protocol("SQL Server connection stream is not available".to_owned())
347}
348
349struct MssqlWireStream {
350    stream: MssqlStream,
351    packet_size: usize,
352}
353
354impl std::fmt::Debug for MssqlWireStream {
355    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356        f.debug_struct("MssqlWireStream")
357            .field("encrypted", &matches!(self.stream, MssqlStream::Tls(_)))
358            .field("packet_size", &self.packet_size)
359            .finish()
360    }
361}
362
363enum MssqlStream {
364    Raw(TcpStream),
365    Tls(tokio_native_tls::TlsStream<TlsPreloginStream<TcpStream>>),
366    Taken,
367}
368
369impl MssqlWireStream {
370    async fn connect(options: &MssqlConnectOptions) -> Result<Self, Error> {
371        let port = match (options.port(), options.instance()) {
372            (Some(port), _) => port,
373            (None, Some(instance)) => ssrp::resolve_instance_port(options.host(), instance).await?,
374            (None, None) => 1433,
375        };
376
377        let stream = TcpStream::connect((options.host(), port)).await?;
378        let packet_size = usize::try_from(options.requested_packet_size()).map_err(|_| {
379            Error::Protocol(format!(
380                "SQL Server packet size {} does not fit usize",
381                options.requested_packet_size()
382            ))
383        })?;
384
385        Ok(Self {
386            stream: MssqlStream::Raw(stream),
387            packet_size,
388        })
389    }
390
391    async fn write_all(&mut self, bytes: &[u8]) -> Result<(), Error> {
392        match &mut self.stream {
393            MssqlStream::Raw(stream) => {
394                stream.write_all(bytes).await?;
395                stream.flush().await?;
396            }
397            MssqlStream::Tls(stream) => {
398                stream.write_all(bytes).await?;
399                stream.flush().await?;
400            }
401            MssqlStream::Taken => return Err(taken_stream_error()),
402        }
403        Ok(())
404    }
405
406    async fn shutdown(&mut self) -> Result<(), Error> {
407        match &mut self.stream {
408            MssqlStream::Raw(stream) => stream.shutdown().await?,
409            MssqlStream::Tls(stream) => stream.shutdown().await?,
410            MssqlStream::Taken => return Err(taken_stream_error()),
411        }
412        Ok(())
413    }
414
415    async fn enable_tls(&mut self, options: &MssqlConnectOptions) -> Result<(), Error> {
416        let stream = match std::mem::replace(&mut self.stream, MssqlStream::Taken) {
417            MssqlStream::Raw(stream) => stream,
418            other => {
419                self.stream = other;
420                return Ok(());
421            }
422        };
423
424        let mut stream = TlsPreloginStream::new(stream);
425        stream.start_handshake();
426
427        let domain = options
428            .hostname_in_certificate()
429            .unwrap_or_else(|| options.host());
430        let connector = build_tls_connector(options)?;
431        let mut stream = connector
432            .connect(domain, stream)
433            .await
434            .map_err(|error| Error::Tls(error.into()))?;
435        stream.get_mut().get_mut().get_mut().finish_handshake();
436
437        self.stream = MssqlStream::Tls(stream);
438        Ok(())
439    }
440
441    async fn read_message(&mut self) -> Result<WireMessage, Error> {
442        let mut packet_type = None;
443        let mut expected_packet_id = None;
444        let mut payload = Vec::new();
445
446        loop {
447            let mut header_bytes = [0u8; PACKET_HEADER_LEN];
448            self.read_exact(&mut header_bytes).await?;
449            let header = PacketHeader::decode(&header_bytes).map_err(packet_error)?;
450
451            if let Some(packet_type) = packet_type {
452                if header.packet_type != packet_type {
453                    return Err(Error::Protocol(format!(
454                        "mismatched SQL Server packet type: expected 0x{:02x}, got 0x{:02x}",
455                        packet_type.code(),
456                        header.packet_type.code()
457                    )));
458                }
459            } else {
460                packet_type = Some(header.packet_type);
461            }
462
463            if let Some(packet_id) = expected_packet_id {
464                if header.packet_id != packet_id {
465                    return Err(Error::Protocol(format!(
466                        "non-contiguous SQL Server packet id: expected {packet_id}, got {}",
467                        header.packet_id
468                    )));
469                }
470            }
471
472            let packet_len = usize::from(header.length);
473            if packet_len > self.packet_size {
474                return Err(Error::Protocol(format!(
475                    "SQL Server packet length {packet_len} exceeds negotiated packet size {}",
476                    self.packet_size
477                )));
478            }
479
480            let payload_len = packet_len.checked_sub(PACKET_HEADER_LEN).ok_or_else(|| {
481                Error::Protocol("SQL Server packet header length underflow".to_owned())
482            })?;
483            let old_len = payload.len();
484            payload.resize(old_len + payload_len, 0);
485            self.read_exact(&mut payload[old_len..]).await?;
486
487            expected_packet_id = Some(header.packet_id.wrapping_add(1));
488
489            if header.status == PacketStatus::END_OF_MESSAGE {
490                return Ok(WireMessage {
491                    packet_type: packet_type.expect("packet_type is set after first header"),
492                    payload,
493                });
494            }
495        }
496    }
497
498    async fn read_exact(&mut self, bytes: &mut [u8]) -> Result<(), Error> {
499        match &mut self.stream {
500            MssqlStream::Raw(stream) => {
501                stream.read_exact(bytes).await?;
502            }
503            MssqlStream::Tls(stream) => {
504                stream.read_exact(bytes).await?;
505            }
506            MssqlStream::Taken => return Err(taken_stream_error()),
507        }
508
509        Ok(())
510    }
511}
512
513#[derive(Debug)]
514struct WireMessage {
515    packet_type: PacketType,
516    payload: Vec<u8>,
517}
518
519fn negotiate_encryption(requested: Encrypt, server: Encrypt) -> std::result::Result<bool, Error> {
520    match (requested, server) {
521        (Encrypt::NotSupported, Encrypt::NotSupported | Encrypt::Off) => Ok(false),
522        (Encrypt::NotSupported, Encrypt::On | Encrypt::Required) => Err(Error::Protocol(
523            "SQL Server requires encryption, but the client URL requested encrypt=not_supported"
524                .to_owned(),
525        )),
526        (Encrypt::Required, Encrypt::Off | Encrypt::NotSupported) => Err(Error::Tls(
527            "SQL Server TLS encryption is required but not supported by the server".into(),
528        )),
529        (Encrypt::On | Encrypt::Required, Encrypt::On | Encrypt::Required) => Ok(true),
530        (Encrypt::Off, _) | (_, Encrypt::Off) => Err(Error::Protocol(
531            "SQL Server login-only TLS fallback is not implemented yet; use encrypt=mandatory or encrypt=strict for encrypted connections, or encrypt=not_supported for plaintext development servers"
532                .to_owned(),
533        )),
534        (Encrypt::On, Encrypt::NotSupported) => Ok(false),
535    }
536}
537
538fn build_tls_connector(options: &MssqlConnectOptions) -> Result<TlsConnector, Error> {
539    let mut builder = native_tls::TlsConnector::builder();
540    builder.danger_accept_invalid_certs(options.trust_server_certificate());
541    builder.danger_accept_invalid_hostnames(options.hostname_in_certificate().is_none());
542
543    if let Some(path) = options.ssl_root_cert() {
544        let cert = std::fs::read(path).map_err(Error::Io)?;
545        let cert = Certificate::from_pem(&cert)
546            .or_else(|_| Certificate::from_der(&cert))
547            .map_err(|error| Error::Tls(error.into()))?;
548        builder.add_root_certificate(cert);
549    }
550
551    builder
552        .build()
553        .map(TlsConnector::from)
554        .map_err(|error| Error::Tls(error.into()))
555}
556
557fn taken_stream_error() -> Error {
558    Error::Protocol("SQL Server stream was used while TLS upgrade was in progress".to_owned())
559}
560
561fn server_error(error: ServerError) -> Error {
562    Error::Protocol(format!(
563        "SQL Server error {} (state {}, class {}): {}",
564        error.number, error.state, error.class, error.message
565    ))
566}
567
568fn packet_error(error: crate::protocol::packet::PacketHeaderError) -> Error {
569    Error::Protocol(error.to_string())
570}
571
572fn pre_login_error(error: PreLoginError) -> Error {
573    Error::Protocol(error.to_string())
574}
575
576fn login_error(error: Login7Error) -> Error {
577    Error::Protocol(error.to_string())
578}
579
580fn token_error(error: TokenParseError) -> Error {
581    Error::Protocol(error.to_string())
582}
583
584fn frame_error(error: crate::protocol::packet::PacketFrameError) -> Error {
585    Error::Protocol(error.to_string())
586}
587
588fn stream_query_output(
589    output: QueryOutput,
590) -> BoxStream<'static, Result<Either<MssqlQueryResult, MssqlRow>, Error>> {
591    stream::iter(
592        output
593            .rows
594            .into_iter()
595            .map(|row| Ok(Either::Right(row)))
596            .chain(std::iter::once(Ok(Either::Left(output.result)))),
597    )
598    .boxed()
599}
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604
605    #[test]
606    fn negotiates_full_tls_for_required_or_mandatory_encryption() {
607        assert!(negotiate_encryption(Encrypt::On, Encrypt::On).unwrap());
608        assert!(negotiate_encryption(Encrypt::Required, Encrypt::Required).unwrap());
609    }
610
611    #[test]
612    fn allows_plaintext_only_when_explicitly_requested_and_supported() {
613        assert!(!negotiate_encryption(Encrypt::NotSupported, Encrypt::Off).unwrap());
614        assert!(negotiate_encryption(Encrypt::NotSupported, Encrypt::Required).is_err());
615    }
616
617    #[test]
618    fn rejects_login_only_tls_fallback_until_downgrade_is_available() {
619        assert!(negotiate_encryption(Encrypt::Off, Encrypt::On).is_err());
620        assert!(negotiate_encryption(Encrypt::On, Encrypt::Off).is_err());
621    }
622}