1use crate::{
2 OdbcArguments, OdbcBufferSettings, OdbcColumn, OdbcConnectOptions, OdbcParameterCollection,
3 OdbcQueryResult, OdbcRow, OdbcStatement, OdbcTypeInfo, OdbcValue, OdbcValueKind, Result,
4};
5use futures_core::future::BoxFuture;
6use futures_core::stream::BoxStream;
7use futures_util::{future, stream, StreamExt};
8use odbc_api::buffers::{AnyColumnBufferSlice, BufferDesc, ColumnarDynBuffer, NullableSlice};
9use odbc_api::{Cursor, DataType, Nullable, ResultSetMetadata};
10use sqlx_core::column::Column;
11use sqlx_core::executor::{Execute, Executor};
12use sqlx_core::transaction::Transaction;
13use sqlx_core::Either;
14use std::future::Future;
15use std::sync::Arc;
16
17pub struct OdbcConnection {
22 conn: odbc_api::Connection<'static>,
23 buffer_settings: OdbcBufferSettings,
24 transaction_depth: usize,
25}
26
27impl std::fmt::Debug for OdbcConnection {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("OdbcConnection").finish_non_exhaustive()
30 }
31}
32
33impl OdbcConnection {
34 pub fn connect_blocking(options: &OdbcConnectOptions) -> Result<Self> {
36 let env = odbc_api::environment().map_err(|error| {
37 crate::OdbcError::Configuration(format!(
38 "failed to initialize the process-wide ODBC environment: {error}"
39 ))
40 })?;
41
42 let conn = env
43 .connect_with_connection_string(options.connection_string(), Default::default())
44 .map_err(|error| {
45 crate::error::database_error_with_context(
46 error,
47 "failed to open ODBC connection using the supplied connection string",
48 )
49 })?;
50
51 Ok(Self {
52 conn,
53 buffer_settings: options.buffer_settings,
54 transaction_depth: 0,
55 })
56 }
57
58 pub fn ping_blocking(&mut self) -> Result<()> {
60 let query = self
61 .conn
62 .database_management_system_name()
63 .map(|name| ping_query_for_dbms_name(&name))
64 .unwrap_or("SELECT 1");
65 self.conn.execute(query, (), None).map_err(|error| {
66 crate::error::database_error_with_context(
67 error,
68 format!("ODBC ping query failed: `{query}`"),
69 )
70 })?;
71 Ok(())
72 }
73
74 pub fn dbms_name(&self) -> Result<String> {
76 self.conn
77 .database_management_system_name()
78 .map_err(|error| {
79 crate::error::database_error_with_context(
80 error,
81 "failed to read the ODBC DBMS name from SQLGetInfo",
82 )
83 })
84 }
85
86 pub(crate) fn begin_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
87 if self.transaction_depth > 0 {
88 return Err(sqlx_core::Error::InvalidSavePointStatement);
89 }
90
91 self.conn.set_autocommit(false).map_err(|error| {
92 crate::error::database_error_with_context(
93 error,
94 "failed to disable ODBC autocommit while beginning a transaction",
95 )
96 })?;
97 self.transaction_depth = 1;
98 Ok(())
99 }
100
101 pub(crate) fn commit_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
102 if self.transaction_depth == 0 {
103 return Ok(());
104 }
105
106 self.conn.commit().map_err(|error| {
107 crate::error::database_error_with_context(
108 error,
109 "failed to commit the active ODBC transaction",
110 )
111 })?;
112 self.conn.set_autocommit(true).map_err(|error| {
113 crate::error::database_error_with_context(
114 error,
115 "failed to restore ODBC autocommit after commit",
116 )
117 })?;
118 self.transaction_depth = 0;
119 Ok(())
120 }
121
122 pub(crate) fn rollback_blocking(&mut self) -> std::result::Result<(), sqlx_core::Error> {
123 if self.transaction_depth == 0 {
124 return Ok(());
125 }
126
127 self.conn.rollback().map_err(|error| {
128 crate::error::database_error_with_context(
129 error,
130 "failed to roll back the active ODBC transaction",
131 )
132 })?;
133 self.conn.set_autocommit(true).map_err(|error| {
134 crate::error::database_error_with_context(
135 error,
136 "failed to restore ODBC autocommit after rollback",
137 )
138 })?;
139 self.transaction_depth = 0;
140 Ok(())
141 }
142
143 pub(crate) fn start_rollback(&mut self) {
144 if self.transaction_depth == 0 {
145 return;
146 }
147
148 if self.conn.rollback().is_ok() {
149 let _ = self.conn.set_autocommit(true);
150 self.transaction_depth = 0;
151 }
152 }
153
154 pub(crate) const fn transaction_depth(&self) -> usize {
155 self.transaction_depth
156 }
157
158 pub fn prepare_blocking(
160 &mut self,
161 sql: sqlx_core::sql_str::SqlStr,
162 ) -> std::result::Result<OdbcStatement, sqlx_core::Error> {
163 let mut prepared = self.conn.prepare(sql.as_str()).map_err(|error| {
164 crate::error::database_error_with_context(
165 error,
166 format!(
167 "failed to prepare ODBC statement: `{}`",
168 sql_preview(sql.as_str())
169 ),
170 )
171 })?;
172 let parameters = prepared.num_params().map_err(|error| {
173 crate::error::database_error_with_context(
174 error,
175 format!(
176 "failed to read ODBC parameter metadata for prepared statement: `{}`",
177 sql_preview(sql.as_str())
178 ),
179 )
180 })?;
181 let columns = collect_prepared_columns(&mut prepared, parameters)?;
182
183 Ok(OdbcStatement::new(sql, columns, usize::from(parameters)))
184 }
185
186 pub(crate) fn run_blocking_sql(
187 &mut self,
188 sql: &str,
189 arguments: Option<&OdbcArguments>,
190 ) -> std::result::Result<OdbcExecution, sqlx_core::Error> {
191 let mut statement = self.conn.preallocate().map_err(|error| {
192 crate::error::database_error_with_context(
193 error,
194 format!(
195 "failed to allocate an ODBC statement for query: `{}`",
196 sql_preview(sql)
197 ),
198 )
199 })?;
200 let parameters = odbc_parameters(arguments);
201
202 if let Some(cursor) = statement
203 .execute(sql, parameters.as_slice())
204 .map_err(|error| {
205 crate::error::database_error_with_context(
206 error,
207 format!("failed to execute ODBC query: `{}`", sql_preview(sql)),
208 )
209 })?
210 {
211 return collect_rows(cursor, self.buffer_settings).map(OdbcExecution::Rows);
212 }
213
214 let rows_affected = statement
215 .row_count()
216 .map_err(|error| {
217 crate::error::database_error_with_context(
218 error,
219 format!(
220 "failed to read ODBC row count for query: `{}`",
221 sql_preview(sql)
222 ),
223 )
224 })?
225 .unwrap_or(0);
226
227 let rows_affected = rows_affected.try_into().map_err(|_| {
228 sqlx_core::Error::Protocol("ODBC row count does not fit in u64".to_owned())
229 })?;
230
231 Ok(OdbcExecution::Done(OdbcQueryResult::new(rows_affected)))
232 }
233}
234
235impl sqlx_core::connection::Connection for OdbcConnection {
236 type Database = crate::Odbc;
237 type Options = OdbcConnectOptions;
238
239 async fn close(self) -> std::result::Result<(), sqlx_core::Error> {
240 drop(self);
241 Ok(())
242 }
243
244 async fn close_hard(self) -> std::result::Result<(), sqlx_core::Error> {
245 drop(self);
246 Ok(())
247 }
248
249 async fn ping(&mut self) -> std::result::Result<(), sqlx_core::Error> {
250 self.ping_blocking().map_err(Into::into)
251 }
252
253 fn begin(
254 &mut self,
255 ) -> impl Future<Output = std::result::Result<Transaction<'_, Self::Database>, sqlx_core::Error>>
256 + Send
257 + '_ {
258 Transaction::begin(self, None)
259 }
260
261 fn shrink_buffers(&mut self) {}
262
263 async fn flush(&mut self) -> std::result::Result<(), sqlx_core::Error> {
264 Ok(())
265 }
266
267 fn should_flush(&self) -> bool {
268 false
269 }
270}
271
272impl<'c> Executor<'c> for &'c mut OdbcConnection {
273 type Database = crate::Odbc;
274
275 fn fetch_many<'e, 'q, E>(
276 self,
277 mut query: E,
278 ) -> BoxStream<'e, std::result::Result<Either<OdbcQueryResult, OdbcRow>, sqlx_core::Error>>
279 where
280 'c: 'e,
281 E: Execute<'q, Self::Database>,
282 'q: 'e,
283 E: 'q,
284 {
285 let arguments = query.take_arguments().map_err(sqlx_core::Error::Encode);
286 let sql = query.sql();
287
288 stream::once(async move {
289 let arguments = arguments?;
290 self.run_blocking_sql(sql.as_str(), arguments.as_ref())
291 })
292 .map(|result| match result {
293 Ok(OdbcExecution::Done(result)) => {
294 stream::once(future::ready(Ok(Either::Left(result)))).boxed()
295 }
296 Ok(OdbcExecution::Rows(rows)) => stream::iter(
297 rows.into_iter()
298 .map(|row| Ok(Either::Right(row)))
299 .chain(std::iter::once(Ok(Either::Left(OdbcQueryResult::new(0))))),
300 )
301 .boxed(),
302 Err(error) => stream::once(future::ready(Err(error))).boxed(),
303 })
304 .flatten()
305 .boxed()
306 }
307
308 fn fetch_optional<'e, 'q, E>(
309 self,
310 mut query: E,
311 ) -> BoxFuture<'e, std::result::Result<Option<OdbcRow>, sqlx_core::Error>>
312 where
313 'c: 'e,
314 E: Execute<'q, Self::Database>,
315 'q: 'e,
316 E: 'q,
317 {
318 let arguments = query.take_arguments().map_err(sqlx_core::Error::Encode);
319 let sql = query.sql();
320
321 Box::pin(async move {
322 let arguments = arguments?;
323
324 match self.run_blocking_sql(sql.as_str(), arguments.as_ref())? {
325 OdbcExecution::Rows(rows) => Ok(rows.into_iter().next()),
326 OdbcExecution::Done(_) => Ok(None),
327 }
328 })
329 }
330
331 fn prepare_with<'e>(
332 self,
333 sql: sqlx_core::sql_str::SqlStr,
334 _parameters: &[crate::OdbcTypeInfo],
335 ) -> BoxFuture<'e, std::result::Result<OdbcStatement, sqlx_core::Error>>
336 where
337 'c: 'e,
338 {
339 Box::pin(async move { self.prepare_blocking(sql) })
340 }
341}
342
343pub(crate) enum OdbcExecution {
344 Done(OdbcQueryResult),
345 Rows(Vec<OdbcRow>),
346}
347
348fn odbc_parameters(arguments: Option<&OdbcArguments>) -> OdbcParameterCollection {
349 arguments
350 .map(OdbcArguments::to_odbc_parameter_collection)
351 .unwrap_or_default()
352}
353
354fn ping_query_for_dbms_name(dbms_name: &str) -> &'static str {
355 let dbms_name = dbms_name.to_ascii_uppercase();
356
357 if dbms_name.contains("DB2")
358 || dbms_name.contains("DB/2")
359 || dbms_name.contains("ISERIES")
360 || dbms_name.contains("AS/400")
361 || dbms_name.contains("IBM I")
362 {
363 "SELECT 1 FROM SYSIBM.SYSDUMMY1"
364 } else {
365 "SELECT 1"
366 }
367}
368
369fn collect_columns(
370 cursor: &mut impl ResultSetMetadata,
371) -> std::result::Result<Vec<OdbcColumn>, sqlx_core::Error> {
372 let count = cursor.num_result_cols().map_err(|error| {
373 crate::error::database_error_with_context(error, "failed to read ODBC result-column count")
374 })?;
375 let count = usize::try_from(count).map_err(|_| {
376 sqlx_core::Error::Protocol(format!("ODBC returned a negative column count: {count}"))
377 })?;
378
379 let mut columns = Vec::with_capacity(count);
380 for ordinal in 0..count {
381 let column_number = u16::try_from(ordinal + 1).map_err(|_| {
382 sqlx_core::Error::Protocol(format!("ODBC column index exceeds u16: {}", ordinal + 1))
383 })?;
384
385 let mut description = odbc_api::ColumnDescription::default();
386 cursor
387 .describe_col(column_number, &mut description)
388 .map_err(|error| {
389 crate::error::database_error_with_context(
390 error,
391 format!("failed to describe ODBC result column {column_number}"),
392 )
393 })?;
394 let name = description
395 .name_to_string()
396 .unwrap_or_else(|_| format!("col{ordinal}"));
397
398 columns.push(OdbcColumn::new(
399 ordinal,
400 name,
401 OdbcTypeInfo::new(description.data_type),
402 ));
403 }
404
405 Ok(columns)
406}
407
408fn collect_prepared_columns(
409 prepared: &mut impl PreparedStatementMetadata,
410 parameter_count: u16,
411) -> std::result::Result<Vec<OdbcColumn>, sqlx_core::Error> {
412 match collect_columns(prepared) {
413 Ok(columns) => Ok(columns),
414 Err(error) if parameter_count > 0 => {
415 validate_parameter_metadata(prepared, parameter_count)?;
416 log::debug!("ODBC driver deferred result-column metadata until execution: {error}");
417 Ok(Vec::new())
418 }
419 Err(error) => Err(error),
420 }
421}
422
423trait PreparedStatementMetadata: ResultSetMetadata {
424 fn describe_prepared_parameter(
425 &mut self,
426 index: u16,
427 ) -> std::result::Result<(), odbc_api::Error>;
428}
429
430impl<S> PreparedStatementMetadata for odbc_api::Prepared<S>
431where
432 S: odbc_api::handles::AsStatementRef,
433{
434 fn describe_prepared_parameter(
435 &mut self,
436 index: u16,
437 ) -> std::result::Result<(), odbc_api::Error> {
438 self.describe_param(index).map(|_| ())
439 }
440}
441
442fn validate_parameter_metadata(
443 prepared: &mut impl PreparedStatementMetadata,
444 parameter_count: u16,
445) -> std::result::Result<(), sqlx_core::Error> {
446 for index in 1..=parameter_count {
447 prepared
448 .describe_prepared_parameter(index)
449 .map_err(|error| {
450 crate::error::database_error_with_context(
451 error,
452 format!("failed to describe ODBC parameter {index}"),
453 )
454 })?;
455 }
456
457 Ok(())
458}
459
460fn collect_rows<C>(
461 cursor: C,
462 settings: OdbcBufferSettings,
463) -> std::result::Result<Vec<OdbcRow>, sqlx_core::Error>
464where
465 C: Cursor + ResultSetMetadata,
466{
467 if let Some(max_column_size) = settings.max_column_size {
468 collect_rows_buffered(cursor, settings.batch_size, max_column_size)
469 } else {
470 collect_rows_unbuffered(cursor)
471 }
472}
473
474#[derive(Debug)]
475struct ColumnBinding {
476 column: OdbcColumn,
477 buffer_desc: BufferDesc,
478}
479
480fn collect_rows_buffered<C>(
481 cursor: C,
482 batch_size: usize,
483 max_column_size: usize,
484) -> std::result::Result<Vec<OdbcRow>, sqlx_core::Error>
485where
486 C: Cursor + ResultSetMetadata,
487{
488 let mut cursor = cursor;
489 let bindings = build_buffer_bindings(&mut cursor, max_column_size)?;
490 let buffer_descriptions = bindings
491 .iter()
492 .map(|binding| binding.buffer_desc)
493 .collect::<Vec<_>>();
494 let mut row_set_cursor = cursor
495 .bind_buffer(ColumnarDynBuffer::from_descs(
496 batch_size,
497 buffer_descriptions,
498 ))
499 .map_err(|error| {
500 crate::error::database_error_with_context(
501 error,
502 format!(
503 "ODBC buffered fetching could not be enabled with batch_size={batch_size}; \
504 this driver may reject the row-array or row-binding statement attributes \
505 used for column-wise buffered fetching, so use \
506 OdbcConnectOptions::max_column_size(None) to fetch rows unbuffered"
507 ),
508 )
509 })?;
510 let columns: Arc<[OdbcColumn]> = bindings
511 .iter()
512 .map(|binding| binding.column.clone())
513 .collect::<Vec<_>>()
514 .into();
515 let mut rows = Vec::new();
516
517 while let Some(batch) = row_set_cursor.fetch().map_err(|error| {
518 crate::error::database_error_with_context(error, "ODBC buffered fetch failed")
519 })? {
520 let column_values = bindings
521 .iter()
522 .enumerate()
523 .map(|(index, binding)| {
524 buffered_column_values(batch.column(index), binding).map_err(|error| {
525 sqlx_core::Error::Protocol(format!(
526 "ODBC buffered fetch could not convert column {} (`{}`) using buffer {:?}: {error}",
527 binding.column.ordinal() + 1,
528 binding.column.name(),
529 binding.buffer_desc
530 ))
531 })
532 })
533 .collect::<std::result::Result<Vec<_>, _>>()?;
534
535 let mut column_iters = column_values
536 .into_iter()
537 .map(Vec::into_iter)
538 .collect::<Vec<_>>();
539
540 for row_index in 0..batch.num_rows() {
541 let values = column_iters
542 .iter_mut()
543 .map(|values| {
544 values.next().map(OdbcValue::new).ok_or_else(|| {
545 sqlx_core::Error::Protocol(format!(
546 "ODBC buffered fetch produced too few values for row {}",
547 row_index + 1
548 ))
549 })
550 })
551 .collect::<std::result::Result<Vec<_>, _>>()?;
552 rows.push(OdbcRow::new_shared(Arc::clone(&columns), values));
553 }
554 }
555
556 Ok(rows)
557}
558
559fn build_buffer_bindings(
560 cursor: &mut impl ResultSetMetadata,
561 max_column_size: usize,
562) -> std::result::Result<Vec<ColumnBinding>, sqlx_core::Error> {
563 collect_columns(cursor).map(|columns| {
564 columns
565 .into_iter()
566 .map(|column| ColumnBinding {
567 buffer_desc: map_buffer_desc(column.type_info().data_type(), max_column_size),
568 column,
569 })
570 .collect()
571 })
572}
573
574fn map_buffer_desc(data_type: DataType, max_column_size: usize) -> BufferDesc {
575 match data_type {
576 DataType::TinyInt | DataType::SmallInt | DataType::Integer | DataType::BigInt => {
577 BufferDesc::I64 { nullable: true }
578 }
579 DataType::Real => BufferDesc::F32 { nullable: true },
580 DataType::Float { .. } | DataType::Double => BufferDesc::F64 { nullable: true },
581 DataType::Bit => BufferDesc::Bit { nullable: true },
582 DataType::Date => BufferDesc::Date { nullable: true },
583 DataType::Time { .. } => BufferDesc::Time { nullable: true },
584 DataType::Timestamp { .. } => BufferDesc::Timestamp { nullable: true },
585 DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => {
586 BufferDesc::Binary {
587 max_bytes: max_column_size,
588 }
589 }
590 DataType::Char { .. }
591 | DataType::WChar { .. }
592 | DataType::Varchar { .. }
593 | DataType::WVarchar { .. }
594 | DataType::LongVarchar { .. }
595 | DataType::WLongVarchar { .. }
596 | DataType::Other { .. }
597 | DataType::Unknown
598 | DataType::Decimal { .. }
599 | DataType::Numeric { .. } => BufferDesc::Text {
600 max_str_len: max_column_size,
601 },
602 }
603}
604
605fn buffered_column_values(
606 slice: AnyColumnBufferSlice<'_>,
607 binding: &ColumnBinding,
608) -> std::result::Result<Vec<OdbcValueKind>, sqlx_core::Error> {
609 let desc = binding.buffer_desc;
610 Ok(match desc {
611 BufferDesc::I8 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
612 OdbcValueKind::TinyInt(value)
613 })?,
614 BufferDesc::I16 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
615 OdbcValueKind::SmallInt(value)
616 })?,
617 BufferDesc::I32 { nullable } => buffered_numeric(&slice, desc, nullable, |value| {
618 OdbcValueKind::Integer(value)
619 })?,
620 BufferDesc::I64 { nullable } => {
621 buffered_numeric(&slice, desc, nullable, OdbcValueKind::BigInt)?
622 }
623 BufferDesc::U8 { nullable } => buffered_numeric(&slice, desc, nullable, |value: u8| {
624 OdbcValueKind::BigInt(i64::from(value))
625 })?,
626 BufferDesc::F32 { nullable } => {
627 buffered_numeric(&slice, desc, nullable, OdbcValueKind::Real)?
628 }
629 BufferDesc::F64 { nullable } => {
630 buffered_numeric(&slice, desc, nullable, OdbcValueKind::Double)?
631 }
632 BufferDesc::Bit { nullable } => {
633 buffered_numeric(&slice, desc, nullable, |value: odbc_api::Bit| {
634 OdbcValueKind::Bit(value.as_bool())
635 })?
636 }
637 BufferDesc::Date { nullable } => {
638 buffered_numeric(&slice, desc, nullable, OdbcValueKind::Date)?
639 }
640 BufferDesc::Time { nullable } => {
641 buffered_numeric(&slice, desc, nullable, OdbcValueKind::Time)?
642 }
643 BufferDesc::Timestamp { nullable } => {
644 buffered_numeric(&slice, desc, nullable, OdbcValueKind::Timestamp)?
645 }
646 BufferDesc::Text { .. } => {
647 let text = expect_buffer_slice(slice.as_text(), desc)?;
648 text.iter()
649 .map(|value| {
650 value
651 .map(|bytes| {
652 OdbcValueKind::Text(String::from_utf8_lossy(bytes).into_owned())
653 })
654 .unwrap_or(OdbcValueKind::Null)
655 })
656 .collect()
657 }
658 BufferDesc::WText { .. } => {
659 let text = expect_buffer_slice(slice.as_wide_text(), desc)?;
660 text.iter()
661 .map(|value| {
662 value
663 .map(|chars| OdbcValueKind::Text(String::from_utf16_lossy(chars.into())))
664 .unwrap_or(OdbcValueKind::Null)
665 })
666 .collect()
667 }
668 BufferDesc::Binary { .. } => {
669 let binary = expect_buffer_slice(slice.as_binary(), desc)?;
670 binary
671 .iter()
672 .map(|value| {
673 value
674 .map(|bytes| OdbcValueKind::Binary(bytes.to_vec()))
675 .unwrap_or(OdbcValueKind::Null)
676 })
677 .collect()
678 }
679 BufferDesc::Numeric => {
680 return Err(sqlx_core::Error::Protocol(format!(
681 "unsupported ODBC buffer descriptor: {desc:?}"
682 )))
683 }
684 })
685}
686
687fn buffered_numeric<T, F>(
688 slice: &AnyColumnBufferSlice<'_>,
689 desc: BufferDesc,
690 nullable: bool,
691 map: F,
692) -> std::result::Result<Vec<OdbcValueKind>, sqlx_core::Error>
693where
694 T: Copy + odbc_api::Pod,
695 F: FnMut(T) -> OdbcValueKind,
696{
697 if nullable {
698 Ok(buffered_nullable_numeric(
699 expect_buffer_slice(slice.as_nullable_slice::<T>(), desc)?,
700 map,
701 ))
702 } else {
703 Ok(expect_buffer_slice(slice.as_slice::<T>(), desc)?
704 .iter()
705 .copied()
706 .map(map)
707 .collect())
708 }
709}
710
711fn buffered_nullable_numeric<T, F>(slice: NullableSlice<'_, T>, mut map: F) -> Vec<OdbcValueKind>
712where
713 T: Copy,
714 F: FnMut(T) -> OdbcValueKind,
715{
716 slice
717 .map(|value| value.copied().map(&mut map).unwrap_or(OdbcValueKind::Null))
718 .collect()
719}
720
721fn expect_buffer_slice<T>(
722 slice: Option<T>,
723 desc: BufferDesc,
724) -> std::result::Result<T, sqlx_core::Error> {
725 slice.ok_or_else(|| {
726 sqlx_core::Error::Protocol(format!(
727 "ODBC column buffer {desc:?} did not match fetched slice"
728 ))
729 })
730}
731
732fn collect_rows_unbuffered<C>(mut cursor: C) -> std::result::Result<Vec<OdbcRow>, sqlx_core::Error>
733where
734 C: Cursor + ResultSetMetadata,
735{
736 let columns: Arc<[OdbcColumn]> = collect_columns(&mut cursor)?.into();
737 let mut rows = Vec::new();
738
739 while let Some(mut cursor_row) = cursor.next_row().map_err(|error| {
740 crate::error::database_error_with_context(
741 error,
742 "ODBC unbuffered fetch failed while reading the next row",
743 )
744 })? {
745 let mut values = Vec::with_capacity(columns.len());
746
747 for column in columns.iter() {
748 let column_number = u16::try_from(sqlx_core::column::Column::ordinal(column) + 1)
749 .map_err(|_| {
750 sqlx_core::Error::Protocol("ODBC column index exceeds u16".to_owned())
751 })?;
752 values.push(fetch_value(&mut cursor_row, column_number, column)?);
753 }
754
755 rows.push(OdbcRow::new_shared(Arc::clone(&columns), values));
756 }
757
758 Ok(rows)
759}
760
761fn fetch_value(
762 row: &mut odbc_api::CursorRow<'_>,
763 column_number: u16,
764 column: &OdbcColumn,
765) -> std::result::Result<OdbcValue, sqlx_core::Error> {
766 let data_type = column.type_info().data_type();
767
768 let kind = match data_type {
769 DataType::Bit => {
770 let mut value = Nullable::<odbc_api::Bit>::null();
771 row.get_data(column_number, &mut value).map_err(|error| {
772 crate::error::database_error_with_context_lazy(error, || {
773 fetch_context(column, data_type)
774 })
775 })?;
776 value
777 .into_opt()
778 .map(|value| OdbcValueKind::Bit(value.as_bool()))
779 .unwrap_or(OdbcValueKind::Null)
780 }
781 DataType::TinyInt => fetch_nullable(
782 row,
783 column_number,
784 column,
785 data_type,
786 OdbcValueKind::TinyInt,
787 )?,
788 DataType::SmallInt => fetch_nullable(
789 row,
790 column_number,
791 column,
792 data_type,
793 OdbcValueKind::SmallInt,
794 )?,
795 DataType::Integer => fetch_nullable(
796 row,
797 column_number,
798 column,
799 data_type,
800 OdbcValueKind::Integer,
801 )?,
802 DataType::BigInt => {
803 fetch_nullable(row, column_number, column, data_type, OdbcValueKind::BigInt)?
804 }
805 DataType::Real => {
806 fetch_nullable(row, column_number, column, data_type, OdbcValueKind::Real)?
807 }
808 DataType::Float { .. } | DataType::Double => {
809 fetch_nullable(row, column_number, column, data_type, OdbcValueKind::Double)?
810 }
811 DataType::Date => {
812 fetch_nullable(row, column_number, column, data_type, OdbcValueKind::Date)?
813 }
814 DataType::Time { .. } => {
815 fetch_nullable(row, column_number, column, data_type, OdbcValueKind::Time)?
816 }
817 DataType::Timestamp { .. } => fetch_nullable(
818 row,
819 column_number,
820 column,
821 data_type,
822 OdbcValueKind::Timestamp,
823 )?,
824 DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => {
825 let mut value = Vec::new();
826 if row.get_binary(column_number, &mut value).map_err(|error| {
827 crate::error::database_error_with_context_lazy(error, || {
828 fetch_context(column, data_type)
829 })
830 })? {
831 OdbcValueKind::Binary(value)
832 } else {
833 OdbcValueKind::Null
834 }
835 }
836 _ => {
837 let mut value = Vec::new();
838 if row
839 .get_wide_text(column_number, &mut value)
840 .map_err(|error| {
841 crate::error::database_error_with_context_lazy(error, || {
842 fetch_context(column, data_type)
843 })
844 })?
845 {
846 OdbcValueKind::Text(String::from_utf16_lossy(&value))
847 } else {
848 OdbcValueKind::Null
849 }
850 }
851 };
852
853 Ok(OdbcValue::new(kind))
854}
855
856fn fetch_nullable<T, F>(
857 row: &mut odbc_api::CursorRow<'_>,
858 column_number: u16,
859 column: &OdbcColumn,
860 data_type: DataType,
861 map: F,
862) -> std::result::Result<OdbcValueKind, sqlx_core::Error>
863where
864 T: Default + Copy + odbc_api::parameter::CElement + odbc_api::handles::CDataMut,
865 Nullable<T>: odbc_api::parameter::CElement + odbc_api::handles::CDataMut,
866 F: FnOnce(T) -> OdbcValueKind,
867{
868 let mut value = Nullable::<T>::null();
869 row.get_data(column_number, &mut value).map_err(|error| {
870 crate::error::database_error_with_context_lazy(error, || fetch_context(column, data_type))
871 })?;
872 Ok(value.into_opt().map(map).unwrap_or(OdbcValueKind::Null))
873}
874
875fn fetch_context(column: &OdbcColumn, data_type: DataType) -> String {
876 format!(
877 "failed to fetch ODBC column {} (`{}`) as {data_type:?}",
878 column.ordinal() + 1,
879 column.name()
880 )
881}
882
883fn sql_preview(sql: &str) -> String {
884 const MAX_LEN: usize = 160;
885
886 let compact = sql.split_whitespace().collect::<Vec<_>>().join(" ");
887 if compact.len() <= MAX_LEN {
888 compact
889 } else {
890 let mut preview = compact.chars().take(MAX_LEN - 3).collect::<String>();
891 preview.push_str("...");
892 preview
893 }
894}
895
896#[cfg(test)]
897mod tests {
898 use super::*;
899
900 #[test]
901 fn buffered_fetch_maps_numeric_types_to_nullable_64_bit_buffers() {
902 assert!(matches!(
903 map_buffer_desc(DataType::TinyInt, 64),
904 BufferDesc::I64 { nullable: true }
905 ));
906 assert!(matches!(
907 map_buffer_desc(DataType::Integer, 64),
908 BufferDesc::I64 { nullable: true }
909 ));
910 assert!(matches!(
911 map_buffer_desc(DataType::BigInt, 64),
912 BufferDesc::I64 { nullable: true }
913 ));
914 }
915
916 #[test]
917 fn buffered_fetch_uses_configured_limits_for_variable_sized_data() {
918 assert_eq!(
919 map_buffer_desc(DataType::Varchar { length: None }, 32),
920 BufferDesc::Text { max_str_len: 32 }
921 );
922 assert_eq!(
923 map_buffer_desc(DataType::Varbinary { length: None }, 16),
924 BufferDesc::Binary { max_bytes: 16 }
925 );
926 }
927
928 #[test]
929 fn ping_query_uses_db2_dummy_table_for_db2_drivers() {
930 assert_eq!(
931 "SELECT 1 FROM SYSIBM.SYSDUMMY1",
932 ping_query_for_dbms_name("DB2")
933 );
934 assert_eq!(
935 "SELECT 1 FROM SYSIBM.SYSDUMMY1",
936 ping_query_for_dbms_name("DB2 UDB for AS/400")
937 );
938 assert_eq!(
939 "SELECT 1 FROM SYSIBM.SYSDUMMY1",
940 ping_query_for_dbms_name("IBM DB2 for i")
941 );
942 assert_eq!(
943 "SELECT 1 FROM SYSIBM.SYSDUMMY1",
944 ping_query_for_dbms_name("iSeries")
945 );
946 assert_eq!(
947 "SELECT 1 FROM SYSIBM.SYSDUMMY1",
948 ping_query_for_dbms_name("IBM i")
949 );
950 }
951
952 #[test]
953 fn ping_query_keeps_select_one_for_non_db2_drivers() {
954 assert_eq!("SELECT 1", ping_query_for_dbms_name("DuckDB"));
955 assert_eq!("SELECT 1", ping_query_for_dbms_name("Microsoft SQL Server"));
956 assert_eq!("SELECT 1", ping_query_for_dbms_name("PostgreSQL"));
957 }
958}