Skip to main content

tauri_plugin_libsql/
wrapper.rs

1use futures::lock::Mutex;
2use futures::FutureExt;
3use indexmap::IndexMap;
4use libsql::{params::Params, Builder as LibsqlBuilder, Connection, Database, Value};
5use serde_json::Value as JsonValue;
6use std::collections::HashMap;
7use std::panic::AssertUnwindSafe;
8use std::path::{Component, Path, PathBuf};
9use std::sync::Arc;
10
11use crate::decode;
12use crate::error::Error;
13use crate::models::{EncryptionConfig, QueryResult};
14
15/// A wrapper around libsql connection
16pub struct DbConnection {
17    conn: Connection,
18    db: Database,
19}
20
21impl DbConnection {
22    /// Connect to a libsql database.
23    ///
24    /// - Local only: `sync_url` = None
25    /// - Embedded replica (Turso): `sync_url` = Some("libsql://…"), `auth_token` = Some("…")
26    /// - Pure remote: `path` starts with "libsql://" or "https://", no `sync_url`
27    pub async fn connect(
28        path: &str,
29        encryption: Option<EncryptionConfig>,
30        base_path: PathBuf,
31        sync_url: Option<String>,
32        auth_token: Option<String>,
33    ) -> Result<Self, Error> {
34        // Wrap in catch_unwind: libsql's builder calls unwrap() internally and can
35        // panic on a malformed URL, which would cause the Tauri IPC to hang forever.
36        let path = path.to_string();
37        let db = AssertUnwindSafe(async move {
38            if let Some(url) = sync_url {
39                let full_path = Self::resolve_local_path(&path, &base_path)?;
40                Self::open_replica(full_path, url, auth_token.unwrap_or_default(), encryption).await
41            } else if path.starts_with("libsql://") || path.starts_with("https://") {
42                Self::open_remote(path, auth_token.unwrap_or_default()).await
43            } else {
44                let full_path = Self::resolve_local_path(&path, &base_path)?;
45                Self::open_local(full_path, encryption).await
46            }
47        })
48        .catch_unwind()
49        .await
50        .map_err(|_| {
51            Error::InvalidDbUrl(
52                "libsql panicked building the database — check your URL format \
53                 (expected libsql://… or https://…)"
54                    .into(),
55            )
56        })??;
57
58        let conn = db.connect()?;
59        Ok(Self { conn, db })
60    }
61
62    // ── connection mode helpers ──────────────────────────────────────────────
63
64    fn resolve_local_path(path: &str, base_path: &Path) -> Result<PathBuf, Error> {
65        let db_path = path.strip_prefix("sqlite:").unwrap_or(path);
66
67        if db_path == ":memory:" {
68            return Ok(PathBuf::from(":memory:"));
69        }
70
71        if PathBuf::from(db_path).is_absolute() {
72            return Ok(PathBuf::from(db_path));
73        }
74
75        // Normalise away `..` so a path can't escape base_path
76        let joined = base_path.join(db_path);
77        let normalised = joined.components().fold(PathBuf::new(), |mut acc, c| {
78            match c {
79                Component::ParentDir => {
80                    acc.pop();
81                }
82                Component::CurDir => {}
83                _ => acc.push(c),
84            }
85            acc
86        });
87
88        if !normalised.starts_with(base_path) {
89            return Err(Error::InvalidDbUrl(format!(
90                "path '{}' escapes the base directory",
91                db_path
92            )));
93        }
94
95        Ok(normalised)
96    }
97
98    async fn open_local(
99        full_path: PathBuf,
100        encryption: Option<EncryptionConfig>,
101    ) -> Result<Database, Error> {
102        #[allow(unused_mut)]
103        let mut builder = LibsqlBuilder::new_local(&full_path.to_string_lossy().to_string());
104
105        #[cfg(feature = "encryption")]
106        if let Some(config) = encryption {
107            builder = builder.encryption_config(config.into());
108        }
109        #[cfg(not(feature = "encryption"))]
110        if encryption.is_some() {
111            return Err(Error::InvalidDbUrl(
112                "encryption feature is not enabled — rebuild with the `encryption` feature".into(),
113            ));
114        }
115
116        Ok(builder.build().await?)
117    }
118
119    #[cfg(feature = "replication")]
120    async fn open_replica(
121        full_path: PathBuf,
122        sync_url: String,
123        auth_token: String,
124        encryption: Option<EncryptionConfig>,
125    ) -> Result<Database, Error> {
126        #[allow(unused_mut)]
127        let mut builder = LibsqlBuilder::new_remote_replica(
128            full_path.to_string_lossy().to_string(),
129            sync_url,
130            auth_token,
131        );
132
133        #[cfg(feature = "encryption")]
134        if let Some(config) = encryption {
135            builder = builder.encryption_config(config.into());
136        }
137
138        let db = builder.build().await?;
139        // Initial sync so the local replica is up-to-date on connect
140        db.sync().await?;
141        Ok(db)
142    }
143
144    #[cfg(not(feature = "replication"))]
145    async fn open_replica(
146        _full_path: PathBuf,
147        _sync_url: String,
148        _auth_token: String,
149        _encryption: Option<EncryptionConfig>,
150    ) -> Result<Database, Error> {
151        Err(Error::InvalidDbUrl(
152            "embedded replica requires the `replication` feature — add features = [\"replication\"] to your Cargo.toml".into(),
153        ))
154    }
155
156    #[cfg(feature = "remote")]
157    async fn open_remote(url: String, auth_token: String) -> Result<Database, Error> {
158        Ok(LibsqlBuilder::new_remote(url, auth_token).build().await?)
159    }
160
161    #[cfg(not(feature = "remote"))]
162    async fn open_remote(_url: String, _auth_token: String) -> Result<Database, Error> {
163        Err(Error::InvalidDbUrl(
164            "remote connections require the `remote` feature — add features = [\"remote\"] to your Cargo.toml".into(),
165        ))
166    }
167
168    // ── public API ───────────────────────────────────────────────────────────
169
170    /// Sync an embedded replica with its remote database.
171    /// No-op (returns Ok) for local-only databases when replication is disabled.
172    pub async fn sync(&self) -> Result<(), Error> {
173        Self::do_sync(&self.db).await
174    }
175
176    #[cfg(feature = "replication")]
177    async fn do_sync(db: &Database) -> Result<(), Error> {
178        db.sync().await?;
179        Ok(())
180    }
181
182    #[cfg(not(feature = "replication"))]
183    async fn do_sync(_db: &Database) -> Result<(), Error> {
184        Err(Error::OperationNotSupported(
185            "sync requires the `replication` feature".into(),
186        ))
187    }
188
189    /// Execute a query that doesn't return rows
190    pub async fn execute(&self, query: &str, values: Vec<JsonValue>) -> Result<QueryResult, Error> {
191        let params = json_to_params(values);
192        let rows_affected = self.conn.execute(query, params).await?;
193
194        Ok(QueryResult {
195            rows_affected,
196            last_insert_id: self.conn.last_insert_rowid(),
197        })
198    }
199
200    /// Execute a query that returns rows
201    pub async fn select(
202        &self,
203        query: &str,
204        values: Vec<JsonValue>,
205    ) -> Result<Vec<IndexMap<String, JsonValue>>, Error> {
206        let params = json_to_params(values);
207        let mut rows = self.conn.query(query, params).await?;
208
209        let mut results = Vec::new();
210
211        while let Some(row) = rows.next().await? {
212            let mut map = IndexMap::new();
213            let column_count = row.column_count();
214
215            for i in 0..column_count {
216                if let Some(column_name) = row.column_name(i) {
217                    let value = decode::to_json(&row, i)?;
218                    map.insert(column_name.to_string(), value);
219                }
220            }
221
222            results.push(map);
223        }
224
225        Ok(results)
226    }
227
228    /// Execute multiple SQL statements atomically inside a transaction.
229    /// Statements must not contain bound parameters — use for DDL and bulk DML only.
230    pub async fn batch(&self, queries: Vec<String>) -> Result<(), Error> {
231        self.conn.execute("BEGIN", Params::None).await?;
232        for query in &queries {
233            if let Err(e) = self.conn.execute(query.as_str(), Params::None).await {
234                let _ = self.conn.execute("ROLLBACK", Params::None).await;
235                return Err(Error::Libsql(e));
236            }
237        }
238        if let Err(e) = self.conn.execute("COMMIT", Params::None).await {
239            let _ = self.conn.execute("ROLLBACK", Params::None).await;
240            return Err(Error::Libsql(e));
241        }
242        Ok(())
243    }
244
245    pub async fn close(&self) {
246        self.conn.reset().await;
247    }
248}
249
250/// Convert JSON values to libsql params
251fn json_to_params(values: Vec<JsonValue>) -> Params {
252    if values.is_empty() {
253        return Params::None;
254    }
255
256    let params: Vec<Value> = values.into_iter().map(json_to_libsql_value).collect();
257    Params::Positional(params)
258}
259
260fn json_to_libsql_value(v: JsonValue) -> Value {
261    match v {
262        JsonValue::Null => Value::Null,
263        JsonValue::Bool(b) => Value::Integer(if b { 1 } else { 0 }),
264        JsonValue::Number(n) => {
265            if let Some(i) = n.as_i64() {
266                Value::Integer(i)
267            } else if let Some(f) = n.as_f64() {
268                Value::Real(f)
269            } else {
270                Value::Null
271            }
272        }
273        JsonValue::String(s) => Value::Text(s),
274        JsonValue::Array(ref arr) => {
275            if arr.iter().all(|v| v.is_number()) {
276                let bytes: Vec<u8> = arr
277                    .iter()
278                    .filter_map(|v| v.as_u64().map(|n| n as u8))
279                    .collect();
280                Value::Blob(bytes)
281            } else {
282                Value::Text(v.to_string())
283            }
284        }
285        JsonValue::Object(_) => Value::Text(v.to_string()),
286    }
287}
288
289/// Database instances holder
290pub struct DbInstances(pub Arc<Mutex<HashMap<String, Arc<DbConnection>>>>);
291
292impl Default for DbInstances {
293    fn default() -> Self {
294        Self(Arc::new(Mutex::new(HashMap::new())))
295    }
296}