Skip to main content

sqlx_mssql_odbc_core/
connection.rs

1use crate::{
2    MssqlArguments, MssqlBufferSettings, MssqlColumn, MssqlConnectOptions,
3    MssqlQueryResult, MssqlRow, MssqlStatement, MssqlTypeInfo, MssqlValue, MssqlValueKind, Result,
4};
5use futures_core::future::BoxFuture;
6use futures_core::stream::BoxStream;
7use futures_util::{future, stream, StreamExt};
8use odbc_api::buffers::{AnyColumnBufferSlice, BufferDesc, ColumnarDynBuffer, NullableSlice};
9use odbc_api::{ConnectionTransitions, Cursor, DataType, Nullable, ResultSetMetadata};
10use sqlx_core::column::Column;
11use sqlx_core::common::StatementCache;
12use sqlx_core::executor::{Execute, Executor};
13use sqlx_core::sql_str::SqlStr;
14use sqlx_core::transaction::Transaction;
15use sqlx_core::Either;
16use std::future::Future;
17use std::sync::atomic::{AtomicUsize, Ordering};
18use std::sync::Arc;
19
20type PreparedStatement =
21    odbc_api::Prepared<odbc_api::handles::StatementConnection<odbc_api::SharedConnection<'static>>>;
22type ExecuteResult = std::result::Result<Either<MssqlQueryResult, MssqlRow>, sqlx_core::Error>;
23type ExecuteSender = flume::Sender<ExecuteResult>;
24
25// ============================================================================
26// Command enum — sent from the async handle to the actor thread
27// ============================================================================
28
29enum Command {
30    Execute {
31        sql: SqlStr,
32        args: Option<MssqlArguments>,
33        persistent: bool,
34        response: ExecuteSender,
35    },
36    Prepare {
37        sql: SqlStr,
38        response: tokio::sync::oneshot::Sender<
39            std::result::Result<MssqlStatement, sqlx_core::Error>,
40        >,
41    },
42    Ping {
43        response: tokio::sync::oneshot::Sender<std::result::Result<(), sqlx_core::Error>>,
44    },
45    Begin {
46        response: tokio::sync::oneshot::Sender<std::result::Result<(), sqlx_core::Error>>,
47    },
48    Commit {
49        response: tokio::sync::oneshot::Sender<std::result::Result<(), sqlx_core::Error>>,
50    },
51    Rollback {
52        response: tokio::sync::oneshot::Sender<std::result::Result<(), sqlx_core::Error>>,
53    },
54    StartRollback,
55    ExecSql {
56        sql: String,
57        response: tokio::sync::oneshot::Sender<std::result::Result<(), sqlx_core::Error>>,
58    },
59    ScalarI64 {
60        sql: String,
61        response:
62            tokio::sync::oneshot::Sender<std::result::Result<Option<i64>, sqlx_core::Error>>,
63    },
64    Shutdown {
65        signal: tokio::sync::oneshot::Sender<()>,
66    },
67    /// Returns `Vec<(version, checksum_bytes)>` from the migrations table.
68    ListMigrations {
69        sql: String,
70        response:
71            tokio::sync::oneshot::Sender<std::result::Result<Vec<(i64, Vec<u8>)>, sqlx_core::Error>>,
72    },
73    /// Applies a migration: starts a transaction, runs SQL, inserts tracking
74    /// record, commits. If `no_tx` is true the transaction is skipped.
75    #[cfg(feature = "migrate")]
76    ApplyMigration {
77        sql: String,
78        insert_sql: String,
79        version: i64,
80        no_tx: bool,
81        response: tokio::sync::oneshot::Sender<std::result::Result<std::time::Duration, sqlx_core::Error>>,
82    },
83    /// Reverts a migration: starts a transaction, runs SQL, deletes tracking
84    /// record, commits. If `no_tx` is true the transaction is skipped.
85    #[cfg(feature = "migrate")]
86    RevertMigration {
87        sql: String,
88        delete_sql: String,
89        version: i64,
90        no_tx: bool,
91        response: tokio::sync::oneshot::Sender<std::result::Result<std::time::Duration, sqlx_core::Error>>,
92    },
93}
94
95// ============================================================================
96// ConnectionActor — owns the ODBC connection on a dedicated blocking thread
97// ============================================================================
98
99struct ConnectionActor {
100    conn: odbc_api::SharedConnection<'static>,
101    stmt_cache: StatementCache<PreparedStatement>,
102    transaction_depth: usize,
103    buffer_settings: MssqlBufferSettings,
104}
105
106impl ConnectionActor {
107    fn run(mut self, rx: flume::Receiver<Command>) {
108        // The channel iterator blocks on recv() and returns None when the
109        // channel is closed (all senders dropped).
110        for cmd in rx {
111            // Ignore errors from response senders — the consumer may have
112            // dropped their receiver (stream cancelled, etc.).
113            match cmd {
114                Command::Execute {
115                    sql,
116                    args,
117                    persistent,
118                    response,
119                } => {
120                    let _ = self.handle_execute(sql, args, persistent, &response);
121                }
122                Command::Prepare { sql, response } => {
123                    let _ = response.send(self.handle_prepare(sql));
124                }
125                Command::Ping { response } => {
126                    let _ = response.send(self.handle_ping());
127                }
128                Command::Begin { response } => {
129                    let _ = response.send(self.handle_begin());
130                }
131                Command::Commit { response } => {
132                    let _ = response.send(self.handle_commit());
133                }
134                Command::Rollback { response } => {
135                    let _ = response.send(self.handle_rollback());
136                }
137                Command::StartRollback => {
138                    self.handle_start_rollback();
139                }
140                Command::ExecSql { sql, response } => {
141                    let _ = response.send(self.handle_exec_sql(&sql));
142                }
143                Command::ScalarI64 { sql, response } => {
144                    let _ = response.send(self.handle_scalar_i64(&sql));
145                }
146                Command::Shutdown { signal } => {
147                    let _ = signal.send(());
148                    return;
149                }
150                Command::ListMigrations { sql, response } => {
151                    let _ = response.send(self.handle_list_migrations(&sql));
152                }
153                #[cfg(feature = "migrate")]
154                Command::ApplyMigration {
155                    sql,
156                    insert_sql,
157                    version,
158                    no_tx,
159                    response,
160                } => {
161                    let _ = response.send(self.handle_apply_migration(&sql, &insert_sql, version, no_tx));
162                }
163                #[cfg(feature = "migrate")]
164                Command::RevertMigration {
165                    sql,
166                    delete_sql,
167                    version,
168                    no_tx,
169                    response,
170                } => {
171                    let _ = response.send(self.handle_revert_migration(&sql, &delete_sql, version, no_tx));
172                }
173            }
174        }
175        // Channel closed — exit loop, dropping self and the SharedConnection.
176    }
177
178    // ---------------------------------------------------------------
179    // Command handlers
180    // ---------------------------------------------------------------
181
182    fn handle_execute(
183        &mut self,
184        sql: SqlStr,
185        arguments: Option<MssqlArguments>,
186        persistent: bool,
187        tx: &ExecuteSender,
188    ) -> std::result::Result<(), sqlx_core::Error> {
189        let has_arguments = arguments.as_ref().is_some_and(|a| !a.is_empty());
190        let parameters = arguments
191            .as_ref()
192            .map(MssqlArguments::to_odbc_parameter_collection)
193            .unwrap_or_default();
194
195        if persistent && has_arguments {
196            if let Some(prepared) = self.stmt_cache.get_mut(sql.as_str()) {
197                // Execute from cache.
198                let mut conn_guard = self.conn.lock().map_err(|_| {
199                    sqlx_core::Error::Protocol(
200                        "ODBC execute: failed to lock connection".to_owned(),
201                    )
202                })?;
203                let has_cursor = prepared
204                    .execute(parameters.as_slice())
205                    .map_err(|error| {
206                        crate::error::database_error_with_context_lazy(error, || {
207                            format!(
208                                "failed to execute cached ODBC statement: `{}`",
209                                sql_preview(sql.as_str())
210                            )
211                        })
212                    })?
213                    .is_some();
214                drop(conn_guard);
215
216                if has_cursor {
217                    // Re-execute to get the cursor (avoid borrow conflict).
218                    let mut conn_guard = self.conn.lock().map_err(|_| {
219                        sqlx_core::Error::Protocol(
220                            "ODBC execute: failed to lock connection".to_owned(),
221                        )
222                    })?;
223                    let cursor = prepared
224                        .execute(parameters.as_slice())
225                        .map_err(|error| {
226                            crate::error::database_error_with_context_lazy(error, || {
227                                format!(
228                                    "failed to execute cached ODBC statement: `{}`",
229                                    sql_preview(sql.as_str())
230                                )
231                            })
232                        })?
233                        .expect("has_cursor was true");
234                    drop(conn_guard);
235                    return stream_result_sets(cursor, self.buffer_settings, tx);
236                }
237
238                let ra = prepared.row_count().map_err(|error| {
239                    crate::error::database_error_with_context_lazy(error, || {
240                        format!(
241                            "failed to read ODBC row count for cached statement: `{}`",
242                            sql_preview(sql.as_str())
243                        )
244                    })
245                })?;
246                return send_rows_affected(ra, tx);
247            } else {
248                // Prepare and cache
249                let mut prepared =
250                    self.conn.clone().into_prepared(sql.as_str()).map_err(|error| {
251                        crate::error::database_error_with_context_lazy(error, || {
252                            format!(
253                                "failed to prepare cached ODBC statement: `{}`",
254                                sql_preview(sql.as_str())
255                            )
256                        })
257                    })?;
258
259                let mut conn_guard = self.conn.lock().map_err(|_| {
260                    sqlx_core::Error::Protocol(
261                        "ODBC execute: failed to lock connection".to_owned(),
262                    )
263                })?;
264                let has_cursor = prepared
265                    .execute(parameters.as_slice())
266                    .map_err(|error| {
267                        crate::error::database_error_with_context_lazy(error, || {
268                            format!(
269                                "failed to execute cached ODBC statement: `{}`",
270                                sql_preview(sql.as_str())
271                            )
272                        })
273                    })?
274                    .is_some();
275                drop(conn_guard);
276
277                if has_cursor {
278                    // Re-execute to get the cursor for streaming.
279                    let mut conn_guard = self.conn.lock().map_err(|_| {
280                        sqlx_core::Error::Protocol(
281                            "ODBC execute: failed to lock connection".to_owned(),
282                        )
283                    })?;
284                    let cursor = prepared
285                        .execute(parameters.as_slice())
286                        .map_err(|error| {
287                            crate::error::database_error_with_context_lazy(error, || {
288                                format!(
289                                    "failed to execute cached ODBC statement: `{}`",
290                                    sql_preview(sql.as_str())
291                                )
292                            })
293                        })?
294                        .expect("has_cursor was true");
295                    drop(conn_guard);
296                    return stream_result_sets(cursor, self.buffer_settings, tx);
297                }
298
299                let ra = prepared.row_count().map_err(|error| {
300                    crate::error::database_error_with_context_lazy(error, || {
301                        format!(
302                            "failed to read ODBC row count for cached statement: `{}`",
303                            sql_preview(sql.as_str())
304                        )
305                    })
306                })?;
307                self.stmt_cache.insert(sql.as_str(), prepared);
308                return send_rows_affected(ra, tx);
309            }
310        } else {
311            // Unprepared (one-shot) path
312            let mut statement = self.conn.clone().into_preallocated().map_err(|error| {
313                crate::error::database_error_with_context_lazy(error, || {
314                    format!(
315                        "failed to allocate an ODBC statement for query: `{}`",
316                        sql_preview(sql.as_str())
317                    )
318                })
319            })?;
320            if let Some(cursor) = statement
321                .execute(sql.as_str(), parameters.as_slice())
322                .map_err(|error| {
323                    crate::error::database_error_with_context_lazy(error, || {
324                        format!(
325                            "failed to execute ODBC query: `{}`",
326                            sql_preview(sql.as_str())
327                        )
328                    })
329                })? {
330                return stream_result_sets(cursor, self.buffer_settings, tx);
331            }
332            let rows_affected = statement.row_count().map_err(|error| {
333                crate::error::database_error_with_context_lazy(error, || {
334                    format!(
335                        "failed to read ODBC row count for query: `{}`",
336                        sql_preview(sql.as_str())
337                    )
338                })
339            })?;
340            send_rows_affected(rows_affected, tx)
341        }
342    }
343
344    fn handle_prepare(
345        &mut self,
346        sql: SqlStr,
347    ) -> std::result::Result<MssqlStatement, sqlx_core::Error> {
348        if let Some(prepared) = self.stmt_cache.get_mut(sql.as_str()) {
349            let parameters = prepared.num_params().map_err(|error| {
350                sqlx_core::Error::from(crate::error::database_error_with_context(
351                    error,
352                    format!(
353                        "failed to read ODBC parameter metadata for cached statement: `{}`",
354                        sql_preview(sql.as_str())
355                    ),
356                ))
357            })?;
358            let columns = collect_prepared_columns(prepared, parameters)?;
359            return Ok(MssqlStatement::new(sql, columns, usize::from(parameters)));
360        }
361
362        let mut prepared = self.conn.clone().into_prepared(sql.as_str()).map_err(|error| {
363            sqlx_core::Error::from(crate::error::database_error_with_context(
364                error,
365                format!(
366                    "failed to prepare MSSQL ODBC statement: `{}`",
367                    sql_preview(sql.as_str())
368                ),
369            ))
370        })?;
371        let parameters = prepared.num_params().map_err(|error| {
372            sqlx_core::Error::from(crate::error::database_error_with_context(
373                error,
374                format!(
375                    "failed to read ODBC parameter metadata for prepared statement: `{}`",
376                    sql_preview(sql.as_str())
377                ),
378            ))
379        })?;
380        let columns = collect_prepared_columns(&mut prepared, parameters)?;
381        if self.stmt_cache.is_enabled() {
382            self.stmt_cache.insert(sql.as_str(), prepared);
383        }
384
385        Ok(MssqlStatement::new(sql, columns, usize::from(parameters)))
386    }
387
388    fn handle_ping(&mut self) -> std::result::Result<(), sqlx_core::Error> {
389        let mut conn_guard = self.conn.lock().map_err(|_| {
390            sqlx_core::Error::Protocol("failed to lock connection for ping".into())
391        })?;
392        conn_guard.execute("SELECT 1", (), None).map_err(|error| {
393            sqlx_core::Error::from(crate::error::database_error_with_context(
394                error,
395                "MSSQL ping query failed: `SELECT 1`",
396            ))
397        })?;
398        Ok(())
399    }
400
401    fn handle_begin(&mut self) -> std::result::Result<(), sqlx_core::Error> {
402        if self.transaction_depth == 0 {
403            let mut conn_guard = self.conn.lock().map_err(|_| {
404                sqlx_core::Error::Protocol(
405                    "MSSQL ODBC begin: failed to lock connection".to_owned(),
406                )
407            })?;
408            conn_guard.set_autocommit(false).map_err(|error| {
409                sqlx_core::Error::from(crate::error::database_error_with_context(
410                    error,
411                    "failed to disable ODBC autocommit while beginning a transaction",
412                ))
413            })?;
414        } else {
415            let savepoint = format!("sqlx_sp_{}", self.transaction_depth);
416            let mut conn_guard = self.conn.lock().map_err(|_| {
417                sqlx_core::Error::Protocol(
418                    "MSSQL ODBC begin (savepoint): failed to lock connection".to_owned(),
419                )
420            })?;
421            conn_guard
422                .execute(&format!("SAVE TRANSACTION {savepoint}"), (), None)
423                .map_err(|error| {
424                    sqlx_core::Error::from(crate::error::database_error_with_context(
425                        error,
426                        format!(
427                            "failed to create save point `{savepoint}` for nested transaction"
428                        ),
429                    ))
430                })?;
431        }
432        self.transaction_depth += 1;
433        Ok(())
434    }
435
436    fn handle_commit(&mut self) -> std::result::Result<(), sqlx_core::Error> {
437        if self.transaction_depth == 0 {
438            return Ok(());
439        }
440
441        if self.transaction_depth == 1 {
442            let mut conn_guard = self.conn.lock().map_err(|_| {
443                sqlx_core::Error::Protocol(
444                    "MSSQL ODBC commit: failed to lock connection".to_owned(),
445                )
446            })?;
447            conn_guard.commit().map_err(|error| {
448                sqlx_core::Error::from(crate::error::database_error_with_context(
449                    error,
450                    "failed to commit the active MSSQL ODBC transaction",
451                ))
452            })?;
453            conn_guard.set_autocommit(true).map_err(|error| {
454                sqlx_core::Error::from(crate::error::database_error_with_context(
455                    error,
456                    "failed to restore ODBC autocommit after commit",
457                ))
458            })?;
459            self.transaction_depth = 0;
460        } else {
461            self.transaction_depth -= 1;
462        }
463        Ok(())
464    }
465
466    fn handle_rollback(&mut self) -> std::result::Result<(), sqlx_core::Error> {
467        if self.transaction_depth == 0 {
468            return Ok(());
469        }
470
471        if self.transaction_depth == 1 {
472            let mut conn_guard = self.conn.lock().map_err(|_| {
473                sqlx_core::Error::Protocol(
474                    "MSSQL ODBC rollback: failed to lock connection".to_owned(),
475                )
476            })?;
477            conn_guard.rollback().map_err(|error| {
478                sqlx_core::Error::from(crate::error::database_error_with_context(
479                    error,
480                    "failed to roll back the active ODBC transaction",
481                ))
482            })?;
483            conn_guard.set_autocommit(true).map_err(|error| {
484                sqlx_core::Error::from(crate::error::database_error_with_context(
485                    error,
486                    "failed to restore ODBC autocommit after rollback",
487                ))
488            })?;
489            self.transaction_depth = 0;
490        } else {
491            let savepoint = format!("sqlx_sp_{}", self.transaction_depth - 1);
492            let mut conn_guard = self.conn.lock().map_err(|_| {
493                sqlx_core::Error::Protocol(
494                    "MSSQL ODBC rollback (savepoint): failed to lock connection".to_owned(),
495                )
496            })?;
497            conn_guard
498                .execute(&format!("ROLLBACK TRANSACTION {savepoint}"), (), None)
499                .map_err(|error| {
500                    sqlx_core::Error::from(crate::error::database_error_with_context(
501                        error,
502                        format!("failed to roll back to save point `{savepoint}`"),
503                    ))
504                })?;
505            self.transaction_depth -= 1;
506        }
507        Ok(())
508    }
509
510    fn handle_start_rollback(&mut self) {
511        if self.transaction_depth == 0 {
512            return;
513        }
514
515        if self.transaction_depth == 1 {
516            if let Ok(mut conn_guard) = self.conn.lock() {
517                let _ = conn_guard.rollback();
518                let _ = conn_guard.set_autocommit(true);
519            }
520            self.transaction_depth = 0;
521        } else {
522            let savepoint = format!("sqlx_sp_{}", self.transaction_depth - 1);
523            if let Ok(mut conn_guard) = self.conn.lock() {
524                let _ = conn_guard.execute(
525                    &format!("ROLLBACK TRANSACTION {savepoint}"),
526                    (),
527                    None,
528                );
529            }
530            self.transaction_depth -= 1;
531        }
532    }
533
534    fn handle_exec_sql(&self, sql: &str) -> std::result::Result<(), sqlx_core::Error> {
535        let mut conn_guard = self.conn.lock().map_err(|_| {
536            sqlx_core::Error::Protocol("failed to lock the shared ODBC connection".into())
537        })?;
538        conn_guard.execute(sql, (), None).map_err(|error| {
539            sqlx_core::Error::from(crate::error::database_error_with_context(
540                error,
541                format!("failed to execute SQL: `{}`", sql_preview(sql)),
542            ))
543        })?;
544        Ok(())
545    }
546
547    fn handle_scalar_i64(&self, sql: &str) -> std::result::Result<Option<i64>, sqlx_core::Error> {
548        let mut conn_guard = self.conn.lock().map_err(|_| {
549            sqlx_core::Error::Protocol("failed to lock the shared ODBC connection".into())
550        })?;
551        let mut cursor = conn_guard
552            .execute(sql, (), None)
553            .map_err(|error| {
554                sqlx_core::Error::from(crate::error::database_error_with_context(
555                    error,
556                    format!("scalar query failed: `{}`", sql_preview(sql)),
557                ))
558            })?
559            .ok_or_else(|| {
560                sqlx_core::Error::Protocol(format!(
561                    "scalar query returned no result set: `{}`",
562                    sql_preview(sql),
563                ))
564            })?;
565
566        if let Some(mut row) = cursor.next_row().map_err(|error| {
567            sqlx_core::Error::from(crate::error::database_error_with_context(
568                error,
569                "scalar query next row",
570            ))
571        })? {
572            let mut value: Nullable<i64> = Nullable::null();
573            row.get_data(1, &mut value).map_err(|error| {
574                sqlx_core::Error::from(crate::error::database_error_with_context(
575                    error,
576                    "scalar query column 1",
577                ))
578            })?;
579            Ok(value.into_opt())
580        } else {
581            Ok(None)
582        }
583    }
584
585    fn handle_list_migrations(
586        &self,
587        sql: &str,
588    ) -> std::result::Result<Vec<(i64, Vec<u8>)>, sqlx_core::Error> {
589        let mut conn_guard = self.conn.lock().map_err(|_| {
590            sqlx_core::Error::Protocol("failed to lock the shared ODBC connection".into())
591        })?;
592        let mut cursor = conn_guard
593            .execute(sql, (), None)
594            .map_err(|error| {
595                sqlx_core::Error::from(crate::error::database_error_with_context(
596                    error,
597                    "failed to query applied migrations",
598                ))
599            })?
600            .ok_or_else(|| {
601                sqlx_core::Error::Protocol(
602                    "list_applied_migrations returned no result set".into(),
603                )
604            })?;
605
606        let mut migrations = Vec::new();
607        while let Some(mut row) = cursor.next_row().map_err(|error| {
608            sqlx_core::Error::from(crate::error::database_error_with_context(
609                error,
610                "failed to read applied migration row",
611            ))
612        })? {
613            let mut version: Nullable<i64> = Nullable::null();
614            row.get_data(1, &mut version).map_err(|error| {
615                sqlx_core::Error::from(crate::error::database_error_with_context(
616                    error,
617                    "failed to read migration version",
618                ))
619            })?;
620
621            let mut checksum_bytes = Vec::new();
622            let has_value = row.get_binary(2, &mut checksum_bytes).map_err(|error| {
623                sqlx_core::Error::from(crate::error::database_error_with_context(
624                    error,
625                    "failed to read migration checksum",
626                ))
627            })?;
628
629            if let Some(version) = version.into_opt() {
630                migrations.push((version, if has_value { checksum_bytes } else { vec![] }));
631            }
632        }
633
634        Ok(migrations)
635    }
636
637    #[cfg(feature = "migrate")]
638    fn handle_apply_migration(
639        &mut self,
640        sql: &str,
641        insert_sql: &str,
642        version: i64,
643        no_tx: bool,
644    ) -> std::result::Result<std::time::Duration, sqlx_core::Error> {
645        let start = std::time::Instant::now();
646        let mut conn_guard = self.conn.lock().map_err(|_| {
647            sqlx_core::Error::Protocol(
648                "failed to lock the shared ODBC connection for migration".into(),
649            )
650        })?;
651
652        if !no_tx {
653            conn_guard.set_autocommit(false).map_err(|error| {
654                sqlx_core::Error::from(crate::error::database_error_with_context(
655                    error,
656                    "failed to start transaction for migration apply",
657                ))
658            })?;
659        }
660
661        conn_guard.execute(sql, (), None).map_err(|error| {
662            sqlx_core::Error::from(crate::error::database_error_with_context(
663                error,
664                format!("migration {version} failed"),
665            ))
666        })?;
667
668        conn_guard.execute(insert_sql, (), None).map_err(|error| {
669            sqlx_core::Error::from(crate::error::database_error_with_context(
670                error,
671                format!("failed to insert tracking record for migration {version}"),
672            ))
673        })?;
674
675        if !no_tx {
676            conn_guard.commit().map_err(|error| {
677                sqlx_core::Error::from(crate::error::database_error_with_context(
678                    error,
679                    format!("failed to commit migration {version}"),
680                ))
681            })?;
682            conn_guard.set_autocommit(true).map_err(|error| {
683                sqlx_core::Error::from(crate::error::database_error_with_context(
684                    error,
685                    "failed to restore autocommit after migration apply",
686                ))
687            })?;
688        }
689
690        Ok(start.elapsed())
691    }
692
693    #[cfg(feature = "migrate")]
694    fn handle_revert_migration(
695        &mut self,
696        sql: &str,
697        delete_sql: &str,
698        version: i64,
699        no_tx: bool,
700    ) -> std::result::Result<std::time::Duration, sqlx_core::Error> {
701        let start = std::time::Instant::now();
702        let mut conn_guard = self.conn.lock().map_err(|_| {
703            sqlx_core::Error::Protocol(
704                "failed to lock the shared ODBC connection for migration".into(),
705            )
706        })?;
707
708        if !no_tx {
709            conn_guard.set_autocommit(false).map_err(|error| {
710                sqlx_core::Error::from(crate::error::database_error_with_context(
711                    error,
712                    "failed to start transaction for migration revert",
713                ))
714            })?;
715        }
716
717        conn_guard.execute(sql, (), None).map_err(|error| {
718            sqlx_core::Error::from(crate::error::database_error_with_context(
719                error,
720                format!("revert migration {version} failed"),
721            ))
722        })?;
723
724        conn_guard.execute(delete_sql, (), None).map_err(|error| {
725            sqlx_core::Error::from(crate::error::database_error_with_context(
726                error,
727                format!("failed to delete tracking record for migration {version}"),
728            ))
729        })?;
730
731        if !no_tx {
732            conn_guard.commit().map_err(|error| {
733                sqlx_core::Error::from(crate::error::database_error_with_context(
734                    error,
735                    format!("failed to commit migration revert {version}"),
736                ))
737            })?;
738            conn_guard.set_autocommit(true).map_err(|error| {
739                sqlx_core::Error::from(crate::error::database_error_with_context(
740                    error,
741                    "failed to restore autocommit after migration revert",
742                ))
743            })?;
744        }
745
746        Ok(start.elapsed())
747    }
748}
749
750/// MSSQL connection backed by an actor thread that owns the ODBC connection.
751pub struct MssqlConnection {
752    cmd_tx: flume::Sender<Command>,
753    buffer_settings: MssqlBufferSettings,
754    transaction_depth: AtomicUsize,
755}
756
757impl std::fmt::Debug for MssqlConnection {
758    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
759        f.debug_struct("MssqlConnection").finish_non_exhaustive()
760    }
761}
762
763impl MssqlConnection {
764    /// Opens a blocking MSSQL ODBC connection with the provided options and
765    /// spawns an actor thread to own it.
766    pub fn connect_blocking(options: &MssqlConnectOptions) -> Result<Self> {
767        let env = odbc_api::environment().map_err(|error| {
768            crate::MssqlError::Configuration(format!(
769                "failed to initialize the process-wide ODBC environment: {error}"
770            ))
771        })?;
772
773        let raw_conn = env
774            .connect_with_connection_string(options.connection_string(), Default::default())
775            .map_err(|error| {
776                crate::error::database_error_with_context(
777                    error,
778                    "failed to open MSSQL ODBC connection using the supplied connection string",
779                )
780            })?;
781
782        // Wrap in SharedConnection so PreparedStatement can own the connection.
783        let conn: odbc_api::SharedConnection<'static> =
784            std::sync::Arc::new(std::sync::Mutex::new(raw_conn));
785
786        let (cmd_tx, cmd_rx) = flume::unbounded();
787
788        let actor = ConnectionActor {
789            conn,
790            stmt_cache: StatementCache::new(options.statement_cache_capacity),
791            transaction_depth: 0,
792            buffer_settings: options.buffer_settings,
793        };
794
795        // Spawn the actor on a dedicated OS thread so this function can be
796        // called from contexts where no Tokio runtime exists (for example,
797        // compile-time query checking in proc macros).
798        std::thread::spawn(move || actor.run(cmd_rx));
799
800        Ok(Self {
801            cmd_tx,
802            buffer_settings: options.buffer_settings,
803            transaction_depth: AtomicUsize::new(0),
804        })
805    }
806
807    /// Executes a minimal connectivity query.
808    pub fn ping_blocking(&self) -> std::result::Result<(), sqlx_core::Error> {
809        send_command_blocking(&self.cmd_tx, |tx| Command::Ping { response: tx })?
810    }
811
812    /// Returns the DBMS name reported by the ODBC driver.
813    pub fn dbms_name(&self) -> std::result::Result<String, sqlx_core::Error> {
814        send_command_blocking(&self.cmd_tx, |tx| {
815            Command::ExecSql {
816                sql: "SELECT 1 /* dbms_name */".into(),
817                response: tx,
818            }
819        })?;
820        Ok("MSSQL via ODBC".to_owned())
821    }
822
823    /// Begins a transaction (synchronous, called from TransactionManager).
824    pub(crate) fn begin_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
825        let r = send_command_blocking(&self.cmd_tx, |tx| Command::Begin { response: tx })?;
826        if r.is_ok() {
827            self.transaction_depth.fetch_add(1, Ordering::SeqCst);
828        }
829        r
830    }
831
832    /// Commits the current transaction (synchronous, called from TransactionManager).
833    pub(crate) fn commit_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
834        let depth = self.transaction_depth.load(Ordering::SeqCst);
835        if depth == 0 {
836            return Ok(());
837        }
838        let r = send_command_blocking(&self.cmd_tx, |tx| Command::Commit { response: tx })?;
839        if r.is_ok() {
840            if depth == 1 {
841                self.transaction_depth.store(0, Ordering::SeqCst);
842            } else {
843                self.transaction_depth.fetch_sub(1, Ordering::SeqCst);
844            }
845        }
846        r
847    }
848
849    /// Rolls back the current transaction (synchronous, called from TransactionManager).
850    pub(crate) fn rollback_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
851        let depth = self.transaction_depth.load(Ordering::SeqCst);
852        if depth == 0 {
853            return Ok(());
854        }
855        let r = send_command_blocking(&self.cmd_tx, |tx| Command::Rollback { response: tx })?;
856        if r.is_ok() {
857            if depth == 1 {
858                self.transaction_depth.store(0, Ordering::SeqCst);
859            } else {
860                self.transaction_depth.fetch_sub(1, Ordering::SeqCst);
861            }
862        }
863        r
864    }
865
866    /// Starts a rollback without blocking (called from Drop path).
867    pub(crate) fn start_rollback(&mut self) {
868        let _ = self.cmd_tx.try_send(Command::StartRollback);
869        self.transaction_depth.store(0, Ordering::SeqCst);
870    }
871
872    /// Returns the current transaction depth.
873    pub(crate) fn transaction_depth(&self) -> usize {
874        self.transaction_depth.load(Ordering::SeqCst)
875    }
876
877    /// Sets the transaction depth (used by TransactionManager).
878    pub(crate) fn set_transaction_depth(&mut self, depth: usize) {
879        self.transaction_depth.store(depth, Ordering::SeqCst);
880    }
881
882    /// Prepares a statement and returns the metadata reported by the ODBC driver.
883    pub fn prepare_blocking(
884        &self,
885        sql: sqlx_core::sql_str::SqlStr,
886    ) -> std::result::Result<MssqlStatement, sqlx_core::Error> {
887        send_command_blocking(&self.cmd_tx, |tx| Command::Prepare { sql, response: tx })?
888    }
889
890    /// Executes a SQL statement directly with no parameters and discards any result set.
891    #[cfg(feature = "migrate")]
892    pub(crate) fn exec_sql_blocking(&self, sql: &str) -> std::result::Result<(), sqlx_core::Error> {
893        send_command_blocking(&self.cmd_tx, |tx| {
894            Command::ExecSql {
895                sql: sql.to_owned(),
896                response: tx,
897            }
898        })?
899    }
900
901    /// Executes a SQL query and returns the first column of the first row as an `i64`.
902    #[cfg(feature = "migrate")]
903    pub(crate) fn scalar_i64_blocking(
904        &self,
905        sql: &str,
906    ) -> std::result::Result<Option<i64>, sqlx_core::Error> {
907        send_command_blocking(&self.cmd_tx, |tx| {
908            Command::ScalarI64 {
909                sql: sql.to_owned(),
910                response: tx,
911            }
912        })?
913    }
914
915    /// Executes a SQL query and returns rows as a list of (i64, binary) tuples.
916    #[cfg(feature = "migrate")]
917    pub(crate) fn list_migrations_blocking(
918        &self,
919        sql: &str,
920    ) -> std::result::Result<Vec<(i64, Vec<u8>)>, sqlx_core::Error> {
921        send_command_blocking(&self.cmd_tx, |tx| {
922            Command::ListMigrations {
923                sql: sql.to_owned(),
924                response: tx,
925            }
926        })?
927    }
928
929    /// Applies a migration via the actor. Returns the elapsed duration.
930    #[cfg(feature = "migrate")]
931    pub(crate) fn apply_migration_blocking(
932        &self,
933        sql: &str,
934        insert_sql: &str,
935        version: i64,
936        no_tx: bool,
937    ) -> std::result::Result<std::time::Duration, sqlx_core::Error> {
938        send_command_blocking(&self.cmd_tx, |tx| {
939            Command::ApplyMigration {
940                sql: sql.to_owned(),
941                insert_sql: insert_sql.to_owned(),
942                version,
943                no_tx,
944                response: tx,
945            }
946        })?
947    }
948
949    /// Reverts a migration via the actor. Returns the elapsed duration.
950    #[cfg(feature = "migrate")]
951    pub(crate) fn revert_migration_blocking(
952        &self,
953        sql: &str,
954        delete_sql: &str,
955        version: i64,
956        no_tx: bool,
957    ) -> std::result::Result<std::time::Duration, sqlx_core::Error> {
958        send_command_blocking(&self.cmd_tx, |tx| {
959            Command::RevertMigration {
960                sql: sql.to_owned(),
961                delete_sql: delete_sql.to_owned(),
962                version,
963                no_tx,
964                response: tx,
965            }
966        })?
967    }
968
969    /// Creates a receiver that the actor will stream query results into.
970    pub(crate) fn execute_receiver(
971        &self,
972        sql: sqlx_core::sql_str::SqlStr,
973        persistent: bool,
974        arguments: Option<MssqlArguments>,
975    ) -> flume::Receiver<ExecuteResult> {
976        let (tx, rx) = flume::bounded(64);
977        if self
978            .cmd_tx
979            .send(Command::Execute {
980                sql,
981                args: arguments,
982                persistent,
983                response: tx,
984            })
985            .is_err()
986        {
987            // Actor has shut down — drain the rx so recv_async returns None
988            let _ = rx.drain();
989        }
990        rx
991    }
992}
993
994// Dropping cmd_tx closes the channel, causing the actor loop to exit.
995impl Drop for MssqlConnection {
996    fn drop(&mut self) {}
997}
998
999// ============================================================================
1000// Connection trait
1001// ============================================================================
1002
1003impl sqlx_core::connection::Connection for MssqlConnection {
1004    type Database = crate::Mssql;
1005    type Options = MssqlConnectOptions;
1006
1007    async fn close(self) -> std::result::Result<(), sqlx_core::Error> {
1008        drop(self);
1009        Ok(())
1010    }
1011
1012    async fn close_hard(self) -> std::result::Result<(), sqlx_core::Error> {
1013        drop(self);
1014        Ok(())
1015    }
1016
1017    async fn ping(&mut self) -> std::result::Result<(), sqlx_core::Error> {
1018        send_command_async(&self.cmd_tx, |tx| Command::Ping { response: tx }).await?
1019    }
1020
1021    fn begin(
1022        &mut self,
1023    ) -> impl Future<Output = std::result::Result<Transaction<'_, Self::Database>, sqlx_core::Error>>
1024           + Send
1025           + '_ {
1026        Transaction::begin(self, None)
1027    }
1028
1029    fn shrink_buffers(&mut self) {}
1030
1031    async fn flush(&mut self) -> std::result::Result<(), sqlx_core::Error> {
1032        Ok(())
1033    }
1034
1035    fn should_flush(&self) -> bool {
1036        false
1037    }
1038
1039    fn cached_statements_size(&self) -> usize
1040    where
1041        Self::Database: sqlx_core::database::HasStatementCache,
1042    {
1043        // The statement cache lives on the actor thread; we can't query it
1044        // synchronously. Return 0 — callers use this only for diagnostics.
1045        0
1046    }
1047
1048    async fn clear_cached_statements(&mut self) -> std::result::Result<(), sqlx_core::Error>
1049    where
1050        Self::Database: sqlx_core::database::HasStatementCache,
1051    {
1052        // The cache lives on the actor; clearing it requires a new command.
1053        // For now this is a no-op since the cache is per-connection and
1054        // bounded by `statement_cache_capacity`.
1055        Ok(())
1056    }
1057}
1058
1059// ============================================================================
1060// Executor trait
1061// ============================================================================
1062
1063impl<'c> Executor<'c> for &'c mut MssqlConnection {
1064    type Database = crate::Mssql;
1065
1066    fn fetch_many<'e, 'q, E>(
1067        self,
1068        mut query: E,
1069    ) -> BoxStream<'e, std::result::Result<Either<MssqlQueryResult, MssqlRow>, sqlx_core::Error>>
1070    where
1071        'c: 'e,
1072        E: Execute<'q, Self::Database>,
1073        'q: 'e,
1074        E: 'q,
1075    {
1076        let arguments = query.take_arguments().map_err(sqlx_core::Error::Encode);
1077        let persistent = query.persistent();
1078        let sql = query.sql();
1079
1080        match arguments {
1081            Ok(arguments) => {
1082                receiver_to_stream(self.execute_receiver(sql, persistent, arguments))
1083            }
1084            Err(error) => stream::once(future::ready(Err(error))).boxed(),
1085        }
1086    }
1087
1088    fn fetch_optional<'e, 'q, E>(
1089        self,
1090        mut query: E,
1091    ) -> BoxFuture<'e, std::result::Result<Option<MssqlRow>, sqlx_core::Error>>
1092    where
1093        'c: 'e,
1094        E: Execute<'q, Self::Database>,
1095        'q: 'e,
1096        E: 'q,
1097    {
1098        let arguments = query.take_arguments().map_err(sqlx_core::Error::Encode);
1099        let persistent = query.persistent();
1100        let sql = query.sql();
1101
1102        Box::pin(async move {
1103            let rx = self.execute_receiver(sql, persistent, arguments?);
1104            while let Ok(item) = rx.recv_async().await {
1105                match item? {
1106                    Either::Right(row) => return Ok(Some(row)),
1107                    Either::Left(_) => {}
1108                }
1109            }
1110            Ok(None)
1111        })
1112    }
1113
1114    fn prepare_with<'e>(
1115        self,
1116        sql: sqlx_core::sql_str::SqlStr,
1117        _parameters: &[crate::MssqlTypeInfo],
1118    ) -> BoxFuture<'e, std::result::Result<MssqlStatement, sqlx_core::Error>>
1119    where
1120        'c: 'e,
1121    {
1122        let cmd_tx = self.cmd_tx.clone();
1123        Box::pin(async move {
1124            send_command_async(&cmd_tx, |tx| Command::Prepare { sql, response: tx }).await?
1125        })
1126    }
1127
1128    #[cfg(feature = "offline")]
1129    fn describe<'e>(
1130        self,
1131        sql: sqlx_core::sql_str::SqlStr,
1132    ) -> BoxFuture<'e, std::result::Result<sqlx_core::describe::Describe<Self::Database>, sqlx_core::Error>>
1133    where
1134        'c: 'e,
1135    {
1136        use sqlx_core::statement::Statement;
1137        let cmd_tx = self.cmd_tx.clone();
1138        Box::pin(async move {
1139            let statement =
1140                send_command_async(&cmd_tx, |tx| Command::Prepare { sql, response: tx }).await??;
1141            let columns = statement.columns().to_vec();
1142            let column_count = columns.len();
1143            let parameter_count = statement
1144                .parameters()
1145                .map(|p| match p {
1146                    Either::Left(types) => types.len(),
1147                    Either::Right(count) => count,
1148                })
1149                .unwrap_or(0);
1150
1151            Ok(sqlx_core::describe::Describe {
1152                columns,
1153                parameters: Some(Either::Right(parameter_count)),
1154                nullable: vec![None; column_count],
1155            })
1156        })
1157    }
1158}
1159
1160// ============================================================================
1161// Helper: send a command and await a oneshot response (async)
1162// ============================================================================
1163
1164async fn send_command_async<T: Send>(
1165    cmd_tx: &flume::Sender<Command>,
1166    make_cmd: impl FnOnce(tokio::sync::oneshot::Sender<T>) -> Command,
1167) -> std::result::Result<T, sqlx_core::Error> {
1168    let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
1169    let cmd = make_cmd(resp_tx);
1170    cmd_tx.send(cmd).map_err(|_| {
1171        sqlx_core::Error::Protocol(
1172            "MSSQL ODBC connection actor has shut down".to_owned(),
1173        )
1174    })?;
1175    resp_rx.await.map_err(|_| {
1176        sqlx_core::Error::Protocol(
1177            "MSSQL ODBC connection actor response channel closed".to_owned(),
1178        )
1179    })
1180}
1181
1182// ============================================================================
1183// Helper: send a command and wait for a oneshot response (blocking)
1184// ============================================================================
1185
1186fn send_command_blocking<T: Send>(
1187    cmd_tx: &flume::Sender<Command>,
1188    make_cmd: impl FnOnce(tokio::sync::oneshot::Sender<T>) -> Command,
1189) -> std::result::Result<T, sqlx_core::Error> {
1190    let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
1191    let cmd = make_cmd(resp_tx);
1192    cmd_tx.send(cmd).map_err(|_| {
1193        sqlx_core::Error::Protocol(
1194            "MSSQL ODBC connection actor has shut down".to_owned(),
1195        )
1196    })?;
1197    resp_rx.blocking_recv().map_err(|_| {
1198        sqlx_core::Error::Protocol(
1199            "MSSQL ODBC connection actor response channel closed".to_owned(),
1200        )
1201    })
1202}
1203
1204// ============================================================================
1205// Helper: convert a flume receiver to a BoxStream
1206// ============================================================================
1207
1208fn receiver_to_stream<'e>(
1209    rx: flume::Receiver<ExecuteResult>,
1210) -> BoxStream<'e, ExecuteResult> {
1211    stream::unfold(rx, |rx| async move {
1212        rx.recv_async().await.ok().map(|item| (item, rx))
1213    })
1214    .boxed()
1215}
1216
1217// ============================================================================
1218// Helper: send query-result rows via the execute channel
1219// ============================================================================
1220
1221fn send_rows_affected(
1222    rows_affected: Option<usize>,
1223    tx: &ExecuteSender,
1224) -> std::result::Result<(), sqlx_core::Error> {
1225    let rows_affected = rows_affected
1226        .unwrap_or(0)
1227        .try_into()
1228        .map_err(|_| sqlx_core::Error::Protocol("ODBC row count does not fit in u64".to_owned()))?;
1229    send_done(tx, rows_affected);
1230    Ok(())
1231}
1232
1233fn send_done(tx: &ExecuteSender, rows_affected: u64) -> bool {
1234    tx.send(Ok(Either::Left(MssqlQueryResult::new(rows_affected))))
1235        .is_ok()
1236}
1237
1238fn send_row(tx: &ExecuteSender, row: MssqlRow) -> bool {
1239    tx.send(Ok(Either::Right(row))).is_ok()
1240}
1241
1242pub(crate) fn collect_columns(
1243    cursor: &mut impl ResultSetMetadata,
1244) -> std::result::Result<Vec<MssqlColumn>, sqlx_core::Error> {
1245    let count = cursor.num_result_cols().map_err(|error| {
1246        crate::error::database_error_with_context(error, "failed to read ODBC result-column count")
1247    })?;
1248    let count = usize::try_from(count).map_err(|_| {
1249        sqlx_core::Error::Protocol(format!("ODBC returned a negative column count: {count}"))
1250    })?;
1251
1252    let mut columns = Vec::with_capacity(count);
1253    for ordinal in 0..count {
1254        let column_number = u16::try_from(ordinal + 1).map_err(|_| {
1255            sqlx_core::Error::Protocol(format!("ODBC column index exceeds u16: {}", ordinal + 1))
1256        })?;
1257
1258        let mut description = odbc_api::ColumnDescription::default();
1259        cursor
1260            .describe_col(column_number, &mut description)
1261            .map_err(|error| {
1262                crate::error::database_error_with_context(
1263                    error,
1264                    format!("failed to describe ODBC result column {column_number}"),
1265                )
1266            })?;
1267        let name = description
1268            .name_to_string()
1269            .unwrap_or_else(|_| format!("col{ordinal}"));
1270
1271        columns.push(MssqlColumn::new(
1272            ordinal,
1273            name,
1274            MssqlTypeInfo::new(description.data_type),
1275        ));
1276    }
1277
1278    Ok(columns)
1279}
1280
1281fn collect_prepared_columns(
1282    prepared: &mut impl PreparedStatementMetadata,
1283    parameter_count: u16,
1284) -> std::result::Result<Vec<MssqlColumn>, sqlx_core::Error> {
1285    match collect_columns(prepared) {
1286        Ok(columns) => Ok(columns),
1287        Err(error) if parameter_count > 0 => {
1288            validate_parameter_metadata(prepared, parameter_count)?;
1289            log::debug!("ODBC driver deferred result-column metadata until execution: {error}");
1290            Ok(Vec::new())
1291        }
1292        Err(error) => Err(error),
1293    }
1294}
1295
1296trait PreparedStatementMetadata: ResultSetMetadata {
1297    fn describe_prepared_parameter(
1298        &mut self,
1299        index: u16,
1300    ) -> std::result::Result<(), odbc_api::Error>;
1301}
1302
1303impl<S> PreparedStatementMetadata for odbc_api::Prepared<S>
1304where
1305    S: odbc_api::handles::AsStatementRef,
1306{
1307    fn describe_prepared_parameter(
1308        &mut self,
1309        index: u16,
1310    ) -> std::result::Result<(), odbc_api::Error> {
1311        self.describe_param(index).map(|_| ())
1312    }
1313}
1314
1315fn validate_parameter_metadata(
1316    prepared: &mut impl PreparedStatementMetadata,
1317    parameter_count: u16,
1318) -> std::result::Result<(), sqlx_core::Error> {
1319    for index in 1..=parameter_count {
1320        prepared
1321            .describe_prepared_parameter(index)
1322            .map_err(|error| {
1323                crate::error::database_error_with_context(
1324                    error,
1325                    format!("failed to describe ODBC parameter {index}"),
1326                )
1327            })?;
1328    }
1329
1330    Ok(())
1331}
1332
1333fn stream_result_sets<C>(
1334    mut cursor: C,
1335    settings: MssqlBufferSettings,
1336    tx: &ExecuteSender,
1337) -> std::result::Result<(), sqlx_core::Error>
1338where
1339    C: Cursor + ResultSetMetadata,
1340{
1341    loop {
1342        if cursor.num_result_cols().map_err(|error| {
1343            crate::error::database_error_with_context(
1344                error,
1345                "failed to read ODBC result-column count",
1346            )
1347        })? == 0
1348        {
1349            send_done(tx, 0);
1350        } else if let Some(max_column_size) = settings.max_column_size {
1351            let (receiver_open, finished_cursor) =
1352                stream_rows_buffered(cursor, settings.batch_size, max_column_size, tx)?;
1353            if !receiver_open {
1354                return Ok(());
1355            }
1356            cursor = finished_cursor;
1357        } else if !stream_rows_unbuffered(&mut cursor, tx)? {
1358            return Ok(());
1359        }
1360
1361        match cursor.more_results().map_err(|error| {
1362            crate::error::database_error_with_context(error, "failed to advance ODBC result set")
1363        })? {
1364            Some(next_cursor) => cursor = next_cursor,
1365            None => return Ok(()),
1366        }
1367    }
1368}
1369
1370#[derive(Debug)]
1371struct ColumnBinding {
1372    column: MssqlColumn,
1373    buffer_desc: BufferDesc,
1374}
1375
1376fn stream_rows_buffered<C>(
1377    cursor: C,
1378    batch_size: usize,
1379    max_column_size: usize,
1380    tx: &ExecuteSender,
1381) -> std::result::Result<(bool, C), sqlx_core::Error>
1382where
1383    C: Cursor + ResultSetMetadata,
1384{
1385    let mut cursor = cursor;
1386    let bindings = build_buffer_bindings(&mut cursor, max_column_size)?;
1387    let buffer_descriptions = bindings
1388        .iter()
1389        .map(|binding| binding.buffer_desc)
1390        .collect::<Vec<_>>();
1391    let mut row_set_cursor = cursor
1392        .bind_buffer(ColumnarDynBuffer::from_descs(
1393            batch_size,
1394            buffer_descriptions,
1395        ))
1396        .map_err(|error| {
1397            crate::error::database_error_with_context(
1398                error,
1399                format!(
1400                    "ODBC buffered fetching could not be enabled with batch_size={batch_size}; \
1401                     this driver may reject the row-array or row-binding statement attributes \
1402                     used for column-wise buffered fetching, so use \
1403                     MssqlConnectOptions::max_column_size(None) to fetch rows unbuffered"
1404                ),
1405            )
1406        })?;
1407    let columns: Arc<[MssqlColumn]> = bindings
1408        .iter()
1409        .map(|binding| binding.column.clone())
1410        .collect::<Vec<_>>()
1411        .into();
1412
1413    while let Some(batch) = row_set_cursor.fetch().map_err(|error| {
1414        crate::error::database_error_with_context(error, "ODBC buffered fetch failed")
1415    })? {
1416        let column_values = bindings
1417            .iter()
1418            .enumerate()
1419            .map(|(index, binding)| {
1420                buffered_column_values(batch.column(index), binding).map_err(|error| {
1421                    sqlx_core::Error::Protocol(format!(
1422                        "ODBC buffered fetch could not convert column {} (`{}`) using buffer {:?}: {error}",
1423                        binding.column.ordinal() + 1,
1424                        binding.column.name(),
1425                        binding.buffer_desc
1426                    ))
1427                })
1428            })
1429            .collect::<std::result::Result<Vec<_>, _>>()?;
1430
1431        let mut column_iters = column_values
1432            .into_iter()
1433            .map(Vec::into_iter)
1434            .collect::<Vec<_>>();
1435
1436        for row_index in 0..batch.num_rows() {
1437            let values = column_iters
1438                .iter_mut()
1439                .map(|values| {
1440                    values.next().map(MssqlValue::new).ok_or_else(|| {
1441                        sqlx_core::Error::Protocol(format!(
1442                            "ODBC buffered fetch produced too few values for row {}",
1443                            row_index + 1
1444                        ))
1445                    })
1446                })
1447                .collect::<std::result::Result<Vec<_>, _>>()?;
1448            if !send_row(tx, MssqlRow::new_shared(Arc::clone(&columns), values)) {
1449                let (cursor, _) = row_set_cursor.unbind().map_err(|error| {
1450                    crate::error::database_error_with_context(
1451                        error,
1452                        "ODBC buffered fetch could not unbind row buffer after receiver closed",
1453                    )
1454                })?;
1455                return Ok((false, cursor));
1456            }
1457        }
1458    }
1459
1460    send_done(tx, 0);
1461    let (cursor, _) = row_set_cursor.unbind().map_err(|error| {
1462        crate::error::database_error_with_context(
1463            error,
1464            "ODBC buffered fetch could not unbind row buffer",
1465        )
1466    })?;
1467    Ok((true, cursor))
1468}
1469
1470fn build_buffer_bindings(
1471    cursor: &mut impl ResultSetMetadata,
1472    max_column_size: usize,
1473) -> std::result::Result<Vec<ColumnBinding>, sqlx_core::Error> {
1474    collect_columns(cursor).map(|columns| {
1475        columns
1476            .into_iter()
1477            .map(|column| ColumnBinding {
1478                buffer_desc: map_buffer_desc(column.type_info().data_type(), max_column_size),
1479                column,
1480            })
1481            .collect()
1482    })
1483}
1484
1485fn map_buffer_desc(data_type: DataType, max_column_size: usize) -> BufferDesc {
1486    match data_type {
1487        DataType::TinyInt | DataType::SmallInt | DataType::Integer | DataType::BigInt => {
1488            BufferDesc::I64 { nullable: true }
1489        }
1490        DataType::Real => BufferDesc::F32 { nullable: true },
1491        DataType::Float { .. } | DataType::Double => BufferDesc::F64 { nullable: true },
1492        DataType::Bit => BufferDesc::Bit { nullable: true },
1493        DataType::Date => BufferDesc::Date { nullable: true },
1494        DataType::Time { .. } => BufferDesc::Time { nullable: true },
1495        DataType::Timestamp { .. } => BufferDesc::Timestamp { nullable: true },
1496        DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => {
1497            BufferDesc::Binary {
1498                max_bytes: max_column_size,
1499            }
1500        }
1501        // Wide character types use SQL_C_WCHAR buffers (UTF-16) to avoid
1502        // codepage-dependent corruption of non-ASCII data.
1503        DataType::WChar { .. } | DataType::WVarchar { .. } | DataType::WLongVarchar { .. } => {
1504            BufferDesc::WText {
1505                max_str_len: max_column_size,
1506            }
1507        }
1508        // Narrow character types and fallback types use SQL_C_CHAR.
1509        DataType::Char { .. }
1510        | DataType::Varchar { .. }
1511        | DataType::LongVarchar { .. }
1512        | DataType::Other { .. }
1513        | DataType::Unknown
1514        | DataType::Decimal { .. }
1515        | DataType::Numeric { .. } => BufferDesc::Text {
1516            max_str_len: max_column_size,
1517        },
1518    }
1519}
1520
1521fn buffered_column_values(
1522    slice: AnyColumnBufferSlice<'_>,
1523    binding: &ColumnBinding,
1524) -> std::result::Result<Vec<MssqlValueKind>, sqlx_core::Error> {
1525    let desc = binding.buffer_desc;
1526    Ok(match desc {
1527        BufferDesc::I8 { nullable } => buffered_numeric(&slice, desc, nullable, |value: i8| {
1528            MssqlValueKind::TinyInt(i16::from(value))
1529        })?,
1530        BufferDesc::I16 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
1531            MssqlValueKind::SmallInt(value)
1532        })?,
1533        BufferDesc::I32 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
1534            MssqlValueKind::Integer(value)
1535        })?,
1536        BufferDesc::I64 { nullable } => {
1537            buffered_numeric(&slice, desc, nullable, MssqlValueKind::BigInt)?
1538        }
1539        BufferDesc::U8 { nullable } => buffered_numeric(&slice, desc, nullable, |value: u8| {
1540            MssqlValueKind::BigInt(i64::from(value))
1541        })?,
1542        BufferDesc::F32 { nullable } => {
1543            buffered_numeric(&slice, desc, nullable, MssqlValueKind::Real)?
1544        }
1545        BufferDesc::F64 { nullable } => {
1546            buffered_numeric(&slice, desc, nullable, MssqlValueKind::Double)?
1547        }
1548        BufferDesc::Bit { nullable } => {
1549            buffered_numeric(&slice, desc, nullable, |value: odbc_api::Bit| {
1550                MssqlValueKind::Bit(value.as_bool())
1551            })?
1552        }
1553        BufferDesc::Date { nullable } => {
1554            buffered_numeric(&slice, desc, nullable, MssqlValueKind::Date)?
1555        }
1556        BufferDesc::Time { nullable } => {
1557            buffered_numeric(&slice, desc, nullable, MssqlValueKind::Time)?
1558        }
1559        BufferDesc::Timestamp { nullable } => {
1560            buffered_numeric(&slice, desc, nullable, MssqlValueKind::Timestamp)?
1561        }
1562        BufferDesc::Text { .. } => {
1563            let text = expect_buffer_slice(slice.as_text(), desc)?;
1564            text.iter()
1565                .map(|value| {
1566                    value
1567                        .map(|bytes| {
1568                            MssqlValueKind::Text(String::from_utf8_lossy(bytes).into_owned())
1569                        })
1570                        .unwrap_or(MssqlValueKind::Null)
1571                })
1572                .collect()
1573        }
1574        BufferDesc::WText { .. } => {
1575            let text = expect_buffer_slice(slice.as_wide_text(), desc)?;
1576            text.iter()
1577                .map(|value| {
1578                    value
1579                        .map(|chars| MssqlValueKind::Text(String::from_utf16_lossy(chars.into())))
1580                        .unwrap_or(MssqlValueKind::Null)
1581                })
1582                .collect()
1583        }
1584        BufferDesc::Binary { .. } => {
1585            let binary = expect_buffer_slice(slice.as_binary(), desc)?;
1586            binary
1587                .iter()
1588                .map(|value| {
1589                    value
1590                        .map(|bytes| MssqlValueKind::Binary(bytes.to_vec()))
1591                        .unwrap_or(MssqlValueKind::Null)
1592                })
1593                .collect()
1594        }
1595        BufferDesc::Numeric => {
1596            return Err(sqlx_core::Error::Protocol(format!(
1597                "unsupported ODBC buffer descriptor: {desc:?}"
1598            )))
1599        }
1600    })
1601}
1602
1603fn buffered_numeric<T, F>(
1604    slice: &AnyColumnBufferSlice<'_>,
1605    desc: BufferDesc,
1606    nullable: bool,
1607    map: F,
1608) -> std::result::Result<Vec<MssqlValueKind>, sqlx_core::Error>
1609where
1610    T: Copy + odbc_api::Pod,
1611    F: FnMut(T) -> MssqlValueKind,
1612{
1613    if nullable {
1614        Ok(buffered_nullable_numeric(
1615            expect_buffer_slice(slice.as_nullable_slice::<T>(), desc)?,
1616            map,
1617        ))
1618    } else {
1619        Ok(expect_buffer_slice(slice.as_slice::<T>(), desc)?
1620            .iter()
1621            .copied()
1622            .map(map)
1623            .collect())
1624    }
1625}
1626
1627fn buffered_nullable_numeric<T, F>(slice: NullableSlice<'_, T>, mut map: F) -> Vec<MssqlValueKind>
1628where
1629    T: Copy,
1630    F: FnMut(T) -> MssqlValueKind,
1631{
1632    slice
1633        .map(|value| value.copied().map(&mut map).unwrap_or(MssqlValueKind::Null))
1634        .collect()
1635}
1636
1637fn expect_buffer_slice<T>(
1638    slice: Option<T>,
1639    desc: BufferDesc,
1640) -> std::result::Result<T, sqlx_core::Error> {
1641    slice.ok_or_else(|| {
1642        sqlx_core::Error::Protocol(format!(
1643            "ODBC column buffer {desc:?} did not match fetched slice"
1644        ))
1645    })
1646}
1647
1648fn stream_rows_unbuffered<C>(
1649    cursor: &mut C,
1650    tx: &ExecuteSender,
1651) -> std::result::Result<bool, sqlx_core::Error>
1652where
1653    C: Cursor + ResultSetMetadata,
1654{
1655    let columns: Arc<[MssqlColumn]> = collect_columns(cursor)?.into();
1656
1657    while let Some(mut cursor_row) = cursor.next_row().map_err(|error| {
1658        crate::error::database_error_with_context(
1659            error,
1660            "ODBC unbuffered fetch failed while reading the next row",
1661        )
1662    })? {
1663        let mut values = Vec::with_capacity(columns.len());
1664
1665        for column in columns.iter() {
1666            let column_number = u16::try_from(sqlx_core::column::Column::ordinal(column) + 1)
1667                .map_err(|_| {
1668                    sqlx_core::Error::Protocol("ODBC column index exceeds u16".to_owned())
1669                })?;
1670            values.push(fetch_value(&mut cursor_row, column_number, column)?);
1671        }
1672
1673        if !send_row(tx, MssqlRow::new_shared(Arc::clone(&columns), values)) {
1674            return Ok(false);
1675        }
1676    }
1677
1678    send_done(tx, 0);
1679    Ok(true)
1680}
1681
1682fn fetch_value(
1683    row: &mut odbc_api::CursorRow<'_>,
1684    column_number: u16,
1685    column: &MssqlColumn,
1686) -> std::result::Result<MssqlValue, sqlx_core::Error> {
1687    let data_type = column.type_info().data_type();
1688
1689    let kind = match data_type {
1690        DataType::Bit => {
1691            let mut value = Nullable::<odbc_api::Bit>::null();
1692            row.get_data(column_number, &mut value).map_err(|error| {
1693                crate::error::database_error_with_context_lazy(error, || {
1694                    fetch_context(column, data_type)
1695                })
1696            })?;
1697            value
1698                .into_opt()
1699                .map(|value| MssqlValueKind::Bit(value.as_bool()))
1700                .unwrap_or(MssqlValueKind::Null)
1701        }
1702        DataType::TinyInt => {
1703            // MSSQL TINYINT is unsigned (0-255), so read as i16 to avoid
1704            // signed overflow of values > 127.
1705            let mut value = Nullable::<i16>::null();
1706            row.get_data(column_number, &mut value).map_err(|error| {
1707                crate::error::database_error_with_context_lazy(error, || {
1708                    fetch_context(column, data_type)
1709                })
1710            })?;
1711            value
1712                .into_opt()
1713                .map(MssqlValueKind::TinyInt)
1714                .unwrap_or(MssqlValueKind::Null)
1715        }
1716        DataType::SmallInt => fetch_nullable(
1717            row,
1718            column_number,
1719            column,
1720            data_type,
1721            MssqlValueKind::SmallInt,
1722        )?,
1723        DataType::Integer => fetch_nullable(
1724            row,
1725            column_number,
1726            column,
1727            data_type,
1728            MssqlValueKind::Integer,
1729        )?,
1730        DataType::BigInt => {
1731            fetch_nullable(row, column_number, column, data_type, MssqlValueKind::BigInt)?
1732        }
1733        DataType::Real => {
1734            fetch_nullable(row, column_number, column, data_type, MssqlValueKind::Real)?
1735        }
1736        DataType::Float { .. } | DataType::Double => {
1737            fetch_nullable(row, column_number, column, data_type, MssqlValueKind::Double)?
1738        }
1739        DataType::Date => {
1740            fetch_nullable(row, column_number, column, data_type, MssqlValueKind::Date)?
1741        }
1742        DataType::Time { .. } => {
1743            fetch_nullable(row, column_number, column, data_type, MssqlValueKind::Time)?
1744        }
1745        DataType::Timestamp { .. } => fetch_nullable(
1746            row,
1747            column_number,
1748            column,
1749            data_type,
1750            MssqlValueKind::Timestamp,
1751        )?,
1752        DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => {
1753            let mut value = Vec::new();
1754            if row.get_binary(column_number, &mut value).map_err(|error| {
1755                crate::error::database_error_with_context_lazy(error, || {
1756                    fetch_context(column, data_type)
1757                })
1758            })? {
1759                MssqlValueKind::Binary(value)
1760            } else {
1761                MssqlValueKind::Null
1762            }
1763        }
1764        DataType::Other {
1765            data_type: sql_type, ..
1766        } if sql_type.0 == -11 => {
1767            // SQL_GUID / UNIQUEIDENTIFIER in MSSQL
1768            let mut value = Vec::new();
1769            if row.get_binary(column_number, &mut value).map_err(|error| {
1770                crate::error::database_error_with_context_lazy(error, || {
1771                    fetch_context(column, data_type)
1772                })
1773            })? {
1774                if value.len() == 16 {
1775                    let mut guid = [0u8; 16];
1776                    guid.copy_from_slice(&value);
1777                    MssqlValueKind::Guid(guid)
1778                } else {
1779                    // Fallback: treat GUID data as text
1780                    MssqlValueKind::Text(String::from_utf16_lossy(
1781                        &value.iter().map(|&b| b as u16).collect::<Vec<_>>(),
1782                    ))
1783                }
1784            } else {
1785                MssqlValueKind::Null
1786            }
1787        }
1788        _ => {
1789            let mut value = Vec::new();
1790            if row
1791                .get_wide_text(column_number, &mut value)
1792                .map_err(|error| {
1793                    crate::error::database_error_with_context_lazy(error, || {
1794                        fetch_context(column, data_type)
1795                    })
1796                })?
1797            {
1798                MssqlValueKind::Text(String::from_utf16_lossy(&value))
1799            } else {
1800                MssqlValueKind::Null
1801            }
1802        }
1803    };
1804
1805    Ok(MssqlValue::new(kind))
1806}
1807
1808fn fetch_nullable<T, F>(
1809    row: &mut odbc_api::CursorRow<'_>,
1810    column_number: u16,
1811    column: &MssqlColumn,
1812    data_type: DataType,
1813    map: F,
1814) -> std::result::Result<MssqlValueKind, sqlx_core::Error>
1815where
1816    T: Default + Copy + odbc_api::parameter::CElement + odbc_api::handles::CDataMut,
1817    Nullable<T>: odbc_api::parameter::CElement + odbc_api::handles::CDataMut,
1818    F: FnOnce(T) -> MssqlValueKind,
1819{
1820    let mut value = Nullable::<T>::null();
1821    row.get_data(column_number, &mut value).map_err(|error| {
1822        crate::error::database_error_with_context_lazy(error, || fetch_context(column, data_type))
1823    })?;
1824    Ok(value.into_opt().map(map).unwrap_or(MssqlValueKind::Null))
1825}
1826
1827fn fetch_context(column: &MssqlColumn, data_type: DataType) -> String {
1828    format!(
1829        "failed to fetch ODBC column {} (`{}`) as {data_type:?}",
1830        column.ordinal() + 1,
1831        column.name()
1832    )
1833}
1834
1835fn sql_preview(sql: &str) -> String {
1836    const MAX_LEN: usize = 160;
1837
1838    let compact = sql.split_whitespace().collect::<Vec<_>>().join(" ");
1839    if compact.len() <= MAX_LEN {
1840        compact
1841    } else {
1842        let mut preview = compact.chars().take(MAX_LEN - 3).collect::<String>();
1843        preview.push_str("...");
1844        preview
1845    }
1846}
1847
1848/// Offloads a blocking operation to Tokio's blocking thread pool.
1849///
1850/// The closure must satisfy `Send + 'static` so it can be moved across
1851/// threads.
1852pub(crate) async fn offload_blocking<F, T>(f: F) -> std::result::Result<T, sqlx_core::Error>
1853where
1854    F: FnOnce() -> std::result::Result<T, sqlx_core::Error> + Send + 'static,
1855    T: Send + 'static,
1856{
1857    tokio::task::spawn_blocking(f)
1858        .await
1859        .map_err(|e| sqlx_core::Error::Protocol(format!("blocking task panicked: {e}")))?
1860}
1861
1862#[cfg(test)]
1863mod tests {
1864    use super::*;
1865
1866    #[test]
1867    fn buffered_fetch_maps_numeric_types_to_nullable_64_bit_buffers() {
1868        assert!(matches!(
1869            map_buffer_desc(DataType::TinyInt, 64),
1870            BufferDesc::I64 { nullable: true }
1871        ));
1872        assert!(matches!(
1873            map_buffer_desc(DataType::Integer, 64),
1874            BufferDesc::I64 { nullable: true }
1875        ));
1876        assert!(matches!(
1877            map_buffer_desc(DataType::BigInt, 64),
1878            BufferDesc::I64 { nullable: true }
1879        ));
1880    }
1881
1882    #[test]
1883    fn buffered_fetch_uses_configured_limits_for_variable_sized_data() {
1884        assert_eq!(
1885            map_buffer_desc(DataType::Varchar { length: None }, 32),
1886            BufferDesc::Text { max_str_len: 32 }
1887        );
1888        assert_eq!(
1889            map_buffer_desc(DataType::Varbinary { length: None }, 16),
1890            BufferDesc::Binary { max_bytes: 16 }
1891        );
1892    }
1893
1894    #[test]
1895    fn buffered_fetch_maps_wide_char_types_to_wtext() {
1896        assert!(matches!(
1897            map_buffer_desc(DataType::WChar { length: None }, 64),
1898            BufferDesc::WText { max_str_len: 64 }
1899        ));
1900        assert!(matches!(
1901            map_buffer_desc(DataType::WVarchar { length: None }, 128),
1902            BufferDesc::WText { max_str_len: 128 }
1903        ));
1904        assert!(matches!(
1905            map_buffer_desc(DataType::WLongVarchar { length: None }, 256),
1906            BufferDesc::WText { max_str_len: 256 }
1907        ));
1908    }
1909
1910    #[test]
1911    fn buffered_fetch_maps_narrow_char_types_to_text() {
1912        assert!(matches!(
1913            map_buffer_desc(DataType::Char { length: None }, 64),
1914            BufferDesc::Text { max_str_len: 64 }
1915        ));
1916        assert!(matches!(
1917            map_buffer_desc(DataType::Varchar { length: None }, 64),
1918            BufferDesc::Text { max_str_len: 64 }
1919        ));
1920        assert!(matches!(
1921            map_buffer_desc(DataType::LongVarchar { length: None }, 64),
1922            BufferDesc::Text { max_str_len: 64 }
1923        ));
1924    }
1925
1926}