1use std::{convert, sync::Arc};
2
3use super::{Error, Result, Statement};
4use crate::types::{self, EnumType, FromSql, FromSqlError, ListType, ValueRef};
5
6use arrow::{
7 array::{self, Array, ArrayRef, DictionaryArray, FixedSizeListArray, ListArray, MapArray, StructArray},
8 datatypes::*,
9};
10use fallible_iterator::FallibleIterator;
11use fallible_streaming_iterator::FallibleStreamingIterator;
12use rust_decimal::prelude::*;
13
14#[must_use = "Rows is lazy and will do nothing unless consumed"]
16pub struct Rows<'stmt> {
17 pub(crate) stmt: Option<&'stmt Statement<'stmt>>,
18 arr: Arc<Option<StructArray>>,
19 row: Option<Row<'stmt>>,
20 current_row: usize,
21 current_batch_row: usize,
22}
23
24impl<'stmt> Rows<'stmt> {
25 #[inline]
26 fn reset(&mut self) {
27 self.current_row = 0;
28 self.current_batch_row = 0;
29 self.arr = Arc::new(None);
30 }
31
32 #[allow(clippy::should_implement_trait)] #[inline]
46 pub fn next(&mut self) -> Result<Option<&Row<'stmt>>> {
47 self.advance()?;
48 Ok((*self).get())
49 }
50
51 #[inline]
52 fn batch_row_count(&self) -> usize {
53 if self.arr.is_none() {
54 return 0;
55 }
56 self.arr.as_ref().as_ref().unwrap().len()
57 }
58
59 #[inline]
71 pub fn map<F, B>(self, f: F) -> Map<'stmt, F>
72 where
73 F: FnMut(&Row<'_>) -> Result<B>,
74 {
75 Map { rows: self, f }
76 }
77
78 #[inline]
81 pub fn mapped<F, B>(self, f: F) -> MappedRows<'stmt, F>
82 where
83 F: FnMut(&Row<'_>) -> Result<B>,
84 {
85 MappedRows { rows: self, map: f }
86 }
87
88 #[inline]
92 pub fn and_then<F, T, E>(self, f: F) -> AndThenRows<'stmt, F>
93 where
94 F: FnMut(&Row<'_>) -> Result<T, E>,
95 {
96 AndThenRows { rows: self, map: f }
97 }
98
99 pub fn as_ref(&self) -> Option<&Statement<'stmt>> {
101 self.stmt
102 }
103}
104
105impl<'stmt> Rows<'stmt> {
106 #[inline]
107 pub(crate) fn new(stmt: &'stmt Statement<'stmt>) -> Rows<'stmt> {
108 Rows {
109 stmt: Some(stmt),
110 arr: Arc::new(None),
111 row: None,
112 current_row: 0,
113 current_batch_row: 0,
114 }
115 }
116
117 #[inline]
118 pub(crate) fn get_expected_row(&mut self) -> Result<&Row<'stmt>> {
119 match self.next()? {
120 Some(row) => Ok(row),
121 None => Err(Error::QueryReturnedNoRows),
122 }
123 }
124}
125
126#[must_use = "iterators are lazy and do nothing unless consumed"]
129pub struct Map<'stmt, F> {
130 rows: Rows<'stmt>,
131 f: F,
132}
133
134impl<F, B> FallibleIterator for Map<'_, F>
135where
136 F: FnMut(&Row<'_>) -> Result<B>,
137{
138 type Error = Error;
139 type Item = B;
140
141 #[inline]
142 fn next(&mut self) -> Result<Option<B>> {
143 match self.rows.next()? {
144 Some(v) => Ok(Some((self.f)(v)?)),
145 None => Ok(None),
146 }
147 }
148}
149
150#[must_use = "iterators are lazy and do nothing unless consumed"]
155pub struct MappedRows<'stmt, F> {
156 rows: Rows<'stmt>,
157 map: F,
158}
159
160impl<T, F> Iterator for MappedRows<'_, F>
161where
162 F: FnMut(&Row<'_>) -> Result<T>,
163{
164 type Item = Result<T>;
165
166 #[inline]
167 fn next(&mut self) -> Option<Result<T>> {
168 let map = &mut self.map;
169 self.rows.next().transpose().map(|row_result| row_result.and_then(map))
170 }
171}
172
173#[must_use = "iterators are lazy and do nothing unless consumed"]
176pub struct AndThenRows<'stmt, F> {
177 rows: Rows<'stmt>,
178 map: F,
179}
180
181impl<T, E, F> Iterator for AndThenRows<'_, F>
182where
183 E: convert::From<Error>,
184 F: FnMut(&Row<'_>) -> Result<T, E>,
185{
186 type Item = Result<T, E>;
187
188 #[inline]
189 fn next(&mut self) -> Option<Self::Item> {
190 let map = &mut self.map;
191 self.rows
192 .next()
193 .transpose()
194 .map(|row_result| row_result.map_err(E::from).and_then(map))
195 }
196}
197
198impl<'stmt> FallibleStreamingIterator for Rows<'stmt> {
217 type Error = Error;
218 type Item = Row<'stmt>;
219
220 #[inline]
221 fn advance(&mut self) -> Result<()> {
222 match self.stmt {
223 Some(stmt) => {
224 if self.current_row < stmt.row_count() {
225 if self.current_batch_row >= self.batch_row_count() {
226 self.arr = Arc::new(stmt.step());
227 if self.arr.is_none() {
228 self.row = None;
229 return Ok(());
230 }
231 self.current_batch_row = 0;
232 }
233 self.row = Some(Row {
234 stmt,
235 arr: self.arr.clone(),
236 current_row: self.current_batch_row,
237 });
238 self.current_row += 1;
239 self.current_batch_row += 1;
240 Ok(())
241 } else {
242 self.reset();
243 self.row = None;
244 Ok(())
245 }
246 }
247 None => {
248 self.row = None;
249 Ok(())
250 }
251 }
252 }
253
254 #[inline]
255 fn get(&self) -> Option<&Row<'stmt>> {
256 self.row.as_ref()
257 }
258}
259
260pub struct Row<'stmt> {
262 pub(crate) stmt: &'stmt Statement<'stmt>,
263 arr: Arc<Option<StructArray>>,
264 current_row: usize,
265}
266
267#[allow(clippy::needless_lifetimes)]
268impl<'stmt> Row<'stmt> {
269 pub fn get_unwrap<I: RowIndex, T: FromSql>(&self, idx: I) -> T {
282 self.get(idx).unwrap()
283 }
284
285 pub fn get<I: RowIndex, T: FromSql>(&self, idx: I) -> Result<T> {
302 let idx = idx.idx(self.stmt)?;
303 let value = self.value_ref(self.current_row, idx);
304 FromSql::column_result(value).map_err(|err| match err {
305 FromSqlError::InvalidType => {
306 Error::InvalidColumnType(idx, self.stmt.column_name_unwrap(idx).into(), value.data_type())
307 }
308 FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i),
309 FromSqlError::Other(err) => Error::FromSqlConversionFailure(idx, value.data_type(), err),
310 #[cfg(feature = "uuid")]
311 FromSqlError::InvalidUuidSize(_) => {
312 Error::InvalidColumnType(idx, self.stmt.column_name_unwrap(idx).into(), value.data_type())
313 }
314 })
315 }
316
317 pub fn get_ref<I: RowIndex>(&self, idx: I) -> Result<ValueRef<'_>> {
333 let idx = idx.idx(self.stmt)?;
334 let val_ref = self.value_ref(self.current_row, idx);
338 Ok(val_ref)
339 }
340
341 fn value_ref(&self, row: usize, col: usize) -> ValueRef<'_> {
342 let column = self.arr.as_ref().as_ref().unwrap().column(col);
343 Self::value_ref_internal(row, col, column)
344 }
345
346 pub(crate) fn value_ref_internal(row: usize, col: usize, column: &ArrayRef) -> ValueRef {
347 if column.is_null(row) {
348 return ValueRef::Null;
349 }
350 match column.data_type() {
353 DataType::Utf8 => {
354 let array = column.as_any().downcast_ref::<array::StringArray>().unwrap();
355
356 if array.is_null(row) {
357 return ValueRef::Null;
358 }
359 ValueRef::from(array.value(row))
360 }
361 DataType::LargeUtf8 => {
362 let array = column.as_any().downcast_ref::<array::LargeStringArray>().unwrap();
363
364 if array.is_null(row) {
365 return ValueRef::Null;
366 }
367 ValueRef::from(array.value(row))
368 }
369 DataType::Binary => {
370 let array = column.as_any().downcast_ref::<array::BinaryArray>().unwrap();
371
372 if array.is_null(row) {
373 return ValueRef::Null;
374 }
375 ValueRef::Blob(array.value(row))
376 }
377 DataType::LargeBinary => {
378 let array = column.as_any().downcast_ref::<array::LargeBinaryArray>().unwrap();
379
380 if array.is_null(row) {
381 return ValueRef::Null;
382 }
383 ValueRef::Blob(array.value(row))
384 }
385 DataType::Boolean => {
386 let array = column.as_any().downcast_ref::<array::BooleanArray>().unwrap();
387
388 if array.is_null(row) {
389 return ValueRef::Null;
390 }
391 ValueRef::Boolean(array.value(row))
392 }
393 DataType::Int8 => {
394 let array = column.as_any().downcast_ref::<array::Int8Array>().unwrap();
395
396 if array.is_null(row) {
397 return ValueRef::Null;
398 }
399 ValueRef::TinyInt(array.value(row))
400 }
401 DataType::Int16 => {
402 let array = column.as_any().downcast_ref::<array::Int16Array>().unwrap();
403
404 if array.is_null(row) {
405 return ValueRef::Null;
406 }
407 ValueRef::SmallInt(array.value(row))
408 }
409 DataType::Int32 => {
410 let array = column.as_any().downcast_ref::<array::Int32Array>().unwrap();
411
412 if array.is_null(row) {
413 return ValueRef::Null;
414 }
415 ValueRef::Int(array.value(row))
416 }
417 DataType::Int64 => {
418 let array = column.as_any().downcast_ref::<array::Int64Array>().unwrap();
419
420 if array.is_null(row) {
421 return ValueRef::Null;
422 }
423 ValueRef::BigInt(array.value(row))
424 }
425 DataType::UInt8 => {
426 let array = column.as_any().downcast_ref::<array::UInt8Array>().unwrap();
427
428 if array.is_null(row) {
429 return ValueRef::Null;
430 }
431 ValueRef::UTinyInt(array.value(row))
432 }
433 DataType::UInt16 => {
434 let array = column.as_any().downcast_ref::<array::UInt16Array>().unwrap();
435
436 if array.is_null(row) {
437 return ValueRef::Null;
438 }
439 ValueRef::USmallInt(array.value(row))
440 }
441 DataType::UInt32 => {
442 let array = column.as_any().downcast_ref::<array::UInt32Array>().unwrap();
443
444 if array.is_null(row) {
445 return ValueRef::Null;
446 }
447 ValueRef::UInt(array.value(row))
448 }
449 DataType::UInt64 => {
450 let array = column.as_any().downcast_ref::<array::UInt64Array>().unwrap();
451
452 if array.is_null(row) {
453 return ValueRef::Null;
454 }
455 ValueRef::UBigInt(array.value(row))
456 }
457 DataType::Float16 => {
458 let array = column.as_any().downcast_ref::<array::Float32Array>().unwrap();
459
460 if array.is_null(row) {
461 return ValueRef::Null;
462 }
463 ValueRef::Float(array.value(row))
464 }
465 DataType::Float32 => {
466 let array = column.as_any().downcast_ref::<array::Float32Array>().unwrap();
467
468 if array.is_null(row) {
469 return ValueRef::Null;
470 }
471 ValueRef::Float(array.value(row))
472 }
473 DataType::Float64 => {
474 let array = column.as_any().downcast_ref::<array::Float64Array>().unwrap();
475
476 if array.is_null(row) {
477 return ValueRef::Null;
478 }
479 ValueRef::Double(array.value(row))
480 }
481 DataType::Decimal128(..) => {
482 let array = column.as_any().downcast_ref::<array::Decimal128Array>().unwrap();
483
484 if array.is_null(row) {
485 return ValueRef::Null;
486 }
487 if array.scale() == 0 {
489 return ValueRef::HugeInt(array.value(row));
490 }
491 ValueRef::Decimal(Decimal::from_i128_with_scale(array.value(row), array.scale() as u32))
492 }
493 DataType::Timestamp(unit, _) if *unit == TimeUnit::Second => {
494 let array = column.as_any().downcast_ref::<array::TimestampSecondArray>().unwrap();
495
496 if array.is_null(row) {
497 return ValueRef::Null;
498 }
499 ValueRef::Timestamp(types::TimeUnit::Second, array.value(row))
500 }
501 DataType::Timestamp(unit, _) if *unit == TimeUnit::Millisecond => {
502 let array = column
503 .as_any()
504 .downcast_ref::<array::TimestampMillisecondArray>()
505 .unwrap();
506
507 if array.is_null(row) {
508 return ValueRef::Null;
509 }
510 ValueRef::Timestamp(types::TimeUnit::Millisecond, array.value(row))
511 }
512 DataType::Timestamp(unit, _) if *unit == TimeUnit::Microsecond => {
513 let array = column
514 .as_any()
515 .downcast_ref::<array::TimestampMicrosecondArray>()
516 .unwrap();
517
518 if array.is_null(row) {
519 return ValueRef::Null;
520 }
521 ValueRef::Timestamp(types::TimeUnit::Microsecond, array.value(row))
522 }
523 DataType::Timestamp(unit, _) if *unit == TimeUnit::Nanosecond => {
524 let array = column
525 .as_any()
526 .downcast_ref::<array::TimestampNanosecondArray>()
527 .unwrap();
528
529 if array.is_null(row) {
530 return ValueRef::Null;
531 }
532 ValueRef::Timestamp(types::TimeUnit::Nanosecond, array.value(row))
533 }
534 DataType::Date32 => {
535 let array = column.as_any().downcast_ref::<array::Date32Array>().unwrap();
536
537 if array.is_null(row) {
538 return ValueRef::Null;
539 }
540 ValueRef::Date32(array.value(row))
541 }
542 DataType::Time64(TimeUnit::Microsecond) => {
543 let array = column.as_any().downcast_ref::<array::Time64MicrosecondArray>().unwrap();
544
545 if array.is_null(row) {
546 return ValueRef::Null;
547 }
548 ValueRef::Time64(types::TimeUnit::Microsecond, array.value(row))
549 }
550 DataType::Interval(unit) => match unit {
551 IntervalUnit::MonthDayNano => {
552 let array = column
553 .as_any()
554 .downcast_ref::<array::IntervalMonthDayNanoArray>()
555 .unwrap();
556
557 if array.is_null(row) {
558 return ValueRef::Null;
559 }
560
561 let value = array.value(row);
562
563 ValueRef::Interval {
564 months: value.months,
565 days: value.days,
566 nanos: value.nanoseconds,
567 }
568 }
569 _ => unimplemented!("{:?}", unit),
570 },
571 DataType::LargeList(..) => {
584 let arr = column.as_any().downcast_ref::<array::LargeListArray>().unwrap();
585
586 ValueRef::List(ListType::Large(arr), row)
587 }
588 DataType::List(..) => {
589 let arr = column.as_any().downcast_ref::<ListArray>().unwrap();
590
591 ValueRef::List(ListType::Regular(arr), row)
592 }
593 DataType::Dictionary(key_type, ..) => {
594 let column = column.as_any();
595 ValueRef::Enum(
596 match key_type.as_ref() {
597 DataType::UInt8 => {
598 EnumType::UInt8(column.downcast_ref::<DictionaryArray<UInt8Type>>().unwrap())
599 }
600 DataType::UInt16 => {
601 EnumType::UInt16(column.downcast_ref::<DictionaryArray<UInt16Type>>().unwrap())
602 }
603 DataType::UInt32 => {
604 EnumType::UInt32(column.downcast_ref::<DictionaryArray<UInt32Type>>().unwrap())
605 }
606 typ => panic!("Unsupported key type: {typ:?}"),
607 },
608 row,
609 )
610 }
611 DataType::Struct(_) => {
612 let res = column.as_any().downcast_ref::<StructArray>().unwrap();
613 ValueRef::Struct(res, row)
614 }
615 DataType::Map(..) => {
616 let arr = column.as_any().downcast_ref::<MapArray>().unwrap();
617 ValueRef::Map(arr, row)
618 }
619 DataType::FixedSizeList(..) => {
620 let arr = column.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
621 ValueRef::Array(arr, row)
622 }
623 DataType::Union(..) => ValueRef::Union(column, row),
624 _ => unreachable!("invalid value: {}, {}", col, column.data_type()),
625 }
626 }
627
628 pub fn get_ref_unwrap<I: RowIndex>(&self, idx: I) -> ValueRef<'_> {
644 self.get_ref(idx).unwrap()
645 }
646}
647
648impl<'stmt> AsRef<Statement<'stmt>> for Row<'stmt> {
649 fn as_ref(&self) -> &Statement<'stmt> {
650 self.stmt
651 }
652}
653
654mod sealed {
655 pub trait Sealed {}
658 impl Sealed for usize {}
659 impl Sealed for &str {}
660}
661
662pub trait RowIndex: sealed::Sealed {
666 fn idx(&self, stmt: &Statement<'_>) -> Result<usize>;
669}
670
671impl RowIndex for usize {
672 #[inline]
673 fn idx(&self, stmt: &Statement<'_>) -> Result<usize> {
674 if *self >= stmt.column_count() {
675 Err(Error::InvalidColumnIndex(*self))
676 } else {
677 Ok(*self)
678 }
679 }
680}
681
682impl RowIndex for &'_ str {
683 #[inline]
684 fn idx(&self, stmt: &Statement<'_>) -> Result<usize> {
685 stmt.column_index(self)
686 }
687}
688
689macro_rules! tuple_try_from_row {
690 ($($field:ident),*) => {
691 impl<'a, $($field,)*> convert::TryFrom<&'a Row<'a>> for ($($field,)*) where $($field: FromSql,)* {
692 type Error = crate::Error;
693
694 #[allow(unused_assignments, unused_variables, unused_mut)]
697 fn try_from(row: &'a Row<'a>) -> Result<Self> {
698 let mut index = 0;
699 $(
700 #[allow(non_snake_case)]
701 let $field = row.get::<_, $field>(index)?;
702 index += 1;
703 )*
704 Ok(($($field,)*))
705 }
706 }
707 }
708}
709
710macro_rules! tuples_try_from_row {
711 () => {
712 tuple_try_from_row!();
714 };
715 ($first:ident $(, $remaining:ident)*) => {
716 tuple_try_from_row!($first $(, $remaining)*);
717 tuples_try_from_row!($($remaining),*);
718 };
719}
720
721tuples_try_from_row!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P);
722
723#[cfg(test)]
724mod tests {
725 #![allow(clippy::redundant_closure)] use crate::{Connection, Result};
727
728 #[test]
729 fn test_try_from_row_for_tuple_1() -> Result<()> {
730 use crate::ToSql;
731 use std::convert::TryFrom;
732
733 let conn = Connection::open_in_memory()?;
734 conn.execute(
735 "CREATE TABLE test (a INTEGER)",
736 crate::params_from_iter(std::iter::empty::<&dyn ToSql>()),
737 )?;
738 conn.execute("INSERT INTO test VALUES (42)", [])?;
739 let val = conn.query_row("SELECT a FROM test", [], |row| <(u32,)>::try_from(row))?;
740 assert_eq!(val, (42,));
741 let fail = conn.query_row("SELECT a FROM test", [], |row| <(u32, u32)>::try_from(row));
742 assert!(fail.is_err());
743 Ok(())
744 }
745
746 #[test]
747 fn test_try_from_row_for_tuple_2() -> Result<()> {
748 use std::convert::TryFrom;
749
750 let conn = Connection::open_in_memory()?;
751 conn.execute("CREATE TABLE test (a INTEGER, b INTEGER)", [])?;
752 conn.execute("INSERT INTO test VALUES (42, 47)", [])?;
753 let val = conn.query_row("SELECT a, b FROM test", [], |row| <(u32, u32)>::try_from(row))?;
754 assert_eq!(val, (42, 47));
755 let fail = conn.query_row("SELECT a, b FROM test", [], |row| <(u32, u32, u32)>::try_from(row));
756 assert!(fail.is_err());
757 Ok(())
758 }
759
760 #[test]
761 fn test_try_from_row_for_tuple_16() -> Result<()> {
762 use std::convert::TryFrom;
763
764 let create_table = "CREATE TABLE test (
765 a INTEGER,
766 b INTEGER,
767 c INTEGER,
768 d INTEGER,
769 e INTEGER,
770 f INTEGER,
771 g INTEGER,
772 h INTEGER,
773 i INTEGER,
774 j INTEGER,
775 k INTEGER,
776 l INTEGER,
777 m INTEGER,
778 n INTEGER,
779 o INTEGER,
780 p INTEGER
781 )";
782
783 let insert_values = "INSERT INTO test VALUES (
784 0,
785 1,
786 2,
787 3,
788 4,
789 5,
790 6,
791 7,
792 8,
793 9,
794 10,
795 11,
796 12,
797 13,
798 14,
799 15
800 )";
801
802 type BigTuple = (
803 u32,
804 u32,
805 u32,
806 u32,
807 u32,
808 u32,
809 u32,
810 u32,
811 u32,
812 u32,
813 u32,
814 u32,
815 u32,
816 u32,
817 u32,
818 u32,
819 );
820
821 let conn = Connection::open_in_memory()?;
822 conn.execute(create_table, [])?;
823 conn.execute(insert_values, [])?;
824 let val = conn.query_row("SELECT * FROM test", [], |row| BigTuple::try_from(row))?;
825 assert_eq!(val.0, 0);
827 assert_eq!(val.1, 1);
828 assert_eq!(val.2, 2);
829 assert_eq!(val.3, 3);
830 assert_eq!(val.4, 4);
831 assert_eq!(val.5, 5);
832 assert_eq!(val.6, 6);
833 assert_eq!(val.7, 7);
834 assert_eq!(val.8, 8);
835 assert_eq!(val.9, 9);
836 assert_eq!(val.10, 10);
837 assert_eq!(val.11, 11);
838 assert_eq!(val.12, 12);
839 assert_eq!(val.13, 13);
840 assert_eq!(val.14, 14);
841 assert_eq!(val.15, 15);
842
843 Ok(())
845 }
846}