tauri_plugin_sql/
lib.rs

1// Copyright 2019-2023 Tauri Programme within The Commons Conservancy
2// SPDX-License-Identifier: Apache-2.0
3// SPDX-License-Identifier: MIT
4
5//! Interface with SQL databases through [sqlx](https://github.com/launchbadge/sqlx). It supports the `sqlite`, `mysql` and `postgres` drivers, enabled by a Cargo feature.
6
7#![doc(
8    html_logo_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png",
9    html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png"
10)]
11
12mod commands;
13mod decode;
14mod error;
15mod wrapper;
16
17pub use error::Error;
18pub use wrapper::DbPool;
19
20use futures_core::future::BoxFuture;
21use serde::{Deserialize, Serialize};
22use sqlx::{
23    error::BoxDynError,
24    migrate::{Migration as SqlxMigration, MigrationSource, MigrationType, Migrator},
25};
26use tauri::{
27    plugin::{Builder as PluginBuilder, TauriPlugin},
28    Manager, RunEvent, Runtime,
29};
30use tokio::sync::{Mutex, RwLock};
31
32use std::collections::HashMap;
33
34#[derive(Default)]
35pub struct DbInstances(pub RwLock<HashMap<String, DbPool>>);
36
37#[derive(Serialize)]
38#[serde(untagged)]
39pub(crate) enum LastInsertId {
40    #[cfg(feature = "sqlite")]
41    Sqlite(i64),
42    #[cfg(feature = "mysql")]
43    MySql(u64),
44    #[cfg(feature = "postgres")]
45    Postgres(()),
46    #[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
47    None,
48}
49
50struct Migrations(Mutex<HashMap<String, MigrationList>>);
51
52#[derive(Default, Clone, Deserialize)]
53pub struct PluginConfig {
54    #[serde(default)]
55    preload: Vec<String>,
56}
57
58#[derive(Debug)]
59pub enum MigrationKind {
60    Up,
61    Down,
62}
63
64impl From<MigrationKind> for MigrationType {
65    fn from(kind: MigrationKind) -> Self {
66        match kind {
67            MigrationKind::Up => Self::ReversibleUp,
68            MigrationKind::Down => Self::ReversibleDown,
69        }
70    }
71}
72
73/// A migration definition.
74#[derive(Debug)]
75pub struct Migration {
76    pub version: i64,
77    pub description: &'static str,
78    pub sql: &'static str,
79    pub kind: MigrationKind,
80}
81
82#[derive(Debug)]
83struct MigrationList(Vec<Migration>);
84
85impl MigrationSource<'static> for MigrationList {
86    fn resolve(self) -> BoxFuture<'static, std::result::Result<Vec<SqlxMigration>, BoxDynError>> {
87        Box::pin(async move {
88            let mut migrations = Vec::new();
89            for migration in self.0 {
90                if matches!(migration.kind, MigrationKind::Up) {
91                    migrations.push(SqlxMigration::new(
92                        migration.version,
93                        migration.description.into(),
94                        migration.kind.into(),
95                        migration.sql.into(),
96                        false,
97                    ));
98                }
99            }
100            Ok(migrations)
101        })
102    }
103}
104
105/// Allows blocking on async code without creating a nested runtime.
106fn run_async_command<F: std::future::Future>(cmd: F) -> F::Output {
107    if tokio::runtime::Handle::try_current().is_ok() {
108        tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(cmd))
109    } else {
110        tauri::async_runtime::block_on(cmd)
111    }
112}
113
114/// Tauri SQL plugin builder.
115#[derive(Default)]
116pub struct Builder {
117    migrations: Option<HashMap<String, MigrationList>>,
118}
119
120impl Builder {
121    pub fn new() -> Self {
122        #[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgres")))]
123        eprintln!("No sql driver enabled. Please set at least one of the \"sqlite\", \"mysql\", \"postgres\" feature flags.");
124
125        Self::default()
126    }
127
128    /// Add migrations to a database.
129    #[must_use]
130    pub fn add_migrations(mut self, db_url: &str, migrations: Vec<Migration>) -> Self {
131        self.migrations
132            .get_or_insert(Default::default())
133            .insert(db_url.to_string(), MigrationList(migrations));
134        self
135    }
136
137    pub fn build<R: Runtime>(mut self) -> TauriPlugin<R, Option<PluginConfig>> {
138        PluginBuilder::<R, Option<PluginConfig>>::new("sql")
139            .invoke_handler(tauri::generate_handler![
140                commands::load,
141                commands::execute,
142                commands::select,
143                commands::close
144            ])
145            .setup(|app, api| {
146                let config = api.config().clone().unwrap_or_default();
147
148                run_async_command(async move {
149                    let instances = DbInstances::default();
150                    let mut lock = instances.0.write().await;
151
152                    for db in config.preload {
153                        let pool = DbPool::connect(&db, app).await?;
154
155                        if let Some(migrations) =
156                            self.migrations.as_mut().and_then(|mm| mm.remove(&db))
157                        {
158                            let migrator = Migrator::new(migrations).await?;
159                            pool.migrate(&migrator).await?;
160                        }
161
162                        lock.insert(db, pool);
163                    }
164                    drop(lock);
165
166                    app.manage(instances);
167                    app.manage(Migrations(Mutex::new(
168                        self.migrations.take().unwrap_or_default(),
169                    )));
170
171                    Ok(())
172                })
173            })
174            .on_event(|app, event| {
175                if let RunEvent::Exit = event {
176                    run_async_command(async move {
177                        let instances = &*app.state::<DbInstances>();
178                        let instances = instances.0.read().await;
179                        for value in instances.values() {
180                            value.close().await;
181                        }
182                    });
183                }
184            })
185            .build()
186    }
187}