1use crate::{
2 MssqlArguments, MssqlBufferSettings, MssqlColumn, MssqlConnectOptions,
3 MssqlQueryResult, MssqlRow, MssqlStatement, MssqlTypeInfo, MssqlValue, MssqlValueKind, Result,
4};
5use futures_core::future::BoxFuture;
6use futures_core::stream::BoxStream;
7use futures_util::{future, stream, StreamExt};
8use odbc_api::buffers::{AnyColumnBufferSlice, BufferDesc, ColumnarDynBuffer, NullableSlice};
9use odbc_api::{ConnectionTransitions, Cursor, DataType, Nullable, ResultSetMetadata};
10use sqlx_core::column::Column;
11use sqlx_core::common::StatementCache;
12use sqlx_core::executor::{Execute, Executor};
13use sqlx_core::sql_str::SqlStr;
14use sqlx_core::transaction::Transaction;
15use sqlx_core::Either;
16use std::future::Future;
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::sync::Arc;
19
20type PreparedStatement =
21 odbc_api::Prepared<odbc_api::handles::StatementConnection<odbc_api::SharedConnection<'static>>>;
22type ExecuteResult = std::result::Result<Either<MssqlQueryResult, MssqlRow>, sqlx_core::Error>;
23type ExecuteSender = flume::Sender<ExecuteResult>;
24
25enum Command {
30 Execute {
31 sql: SqlStr,
32 args: Option<MssqlArguments>,
33 persistent: bool,
34 response: ExecuteSender,
35 },
36 Prepare {
37 sql: SqlStr,
38 response: flume::Sender<
39 std::result::Result<MssqlStatement, sqlx_core::Error>,
40 >,
41 },
42 Ping {
43 response: flume::Sender<std::result::Result<(), sqlx_core::Error>>,
44 },
45 Begin {
46 response: flume::Sender<std::result::Result<(), sqlx_core::Error>>,
47 },
48 Commit {
49 response: flume::Sender<std::result::Result<(), sqlx_core::Error>>,
50 },
51 Rollback {
52 response: flume::Sender<std::result::Result<(), sqlx_core::Error>>,
53 },
54 StartRollback,
55 ExecSql {
56 sql: String,
57 response: flume::Sender<std::result::Result<(), sqlx_core::Error>>,
58 },
59 ScalarI64 {
60 sql: String,
61 response:
62 flume::Sender<std::result::Result<Option<i64>, sqlx_core::Error>>,
63 },
64 Shutdown {
65 signal: flume::Sender<()>,
66 },
67 ListMigrations {
69 sql: String,
70 response:
71 flume::Sender<std::result::Result<Vec<(i64, Vec<u8>)>, sqlx_core::Error>>,
72 },
73 #[cfg(feature = "migrate")]
76 ApplyMigration {
77 sql: String,
78 insert_sql: String,
79 version: i64,
80 no_tx: bool,
81 response: flume::Sender<std::result::Result<std::time::Duration, sqlx_core::Error>>,
82 },
83 #[cfg(feature = "migrate")]
86 RevertMigration {
87 sql: String,
88 delete_sql: String,
89 version: i64,
90 no_tx: bool,
91 response: flume::Sender<std::result::Result<std::time::Duration, sqlx_core::Error>>,
92 },
93}
94
95struct ConnectionActor {
100 conn: odbc_api::SharedConnection<'static>,
101 stmt_cache: StatementCache<PreparedStatement>,
102 transaction_depth: usize,
103 buffer_settings: MssqlBufferSettings,
104}
105
106impl ConnectionActor {
107 fn run(mut self, rx: flume::Receiver<Command>) {
108 for cmd in rx {
111 match cmd {
114 Command::Execute {
115 sql,
116 args,
117 persistent,
118 response,
119 } => {
120 let _ = self.handle_execute(sql, args, persistent, &response);
121 }
122 Command::Prepare { sql, response } => {
123 let _ = response.send(self.handle_prepare(sql));
124 }
125 Command::Ping { response } => {
126 let _ = response.send(self.handle_ping());
127 }
128 Command::Begin { response } => {
129 let _ = response.send(self.handle_begin());
130 }
131 Command::Commit { response } => {
132 let _ = response.send(self.handle_commit());
133 }
134 Command::Rollback { response } => {
135 let _ = response.send(self.handle_rollback());
136 }
137 Command::StartRollback => {
138 self.handle_start_rollback();
139 }
140 Command::ExecSql { sql, response } => {
141 let _ = response.send(self.handle_exec_sql(&sql));
142 }
143 Command::ScalarI64 { sql, response } => {
144 let _ = response.send(self.handle_scalar_i64(&sql));
145 }
146 Command::Shutdown { signal } => {
147 let _ = signal.send(());
148 return;
149 }
150 Command::ListMigrations { sql, response } => {
151 let _ = response.send(self.handle_list_migrations(&sql));
152 }
153 #[cfg(feature = "migrate")]
154 Command::ApplyMigration {
155 sql,
156 insert_sql,
157 version,
158 no_tx,
159 response,
160 } => {
161 let _ = response.send(self.handle_apply_migration(&sql, &insert_sql, version, no_tx));
162 }
163 #[cfg(feature = "migrate")]
164 Command::RevertMigration {
165 sql,
166 delete_sql,
167 version,
168 no_tx,
169 response,
170 } => {
171 let _ = response.send(self.handle_revert_migration(&sql, &delete_sql, version, no_tx));
172 }
173 }
174 }
175 }
177
178 fn handle_execute(
183 &mut self,
184 sql: SqlStr,
185 arguments: Option<MssqlArguments>,
186 persistent: bool,
187 tx: &ExecuteSender,
188 ) -> std::result::Result<(), sqlx_core::Error> {
189 let has_arguments = arguments.as_ref().is_some_and(|a| !a.is_empty());
190 let parameters = arguments
191 .as_ref()
192 .map(MssqlArguments::to_odbc_parameter_collection)
193 .unwrap_or_default();
194
195 if persistent && has_arguments {
196 if let Some(prepared) = self.stmt_cache.get_mut(sql.as_str()) {
197 let mut conn_guard = self.conn.lock().map_err(|_| {
199 sqlx_core::Error::Protocol(
200 "ODBC execute: failed to lock connection".to_owned(),
201 )
202 })?;
203 let has_cursor = prepared
204 .execute(parameters.as_slice())
205 .map_err(|error| {
206 crate::error::database_error_with_context_lazy(error, || {
207 format!(
208 "failed to execute cached ODBC statement: `{}`",
209 sql_preview(sql.as_str())
210 )
211 })
212 })?
213 .is_some();
214 drop(conn_guard);
215
216 if has_cursor {
217 let mut conn_guard = self.conn.lock().map_err(|_| {
219 sqlx_core::Error::Protocol(
220 "ODBC execute: failed to lock connection".to_owned(),
221 )
222 })?;
223 let cursor = prepared
224 .execute(parameters.as_slice())
225 .map_err(|error| {
226 crate::error::database_error_with_context_lazy(error, || {
227 format!(
228 "failed to execute cached ODBC statement: `{}`",
229 sql_preview(sql.as_str())
230 )
231 })
232 })?
233 .expect("has_cursor was true");
234 drop(conn_guard);
235 return stream_result_sets(cursor, self.buffer_settings, tx);
236 }
237
238 let ra = prepared.row_count().map_err(|error| {
239 crate::error::database_error_with_context_lazy(error, || {
240 format!(
241 "failed to read ODBC row count for cached statement: `{}`",
242 sql_preview(sql.as_str())
243 )
244 })
245 })?;
246 return send_rows_affected(ra, tx);
247 } else {
248 let mut prepared =
250 self.conn.clone().into_prepared(sql.as_str()).map_err(|error| {
251 crate::error::database_error_with_context_lazy(error, || {
252 format!(
253 "failed to prepare cached ODBC statement: `{}`",
254 sql_preview(sql.as_str())
255 )
256 })
257 })?;
258
259 let mut conn_guard = self.conn.lock().map_err(|_| {
260 sqlx_core::Error::Protocol(
261 "ODBC execute: failed to lock connection".to_owned(),
262 )
263 })?;
264 let has_cursor = prepared
265 .execute(parameters.as_slice())
266 .map_err(|error| {
267 crate::error::database_error_with_context_lazy(error, || {
268 format!(
269 "failed to execute cached ODBC statement: `{}`",
270 sql_preview(sql.as_str())
271 )
272 })
273 })?
274 .is_some();
275 drop(conn_guard);
276
277 if has_cursor {
278 let mut conn_guard = self.conn.lock().map_err(|_| {
280 sqlx_core::Error::Protocol(
281 "ODBC execute: failed to lock connection".to_owned(),
282 )
283 })?;
284 let cursor = prepared
285 .execute(parameters.as_slice())
286 .map_err(|error| {
287 crate::error::database_error_with_context_lazy(error, || {
288 format!(
289 "failed to execute cached ODBC statement: `{}`",
290 sql_preview(sql.as_str())
291 )
292 })
293 })?
294 .expect("has_cursor was true");
295 drop(conn_guard);
296 return stream_result_sets(cursor, self.buffer_settings, tx);
297 }
298
299 let ra = prepared.row_count().map_err(|error| {
300 crate::error::database_error_with_context_lazy(error, || {
301 format!(
302 "failed to read ODBC row count for cached statement: `{}`",
303 sql_preview(sql.as_str())
304 )
305 })
306 })?;
307 self.stmt_cache.insert(sql.as_str(), prepared);
308 return send_rows_affected(ra, tx);
309 }
310 } else {
311 let mut statement = self.conn.clone().into_preallocated().map_err(|error| {
313 crate::error::database_error_with_context_lazy(error, || {
314 format!(
315 "failed to allocate an ODBC statement for query: `{}`",
316 sql_preview(sql.as_str())
317 )
318 })
319 })?;
320 if let Some(cursor) = statement
321 .execute(sql.as_str(), parameters.as_slice())
322 .map_err(|error| {
323 crate::error::database_error_with_context_lazy(error, || {
324 format!(
325 "failed to execute ODBC query: `{}`",
326 sql_preview(sql.as_str())
327 )
328 })
329 })? {
330 return stream_result_sets(cursor, self.buffer_settings, tx);
331 }
332 let rows_affected = statement.row_count().map_err(|error| {
333 crate::error::database_error_with_context_lazy(error, || {
334 format!(
335 "failed to read ODBC row count for query: `{}`",
336 sql_preview(sql.as_str())
337 )
338 })
339 })?;
340 send_rows_affected(rows_affected, tx)
341 }
342 }
343
344 fn handle_prepare(
345 &mut self,
346 sql: SqlStr,
347 ) -> std::result::Result<MssqlStatement, sqlx_core::Error> {
348 if let Some(prepared) = self.stmt_cache.get_mut(sql.as_str()) {
349 let parameters = prepared.num_params().map_err(|error| {
350 sqlx_core::Error::from(crate::error::database_error_with_context(
351 error,
352 format!(
353 "failed to read ODBC parameter metadata for cached statement: `{}`",
354 sql_preview(sql.as_str())
355 ),
356 ))
357 })?;
358 let columns = collect_prepared_columns(prepared)?;
359 return Ok(MssqlStatement::new(sql, columns, usize::from(parameters)));
360 }
361
362 let mut prepared = self.conn.clone().into_prepared(sql.as_str()).map_err(|error| {
363 sqlx_core::Error::from(crate::error::database_error_with_context(
364 error,
365 format!(
366 "failed to prepare MSSQL ODBC statement: `{}`",
367 sql_preview(sql.as_str())
368 ),
369 ))
370 })?;
371 let parameters = prepared.num_params().map_err(|error| {
372 sqlx_core::Error::from(crate::error::database_error_with_context(
373 error,
374 format!(
375 "failed to read ODBC parameter metadata for prepared statement: `{}`",
376 sql_preview(sql.as_str())
377 ),
378 ))
379 })?;
380 let columns = collect_prepared_columns(&mut prepared)?;
381 if self.stmt_cache.is_enabled() {
382 self.stmt_cache.insert(sql.as_str(), prepared);
383 }
384
385 Ok(MssqlStatement::new(sql, columns, usize::from(parameters)))
386 }
387
388 fn handle_ping(&mut self) -> std::result::Result<(), sqlx_core::Error> {
389 let mut conn_guard = self.conn.lock().map_err(|_| {
390 sqlx_core::Error::Protocol("failed to lock connection for ping".into())
391 })?;
392 conn_guard.execute("SELECT 1", (), None).map_err(|error| {
393 sqlx_core::Error::from(crate::error::database_error_with_context(
394 error,
395 "MSSQL ping query failed: `SELECT 1`",
396 ))
397 })?;
398 Ok(())
399 }
400
401 fn handle_begin(&mut self) -> std::result::Result<(), sqlx_core::Error> {
402 if self.transaction_depth == 0 {
403 let mut conn_guard = self.conn.lock().map_err(|_| {
404 sqlx_core::Error::Protocol(
405 "MSSQL ODBC begin: failed to lock connection".to_owned(),
406 )
407 })?;
408 conn_guard.set_autocommit(false).map_err(|error| {
409 sqlx_core::Error::from(crate::error::database_error_with_context(
410 error,
411 "failed to disable ODBC autocommit while beginning a transaction",
412 ))
413 })?;
414 } else {
415 let savepoint = format!("sqlx_sp_{}", self.transaction_depth);
416 let mut conn_guard = self.conn.lock().map_err(|_| {
417 sqlx_core::Error::Protocol(
418 "MSSQL ODBC begin (savepoint): failed to lock connection".to_owned(),
419 )
420 })?;
421 conn_guard
422 .execute(&format!("SAVE TRANSACTION {savepoint}"), (), None)
423 .map_err(|error| {
424 sqlx_core::Error::from(crate::error::database_error_with_context(
425 error,
426 format!(
427 "failed to create save point `{savepoint}` for nested transaction"
428 ),
429 ))
430 })?;
431 }
432 self.transaction_depth += 1;
433 Ok(())
434 }
435
436 fn handle_commit(&mut self) -> std::result::Result<(), sqlx_core::Error> {
437 if self.transaction_depth == 0 {
438 return Ok(());
439 }
440
441 if self.transaction_depth == 1 {
442 let mut conn_guard = self.conn.lock().map_err(|_| {
443 sqlx_core::Error::Protocol(
444 "MSSQL ODBC commit: failed to lock connection".to_owned(),
445 )
446 })?;
447 conn_guard.commit().map_err(|error| {
448 sqlx_core::Error::from(crate::error::database_error_with_context(
449 error,
450 "failed to commit the active MSSQL ODBC transaction",
451 ))
452 })?;
453 conn_guard.set_autocommit(true).map_err(|error| {
454 sqlx_core::Error::from(crate::error::database_error_with_context(
455 error,
456 "failed to restore ODBC autocommit after commit",
457 ))
458 })?;
459 self.transaction_depth = 0;
460 } else {
461 self.transaction_depth -= 1;
462 }
463 Ok(())
464 }
465
466 fn handle_rollback(&mut self) -> std::result::Result<(), sqlx_core::Error> {
467 if self.transaction_depth == 0 {
468 return Ok(());
469 }
470
471 if self.transaction_depth == 1 {
472 let mut conn_guard = self.conn.lock().map_err(|_| {
473 sqlx_core::Error::Protocol(
474 "MSSQL ODBC rollback: failed to lock connection".to_owned(),
475 )
476 })?;
477 conn_guard.rollback().map_err(|error| {
478 sqlx_core::Error::from(crate::error::database_error_with_context(
479 error,
480 "failed to roll back the active ODBC transaction",
481 ))
482 })?;
483 conn_guard.set_autocommit(true).map_err(|error| {
484 sqlx_core::Error::from(crate::error::database_error_with_context(
485 error,
486 "failed to restore ODBC autocommit after rollback",
487 ))
488 })?;
489 self.transaction_depth = 0;
490 } else {
491 let savepoint = format!("sqlx_sp_{}", self.transaction_depth - 1);
492 let mut conn_guard = self.conn.lock().map_err(|_| {
493 sqlx_core::Error::Protocol(
494 "MSSQL ODBC rollback (savepoint): failed to lock connection".to_owned(),
495 )
496 })?;
497 conn_guard
498 .execute(&format!("ROLLBACK TRANSACTION {savepoint}"), (), None)
499 .map_err(|error| {
500 sqlx_core::Error::from(crate::error::database_error_with_context(
501 error,
502 format!("failed to roll back to save point `{savepoint}`"),
503 ))
504 })?;
505 self.transaction_depth -= 1;
506 }
507 Ok(())
508 }
509
510 fn handle_start_rollback(&mut self) {
511 if self.transaction_depth == 0 {
512 return;
513 }
514
515 if self.transaction_depth == 1 {
516 if let Ok(mut conn_guard) = self.conn.lock() {
517 let _ = conn_guard.rollback();
518 let _ = conn_guard.set_autocommit(true);
519 }
520 self.transaction_depth = 0;
521 } else {
522 let savepoint = format!("sqlx_sp_{}", self.transaction_depth - 1);
523 if let Ok(mut conn_guard) = self.conn.lock() {
524 let _ = conn_guard.execute(
525 &format!("ROLLBACK TRANSACTION {savepoint}"),
526 (),
527 None,
528 );
529 }
530 self.transaction_depth -= 1;
531 }
532 }
533
534 fn handle_exec_sql(&self, sql: &str) -> std::result::Result<(), sqlx_core::Error> {
535 let mut conn_guard = self.conn.lock().map_err(|_| {
536 sqlx_core::Error::Protocol("failed to lock the shared ODBC connection".into())
537 })?;
538 conn_guard.execute(sql, (), None).map_err(|error| {
539 sqlx_core::Error::from(crate::error::database_error_with_context(
540 error,
541 format!("failed to execute SQL: `{}`", sql_preview(sql)),
542 ))
543 })?;
544 Ok(())
545 }
546
547 fn handle_scalar_i64(&self, sql: &str) -> std::result::Result<Option<i64>, sqlx_core::Error> {
548 let mut conn_guard = self.conn.lock().map_err(|_| {
549 sqlx_core::Error::Protocol("failed to lock the shared ODBC connection".into())
550 })?;
551 let mut cursor = conn_guard
552 .execute(sql, (), None)
553 .map_err(|error| {
554 sqlx_core::Error::from(crate::error::database_error_with_context(
555 error,
556 format!("scalar query failed: `{}`", sql_preview(sql)),
557 ))
558 })?
559 .ok_or_else(|| {
560 sqlx_core::Error::Protocol(format!(
561 "scalar query returned no result set: `{}`",
562 sql_preview(sql),
563 ))
564 })?;
565
566 if let Some(mut row) = cursor.next_row().map_err(|error| {
567 sqlx_core::Error::from(crate::error::database_error_with_context(
568 error,
569 "scalar query next row",
570 ))
571 })? {
572 let mut value: Nullable<i64> = Nullable::null();
573 row.get_data(1, &mut value).map_err(|error| {
574 sqlx_core::Error::from(crate::error::database_error_with_context(
575 error,
576 "scalar query column 1",
577 ))
578 })?;
579 Ok(value.into_opt())
580 } else {
581 Ok(None)
582 }
583 }
584
585 fn handle_list_migrations(
586 &self,
587 sql: &str,
588 ) -> std::result::Result<Vec<(i64, Vec<u8>)>, sqlx_core::Error> {
589 let mut conn_guard = self.conn.lock().map_err(|_| {
590 sqlx_core::Error::Protocol("failed to lock the shared ODBC connection".into())
591 })?;
592 let mut cursor = conn_guard
593 .execute(sql, (), None)
594 .map_err(|error| {
595 sqlx_core::Error::from(crate::error::database_error_with_context(
596 error,
597 "failed to query applied migrations",
598 ))
599 })?
600 .ok_or_else(|| {
601 sqlx_core::Error::Protocol(
602 "list_applied_migrations returned no result set".into(),
603 )
604 })?;
605
606 let mut migrations = Vec::new();
607 while let Some(mut row) = cursor.next_row().map_err(|error| {
608 sqlx_core::Error::from(crate::error::database_error_with_context(
609 error,
610 "failed to read applied migration row",
611 ))
612 })? {
613 let mut version: Nullable<i64> = Nullable::null();
614 row.get_data(1, &mut version).map_err(|error| {
615 sqlx_core::Error::from(crate::error::database_error_with_context(
616 error,
617 "failed to read migration version",
618 ))
619 })?;
620
621 let mut checksum_bytes = Vec::new();
622 let has_value = row.get_binary(2, &mut checksum_bytes).map_err(|error| {
623 sqlx_core::Error::from(crate::error::database_error_with_context(
624 error,
625 "failed to read migration checksum",
626 ))
627 })?;
628
629 if let Some(version) = version.into_opt() {
630 migrations.push((version, if has_value { checksum_bytes } else { vec![] }));
631 }
632 }
633
634 Ok(migrations)
635 }
636
637 #[cfg(feature = "migrate")]
638 fn handle_apply_migration(
639 &mut self,
640 sql: &str,
641 insert_sql: &str,
642 version: i64,
643 no_tx: bool,
644 ) -> std::result::Result<std::time::Duration, sqlx_core::Error> {
645 let start = std::time::Instant::now();
646 let mut conn_guard = self.conn.lock().map_err(|_| {
647 sqlx_core::Error::Protocol(
648 "failed to lock the shared ODBC connection for migration".into(),
649 )
650 })?;
651
652 if !no_tx {
653 conn_guard.set_autocommit(false).map_err(|error| {
654 sqlx_core::Error::from(crate::error::database_error_with_context(
655 error,
656 "failed to start transaction for migration apply",
657 ))
658 })?;
659 }
660
661 conn_guard.execute(sql, (), None).map_err(|error| {
662 sqlx_core::Error::from(crate::error::database_error_with_context(
663 error,
664 format!("migration {version} failed"),
665 ))
666 })?;
667
668 conn_guard.execute(insert_sql, (), None).map_err(|error| {
669 sqlx_core::Error::from(crate::error::database_error_with_context(
670 error,
671 format!("failed to insert tracking record for migration {version}"),
672 ))
673 })?;
674
675 if !no_tx {
676 conn_guard.commit().map_err(|error| {
677 sqlx_core::Error::from(crate::error::database_error_with_context(
678 error,
679 format!("failed to commit migration {version}"),
680 ))
681 })?;
682 conn_guard.set_autocommit(true).map_err(|error| {
683 sqlx_core::Error::from(crate::error::database_error_with_context(
684 error,
685 "failed to restore autocommit after migration apply",
686 ))
687 })?;
688 }
689
690 Ok(start.elapsed())
691 }
692
693 #[cfg(feature = "migrate")]
694 fn handle_revert_migration(
695 &mut self,
696 sql: &str,
697 delete_sql: &str,
698 version: i64,
699 no_tx: bool,
700 ) -> std::result::Result<std::time::Duration, sqlx_core::Error> {
701 let start = std::time::Instant::now();
702 let mut conn_guard = self.conn.lock().map_err(|_| {
703 sqlx_core::Error::Protocol(
704 "failed to lock the shared ODBC connection for migration".into(),
705 )
706 })?;
707
708 if !no_tx {
709 conn_guard.set_autocommit(false).map_err(|error| {
710 sqlx_core::Error::from(crate::error::database_error_with_context(
711 error,
712 "failed to start transaction for migration revert",
713 ))
714 })?;
715 }
716
717 conn_guard.execute(sql, (), None).map_err(|error| {
718 sqlx_core::Error::from(crate::error::database_error_with_context(
719 error,
720 format!("revert migration {version} failed"),
721 ))
722 })?;
723
724 conn_guard.execute(delete_sql, (), None).map_err(|error| {
725 sqlx_core::Error::from(crate::error::database_error_with_context(
726 error,
727 format!("failed to delete tracking record for migration {version}"),
728 ))
729 })?;
730
731 if !no_tx {
732 conn_guard.commit().map_err(|error| {
733 sqlx_core::Error::from(crate::error::database_error_with_context(
734 error,
735 format!("failed to commit migration revert {version}"),
736 ))
737 })?;
738 conn_guard.set_autocommit(true).map_err(|error| {
739 sqlx_core::Error::from(crate::error::database_error_with_context(
740 error,
741 "failed to restore autocommit after migration revert",
742 ))
743 })?;
744 }
745
746 Ok(start.elapsed())
747 }
748}
749
750pub struct MssqlConnection {
752 cmd_tx: flume::Sender<Command>,
753 buffer_settings: MssqlBufferSettings,
754 transaction_depth: AtomicUsize,
755}
756
757impl std::fmt::Debug for MssqlConnection {
758 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
759 f.debug_struct("MssqlConnection").finish_non_exhaustive()
760 }
761}
762
763impl MssqlConnection {
764 pub fn connect_blocking(options: &MssqlConnectOptions) -> Result<Self> {
767 let env = odbc_api::environment().map_err(|error| {
768 crate::MssqlError::Configuration(format!(
769 "failed to initialize the process-wide ODBC environment: {error}"
770 ))
771 })?;
772
773 let raw_conn = env
774 .connect_with_connection_string(options.connection_string(), Default::default())
775 .map_err(|error| {
776 crate::error::database_error_with_context(
777 error,
778 "failed to open MSSQL ODBC connection using the supplied connection string",
779 )
780 })?;
781
782 let conn: odbc_api::SharedConnection<'static> =
784 std::sync::Arc::new(std::sync::Mutex::new(raw_conn));
785
786 let (cmd_tx, cmd_rx) = flume::unbounded();
787
788 let actor = ConnectionActor {
789 conn,
790 stmt_cache: StatementCache::new(options.statement_cache_capacity),
791 transaction_depth: 0,
792 buffer_settings: options.buffer_settings,
793 };
794
795 std::thread::spawn(move || actor.run(cmd_rx));
799
800 Ok(Self {
801 cmd_tx,
802 buffer_settings: options.buffer_settings,
803 transaction_depth: AtomicUsize::new(0),
804 })
805 }
806
807 pub fn ping_blocking(&self) -> std::result::Result<(), sqlx_core::Error> {
809 send_command_blocking(&self.cmd_tx, |tx| Command::Ping { response: tx })?
810 }
811
812 pub fn dbms_name(&self) -> std::result::Result<String, sqlx_core::Error> {
814 send_command_blocking(&self.cmd_tx, |tx| {
815 Command::ExecSql {
816 sql: "SELECT 1 /* dbms_name */".into(),
817 response: tx,
818 }
819 })?;
820 Ok("MSSQL via ODBC".to_owned())
821 }
822
823 pub(crate) fn begin_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
825 let r = send_command_blocking(&self.cmd_tx, |tx| Command::Begin { response: tx })?;
826 if r.is_ok() {
827 self.transaction_depth.fetch_add(1, Ordering::SeqCst);
828 }
829 r
830 }
831
832 pub(crate) fn commit_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
834 let depth = self.transaction_depth.load(Ordering::SeqCst);
835 if depth == 0 {
836 return Ok(());
837 }
838 let r = send_command_blocking(&self.cmd_tx, |tx| Command::Commit { response: tx })?;
839 if r.is_ok() {
840 if depth == 1 {
841 self.transaction_depth.store(0, Ordering::SeqCst);
842 } else {
843 self.transaction_depth.fetch_sub(1, Ordering::SeqCst);
844 }
845 }
846 r
847 }
848
849 pub(crate) fn rollback_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
851 let depth = self.transaction_depth.load(Ordering::SeqCst);
852 if depth == 0 {
853 return Ok(());
854 }
855 let r = send_command_blocking(&self.cmd_tx, |tx| Command::Rollback { response: tx })?;
856 if r.is_ok() {
857 if depth == 1 {
858 self.transaction_depth.store(0, Ordering::SeqCst);
859 } else {
860 self.transaction_depth.fetch_sub(1, Ordering::SeqCst);
861 }
862 }
863 r
864 }
865
866 pub(crate) fn start_rollback(&mut self) {
868 let _ = self.cmd_tx.try_send(Command::StartRollback);
869 self.transaction_depth.store(0, Ordering::SeqCst);
870 }
871
872 pub(crate) fn transaction_depth(&self) -> usize {
874 self.transaction_depth.load(Ordering::SeqCst)
875 }
876
877 pub(crate) fn set_transaction_depth(&mut self, depth: usize) {
879 self.transaction_depth.store(depth, Ordering::SeqCst);
880 }
881
882 pub fn prepare_blocking(
884 &self,
885 sql: sqlx_core::sql_str::SqlStr,
886 ) -> std::result::Result<MssqlStatement, sqlx_core::Error> {
887 send_command_blocking(&self.cmd_tx, |tx| Command::Prepare { sql, response: tx })?
888 }
889
890 #[cfg(feature = "migrate")]
892 pub(crate) fn exec_sql_blocking(&self, sql: &str) -> std::result::Result<(), sqlx_core::Error> {
893 send_command_blocking(&self.cmd_tx, |tx| {
894 Command::ExecSql {
895 sql: sql.to_owned(),
896 response: tx,
897 }
898 })?
899 }
900
901 #[cfg(feature = "migrate")]
903 pub(crate) fn scalar_i64_blocking(
904 &self,
905 sql: &str,
906 ) -> std::result::Result<Option<i64>, sqlx_core::Error> {
907 send_command_blocking(&self.cmd_tx, |tx| {
908 Command::ScalarI64 {
909 sql: sql.to_owned(),
910 response: tx,
911 }
912 })?
913 }
914
915 #[cfg(feature = "migrate")]
917 pub(crate) fn list_migrations_blocking(
918 &self,
919 sql: &str,
920 ) -> std::result::Result<Vec<(i64, Vec<u8>)>, sqlx_core::Error> {
921 send_command_blocking(&self.cmd_tx, |tx| {
922 Command::ListMigrations {
923 sql: sql.to_owned(),
924 response: tx,
925 }
926 })?
927 }
928
929 #[cfg(feature = "migrate")]
931 pub(crate) fn apply_migration_blocking(
932 &self,
933 sql: &str,
934 insert_sql: &str,
935 version: i64,
936 no_tx: bool,
937 ) -> std::result::Result<std::time::Duration, sqlx_core::Error> {
938 send_command_blocking(&self.cmd_tx, |tx| {
939 Command::ApplyMigration {
940 sql: sql.to_owned(),
941 insert_sql: insert_sql.to_owned(),
942 version,
943 no_tx,
944 response: tx,
945 }
946 })?
947 }
948
949 #[cfg(feature = "migrate")]
951 pub(crate) fn revert_migration_blocking(
952 &self,
953 sql: &str,
954 delete_sql: &str,
955 version: i64,
956 no_tx: bool,
957 ) -> std::result::Result<std::time::Duration, sqlx_core::Error> {
958 send_command_blocking(&self.cmd_tx, |tx| {
959 Command::RevertMigration {
960 sql: sql.to_owned(),
961 delete_sql: delete_sql.to_owned(),
962 version,
963 no_tx,
964 response: tx,
965 }
966 })?
967 }
968
969 pub(crate) fn execute_receiver(
971 &self,
972 sql: sqlx_core::sql_str::SqlStr,
973 persistent: bool,
974 arguments: Option<MssqlArguments>,
975 ) -> flume::Receiver<ExecuteResult> {
976 let (tx, rx) = flume::bounded(64);
977 if self
978 .cmd_tx
979 .send(Command::Execute {
980 sql,
981 args: arguments,
982 persistent,
983 response: tx,
984 })
985 .is_err()
986 {
987 let _ = rx.drain();
989 }
990 rx
991 }
992}
993
994impl Drop for MssqlConnection {
996 fn drop(&mut self) {}
997}
998
999impl sqlx_core::connection::Connection for MssqlConnection {
1004 type Database = crate::Mssql;
1005 type Options = MssqlConnectOptions;
1006
1007 async fn close(self) -> std::result::Result<(), sqlx_core::Error> {
1008 drop(self);
1009 Ok(())
1010 }
1011
1012 async fn close_hard(self) -> std::result::Result<(), sqlx_core::Error> {
1013 drop(self);
1014 Ok(())
1015 }
1016
1017 async fn ping(&mut self) -> std::result::Result<(), sqlx_core::Error> {
1018 send_command_async(&self.cmd_tx, |tx| Command::Ping { response: tx }).await?
1019 }
1020
1021 fn begin(
1022 &mut self,
1023 ) -> impl Future<Output = std::result::Result<Transaction<'_, Self::Database>, sqlx_core::Error>>
1024 + Send
1025 + '_ {
1026 Transaction::begin(self, None)
1027 }
1028
1029 fn shrink_buffers(&mut self) {}
1030
1031 async fn flush(&mut self) -> std::result::Result<(), sqlx_core::Error> {
1032 Ok(())
1033 }
1034
1035 fn should_flush(&self) -> bool {
1036 false
1037 }
1038
1039 fn cached_statements_size(&self) -> usize
1040 where
1041 Self::Database: sqlx_core::database::HasStatementCache,
1042 {
1043 0
1046 }
1047
1048 async fn clear_cached_statements(&mut self) -> std::result::Result<(), sqlx_core::Error>
1049 where
1050 Self::Database: sqlx_core::database::HasStatementCache,
1051 {
1052 Ok(())
1056 }
1057}
1058
1059impl<'c> Executor<'c> for &'c mut MssqlConnection {
1064 type Database = crate::Mssql;
1065
1066 fn fetch_many<'e, 'q, E>(
1067 self,
1068 mut query: E,
1069 ) -> BoxStream<'e, std::result::Result<Either<MssqlQueryResult, MssqlRow>, sqlx_core::Error>>
1070 where
1071 'c: 'e,
1072 E: Execute<'q, Self::Database>,
1073 'q: 'e,
1074 E: 'q,
1075 {
1076 let arguments = query.take_arguments().map_err(sqlx_core::Error::Encode);
1077 let persistent = query.persistent();
1078 let sql = query.sql();
1079
1080 match arguments {
1081 Ok(arguments) => {
1082 receiver_to_stream(self.execute_receiver(sql, persistent, arguments))
1083 }
1084 Err(error) => stream::once(future::ready(Err(error))).boxed(),
1085 }
1086 }
1087
1088 fn fetch_optional<'e, 'q, E>(
1089 self,
1090 mut query: E,
1091 ) -> BoxFuture<'e, std::result::Result<Option<MssqlRow>, sqlx_core::Error>>
1092 where
1093 'c: 'e,
1094 E: Execute<'q, Self::Database>,
1095 'q: 'e,
1096 E: 'q,
1097 {
1098 let arguments = query.take_arguments().map_err(sqlx_core::Error::Encode);
1099 let persistent = query.persistent();
1100 let sql = query.sql();
1101
1102 Box::pin(async move {
1103 let rx = self.execute_receiver(sql, persistent, arguments?);
1104 while let Ok(item) = rx.recv_async().await {
1105 match item? {
1106 Either::Right(row) => return Ok(Some(row)),
1107 Either::Left(_) => {}
1108 }
1109 }
1110 Ok(None)
1111 })
1112 }
1113
1114 fn prepare_with<'e>(
1115 self,
1116 sql: sqlx_core::sql_str::SqlStr,
1117 _parameters: &[crate::MssqlTypeInfo],
1118 ) -> BoxFuture<'e, std::result::Result<MssqlStatement, sqlx_core::Error>>
1119 where
1120 'c: 'e,
1121 {
1122 let cmd_tx = self.cmd_tx.clone();
1123 Box::pin(async move {
1124 send_command_async(&cmd_tx, |tx| Command::Prepare { sql, response: tx }).await?
1125 })
1126 }
1127
1128 #[cfg(feature = "offline")]
1129 fn describe<'e>(
1130 self,
1131 sql: sqlx_core::sql_str::SqlStr,
1132 ) -> BoxFuture<'e, std::result::Result<sqlx_core::describe::Describe<Self::Database>, sqlx_core::Error>>
1133 where
1134 'c: 'e,
1135 {
1136 use sqlx_core::statement::Statement;
1137 let cmd_tx = self.cmd_tx.clone();
1138 Box::pin(async move {
1139 let statement =
1140 send_command_async(&cmd_tx, |tx| Command::Prepare { sql, response: tx }).await??;
1141 let columns = statement.columns().to_vec();
1142 let column_count = columns.len();
1143 let parameter_count = statement
1144 .parameters()
1145 .map(|p| match p {
1146 Either::Left(types) => types.len(),
1147 Either::Right(count) => count,
1148 })
1149 .unwrap_or(0);
1150
1151 Ok(sqlx_core::describe::Describe {
1152 columns,
1153 parameters: Some(Either::Right(parameter_count)),
1154 nullable: vec![None; column_count],
1155 })
1156 })
1157 }
1158}
1159
1160async fn send_command_async<T: Send>(
1165 cmd_tx: &flume::Sender<Command>,
1166 make_cmd: impl FnOnce(flume::Sender<T>) -> Command,
1167) -> std::result::Result<T, sqlx_core::Error> {
1168 let (resp_tx, resp_rx) = flume::bounded(1);
1169 let cmd = make_cmd(resp_tx);
1170 cmd_tx.send(cmd).map_err(|_| {
1171 sqlx_core::Error::Protocol(
1172 "MSSQL ODBC connection actor has shut down".to_owned(),
1173 )
1174 })?;
1175 resp_rx.recv_async().await.map_err(|_| {
1176 sqlx_core::Error::Protocol(
1177 "MSSQL ODBC connection actor response channel closed".to_owned(),
1178 )
1179 })
1180}
1181
1182fn send_command_blocking<T: Send>(
1187 cmd_tx: &flume::Sender<Command>,
1188 make_cmd: impl FnOnce(flume::Sender<T>) -> Command,
1189) -> std::result::Result<T, sqlx_core::Error> {
1190 let (resp_tx, resp_rx) = flume::bounded(1);
1191 let cmd = make_cmd(resp_tx);
1192 cmd_tx.send(cmd).map_err(|_| {
1193 sqlx_core::Error::Protocol(
1194 "MSSQL ODBC connection actor has shut down".to_owned(),
1195 )
1196 })?;
1197 resp_rx.recv().map_err(|_| {
1198 sqlx_core::Error::Protocol(
1199 "MSSQL ODBC connection actor response channel closed".to_owned(),
1200 )
1201 })
1202}
1203
1204fn receiver_to_stream<'e>(
1209 rx: flume::Receiver<ExecuteResult>,
1210) -> BoxStream<'e, ExecuteResult> {
1211 stream::unfold(rx, |rx| async move {
1212 rx.recv_async().await.ok().map(|item| (item, rx))
1213 })
1214 .boxed()
1215}
1216
1217fn send_rows_affected(
1222 rows_affected: Option<usize>,
1223 tx: &ExecuteSender,
1224) -> std::result::Result<(), sqlx_core::Error> {
1225 let rows_affected = rows_affected
1226 .unwrap_or(0)
1227 .try_into()
1228 .map_err(|_| sqlx_core::Error::Protocol("ODBC row count does not fit in u64".to_owned()))?;
1229 send_done(tx, rows_affected);
1230 Ok(())
1231}
1232
1233fn send_done(tx: &ExecuteSender, rows_affected: u64) -> bool {
1234 tx.send(Ok(Either::Left(MssqlQueryResult::new(rows_affected))))
1235 .is_ok()
1236}
1237
1238fn send_row(tx: &ExecuteSender, row: MssqlRow) -> bool {
1239 tx.send(Ok(Either::Right(row))).is_ok()
1240}
1241
1242pub(crate) fn collect_columns(
1243 cursor: &mut impl ResultSetMetadata,
1244) -> std::result::Result<Vec<MssqlColumn>, sqlx_core::Error> {
1245 let count = cursor.num_result_cols().map_err(|error| {
1246 crate::error::database_error_with_context(error, "failed to read ODBC result-column count")
1247 })?;
1248 let count = usize::try_from(count).map_err(|_| {
1249 sqlx_core::Error::Protocol(format!("ODBC returned a negative column count: {count}"))
1250 })?;
1251
1252 let mut columns = Vec::with_capacity(count);
1253 for ordinal in 0..count {
1254 let column_number = u16::try_from(ordinal + 1).map_err(|_| {
1255 sqlx_core::Error::Protocol(format!("ODBC column index exceeds u16: {}", ordinal + 1))
1256 })?;
1257
1258 let mut description = odbc_api::ColumnDescription::default();
1259 cursor
1260 .describe_col(column_number, &mut description)
1261 .map_err(|error| {
1262 crate::error::database_error_with_context(
1263 error,
1264 format!("failed to describe ODBC result column {column_number}"),
1265 )
1266 })?;
1267 let name = description
1268 .name_to_string()
1269 .unwrap_or_else(|_| format!("col{ordinal}"));
1270
1271 let nullable = match description.nullability {
1272 odbc_api::Nullability::NoNulls => Some(false),
1273 odbc_api::Nullability::Nullable => Some(true),
1274 odbc_api::Nullability::Unknown => None,
1275 };
1276
1277 columns.push(MssqlColumn::new(
1278 ordinal,
1279 name,
1280 MssqlTypeInfo::new(description.data_type),
1281 nullable,
1282 ));
1283 }
1284
1285 Ok(columns)
1286}
1287
1288fn collect_prepared_columns(
1289 prepared: &mut impl ResultSetMetadata,
1290) -> std::result::Result<Vec<MssqlColumn>, sqlx_core::Error> {
1291 match collect_columns(prepared) {
1292 Ok(columns) => Ok(columns),
1293 Err(error) => Err(error),
1294 }
1295}
1296
1297fn stream_result_sets<C>(
1298 mut cursor: C,
1299 settings: MssqlBufferSettings,
1300 tx: &ExecuteSender,
1301) -> std::result::Result<(), sqlx_core::Error>
1302where
1303 C: Cursor + ResultSetMetadata,
1304{
1305 loop {
1306 if cursor.num_result_cols().map_err(|error| {
1307 crate::error::database_error_with_context(
1308 error,
1309 "failed to read ODBC result-column count",
1310 )
1311 })? == 0
1312 {
1313 send_done(tx, 0);
1314 } else if let Some(max_column_size) = settings.max_column_size {
1315 let (receiver_open, finished_cursor) =
1316 stream_rows_buffered(cursor, settings.batch_size, max_column_size, tx)?;
1317 if !receiver_open {
1318 return Ok(());
1319 }
1320 cursor = finished_cursor;
1321 } else if !stream_rows_unbuffered(&mut cursor, tx)? {
1322 return Ok(());
1323 }
1324
1325 match cursor.more_results().map_err(|error| {
1326 crate::error::database_error_with_context(error, "failed to advance ODBC result set")
1327 })? {
1328 Some(next_cursor) => cursor = next_cursor,
1329 None => return Ok(()),
1330 }
1331 }
1332}
1333
1334#[derive(Debug)]
1335struct ColumnBinding {
1336 column: MssqlColumn,
1337 buffer_desc: BufferDesc,
1338}
1339
1340fn stream_rows_buffered<C>(
1341 cursor: C,
1342 batch_size: usize,
1343 max_column_size: usize,
1344 tx: &ExecuteSender,
1345) -> std::result::Result<(bool, C), sqlx_core::Error>
1346where
1347 C: Cursor + ResultSetMetadata,
1348{
1349 let mut cursor = cursor;
1350 let bindings = build_buffer_bindings(&mut cursor, max_column_size)?;
1351 let buffer_descriptions = bindings
1352 .iter()
1353 .map(|binding| binding.buffer_desc)
1354 .collect::<Vec<_>>();
1355 let mut row_set_cursor = cursor
1356 .bind_buffer(ColumnarDynBuffer::from_descs(
1357 batch_size,
1358 buffer_descriptions,
1359 ))
1360 .map_err(|error| {
1361 crate::error::database_error_with_context(
1362 error,
1363 format!(
1364 "ODBC buffered fetching could not be enabled with batch_size={batch_size}; \
1365 this driver may reject the row-array or row-binding statement attributes \
1366 used for column-wise buffered fetching, so use \
1367 MssqlConnectOptions::max_column_size(None) to fetch rows unbuffered"
1368 ),
1369 )
1370 })?;
1371 let columns: Arc<[MssqlColumn]> = bindings
1372 .iter()
1373 .map(|binding| binding.column.clone())
1374 .collect::<Vec<_>>()
1375 .into();
1376
1377 while let Some(batch) = row_set_cursor.fetch().map_err(|error| {
1378 crate::error::database_error_with_context(error, "ODBC buffered fetch failed")
1379 })? {
1380 let column_values = bindings
1381 .iter()
1382 .enumerate()
1383 .map(|(index, binding)| {
1384 buffered_column_values(batch.column(index), binding).map_err(|error| {
1385 sqlx_core::Error::Protocol(format!(
1386 "ODBC buffered fetch could not convert column {} (`{}`) using buffer {:?}: {error}",
1387 binding.column.ordinal() + 1,
1388 binding.column.name(),
1389 binding.buffer_desc
1390 ))
1391 })
1392 })
1393 .collect::<std::result::Result<Vec<_>, _>>()?;
1394
1395 let mut column_iters = column_values
1396 .into_iter()
1397 .map(Vec::into_iter)
1398 .collect::<Vec<_>>();
1399
1400 for row_index in 0..batch.num_rows() {
1401 let values = column_iters
1402 .iter_mut()
1403 .map(|values| {
1404 values.next().map(MssqlValue::new).ok_or_else(|| {
1405 sqlx_core::Error::Protocol(format!(
1406 "ODBC buffered fetch produced too few values for row {}",
1407 row_index + 1
1408 ))
1409 })
1410 })
1411 .collect::<std::result::Result<Vec<_>, _>>()?;
1412 if !send_row(tx, MssqlRow::new_shared(Arc::clone(&columns), values)) {
1413 let (cursor, _) = row_set_cursor.unbind().map_err(|error| {
1414 crate::error::database_error_with_context(
1415 error,
1416 "ODBC buffered fetch could not unbind row buffer after receiver closed",
1417 )
1418 })?;
1419 return Ok((false, cursor));
1420 }
1421 }
1422 }
1423
1424 send_done(tx, 0);
1425 let (cursor, _) = row_set_cursor.unbind().map_err(|error| {
1426 crate::error::database_error_with_context(
1427 error,
1428 "ODBC buffered fetch could not unbind row buffer",
1429 )
1430 })?;
1431 Ok((true, cursor))
1432}
1433
1434fn build_buffer_bindings(
1435 cursor: &mut impl ResultSetMetadata,
1436 max_column_size: usize,
1437) -> std::result::Result<Vec<ColumnBinding>, sqlx_core::Error> {
1438 collect_columns(cursor).map(|columns| {
1439 columns
1440 .into_iter()
1441 .map(|column| {
1442 let nullable = column.nullable().unwrap_or(true);
1443 ColumnBinding {
1444 buffer_desc: map_buffer_desc(column.type_info().data_type(), max_column_size, nullable),
1445 column,
1446 }
1447 })
1448 .collect()
1449 })
1450}
1451
1452fn map_buffer_desc(data_type: DataType, max_column_size: usize, nullable: bool) -> BufferDesc {
1453 match data_type {
1454 DataType::TinyInt | DataType::SmallInt | DataType::Integer | DataType::BigInt => {
1455 BufferDesc::I64 { nullable }
1456 }
1457 DataType::Real => BufferDesc::F32 { nullable },
1458 DataType::Float { .. } | DataType::Double => BufferDesc::F64 { nullable },
1459 DataType::Bit => BufferDesc::Bit { nullable },
1460 DataType::Date => BufferDesc::Date { nullable },
1461 DataType::Time { .. } => BufferDesc::Time { nullable },
1462 DataType::Timestamp { .. } => BufferDesc::Timestamp { nullable },
1463 DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => {
1464 BufferDesc::Binary {
1465 max_bytes: max_column_size,
1466 }
1467 }
1468 DataType::WChar { .. } | DataType::WVarchar { .. } | DataType::WLongVarchar { .. } => {
1471 BufferDesc::WText {
1472 max_str_len: max_column_size,
1473 }
1474 }
1475 DataType::Char { .. }
1477 | DataType::Varchar { .. }
1478 | DataType::LongVarchar { .. }
1479 | DataType::Other { .. }
1480 | DataType::Unknown
1481 | DataType::Decimal { .. }
1482 | DataType::Numeric { .. } => BufferDesc::Text {
1483 max_str_len: max_column_size,
1484 },
1485 }
1486}
1487
1488fn buffered_column_values(
1489 slice: AnyColumnBufferSlice<'_>,
1490 binding: &ColumnBinding,
1491) -> std::result::Result<Vec<MssqlValueKind>, sqlx_core::Error> {
1492 let desc = binding.buffer_desc;
1493 Ok(match desc {
1494 BufferDesc::I8 { nullable } => buffered_numeric(&slice, desc, nullable, |value: i8| {
1495 MssqlValueKind::TinyInt(i16::from(value))
1496 })?,
1497 BufferDesc::I16 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
1498 MssqlValueKind::SmallInt(value)
1499 })?,
1500 BufferDesc::I32 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
1501 MssqlValueKind::Integer(value)
1502 })?,
1503 BufferDesc::I64 { nullable } => {
1504 buffered_numeric(&slice, desc, nullable, MssqlValueKind::BigInt)?
1505 }
1506 BufferDesc::U8 { nullable } => buffered_numeric(&slice, desc, nullable, |value: u8| {
1507 MssqlValueKind::BigInt(i64::from(value))
1508 })?,
1509 BufferDesc::F32 { nullable } => {
1510 buffered_numeric(&slice, desc, nullable, MssqlValueKind::Real)?
1511 }
1512 BufferDesc::F64 { nullable } => {
1513 buffered_numeric(&slice, desc, nullable, MssqlValueKind::Double)?
1514 }
1515 BufferDesc::Bit { nullable } => {
1516 buffered_numeric(&slice, desc, nullable, |value: odbc_api::Bit| {
1517 MssqlValueKind::Bit(value.as_bool())
1518 })?
1519 }
1520 BufferDesc::Date { nullable } => {
1521 buffered_numeric(&slice, desc, nullable, MssqlValueKind::Date)?
1522 }
1523 BufferDesc::Time { nullable } => {
1524 buffered_numeric(&slice, desc, nullable, MssqlValueKind::Time)?
1525 }
1526 BufferDesc::Timestamp { nullable } => {
1527 buffered_numeric(&slice, desc, nullable, MssqlValueKind::Timestamp)?
1528 }
1529 BufferDesc::Text { .. } => {
1530 let text = expect_buffer_slice(slice.as_text(), desc)?;
1531 text.iter()
1532 .map(|value| {
1533 value
1534 .map(|bytes| {
1535 MssqlValueKind::Text(String::from_utf8_lossy(bytes).into_owned())
1536 })
1537 .unwrap_or(MssqlValueKind::Null)
1538 })
1539 .collect()
1540 }
1541 BufferDesc::WText { .. } => {
1542 let text = expect_buffer_slice(slice.as_wide_text(), desc)?;
1543 text.iter()
1544 .map(|value| {
1545 value
1546 .map(|chars| MssqlValueKind::Text(String::from_utf16_lossy(chars.into())))
1547 .unwrap_or(MssqlValueKind::Null)
1548 })
1549 .collect()
1550 }
1551 BufferDesc::Binary { .. } => {
1552 let binary = expect_buffer_slice(slice.as_binary(), desc)?;
1553 binary
1554 .iter()
1555 .map(|value| {
1556 value
1557 .map(|bytes| MssqlValueKind::Binary(bytes.to_vec()))
1558 .unwrap_or(MssqlValueKind::Null)
1559 })
1560 .collect()
1561 }
1562 BufferDesc::Numeric => {
1563 return Err(sqlx_core::Error::Protocol(format!(
1564 "unsupported ODBC buffer descriptor: {desc:?}"
1565 )))
1566 }
1567 })
1568}
1569
1570fn buffered_numeric<T, F>(
1571 slice: &AnyColumnBufferSlice<'_>,
1572 desc: BufferDesc,
1573 nullable: bool,
1574 map: F,
1575) -> std::result::Result<Vec<MssqlValueKind>, sqlx_core::Error>
1576where
1577 T: Copy + odbc_api::Pod,
1578 F: FnMut(T) -> MssqlValueKind,
1579{
1580 if nullable {
1581 Ok(buffered_nullable_numeric(
1582 expect_buffer_slice(slice.as_nullable_slice::<T>(), desc)?,
1583 map,
1584 ))
1585 } else {
1586 Ok(expect_buffer_slice(slice.as_slice::<T>(), desc)?
1587 .iter()
1588 .copied()
1589 .map(map)
1590 .collect())
1591 }
1592}
1593
1594fn buffered_nullable_numeric<T, F>(slice: NullableSlice<'_, T>, mut map: F) -> Vec<MssqlValueKind>
1595where
1596 T: Copy,
1597 F: FnMut(T) -> MssqlValueKind,
1598{
1599 slice
1600 .map(|value| value.copied().map(&mut map).unwrap_or(MssqlValueKind::Null))
1601 .collect()
1602}
1603
1604fn expect_buffer_slice<T>(
1605 slice: Option<T>,
1606 desc: BufferDesc,
1607) -> std::result::Result<T, sqlx_core::Error> {
1608 slice.ok_or_else(|| {
1609 sqlx_core::Error::Protocol(format!(
1610 "ODBC column buffer {desc:?} did not match fetched slice"
1611 ))
1612 })
1613}
1614
1615fn stream_rows_unbuffered<C>(
1616 cursor: &mut C,
1617 tx: &ExecuteSender,
1618) -> std::result::Result<bool, sqlx_core::Error>
1619where
1620 C: Cursor + ResultSetMetadata,
1621{
1622 let columns: Arc<[MssqlColumn]> = collect_columns(cursor)?.into();
1623
1624 while let Some(mut cursor_row) = cursor.next_row().map_err(|error| {
1625 crate::error::database_error_with_context(
1626 error,
1627 "ODBC unbuffered fetch failed while reading the next row",
1628 )
1629 })? {
1630 let mut values = Vec::with_capacity(columns.len());
1631
1632 for column in columns.iter() {
1633 let column_number = u16::try_from(sqlx_core::column::Column::ordinal(column) + 1)
1634 .map_err(|_| {
1635 sqlx_core::Error::Protocol("ODBC column index exceeds u16".to_owned())
1636 })?;
1637 values.push(fetch_value(&mut cursor_row, column_number, column)?);
1638 }
1639
1640 if !send_row(tx, MssqlRow::new_shared(Arc::clone(&columns), values)) {
1641 return Ok(false);
1642 }
1643 }
1644
1645 send_done(tx, 0);
1646 Ok(true)
1647}
1648
1649fn fetch_value(
1650 row: &mut odbc_api::CursorRow<'_>,
1651 column_number: u16,
1652 column: &MssqlColumn,
1653) -> std::result::Result<MssqlValue, sqlx_core::Error> {
1654 let data_type = column.type_info().data_type();
1655
1656 let kind = match data_type {
1657 DataType::Bit => {
1658 let mut value = Nullable::<odbc_api::Bit>::null();
1659 row.get_data(column_number, &mut value).map_err(|error| {
1660 crate::error::database_error_with_context_lazy(error, || {
1661 fetch_context(column, data_type)
1662 })
1663 })?;
1664 value
1665 .into_opt()
1666 .map(|value| MssqlValueKind::Bit(value.as_bool()))
1667 .unwrap_or(MssqlValueKind::Null)
1668 }
1669 DataType::TinyInt => {
1670 let mut value = Nullable::<i16>::null();
1673 row.get_data(column_number, &mut value).map_err(|error| {
1674 crate::error::database_error_with_context_lazy(error, || {
1675 fetch_context(column, data_type)
1676 })
1677 })?;
1678 value
1679 .into_opt()
1680 .map(MssqlValueKind::TinyInt)
1681 .unwrap_or(MssqlValueKind::Null)
1682 }
1683 DataType::SmallInt => fetch_nullable(
1684 row,
1685 column_number,
1686 column,
1687 data_type,
1688 MssqlValueKind::SmallInt,
1689 )?,
1690 DataType::Integer => fetch_nullable(
1691 row,
1692 column_number,
1693 column,
1694 data_type,
1695 MssqlValueKind::Integer,
1696 )?,
1697 DataType::BigInt => {
1698 fetch_nullable(row, column_number, column, data_type, MssqlValueKind::BigInt)?
1699 }
1700 DataType::Real => {
1701 fetch_nullable(row, column_number, column, data_type, MssqlValueKind::Real)?
1702 }
1703 DataType::Float { .. } | DataType::Double => {
1704 fetch_nullable(row, column_number, column, data_type, MssqlValueKind::Double)?
1705 }
1706 DataType::Date => {
1707 fetch_nullable(row, column_number, column, data_type, MssqlValueKind::Date)?
1708 }
1709 DataType::Time { .. } => {
1710 fetch_nullable(row, column_number, column, data_type, MssqlValueKind::Time)?
1711 }
1712 DataType::Timestamp { .. } => fetch_nullable(
1713 row,
1714 column_number,
1715 column,
1716 data_type,
1717 MssqlValueKind::Timestamp,
1718 )?,
1719 DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => {
1720 let mut value = Vec::new();
1721 if row.get_binary(column_number, &mut value).map_err(|error| {
1722 crate::error::database_error_with_context_lazy(error, || {
1723 fetch_context(column, data_type)
1724 })
1725 })? {
1726 MssqlValueKind::Binary(value)
1727 } else {
1728 MssqlValueKind::Null
1729 }
1730 }
1731 DataType::Other {
1732 data_type: sql_type, ..
1733 } if sql_type.0 == -11 => {
1734 let mut value = Vec::new();
1736 if row.get_binary(column_number, &mut value).map_err(|error| {
1737 crate::error::database_error_with_context_lazy(error, || {
1738 fetch_context(column, data_type)
1739 })
1740 })? {
1741 if value.len() == 16 {
1742 let mut guid = [0u8; 16];
1743 guid.copy_from_slice(&value);
1744 MssqlValueKind::Guid(guid)
1745 } else {
1746 MssqlValueKind::Text(String::from_utf16_lossy(
1748 &value.iter().map(|&b| b as u16).collect::<Vec<_>>(),
1749 ))
1750 }
1751 } else {
1752 MssqlValueKind::Null
1753 }
1754 }
1755 _ => {
1756 let mut value = Vec::new();
1757 if row
1758 .get_wide_text(column_number, &mut value)
1759 .map_err(|error| {
1760 crate::error::database_error_with_context_lazy(error, || {
1761 fetch_context(column, data_type)
1762 })
1763 })?
1764 {
1765 MssqlValueKind::Text(String::from_utf16_lossy(&value))
1766 } else {
1767 MssqlValueKind::Null
1768 }
1769 }
1770 };
1771
1772 Ok(MssqlValue::new(kind))
1773}
1774
1775fn fetch_nullable<T, F>(
1776 row: &mut odbc_api::CursorRow<'_>,
1777 column_number: u16,
1778 column: &MssqlColumn,
1779 data_type: DataType,
1780 map: F,
1781) -> std::result::Result<MssqlValueKind, sqlx_core::Error>
1782where
1783 T: Default + Copy + odbc_api::parameter::CElement + odbc_api::handles::CDataMut,
1784 Nullable<T>: odbc_api::parameter::CElement + odbc_api::handles::CDataMut,
1785 F: FnOnce(T) -> MssqlValueKind,
1786{
1787 let mut value = Nullable::<T>::null();
1788 row.get_data(column_number, &mut value).map_err(|error| {
1789 crate::error::database_error_with_context_lazy(error, || fetch_context(column, data_type))
1790 })?;
1791 Ok(value.into_opt().map(map).unwrap_or(MssqlValueKind::Null))
1792}
1793
1794fn fetch_context(column: &MssqlColumn, data_type: DataType) -> String {
1795 format!(
1796 "failed to fetch ODBC column {} (`{}`) as {data_type:?}",
1797 column.ordinal() + 1,
1798 column.name()
1799 )
1800}
1801
1802fn sql_preview(sql: &str) -> String {
1803 const MAX_LEN: usize = 160;
1804
1805 let compact = sql.split_whitespace().collect::<Vec<_>>().join(" ");
1806 if compact.len() <= MAX_LEN {
1807 compact
1808 } else {
1809 let mut preview = compact.chars().take(MAX_LEN - 3).collect::<String>();
1810 preview.push_str("...");
1811 preview
1812 }
1813}
1814
1815#[cfg(feature = "runtime-tokio")]
1820pub(crate) async fn offload_blocking<F, T>(f: F) -> std::result::Result<T, sqlx_core::Error>
1821where
1822 F: FnOnce() -> std::result::Result<T, sqlx_core::Error> + Send + 'static,
1823 T: Send + 'static,
1824{
1825 tokio::task::spawn_blocking(f)
1826 .await
1827 .map_err(|e| sqlx_core::Error::Protocol(format!("blocking task panicked: {e}")))?
1828}
1829
1830#[cfg(test)]
1831mod tests {
1832 use super::*;
1833
1834 #[test]
1835 fn buffered_fetch_maps_numeric_types_to_nullable_64_bit_buffers() {
1836 assert!(matches!(
1837 map_buffer_desc(DataType::TinyInt, 64, true),
1838 BufferDesc::I64 { nullable: true }
1839 ));
1840 assert!(matches!(
1841 map_buffer_desc(DataType::Integer, 64, true),
1842 BufferDesc::I64 { nullable: true }
1843 ));
1844 assert!(matches!(
1845 map_buffer_desc(DataType::BigInt, 64, true),
1846 BufferDesc::I64 { nullable: true }
1847 ));
1848 }
1849
1850 #[test]
1851 fn buffered_fetch_uses_configured_limits_for_variable_sized_data() {
1852 assert_eq!(
1853 map_buffer_desc(DataType::Varchar { length: None }, 32, true),
1854 BufferDesc::Text { max_str_len: 32 }
1855 );
1856 assert_eq!(
1857 map_buffer_desc(DataType::Varbinary { length: None }, 16, true),
1858 BufferDesc::Binary { max_bytes: 16 }
1859 );
1860 }
1861
1862 #[test]
1863 fn buffered_fetch_maps_wide_char_types_to_wtext() {
1864 assert!(matches!(
1865 map_buffer_desc(DataType::WChar { length: None }, 64, true),
1866 BufferDesc::WText { max_str_len: 64 }
1867 ));
1868 assert!(matches!(
1869 map_buffer_desc(DataType::WVarchar { length: None }, 128, true),
1870 BufferDesc::WText { max_str_len: 128 }
1871 ));
1872 assert!(matches!(
1873 map_buffer_desc(DataType::WLongVarchar { length: None }, 256, true),
1874 BufferDesc::WText { max_str_len: 256 }
1875 ));
1876 }
1877
1878 #[test]
1879 fn buffered_fetch_maps_narrow_char_types_to_text() {
1880 assert!(matches!(
1881 map_buffer_desc(DataType::Char { length: None }, 64, true),
1882 BufferDesc::Text { max_str_len: 64 }
1883 ));
1884 assert!(matches!(
1885 map_buffer_desc(DataType::Varchar { length: None }, 64, true),
1886 BufferDesc::Text { max_str_len: 64 }
1887 ));
1888 assert!(matches!(
1889 map_buffer_desc(DataType::LongVarchar { length: None }, 64, true),
1890 BufferDesc::Text { max_str_len: 64 }
1891 ));
1892 }
1893
1894}