sqlx_d1_core/
connection.rs

1use sqlx_core::{Url, Either};
2
3#[cfg(target_arch = "wasm32")]
4use {
5    crate::{error::D1Error, row::D1Row},
6    std::pin::Pin,
7    worker::{wasm_bindgen::JsValue, wasm_bindgen_futures::JsFuture, js_sys},
8};
9
10/// ## Note
11/// 
12/// Make sure to use where _**target is set to `wasm32-unknown-unknown`**_.
13/// In non-`wasm32-unknown-unknown` environment, `D1Connection` doesn't provide
14/// `new` or `impl Executor for &_`.
15/// 
16/// Typically, place following `.cargo/config.toml` at project/workspace root :
17/// 
18/// ---
19/// ```toml
20/// [build]
21/// target = "wasm32-unknown-unknown"
22/// ```
23/// ---
24/// 
25/// ## Example
26/// 
27/// ```toml
28/// # Cargo.toml
29/// 
30/// [dependencies]
31/// sqlx_d1 = { version = "0.1", features = ["macros"] }
32/// worker = { version = "0.5", features = ["d1"] }
33/// serde = { version = "1.0", features = ["derive"] }
34/// ```
35/// ```toml
36/// # wrangler.toml
37/// 
38/// [[d1_database]]
39/// binding = "DB"
40/// database_name = "..."
41/// database_id = "..."
42/// ```
43/// ```rust,ignore
44/// // src/lib.rs
45/// 
46/// #[worker::event(fetch)]
47/// async fn main(
48///     mut req: worker::Request,
49///     env: worker::Env,
50///     _ctx: worker::Context,
51/// ) -> worker::Result<worker::Response> {
52///     let d1 = env.d1("DB")?;
53///     let conn = sqlx_d1::D1Connection::new(d1);
54/// 
55///     #[derive(serde::Deserialize)]
56///     struct CreateUser {
57///         name: String,
58///         age: Option<u8>,
59///     }
60/// 
61///     let req = req.json::<CreateUser>().await?;
62/// 
63///     let id = sqlx_d1::query!(
64///         "
65///         INSERT INTO users (name, age) VALUES (?, ?)
66///         RETURNING id
67///         ",
68///             req.name,
69///             req.age
70///         )
71///         .fetch_one(&conn)
72///         .await
73///         .map_err(|e| worker::Error::RustError(e.to_string()))?
74///         .id;
75/// 
76///     worker::Response::ok(format!("Your id is {id}!"))
77/// }
78/// ```
79pub struct D1Connection {
80    #[cfg(target_arch = "wasm32")]
81    pub(crate) inner: worker_sys::D1Database,
82
83    #[cfg(not(target_arch = "wasm32"))]
84    pub(crate) inner: sqlx_sqlite::SqliteConnection,
85}
86
87const _: () = {
88    /* SAFETY: used in single-threaded Workers */
89    unsafe impl Send for D1Connection {}
90    unsafe impl Sync for D1Connection {}
91
92    impl D1Connection {
93        #[cfg(target_arch = "wasm32")]
94        pub fn new(d1: worker::D1Database) -> Self {
95            Self { inner: unsafe {std::mem::transmute(d1)} }
96        }
97
98        #[cfg(not(target_arch = "wasm32"))]
99        pub async fn connect(url: impl AsRef<str>) -> Result<Self, sqlx_core::Error> {
100            <Self as sqlx_core::connection::Connection>::connect(url.as_ref()).await
101        }
102    }
103
104    #[cfg(target_arch = "wasm32")]
105    impl Clone for D1Connection {
106        fn clone(&self) -> Self {
107            Self { inner: self.inner.clone() }
108        }
109    }
110
111    impl std::fmt::Debug for D1Connection {
112        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113            f.debug_struct("D1Connection").finish()
114        }
115    }
116
117    impl sqlx_core::connection::Connection for D1Connection {
118        type Database = crate::D1;
119
120        type Options = D1ConnectOptions;
121
122        fn close(self) -> crate::ResultFuture<'static, ()> {
123            Box::pin(async {Ok(())})
124        }
125
126        fn close_hard(self) -> crate::ResultFuture<'static, ()> {
127            Box::pin(async {Ok(())})
128        }
129
130        fn ping(&mut self) -> crate::ResultFuture<'_, ()> {
131            Box::pin(async {Ok(())})
132        }
133
134        fn begin(&mut self) -> crate::ResultFuture<'_, sqlx_core::transaction::Transaction<'_, Self::Database>>
135        where
136            Self: Sized,
137        {
138            sqlx_core::transaction::Transaction::begin(self)
139        }
140
141        fn shrink_buffers(&mut self) {
142            /* do nothing */
143        }
144
145        fn flush(&mut self) -> crate::ResultFuture<'_, ()> {
146            Box::pin(async {Ok(())})
147        }
148
149        fn should_flush(&self) -> bool {
150            false
151        }
152    }
153
154    impl<'c> sqlx_core::executor::Executor<'c> for &'c mut D1Connection {
155        type Database = crate::D1;
156
157        fn fetch_many<'e, 'q: 'e, E>(
158            self,
159            #[allow(unused)]
160            mut query: E,
161        ) -> futures_core::stream::BoxStream<
162            'e,
163            Result<
164                Either<
165                    <Self::Database as sqlx_core::database::Database>::QueryResult,
166                    <Self::Database as sqlx_core::database::Database>::Row
167                >,
168                sqlx_core::Error,
169            >,
170        >
171        where
172            'c: 'e,
173            E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
174        {
175            #[cfg(not(target_arch = "wasm32"))] {
176                unreachable!("native `Executor` impl")
177            }
178            #[cfg(target_arch = "wasm32")] {
179                <&'c D1Connection as sqlx_core::executor::Executor<'c>>::fetch_many(self, query)
180            }
181        }
182
183        fn fetch_optional<'e, 'q: 'e, E>(
184            self,
185            #[allow(unused)]
186            mut query: E,
187        ) -> crate::ResultFuture<'e, Option<<Self::Database as sqlx_core::database::Database>::Row>>
188        where
189            'c: 'e,
190            E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
191        {
192            #[cfg(not(target_arch = "wasm32"))] {
193                unreachable!("native `Executor` impl")
194            }
195            #[cfg(target_arch = "wasm32")] {
196                <&'c D1Connection as sqlx_core::executor::Executor<'c>>::fetch_optional(self, query)
197            }
198        }
199
200        fn prepare_with<'e, 'q: 'e>(
201            self,
202            sql: &'q str,
203            _parameters: &'e [<Self::Database as sqlx_core::database::Database>::TypeInfo],
204        ) -> crate::ResultFuture<'e, <Self::Database as sqlx_core::database::Database>::Statement<'q>>
205        where
206            'c: 'e,
207        {
208            Box::pin(async {
209                Ok(crate::statement::D1Statement {
210                    sql: std::borrow::Cow::Borrowed(sql),
211                })
212            })
213        }
214
215        fn describe<'e, 'q: 'e>(
216            self,
217            #[allow(unused)]
218            sql: &'q str,
219        ) -> crate::ResultFuture<'e, sqlx_core::describe::Describe<Self::Database>>
220        where
221            'c: 'e,
222        {
223            #[cfg(target_arch = "wasm32")] {
224                unreachable!("wasm32 describe")
225            }
226            #[cfg(not(target_arch = "wasm32"))] {
227                /* compile-time verification by macros */
228
229                Box::pin(async {
230                    let sqlx_core::describe::Describe {
231                        columns,
232                        parameters,
233                        nullable
234                    } = <&mut sqlx_sqlite::SqliteConnection as sqlx_core::executor::Executor>::describe(
235                        &mut self.inner,
236                        sql
237                    ).await?;
238                    
239                    Ok(sqlx_core::describe::Describe {
240                        parameters: parameters.map(|ps| match ps {
241                            Either::Left(type_infos) => Either::Left(type_infos.into_iter().map(crate::type_info::D1TypeInfo::from_sqlite).collect()),
242                            Either::Right(n) => Either::Right(n)
243                        }),
244                        columns: columns.into_iter().map(crate::column::D1Column::from_sqlite).collect(),
245                        nullable
246                    })
247                })
248            }
249        }
250    }
251
252    #[cfg(target_arch = "wasm32")]
253    impl<'c> sqlx_core::executor::Executor<'c> for &'c D1Connection {
254        type Database = crate::D1;
255
256        fn fetch_many<'e, 'q: 'e, E>(
257            self,
258            #[allow(unused)]
259            mut query: E,
260        ) -> futures_core::stream::BoxStream<
261            'e,
262            Result<
263                Either<
264                    <Self::Database as sqlx_core::database::Database>::QueryResult,
265                    <Self::Database as sqlx_core::database::Database>::Row
266                >,
267                sqlx_core::Error,
268            >,
269        >
270        where
271            'c: 'e,
272            E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
273        {
274            let sql = query.sql();
275            let arguments = match query.take_arguments() {
276                Ok(a) => a,
277                Err(e) => return Box::pin(futures_util::stream::once(async {Err(sqlx_core::Error::Encode(e))})),
278            };
279
280            struct FetchMany<F> {
281                raw_rows_future: F,
282                raw_rows: Option<js_sys::ArrayIntoIter>,
283            }
284            const _: () = {
285                /* SAFETY: used in single-threaded Workers */
286                unsafe impl<F> Send for FetchMany<F> {}
287
288                impl<F> FetchMany<F> {
289                    fn new(raw_rows_future: F) -> Self {
290                        Self { raw_rows_future, raw_rows: None }
291                    }
292                }
293
294                impl<F> futures_core::Stream for FetchMany<F>
295                where
296                    F: Future<Output = Result<Option<js_sys::Array>, JsValue>>,
297                {
298                    type Item = Result<
299                        Either<crate::query_result::D1QueryResult, D1Row>,
300                        sqlx_core::Error
301                    >;
302
303                    fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
304                        use std::task::Poll;
305
306                        fn pop_next(raw_rows: &mut js_sys::ArrayIntoIter) ->
307                            Option<Result<
308                                Either<crate::query_result::D1QueryResult, D1Row>,
309                                sqlx_core::Error
310                            >>
311                        {
312                            let raw_row = raw_rows.next()?;
313                            Some(D1Row::from_raw(raw_row).map(Either::Right))
314                        }
315
316                        let this = unsafe {self.get_unchecked_mut()};
317                        match &mut this.raw_rows {
318                            Some(raw_rows) => Poll::Ready(pop_next(raw_rows)),
319                            None => match unsafe {Pin::new_unchecked(&mut this.raw_rows_future)}.poll(cx) {
320                                Poll::Pending => Poll::Pending,
321                                Poll::Ready(Err(e)) => Poll::Ready(Some(Err(
322                                    sqlx_core::Error::from(D1Error::from(e))
323                                ))),
324                                Poll::Ready(Ok(maybe_raw_rows)) => {
325                                    this.raw_rows = Some(maybe_raw_rows.unwrap_or_else(js_sys::Array::new).into_iter());
326                                    Poll::Ready(pop_next(unsafe {this.raw_rows.as_mut().unwrap_unchecked()}))
327                                }
328                            }
329                        }                        
330                    }
331                }
332            };
333
334            Box::pin(FetchMany::new(async move {
335                let mut statement = self.inner.prepare(sql).unwrap();
336                if let Some(a) = arguments {
337                    statement = statement.bind(a.as_ref().iter().collect())?;
338                }
339
340                let d1_result_jsvalue = JsFuture::from(statement.all()?)
341                    .await?;
342                worker_sys::D1Result::from(d1_result_jsvalue)
343                    .results()
344            }))
345        }
346
347        fn fetch_optional<'e, 'q: 'e, E>(
348            self,
349            #[allow(unused)]
350            mut query: E,
351        ) -> crate::ResultFuture<'e, Option<<Self::Database as sqlx_core::database::Database>::Row>>
352        where
353            'c: 'e,
354            E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
355        {
356            let sql = query.sql();
357            let arguments = match query.take_arguments() {
358                Ok(a) => a,
359                Err(e) => return Box::pin(async {Err(sqlx_core::Error::Encode(e))}),
360            };
361
362            Box::pin(worker::send::SendFuture::new(async move {
363                let mut statement = self.inner.prepare(sql).unwrap();
364                if let Some(a) = arguments {
365                    statement = statement
366                        .bind(a.as_ref().iter().collect())
367                        .map_err(|e| sqlx_core::Error::Encode(Box::new(D1Error::from(e))))?;
368                }
369
370                let raw = JsFuture::from(statement.first(None).map_err(D1Error::from)?)
371                    .await
372                    .map_err(D1Error::from)?;
373                if raw.is_null() {
374                    Ok(None)
375                } else {
376                    D1Row::from_raw(raw).map(Some)
377                }
378            }))
379        }
380
381        fn prepare_with<'e, 'q: 'e>(
382            self,
383            sql: &'q str,
384            _parameters: &'e [<Self::Database as sqlx_core::database::Database>::TypeInfo],
385        ) -> crate::ResultFuture<'e, <Self::Database as sqlx_core::database::Database>::Statement<'q>>
386        where
387            'c: 'e,
388        {
389            Box::pin(async {
390                Ok(crate::statement::D1Statement {
391                    sql: std::borrow::Cow::Borrowed(sql),
392                })
393            })
394        }
395
396        fn describe<'e, 'q: 'e>(
397            self,
398            #[allow(unused)]
399            sql: &'q str,
400        ) -> crate::ResultFuture<'e, sqlx_core::describe::Describe<Self::Database>>
401        where
402            'c: 'e,
403        {
404            unreachable!("wasm32 describe")
405        }
406    }
407};
408
409/// ref: <https://developers.cloudflare.com/d1/sql-api/sql-statements/#compatible-pragma-statements>
410#[derive(Clone)]
411pub struct D1ConnectOptions {
412    pragmas: TogglePragmas,
413    #[cfg(target_arch = "wasm32")]
414    d1: worker_sys::D1Database,
415    #[cfg(not(target_arch = "wasm32"))]
416    sqlite_path: std::path::PathBuf,
417}
418const _: () = {
419    /* SAFETY: used in single-threaded Workers */
420    unsafe impl Send for D1ConnectOptions {}
421    unsafe impl Sync for D1ConnectOptions {}
422
423    #[cfg(target_arch = "wasm32")]
424    const URL_CONVERSION_UNSUPPORTED_MESSAGE: &'static str = "\
425        `sqlx_d1::D1ConnectOptions` doesn't support conversion between `Url`. \
426        Consider connect from options created by `D1ConnectOptions::new`. \
427    ";
428
429    const LOG_SETTINGS_UNSUPPORTED_MESSAGE: &'static str = "\
430        `sqlx_d1::D1ConnectOptions` doesn't support log settings.
431    ";
432
433    impl D1ConnectOptions {
434        #[cfg(target_arch = "wasm32")]
435        pub fn new(d1: worker::D1Database) -> Self {
436            Self {
437                d1: unsafe {core::mem::transmute(d1)},
438                pragmas: TogglePragmas::new(),
439            }
440        }
441    }
442
443    impl std::fmt::Debug for D1ConnectOptions {
444        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
445            f.debug_struct("D1ConnectOptions")
446                .field("pragmas", &self.pragmas)
447                .finish()
448        }
449    }
450
451    impl std::str::FromStr for D1ConnectOptions {
452        type Err = sqlx_core::Error;
453
454        fn from_str(_: &str) -> Result<Self, Self::Err> {
455            #[cfg(target_arch = "wasm32")] {
456                Err(sqlx_core::Error::Configuration(From::from(
457                    URL_CONVERSION_UNSUPPORTED_MESSAGE
458                )))
459            }
460
461            #[cfg(not(target_arch = "wasm32"))] {
462                use std::{io, fs, path::{Path, PathBuf}};
463
464                fn maybe_miniflare_d1_dir_of(dir: impl AsRef<Path>) -> PathBuf {
465                    dir.as_ref()
466                        .join(".wrangler")
467                        .join("state")
468                        .join("v3")
469                        .join("d1")
470                        .join("miniflare-D1DatabaseObject")
471                }
472            
473                const PACKAGE_ROOT: &str = env!("CARGO_MANIFEST_DIR");
474
475                let (candidate_1, candidate_2) = (
476                    maybe_miniflare_d1_dir_of(PACKAGE_ROOT),
477                    maybe_miniflare_d1_dir_of(".")
478                );
479
480                let sqlite_path = (|| -> io::Result<PathBuf> {                    
481                    let miniflare_d1_dir = match (
482                        fs::exists(&candidate_1),
483                        fs::exists(&candidate_2)
484                    ) {
485                        (Ok(true), _) => candidate_1,
486                        (_, Ok(true)) => candidate_2,
487                        (Err(e), _) | (_, Err(e)) => return Err(e),
488                        (Ok(false), Ok(false)) => return Err(io::Error::new(
489                            io::ErrorKind::NotFound,
490                            "miniflare's D1 emulating directory not found"
491                        )),
492                    };
493                    
494                    let [sqlite_path] = fs::read_dir(miniflare_d1_dir)?
495                        .filter_map(|r| r.as_ref().ok().and_then(|e| {
496                            let path = e.path();
497                            path.extension()
498                                .is_some_and(|ex| ex == "sqlite")
499                                .then_some(path)
500                        }))
501                        .collect::<Vec<_>>()
502                        .try_into()
503                        .map_err(|_| io::Error::new(
504                            io::ErrorKind::Other,
505                            "Currently, sqlx_d1 doesn't support multiple D1 bindings!"
506                        ))?;
507
508                    Ok(sqlite_path)
509                })().map_err(|_| sqlx_core::Error::WorkerCrashed)?;
510                    
511                Ok(Self {
512                    pragmas: TogglePragmas::new(),
513                    sqlite_path
514                })
515            }
516        }
517    }
518
519    impl sqlx_core::connection::ConnectOptions for D1ConnectOptions {
520        type Connection = D1Connection;
521
522        fn from_url(_url: &Url) -> Result<Self, sqlx_core::Error> {
523            #[cfg(target_arch = "wasm32")] {
524                Err(sqlx_core::Error::Configuration(From::from(
525                    URL_CONVERSION_UNSUPPORTED_MESSAGE
526                )))
527            }
528            #[cfg(not(target_arch = "wasm32"))] {
529                _url.as_str().parse()
530            }
531        }
532
533        fn to_url_lossy(&self) -> Url {
534            unreachable!("`sqlx_d1::ConnectOptions` doesn't support `ConnectOptions::to_url_lossy`")
535        }
536
537        fn connect(&self) -> crate::ResultFuture<'_, Self::Connection>
538        where
539            Self::Connection: Sized,
540        {
541            #[cfg(target_arch = "wasm32")] {
542                Box::pin(worker::send::SendFuture::new(async move {
543                    let d1 = self.d1.clone();
544
545                    if let Some(pragmas) = self.pragmas.collect() {
546                        JsFuture::from(d1.exec(&pragmas.join("\n")).map_err(D1Error::from)?)
547                            .await
548                            .map_err(D1Error::from)?;
549                    }
550
551                    Ok(D1Connection {
552                        inner: d1
553                    })
554                }))
555            }
556
557            #[cfg(not(target_arch = "wasm32"))] {
558                Box::pin(async move {
559                    use sqlx_core::{connection::Connection, executor::Executor};
560
561                    let mut sqlite_conn = sqlx_sqlite::SqliteConnection::connect(
562                        self.sqlite_path.to_str().ok_or(sqlx_core::Error::WorkerCrashed)?
563                    ).await?;
564
565                    if let Some(pragmas) = self.pragmas.collect() {
566                        for pragma in pragmas {
567                            sqlite_conn.execute(pragma).await?;
568                        }
569                    }
570                    
571                    Ok(D1Connection { inner: sqlite_conn })
572                })
573            }
574        }
575
576        fn log_statements(self, _: log::LevelFilter) -> Self {
577            unreachable!("{LOG_SETTINGS_UNSUPPORTED_MESSAGE}")
578        }
579
580        fn log_slow_statements(self, _: log::LevelFilter, _: std::time::Duration) -> Self {
581            unreachable!("{LOG_SETTINGS_UNSUPPORTED_MESSAGE}")
582        }
583    }
584};
585
586/// ref: <https://developers.cloudflare.com/d1/sql-api/sql-statements/#compatible-pragma-statements>
587#[derive(Clone, Copy)]
588struct TogglePragmas(u8);
589const _: () = {
590    impl std::ops::Not for TogglePragmas {
591        type Output = Self;
592        fn not(self) -> Self::Output {
593            Self(!self.0)
594        }
595    }
596    impl std::ops::BitOrAssign for TogglePragmas {
597        fn bitor_assign(&mut self, rhs: Self) {
598            self.0 |= self.0 | rhs.0;
599        }
600    }
601    impl std::ops::BitAndAssign for TogglePragmas {
602        fn bitand_assign(&mut self, rhs: Self) {
603            self.0 &= self.0 & rhs.0;
604        }
605    }
606    
607    impl TogglePragmas {
608        const fn new() -> Self {
609            Self(0)
610        }
611    }
612};
613
614macro_rules! toggles {
615    ($( $name:ident as $bits:literal; )*) => {
616        impl TogglePragmas {
617            $(
618                #[allow(non_upper_case_globals)]
619                const $name: Self = Self($bits);
620            )*
621
622            fn collect(&self) -> Option<Vec<&'static str>> {
623                #[allow(unused_mut)]
624                let mut pragmas = Vec::new();
625                $(
626                    if self.0 & Self::$name.0 != 0 {
627                        pragmas.push(concat!(
628                            "PRAGMA ",
629                            stringify!($name),
630                            " = on"
631                        ));
632                    }
633                )*
634                (!pragmas.is_empty()).then_some(pragmas)
635            }
636        }
637
638        impl std::fmt::Debug for TogglePragmas {
639            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
640                let mut f = &mut f.debug_map();
641                $(
642                    f = f.entry(
643                        &stringify!($name),
644                        &if self.0 & Self::$name.0 != 0 {"on"} else {"off"}
645                    );
646                )*
647                f.finish()
648            }
649        }
650
651        impl D1ConnectOptions {
652            $(
653                pub fn $name(mut self, yes: bool) -> Self {
654                    if yes {
655                        self.pragmas |= TogglePragmas::$name;
656                    } else {
657                        self.pragmas &= !TogglePragmas::$name;
658                    }
659                    self
660                }
661            )*
662        }
663    };
664}
665toggles! {
666    case_sensitive_like     as 0b0000001;
667    ignore_check_constraint as 0b0000010;
668    legacy_alter_table      as 0b0000100;
669    recursive_triggers      as 0b0001000;
670    unordered_selects       as 0b0010000;
671    foreign_keys            as 0b0100000;
672    defer_foreign_keys      as 0b1000000;
673}