sqlx_d1_core/
connection.rs

1use sqlx_core::{Url, Either};
2
3#[cfg(target_arch = "wasm32")]
4use {
5    crate::{error::D1Error, row::D1Row},
6    std::pin::Pin,
7    worker::{wasm_bindgen::JsValue, wasm_bindgen_futures::JsFuture, js_sys},
8};
9
10/// ## Example
11/// 
12/// ```toml
13/// # Cargo.toml
14/// 
15/// [dependencies]
16/// sqlx_d1 = { version = "0.1", features = ["macros"] }
17/// worker = { version = "0.5", features = ["d1"] }
18/// serde = { version = "1.0", features = ["derive"] }
19/// ```
20/// ```toml
21/// # wrangler.toml
22/// 
23/// [[d1_database]]
24/// binding = "DB"
25/// database_name = "..."
26/// database_id = "..."
27/// ```
28/// ```rust,ignore
29/// // src/lib.rs
30/// 
31/// #[worker::event(fetch)]
32/// async fn main(
33///     mut req: worker::Request,
34///     env: worker::Env,
35///     _ctx: worker::Context,
36/// ) -> worker::Result<worker::Response> {
37///     let d1 = env.d1("DB")?;
38///     let conn = sqlx_d1::D1Connection::new(d1);
39/// 
40///     #[derive(serde::Deserialize)]
41///     struct CreateUser {
42///         name: String,
43///         age: Option<u8>,
44///     }
45/// 
46///     let req = req.json::<CreateUser>().await?;
47/// 
48///     let id = sqlx_d1::query!(
49///         "
50///         INSERT INTO users (name, age) VALUES (?, ?)
51///         RETURNING id
52///         ",
53///             req.name,
54///             req.age
55///         )
56///         .fetch_one(&conn)
57///         .await
58///         .map_err(|e| worker::Error::RustError(e.to_string()))?
59///         .id;
60/// 
61///     worker::Response::ok(format!("Your id is {id}!"))
62/// }
63/// ```
64pub struct D1Connection {
65    #[cfg(target_arch = "wasm32")]
66    pub(crate) inner: worker_sys::D1Database,
67
68    #[cfg(not(target_arch = "wasm32"))]
69    pub(crate) inner: sqlx_sqlite::SqliteConnection,
70}
71
72const _: () = {
73    /* SAFETY: used in single-threaded Workers */
74    unsafe impl Send for D1Connection {}
75    unsafe impl Sync for D1Connection {}
76
77    #[cfg(not(target_arch = "wasm32"))]
78    macro_rules! unreachable_native_impl_of_item_for_only_wasm32 {
79        ($item_for_only_wasm32:literal) => {
80            panic!(
81                "native `{}`: Invalid use of `sqlx_d1`. Be sure to use `sqlx_d1` where the target is set to \
82                `wasm32-unknown-unknown` ! \n\
83                For this, typcally, place `.cargo/config.toml` of following content at the root of \
84                your project or workspace : \n\
85                \n\
86                [build]\n\
87                target = \"wasm32-unknown-unknown\"\n",
88                $item_for_only_wasm32
89            )
90        };
91    }
92
93    impl D1Connection {
94        pub fn new(d1: worker::D1Database) -> Self {
95            #[cfg(target_arch = "wasm32")] {
96                Self { inner: unsafe {std::mem::transmute(d1)} }
97            }
98            #[cfg(not(target_arch = "wasm32"))] {
99                let _ = d1;
100                unreachable_native_impl_of_item_for_only_wasm32!("D1Cnnection::new");
101            }
102        }
103
104        #[cfg(not(target_arch = "wasm32"))]
105        pub async fn connect(url: impl AsRef<str>) -> Result<Self, sqlx_core::Error> {
106            <Self as sqlx_core::connection::Connection>::connect(url.as_ref()).await
107        }
108    }
109
110    impl Clone for D1Connection {
111        fn clone(&self) -> Self {
112            #[cfg(target_arch = "wasm32")] {
113                Self { inner: self.inner.clone() }
114            }
115            #[cfg(not(target_arch = "wasm32"))] {
116                unreachable_native_impl_of_item_for_only_wasm32!("impl Clone for D1Connection");
117            }
118        }
119    }
120
121    impl std::fmt::Debug for D1Connection {
122        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123            f.debug_struct("D1Connection").finish()
124        }
125    }
126
127    impl sqlx_core::connection::Connection for D1Connection {
128        type Database = crate::D1;
129
130        type Options = D1ConnectOptions;
131
132        fn close(self) -> crate::ResultFuture<'static, ()> {
133            Box::pin(async {Ok(())})
134        }
135
136        fn close_hard(self) -> crate::ResultFuture<'static, ()> {
137            Box::pin(async {Ok(())})
138        }
139
140        fn ping(&mut self) -> crate::ResultFuture<'_, ()> {
141            Box::pin(async {Ok(())})
142        }
143
144        fn begin(&mut self) -> crate::ResultFuture<'_, sqlx_core::transaction::Transaction<'_, Self::Database>>
145        where
146            Self: Sized,
147        {
148            sqlx_core::transaction::Transaction::begin(self)
149        }
150
151        fn shrink_buffers(&mut self) {
152            /* do nothing */
153        }
154
155        fn flush(&mut self) -> crate::ResultFuture<'_, ()> {
156            Box::pin(async {Ok(())})
157        }
158
159        fn should_flush(&self) -> bool {
160            false
161        }
162    }
163
164    impl<'c> sqlx_core::executor::Executor<'c> for &'c mut D1Connection {
165        type Database = crate::D1;
166
167        fn fetch_many<'e, 'q: 'e, E>(
168            self,
169            #[allow(unused)]
170            mut query: E,
171        ) -> futures_core::stream::BoxStream<
172            'e,
173            Result<
174                Either<
175                    <Self::Database as sqlx_core::database::Database>::QueryResult,
176                    <Self::Database as sqlx_core::database::Database>::Row
177                >,
178                sqlx_core::Error,
179            >,
180        >
181        where
182            'c: 'e,
183            E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
184        {
185            <&'c D1Connection as sqlx_core::executor::Executor<'c>>::fetch_many(self, query)
186        }
187
188        fn fetch_optional<'e, 'q: 'e, E>(
189            self,
190            #[allow(unused)]
191            mut query: E,
192        ) -> crate::ResultFuture<'e, Option<<Self::Database as sqlx_core::database::Database>::Row>>
193        where
194            'c: 'e,
195            E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
196        {
197            <&'c D1Connection as sqlx_core::executor::Executor<'c>>::fetch_optional(self, query)
198        }
199
200        fn prepare_with<'e, 'q: 'e>(
201            self,
202            sql: &'q str,
203            _parameters: &'e [<Self::Database as sqlx_core::database::Database>::TypeInfo],
204        ) -> crate::ResultFuture<'e, <Self::Database as sqlx_core::database::Database>::Statement<'q>>
205        where
206            'c: 'e,
207        {
208            Box::pin(async {
209                Ok(crate::statement::D1Statement {
210                    sql: std::borrow::Cow::Borrowed(sql),
211                })
212            })
213        }
214
215        fn describe<'e, 'q: 'e>(
216            self,
217            #[allow(unused)]
218            sql: &'q str,
219        ) -> crate::ResultFuture<'e, sqlx_core::describe::Describe<Self::Database>>
220        where
221            'c: 'e,
222        {
223            #[cfg(target_arch = "wasm32")] {
224                unreachable!("wasm32 describe")
225            }
226            #[cfg(not(target_arch = "wasm32"))] {
227                /* compile-time verification by macros */
228
229                Box::pin(async {
230                    let sqlx_core::describe::Describe {
231                        columns,
232                        parameters,
233                        nullable
234                    } = <&mut sqlx_sqlite::SqliteConnection as sqlx_core::executor::Executor>::describe(
235                        &mut self.inner,
236                        sql
237                    ).await?;
238                    
239                    Ok(sqlx_core::describe::Describe {
240                        parameters: parameters.map(|ps| match ps {
241                            Either::Left(type_infos) => Either::Left(type_infos.into_iter().map(crate::type_info::D1TypeInfo::from_sqlite).collect()),
242                            Either::Right(n) => Either::Right(n)
243                        }),
244                        columns: columns.into_iter().map(crate::column::D1Column::from_sqlite).collect(),
245                        nullable
246                    })
247                })
248            }
249        }
250    }
251
252    impl<'c> sqlx_core::executor::Executor<'c> for &'c D1Connection {
253        type Database = crate::D1;
254
255        fn fetch_many<'e, 'q: 'e, E>(
256            self,
257            #[allow(unused)]
258            mut query: E,
259        ) -> futures_core::stream::BoxStream<
260            'e,
261            Result<
262                Either<
263                    <Self::Database as sqlx_core::database::Database>::QueryResult,
264                    <Self::Database as sqlx_core::database::Database>::Row
265                >,
266                sqlx_core::Error,
267            >,
268        >
269        where
270            'c: 'e,
271            E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
272        {
273            #[cfg(not(target_arch = "wasm32"))] {
274                unreachable_native_impl_of_item_for_only_wasm32!("impl Executor for &D1Conection");
275            }
276            #[cfg(target_arch = "wasm32")] {
277                let sql = query.sql();
278                let arguments = match query.take_arguments() {
279                    Ok(a) => a,
280                    Err(e) => return Box::pin(futures_util::stream::once(async {Err(sqlx_core::Error::Encode(e))})),
281                };
282
283                struct FetchMany<F> {
284                    raw_rows_future: F,
285                    raw_rows: Option<js_sys::ArrayIntoIter>,
286                }
287                const _: () = {
288                    /* SAFETY: used in single-threaded Workers */
289                    unsafe impl<F> Send for FetchMany<F> {}
290
291                    impl<F> FetchMany<F> {
292                        fn new(raw_rows_future: F) -> Self {
293                            Self { raw_rows_future, raw_rows: None }
294                        }
295                    }
296
297                    impl<F> futures_core::Stream for FetchMany<F>
298                    where
299                        F: Future<Output = Result<Option<js_sys::Array>, JsValue>>,
300                    {
301                        type Item = Result<
302                            Either<crate::query_result::D1QueryResult, D1Row>,
303                            sqlx_core::Error
304                        >;
305
306                        fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
307                            use std::task::Poll;
308
309                            fn pop_next(raw_rows: &mut js_sys::ArrayIntoIter) ->
310                                Option<Result<
311                                    Either<crate::query_result::D1QueryResult, D1Row>,
312                                    sqlx_core::Error
313                                >>
314                            {
315                                let raw_row = raw_rows.next()?;
316                                Some(D1Row::from_raw(raw_row).map(Either::Right))
317                            }
318
319                            let this = unsafe {self.get_unchecked_mut()};
320                            match &mut this.raw_rows {
321                                Some(raw_rows) => Poll::Ready(pop_next(raw_rows)),
322                                None => match unsafe {Pin::new_unchecked(&mut this.raw_rows_future)}.poll(cx) {
323                                    Poll::Pending => Poll::Pending,
324                                    Poll::Ready(Err(e)) => Poll::Ready(Some(Err(
325                                        sqlx_core::Error::from(D1Error::from(e))
326                                    ))),
327                                    Poll::Ready(Ok(maybe_raw_rows)) => {
328                                        this.raw_rows = Some(maybe_raw_rows.unwrap_or_else(js_sys::Array::new).into_iter());
329                                        Poll::Ready(pop_next(unsafe {this.raw_rows.as_mut().unwrap_unchecked()}))
330                                    }
331                                }
332                            }                        
333                        }
334                    }
335                };
336
337                Box::pin(FetchMany::new(async move {
338                    let mut statement = self.inner.prepare(sql).unwrap();
339                    if let Some(a) = arguments {
340                        statement = statement.bind(a.as_ref().iter().collect())?;
341                    }
342
343                    let d1_result_jsvalue = JsFuture::from(statement.all()?)
344                        .await?;
345                    worker_sys::D1Result::from(d1_result_jsvalue)
346                        .results()
347                }))
348            }
349        }
350
351        fn fetch_optional<'e, 'q: 'e, E>(
352            self,
353            #[allow(unused)]
354            mut query: E,
355        ) -> crate::ResultFuture<'e, Option<<Self::Database as sqlx_core::database::Database>::Row>>
356        where
357            'c: 'e,
358            E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
359        {
360            #[cfg(not(target_arch = "wasm32"))] {
361                unreachable_native_impl_of_item_for_only_wasm32!("impl Executor for &D1Conection");
362            }
363            #[cfg(target_arch = "wasm32")] {
364                let sql = query.sql();
365                let arguments = match query.take_arguments() {
366                    Ok(a) => a,
367                    Err(e) => return Box::pin(async {Err(sqlx_core::Error::Encode(e))}),
368                };
369
370                Box::pin(worker::send::SendFuture::new(async move {
371                    let mut statement = self.inner.prepare(sql).unwrap();
372                    if let Some(a) = arguments {
373                        statement = statement
374                            .bind(a.as_ref().iter().collect())
375                            .map_err(|e| sqlx_core::Error::Encode(Box::new(D1Error::from(e))))?;
376                    }
377
378                    let raw = JsFuture::from(statement.first(None).map_err(D1Error::from)?)
379                        .await
380                        .map_err(D1Error::from)?;
381                    if raw.is_null() {
382                        Ok(None)
383                    } else {
384                        D1Row::from_raw(raw).map(Some)
385                    }
386                }))
387            }
388        }
389
390        fn prepare_with<'e, 'q: 'e>(
391            self,
392            sql: &'q str,
393            _parameters: &'e [<Self::Database as sqlx_core::database::Database>::TypeInfo],
394        ) -> crate::ResultFuture<'e, <Self::Database as sqlx_core::database::Database>::Statement<'q>>
395        where
396            'c: 'e,
397        {
398            Box::pin(async {
399                Ok(crate::statement::D1Statement {
400                    sql: std::borrow::Cow::Borrowed(sql),
401                })
402            })
403        }
404
405        fn describe<'e, 'q: 'e>(
406            self,
407            #[allow(unused)]
408            sql: &'q str,
409        ) -> crate::ResultFuture<'e, sqlx_core::describe::Describe<Self::Database>>
410        where
411            'c: 'e,
412        {
413            #[cfg(not(target_arch = "wasm32"))] {
414                unreachable_native_impl_of_item_for_only_wasm32!("impl Executor for &D1Conection");
415            }
416            #[cfg(target_arch = "wasm32")] {
417                unreachable!("wasm32 describe")
418            }
419        }
420    }
421};
422
423/// ref: <https://developers.cloudflare.com/d1/sql-api/sql-statements/#compatible-pragma-statements>
424#[derive(Clone)]
425pub struct D1ConnectOptions {
426    pragmas: TogglePragmas,
427    #[cfg(target_arch = "wasm32")]
428    d1: worker_sys::D1Database,
429    #[cfg(not(target_arch = "wasm32"))]
430    sqlite_path: std::path::PathBuf,
431}
432const _: () = {
433    /* SAFETY: used in single-threaded Workers */
434    unsafe impl Send for D1ConnectOptions {}
435    unsafe impl Sync for D1ConnectOptions {}
436
437    #[cfg(target_arch = "wasm32")]
438    const URL_CONVERSION_UNSUPPORTED_MESSAGE: &'static str = "\
439        `sqlx_d1::D1ConnectOptions` doesn't support conversion between `Url`. \
440        Consider connect from options created by `D1ConnectOptions::new`. \
441    ";
442
443    const LOG_SETTINGS_UNSUPPORTED_MESSAGE: &'static str = "\
444        `sqlx_d1::D1ConnectOptions` doesn't support log settings.
445    ";
446
447    impl D1ConnectOptions {
448        #[cfg(target_arch = "wasm32")]
449        pub fn new(d1: worker::D1Database) -> Self {
450            Self {
451                d1: unsafe {core::mem::transmute(d1)},
452                pragmas: TogglePragmas::new(),
453            }
454        }
455    }
456
457    impl std::fmt::Debug for D1ConnectOptions {
458        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
459            f.debug_struct("D1ConnectOptions")
460                .field("pragmas", &self.pragmas)
461                .finish()
462        }
463    }
464
465    impl std::str::FromStr for D1ConnectOptions {
466        type Err = sqlx_core::Error;
467
468        fn from_str(_: &str) -> Result<Self, Self::Err> {
469            #[cfg(target_arch = "wasm32")] {
470                Err(sqlx_core::Error::Configuration(From::from(
471                    URL_CONVERSION_UNSUPPORTED_MESSAGE
472                )))
473            }
474
475            #[cfg(not(target_arch = "wasm32"))] {
476                use std::{io, fs, path::{Path, PathBuf}};
477
478                fn maybe_miniflare_d1_dir_of(dir: impl AsRef<Path>) -> PathBuf {
479                    dir.as_ref()
480                        .join(".wrangler")
481                        .join("state")
482                        .join("v3")
483                        .join("d1")
484                        .join("miniflare-D1DatabaseObject")
485                }
486            
487                const PACKAGE_ROOT: &str = env!("CARGO_MANIFEST_DIR");
488
489                let (candidate_1, candidate_2) = (
490                    maybe_miniflare_d1_dir_of(PACKAGE_ROOT),
491                    maybe_miniflare_d1_dir_of(".")
492                );
493
494                let sqlite_path = (|| -> io::Result<PathBuf> {                    
495                    let miniflare_d1_dir = match (
496                        fs::exists(&candidate_1),
497                        fs::exists(&candidate_2)
498                    ) {
499                        (Ok(true), _) => candidate_1,
500                        (_, Ok(true)) => candidate_2,
501                        (Err(e), _) | (_, Err(e)) => return Err(e),
502                        (Ok(false), Ok(false)) => return Err(io::Error::new(
503                            io::ErrorKind::NotFound,
504                            "miniflare's D1 emulating directory not found"
505                        )),
506                    };
507                    
508                    let [sqlite_path] = fs::read_dir(miniflare_d1_dir)?
509                        .filter_map(|r| r.as_ref().ok().and_then(|e| {
510                            let path = e.path();
511                            path.extension()
512                                .is_some_and(|ex| ex == "sqlite")
513                                .then_some(path)
514                        }))
515                        .collect::<Vec<_>>()
516                        .try_into()
517                        .map_err(|_| io::Error::new(
518                            io::ErrorKind::Other,
519                            "Currently, sqlx_d1 doesn't support multiple D1 bindings!"
520                        ))?;
521
522                    Ok(sqlite_path)
523                })().map_err(|_| sqlx_core::Error::WorkerCrashed)?;
524                    
525                Ok(Self {
526                    pragmas: TogglePragmas::new(),
527                    sqlite_path
528                })
529            }
530        }
531    }
532
533    impl sqlx_core::connection::ConnectOptions for D1ConnectOptions {
534        type Connection = D1Connection;
535
536        fn from_url(_url: &Url) -> Result<Self, sqlx_core::Error> {
537            #[cfg(target_arch = "wasm32")] {
538                Err(sqlx_core::Error::Configuration(From::from(
539                    URL_CONVERSION_UNSUPPORTED_MESSAGE
540                )))
541            }
542            #[cfg(not(target_arch = "wasm32"))] {
543                _url.as_str().parse()
544            }
545        }
546
547        fn to_url_lossy(&self) -> Url {
548            unreachable!("`sqlx_d1::ConnectOptions` doesn't support `ConnectOptions::to_url_lossy`")
549        }
550
551        fn connect(&self) -> crate::ResultFuture<'_, Self::Connection>
552        where
553            Self::Connection: Sized,
554        {
555            #[cfg(target_arch = "wasm32")] {
556                Box::pin(worker::send::SendFuture::new(async move {
557                    let d1 = self.d1.clone();
558
559                    if let Some(pragmas) = self.pragmas.collect() {
560                        JsFuture::from(d1.exec(&pragmas.join("\n")).map_err(D1Error::from)?)
561                            .await
562                            .map_err(D1Error::from)?;
563                    }
564
565                    Ok(D1Connection {
566                        inner: d1
567                    })
568                }))
569            }
570
571            #[cfg(not(target_arch = "wasm32"))] {
572                Box::pin(async move {
573                    use sqlx_core::{connection::Connection, executor::Executor};
574
575                    let mut sqlite_conn = sqlx_sqlite::SqliteConnection::connect(
576                        self.sqlite_path.to_str().ok_or(sqlx_core::Error::WorkerCrashed)?
577                    ).await?;
578
579                    if let Some(pragmas) = self.pragmas.collect() {
580                        for pragma in pragmas {
581                            sqlite_conn.execute(pragma).await?;
582                        }
583                    }
584                    
585                    Ok(D1Connection { inner: sqlite_conn })
586                })
587            }
588        }
589
590        fn log_statements(self, _: log::LevelFilter) -> Self {
591            unreachable!("{LOG_SETTINGS_UNSUPPORTED_MESSAGE}")
592        }
593
594        fn log_slow_statements(self, _: log::LevelFilter, _: std::time::Duration) -> Self {
595            unreachable!("{LOG_SETTINGS_UNSUPPORTED_MESSAGE}")
596        }
597    }
598};
599
600/// ref: <https://developers.cloudflare.com/d1/sql-api/sql-statements/#compatible-pragma-statements>
601#[derive(Clone, Copy)]
602struct TogglePragmas(u8);
603const _: () = {
604    impl std::ops::Not for TogglePragmas {
605        type Output = Self;
606        fn not(self) -> Self::Output {
607            Self(!self.0)
608        }
609    }
610    impl std::ops::BitOrAssign for TogglePragmas {
611        fn bitor_assign(&mut self, rhs: Self) {
612            self.0 |= self.0 | rhs.0;
613        }
614    }
615    impl std::ops::BitAndAssign for TogglePragmas {
616        fn bitand_assign(&mut self, rhs: Self) {
617            self.0 &= self.0 & rhs.0;
618        }
619    }
620    
621    impl TogglePragmas {
622        const fn new() -> Self {
623            Self(0)
624        }
625    }
626};
627
628macro_rules! toggles {
629    ($( $name:ident as $bits:literal; )*) => {
630        impl TogglePragmas {
631            $(
632                #[allow(non_upper_case_globals)]
633                const $name: Self = Self($bits);
634            )*
635
636            fn collect(&self) -> Option<Vec<&'static str>> {
637                #[allow(unused_mut)]
638                let mut pragmas = Vec::new();
639                $(
640                    if self.0 & Self::$name.0 != 0 {
641                        pragmas.push(concat!(
642                            "PRAGMA ",
643                            stringify!($name),
644                            " = on"
645                        ));
646                    }
647                )*
648                (!pragmas.is_empty()).then_some(pragmas)
649            }
650        }
651
652        impl std::fmt::Debug for TogglePragmas {
653            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
654                let mut f = &mut f.debug_map();
655                $(
656                    f = f.entry(
657                        &stringify!($name),
658                        &if self.0 & Self::$name.0 != 0 {"on"} else {"off"}
659                    );
660                )*
661                f.finish()
662            }
663        }
664
665        impl D1ConnectOptions {
666            $(
667                pub fn $name(mut self, yes: bool) -> Self {
668                    if yes {
669                        self.pragmas |= TogglePragmas::$name;
670                    } else {
671                        self.pragmas &= !TogglePragmas::$name;
672                    }
673                    self
674                }
675            )*
676        }
677    };
678}
679toggles! {
680    case_sensitive_like     as 0b0000001;
681    ignore_check_constraint as 0b0000010;
682    legacy_alter_table      as 0b0000100;
683    recursive_triggers      as 0b0001000;
684    unordered_selects       as 0b0010000;
685    foreign_keys            as 0b0100000;
686    defer_foreign_keys      as 0b1000000;
687}