1use crate::{
2 MssqlArguments, MssqlBufferSettings, MssqlColumn, MssqlConnectOptions,
3 MssqlQueryResult, MssqlRow, MssqlStatement, MssqlTypeInfo, MssqlValue, MssqlValueKind, Result,
4};
5use futures_core::future::BoxFuture;
6use futures_core::stream::BoxStream;
7use futures_util::{future, stream, StreamExt};
8use odbc_api::buffers::{AnyColumnBufferSlice, BufferDesc, ColumnarDynBuffer, NullableSlice};
9use odbc_api::{ConnectionTransitions, Cursor, DataType, Nullable, ResultSetMetadata};
10use sqlx_core::column::Column;
11use sqlx_core::common::StatementCache;
12use sqlx_core::executor::{Execute, Executor};
13use sqlx_core::sql_str::SqlStr;
14use sqlx_core::transaction::Transaction;
15use sqlx_core::Either;
16use std::future::Future;
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::sync::Arc;
19
20type PreparedStatement =
21 odbc_api::Prepared<odbc_api::handles::StatementConnection<odbc_api::SharedConnection<'static>>>;
22type ExecuteResult = std::result::Result<Either<MssqlQueryResult, MssqlRow>, sqlx_core::Error>;
23type ExecuteSender = flume::Sender<ExecuteResult>;
24
25enum Command {
30 Execute {
31 sql: SqlStr,
32 args: Option<MssqlArguments>,
33 persistent: bool,
34 response: ExecuteSender,
35 },
36 Prepare {
37 sql: SqlStr,
38 response: tokio::sync::oneshot::Sender<
39 std::result::Result<MssqlStatement, sqlx_core::Error>,
40 >,
41 },
42 Ping {
43 response: tokio::sync::oneshot::Sender<std::result::Result<(), sqlx_core::Error>>,
44 },
45 Begin {
46 response: tokio::sync::oneshot::Sender<std::result::Result<(), sqlx_core::Error>>,
47 },
48 Commit {
49 response: tokio::sync::oneshot::Sender<std::result::Result<(), sqlx_core::Error>>,
50 },
51 Rollback {
52 response: tokio::sync::oneshot::Sender<std::result::Result<(), sqlx_core::Error>>,
53 },
54 StartRollback,
55 ExecSql {
56 sql: String,
57 response: tokio::sync::oneshot::Sender<std::result::Result<(), sqlx_core::Error>>,
58 },
59 ScalarI64 {
60 sql: String,
61 response:
62 tokio::sync::oneshot::Sender<std::result::Result<Option<i64>, sqlx_core::Error>>,
63 },
64 Shutdown {
65 signal: tokio::sync::oneshot::Sender<()>,
66 },
67 ListMigrations {
69 sql: String,
70 response:
71 tokio::sync::oneshot::Sender<std::result::Result<Vec<(i64, Vec<u8>)>, sqlx_core::Error>>,
72 },
73 #[cfg(feature = "migrate")]
76 ApplyMigration {
77 sql: String,
78 insert_sql: String,
79 version: i64,
80 no_tx: bool,
81 response: tokio::sync::oneshot::Sender<std::result::Result<std::time::Duration, sqlx_core::Error>>,
82 },
83 #[cfg(feature = "migrate")]
86 RevertMigration {
87 sql: String,
88 delete_sql: String,
89 version: i64,
90 no_tx: bool,
91 response: tokio::sync::oneshot::Sender<std::result::Result<std::time::Duration, sqlx_core::Error>>,
92 },
93}
94
95struct ConnectionActor {
100 conn: odbc_api::SharedConnection<'static>,
101 stmt_cache: StatementCache<PreparedStatement>,
102 transaction_depth: usize,
103 buffer_settings: MssqlBufferSettings,
104}
105
106impl ConnectionActor {
107 fn run(mut self, rx: flume::Receiver<Command>) {
108 for cmd in rx {
111 match cmd {
114 Command::Execute {
115 sql,
116 args,
117 persistent,
118 response,
119 } => {
120 let _ = self.handle_execute(sql, args, persistent, &response);
121 }
122 Command::Prepare { sql, response } => {
123 let _ = response.send(self.handle_prepare(sql));
124 }
125 Command::Ping { response } => {
126 let _ = response.send(self.handle_ping());
127 }
128 Command::Begin { response } => {
129 let _ = response.send(self.handle_begin());
130 }
131 Command::Commit { response } => {
132 let _ = response.send(self.handle_commit());
133 }
134 Command::Rollback { response } => {
135 let _ = response.send(self.handle_rollback());
136 }
137 Command::StartRollback => {
138 self.handle_start_rollback();
139 }
140 Command::ExecSql { sql, response } => {
141 let _ = response.send(self.handle_exec_sql(&sql));
142 }
143 Command::ScalarI64 { sql, response } => {
144 let _ = response.send(self.handle_scalar_i64(&sql));
145 }
146 Command::Shutdown { signal } => {
147 let _ = signal.send(());
148 return;
149 }
150 Command::ListMigrations { sql, response } => {
151 let _ = response.send(self.handle_list_migrations(&sql));
152 }
153 #[cfg(feature = "migrate")]
154 Command::ApplyMigration {
155 sql,
156 insert_sql,
157 version,
158 no_tx,
159 response,
160 } => {
161 let _ = response.send(self.handle_apply_migration(&sql, &insert_sql, version, no_tx));
162 }
163 #[cfg(feature = "migrate")]
164 Command::RevertMigration {
165 sql,
166 delete_sql,
167 version,
168 no_tx,
169 response,
170 } => {
171 let _ = response.send(self.handle_revert_migration(&sql, &delete_sql, version, no_tx));
172 }
173 }
174 }
175 }
177
178 fn handle_execute(
183 &mut self,
184 sql: SqlStr,
185 arguments: Option<MssqlArguments>,
186 persistent: bool,
187 tx: &ExecuteSender,
188 ) -> std::result::Result<(), sqlx_core::Error> {
189 let has_arguments = arguments.as_ref().is_some_and(|a| !a.is_empty());
190 let parameters = arguments
191 .as_ref()
192 .map(MssqlArguments::to_odbc_parameter_collection)
193 .unwrap_or_default();
194
195 if persistent && has_arguments {
196 if let Some(prepared) = self.stmt_cache.get_mut(sql.as_str()) {
197 let mut conn_guard = self.conn.lock().map_err(|_| {
199 sqlx_core::Error::Protocol(
200 "ODBC execute: failed to lock connection".to_owned(),
201 )
202 })?;
203 let has_cursor = prepared
204 .execute(parameters.as_slice())
205 .map_err(|error| {
206 crate::error::database_error_with_context_lazy(error, || {
207 format!(
208 "failed to execute cached ODBC statement: `{}`",
209 sql_preview(sql.as_str())
210 )
211 })
212 })?
213 .is_some();
214 drop(conn_guard);
215
216 if has_cursor {
217 let mut conn_guard = self.conn.lock().map_err(|_| {
219 sqlx_core::Error::Protocol(
220 "ODBC execute: failed to lock connection".to_owned(),
221 )
222 })?;
223 let cursor = prepared
224 .execute(parameters.as_slice())
225 .map_err(|error| {
226 crate::error::database_error_with_context_lazy(error, || {
227 format!(
228 "failed to execute cached ODBC statement: `{}`",
229 sql_preview(sql.as_str())
230 )
231 })
232 })?
233 .expect("has_cursor was true");
234 drop(conn_guard);
235 return stream_result_sets(cursor, self.buffer_settings, tx);
236 }
237
238 let ra = prepared.row_count().map_err(|error| {
239 crate::error::database_error_with_context_lazy(error, || {
240 format!(
241 "failed to read ODBC row count for cached statement: `{}`",
242 sql_preview(sql.as_str())
243 )
244 })
245 })?;
246 return send_rows_affected(ra, tx);
247 } else {
248 let mut prepared =
250 self.conn.clone().into_prepared(sql.as_str()).map_err(|error| {
251 crate::error::database_error_with_context_lazy(error, || {
252 format!(
253 "failed to prepare cached ODBC statement: `{}`",
254 sql_preview(sql.as_str())
255 )
256 })
257 })?;
258
259 let mut conn_guard = self.conn.lock().map_err(|_| {
260 sqlx_core::Error::Protocol(
261 "ODBC execute: failed to lock connection".to_owned(),
262 )
263 })?;
264 let has_cursor = prepared
265 .execute(parameters.as_slice())
266 .map_err(|error| {
267 crate::error::database_error_with_context_lazy(error, || {
268 format!(
269 "failed to execute cached ODBC statement: `{}`",
270 sql_preview(sql.as_str())
271 )
272 })
273 })?
274 .is_some();
275 drop(conn_guard);
276
277 if has_cursor {
278 let mut conn_guard = self.conn.lock().map_err(|_| {
280 sqlx_core::Error::Protocol(
281 "ODBC execute: failed to lock connection".to_owned(),
282 )
283 })?;
284 let cursor = prepared
285 .execute(parameters.as_slice())
286 .map_err(|error| {
287 crate::error::database_error_with_context_lazy(error, || {
288 format!(
289 "failed to execute cached ODBC statement: `{}`",
290 sql_preview(sql.as_str())
291 )
292 })
293 })?
294 .expect("has_cursor was true");
295 drop(conn_guard);
296 return stream_result_sets(cursor, self.buffer_settings, tx);
297 }
298
299 let ra = prepared.row_count().map_err(|error| {
300 crate::error::database_error_with_context_lazy(error, || {
301 format!(
302 "failed to read ODBC row count for cached statement: `{}`",
303 sql_preview(sql.as_str())
304 )
305 })
306 })?;
307 self.stmt_cache.insert(sql.as_str(), prepared);
308 return send_rows_affected(ra, tx);
309 }
310 } else {
311 let mut statement = self.conn.clone().into_preallocated().map_err(|error| {
313 crate::error::database_error_with_context_lazy(error, || {
314 format!(
315 "failed to allocate an ODBC statement for query: `{}`",
316 sql_preview(sql.as_str())
317 )
318 })
319 })?;
320 if let Some(cursor) = statement
321 .execute(sql.as_str(), parameters.as_slice())
322 .map_err(|error| {
323 crate::error::database_error_with_context_lazy(error, || {
324 format!(
325 "failed to execute ODBC query: `{}`",
326 sql_preview(sql.as_str())
327 )
328 })
329 })? {
330 return stream_result_sets(cursor, self.buffer_settings, tx);
331 }
332 let rows_affected = statement.row_count().map_err(|error| {
333 crate::error::database_error_with_context_lazy(error, || {
334 format!(
335 "failed to read ODBC row count for query: `{}`",
336 sql_preview(sql.as_str())
337 )
338 })
339 })?;
340 send_rows_affected(rows_affected, tx)
341 }
342 }
343
344 fn handle_prepare(
345 &mut self,
346 sql: SqlStr,
347 ) -> std::result::Result<MssqlStatement, sqlx_core::Error> {
348 if let Some(prepared) = self.stmt_cache.get_mut(sql.as_str()) {
349 let parameters = prepared.num_params().map_err(|error| {
350 sqlx_core::Error::from(crate::error::database_error_with_context(
351 error,
352 format!(
353 "failed to read ODBC parameter metadata for cached statement: `{}`",
354 sql_preview(sql.as_str())
355 ),
356 ))
357 })?;
358 let columns = collect_prepared_columns(prepared, parameters)?;
359 return Ok(MssqlStatement::new(sql, columns, usize::from(parameters)));
360 }
361
362 let mut prepared = self.conn.clone().into_prepared(sql.as_str()).map_err(|error| {
363 sqlx_core::Error::from(crate::error::database_error_with_context(
364 error,
365 format!(
366 "failed to prepare MSSQL ODBC statement: `{}`",
367 sql_preview(sql.as_str())
368 ),
369 ))
370 })?;
371 let parameters = prepared.num_params().map_err(|error| {
372 sqlx_core::Error::from(crate::error::database_error_with_context(
373 error,
374 format!(
375 "failed to read ODBC parameter metadata for prepared statement: `{}`",
376 sql_preview(sql.as_str())
377 ),
378 ))
379 })?;
380 let columns = collect_prepared_columns(&mut prepared, parameters)?;
381 if self.stmt_cache.is_enabled() {
382 self.stmt_cache.insert(sql.as_str(), prepared);
383 }
384
385 Ok(MssqlStatement::new(sql, columns, usize::from(parameters)))
386 }
387
388 fn handle_ping(&mut self) -> std::result::Result<(), sqlx_core::Error> {
389 let mut conn_guard = self.conn.lock().map_err(|_| {
390 sqlx_core::Error::Protocol("failed to lock connection for ping".into())
391 })?;
392 conn_guard.execute("SELECT 1", (), None).map_err(|error| {
393 sqlx_core::Error::from(crate::error::database_error_with_context(
394 error,
395 "MSSQL ping query failed: `SELECT 1`",
396 ))
397 })?;
398 Ok(())
399 }
400
401 fn handle_begin(&mut self) -> std::result::Result<(), sqlx_core::Error> {
402 if self.transaction_depth == 0 {
403 let mut conn_guard = self.conn.lock().map_err(|_| {
404 sqlx_core::Error::Protocol(
405 "MSSQL ODBC begin: failed to lock connection".to_owned(),
406 )
407 })?;
408 conn_guard.set_autocommit(false).map_err(|error| {
409 sqlx_core::Error::from(crate::error::database_error_with_context(
410 error,
411 "failed to disable ODBC autocommit while beginning a transaction",
412 ))
413 })?;
414 } else {
415 let savepoint = format!("sqlx_sp_{}", self.transaction_depth);
416 let mut conn_guard = self.conn.lock().map_err(|_| {
417 sqlx_core::Error::Protocol(
418 "MSSQL ODBC begin (savepoint): failed to lock connection".to_owned(),
419 )
420 })?;
421 conn_guard
422 .execute(&format!("SAVE TRANSACTION {savepoint}"), (), None)
423 .map_err(|error| {
424 sqlx_core::Error::from(crate::error::database_error_with_context(
425 error,
426 format!(
427 "failed to create save point `{savepoint}` for nested transaction"
428 ),
429 ))
430 })?;
431 }
432 self.transaction_depth += 1;
433 Ok(())
434 }
435
436 fn handle_commit(&mut self) -> std::result::Result<(), sqlx_core::Error> {
437 if self.transaction_depth == 0 {
438 return Ok(());
439 }
440
441 if self.transaction_depth == 1 {
442 let mut conn_guard = self.conn.lock().map_err(|_| {
443 sqlx_core::Error::Protocol(
444 "MSSQL ODBC commit: failed to lock connection".to_owned(),
445 )
446 })?;
447 conn_guard.commit().map_err(|error| {
448 sqlx_core::Error::from(crate::error::database_error_with_context(
449 error,
450 "failed to commit the active MSSQL ODBC transaction",
451 ))
452 })?;
453 conn_guard.set_autocommit(true).map_err(|error| {
454 sqlx_core::Error::from(crate::error::database_error_with_context(
455 error,
456 "failed to restore ODBC autocommit after commit",
457 ))
458 })?;
459 self.transaction_depth = 0;
460 } else {
461 self.transaction_depth -= 1;
462 }
463 Ok(())
464 }
465
466 fn handle_rollback(&mut self) -> std::result::Result<(), sqlx_core::Error> {
467 if self.transaction_depth == 0 {
468 return Ok(());
469 }
470
471 if self.transaction_depth == 1 {
472 let mut conn_guard = self.conn.lock().map_err(|_| {
473 sqlx_core::Error::Protocol(
474 "MSSQL ODBC rollback: failed to lock connection".to_owned(),
475 )
476 })?;
477 conn_guard.rollback().map_err(|error| {
478 sqlx_core::Error::from(crate::error::database_error_with_context(
479 error,
480 "failed to roll back the active ODBC transaction",
481 ))
482 })?;
483 conn_guard.set_autocommit(true).map_err(|error| {
484 sqlx_core::Error::from(crate::error::database_error_with_context(
485 error,
486 "failed to restore ODBC autocommit after rollback",
487 ))
488 })?;
489 self.transaction_depth = 0;
490 } else {
491 let savepoint = format!("sqlx_sp_{}", self.transaction_depth - 1);
492 let mut conn_guard = self.conn.lock().map_err(|_| {
493 sqlx_core::Error::Protocol(
494 "MSSQL ODBC rollback (savepoint): failed to lock connection".to_owned(),
495 )
496 })?;
497 conn_guard
498 .execute(&format!("ROLLBACK TRANSACTION {savepoint}"), (), None)
499 .map_err(|error| {
500 sqlx_core::Error::from(crate::error::database_error_with_context(
501 error,
502 format!("failed to roll back to save point `{savepoint}`"),
503 ))
504 })?;
505 self.transaction_depth -= 1;
506 }
507 Ok(())
508 }
509
510 fn handle_start_rollback(&mut self) {
511 if self.transaction_depth == 0 {
512 return;
513 }
514
515 if self.transaction_depth == 1 {
516 if let Ok(mut conn_guard) = self.conn.lock() {
517 let _ = conn_guard.rollback();
518 let _ = conn_guard.set_autocommit(true);
519 }
520 self.transaction_depth = 0;
521 } else {
522 let savepoint = format!("sqlx_sp_{}", self.transaction_depth - 1);
523 if let Ok(mut conn_guard) = self.conn.lock() {
524 let _ = conn_guard.execute(
525 &format!("ROLLBACK TRANSACTION {savepoint}"),
526 (),
527 None,
528 );
529 }
530 self.transaction_depth -= 1;
531 }
532 }
533
534 fn handle_exec_sql(&self, sql: &str) -> std::result::Result<(), sqlx_core::Error> {
535 let mut conn_guard = self.conn.lock().map_err(|_| {
536 sqlx_core::Error::Protocol("failed to lock the shared ODBC connection".into())
537 })?;
538 conn_guard.execute(sql, (), None).map_err(|error| {
539 sqlx_core::Error::from(crate::error::database_error_with_context(
540 error,
541 format!("failed to execute SQL: `{}`", sql_preview(sql)),
542 ))
543 })?;
544 Ok(())
545 }
546
547 fn handle_scalar_i64(&self, sql: &str) -> std::result::Result<Option<i64>, sqlx_core::Error> {
548 let mut conn_guard = self.conn.lock().map_err(|_| {
549 sqlx_core::Error::Protocol("failed to lock the shared ODBC connection".into())
550 })?;
551 let mut cursor = conn_guard
552 .execute(sql, (), None)
553 .map_err(|error| {
554 sqlx_core::Error::from(crate::error::database_error_with_context(
555 error,
556 format!("scalar query failed: `{}`", sql_preview(sql)),
557 ))
558 })?
559 .ok_or_else(|| {
560 sqlx_core::Error::Protocol(format!(
561 "scalar query returned no result set: `{}`",
562 sql_preview(sql),
563 ))
564 })?;
565
566 if let Some(mut row) = cursor.next_row().map_err(|error| {
567 sqlx_core::Error::from(crate::error::database_error_with_context(
568 error,
569 "scalar query next row",
570 ))
571 })? {
572 let mut value: Nullable<i64> = Nullable::null();
573 row.get_data(1, &mut value).map_err(|error| {
574 sqlx_core::Error::from(crate::error::database_error_with_context(
575 error,
576 "scalar query column 1",
577 ))
578 })?;
579 Ok(value.into_opt())
580 } else {
581 Ok(None)
582 }
583 }
584
585 fn handle_list_migrations(
586 &self,
587 sql: &str,
588 ) -> std::result::Result<Vec<(i64, Vec<u8>)>, sqlx_core::Error> {
589 let mut conn_guard = self.conn.lock().map_err(|_| {
590 sqlx_core::Error::Protocol("failed to lock the shared ODBC connection".into())
591 })?;
592 let mut cursor = conn_guard
593 .execute(sql, (), None)
594 .map_err(|error| {
595 sqlx_core::Error::from(crate::error::database_error_with_context(
596 error,
597 "failed to query applied migrations",
598 ))
599 })?
600 .ok_or_else(|| {
601 sqlx_core::Error::Protocol(
602 "list_applied_migrations returned no result set".into(),
603 )
604 })?;
605
606 let mut migrations = Vec::new();
607 while let Some(mut row) = cursor.next_row().map_err(|error| {
608 sqlx_core::Error::from(crate::error::database_error_with_context(
609 error,
610 "failed to read applied migration row",
611 ))
612 })? {
613 let mut version: Nullable<i64> = Nullable::null();
614 row.get_data(1, &mut version).map_err(|error| {
615 sqlx_core::Error::from(crate::error::database_error_with_context(
616 error,
617 "failed to read migration version",
618 ))
619 })?;
620
621 let mut checksum_bytes = Vec::new();
622 let has_value = row.get_binary(2, &mut checksum_bytes).map_err(|error| {
623 sqlx_core::Error::from(crate::error::database_error_with_context(
624 error,
625 "failed to read migration checksum",
626 ))
627 })?;
628
629 if let Some(version) = version.into_opt() {
630 migrations.push((version, if has_value { checksum_bytes } else { vec![] }));
631 }
632 }
633
634 Ok(migrations)
635 }
636
637 #[cfg(feature = "migrate")]
638 fn handle_apply_migration(
639 &mut self,
640 sql: &str,
641 insert_sql: &str,
642 version: i64,
643 no_tx: bool,
644 ) -> std::result::Result<std::time::Duration, sqlx_core::Error> {
645 let start = std::time::Instant::now();
646 let mut conn_guard = self.conn.lock().map_err(|_| {
647 sqlx_core::Error::Protocol(
648 "failed to lock the shared ODBC connection for migration".into(),
649 )
650 })?;
651
652 if !no_tx {
653 conn_guard.set_autocommit(false).map_err(|error| {
654 sqlx_core::Error::from(crate::error::database_error_with_context(
655 error,
656 "failed to start transaction for migration apply",
657 ))
658 })?;
659 }
660
661 conn_guard.execute(sql, (), None).map_err(|error| {
662 sqlx_core::Error::from(crate::error::database_error_with_context(
663 error,
664 format!("migration {version} failed"),
665 ))
666 })?;
667
668 conn_guard.execute(insert_sql, (), None).map_err(|error| {
669 sqlx_core::Error::from(crate::error::database_error_with_context(
670 error,
671 format!("failed to insert tracking record for migration {version}"),
672 ))
673 })?;
674
675 if !no_tx {
676 conn_guard.commit().map_err(|error| {
677 sqlx_core::Error::from(crate::error::database_error_with_context(
678 error,
679 format!("failed to commit migration {version}"),
680 ))
681 })?;
682 conn_guard.set_autocommit(true).map_err(|error| {
683 sqlx_core::Error::from(crate::error::database_error_with_context(
684 error,
685 "failed to restore autocommit after migration apply",
686 ))
687 })?;
688 }
689
690 Ok(start.elapsed())
691 }
692
693 #[cfg(feature = "migrate")]
694 fn handle_revert_migration(
695 &mut self,
696 sql: &str,
697 delete_sql: &str,
698 version: i64,
699 no_tx: bool,
700 ) -> std::result::Result<std::time::Duration, sqlx_core::Error> {
701 let start = std::time::Instant::now();
702 let mut conn_guard = self.conn.lock().map_err(|_| {
703 sqlx_core::Error::Protocol(
704 "failed to lock the shared ODBC connection for migration".into(),
705 )
706 })?;
707
708 if !no_tx {
709 conn_guard.set_autocommit(false).map_err(|error| {
710 sqlx_core::Error::from(crate::error::database_error_with_context(
711 error,
712 "failed to start transaction for migration revert",
713 ))
714 })?;
715 }
716
717 conn_guard.execute(sql, (), None).map_err(|error| {
718 sqlx_core::Error::from(crate::error::database_error_with_context(
719 error,
720 format!("revert migration {version} failed"),
721 ))
722 })?;
723
724 conn_guard.execute(delete_sql, (), None).map_err(|error| {
725 sqlx_core::Error::from(crate::error::database_error_with_context(
726 error,
727 format!("failed to delete tracking record for migration {version}"),
728 ))
729 })?;
730
731 if !no_tx {
732 conn_guard.commit().map_err(|error| {
733 sqlx_core::Error::from(crate::error::database_error_with_context(
734 error,
735 format!("failed to commit migration revert {version}"),
736 ))
737 })?;
738 conn_guard.set_autocommit(true).map_err(|error| {
739 sqlx_core::Error::from(crate::error::database_error_with_context(
740 error,
741 "failed to restore autocommit after migration revert",
742 ))
743 })?;
744 }
745
746 Ok(start.elapsed())
747 }
748}
749
750pub struct MssqlConnection {
752 cmd_tx: flume::Sender<Command>,
753 buffer_settings: MssqlBufferSettings,
754 transaction_depth: AtomicUsize,
755}
756
757impl std::fmt::Debug for MssqlConnection {
758 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
759 f.debug_struct("MssqlConnection").finish_non_exhaustive()
760 }
761}
762
763impl MssqlConnection {
764 pub fn connect_blocking(options: &MssqlConnectOptions) -> Result<Self> {
767 let env = odbc_api::environment().map_err(|error| {
768 crate::MssqlError::Configuration(format!(
769 "failed to initialize the process-wide ODBC environment: {error}"
770 ))
771 })?;
772
773 let raw_conn = env
774 .connect_with_connection_string(options.connection_string(), Default::default())
775 .map_err(|error| {
776 crate::error::database_error_with_context(
777 error,
778 "failed to open MSSQL ODBC connection using the supplied connection string",
779 )
780 })?;
781
782 let conn: odbc_api::SharedConnection<'static> =
784 std::sync::Arc::new(std::sync::Mutex::new(raw_conn));
785
786 let (cmd_tx, cmd_rx) = flume::unbounded();
787
788 let actor = ConnectionActor {
789 conn,
790 stmt_cache: StatementCache::new(options.statement_cache_capacity),
791 transaction_depth: 0,
792 buffer_settings: options.buffer_settings,
793 };
794
795 std::thread::spawn(move || actor.run(cmd_rx));
799
800 Ok(Self {
801 cmd_tx,
802 buffer_settings: options.buffer_settings,
803 transaction_depth: AtomicUsize::new(0),
804 })
805 }
806
807 pub fn ping_blocking(&self) -> std::result::Result<(), sqlx_core::Error> {
809 send_command_blocking(&self.cmd_tx, |tx| Command::Ping { response: tx })?
810 }
811
812 pub fn dbms_name(&self) -> std::result::Result<String, sqlx_core::Error> {
814 send_command_blocking(&self.cmd_tx, |tx| {
815 Command::ExecSql {
816 sql: "SELECT 1 /* dbms_name */".into(),
817 response: tx,
818 }
819 })?;
820 Ok("MSSQL via ODBC".to_owned())
821 }
822
823 pub(crate) fn begin_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
825 let r = send_command_blocking(&self.cmd_tx, |tx| Command::Begin { response: tx })?;
826 if r.is_ok() {
827 self.transaction_depth.fetch_add(1, Ordering::SeqCst);
828 }
829 r
830 }
831
832 pub(crate) fn commit_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
834 let depth = self.transaction_depth.load(Ordering::SeqCst);
835 if depth == 0 {
836 return Ok(());
837 }
838 let r = send_command_blocking(&self.cmd_tx, |tx| Command::Commit { response: tx })?;
839 if r.is_ok() {
840 if depth == 1 {
841 self.transaction_depth.store(0, Ordering::SeqCst);
842 } else {
843 self.transaction_depth.fetch_sub(1, Ordering::SeqCst);
844 }
845 }
846 r
847 }
848
849 pub(crate) fn rollback_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
851 let depth = self.transaction_depth.load(Ordering::SeqCst);
852 if depth == 0 {
853 return Ok(());
854 }
855 let r = send_command_blocking(&self.cmd_tx, |tx| Command::Rollback { response: tx })?;
856 if r.is_ok() {
857 if depth == 1 {
858 self.transaction_depth.store(0, Ordering::SeqCst);
859 } else {
860 self.transaction_depth.fetch_sub(1, Ordering::SeqCst);
861 }
862 }
863 r
864 }
865
866 pub(crate) fn start_rollback(&mut self) {
868 let _ = self.cmd_tx.try_send(Command::StartRollback);
869 self.transaction_depth.store(0, Ordering::SeqCst);
870 }
871
872 pub(crate) fn transaction_depth(&self) -> usize {
874 self.transaction_depth.load(Ordering::SeqCst)
875 }
876
877 pub(crate) fn set_transaction_depth(&mut self, depth: usize) {
879 self.transaction_depth.store(depth, Ordering::SeqCst);
880 }
881
882 pub fn prepare_blocking(
884 &self,
885 sql: sqlx_core::sql_str::SqlStr,
886 ) -> std::result::Result<MssqlStatement, sqlx_core::Error> {
887 send_command_blocking(&self.cmd_tx, |tx| Command::Prepare { sql, response: tx })?
888 }
889
890 #[cfg(feature = "migrate")]
892 pub(crate) fn exec_sql_blocking(&self, sql: &str) -> std::result::Result<(), sqlx_core::Error> {
893 send_command_blocking(&self.cmd_tx, |tx| {
894 Command::ExecSql {
895 sql: sql.to_owned(),
896 response: tx,
897 }
898 })?
899 }
900
901 #[cfg(feature = "migrate")]
903 pub(crate) fn scalar_i64_blocking(
904 &self,
905 sql: &str,
906 ) -> std::result::Result<Option<i64>, sqlx_core::Error> {
907 send_command_blocking(&self.cmd_tx, |tx| {
908 Command::ScalarI64 {
909 sql: sql.to_owned(),
910 response: tx,
911 }
912 })?
913 }
914
915 #[cfg(feature = "migrate")]
917 pub(crate) fn list_migrations_blocking(
918 &self,
919 sql: &str,
920 ) -> std::result::Result<Vec<(i64, Vec<u8>)>, sqlx_core::Error> {
921 send_command_blocking(&self.cmd_tx, |tx| {
922 Command::ListMigrations {
923 sql: sql.to_owned(),
924 response: tx,
925 }
926 })?
927 }
928
929 #[cfg(feature = "migrate")]
931 pub(crate) fn apply_migration_blocking(
932 &self,
933 sql: &str,
934 insert_sql: &str,
935 version: i64,
936 no_tx: bool,
937 ) -> std::result::Result<std::time::Duration, sqlx_core::Error> {
938 send_command_blocking(&self.cmd_tx, |tx| {
939 Command::ApplyMigration {
940 sql: sql.to_owned(),
941 insert_sql: insert_sql.to_owned(),
942 version,
943 no_tx,
944 response: tx,
945 }
946 })?
947 }
948
949 #[cfg(feature = "migrate")]
951 pub(crate) fn revert_migration_blocking(
952 &self,
953 sql: &str,
954 delete_sql: &str,
955 version: i64,
956 no_tx: bool,
957 ) -> std::result::Result<std::time::Duration, sqlx_core::Error> {
958 send_command_blocking(&self.cmd_tx, |tx| {
959 Command::RevertMigration {
960 sql: sql.to_owned(),
961 delete_sql: delete_sql.to_owned(),
962 version,
963 no_tx,
964 response: tx,
965 }
966 })?
967 }
968
969 pub(crate) fn execute_receiver(
971 &self,
972 sql: sqlx_core::sql_str::SqlStr,
973 persistent: bool,
974 arguments: Option<MssqlArguments>,
975 ) -> flume::Receiver<ExecuteResult> {
976 let (tx, rx) = flume::bounded(64);
977 if self
978 .cmd_tx
979 .send(Command::Execute {
980 sql,
981 args: arguments,
982 persistent,
983 response: tx,
984 })
985 .is_err()
986 {
987 let _ = rx.drain();
989 }
990 rx
991 }
992}
993
994impl Drop for MssqlConnection {
996 fn drop(&mut self) {}
997}
998
999impl sqlx_core::connection::Connection for MssqlConnection {
1004 type Database = crate::Mssql;
1005 type Options = MssqlConnectOptions;
1006
1007 async fn close(self) -> std::result::Result<(), sqlx_core::Error> {
1008 drop(self);
1009 Ok(())
1010 }
1011
1012 async fn close_hard(self) -> std::result::Result<(), sqlx_core::Error> {
1013 drop(self);
1014 Ok(())
1015 }
1016
1017 async fn ping(&mut self) -> std::result::Result<(), sqlx_core::Error> {
1018 send_command_async(&self.cmd_tx, |tx| Command::Ping { response: tx }).await?
1019 }
1020
1021 fn begin(
1022 &mut self,
1023 ) -> impl Future<Output = std::result::Result<Transaction<'_, Self::Database>, sqlx_core::Error>>
1024 + Send
1025 + '_ {
1026 Transaction::begin(self, None)
1027 }
1028
1029 fn shrink_buffers(&mut self) {}
1030
1031 async fn flush(&mut self) -> std::result::Result<(), sqlx_core::Error> {
1032 Ok(())
1033 }
1034
1035 fn should_flush(&self) -> bool {
1036 false
1037 }
1038
1039 fn cached_statements_size(&self) -> usize
1040 where
1041 Self::Database: sqlx_core::database::HasStatementCache,
1042 {
1043 0
1046 }
1047
1048 async fn clear_cached_statements(&mut self) -> std::result::Result<(), sqlx_core::Error>
1049 where
1050 Self::Database: sqlx_core::database::HasStatementCache,
1051 {
1052 Ok(())
1056 }
1057}
1058
1059impl<'c> Executor<'c> for &'c mut MssqlConnection {
1064 type Database = crate::Mssql;
1065
1066 fn fetch_many<'e, 'q, E>(
1067 self,
1068 mut query: E,
1069 ) -> BoxStream<'e, std::result::Result<Either<MssqlQueryResult, MssqlRow>, sqlx_core::Error>>
1070 where
1071 'c: 'e,
1072 E: Execute<'q, Self::Database>,
1073 'q: 'e,
1074 E: 'q,
1075 {
1076 let arguments = query.take_arguments().map_err(sqlx_core::Error::Encode);
1077 let persistent = query.persistent();
1078 let sql = query.sql();
1079
1080 match arguments {
1081 Ok(arguments) => {
1082 receiver_to_stream(self.execute_receiver(sql, persistent, arguments))
1083 }
1084 Err(error) => stream::once(future::ready(Err(error))).boxed(),
1085 }
1086 }
1087
1088 fn fetch_optional<'e, 'q, E>(
1089 self,
1090 mut query: E,
1091 ) -> BoxFuture<'e, std::result::Result<Option<MssqlRow>, sqlx_core::Error>>
1092 where
1093 'c: 'e,
1094 E: Execute<'q, Self::Database>,
1095 'q: 'e,
1096 E: 'q,
1097 {
1098 let arguments = query.take_arguments().map_err(sqlx_core::Error::Encode);
1099 let persistent = query.persistent();
1100 let sql = query.sql();
1101
1102 Box::pin(async move {
1103 let rx = self.execute_receiver(sql, persistent, arguments?);
1104 while let Ok(item) = rx.recv_async().await {
1105 match item? {
1106 Either::Right(row) => return Ok(Some(row)),
1107 Either::Left(_) => {}
1108 }
1109 }
1110 Ok(None)
1111 })
1112 }
1113
1114 fn prepare_with<'e>(
1115 self,
1116 sql: sqlx_core::sql_str::SqlStr,
1117 _parameters: &[crate::MssqlTypeInfo],
1118 ) -> BoxFuture<'e, std::result::Result<MssqlStatement, sqlx_core::Error>>
1119 where
1120 'c: 'e,
1121 {
1122 let cmd_tx = self.cmd_tx.clone();
1123 Box::pin(async move {
1124 send_command_async(&cmd_tx, |tx| Command::Prepare { sql, response: tx }).await?
1125 })
1126 }
1127
1128 #[cfg(feature = "offline")]
1129 fn describe<'e>(
1130 self,
1131 sql: sqlx_core::sql_str::SqlStr,
1132 ) -> BoxFuture<'e, std::result::Result<sqlx_core::describe::Describe<Self::Database>, sqlx_core::Error>>
1133 where
1134 'c: 'e,
1135 {
1136 use sqlx_core::statement::Statement;
1137 let cmd_tx = self.cmd_tx.clone();
1138 Box::pin(async move {
1139 let statement =
1140 send_command_async(&cmd_tx, |tx| Command::Prepare { sql, response: tx }).await??;
1141 let columns = statement.columns().to_vec();
1142 let column_count = columns.len();
1143 let parameter_count = statement
1144 .parameters()
1145 .map(|p| match p {
1146 Either::Left(types) => types.len(),
1147 Either::Right(count) => count,
1148 })
1149 .unwrap_or(0);
1150
1151 Ok(sqlx_core::describe::Describe {
1152 columns,
1153 parameters: Some(Either::Right(parameter_count)),
1154 nullable: vec![None; column_count],
1155 })
1156 })
1157 }
1158}
1159
1160async fn send_command_async<T: Send>(
1165 cmd_tx: &flume::Sender<Command>,
1166 make_cmd: impl FnOnce(tokio::sync::oneshot::Sender<T>) -> Command,
1167) -> std::result::Result<T, sqlx_core::Error> {
1168 let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
1169 let cmd = make_cmd(resp_tx);
1170 cmd_tx.send(cmd).map_err(|_| {
1171 sqlx_core::Error::Protocol(
1172 "MSSQL ODBC connection actor has shut down".to_owned(),
1173 )
1174 })?;
1175 resp_rx.await.map_err(|_| {
1176 sqlx_core::Error::Protocol(
1177 "MSSQL ODBC connection actor response channel closed".to_owned(),
1178 )
1179 })
1180}
1181
1182fn send_command_blocking<T: Send>(
1187 cmd_tx: &flume::Sender<Command>,
1188 make_cmd: impl FnOnce(tokio::sync::oneshot::Sender<T>) -> Command,
1189) -> std::result::Result<T, sqlx_core::Error> {
1190 let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
1191 let cmd = make_cmd(resp_tx);
1192 cmd_tx.send(cmd).map_err(|_| {
1193 sqlx_core::Error::Protocol(
1194 "MSSQL ODBC connection actor has shut down".to_owned(),
1195 )
1196 })?;
1197 resp_rx.blocking_recv().map_err(|_| {
1198 sqlx_core::Error::Protocol(
1199 "MSSQL ODBC connection actor response channel closed".to_owned(),
1200 )
1201 })
1202}
1203
1204fn receiver_to_stream<'e>(
1209 rx: flume::Receiver<ExecuteResult>,
1210) -> BoxStream<'e, ExecuteResult> {
1211 stream::unfold(rx, |rx| async move {
1212 rx.recv_async().await.ok().map(|item| (item, rx))
1213 })
1214 .boxed()
1215}
1216
1217fn send_rows_affected(
1222 rows_affected: Option<usize>,
1223 tx: &ExecuteSender,
1224) -> std::result::Result<(), sqlx_core::Error> {
1225 let rows_affected = rows_affected
1226 .unwrap_or(0)
1227 .try_into()
1228 .map_err(|_| sqlx_core::Error::Protocol("ODBC row count does not fit in u64".to_owned()))?;
1229 send_done(tx, rows_affected);
1230 Ok(())
1231}
1232
1233fn send_done(tx: &ExecuteSender, rows_affected: u64) -> bool {
1234 tx.send(Ok(Either::Left(MssqlQueryResult::new(rows_affected))))
1235 .is_ok()
1236}
1237
1238fn send_row(tx: &ExecuteSender, row: MssqlRow) -> bool {
1239 tx.send(Ok(Either::Right(row))).is_ok()
1240}
1241
1242pub(crate) fn collect_columns(
1243 cursor: &mut impl ResultSetMetadata,
1244) -> std::result::Result<Vec<MssqlColumn>, sqlx_core::Error> {
1245 let count = cursor.num_result_cols().map_err(|error| {
1246 crate::error::database_error_with_context(error, "failed to read ODBC result-column count")
1247 })?;
1248 let count = usize::try_from(count).map_err(|_| {
1249 sqlx_core::Error::Protocol(format!("ODBC returned a negative column count: {count}"))
1250 })?;
1251
1252 let mut columns = Vec::with_capacity(count);
1253 for ordinal in 0..count {
1254 let column_number = u16::try_from(ordinal + 1).map_err(|_| {
1255 sqlx_core::Error::Protocol(format!("ODBC column index exceeds u16: {}", ordinal + 1))
1256 })?;
1257
1258 let mut description = odbc_api::ColumnDescription::default();
1259 cursor
1260 .describe_col(column_number, &mut description)
1261 .map_err(|error| {
1262 crate::error::database_error_with_context(
1263 error,
1264 format!("failed to describe ODBC result column {column_number}"),
1265 )
1266 })?;
1267 let name = description
1268 .name_to_string()
1269 .unwrap_or_else(|_| format!("col{ordinal}"));
1270
1271 columns.push(MssqlColumn::new(
1272 ordinal,
1273 name,
1274 MssqlTypeInfo::new(description.data_type),
1275 ));
1276 }
1277
1278 Ok(columns)
1279}
1280
1281fn collect_prepared_columns(
1282 prepared: &mut impl PreparedStatementMetadata,
1283 parameter_count: u16,
1284) -> std::result::Result<Vec<MssqlColumn>, sqlx_core::Error> {
1285 match collect_columns(prepared) {
1286 Ok(columns) => Ok(columns),
1287 Err(error) if parameter_count > 0 => {
1288 validate_parameter_metadata(prepared, parameter_count)?;
1289 log::debug!("ODBC driver deferred result-column metadata until execution: {error}");
1290 Ok(Vec::new())
1291 }
1292 Err(error) => Err(error),
1293 }
1294}
1295
1296trait PreparedStatementMetadata: ResultSetMetadata {
1297 fn describe_prepared_parameter(
1298 &mut self,
1299 index: u16,
1300 ) -> std::result::Result<(), odbc_api::Error>;
1301}
1302
1303impl<S> PreparedStatementMetadata for odbc_api::Prepared<S>
1304where
1305 S: odbc_api::handles::AsStatementRef,
1306{
1307 fn describe_prepared_parameter(
1308 &mut self,
1309 index: u16,
1310 ) -> std::result::Result<(), odbc_api::Error> {
1311 self.describe_param(index).map(|_| ())
1312 }
1313}
1314
1315fn validate_parameter_metadata(
1316 prepared: &mut impl PreparedStatementMetadata,
1317 parameter_count: u16,
1318) -> std::result::Result<(), sqlx_core::Error> {
1319 for index in 1..=parameter_count {
1320 prepared
1321 .describe_prepared_parameter(index)
1322 .map_err(|error| {
1323 crate::error::database_error_with_context(
1324 error,
1325 format!("failed to describe ODBC parameter {index}"),
1326 )
1327 })?;
1328 }
1329
1330 Ok(())
1331}
1332
1333fn stream_result_sets<C>(
1334 mut cursor: C,
1335 settings: MssqlBufferSettings,
1336 tx: &ExecuteSender,
1337) -> std::result::Result<(), sqlx_core::Error>
1338where
1339 C: Cursor + ResultSetMetadata,
1340{
1341 loop {
1342 if cursor.num_result_cols().map_err(|error| {
1343 crate::error::database_error_with_context(
1344 error,
1345 "failed to read ODBC result-column count",
1346 )
1347 })? == 0
1348 {
1349 send_done(tx, 0);
1350 } else if let Some(max_column_size) = settings.max_column_size {
1351 let (receiver_open, finished_cursor) =
1352 stream_rows_buffered(cursor, settings.batch_size, max_column_size, tx)?;
1353 if !receiver_open {
1354 return Ok(());
1355 }
1356 cursor = finished_cursor;
1357 } else if !stream_rows_unbuffered(&mut cursor, tx)? {
1358 return Ok(());
1359 }
1360
1361 match cursor.more_results().map_err(|error| {
1362 crate::error::database_error_with_context(error, "failed to advance ODBC result set")
1363 })? {
1364 Some(next_cursor) => cursor = next_cursor,
1365 None => return Ok(()),
1366 }
1367 }
1368}
1369
1370#[derive(Debug)]
1371struct ColumnBinding {
1372 column: MssqlColumn,
1373 buffer_desc: BufferDesc,
1374}
1375
1376fn stream_rows_buffered<C>(
1377 cursor: C,
1378 batch_size: usize,
1379 max_column_size: usize,
1380 tx: &ExecuteSender,
1381) -> std::result::Result<(bool, C), sqlx_core::Error>
1382where
1383 C: Cursor + ResultSetMetadata,
1384{
1385 let mut cursor = cursor;
1386 let bindings = build_buffer_bindings(&mut cursor, max_column_size)?;
1387 let buffer_descriptions = bindings
1388 .iter()
1389 .map(|binding| binding.buffer_desc)
1390 .collect::<Vec<_>>();
1391 let mut row_set_cursor = cursor
1392 .bind_buffer(ColumnarDynBuffer::from_descs(
1393 batch_size,
1394 buffer_descriptions,
1395 ))
1396 .map_err(|error| {
1397 crate::error::database_error_with_context(
1398 error,
1399 format!(
1400 "ODBC buffered fetching could not be enabled with batch_size={batch_size}; \
1401 this driver may reject the row-array or row-binding statement attributes \
1402 used for column-wise buffered fetching, so use \
1403 MssqlConnectOptions::max_column_size(None) to fetch rows unbuffered"
1404 ),
1405 )
1406 })?;
1407 let columns: Arc<[MssqlColumn]> = bindings
1408 .iter()
1409 .map(|binding| binding.column.clone())
1410 .collect::<Vec<_>>()
1411 .into();
1412
1413 while let Some(batch) = row_set_cursor.fetch().map_err(|error| {
1414 crate::error::database_error_with_context(error, "ODBC buffered fetch failed")
1415 })? {
1416 let column_values = bindings
1417 .iter()
1418 .enumerate()
1419 .map(|(index, binding)| {
1420 buffered_column_values(batch.column(index), binding).map_err(|error| {
1421 sqlx_core::Error::Protocol(format!(
1422 "ODBC buffered fetch could not convert column {} (`{}`) using buffer {:?}: {error}",
1423 binding.column.ordinal() + 1,
1424 binding.column.name(),
1425 binding.buffer_desc
1426 ))
1427 })
1428 })
1429 .collect::<std::result::Result<Vec<_>, _>>()?;
1430
1431 let mut column_iters = column_values
1432 .into_iter()
1433 .map(Vec::into_iter)
1434 .collect::<Vec<_>>();
1435
1436 for row_index in 0..batch.num_rows() {
1437 let values = column_iters
1438 .iter_mut()
1439 .map(|values| {
1440 values.next().map(MssqlValue::new).ok_or_else(|| {
1441 sqlx_core::Error::Protocol(format!(
1442 "ODBC buffered fetch produced too few values for row {}",
1443 row_index + 1
1444 ))
1445 })
1446 })
1447 .collect::<std::result::Result<Vec<_>, _>>()?;
1448 if !send_row(tx, MssqlRow::new_shared(Arc::clone(&columns), values)) {
1449 let (cursor, _) = row_set_cursor.unbind().map_err(|error| {
1450 crate::error::database_error_with_context(
1451 error,
1452 "ODBC buffered fetch could not unbind row buffer after receiver closed",
1453 )
1454 })?;
1455 return Ok((false, cursor));
1456 }
1457 }
1458 }
1459
1460 send_done(tx, 0);
1461 let (cursor, _) = row_set_cursor.unbind().map_err(|error| {
1462 crate::error::database_error_with_context(
1463 error,
1464 "ODBC buffered fetch could not unbind row buffer",
1465 )
1466 })?;
1467 Ok((true, cursor))
1468}
1469
1470fn build_buffer_bindings(
1471 cursor: &mut impl ResultSetMetadata,
1472 max_column_size: usize,
1473) -> std::result::Result<Vec<ColumnBinding>, sqlx_core::Error> {
1474 collect_columns(cursor).map(|columns| {
1475 columns
1476 .into_iter()
1477 .map(|column| ColumnBinding {
1478 buffer_desc: map_buffer_desc(column.type_info().data_type(), max_column_size),
1479 column,
1480 })
1481 .collect()
1482 })
1483}
1484
1485fn map_buffer_desc(data_type: DataType, max_column_size: usize) -> BufferDesc {
1486 match data_type {
1487 DataType::TinyInt | DataType::SmallInt | DataType::Integer | DataType::BigInt => {
1488 BufferDesc::I64 { nullable: true }
1489 }
1490 DataType::Real => BufferDesc::F32 { nullable: true },
1491 DataType::Float { .. } | DataType::Double => BufferDesc::F64 { nullable: true },
1492 DataType::Bit => BufferDesc::Bit { nullable: true },
1493 DataType::Date => BufferDesc::Date { nullable: true },
1494 DataType::Time { .. } => BufferDesc::Time { nullable: true },
1495 DataType::Timestamp { .. } => BufferDesc::Timestamp { nullable: true },
1496 DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => {
1497 BufferDesc::Binary {
1498 max_bytes: max_column_size,
1499 }
1500 }
1501 DataType::WChar { .. } | DataType::WVarchar { .. } | DataType::WLongVarchar { .. } => {
1504 BufferDesc::WText {
1505 max_str_len: max_column_size,
1506 }
1507 }
1508 DataType::Char { .. }
1510 | DataType::Varchar { .. }
1511 | DataType::LongVarchar { .. }
1512 | DataType::Other { .. }
1513 | DataType::Unknown
1514 | DataType::Decimal { .. }
1515 | DataType::Numeric { .. } => BufferDesc::Text {
1516 max_str_len: max_column_size,
1517 },
1518 }
1519}
1520
1521fn buffered_column_values(
1522 slice: AnyColumnBufferSlice<'_>,
1523 binding: &ColumnBinding,
1524) -> std::result::Result<Vec<MssqlValueKind>, sqlx_core::Error> {
1525 let desc = binding.buffer_desc;
1526 Ok(match desc {
1527 BufferDesc::I8 { nullable } => buffered_numeric(&slice, desc, nullable, |value: i8| {
1528 MssqlValueKind::TinyInt(i16::from(value))
1529 })?,
1530 BufferDesc::I16 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
1531 MssqlValueKind::SmallInt(value)
1532 })?,
1533 BufferDesc::I32 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
1534 MssqlValueKind::Integer(value)
1535 })?,
1536 BufferDesc::I64 { nullable } => {
1537 buffered_numeric(&slice, desc, nullable, MssqlValueKind::BigInt)?
1538 }
1539 BufferDesc::U8 { nullable } => buffered_numeric(&slice, desc, nullable, |value: u8| {
1540 MssqlValueKind::BigInt(i64::from(value))
1541 })?,
1542 BufferDesc::F32 { nullable } => {
1543 buffered_numeric(&slice, desc, nullable, MssqlValueKind::Real)?
1544 }
1545 BufferDesc::F64 { nullable } => {
1546 buffered_numeric(&slice, desc, nullable, MssqlValueKind::Double)?
1547 }
1548 BufferDesc::Bit { nullable } => {
1549 buffered_numeric(&slice, desc, nullable, |value: odbc_api::Bit| {
1550 MssqlValueKind::Bit(value.as_bool())
1551 })?
1552 }
1553 BufferDesc::Date { nullable } => {
1554 buffered_numeric(&slice, desc, nullable, MssqlValueKind::Date)?
1555 }
1556 BufferDesc::Time { nullable } => {
1557 buffered_numeric(&slice, desc, nullable, MssqlValueKind::Time)?
1558 }
1559 BufferDesc::Timestamp { nullable } => {
1560 buffered_numeric(&slice, desc, nullable, MssqlValueKind::Timestamp)?
1561 }
1562 BufferDesc::Text { .. } => {
1563 let text = expect_buffer_slice(slice.as_text(), desc)?;
1564 text.iter()
1565 .map(|value| {
1566 value
1567 .map(|bytes| {
1568 MssqlValueKind::Text(String::from_utf8_lossy(bytes).into_owned())
1569 })
1570 .unwrap_or(MssqlValueKind::Null)
1571 })
1572 .collect()
1573 }
1574 BufferDesc::WText { .. } => {
1575 let text = expect_buffer_slice(slice.as_wide_text(), desc)?;
1576 text.iter()
1577 .map(|value| {
1578 value
1579 .map(|chars| MssqlValueKind::Text(String::from_utf16_lossy(chars.into())))
1580 .unwrap_or(MssqlValueKind::Null)
1581 })
1582 .collect()
1583 }
1584 BufferDesc::Binary { .. } => {
1585 let binary = expect_buffer_slice(slice.as_binary(), desc)?;
1586 binary
1587 .iter()
1588 .map(|value| {
1589 value
1590 .map(|bytes| MssqlValueKind::Binary(bytes.to_vec()))
1591 .unwrap_or(MssqlValueKind::Null)
1592 })
1593 .collect()
1594 }
1595 BufferDesc::Numeric => {
1596 return Err(sqlx_core::Error::Protocol(format!(
1597 "unsupported ODBC buffer descriptor: {desc:?}"
1598 )))
1599 }
1600 })
1601}
1602
1603fn buffered_numeric<T, F>(
1604 slice: &AnyColumnBufferSlice<'_>,
1605 desc: BufferDesc,
1606 nullable: bool,
1607 map: F,
1608) -> std::result::Result<Vec<MssqlValueKind>, sqlx_core::Error>
1609where
1610 T: Copy + odbc_api::Pod,
1611 F: FnMut(T) -> MssqlValueKind,
1612{
1613 if nullable {
1614 Ok(buffered_nullable_numeric(
1615 expect_buffer_slice(slice.as_nullable_slice::<T>(), desc)?,
1616 map,
1617 ))
1618 } else {
1619 Ok(expect_buffer_slice(slice.as_slice::<T>(), desc)?
1620 .iter()
1621 .copied()
1622 .map(map)
1623 .collect())
1624 }
1625}
1626
1627fn buffered_nullable_numeric<T, F>(slice: NullableSlice<'_, T>, mut map: F) -> Vec<MssqlValueKind>
1628where
1629 T: Copy,
1630 F: FnMut(T) -> MssqlValueKind,
1631{
1632 slice
1633 .map(|value| value.copied().map(&mut map).unwrap_or(MssqlValueKind::Null))
1634 .collect()
1635}
1636
1637fn expect_buffer_slice<T>(
1638 slice: Option<T>,
1639 desc: BufferDesc,
1640) -> std::result::Result<T, sqlx_core::Error> {
1641 slice.ok_or_else(|| {
1642 sqlx_core::Error::Protocol(format!(
1643 "ODBC column buffer {desc:?} did not match fetched slice"
1644 ))
1645 })
1646}
1647
1648fn stream_rows_unbuffered<C>(
1649 cursor: &mut C,
1650 tx: &ExecuteSender,
1651) -> std::result::Result<bool, sqlx_core::Error>
1652where
1653 C: Cursor + ResultSetMetadata,
1654{
1655 let columns: Arc<[MssqlColumn]> = collect_columns(cursor)?.into();
1656
1657 while let Some(mut cursor_row) = cursor.next_row().map_err(|error| {
1658 crate::error::database_error_with_context(
1659 error,
1660 "ODBC unbuffered fetch failed while reading the next row",
1661 )
1662 })? {
1663 let mut values = Vec::with_capacity(columns.len());
1664
1665 for column in columns.iter() {
1666 let column_number = u16::try_from(sqlx_core::column::Column::ordinal(column) + 1)
1667 .map_err(|_| {
1668 sqlx_core::Error::Protocol("ODBC column index exceeds u16".to_owned())
1669 })?;
1670 values.push(fetch_value(&mut cursor_row, column_number, column)?);
1671 }
1672
1673 if !send_row(tx, MssqlRow::new_shared(Arc::clone(&columns), values)) {
1674 return Ok(false);
1675 }
1676 }
1677
1678 send_done(tx, 0);
1679 Ok(true)
1680}
1681
1682fn fetch_value(
1683 row: &mut odbc_api::CursorRow<'_>,
1684 column_number: u16,
1685 column: &MssqlColumn,
1686) -> std::result::Result<MssqlValue, sqlx_core::Error> {
1687 let data_type = column.type_info().data_type();
1688
1689 let kind = match data_type {
1690 DataType::Bit => {
1691 let mut value = Nullable::<odbc_api::Bit>::null();
1692 row.get_data(column_number, &mut value).map_err(|error| {
1693 crate::error::database_error_with_context_lazy(error, || {
1694 fetch_context(column, data_type)
1695 })
1696 })?;
1697 value
1698 .into_opt()
1699 .map(|value| MssqlValueKind::Bit(value.as_bool()))
1700 .unwrap_or(MssqlValueKind::Null)
1701 }
1702 DataType::TinyInt => {
1703 let mut value = Nullable::<i16>::null();
1706 row.get_data(column_number, &mut value).map_err(|error| {
1707 crate::error::database_error_with_context_lazy(error, || {
1708 fetch_context(column, data_type)
1709 })
1710 })?;
1711 value
1712 .into_opt()
1713 .map(MssqlValueKind::TinyInt)
1714 .unwrap_or(MssqlValueKind::Null)
1715 }
1716 DataType::SmallInt => fetch_nullable(
1717 row,
1718 column_number,
1719 column,
1720 data_type,
1721 MssqlValueKind::SmallInt,
1722 )?,
1723 DataType::Integer => fetch_nullable(
1724 row,
1725 column_number,
1726 column,
1727 data_type,
1728 MssqlValueKind::Integer,
1729 )?,
1730 DataType::BigInt => {
1731 fetch_nullable(row, column_number, column, data_type, MssqlValueKind::BigInt)?
1732 }
1733 DataType::Real => {
1734 fetch_nullable(row, column_number, column, data_type, MssqlValueKind::Real)?
1735 }
1736 DataType::Float { .. } | DataType::Double => {
1737 fetch_nullable(row, column_number, column, data_type, MssqlValueKind::Double)?
1738 }
1739 DataType::Date => {
1740 fetch_nullable(row, column_number, column, data_type, MssqlValueKind::Date)?
1741 }
1742 DataType::Time { .. } => {
1743 fetch_nullable(row, column_number, column, data_type, MssqlValueKind::Time)?
1744 }
1745 DataType::Timestamp { .. } => fetch_nullable(
1746 row,
1747 column_number,
1748 column,
1749 data_type,
1750 MssqlValueKind::Timestamp,
1751 )?,
1752 DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => {
1753 let mut value = Vec::new();
1754 if row.get_binary(column_number, &mut value).map_err(|error| {
1755 crate::error::database_error_with_context_lazy(error, || {
1756 fetch_context(column, data_type)
1757 })
1758 })? {
1759 MssqlValueKind::Binary(value)
1760 } else {
1761 MssqlValueKind::Null
1762 }
1763 }
1764 DataType::Other {
1765 data_type: sql_type, ..
1766 } if sql_type.0 == -11 => {
1767 let mut value = Vec::new();
1769 if row.get_binary(column_number, &mut value).map_err(|error| {
1770 crate::error::database_error_with_context_lazy(error, || {
1771 fetch_context(column, data_type)
1772 })
1773 })? {
1774 if value.len() == 16 {
1775 let mut guid = [0u8; 16];
1776 guid.copy_from_slice(&value);
1777 MssqlValueKind::Guid(guid)
1778 } else {
1779 MssqlValueKind::Text(String::from_utf16_lossy(
1781 &value.iter().map(|&b| b as u16).collect::<Vec<_>>(),
1782 ))
1783 }
1784 } else {
1785 MssqlValueKind::Null
1786 }
1787 }
1788 _ => {
1789 let mut value = Vec::new();
1790 if row
1791 .get_wide_text(column_number, &mut value)
1792 .map_err(|error| {
1793 crate::error::database_error_with_context_lazy(error, || {
1794 fetch_context(column, data_type)
1795 })
1796 })?
1797 {
1798 MssqlValueKind::Text(String::from_utf16_lossy(&value))
1799 } else {
1800 MssqlValueKind::Null
1801 }
1802 }
1803 };
1804
1805 Ok(MssqlValue::new(kind))
1806}
1807
1808fn fetch_nullable<T, F>(
1809 row: &mut odbc_api::CursorRow<'_>,
1810 column_number: u16,
1811 column: &MssqlColumn,
1812 data_type: DataType,
1813 map: F,
1814) -> std::result::Result<MssqlValueKind, sqlx_core::Error>
1815where
1816 T: Default + Copy + odbc_api::parameter::CElement + odbc_api::handles::CDataMut,
1817 Nullable<T>: odbc_api::parameter::CElement + odbc_api::handles::CDataMut,
1818 F: FnOnce(T) -> MssqlValueKind,
1819{
1820 let mut value = Nullable::<T>::null();
1821 row.get_data(column_number, &mut value).map_err(|error| {
1822 crate::error::database_error_with_context_lazy(error, || fetch_context(column, data_type))
1823 })?;
1824 Ok(value.into_opt().map(map).unwrap_or(MssqlValueKind::Null))
1825}
1826
1827fn fetch_context(column: &MssqlColumn, data_type: DataType) -> String {
1828 format!(
1829 "failed to fetch ODBC column {} (`{}`) as {data_type:?}",
1830 column.ordinal() + 1,
1831 column.name()
1832 )
1833}
1834
1835fn sql_preview(sql: &str) -> String {
1836 const MAX_LEN: usize = 160;
1837
1838 let compact = sql.split_whitespace().collect::<Vec<_>>().join(" ");
1839 if compact.len() <= MAX_LEN {
1840 compact
1841 } else {
1842 let mut preview = compact.chars().take(MAX_LEN - 3).collect::<String>();
1843 preview.push_str("...");
1844 preview
1845 }
1846}
1847
1848pub(crate) async fn offload_blocking<F, T>(f: F) -> std::result::Result<T, sqlx_core::Error>
1853where
1854 F: FnOnce() -> std::result::Result<T, sqlx_core::Error> + Send + 'static,
1855 T: Send + 'static,
1856{
1857 tokio::task::spawn_blocking(f)
1858 .await
1859 .map_err(|e| sqlx_core::Error::Protocol(format!("blocking task panicked: {e}")))?
1860}
1861
1862#[cfg(test)]
1863mod tests {
1864 use super::*;
1865
1866 #[test]
1867 fn buffered_fetch_maps_numeric_types_to_nullable_64_bit_buffers() {
1868 assert!(matches!(
1869 map_buffer_desc(DataType::TinyInt, 64),
1870 BufferDesc::I64 { nullable: true }
1871 ));
1872 assert!(matches!(
1873 map_buffer_desc(DataType::Integer, 64),
1874 BufferDesc::I64 { nullable: true }
1875 ));
1876 assert!(matches!(
1877 map_buffer_desc(DataType::BigInt, 64),
1878 BufferDesc::I64 { nullable: true }
1879 ));
1880 }
1881
1882 #[test]
1883 fn buffered_fetch_uses_configured_limits_for_variable_sized_data() {
1884 assert_eq!(
1885 map_buffer_desc(DataType::Varchar { length: None }, 32),
1886 BufferDesc::Text { max_str_len: 32 }
1887 );
1888 assert_eq!(
1889 map_buffer_desc(DataType::Varbinary { length: None }, 16),
1890 BufferDesc::Binary { max_bytes: 16 }
1891 );
1892 }
1893
1894 #[test]
1895 fn buffered_fetch_maps_wide_char_types_to_wtext() {
1896 assert!(matches!(
1897 map_buffer_desc(DataType::WChar { length: None }, 64),
1898 BufferDesc::WText { max_str_len: 64 }
1899 ));
1900 assert!(matches!(
1901 map_buffer_desc(DataType::WVarchar { length: None }, 128),
1902 BufferDesc::WText { max_str_len: 128 }
1903 ));
1904 assert!(matches!(
1905 map_buffer_desc(DataType::WLongVarchar { length: None }, 256),
1906 BufferDesc::WText { max_str_len: 256 }
1907 ));
1908 }
1909
1910 #[test]
1911 fn buffered_fetch_maps_narrow_char_types_to_text() {
1912 assert!(matches!(
1913 map_buffer_desc(DataType::Char { length: None }, 64),
1914 BufferDesc::Text { max_str_len: 64 }
1915 ));
1916 assert!(matches!(
1917 map_buffer_desc(DataType::Varchar { length: None }, 64),
1918 BufferDesc::Text { max_str_len: 64 }
1919 ));
1920 assert!(matches!(
1921 map_buffer_desc(DataType::LongVarchar { length: None }, 64),
1922 BufferDesc::Text { max_str_len: 64 }
1923 ));
1924 }
1925
1926}