sqlx_macros_core/database/
mod.rs

1use sqlx_core::config;
2use sqlx_core::connection::Connection;
3use sqlx_core::database::Database;
4use sqlx_core::describe::Describe;
5use sqlx_core::executor::Executor;
6use sqlx_core::sql_str::AssertSqlSafe;
7use sqlx_core::sql_str::SqlSafeStr;
8use sqlx_core::type_checking::TypeChecking;
9use std::collections::hash_map;
10use std::collections::HashMap;
11use std::sync::{LazyLock, Mutex};
12
13#[cfg(any(feature = "postgres", feature = "mysql", feature = "_sqlite"))]
14mod impls;
15
16pub trait DatabaseExt: Database + TypeChecking {
17    const DATABASE_PATH: &'static str;
18    const ROW_PATH: &'static str;
19
20    fn db_path() -> syn::Path {
21        syn::parse_str(Self::DATABASE_PATH).unwrap()
22    }
23
24    fn row_path() -> syn::Path {
25        syn::parse_str(Self::ROW_PATH).unwrap()
26    }
27
28    fn describe_blocking(
29        query: &str,
30        database_url: &str,
31        driver_config: &config::drivers::Config,
32    ) -> sqlx_core::Result<Describe<Self>>;
33}
34
35#[allow(dead_code)]
36pub struct CachingDescribeBlocking<DB: DatabaseExt> {
37    connections: LazyLock<Mutex<HashMap<String, DB::Connection>>>,
38}
39
40#[allow(dead_code)]
41impl<DB: DatabaseExt> CachingDescribeBlocking<DB> {
42    #[allow(clippy::new_without_default, reason = "internal API")]
43    pub const fn new() -> Self {
44        CachingDescribeBlocking {
45            connections: LazyLock::new(|| Mutex::new(HashMap::new())),
46        }
47    }
48
49    pub fn describe(
50        &self,
51        query: &str,
52        database_url: &str,
53        _driver_config: &config::drivers::Config,
54    ) -> sqlx_core::Result<Describe<DB>>
55    where
56        for<'a> &'a mut DB::Connection: Executor<'a, Database = DB>,
57    {
58        let mut cache = self
59            .connections
60            .lock()
61            .expect("previous panic in describe call");
62
63        crate::block_on(async {
64            let conn = match cache.entry(database_url.to_string()) {
65                hash_map::Entry::Occupied(hit) => hit.into_mut(),
66                hash_map::Entry::Vacant(miss) => {
67                    let conn = miss.insert(DB::Connection::connect(database_url).await?);
68
69                    #[cfg(feature = "postgres")]
70                    if DB::NAME == sqlx_postgres::Postgres::NAME {
71                        conn.execute(
72                            "
73                            DO $$
74                            BEGIN
75                                IF EXISTS (
76                                    SELECT 1
77                                    FROM pg_settings
78                                    WHERE name = 'plan_cache_mode'
79                                ) THEN
80                                    SET SESSION plan_cache_mode = 'force_generic_plan';
81                                END IF;
82                            END $$;
83                        ",
84                        )
85                        .await?;
86                    }
87                    conn
88                }
89            };
90
91            match conn
92                .describe(AssertSqlSafe(query.to_string()).into_sql_str())
93                .await
94            {
95                Ok(describe) => Ok(describe),
96                Err(e) => {
97                    if matches!(e, sqlx_core::Error::Io(_) | sqlx_core::Error::Protocol(_)) {
98                        cache.remove(database_url);
99                    }
100
101                    Err(e)
102                }
103            }
104        })
105    }
106}