tauri_plugin_sql/
wrapper.rs

1// Copyright 2019-2023 Tauri Programme within The Commons Conservancy
2// SPDX-License-Identifier: Apache-2.0
3// SPDX-License-Identifier: MIT
4
5#[cfg(feature = "sqlite")]
6use std::fs::create_dir_all;
7
8use indexmap::IndexMap;
9use serde_json::Value as JsonValue;
10#[cfg(any(feature = "sqlite", feature = "mysql", feature = "postgres"))]
11use sqlx::{migrate::MigrateDatabase, Column, Executor, Pool, Row};
12#[cfg(any(feature = "sqlite", feature = "mysql", feature = "postgres"))]
13use tauri::Manager;
14use tauri::{AppHandle, Runtime};
15
16#[cfg(feature = "mysql")]
17use sqlx::MySql;
18#[cfg(feature = "postgres")]
19use sqlx::Postgres;
20#[cfg(feature = "sqlite")]
21use sqlx::Sqlite;
22
23use crate::LastInsertId;
24
25pub enum DbPool {
26    #[cfg(feature = "sqlite")]
27    Sqlite(Pool<Sqlite>),
28    #[cfg(feature = "mysql")]
29    MySql(Pool<MySql>),
30    #[cfg(feature = "postgres")]
31    Postgres(Pool<Postgres>),
32    #[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
33    None,
34}
35
36// public methods
37/* impl DbPool {
38    /// Get the inner Sqlite Pool. Returns None for MySql and Postgres pools.
39    #[cfg(feature = "sqlite")]
40    pub fn sqlite(&self) -> Option<&Pool<Sqlite>> {
41        match self {
42            DbPool::Sqlite(pool) => Some(pool),
43            _ => None,
44        }
45    }
46
47    /// Get the inner MySql Pool. Returns None for Sqlite and Postgres pools.
48    #[cfg(feature = "mysql")]
49    pub fn mysql(&self) -> Option<&Pool<MySql>> {
50        match self {
51            DbPool::MySql(pool) => Some(pool),
52            _ => None,
53        }
54    }
55
56    /// Get the inner Postgres Pool. Returns None for MySql and Sqlite pools.
57    #[cfg(feature = "postgres")]
58    pub fn postgres(&self) -> Option<&Pool<Postgres>> {
59        match self {
60            DbPool::Postgres(pool) => Some(pool),
61            _ => None,
62        }
63    }
64} */
65
66// private methods
67impl DbPool {
68    pub(crate) async fn connect<R: Runtime>(
69        conn_url: &str,
70        _app: &AppHandle<R>,
71    ) -> Result<Self, crate::Error> {
72        match conn_url
73            .split_once(':')
74            .ok_or_else(|| crate::Error::InvalidDbUrl(conn_url.to_string()))?
75            .0
76        {
77            #[cfg(feature = "sqlite")]
78            "sqlite" => {
79                let app_path = _app
80                    .path()
81                    .app_config_dir()
82                    .expect("No App config path was found!");
83
84                create_dir_all(&app_path).expect("Couldn't create app config dir");
85
86                let conn_url = &path_mapper(app_path, conn_url);
87
88                if !Sqlite::database_exists(conn_url).await.unwrap_or(false) {
89                    Sqlite::create_database(conn_url).await?;
90                }
91                Ok(Self::Sqlite(Pool::connect(conn_url).await?))
92            }
93            #[cfg(feature = "mysql")]
94            "mysql" => {
95                if !MySql::database_exists(conn_url).await.unwrap_or(false) {
96                    MySql::create_database(conn_url).await?;
97                }
98                Ok(Self::MySql(Pool::connect(conn_url).await?))
99            }
100            #[cfg(feature = "postgres")]
101            "postgres" => {
102                if !Postgres::database_exists(conn_url).await.unwrap_or(false) {
103                    Postgres::create_database(conn_url).await?;
104                }
105                Ok(Self::Postgres(Pool::connect(conn_url).await?))
106            }
107            #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mysql")))]
108            _ => Err(crate::Error::InvalidDbUrl(format!(
109                "{conn_url} - No database driver enabled!"
110            ))),
111            #[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
112            _ => Err(crate::Error::InvalidDbUrl(conn_url.to_string())),
113        }
114    }
115
116    pub(crate) async fn migrate(
117        &self,
118        _migrator: &sqlx::migrate::Migrator,
119    ) -> Result<(), crate::Error> {
120        match self {
121            #[cfg(feature = "sqlite")]
122            DbPool::Sqlite(pool) => _migrator.run(pool).await?,
123            #[cfg(feature = "mysql")]
124            DbPool::MySql(pool) => _migrator.run(pool).await?,
125            #[cfg(feature = "postgres")]
126            DbPool::Postgres(pool) => _migrator.run(pool).await?,
127            #[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
128            DbPool::None => (),
129        }
130        Ok(())
131    }
132
133    pub(crate) async fn close(&self) {
134        match self {
135            #[cfg(feature = "sqlite")]
136            DbPool::Sqlite(pool) => pool.close().await,
137            #[cfg(feature = "mysql")]
138            DbPool::MySql(pool) => pool.close().await,
139            #[cfg(feature = "postgres")]
140            DbPool::Postgres(pool) => pool.close().await,
141            #[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
142            DbPool::None => (),
143        }
144    }
145
146    pub(crate) async fn execute(
147        &self,
148        _query: String,
149        _values: Vec<JsonValue>,
150    ) -> Result<(u64, LastInsertId), crate::Error> {
151        Ok(match self {
152            #[cfg(feature = "sqlite")]
153            DbPool::Sqlite(pool) => {
154                let mut query = sqlx::query(&_query);
155                for value in _values {
156                    if value.is_null() {
157                        query = query.bind(None::<JsonValue>);
158                    } else if value.is_string() {
159                        query = query.bind(value.as_str().unwrap().to_owned())
160                    } else if let Some(number) = value.as_number() {
161                        query = query.bind(number.as_f64().unwrap_or_default())
162                    } else {
163                        query = query.bind(value);
164                    }
165                }
166                let result = pool.execute(query).await?;
167                (
168                    result.rows_affected(),
169                    LastInsertId::Sqlite(result.last_insert_rowid()),
170                )
171            }
172            #[cfg(feature = "mysql")]
173            DbPool::MySql(pool) => {
174                let mut query = sqlx::query(&_query);
175                for value in _values {
176                    if value.is_null() {
177                        query = query.bind(None::<JsonValue>);
178                    } else if value.is_string() {
179                        query = query.bind(value.as_str().unwrap().to_owned())
180                    } else if let Some(number) = value.as_number() {
181                        query = query.bind(number.as_f64().unwrap_or_default())
182                    } else {
183                        query = query.bind(value);
184                    }
185                }
186                let result = pool.execute(query).await?;
187                (
188                    result.rows_affected(),
189                    LastInsertId::MySql(result.last_insert_id()),
190                )
191            }
192            #[cfg(feature = "postgres")]
193            DbPool::Postgres(pool) => {
194                let mut query = sqlx::query(&_query);
195                for value in _values {
196                    if value.is_null() {
197                        query = query.bind(None::<JsonValue>);
198                    } else if value.is_string() {
199                        query = query.bind(value.as_str().unwrap().to_owned())
200                    } else if let Some(number) = value.as_number() {
201                        query = query.bind(number.as_f64().unwrap_or_default())
202                    } else {
203                        query = query.bind(value);
204                    }
205                }
206                let result = pool.execute(query).await?;
207                (result.rows_affected(), LastInsertId::Postgres(()))
208            }
209            #[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
210            DbPool::None => (0, LastInsertId::None),
211        })
212    }
213
214    pub(crate) async fn select(
215        &self,
216        _query: String,
217        _values: Vec<JsonValue>,
218    ) -> Result<Vec<IndexMap<String, JsonValue>>, crate::Error> {
219        Ok(match self {
220            #[cfg(feature = "sqlite")]
221            DbPool::Sqlite(pool) => {
222                let mut query = sqlx::query(&_query);
223                for value in _values {
224                    if value.is_null() {
225                        query = query.bind(None::<JsonValue>);
226                    } else if value.is_string() {
227                        query = query.bind(value.as_str().unwrap().to_owned())
228                    } else if let Some(number) = value.as_number() {
229                        query = query.bind(number.as_f64().unwrap_or_default())
230                    } else {
231                        query = query.bind(value);
232                    }
233                }
234                let rows = pool.fetch_all(query).await?;
235                let mut values = Vec::new();
236                for row in rows {
237                    let mut value = IndexMap::default();
238                    for (i, column) in row.columns().iter().enumerate() {
239                        let v = row.try_get_raw(i)?;
240
241                        let v = crate::decode::sqlite::to_json(v)?;
242
243                        value.insert(column.name().to_string(), v);
244                    }
245
246                    values.push(value);
247                }
248                values
249            }
250            #[cfg(feature = "mysql")]
251            DbPool::MySql(pool) => {
252                let mut query = sqlx::query(&_query);
253                for value in _values {
254                    if value.is_null() {
255                        query = query.bind(None::<JsonValue>);
256                    } else if value.is_string() {
257                        query = query.bind(value.as_str().unwrap().to_owned())
258                    } else if let Some(number) = value.as_number() {
259                        query = query.bind(number.as_f64().unwrap_or_default())
260                    } else {
261                        query = query.bind(value);
262                    }
263                }
264                let rows = pool.fetch_all(query).await?;
265                let mut values = Vec::new();
266                for row in rows {
267                    let mut value = IndexMap::default();
268                    for (i, column) in row.columns().iter().enumerate() {
269                        let v = row.try_get_raw(i)?;
270
271                        let v = crate::decode::mysql::to_json(v)?;
272
273                        value.insert(column.name().to_string(), v);
274                    }
275
276                    values.push(value);
277                }
278                values
279            }
280            #[cfg(feature = "postgres")]
281            DbPool::Postgres(pool) => {
282                let mut query = sqlx::query(&_query);
283                for value in _values {
284                    if value.is_null() {
285                        query = query.bind(None::<JsonValue>);
286                    } else if value.is_string() {
287                        query = query.bind(value.as_str().unwrap().to_owned())
288                    } else if let Some(number) = value.as_number() {
289                        query = query.bind(number.as_f64().unwrap_or_default())
290                    } else {
291                        query = query.bind(value);
292                    }
293                }
294                let rows = pool.fetch_all(query).await?;
295                let mut values = Vec::new();
296                for row in rows {
297                    let mut value = IndexMap::default();
298                    for (i, column) in row.columns().iter().enumerate() {
299                        let v = row.try_get_raw(i)?;
300
301                        let v = crate::decode::postgres::to_json(v)?;
302
303                        value.insert(column.name().to_string(), v);
304                    }
305
306                    values.push(value);
307                }
308                values
309            }
310            #[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
311            DbPool::None => Vec::new(),
312        })
313    }
314}
315
316#[cfg(feature = "sqlite")]
317/// Maps the user supplied DB connection string to a connection string
318/// with a fully qualified file path to the App's designed "app_path"
319fn path_mapper(mut app_path: std::path::PathBuf, connection_string: &str) -> String {
320    app_path.push(
321        connection_string
322            .split_once(':')
323            .expect("Couldn't parse the connection string for DB!")
324            .1,
325    );
326
327    format!(
328        "sqlite:{}",
329        app_path
330            .to_str()
331            .expect("Problem creating fully qualified path to Database file!")
332    )
333}