1use futures_core::future::BoxFuture;
2use futures_core::stream::BoxStream;
3use futures_util::{future, stream, StreamExt};
4use native_tls::Certificate;
5use sqlx_core::connection::Connection;
6use sqlx_core::decode::Decode;
7use sqlx_core::error::Error;
8use sqlx_core::executor::{Execute, Executor};
9use sqlx_core::transaction::Transaction;
10use sqlx_core::value::Value;
11use sqlx_core::Either;
12use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
13use tokio::net::TcpStream;
14use tokio_native_tls::TlsConnector;
15
16use crate::error::server_error;
17use crate::protocol::login::build_login7_packet;
18use crate::protocol::packet::{PacketHeader, PacketStatus, PacketType, PACKET_HEADER_LEN};
19use crate::protocol::pre_login::{build_pre_login_packet, parse_server_encrypt};
20use crate::protocol::query::{build_sql_batch_packet, parse_query_response, QueryOutput};
21use crate::protocol::rpc::{
22 build_execute_sql_packet, build_prepare_packet, build_unprepare_packet,
23};
24use crate::protocol::token::{parse_login_response, EnvChange, LoginResponse};
25use crate::tls::TlsPreloginStream;
26use crate::{
27 ssrp, Encrypt, Mssql, MssqlArguments, MssqlConnectOptions, MssqlQueryResult, MssqlRow,
28 MssqlStatement, MssqlTypeInfo,
29};
30
31#[derive(Debug)]
33pub struct MssqlConnection {
34 stream: Option<MssqlWireStream>,
35 transaction_depth: usize,
36 transaction_descriptor: u64,
37 pending_rollback_sql: Option<&'static str>,
38}
39
40impl MssqlConnection {
41 pub async fn establish(options: &MssqlConnectOptions) -> Result<Self, Error> {
43 let mut stream = MssqlWireStream::connect(options).await?;
44
45 let pre_login = build_pre_login_packet(options).map_err(|error| {
46 Error::Protocol(format!(
47 "failed to build SQL Server PRELOGIN packet: {error}"
48 ))
49 })?;
50 stream
51 .write_all(&pre_login)
52 .await
53 .map_err(|error| context_error("failed to send SQL Server PRELOGIN packet", error))?;
54
55 let pre_login_response = stream
56 .read_message()
57 .await
58 .map_err(|error| context_error("failed to read SQL Server PRELOGIN response", error))?;
59 if pre_login_response.packet_type != PacketType::TABULAR_RESULT {
60 return Err(Error::Protocol(format!(
61 "expected SQL Server PRELOGIN response as tabular result, got packet type 0x{:02x}",
62 pre_login_response.packet_type.code()
63 )));
64 }
65
66 let server_encrypt =
67 parse_server_encrypt(&pre_login_response.payload).map_err(|error| {
68 Error::Protocol(format!(
69 "failed to parse SQL Server PRELOGIN response: {error}"
70 ))
71 })?;
72 let encrypted = negotiate_encryption(options.encrypt(), server_encrypt)?;
73
74 if encrypted {
75 stream.enable_tls(options).await?;
76 }
77
78 let login = build_login7_packet(options).map_err(|error| {
79 Error::Protocol(format!("failed to build SQL Server LOGIN7 packet: {error}"))
80 })?;
81 stream
82 .write_all(&login)
83 .await
84 .map_err(|error| context_error("failed to send SQL Server LOGIN7 packet", error))?;
85
86 let login_response = stream
87 .read_message()
88 .await
89 .map_err(|error| context_error("failed to read SQL Server LOGIN7 response", error))?;
90 if login_response.packet_type != PacketType::TABULAR_RESULT {
91 return Err(Error::Protocol(format!(
92 "expected SQL Server LOGIN7 response as tabular result, got packet type 0x{:02x}",
93 login_response.packet_type.code()
94 )));
95 }
96
97 match parse_login_response(&login_response.payload).map_err(|error| {
98 Error::Protocol(format!(
99 "failed to parse SQL Server LOGIN7 response: {error}"
100 ))
101 })? {
102 LoginResponse::Success { env_changes, .. } => {
103 let mut conn = Self {
104 stream: Some(stream),
105 transaction_depth: 0,
106 transaction_descriptor: 0,
107 pending_rollback_sql: None,
108 };
109 conn.apply_env_changes(&env_changes);
110 Ok(conn)
111 }
112 LoginResponse::ServerError(error) => Err(server_error(error)),
113 }
114 }
115
116 fn apply_env_changes(&mut self, env_changes: &[EnvChange]) {
117 for change in env_changes {
118 match change {
119 EnvChange::PacketSize(size) => {
120 if let Some(stream) = self.stream.as_mut() {
121 stream.packet_size = (*size).clamp(512, 32767) as usize;
122 }
123 }
124 EnvChange::BeginTransaction(descriptor) => {
125 self.transaction_descriptor = *descriptor;
126 }
127 EnvChange::CommitTransaction(_) | EnvChange::RollbackTransaction(_) => {
128 self.transaction_descriptor = 0;
129 }
130 _ => {}
131 }
132 }
133 }
134
135 pub const fn transaction_depth(&self) -> usize {
137 self.transaction_depth
138 }
139
140 pub(crate) fn increment_transaction_depth(&mut self) {
141 self.transaction_depth += 1;
142 }
143
144 pub(crate) fn decrement_transaction_depth(&mut self) {
145 self.transaction_depth = self.transaction_depth.saturating_sub(1);
146 }
147
148 pub(crate) fn clear_transaction_depth(&mut self) {
149 self.transaction_depth = 0;
150 }
151
152 pub(crate) async fn run_sql_batch(&mut self, sql: &str) -> Result<QueryOutput, Error> {
153 self.flush_pending_rollback().await?;
154 self.run_sql_batch_direct(sql).await
155 }
156
157 async fn run_sql_batch_direct(&mut self, sql: &str) -> Result<QueryOutput, Error> {
158 let transaction_descriptor = self.transaction_descriptor;
159 let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
160 let packet = build_sql_batch_packet(sql, stream.packet_size, transaction_descriptor)
161 .map_err(frame_error)?;
162 stream
163 .write_all(&packet)
164 .await
165 .map_err(|error| context_error("failed to send SQL Server SQL batch packet", error))?;
166
167 self.read_query_response("SQL batch").await
168 }
169
170 pub(crate) fn queue_rollback(&mut self) {
171 let sql = match self.transaction_depth {
172 0 => return,
173 1 => {
174 self.transaction_depth = 0;
175 "ROLLBACK TRANSACTION"
176 }
177 _ => {
178 self.transaction_depth -= 1;
179 "ROLLBACK TRANSACTION sqlx_savepoint"
180 }
181 };
182
183 self.pending_rollback_sql = Some(sql);
184 }
185
186 async fn flush_pending_rollback(&mut self) -> Result<(), Error> {
187 let Some(sql) = self.pending_rollback_sql.take() else {
188 return Ok(());
189 };
190
191 self.run_sql_batch_direct(sql).await?;
192 Ok(())
193 }
194
195 pub(crate) async fn run_execute_sql(
196 &mut self,
197 sql: &str,
198 arguments: Option<&MssqlArguments>,
199 ) -> Result<QueryOutput, Error> {
200 self.flush_pending_rollback().await?;
201
202 match arguments {
203 Some(arguments) if !arguments.is_empty() => {
204 let transaction_descriptor = self.transaction_descriptor;
205 let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
206 let packet = build_execute_sql_packet(
207 sql,
208 arguments,
209 stream.packet_size,
210 transaction_descriptor,
211 )
212 .map_err(|error| {
213 Error::Protocol(format!("failed to encode SQL Server RPC: {error}"))
214 })?;
215 stream.write_all(&packet).await.map_err(|error| {
216 context_error("failed to send SQL Server RPC execute packet", error)
217 })?;
218 self.read_query_response("RPC execute").await
219 }
220 _ => self.run_sql_batch_direct(sql).await,
221 }
222 }
223
224 pub(crate) async fn run_prepare(
225 &mut self,
226 sql: &str,
227 parameters: &[MssqlTypeInfo],
228 ) -> Result<QueryOutput, Error> {
229 self.flush_pending_rollback().await?;
230
231 let transaction_descriptor = self.transaction_descriptor;
232 let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
233 let packet =
234 build_prepare_packet(sql, parameters, stream.packet_size, transaction_descriptor)
235 .map_err(|error| {
236 Error::Protocol(format!("failed to encode SQL Server prepare RPC: {error}"))
237 })?;
238 stream.write_all(&packet).await.map_err(|error| {
239 context_error("failed to send SQL Server prepare RPC packet", error)
240 })?;
241
242 let output = self.read_query_response("prepare RPC").await?;
243
244 if let Some(statement_id) = first_i32_return_value(&output)? {
245 let transaction_descriptor = self.transaction_descriptor;
246 let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
247 let packet =
248 build_unprepare_packet(statement_id, stream.packet_size, transaction_descriptor)
249 .map_err(|error| {
250 Error::Protocol(format!(
251 "failed to encode SQL Server unprepare RPC: {error}"
252 ))
253 })?;
254 stream.write_all(&packet).await.map_err(|error| {
255 context_error("failed to send SQL Server unprepare RPC packet", error)
256 })?;
257 let _ = self.read_query_response("unprepare RPC").await?;
258 }
259
260 Ok(output)
261 }
262
263 async fn read_query_response(&mut self, operation: &'static str) -> Result<QueryOutput, Error> {
264 let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
265 let response = stream.read_message().await.map_err(|error| {
266 context_error(
267 format!("failed to read SQL Server {operation} response"),
268 error,
269 )
270 })?;
271 if response.packet_type != PacketType::TABULAR_RESULT {
272 return Err(Error::Protocol(format!(
273 "expected SQL Server query response as tabular result, got packet type 0x{:02x}",
274 response.packet_type.code()
275 )));
276 }
277
278 let output = parse_query_response(&response.payload).map_err(|error| {
279 Error::Protocol(format!(
280 "failed to parse SQL Server {operation} response: {error}"
281 ))
282 })?;
283 self.apply_env_changes(&output.env_changes);
284 Ok(output)
285 }
286}
287
288impl Connection for MssqlConnection {
289 type Database = Mssql;
290 type Options = MssqlConnectOptions;
291
292 async fn close(mut self) -> Result<(), Error> {
293 self.flush_pending_rollback().await?;
294
295 if let Some(mut stream) = self.stream.take() {
296 stream.shutdown().await?;
297 }
298
299 Ok(())
300 }
301
302 async fn close_hard(mut self) -> Result<(), Error> {
303 if let Some(mut stream) = self.stream.take() {
304 stream.shutdown().await?;
305 }
306
307 Ok(())
308 }
309
310 async fn ping(&mut self) -> Result<(), Error> {
311 self.flush_pending_rollback().await?;
312
313 if self.stream.is_some() {
314 Ok(())
315 } else {
316 Err(wire_not_implemented())
317 }
318 }
319
320 fn begin(
321 &mut self,
322 ) -> impl std::future::Future<Output = Result<Transaction<'_, Self::Database>, Error>> + Send + '_
323 {
324 Transaction::begin(self, None)
325 }
326
327 fn shrink_buffers(&mut self) {}
328
329 async fn flush(&mut self) -> Result<(), Error> {
330 Ok(())
331 }
332
333 fn should_flush(&self) -> bool {
334 false
335 }
336}
337
338impl<'c> Executor<'c> for &'c mut MssqlConnection {
339 type Database = Mssql;
340
341 fn fetch_many<'e, 'q, E>(
342 self,
343 mut query: E,
344 ) -> BoxStream<'e, Result<Either<MssqlQueryResult, MssqlRow>, Error>>
345 where
346 'c: 'e,
347 E: Execute<'q, Self::Database>,
348 'q: 'e,
349 E: 'q,
350 {
351 let arguments = query.take_arguments().map_err(Error::Encode);
352 let sql = query.sql();
353
354 stream::once(async move {
355 let arguments = arguments?;
356 self.run_execute_sql(sql.as_str(), arguments.as_ref()).await
357 })
358 .map(|result| match result {
359 Ok(output) => stream_query_output(output),
360 Err(error) => stream::once(future::ready(Err(error))).boxed(),
361 })
362 .flatten()
363 .boxed()
364 }
365
366 fn fetch_optional<'e, 'q, E>(
367 self,
368 mut query: E,
369 ) -> BoxFuture<'e, Result<Option<MssqlRow>, Error>>
370 where
371 'c: 'e,
372 E: Execute<'q, Self::Database>,
373 'q: 'e,
374 E: 'q,
375 {
376 let arguments = query.take_arguments().map_err(Error::Encode);
377 let sql = query.sql();
378
379 Box::pin(async move {
380 let arguments = arguments?;
381 Ok(self
382 .run_execute_sql(sql.as_str(), arguments.as_ref())
383 .await?
384 .rows
385 .into_iter()
386 .next())
387 })
388 }
389
390 fn prepare_with<'e>(
391 self,
392 sql: sqlx_core::sql_str::SqlStr,
393 parameters: &'e [crate::MssqlTypeInfo],
394 ) -> BoxFuture<'e, Result<MssqlStatement, Error>>
395 where
396 'c: 'e,
397 {
398 Box::pin(async move {
399 let output = self.run_prepare(sql.as_str(), parameters).await?;
400 let parameters = if parameters.is_empty() {
401 None
402 } else {
403 Some(Either::Left(parameters.to_vec()))
404 };
405
406 Ok(MssqlStatement::with_parameters(
407 sql,
408 output.columns,
409 parameters,
410 ))
411 })
412 }
413}
414
415fn first_i32_return_value(output: &QueryOutput) -> Result<Option<i32>, Error> {
416 output
417 .return_values
418 .first()
419 .map(|value| {
420 <i32 as Decode<Mssql>>::decode(value.as_ref()).map_err(|error| Error::ColumnDecode {
421 index: "return value".to_owned(),
422 source: error,
423 })
424 })
425 .transpose()
426}
427
428pub(crate) fn wire_not_implemented() -> Error {
429 Error::Protocol("SQL Server connection stream is not available".to_owned())
430}
431
432struct MssqlWireStream {
433 stream: MssqlStream,
434 packet_size: usize,
435}
436
437impl std::fmt::Debug for MssqlWireStream {
438 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
439 f.debug_struct("MssqlWireStream")
440 .field("encrypted", &matches!(self.stream, MssqlStream::Tls(_)))
441 .field("packet_size", &self.packet_size)
442 .finish()
443 }
444}
445
446enum MssqlStream {
447 Raw(TcpStream),
448 Tls(tokio_native_tls::TlsStream<TlsPreloginStream<TcpStream>>),
449 Taken,
450}
451
452impl MssqlWireStream {
453 async fn connect(options: &MssqlConnectOptions) -> Result<Self, Error> {
454 let port = match (options.port(), options.instance()) {
455 (Some(port), _) => port,
456 (None, Some(instance)) => ssrp::resolve_instance_port(options.host(), instance).await?,
457 (None, None) => 1433,
458 };
459
460 let stream = TcpStream::connect((options.host(), port))
461 .await
462 .map_err(|error| {
463 Error::Io(std::io::Error::new(
464 error.kind(),
465 format!(
466 "failed to connect to SQL Server at {}:{port}{}: {error}",
467 options.host(),
468 options
469 .instance()
470 .map(|instance| format!(" (instance={instance})"))
471 .unwrap_or_default()
472 ),
473 ))
474 })?;
475 let packet_size = usize::try_from(options.requested_packet_size()).map_err(|_| {
476 Error::Protocol(format!(
477 "SQL Server packet size {} does not fit usize",
478 options.requested_packet_size()
479 ))
480 })?;
481
482 Ok(Self {
483 stream: MssqlStream::Raw(stream),
484 packet_size,
485 })
486 }
487
488 async fn write_all(&mut self, bytes: &[u8]) -> Result<(), Error> {
489 match &mut self.stream {
490 MssqlStream::Raw(stream) => {
491 write_tds_packets(stream, bytes).await?;
492 }
493 MssqlStream::Tls(stream) => {
494 write_tds_packets(stream, bytes).await?;
495 }
496 MssqlStream::Taken => return Err(taken_stream_error()),
497 }
498 Ok(())
499 }
500
501 async fn shutdown(&mut self) -> Result<(), Error> {
502 match &mut self.stream {
503 MssqlStream::Raw(stream) => stream.shutdown().await?,
504 MssqlStream::Tls(stream) => stream.shutdown().await?,
505 MssqlStream::Taken => return Err(taken_stream_error()),
506 }
507 Ok(())
508 }
509
510 async fn enable_tls(&mut self, options: &MssqlConnectOptions) -> Result<(), Error> {
511 let stream = match std::mem::replace(&mut self.stream, MssqlStream::Taken) {
512 MssqlStream::Raw(stream) => stream,
513 other => {
514 self.stream = other;
515 return Ok(());
516 }
517 };
518
519 let mut stream = TlsPreloginStream::new(stream);
520 stream.start_handshake();
521
522 let domain = options
523 .hostname_in_certificate()
524 .unwrap_or_else(|| options.host());
525 let connector = build_tls_connector(options)?;
526 let mut stream = connector
527 .connect(domain, stream)
528 .await
529 .map_err(|error| {
530 Error::Tls(
531 std::io::Error::other(format!(
532 "SQL Server TLS handshake failed for host `{}` during the TDS PRELOGIN encryption upgrade \
533 (encrypt={:?}, trust_server_certificate={}, hostname_in_certificate={}, ssl_root_cert={}): {}",
534 domain,
535 options.encrypt(),
536 options.trust_server_certificate(),
537 options.hostname_in_certificate().unwrap_or("<not set>"),
538 options.ssl_root_cert().is_some(),
539 error
540 ))
541 .into(),
542 )
543 })?;
544 stream.get_mut().get_mut().get_mut().finish_handshake();
545
546 self.stream = MssqlStream::Tls(stream);
547 Ok(())
548 }
549
550 async fn read_message(&mut self) -> Result<WireMessage, Error> {
551 let mut packet_type = None;
552 let mut expected_packet_id = None;
553 let mut payload = Vec::new();
554
555 loop {
556 let mut header_bytes = [0u8; PACKET_HEADER_LEN];
557 self.read_exact(&mut header_bytes).await?;
558 let header = PacketHeader::decode(&header_bytes).map_err(packet_error)?;
559
560 if let Some(packet_type) = packet_type {
561 if header.packet_type != packet_type {
562 return Err(Error::Protocol(format!(
563 "mismatched SQL Server packet type: expected 0x{:02x}, got 0x{:02x}",
564 packet_type.code(),
565 header.packet_type.code()
566 )));
567 }
568 } else {
569 packet_type = Some(header.packet_type);
570 }
571
572 if let Some(packet_id) = expected_packet_id {
573 if header.packet_id != packet_id {
574 return Err(Error::Protocol(format!(
575 "non-contiguous SQL Server packet id: expected {packet_id}, got {}",
576 header.packet_id
577 )));
578 }
579 }
580
581 let packet_len = usize::from(header.length);
582 if packet_len > self.packet_size {
583 return Err(Error::Protocol(format!(
584 "SQL Server packet length {packet_len} exceeds negotiated packet size {}",
585 self.packet_size
586 )));
587 }
588
589 let payload_len = packet_len.checked_sub(PACKET_HEADER_LEN).ok_or_else(|| {
590 Error::Protocol("SQL Server packet header length underflow".to_owned())
591 })?;
592 let old_len = payload.len();
593 payload.resize(old_len + payload_len, 0);
594 self.read_exact(&mut payload[old_len..]).await?;
595
596 expected_packet_id = Some(header.packet_id.wrapping_add(1));
597
598 if header.status == PacketStatus::END_OF_MESSAGE {
599 return Ok(WireMessage {
600 packet_type: packet_type.expect("packet_type is set after first header"),
601 payload,
602 });
603 }
604 }
605 }
606
607 async fn read_exact(&mut self, bytes: &mut [u8]) -> Result<(), Error> {
608 match &mut self.stream {
609 MssqlStream::Raw(stream) => {
610 stream.read_exact(bytes).await?;
611 }
612 MssqlStream::Tls(stream) => {
613 stream.read_exact(bytes).await?;
614 }
615 MssqlStream::Taken => return Err(taken_stream_error()),
616 }
617
618 Ok(())
619 }
620}
621
622async fn write_tds_packets<S>(stream: &mut S, bytes: &[u8]) -> Result<(), Error>
623where
624 S: AsyncWrite + Unpin,
625{
626 let mut offset = 0usize;
627
628 while offset < bytes.len() {
629 let packet = tds_packet_slice(bytes, offset)?;
630 stream.write_all(packet).await?;
631 offset += packet.len();
632 }
633
634 stream.flush().await?;
635 Ok(())
636}
637
638fn tds_packet_slice(bytes: &[u8], offset: usize) -> Result<&[u8], Error> {
639 let header_end = offset
640 .checked_add(PACKET_HEADER_LEN)
641 .ok_or_else(|| Error::Protocol("SQL Server outbound packet offset overflow".to_owned()))?;
642 let header_bytes = bytes.get(offset..header_end).ok_or_else(|| {
643 Error::Protocol("SQL Server outbound packet buffer ended inside a header".to_owned())
644 })?;
645 let header = PacketHeader::decode(header_bytes).map_err(packet_error)?;
646 let packet_len = usize::from(header.length);
647 let packet_end = offset
648 .checked_add(packet_len)
649 .ok_or_else(|| Error::Protocol("SQL Server outbound packet length overflow".to_owned()))?;
650
651 bytes.get(offset..packet_end).ok_or_else(|| {
652 Error::Protocol("SQL Server outbound packet buffer ended inside a packet".to_owned())
653 })
654}
655
656#[derive(Debug)]
657struct WireMessage {
658 packet_type: PacketType,
659 payload: Vec<u8>,
660}
661
662fn negotiate_encryption(requested: Encrypt, server: Encrypt) -> std::result::Result<bool, Error> {
663 match (requested, server) {
664 (Encrypt::NotSupported, Encrypt::NotSupported | Encrypt::Off) => Ok(false),
665 (Encrypt::NotSupported, Encrypt::On | Encrypt::Required) => Err(Error::Protocol(
666 "SQL Server requires encryption, but the client URL requested encrypt=not_supported"
667 .to_owned(),
668 )),
669 (Encrypt::Required, Encrypt::Off | Encrypt::NotSupported) => Err(Error::Tls(
670 "SQL Server TLS encryption is required but not supported by the server".into(),
671 )),
672 (Encrypt::On | Encrypt::Required, Encrypt::On | Encrypt::Required) => Ok(true),
673 (Encrypt::Off, _) | (_, Encrypt::Off) => Err(Error::Protocol(
674 "SQL Server login-only TLS fallback is not implemented yet; use encrypt=mandatory or encrypt=strict for encrypted connections, or encrypt=not_supported for plaintext development servers"
675 .to_owned(),
676 )),
677 (Encrypt::On, Encrypt::NotSupported) => Ok(false),
678 }
679}
680
681fn build_tls_connector(options: &MssqlConnectOptions) -> Result<TlsConnector, Error> {
682 let mut builder = native_tls::TlsConnector::builder();
683 builder.danger_accept_invalid_certs(options.trust_server_certificate());
684 builder.danger_accept_invalid_hostnames(options.hostname_in_certificate().is_none());
685
686 if let Some(path) = options.ssl_root_cert() {
687 let cert = std::fs::read(path).map_err(|error| {
688 Error::Io(std::io::Error::new(
689 error.kind(),
690 format!(
691 "failed to read SQL Server ssl_root_cert `{}`: {error}",
692 path.display()
693 ),
694 ))
695 })?;
696 let cert = Certificate::from_pem(&cert)
697 .or_else(|_| Certificate::from_der(&cert))
698 .map_err(|error| {
699 Error::Tls(
700 format!(
701 "failed to parse SQL Server ssl_root_cert `{}` as PEM or DER: {error}",
702 path.display()
703 )
704 .into(),
705 )
706 })?;
707 builder.add_root_certificate(cert);
708 }
709
710 builder.build().map(TlsConnector::from).map_err(|error| {
711 Error::Tls(format!("failed to build SQL Server TLS connector: {error}").into())
712 })
713}
714
715fn taken_stream_error() -> Error {
716 Error::Protocol("SQL Server stream was used while TLS upgrade was in progress".to_owned())
717}
718
719fn packet_error(error: crate::protocol::packet::PacketHeaderError) -> Error {
720 Error::Protocol(error.to_string())
721}
722
723fn frame_error(error: crate::protocol::packet::PacketFrameError) -> Error {
724 Error::Protocol(error.to_string())
725}
726
727fn context_error(context: impl Into<String>, error: Error) -> Error {
728 let context = context.into();
729
730 match error {
731 Error::Io(error) => Error::Io(std::io::Error::new(
732 error.kind(),
733 format!("{context}: {error}"),
734 )),
735 Error::Tls(error) => Error::Tls(format!("{context}: {error}").into()),
736 Error::Protocol(message) => Error::Protocol(format!("{context}: {message}")),
737 error => Error::Protocol(format!("{context}: {error}")),
738 }
739}
740
741fn stream_query_output(
742 output: QueryOutput,
743) -> BoxStream<'static, Result<Either<MssqlQueryResult, MssqlRow>, Error>> {
744 stream::iter(
745 output
746 .rows
747 .into_iter()
748 .map(|row| Ok(Either::Right(row)))
749 .chain(std::iter::once(Ok(Either::Left(output.result)))),
750 )
751 .boxed()
752}
753
754#[cfg(test)]
755mod tests {
756 use super::*;
757
758 #[test]
759 fn negotiates_full_tls_for_required_or_mandatory_encryption() {
760 assert!(negotiate_encryption(Encrypt::On, Encrypt::On).unwrap());
761 assert!(negotiate_encryption(Encrypt::Required, Encrypt::Required).unwrap());
762 }
763
764 #[test]
765 fn allows_plaintext_only_when_explicitly_requested_and_supported() {
766 assert!(!negotiate_encryption(Encrypt::NotSupported, Encrypt::Off).unwrap());
767 assert!(negotiate_encryption(Encrypt::NotSupported, Encrypt::Required).is_err());
768 }
769
770 #[test]
771 fn rejects_login_only_tls_fallback_until_downgrade_is_available() {
772 assert!(negotiate_encryption(Encrypt::Off, Encrypt::On).is_err());
773 assert!(negotiate_encryption(Encrypt::On, Encrypt::Off).is_err());
774 }
775
776 #[test]
777 fn slices_encoded_outbound_packets_by_header_length() {
778 let bytes = crate::protocol::packet::encode_message(PacketType::RPC, &[0; 11], 12).unwrap();
779
780 let first = tds_packet_slice(&bytes, 0).unwrap();
781 assert_eq!(12, first.len());
782
783 let second = tds_packet_slice(&bytes, first.len()).unwrap();
784 assert_eq!(12, second.len());
785
786 let third = tds_packet_slice(&bytes, first.len() + second.len()).unwrap();
787 assert_eq!(11, third.len());
788 }
789
790 #[test]
791 fn rejects_truncated_outbound_packet() {
792 let bytes = crate::protocol::packet::encode_message(PacketType::RPC, &[0; 11], 12).unwrap();
793 let err = tds_packet_slice(&bytes[..bytes.len() - 1], 24).unwrap_err();
794
795 assert!(err.to_string().contains("ended inside a packet"));
796 }
797}