Skip to main content

sqlx_sqlserver/
connection.rs

1use futures_core::future::BoxFuture;
2use futures_core::stream::BoxStream;
3use futures_util::{future, stream, StreamExt};
4use native_tls::Certificate;
5use sqlx_core::connection::Connection;
6use sqlx_core::decode::Decode;
7use sqlx_core::error::Error;
8use sqlx_core::executor::{Execute, Executor};
9use sqlx_core::transaction::Transaction;
10use sqlx_core::value::Value;
11use sqlx_core::Either;
12use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
13use tokio::net::TcpStream;
14use tokio_native_tls::TlsConnector;
15
16use crate::error::server_error;
17use crate::protocol::login::build_login7_packet;
18use crate::protocol::packet::{PacketHeader, PacketStatus, PacketType, PACKET_HEADER_LEN};
19use crate::protocol::pre_login::{build_pre_login_packet, parse_server_encrypt};
20use crate::protocol::query::{build_sql_batch_packet, parse_query_response, QueryOutput};
21use crate::protocol::rpc::{
22    build_execute_sql_packet, build_prepare_packet, build_unprepare_packet,
23};
24use crate::protocol::token::{parse_login_response, EnvChange, LoginResponse};
25use crate::tls::TlsPreloginStream;
26use crate::{
27    ssrp, Encrypt, Mssql, MssqlArguments, MssqlConnectOptions, MssqlQueryResult, MssqlRow,
28    MssqlStatement, MssqlTypeInfo,
29};
30
31/// SQL Server connection.
32#[derive(Debug)]
33pub struct MssqlConnection {
34    stream: Option<MssqlWireStream>,
35    transaction_depth: usize,
36    transaction_descriptor: u64,
37    pending_rollback_sql: Option<&'static str>,
38}
39
40impl MssqlConnection {
41    /// Establishes a SQL Server TCP connection and completes PRELOGIN and LOGIN7.
42    pub async fn establish(options: &MssqlConnectOptions) -> Result<Self, Error> {
43        let mut stream = MssqlWireStream::connect(options).await?;
44
45        let pre_login = build_pre_login_packet(options).map_err(|error| {
46            Error::Protocol(format!(
47                "failed to build SQL Server PRELOGIN packet: {error}"
48            ))
49        })?;
50        stream
51            .write_all(&pre_login)
52            .await
53            .map_err(|error| context_error("failed to send SQL Server PRELOGIN packet", error))?;
54
55        let pre_login_response = stream
56            .read_message()
57            .await
58            .map_err(|error| context_error("failed to read SQL Server PRELOGIN response", error))?;
59        if pre_login_response.packet_type != PacketType::TABULAR_RESULT {
60            return Err(Error::Protocol(format!(
61                "expected SQL Server PRELOGIN response as tabular result, got packet type 0x{:02x}",
62                pre_login_response.packet_type.code()
63            )));
64        }
65
66        let server_encrypt =
67            parse_server_encrypt(&pre_login_response.payload).map_err(|error| {
68                Error::Protocol(format!(
69                    "failed to parse SQL Server PRELOGIN response: {error}"
70                ))
71            })?;
72        let encrypted = negotiate_encryption(options.encrypt(), server_encrypt)?;
73
74        if encrypted {
75            stream.enable_tls(options).await?;
76        }
77
78        let login = build_login7_packet(options).map_err(|error| {
79            Error::Protocol(format!("failed to build SQL Server LOGIN7 packet: {error}"))
80        })?;
81        stream
82            .write_all(&login)
83            .await
84            .map_err(|error| context_error("failed to send SQL Server LOGIN7 packet", error))?;
85
86        let login_response = stream
87            .read_message()
88            .await
89            .map_err(|error| context_error("failed to read SQL Server LOGIN7 response", error))?;
90        if login_response.packet_type != PacketType::TABULAR_RESULT {
91            return Err(Error::Protocol(format!(
92                "expected SQL Server LOGIN7 response as tabular result, got packet type 0x{:02x}",
93                login_response.packet_type.code()
94            )));
95        }
96
97        match parse_login_response(&login_response.payload).map_err(|error| {
98            Error::Protocol(format!(
99                "failed to parse SQL Server LOGIN7 response: {error}"
100            ))
101        })? {
102            LoginResponse::Success { env_changes, .. } => {
103                let mut conn = Self {
104                    stream: Some(stream),
105                    transaction_depth: 0,
106                    transaction_descriptor: 0,
107                    pending_rollback_sql: None,
108                };
109                conn.apply_env_changes(&env_changes);
110                Ok(conn)
111            }
112            LoginResponse::ServerError(error) => Err(server_error(error)),
113        }
114    }
115
116    fn apply_env_changes(&mut self, env_changes: &[EnvChange]) {
117        for change in env_changes {
118            match change {
119                EnvChange::PacketSize(size) => {
120                    if let Some(stream) = self.stream.as_mut() {
121                        stream.packet_size = (*size).clamp(512, 32767) as usize;
122                    }
123                }
124                EnvChange::BeginTransaction(descriptor) => {
125                    self.transaction_descriptor = *descriptor;
126                }
127                EnvChange::CommitTransaction(_) | EnvChange::RollbackTransaction(_) => {
128                    self.transaction_descriptor = 0;
129                }
130                _ => {}
131            }
132        }
133    }
134
135    /// Returns the current transaction depth tracked by the connection.
136    pub const fn transaction_depth(&self) -> usize {
137        self.transaction_depth
138    }
139
140    pub(crate) fn increment_transaction_depth(&mut self) {
141        self.transaction_depth += 1;
142    }
143
144    pub(crate) fn decrement_transaction_depth(&mut self) {
145        self.transaction_depth = self.transaction_depth.saturating_sub(1);
146    }
147
148    pub(crate) fn clear_transaction_depth(&mut self) {
149        self.transaction_depth = 0;
150    }
151
152    pub(crate) async fn run_sql_batch(&mut self, sql: &str) -> Result<QueryOutput, Error> {
153        self.flush_pending_rollback().await?;
154        self.run_sql_batch_direct(sql).await
155    }
156
157    async fn run_sql_batch_direct(&mut self, sql: &str) -> Result<QueryOutput, Error> {
158        let transaction_descriptor = self.transaction_descriptor;
159        let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
160        let packet = build_sql_batch_packet(sql, stream.packet_size, transaction_descriptor)
161            .map_err(frame_error)?;
162        stream
163            .write_all(&packet)
164            .await
165            .map_err(|error| context_error("failed to send SQL Server SQL batch packet", error))?;
166
167        self.read_query_response("SQL batch").await
168    }
169
170    pub(crate) fn queue_rollback(&mut self) {
171        let sql = match self.transaction_depth {
172            0 => return,
173            1 => {
174                self.transaction_depth = 0;
175                "ROLLBACK TRANSACTION"
176            }
177            _ => {
178                self.transaction_depth -= 1;
179                "ROLLBACK TRANSACTION sqlx_savepoint"
180            }
181        };
182
183        self.pending_rollback_sql = Some(sql);
184    }
185
186    async fn flush_pending_rollback(&mut self) -> Result<(), Error> {
187        let Some(sql) = self.pending_rollback_sql.take() else {
188            return Ok(());
189        };
190
191        self.run_sql_batch_direct(sql).await?;
192        Ok(())
193    }
194
195    pub(crate) async fn run_execute_sql(
196        &mut self,
197        sql: &str,
198        arguments: Option<&MssqlArguments>,
199    ) -> Result<QueryOutput, Error> {
200        self.flush_pending_rollback().await?;
201
202        match arguments {
203            Some(arguments) if !arguments.is_empty() => {
204                let transaction_descriptor = self.transaction_descriptor;
205                let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
206                let packet = build_execute_sql_packet(
207                    sql,
208                    arguments,
209                    stream.packet_size,
210                    transaction_descriptor,
211                )
212                .map_err(|error| {
213                    Error::Protocol(format!("failed to encode SQL Server RPC: {error}"))
214                })?;
215                stream.write_all(&packet).await.map_err(|error| {
216                    context_error("failed to send SQL Server RPC execute packet", error)
217                })?;
218                self.read_query_response("RPC execute").await
219            }
220            _ => self.run_sql_batch_direct(sql).await,
221        }
222    }
223
224    pub(crate) async fn run_prepare(
225        &mut self,
226        sql: &str,
227        parameters: &[MssqlTypeInfo],
228    ) -> Result<QueryOutput, Error> {
229        self.flush_pending_rollback().await?;
230
231        let transaction_descriptor = self.transaction_descriptor;
232        let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
233        let packet =
234            build_prepare_packet(sql, parameters, stream.packet_size, transaction_descriptor)
235                .map_err(|error| {
236                    Error::Protocol(format!("failed to encode SQL Server prepare RPC: {error}"))
237                })?;
238        stream.write_all(&packet).await.map_err(|error| {
239            context_error("failed to send SQL Server prepare RPC packet", error)
240        })?;
241
242        let output = self.read_query_response("prepare RPC").await?;
243
244        if let Some(statement_id) = first_i32_return_value(&output)? {
245            let transaction_descriptor = self.transaction_descriptor;
246            let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
247            let packet =
248                build_unprepare_packet(statement_id, stream.packet_size, transaction_descriptor)
249                    .map_err(|error| {
250                        Error::Protocol(format!(
251                            "failed to encode SQL Server unprepare RPC: {error}"
252                        ))
253                    })?;
254            stream.write_all(&packet).await.map_err(|error| {
255                context_error("failed to send SQL Server unprepare RPC packet", error)
256            })?;
257            let _ = self.read_query_response("unprepare RPC").await?;
258        }
259
260        Ok(output)
261    }
262
263    async fn read_query_response(&mut self, operation: &'static str) -> Result<QueryOutput, Error> {
264        let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
265        let response = stream.read_message().await.map_err(|error| {
266            context_error(
267                format!("failed to read SQL Server {operation} response"),
268                error,
269            )
270        })?;
271        if response.packet_type != PacketType::TABULAR_RESULT {
272            return Err(Error::Protocol(format!(
273                "expected SQL Server query response as tabular result, got packet type 0x{:02x}",
274                response.packet_type.code()
275            )));
276        }
277
278        let output = parse_query_response(&response.payload).map_err(|error| {
279            Error::Protocol(format!(
280                "failed to parse SQL Server {operation} response: {error}"
281            ))
282        })?;
283        self.apply_env_changes(&output.env_changes);
284        Ok(output)
285    }
286}
287
288impl Connection for MssqlConnection {
289    type Database = Mssql;
290    type Options = MssqlConnectOptions;
291
292    async fn close(mut self) -> Result<(), Error> {
293        self.flush_pending_rollback().await?;
294
295        if let Some(mut stream) = self.stream.take() {
296            stream.shutdown().await?;
297        }
298
299        Ok(())
300    }
301
302    async fn close_hard(mut self) -> Result<(), Error> {
303        if let Some(mut stream) = self.stream.take() {
304            stream.shutdown().await?;
305        }
306
307        Ok(())
308    }
309
310    async fn ping(&mut self) -> Result<(), Error> {
311        self.flush_pending_rollback().await?;
312
313        if self.stream.is_some() {
314            Ok(())
315        } else {
316            Err(wire_not_implemented())
317        }
318    }
319
320    fn begin(
321        &mut self,
322    ) -> impl std::future::Future<Output = Result<Transaction<'_, Self::Database>, Error>> + Send + '_
323    {
324        Transaction::begin(self, None)
325    }
326
327    fn shrink_buffers(&mut self) {}
328
329    async fn flush(&mut self) -> Result<(), Error> {
330        Ok(())
331    }
332
333    fn should_flush(&self) -> bool {
334        false
335    }
336}
337
338impl<'c> Executor<'c> for &'c mut MssqlConnection {
339    type Database = Mssql;
340
341    fn fetch_many<'e, 'q, E>(
342        self,
343        mut query: E,
344    ) -> BoxStream<'e, Result<Either<MssqlQueryResult, MssqlRow>, Error>>
345    where
346        'c: 'e,
347        E: Execute<'q, Self::Database>,
348        'q: 'e,
349        E: 'q,
350    {
351        let arguments = query.take_arguments().map_err(Error::Encode);
352        let sql = query.sql();
353
354        stream::once(async move {
355            let arguments = arguments?;
356            self.run_execute_sql(sql.as_str(), arguments.as_ref()).await
357        })
358        .map(|result| match result {
359            Ok(output) => stream_query_output(output),
360            Err(error) => stream::once(future::ready(Err(error))).boxed(),
361        })
362        .flatten()
363        .boxed()
364    }
365
366    fn fetch_optional<'e, 'q, E>(
367        self,
368        mut query: E,
369    ) -> BoxFuture<'e, Result<Option<MssqlRow>, Error>>
370    where
371        'c: 'e,
372        E: Execute<'q, Self::Database>,
373        'q: 'e,
374        E: 'q,
375    {
376        let arguments = query.take_arguments().map_err(Error::Encode);
377        let sql = query.sql();
378
379        Box::pin(async move {
380            let arguments = arguments?;
381            Ok(self
382                .run_execute_sql(sql.as_str(), arguments.as_ref())
383                .await?
384                .rows
385                .into_iter()
386                .next())
387        })
388    }
389
390    fn prepare_with<'e>(
391        self,
392        sql: sqlx_core::sql_str::SqlStr,
393        parameters: &'e [crate::MssqlTypeInfo],
394    ) -> BoxFuture<'e, Result<MssqlStatement, Error>>
395    where
396        'c: 'e,
397    {
398        Box::pin(async move {
399            let output = self.run_prepare(sql.as_str(), parameters).await?;
400            let parameters = if parameters.is_empty() {
401                None
402            } else {
403                Some(Either::Left(parameters.to_vec()))
404            };
405
406            Ok(MssqlStatement::with_parameters(
407                sql,
408                output.columns,
409                parameters,
410            ))
411        })
412    }
413}
414
415fn first_i32_return_value(output: &QueryOutput) -> Result<Option<i32>, Error> {
416    output
417        .return_values
418        .first()
419        .map(|value| {
420            <i32 as Decode<Mssql>>::decode(value.as_ref()).map_err(|error| Error::ColumnDecode {
421                index: "return value".to_owned(),
422                source: error,
423            })
424        })
425        .transpose()
426}
427
428pub(crate) fn wire_not_implemented() -> Error {
429    Error::Protocol("SQL Server connection stream is not available".to_owned())
430}
431
432struct MssqlWireStream {
433    stream: MssqlStream,
434    packet_size: usize,
435}
436
437impl std::fmt::Debug for MssqlWireStream {
438    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
439        f.debug_struct("MssqlWireStream")
440            .field("encrypted", &matches!(self.stream, MssqlStream::Tls(_)))
441            .field("packet_size", &self.packet_size)
442            .finish()
443    }
444}
445
446enum MssqlStream {
447    Raw(TcpStream),
448    Tls(tokio_native_tls::TlsStream<TlsPreloginStream<TcpStream>>),
449    Taken,
450}
451
452impl MssqlWireStream {
453    async fn connect(options: &MssqlConnectOptions) -> Result<Self, Error> {
454        let port = match (options.port(), options.instance()) {
455            (Some(port), _) => port,
456            (None, Some(instance)) => ssrp::resolve_instance_port(options.host(), instance).await?,
457            (None, None) => 1433,
458        };
459
460        let stream = TcpStream::connect((options.host(), port))
461            .await
462            .map_err(|error| {
463                Error::Io(std::io::Error::new(
464                    error.kind(),
465                    format!(
466                        "failed to connect to SQL Server at {}:{port}{}: {error}",
467                        options.host(),
468                        options
469                            .instance()
470                            .map(|instance| format!(" (instance={instance})"))
471                            .unwrap_or_default()
472                    ),
473                ))
474            })?;
475        let packet_size = usize::try_from(options.requested_packet_size()).map_err(|_| {
476            Error::Protocol(format!(
477                "SQL Server packet size {} does not fit usize",
478                options.requested_packet_size()
479            ))
480        })?;
481
482        Ok(Self {
483            stream: MssqlStream::Raw(stream),
484            packet_size,
485        })
486    }
487
488    async fn write_all(&mut self, bytes: &[u8]) -> Result<(), Error> {
489        match &mut self.stream {
490            MssqlStream::Raw(stream) => {
491                write_tds_packets(stream, bytes).await?;
492            }
493            MssqlStream::Tls(stream) => {
494                write_tds_packets(stream, bytes).await?;
495            }
496            MssqlStream::Taken => return Err(taken_stream_error()),
497        }
498        Ok(())
499    }
500
501    async fn shutdown(&mut self) -> Result<(), Error> {
502        match &mut self.stream {
503            MssqlStream::Raw(stream) => stream.shutdown().await?,
504            MssqlStream::Tls(stream) => stream.shutdown().await?,
505            MssqlStream::Taken => return Err(taken_stream_error()),
506        }
507        Ok(())
508    }
509
510    async fn enable_tls(&mut self, options: &MssqlConnectOptions) -> Result<(), Error> {
511        let stream = match std::mem::replace(&mut self.stream, MssqlStream::Taken) {
512            MssqlStream::Raw(stream) => stream,
513            other => {
514                self.stream = other;
515                return Ok(());
516            }
517        };
518
519        let mut stream = TlsPreloginStream::new(stream);
520        stream.start_handshake();
521
522        let domain = options
523            .hostname_in_certificate()
524            .unwrap_or_else(|| options.host());
525        let connector = build_tls_connector(options)?;
526        let mut stream = connector
527            .connect(domain, stream)
528            .await
529            .map_err(|error| {
530                Error::Tls(
531                    std::io::Error::other(format!(
532                        "SQL Server TLS handshake failed for host `{}` during the TDS PRELOGIN encryption upgrade \
533                         (encrypt={:?}, trust_server_certificate={}, hostname_in_certificate={}, ssl_root_cert={}): {}",
534                        domain,
535                        options.encrypt(),
536                        options.trust_server_certificate(),
537                        options.hostname_in_certificate().unwrap_or("<not set>"),
538                        options.ssl_root_cert().is_some(),
539                        error
540                    ))
541                    .into(),
542                )
543            })?;
544        stream.get_mut().get_mut().get_mut().finish_handshake();
545
546        self.stream = MssqlStream::Tls(stream);
547        Ok(())
548    }
549
550    async fn read_message(&mut self) -> Result<WireMessage, Error> {
551        let mut packet_type = None;
552        let mut expected_packet_id = None;
553        let mut payload = Vec::new();
554
555        loop {
556            let mut header_bytes = [0u8; PACKET_HEADER_LEN];
557            self.read_exact(&mut header_bytes).await?;
558            let header = PacketHeader::decode(&header_bytes).map_err(packet_error)?;
559
560            if let Some(packet_type) = packet_type {
561                if header.packet_type != packet_type {
562                    return Err(Error::Protocol(format!(
563                        "mismatched SQL Server packet type: expected 0x{:02x}, got 0x{:02x}",
564                        packet_type.code(),
565                        header.packet_type.code()
566                    )));
567                }
568            } else {
569                packet_type = Some(header.packet_type);
570            }
571
572            if let Some(packet_id) = expected_packet_id {
573                if header.packet_id != packet_id {
574                    return Err(Error::Protocol(format!(
575                        "non-contiguous SQL Server packet id: expected {packet_id}, got {}",
576                        header.packet_id
577                    )));
578                }
579            }
580
581            let packet_len = usize::from(header.length);
582            if packet_len > self.packet_size {
583                return Err(Error::Protocol(format!(
584                    "SQL Server packet length {packet_len} exceeds negotiated packet size {}",
585                    self.packet_size
586                )));
587            }
588
589            let payload_len = packet_len.checked_sub(PACKET_HEADER_LEN).ok_or_else(|| {
590                Error::Protocol("SQL Server packet header length underflow".to_owned())
591            })?;
592            let old_len = payload.len();
593            payload.resize(old_len + payload_len, 0);
594            self.read_exact(&mut payload[old_len..]).await?;
595
596            expected_packet_id = Some(header.packet_id.wrapping_add(1));
597
598            if header.status == PacketStatus::END_OF_MESSAGE {
599                return Ok(WireMessage {
600                    packet_type: packet_type.expect("packet_type is set after first header"),
601                    payload,
602                });
603            }
604        }
605    }
606
607    async fn read_exact(&mut self, bytes: &mut [u8]) -> Result<(), Error> {
608        match &mut self.stream {
609            MssqlStream::Raw(stream) => {
610                stream.read_exact(bytes).await?;
611            }
612            MssqlStream::Tls(stream) => {
613                stream.read_exact(bytes).await?;
614            }
615            MssqlStream::Taken => return Err(taken_stream_error()),
616        }
617
618        Ok(())
619    }
620}
621
622async fn write_tds_packets<S>(stream: &mut S, bytes: &[u8]) -> Result<(), Error>
623where
624    S: AsyncWrite + Unpin,
625{
626    let mut offset = 0usize;
627
628    while offset < bytes.len() {
629        let packet = tds_packet_slice(bytes, offset)?;
630        stream.write_all(packet).await?;
631        offset += packet.len();
632    }
633
634    stream.flush().await?;
635    Ok(())
636}
637
638fn tds_packet_slice(bytes: &[u8], offset: usize) -> Result<&[u8], Error> {
639    let header_end = offset
640        .checked_add(PACKET_HEADER_LEN)
641        .ok_or_else(|| Error::Protocol("SQL Server outbound packet offset overflow".to_owned()))?;
642    let header_bytes = bytes.get(offset..header_end).ok_or_else(|| {
643        Error::Protocol("SQL Server outbound packet buffer ended inside a header".to_owned())
644    })?;
645    let header = PacketHeader::decode(header_bytes).map_err(packet_error)?;
646    let packet_len = usize::from(header.length);
647    let packet_end = offset
648        .checked_add(packet_len)
649        .ok_or_else(|| Error::Protocol("SQL Server outbound packet length overflow".to_owned()))?;
650
651    bytes.get(offset..packet_end).ok_or_else(|| {
652        Error::Protocol("SQL Server outbound packet buffer ended inside a packet".to_owned())
653    })
654}
655
656#[derive(Debug)]
657struct WireMessage {
658    packet_type: PacketType,
659    payload: Vec<u8>,
660}
661
662fn negotiate_encryption(requested: Encrypt, server: Encrypt) -> std::result::Result<bool, Error> {
663    match (requested, server) {
664        (Encrypt::NotSupported, Encrypt::NotSupported | Encrypt::Off) => Ok(false),
665        (Encrypt::NotSupported, Encrypt::On | Encrypt::Required) => Err(Error::Protocol(
666            "SQL Server requires encryption, but the client URL requested encrypt=not_supported"
667                .to_owned(),
668        )),
669        (Encrypt::Required, Encrypt::Off | Encrypt::NotSupported) => Err(Error::Tls(
670            "SQL Server TLS encryption is required but not supported by the server".into(),
671        )),
672        (Encrypt::On | Encrypt::Required, Encrypt::On | Encrypt::Required) => Ok(true),
673        (Encrypt::Off, _) | (_, Encrypt::Off) => Err(Error::Protocol(
674            "SQL Server login-only TLS fallback is not implemented yet; use encrypt=mandatory or encrypt=strict for encrypted connections, or encrypt=not_supported for plaintext development servers"
675                .to_owned(),
676        )),
677        (Encrypt::On, Encrypt::NotSupported) => Ok(false),
678    }
679}
680
681fn build_tls_connector(options: &MssqlConnectOptions) -> Result<TlsConnector, Error> {
682    let mut builder = native_tls::TlsConnector::builder();
683    builder.danger_accept_invalid_certs(options.trust_server_certificate());
684    builder.danger_accept_invalid_hostnames(options.hostname_in_certificate().is_none());
685
686    if let Some(path) = options.ssl_root_cert() {
687        let cert = std::fs::read(path).map_err(|error| {
688            Error::Io(std::io::Error::new(
689                error.kind(),
690                format!(
691                    "failed to read SQL Server ssl_root_cert `{}`: {error}",
692                    path.display()
693                ),
694            ))
695        })?;
696        let cert = Certificate::from_pem(&cert)
697            .or_else(|_| Certificate::from_der(&cert))
698            .map_err(|error| {
699                Error::Tls(
700                    format!(
701                        "failed to parse SQL Server ssl_root_cert `{}` as PEM or DER: {error}",
702                        path.display()
703                    )
704                    .into(),
705                )
706            })?;
707        builder.add_root_certificate(cert);
708    }
709
710    builder.build().map(TlsConnector::from).map_err(|error| {
711        Error::Tls(format!("failed to build SQL Server TLS connector: {error}").into())
712    })
713}
714
715fn taken_stream_error() -> Error {
716    Error::Protocol("SQL Server stream was used while TLS upgrade was in progress".to_owned())
717}
718
719fn packet_error(error: crate::protocol::packet::PacketHeaderError) -> Error {
720    Error::Protocol(error.to_string())
721}
722
723fn frame_error(error: crate::protocol::packet::PacketFrameError) -> Error {
724    Error::Protocol(error.to_string())
725}
726
727fn context_error(context: impl Into<String>, error: Error) -> Error {
728    let context = context.into();
729
730    match error {
731        Error::Io(error) => Error::Io(std::io::Error::new(
732            error.kind(),
733            format!("{context}: {error}"),
734        )),
735        Error::Tls(error) => Error::Tls(format!("{context}: {error}").into()),
736        Error::Protocol(message) => Error::Protocol(format!("{context}: {message}")),
737        error => Error::Protocol(format!("{context}: {error}")),
738    }
739}
740
741fn stream_query_output(
742    output: QueryOutput,
743) -> BoxStream<'static, Result<Either<MssqlQueryResult, MssqlRow>, Error>> {
744    stream::iter(
745        output
746            .rows
747            .into_iter()
748            .map(|row| Ok(Either::Right(row)))
749            .chain(std::iter::once(Ok(Either::Left(output.result)))),
750    )
751    .boxed()
752}
753
754#[cfg(test)]
755mod tests {
756    use super::*;
757
758    #[test]
759    fn negotiates_full_tls_for_required_or_mandatory_encryption() {
760        assert!(negotiate_encryption(Encrypt::On, Encrypt::On).unwrap());
761        assert!(negotiate_encryption(Encrypt::Required, Encrypt::Required).unwrap());
762    }
763
764    #[test]
765    fn allows_plaintext_only_when_explicitly_requested_and_supported() {
766        assert!(!negotiate_encryption(Encrypt::NotSupported, Encrypt::Off).unwrap());
767        assert!(negotiate_encryption(Encrypt::NotSupported, Encrypt::Required).is_err());
768    }
769
770    #[test]
771    fn rejects_login_only_tls_fallback_until_downgrade_is_available() {
772        assert!(negotiate_encryption(Encrypt::Off, Encrypt::On).is_err());
773        assert!(negotiate_encryption(Encrypt::On, Encrypt::Off).is_err());
774    }
775
776    #[test]
777    fn slices_encoded_outbound_packets_by_header_length() {
778        let bytes = crate::protocol::packet::encode_message(PacketType::RPC, &[0; 11], 12).unwrap();
779
780        let first = tds_packet_slice(&bytes, 0).unwrap();
781        assert_eq!(12, first.len());
782
783        let second = tds_packet_slice(&bytes, first.len()).unwrap();
784        assert_eq!(12, second.len());
785
786        let third = tds_packet_slice(&bytes, first.len() + second.len()).unwrap();
787        assert_eq!(11, third.len());
788    }
789
790    #[test]
791    fn rejects_truncated_outbound_packet() {
792        let bytes = crate::protocol::packet::encode_message(PacketType::RPC, &[0; 11], 12).unwrap();
793        let err = tds_packet_slice(&bytes[..bytes.len() - 1], 24).unwrap_err();
794
795        assert!(err.to_string().contains("ended inside a packet"));
796    }
797}