1use std::{convert, ffi::c_void, fmt, mem, os::raw::c_char, ptr, str};
2
3use arrow::{array::StructArray, datatypes::SchemaRef};
4
5use super::{ffi, AndThenRows, Connection, Error, MappedRows, Params, RawStatement, Result, Row, Rows, ValueRef};
6#[cfg(feature = "polars")]
7use crate::{arrow2, polars_dataframe::Polars};
8use crate::{
9 arrow_batch::{Arrow, ArrowStream},
10 error::result_from_duckdb_prepare,
11 types::{TimeUnit, ToSql, ToSqlOutput},
12};
13
14pub struct Statement<'conn> {
16 conn: &'conn Connection,
17 pub(crate) stmt: RawStatement,
18}
19
20impl Statement<'_> {
21 #[inline]
63 pub fn execute<P: Params>(&mut self, params: P) -> Result<usize> {
64 params.__bind_in(self)?;
65 self.execute_with_bound_parameters()
66 }
67
68 #[inline]
82 pub fn insert<P: Params>(&mut self, params: P) -> Result<()> {
83 let changes = self.execute(params)?;
84 match changes {
85 1 => Ok(()),
86 _ => Err(Error::StatementChangedRows(changes)),
87 }
88 }
89
90 #[inline]
107 pub fn query_arrow<P: Params>(&mut self, params: P) -> Result<Arrow<'_>> {
108 self.execute(params)?;
109 Ok(Arrow::new(self))
110 }
111
112 #[inline]
130 pub fn stream_arrow<P: Params>(&mut self, params: P, schema: SchemaRef) -> Result<ArrowStream<'_>> {
131 params.__bind_in(self)?;
132 self.stmt.execute_streaming()?;
133 Ok(ArrowStream::new(self, schema))
134 }
135
136 #[cfg(feature = "polars")]
174 #[inline]
175 pub fn query_polars<P: Params>(&mut self, params: P) -> Result<Polars<'_>> {
176 self.execute(params)?;
177 Ok(Polars::new(self))
178 }
179
180 #[inline]
239 pub fn query<P: Params>(&mut self, params: P) -> Result<Rows<'_>> {
240 self.execute(params)?;
241 Ok(Rows::new(self))
242 }
243
244 pub fn query_map<T, P, F>(&mut self, params: P, f: F) -> Result<MappedRows<'_, F>>
275 where
276 P: Params,
277 F: FnMut(&Row<'_>) -> Result<T>,
278 {
279 self.query(params).map(|rows| rows.mapped(f))
280 }
281
282 #[inline]
311 pub fn query_and_then<T, E, P, F>(&mut self, params: P, f: F) -> Result<AndThenRows<'_, F>>
312 where
313 P: Params,
314 E: convert::From<Error>,
315 F: FnMut(&Row<'_>) -> Result<T, E>,
316 {
317 self.query(params).map(|rows| rows.and_then(f))
318 }
319
320 #[inline]
323 pub fn exists<P: Params>(&mut self, params: P) -> Result<bool> {
324 let mut rows = self.query(params)?;
325 let exists = rows.next()?.is_some();
326 Ok(exists)
327 }
328
329 pub fn query_row<T, P, F>(&mut self, params: P, f: F) -> Result<T>
345 where
346 P: Params,
347 F: FnOnce(&Row<'_>) -> Result<T>,
348 {
349 self.query(params)?.get_expected_row().and_then(f)
350 }
351
352 #[inline]
354 pub fn row_count(&self) -> usize {
355 self.stmt.row_count()
356 }
357
358 #[inline]
360 pub fn step(&self) -> Option<StructArray> {
361 self.stmt.step()
362 }
363
364 #[inline]
366 pub fn stream_step(&self, schema: SchemaRef) -> Option<StructArray> {
367 self.stmt.streaming_step(schema)
368 }
369
370 #[cfg(feature = "polars")]
371 #[inline]
373 pub fn step2(&self) -> Option<arrow2::array::StructArray> {
374 self.stmt.step2()
375 }
376
377 #[inline]
378 pub(crate) fn bind_parameters<P>(&mut self, params: P) -> Result<()>
379 where
380 P: IntoIterator,
381 P::Item: ToSql,
382 {
383 let expected = self.stmt.bind_parameter_count();
384 let mut index = 0;
385 for p in params.into_iter() {
386 index += 1; if index > expected {
388 break;
389 }
390 self.bind_parameter(&p, index)?;
391 }
392 if index != expected {
393 Err(Error::InvalidParameterCount(index, expected))
394 } else {
395 Ok(())
396 }
397 }
398
399 #[inline]
401 pub fn parameter_count(&self) -> usize {
402 self.stmt.bind_parameter_count()
403 }
404
405 #[inline]
443 pub fn raw_bind_parameter<T: ToSql>(&mut self, one_based_col_index: usize, param: T) -> Result<()> {
444 self.bind_parameter(¶m, one_based_col_index)
447 }
448
449 #[inline]
464 pub fn raw_execute(&mut self) -> Result<usize> {
465 self.execute_with_bound_parameters()
466 }
467
468 #[inline]
481 pub fn raw_query(&self) -> Rows<'_> {
482 Rows::new(self)
483 }
484
485 #[inline]
490 pub fn schema(&self) -> SchemaRef {
491 self.stmt.schema()
492 }
493
494 fn bind_parameter<P: ?Sized + ToSql>(&self, param: &P, col: usize) -> Result<()> {
496 let value = param.to_sql()?;
497
498 let ptr = unsafe { self.stmt.ptr() };
499 let value = match value {
500 ToSqlOutput::Borrowed(v) => v,
501 ToSqlOutput::Owned(ref v) => ValueRef::from(v),
502 };
503 let rc = match value {
505 ValueRef::Null => unsafe { ffi::duckdb_bind_null(ptr, col as u64) },
506 ValueRef::Boolean(i) => unsafe { ffi::duckdb_bind_boolean(ptr, col as u64, i) },
507 ValueRef::TinyInt(i) => unsafe { ffi::duckdb_bind_int8(ptr, col as u64, i) },
508 ValueRef::SmallInt(i) => unsafe { ffi::duckdb_bind_int16(ptr, col as u64, i) },
509 ValueRef::Int(i) => unsafe { ffi::duckdb_bind_int32(ptr, col as u64, i) },
510 ValueRef::BigInt(i) => unsafe { ffi::duckdb_bind_int64(ptr, col as u64, i) },
511 ValueRef::HugeInt(i) => unsafe {
512 let hi = ffi::duckdb_hugeint {
513 lower: i as u64,
514 upper: (i >> 64) as i64,
515 };
516 ffi::duckdb_bind_hugeint(ptr, col as u64, hi)
517 },
518 ValueRef::UTinyInt(i) => unsafe { ffi::duckdb_bind_uint8(ptr, col as u64, i) },
519 ValueRef::USmallInt(i) => unsafe { ffi::duckdb_bind_uint16(ptr, col as u64, i) },
520 ValueRef::UInt(i) => unsafe { ffi::duckdb_bind_uint32(ptr, col as u64, i) },
521 ValueRef::UBigInt(i) => unsafe { ffi::duckdb_bind_uint64(ptr, col as u64, i) },
522 ValueRef::Float(r) => unsafe { ffi::duckdb_bind_float(ptr, col as u64, r) },
523 ValueRef::Double(r) => unsafe { ffi::duckdb_bind_double(ptr, col as u64, r) },
524 ValueRef::Text(s) => unsafe {
525 ffi::duckdb_bind_varchar_length(ptr, col as u64, s.as_ptr() as *const c_char, s.len() as u64)
526 },
527 ValueRef::Blob(b) => unsafe {
528 ffi::duckdb_bind_blob(ptr, col as u64, b.as_ptr() as *const c_void, b.len() as u64)
529 },
530 ValueRef::Timestamp(u, i) => unsafe {
531 let micros = match u {
532 TimeUnit::Second => i * 1_000_000,
533 TimeUnit::Millisecond => i * 1_000,
534 TimeUnit::Microsecond => i,
535 TimeUnit::Nanosecond => i / 1_000,
536 };
537 ffi::duckdb_bind_timestamp(ptr, col as u64, ffi::duckdb_timestamp { micros })
538 },
539 ValueRef::Interval { months, days, nanos } => unsafe {
540 let micros = nanos / 1_000;
541 ffi::duckdb_bind_interval(ptr, col as u64, ffi::duckdb_interval { months, days, micros })
542 },
543 _ => unreachable!("not supported: {}", value.data_type()),
544 };
545 result_from_duckdb_prepare(rc, ptr)
546 }
547
548 #[inline]
549 fn execute_with_bound_parameters(&mut self) -> Result<usize> {
550 self.stmt.execute()
551 }
552
553 #[inline]
557 pub(crate) unsafe fn into_raw(mut self) -> RawStatement {
558 let mut stmt = RawStatement::new(ptr::null_mut());
559 mem::swap(&mut stmt, &mut self.stmt);
560 stmt
561 }
562}
563
564impl fmt::Debug for Statement<'_> {
565 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
566 let sql = if self.stmt.is_null() {
567 Ok("")
568 } else {
569 str::from_utf8(self.stmt.sql().unwrap().to_bytes())
570 };
571 f.debug_struct("Statement")
572 .field("conn", self.conn)
573 .field("stmt", &self.stmt)
574 .field("sql", &sql)
575 .finish()
576 }
577}
578
579impl Statement<'_> {
580 #[inline]
581 pub(super) fn new(conn: &Connection, stmt: RawStatement) -> Statement<'_> {
582 Statement { conn, stmt }
583 }
584}
585
586#[cfg(test)]
587mod test {
588 use crate::{params_from_iter, types::ToSql, Connection, Error, Result};
589
590 #[test]
591 fn test_execute() -> Result<()> {
592 let db = Connection::open_in_memory()?;
593 db.execute_batch("CREATE TABLE foo(x INTEGER)")?;
594
595 assert_eq!(db.execute("INSERT INTO foo(x) VALUES (?)", [&2i32])?, 1);
596 assert_eq!(db.execute("INSERT INTO foo(x) VALUES (?)", [&3i32])?, 1);
597
598 assert_eq!(
600 5i32,
601 db.query_row::<i32, _, _>("SELECT SUM(x) FROM foo WHERE x > ?", [&0i32], |r| r.get(0))?
602 );
603 assert_eq!(
604 3i32,
605 db.query_row::<i32, _, _>("SELECT SUM(x) FROM foo WHERE x > ?", [&2i32], |r| r.get(0))?
606 );
607 Ok(())
608 }
609
610 #[test]
611 fn test_stmt_execute() -> Result<()> {
612 let db = Connection::open_in_memory()?;
613 let sql = r#"
614 CREATE SEQUENCE seq;
615 CREATE TABLE test (id INTEGER DEFAULT NEXTVAL('seq'), name TEXT NOT NULL, flag INTEGER);
616 "#;
617 db.execute_batch(sql)?;
618
619 let mut stmt = db.prepare("INSERT INTO test (name) VALUES (?)")?;
620 stmt.execute([&"one"])?;
621
622 let mut stmt = db.prepare("SELECT COUNT(*) FROM test WHERE name = ?")?;
623 assert_eq!(1i32, stmt.query_row::<i32, _, _>([&"one"], |r| r.get(0))?);
624 Ok(())
625 }
626
627 #[test]
628 fn test_query() -> Result<()> {
629 let db = Connection::open_in_memory()?;
630 let sql = r#"
631 CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, flag INTEGER);
632 INSERT INTO test(id, name) VALUES (1, 'one');
633 "#;
634 db.execute_batch(sql)?;
635
636 let mut stmt = db.prepare("SELECT id FROM test where name = ?")?;
637 {
638 let mut rows = stmt.query([&"one"])?;
639 let id: Result<i32> = rows.next()?.unwrap().get(0);
640 assert_eq!(Ok(1), id);
641 }
642 Ok(())
643 }
644
645 #[test]
646 fn test_query_and_then() -> Result<()> {
647 let db = Connection::open_in_memory()?;
648 let sql = r#"
649 CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, flag INTEGER);
650 INSERT INTO test(id, name) VALUES (1, 'one');
651 INSERT INTO test(id, name) VALUES (2, 'one');
652 "#;
653 db.execute_batch(sql)?;
654
655 let mut stmt = db.prepare("SELECT id FROM test where name = ? ORDER BY id ASC")?;
656 let mut rows = stmt.query_and_then([&"one"], |row| {
657 let id: i32 = row.get(0)?;
658 if id == 1 {
659 Ok(id)
660 } else {
661 Err(Error::ExecuteReturnedResults)
662 }
663 })?;
664
665 let doubled_id: i32 = rows.next().unwrap()?;
667 assert_eq!(1, doubled_id);
668
669 #[allow(clippy::match_wild_err_arm)]
671 match rows.next().unwrap() {
672 Ok(_) => panic!("invalid Ok"),
673 Err(Error::ExecuteReturnedResults) => (),
674 Err(_) => panic!("invalid Err"),
675 }
676 Ok(())
677 }
678
679 #[test]
680 fn test_unbound_parameters_are_error() -> Result<()> {
681 let db = Connection::open_in_memory()?;
682 let sql = "CREATE TABLE test (x TEXT, y TEXT)";
683 db.execute_batch(sql)?;
684
685 let mut stmt = db.prepare("INSERT INTO test (x, y) VALUES (?, ?)")?;
686 assert!(stmt.execute([&"one"]).is_err());
687 Ok(())
688 }
689
690 #[test]
691 fn test_insert_empty_text_is_none() -> Result<()> {
692 let db = Connection::open_in_memory()?;
693 let sql = "CREATE TABLE test (x TEXT, y TEXT)";
694 db.execute_batch(sql)?;
695
696 let mut stmt = db.prepare("INSERT INTO test (x) VALUES (?)")?;
697 stmt.execute([&"one"])?;
698
699 let result: Option<String> = db.query_row("SELECT y FROM test WHERE x = 'one'", [], |row| row.get(0))?;
700 assert!(result.is_none());
701 Ok(())
702 }
703
704 #[test]
705 fn test_raw_binding() -> Result<()> {
706 let db = Connection::open_in_memory()?;
707 db.execute_batch("CREATE TABLE test (name TEXT, value INTEGER)")?;
708 {
709 let mut stmt = db.prepare("INSERT INTO test (name, value) VALUES (?, ?)")?;
710
711 stmt.raw_bind_parameter(2, 50i32)?;
712 stmt.raw_bind_parameter(1, "example")?;
713 let n = stmt.raw_execute()?;
714 assert_eq!(n, 1);
715 }
716
717 {
718 let mut stmt = db.prepare("SELECT name, value FROM test WHERE value = ?")?;
719 stmt.raw_bind_parameter(1, 50)?;
720 stmt.raw_execute()?;
721 let mut rows = stmt.raw_query();
722 {
723 let row = rows.next()?.unwrap();
724 let name: String = row.get(0)?;
725 assert_eq!(name, "example");
726 let value: i32 = row.get(1)?;
727 assert_eq!(value, 50);
728 }
729 assert!(rows.next()?.is_none());
730 }
731
732 {
733 let db = Connection::open_in_memory()?;
734 db.execute_batch("CREATE TABLE test (name TEXT, value UINTEGER)")?;
735 let mut stmt = db.prepare("INSERT INTO test(name, value) VALUES (?, ?)")?;
736 stmt.raw_bind_parameter(1, "negative")?;
737 stmt.raw_bind_parameter(2, u32::MAX)?;
738 let n = stmt.raw_execute()?;
739 assert_eq!(n, 1);
740 assert_eq!(
741 u32::MAX,
742 db.query_row::<u32, _, _>("SELECT value FROM test", [], |r| r.get(0))?
743 );
744 }
745
746 {
747 let db = Connection::open_in_memory()?;
748 db.execute_batch("CREATE TABLE test (name TEXT, value UBIGINT)")?;
749 let mut stmt = db.prepare("INSERT INTO test(name, value) VALUES (?, ?)")?;
750 stmt.raw_bind_parameter(1, "negative")?;
751 stmt.raw_bind_parameter(2, u64::MAX)?;
752 let n = stmt.raw_execute()?;
753 assert_eq!(n, 1);
754 assert_eq!(
755 u64::MAX,
756 db.query_row::<u64, _, _>("SELECT value FROM test", [], |r| r.get(0))?
757 );
758 }
759
760 Ok(())
761 }
762
763 #[test]
764 #[cfg_attr(windows, ignore = "Windows doesn't allow concurrent writes to a file")]
765 fn test_insert_duplicate() -> Result<()> {
766 let db = Connection::open_in_memory()?;
767 db.execute_batch("CREATE TABLE foo(x INTEGER UNIQUE)")?;
768 let mut stmt = db.prepare("INSERT INTO foo (x) VALUES (?)")?;
769 stmt.insert([1i32])?;
771 stmt.insert([2i32])?;
772 assert!(stmt.insert([1i32]).is_err());
773 let mut multi = db.prepare("INSERT INTO foo (x) SELECT 3 UNION ALL SELECT 4")?;
774 match multi.insert([]).unwrap_err() {
775 Error::StatementChangedRows(2) => (),
776 err => panic!("Unexpected error {err}"),
777 }
778 Ok(())
779 }
780
781 #[test]
782 fn test_insert_different_tables() -> Result<()> {
783 let db = Connection::open_in_memory()?;
785 db.execute_batch(
786 r"
787 CREATE TABLE foo(x INTEGER);
788 CREATE TABLE bar(x INTEGER);
789 ",
790 )?;
791
792 db.prepare("INSERT INTO foo VALUES (10)")?.insert([])?;
793 db.prepare("INSERT INTO bar VALUES (10)")?.insert([])?;
794 Ok(())
795 }
796
797 #[test]
798 fn test_exists() -> Result<()> {
799 let db = Connection::open_in_memory()?;
800 let sql = "BEGIN;
801 CREATE TABLE foo(x INTEGER);
802 INSERT INTO foo VALUES(1);
803 INSERT INTO foo VALUES(2);
804 END;";
805 db.execute_batch(sql)?;
806 let mut stmt = db.prepare("SELECT 1 FROM foo WHERE x = ?")?;
807 assert!(stmt.exists([1i32])?);
808 assert!(stmt.exists([2i32])?);
809 assert!(!stmt.exists([0i32])?);
810 Ok(())
811 }
812
813 #[test]
814 fn test_query_row() -> Result<()> {
815 let db = Connection::open_in_memory()?;
816 let sql = "BEGIN;
817 CREATE TABLE foo(x INTEGER, y INTEGER);
818 INSERT INTO foo VALUES(1, 3);
819 INSERT INTO foo VALUES(2, 4);
820 END;";
821 db.execute_batch(sql)?;
822 let mut stmt = db.prepare("SELECT y FROM foo WHERE x = ?")?;
823 let y: Result<i32> = stmt.query_row([1i32], |r| r.get(0));
824 assert_eq!(3i32, y?);
825 Ok(())
826 }
827
828 #[test]
829 fn test_query_by_column_name() -> Result<()> {
830 let db = Connection::open_in_memory()?;
831 let sql = "BEGIN;
832 CREATE TABLE foo(x INTEGER, y INTEGER);
833 INSERT INTO foo VALUES(1, 3);
834 END;";
835 db.execute_batch(sql)?;
836 let mut stmt = db.prepare("SELECT y FROM foo")?;
837 let y: Result<i64> = stmt.query_row([], |r| r.get("y"));
838 assert_eq!(3i64, y?);
839 Ok(())
840 }
841
842 #[test]
843 fn test_get_schema_of_executed_result() -> Result<()> {
844 use arrow::datatypes::{DataType, Field, Schema};
845 let db = Connection::open_in_memory()?;
846 let sql = "BEGIN;
847 CREATE TABLE foo(x STRING, y INTEGER);
848 INSERT INTO foo VALUES('hello', 3);
849 END;";
850 db.execute_batch(sql)?;
851 let mut stmt = db.prepare("SELECT x, y FROM foo")?;
852 let _ = stmt.execute([]);
853 let schema = stmt.schema();
854 assert_eq!(
855 *schema,
856 Schema::new(vec![
857 Field::new("x", DataType::Utf8, true),
858 Field::new("y", DataType::Int32, true)
859 ])
860 );
861 Ok(())
862 }
863
864 #[test]
865 #[should_panic(expected = "called `Option::unwrap()` on a `None` value")]
866 fn test_unexecuted_schema_panics() {
867 let db = Connection::open_in_memory().unwrap();
868 let sql = "BEGIN;
869 CREATE TABLE foo(x STRING, y INTEGER);
870 INSERT INTO foo VALUES('hello', 3);
871 END;";
872 db.execute_batch(sql).unwrap();
873 let stmt = db.prepare("SELECT x, y FROM foo").unwrap();
874 let _ = stmt.schema();
875 }
876
877 #[test]
878 fn test_query_by_column_name_ignore_case() -> Result<()> {
879 let db = Connection::open_in_memory()?;
880 let sql = "BEGIN;
881 CREATE TABLE foo(x INTEGER, y INTEGER);
882 INSERT INTO foo VALUES(1, 3);
883 END;";
884 db.execute_batch(sql)?;
885 let mut stmt = db.prepare("SELECT y as Y FROM foo")?;
886 let y: Result<i64> = stmt.query_row([], |r| r.get("y"));
887 assert_eq!(3i64, y?);
888 Ok(())
889 }
890
891 #[test]
892 #[ignore]
893 fn test_bind_parameters() -> Result<()> {
894 let db = Connection::open_in_memory()?;
895 db.query_row("SELECT ?1, ?2, ?3", [&1u8 as &dyn ToSql, &"one", &Some("one")], |row| {
897 row.get::<_, u8>(0)
898 })?;
899 let data = vec![1, 2, 3];
901 db.query_row("SELECT ?1, ?2, ?3", params_from_iter(&data), |row| row.get::<_, u8>(0))?;
902 db.query_row("SELECT ?1, ?2, ?3", params_from_iter(data.as_slice()), |row| {
903 row.get::<_, u8>(0)
904 })?;
905 db.query_row("SELECT ?1, ?2, ?3", params_from_iter(data), |row| row.get::<_, u8>(0))?;
906
907 use std::collections::BTreeSet;
908 let data: BTreeSet<String> = ["one", "two", "three"].iter().map(|s| (*s).to_string()).collect();
909 db.query_row("SELECT ?1, ?2, ?3", params_from_iter(&data), |row| {
910 row.get::<_, String>(0)
911 })?;
912
913 let data = [0; 3];
914 db.query_row("SELECT ?1, ?2, ?3", params_from_iter(&data), |row| row.get::<_, u8>(0))?;
915 db.query_row("SELECT ?1, ?2, ?3", params_from_iter(data.iter()), |row| {
916 row.get::<_, u8>(0)
917 })?;
918 Ok(())
919 }
920
921 #[test]
922 fn test_empty_stmt() -> Result<()> {
923 let conn = Connection::open_in_memory()?;
924 let stmt = conn.prepare("");
925 assert!(stmt.is_err());
926
927 Ok(())
928 }
929
930 #[test]
931 fn test_comment_empty_stmt() -> Result<()> {
932 let conn = Connection::open_in_memory()?;
933 assert!(conn.prepare("/*SELECT 1;*/").is_err());
934 Ok(())
935 }
936
937 #[test]
938 fn test_comment_and_sql_stmt() -> Result<()> {
939 let conn = Connection::open_in_memory()?;
940 let mut stmt = conn.prepare("/*...*/ SELECT 1;")?;
941 stmt.execute([])?;
942 assert_eq!(1, stmt.column_count());
943 Ok(())
944 }
945
946 #[test]
947 #[ignore]
948 fn test_utf16_conversion() -> Result<()> {
949 let db = Connection::open_in_memory()?;
950 db.pragma_update(None, "encoding", &"UTF-16le")?;
951 let encoding: String = db.pragma_query_value(None, "encoding", |row| row.get(0))?;
952 assert_eq!("UTF-16le", encoding);
953 db.execute_batch("CREATE TABLE foo(x TEXT)")?;
954 let expected = "ใในใ";
955 db.execute("INSERT INTO foo(x) VALUES (?)", [&expected])?;
956 let actual: String = db.query_row("SELECT x FROM foo", [], |row| row.get(0))?;
957 assert_eq!(expected, actual);
958 Ok(())
959 }
960
961 #[test]
962 #[ignore]
963 fn test_nul_byte() -> Result<()> {
964 let db = Connection::open_in_memory()?;
965 let expected = "a\x00b";
966 let actual: String = db.query_row("SELECT CAST(? AS VARCHAR)", [expected], |row| row.get(0))?;
967 assert_eq!(expected, actual);
968 Ok(())
969 }
970}