1use futures_core::future::BoxFuture;
2use futures_core::stream::BoxStream;
3use futures_util::{future, stream, StreamExt};
4use native_tls::Certificate;
5use sqlx_core::connection::Connection;
6use sqlx_core::decode::Decode;
7use sqlx_core::error::Error;
8use sqlx_core::executor::{Execute, Executor};
9use sqlx_core::transaction::Transaction;
10use sqlx_core::value::Value;
11use sqlx_core::Either;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::TcpStream;
14use tokio_native_tls::TlsConnector;
15
16use crate::protocol::login::{build_login7_packet, Login7Error};
17use crate::protocol::packet::{PacketHeader, PacketStatus, PacketType, PACKET_HEADER_LEN};
18use crate::protocol::pre_login::{build_pre_login_packet, parse_server_encrypt, PreLoginError};
19use crate::protocol::query::{build_sql_batch_packet, parse_query_response, QueryOutput};
20use crate::protocol::rpc::{
21 build_execute_sql_packet, build_prepare_packet, build_unprepare_packet,
22};
23use crate::protocol::token::{
24 parse_login_response, EnvChange, LoginResponse, ServerError, TokenParseError,
25};
26use crate::tls::TlsPreloginStream;
27use crate::{
28 ssrp, Encrypt, Mssql, MssqlArguments, MssqlConnectOptions, MssqlQueryResult, MssqlRow,
29 MssqlStatement, MssqlTypeInfo,
30};
31
32#[derive(Debug)]
34pub struct MssqlConnection {
35 stream: Option<MssqlWireStream>,
36 transaction_depth: usize,
37 transaction_descriptor: u64,
38}
39
40impl MssqlConnection {
41 pub async fn establish(options: &MssqlConnectOptions) -> Result<Self, Error> {
43 let mut stream = MssqlWireStream::connect(options).await?;
44
45 let pre_login = build_pre_login_packet(options).map_err(pre_login_error)?;
46 stream.write_all(&pre_login).await?;
47
48 let pre_login_response = stream.read_message().await?;
49 if pre_login_response.packet_type != PacketType::TABULAR_RESULT {
50 return Err(Error::Protocol(format!(
51 "expected SQL Server PRELOGIN response as tabular result, got packet type 0x{:02x}",
52 pre_login_response.packet_type.code()
53 )));
54 }
55
56 let server_encrypt =
57 parse_server_encrypt(&pre_login_response.payload).map_err(pre_login_error)?;
58 let encrypted = negotiate_encryption(options.encrypt(), server_encrypt)?;
59
60 if encrypted {
61 stream.enable_tls(options).await?;
62 }
63
64 let login = build_login7_packet(options).map_err(login_error)?;
65 stream.write_all(&login).await?;
66
67 let login_response = stream.read_message().await?;
68 if login_response.packet_type != PacketType::TABULAR_RESULT {
69 return Err(Error::Protocol(format!(
70 "expected SQL Server LOGIN7 response as tabular result, got packet type 0x{:02x}",
71 login_response.packet_type.code()
72 )));
73 }
74
75 match parse_login_response(&login_response.payload).map_err(token_error)? {
76 LoginResponse::Success { env_changes, .. } => {
77 let mut conn = Self {
78 stream: Some(stream),
79 transaction_depth: 0,
80 transaction_descriptor: 0,
81 };
82 conn.apply_env_changes(&env_changes);
83 Ok(conn)
84 }
85 LoginResponse::ServerError(error) => Err(server_error(error)),
86 }
87 }
88
89 fn apply_env_changes(&mut self, env_changes: &[EnvChange]) {
90 for change in env_changes {
91 match change {
92 EnvChange::PacketSize(size) => {
93 if let Some(stream) = self.stream.as_mut() {
94 stream.packet_size = (*size).clamp(512, 32767) as usize;
95 }
96 }
97 EnvChange::BeginTransaction(descriptor) => {
98 self.transaction_descriptor = *descriptor;
99 }
100 EnvChange::CommitTransaction(_) | EnvChange::RollbackTransaction(_) => {
101 self.transaction_descriptor = 0;
102 }
103 _ => {}
104 }
105 }
106 }
107
108 pub const fn transaction_depth(&self) -> usize {
110 self.transaction_depth
111 }
112
113 pub(crate) fn increment_transaction_depth(&mut self) {
114 self.transaction_depth += 1;
115 }
116
117 pub(crate) fn decrement_transaction_depth(&mut self) {
118 self.transaction_depth = self.transaction_depth.saturating_sub(1);
119 }
120
121 pub(crate) fn clear_transaction_depth(&mut self) {
122 self.transaction_depth = 0;
123 }
124
125 pub(crate) async fn run_sql_batch(&mut self, sql: &str) -> Result<QueryOutput, Error> {
126 let transaction_descriptor = self.transaction_descriptor;
127 let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
128 let packet = build_sql_batch_packet(sql, stream.packet_size, transaction_descriptor)
129 .map_err(frame_error)?;
130 stream.write_all(&packet).await?;
131
132 self.read_query_response().await
133 }
134
135 pub(crate) async fn run_execute_sql(
136 &mut self,
137 sql: &str,
138 arguments: Option<&MssqlArguments>,
139 ) -> Result<QueryOutput, Error> {
140 match arguments {
141 Some(arguments) if !arguments.is_empty() => {
142 let transaction_descriptor = self.transaction_descriptor;
143 let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
144 let packet = build_execute_sql_packet(
145 sql,
146 arguments,
147 stream.packet_size,
148 transaction_descriptor,
149 )
150 .map_err(|error| {
151 Error::Protocol(format!("failed to encode SQL Server RPC: {error}"))
152 })?;
153 stream.write_all(&packet).await?;
154 self.read_query_response().await
155 }
156 _ => self.run_sql_batch(sql).await,
157 }
158 }
159
160 pub(crate) async fn run_prepare(
161 &mut self,
162 sql: &str,
163 parameters: &[MssqlTypeInfo],
164 ) -> Result<QueryOutput, Error> {
165 let transaction_descriptor = self.transaction_descriptor;
166 let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
167 let packet =
168 build_prepare_packet(sql, parameters, stream.packet_size, transaction_descriptor)
169 .map_err(|error| {
170 Error::Protocol(format!("failed to encode SQL Server prepare RPC: {error}"))
171 })?;
172 stream.write_all(&packet).await?;
173
174 let output = self.read_query_response().await?;
175
176 if let Some(statement_id) = first_i32_return_value(&output)? {
177 let transaction_descriptor = self.transaction_descriptor;
178 let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
179 let packet =
180 build_unprepare_packet(statement_id, stream.packet_size, transaction_descriptor)
181 .map_err(|error| {
182 Error::Protocol(format!(
183 "failed to encode SQL Server unprepare RPC: {error}"
184 ))
185 })?;
186 stream.write_all(&packet).await?;
187 let _ = self.read_query_response().await?;
188 }
189
190 Ok(output)
191 }
192
193 async fn read_query_response(&mut self) -> Result<QueryOutput, Error> {
194 let stream = self.stream.as_mut().ok_or_else(wire_not_implemented)?;
195 let response = stream.read_message().await?;
196 if response.packet_type != PacketType::TABULAR_RESULT {
197 return Err(Error::Protocol(format!(
198 "expected SQL Server query response as tabular result, got packet type 0x{:02x}",
199 response.packet_type.code()
200 )));
201 }
202
203 let output = parse_query_response(&response.payload)?;
204 self.apply_env_changes(&output.env_changes);
205 Ok(output)
206 }
207}
208
209impl Connection for MssqlConnection {
210 type Database = Mssql;
211 type Options = MssqlConnectOptions;
212
213 async fn close(mut self) -> Result<(), Error> {
214 if let Some(mut stream) = self.stream.take() {
215 stream.shutdown().await?;
216 }
217
218 Ok(())
219 }
220
221 async fn close_hard(mut self) -> Result<(), Error> {
222 if let Some(mut stream) = self.stream.take() {
223 stream.shutdown().await?;
224 }
225
226 Ok(())
227 }
228
229 async fn ping(&mut self) -> Result<(), Error> {
230 if self.stream.is_some() {
231 Ok(())
232 } else {
233 Err(wire_not_implemented())
234 }
235 }
236
237 fn begin(
238 &mut self,
239 ) -> impl std::future::Future<Output = Result<Transaction<'_, Self::Database>, Error>> + Send + '_
240 {
241 Transaction::begin(self, None)
242 }
243
244 fn shrink_buffers(&mut self) {}
245
246 async fn flush(&mut self) -> Result<(), Error> {
247 Ok(())
248 }
249
250 fn should_flush(&self) -> bool {
251 false
252 }
253}
254
255impl<'c> Executor<'c> for &'c mut MssqlConnection {
256 type Database = Mssql;
257
258 fn fetch_many<'e, 'q, E>(
259 self,
260 mut query: E,
261 ) -> BoxStream<'e, Result<Either<MssqlQueryResult, MssqlRow>, Error>>
262 where
263 'c: 'e,
264 E: Execute<'q, Self::Database>,
265 'q: 'e,
266 E: 'q,
267 {
268 let arguments = query.take_arguments().map_err(Error::Encode);
269 let sql = query.sql();
270
271 stream::once(async move {
272 let arguments = arguments?;
273 self.run_execute_sql(sql.as_str(), arguments.as_ref()).await
274 })
275 .map(|result| match result {
276 Ok(output) => stream_query_output(output),
277 Err(error) => stream::once(future::ready(Err(error))).boxed(),
278 })
279 .flatten()
280 .boxed()
281 }
282
283 fn fetch_optional<'e, 'q, E>(
284 self,
285 mut query: E,
286 ) -> BoxFuture<'e, Result<Option<MssqlRow>, Error>>
287 where
288 'c: 'e,
289 E: Execute<'q, Self::Database>,
290 'q: 'e,
291 E: 'q,
292 {
293 let arguments = query.take_arguments().map_err(Error::Encode);
294 let sql = query.sql();
295
296 Box::pin(async move {
297 let arguments = arguments?;
298 Ok(self
299 .run_execute_sql(sql.as_str(), arguments.as_ref())
300 .await?
301 .rows
302 .into_iter()
303 .next())
304 })
305 }
306
307 fn prepare_with<'e>(
308 self,
309 sql: sqlx_core::sql_str::SqlStr,
310 parameters: &'e [crate::MssqlTypeInfo],
311 ) -> BoxFuture<'e, Result<MssqlStatement, Error>>
312 where
313 'c: 'e,
314 {
315 Box::pin(async move {
316 let output = self.run_prepare(sql.as_str(), parameters).await?;
317 let parameters = if parameters.is_empty() {
318 None
319 } else {
320 Some(Either::Left(parameters.to_vec()))
321 };
322
323 Ok(MssqlStatement::with_parameters(
324 sql,
325 output.columns,
326 parameters,
327 ))
328 })
329 }
330}
331
332fn first_i32_return_value(output: &QueryOutput) -> Result<Option<i32>, Error> {
333 output
334 .return_values
335 .first()
336 .map(|value| {
337 <i32 as Decode<Mssql>>::decode(value.as_ref()).map_err(|error| Error::ColumnDecode {
338 index: "return value".to_owned(),
339 source: error,
340 })
341 })
342 .transpose()
343}
344
345pub(crate) fn wire_not_implemented() -> Error {
346 Error::Protocol("SQL Server connection stream is not available".to_owned())
347}
348
349struct MssqlWireStream {
350 stream: MssqlStream,
351 packet_size: usize,
352}
353
354impl std::fmt::Debug for MssqlWireStream {
355 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356 f.debug_struct("MssqlWireStream")
357 .field("encrypted", &matches!(self.stream, MssqlStream::Tls(_)))
358 .field("packet_size", &self.packet_size)
359 .finish()
360 }
361}
362
363enum MssqlStream {
364 Raw(TcpStream),
365 Tls(tokio_native_tls::TlsStream<TlsPreloginStream<TcpStream>>),
366 Taken,
367}
368
369impl MssqlWireStream {
370 async fn connect(options: &MssqlConnectOptions) -> Result<Self, Error> {
371 let port = match (options.port(), options.instance()) {
372 (Some(port), _) => port,
373 (None, Some(instance)) => ssrp::resolve_instance_port(options.host(), instance).await?,
374 (None, None) => 1433,
375 };
376
377 let stream = TcpStream::connect((options.host(), port)).await?;
378 let packet_size = usize::try_from(options.requested_packet_size()).map_err(|_| {
379 Error::Protocol(format!(
380 "SQL Server packet size {} does not fit usize",
381 options.requested_packet_size()
382 ))
383 })?;
384
385 Ok(Self {
386 stream: MssqlStream::Raw(stream),
387 packet_size,
388 })
389 }
390
391 async fn write_all(&mut self, bytes: &[u8]) -> Result<(), Error> {
392 match &mut self.stream {
393 MssqlStream::Raw(stream) => {
394 stream.write_all(bytes).await?;
395 stream.flush().await?;
396 }
397 MssqlStream::Tls(stream) => {
398 stream.write_all(bytes).await?;
399 stream.flush().await?;
400 }
401 MssqlStream::Taken => return Err(taken_stream_error()),
402 }
403 Ok(())
404 }
405
406 async fn shutdown(&mut self) -> Result<(), Error> {
407 match &mut self.stream {
408 MssqlStream::Raw(stream) => stream.shutdown().await?,
409 MssqlStream::Tls(stream) => stream.shutdown().await?,
410 MssqlStream::Taken => return Err(taken_stream_error()),
411 }
412 Ok(())
413 }
414
415 async fn enable_tls(&mut self, options: &MssqlConnectOptions) -> Result<(), Error> {
416 let stream = match std::mem::replace(&mut self.stream, MssqlStream::Taken) {
417 MssqlStream::Raw(stream) => stream,
418 other => {
419 self.stream = other;
420 return Ok(());
421 }
422 };
423
424 let mut stream = TlsPreloginStream::new(stream);
425 stream.start_handshake();
426
427 let domain = options
428 .hostname_in_certificate()
429 .unwrap_or_else(|| options.host());
430 let connector = build_tls_connector(options)?;
431 let mut stream = connector
432 .connect(domain, stream)
433 .await
434 .map_err(|error| Error::Tls(error.into()))?;
435 stream.get_mut().get_mut().get_mut().finish_handshake();
436
437 self.stream = MssqlStream::Tls(stream);
438 Ok(())
439 }
440
441 async fn read_message(&mut self) -> Result<WireMessage, Error> {
442 let mut packet_type = None;
443 let mut expected_packet_id = None;
444 let mut payload = Vec::new();
445
446 loop {
447 let mut header_bytes = [0u8; PACKET_HEADER_LEN];
448 self.read_exact(&mut header_bytes).await?;
449 let header = PacketHeader::decode(&header_bytes).map_err(packet_error)?;
450
451 if let Some(packet_type) = packet_type {
452 if header.packet_type != packet_type {
453 return Err(Error::Protocol(format!(
454 "mismatched SQL Server packet type: expected 0x{:02x}, got 0x{:02x}",
455 packet_type.code(),
456 header.packet_type.code()
457 )));
458 }
459 } else {
460 packet_type = Some(header.packet_type);
461 }
462
463 if let Some(packet_id) = expected_packet_id {
464 if header.packet_id != packet_id {
465 return Err(Error::Protocol(format!(
466 "non-contiguous SQL Server packet id: expected {packet_id}, got {}",
467 header.packet_id
468 )));
469 }
470 }
471
472 let packet_len = usize::from(header.length);
473 if packet_len > self.packet_size {
474 return Err(Error::Protocol(format!(
475 "SQL Server packet length {packet_len} exceeds negotiated packet size {}",
476 self.packet_size
477 )));
478 }
479
480 let payload_len = packet_len.checked_sub(PACKET_HEADER_LEN).ok_or_else(|| {
481 Error::Protocol("SQL Server packet header length underflow".to_owned())
482 })?;
483 let old_len = payload.len();
484 payload.resize(old_len + payload_len, 0);
485 self.read_exact(&mut payload[old_len..]).await?;
486
487 expected_packet_id = Some(header.packet_id.wrapping_add(1));
488
489 if header.status == PacketStatus::END_OF_MESSAGE {
490 return Ok(WireMessage {
491 packet_type: packet_type.expect("packet_type is set after first header"),
492 payload,
493 });
494 }
495 }
496 }
497
498 async fn read_exact(&mut self, bytes: &mut [u8]) -> Result<(), Error> {
499 match &mut self.stream {
500 MssqlStream::Raw(stream) => {
501 stream.read_exact(bytes).await?;
502 }
503 MssqlStream::Tls(stream) => {
504 stream.read_exact(bytes).await?;
505 }
506 MssqlStream::Taken => return Err(taken_stream_error()),
507 }
508
509 Ok(())
510 }
511}
512
513#[derive(Debug)]
514struct WireMessage {
515 packet_type: PacketType,
516 payload: Vec<u8>,
517}
518
519fn negotiate_encryption(requested: Encrypt, server: Encrypt) -> std::result::Result<bool, Error> {
520 match (requested, server) {
521 (Encrypt::NotSupported, Encrypt::NotSupported | Encrypt::Off) => Ok(false),
522 (Encrypt::NotSupported, Encrypt::On | Encrypt::Required) => Err(Error::Protocol(
523 "SQL Server requires encryption, but the client URL requested encrypt=not_supported"
524 .to_owned(),
525 )),
526 (Encrypt::Required, Encrypt::Off | Encrypt::NotSupported) => Err(Error::Tls(
527 "SQL Server TLS encryption is required but not supported by the server".into(),
528 )),
529 (Encrypt::On | Encrypt::Required, Encrypt::On | Encrypt::Required) => Ok(true),
530 (Encrypt::Off, _) | (_, Encrypt::Off) => Err(Error::Protocol(
531 "SQL Server login-only TLS fallback is not implemented yet; use encrypt=mandatory or encrypt=strict for encrypted connections, or encrypt=not_supported for plaintext development servers"
532 .to_owned(),
533 )),
534 (Encrypt::On, Encrypt::NotSupported) => Ok(false),
535 }
536}
537
538fn build_tls_connector(options: &MssqlConnectOptions) -> Result<TlsConnector, Error> {
539 let mut builder = native_tls::TlsConnector::builder();
540 builder.danger_accept_invalid_certs(options.trust_server_certificate());
541 builder.danger_accept_invalid_hostnames(options.hostname_in_certificate().is_none());
542
543 if let Some(path) = options.ssl_root_cert() {
544 let cert = std::fs::read(path).map_err(Error::Io)?;
545 let cert = Certificate::from_pem(&cert)
546 .or_else(|_| Certificate::from_der(&cert))
547 .map_err(|error| Error::Tls(error.into()))?;
548 builder.add_root_certificate(cert);
549 }
550
551 builder
552 .build()
553 .map(TlsConnector::from)
554 .map_err(|error| Error::Tls(error.into()))
555}
556
557fn taken_stream_error() -> Error {
558 Error::Protocol("SQL Server stream was used while TLS upgrade was in progress".to_owned())
559}
560
561fn server_error(error: ServerError) -> Error {
562 Error::Protocol(format!(
563 "SQL Server error {} (state {}, class {}): {}",
564 error.number, error.state, error.class, error.message
565 ))
566}
567
568fn packet_error(error: crate::protocol::packet::PacketHeaderError) -> Error {
569 Error::Protocol(error.to_string())
570}
571
572fn pre_login_error(error: PreLoginError) -> Error {
573 Error::Protocol(error.to_string())
574}
575
576fn login_error(error: Login7Error) -> Error {
577 Error::Protocol(error.to_string())
578}
579
580fn token_error(error: TokenParseError) -> Error {
581 Error::Protocol(error.to_string())
582}
583
584fn frame_error(error: crate::protocol::packet::PacketFrameError) -> Error {
585 Error::Protocol(error.to_string())
586}
587
588fn stream_query_output(
589 output: QueryOutput,
590) -> BoxStream<'static, Result<Either<MssqlQueryResult, MssqlRow>, Error>> {
591 stream::iter(
592 output
593 .rows
594 .into_iter()
595 .map(|row| Ok(Either::Right(row)))
596 .chain(std::iter::once(Ok(Either::Left(output.result)))),
597 )
598 .boxed()
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604
605 #[test]
606 fn negotiates_full_tls_for_required_or_mandatory_encryption() {
607 assert!(negotiate_encryption(Encrypt::On, Encrypt::On).unwrap());
608 assert!(negotiate_encryption(Encrypt::Required, Encrypt::Required).unwrap());
609 }
610
611 #[test]
612 fn allows_plaintext_only_when_explicitly_requested_and_supported() {
613 assert!(!negotiate_encryption(Encrypt::NotSupported, Encrypt::Off).unwrap());
614 assert!(negotiate_encryption(Encrypt::NotSupported, Encrypt::Required).is_err());
615 }
616
617 #[test]
618 fn rejects_login_only_tls_fallback_until_downgrade_is_available() {
619 assert!(negotiate_encryption(Encrypt::Off, Encrypt::On).is_err());
620 assert!(negotiate_encryption(Encrypt::On, Encrypt::Off).is_err());
621 }
622}