tauri_plugin_libsql/
wrapper.rs1use futures::lock::Mutex;
2use futures::FutureExt;
3use indexmap::IndexMap;
4use libsql::{params::Params, Builder as LibsqlBuilder, Connection, Database, Value};
5use serde_json::Value as JsonValue;
6use std::collections::HashMap;
7use std::panic::AssertUnwindSafe;
8use std::path::{Component, Path, PathBuf};
9use std::sync::Arc;
10
11use crate::decode;
12use crate::error::Error;
13use crate::models::{EncryptionConfig, QueryResult};
14
15pub struct DbConnection {
17 conn: Connection,
18 db: Database,
19}
20
21impl DbConnection {
22 pub async fn connect(
28 path: &str,
29 encryption: Option<EncryptionConfig>,
30 base_path: PathBuf,
31 sync_url: Option<String>,
32 auth_token: Option<String>,
33 ) -> Result<Self, Error> {
34 let path = path.to_string();
37 let db = AssertUnwindSafe(async move {
38 if let Some(url) = sync_url {
39 let full_path = Self::resolve_local_path(&path, &base_path)?;
40 Self::open_replica(full_path, url, auth_token.unwrap_or_default(), encryption).await
41 } else if path.starts_with("libsql://") || path.starts_with("https://") {
42 Self::open_remote(path, auth_token.unwrap_or_default()).await
43 } else {
44 let full_path = Self::resolve_local_path(&path, &base_path)?;
45 Self::open_local(full_path, encryption).await
46 }
47 })
48 .catch_unwind()
49 .await
50 .map_err(|_| {
51 Error::InvalidDbUrl(
52 "libsql panicked building the database — check your URL format \
53 (expected libsql://… or https://…)"
54 .into(),
55 )
56 })??;
57
58 let conn = db.connect()?;
59 Ok(Self { conn, db })
60 }
61
62 fn resolve_local_path(path: &str, base_path: &Path) -> Result<PathBuf, Error> {
65 let db_path = path.strip_prefix("sqlite:").unwrap_or(path);
66
67 if db_path == ":memory:" {
68 return Ok(PathBuf::from(":memory:"));
69 }
70
71 if PathBuf::from(db_path).is_absolute() {
72 return Ok(PathBuf::from(db_path));
73 }
74
75 let joined = base_path.join(db_path);
77 let normalised = joined.components().fold(PathBuf::new(), |mut acc, c| {
78 match c {
79 Component::ParentDir => {
80 acc.pop();
81 }
82 Component::CurDir => {}
83 _ => acc.push(c),
84 }
85 acc
86 });
87
88 if !normalised.starts_with(base_path) {
89 return Err(Error::InvalidDbUrl(format!(
90 "path '{}' escapes the base directory",
91 db_path
92 )));
93 }
94
95 Ok(normalised)
96 }
97
98 async fn open_local(
99 full_path: PathBuf,
100 encryption: Option<EncryptionConfig>,
101 ) -> Result<Database, Error> {
102 #[allow(unused_mut)]
103 let mut builder = LibsqlBuilder::new_local(&full_path.to_string_lossy().to_string());
104
105 #[cfg(feature = "encryption")]
106 if let Some(config) = encryption {
107 builder = builder.encryption_config(config.into());
108 }
109 #[cfg(not(feature = "encryption"))]
110 if encryption.is_some() {
111 return Err(Error::InvalidDbUrl(
112 "encryption feature is not enabled — rebuild with the `encryption` feature".into(),
113 ));
114 }
115
116 Ok(builder.build().await?)
117 }
118
119 #[cfg(feature = "replication")]
120 async fn open_replica(
121 full_path: PathBuf,
122 sync_url: String,
123 auth_token: String,
124 encryption: Option<EncryptionConfig>,
125 ) -> Result<Database, Error> {
126 #[allow(unused_mut)]
127 let mut builder = LibsqlBuilder::new_remote_replica(
128 full_path.to_string_lossy().to_string(),
129 sync_url,
130 auth_token,
131 );
132
133 #[cfg(feature = "encryption")]
134 if let Some(config) = encryption {
135 builder = builder.encryption_config(config.into());
136 }
137
138 let db = builder.build().await?;
139 db.sync().await?;
141 Ok(db)
142 }
143
144 #[cfg(not(feature = "replication"))]
145 async fn open_replica(
146 _full_path: PathBuf,
147 _sync_url: String,
148 _auth_token: String,
149 _encryption: Option<EncryptionConfig>,
150 ) -> Result<Database, Error> {
151 Err(Error::InvalidDbUrl(
152 "embedded replica requires the `replication` feature — add features = [\"replication\"] to your Cargo.toml".into(),
153 ))
154 }
155
156 #[cfg(feature = "remote")]
157 async fn open_remote(url: String, auth_token: String) -> Result<Database, Error> {
158 Ok(LibsqlBuilder::new_remote(url, auth_token).build().await?)
159 }
160
161 #[cfg(not(feature = "remote"))]
162 async fn open_remote(_url: String, _auth_token: String) -> Result<Database, Error> {
163 Err(Error::InvalidDbUrl(
164 "remote connections require the `remote` feature — add features = [\"remote\"] to your Cargo.toml".into(),
165 ))
166 }
167
168 pub async fn sync(&self) -> Result<(), Error> {
173 Self::do_sync(&self.db).await
174 }
175
176 #[cfg(feature = "replication")]
177 async fn do_sync(db: &Database) -> Result<(), Error> {
178 db.sync().await?;
179 Ok(())
180 }
181
182 #[cfg(not(feature = "replication"))]
183 async fn do_sync(_db: &Database) -> Result<(), Error> {
184 Err(Error::OperationNotSupported(
185 "sync requires the `replication` feature".into(),
186 ))
187 }
188
189 pub async fn execute(&self, query: &str, values: Vec<JsonValue>) -> Result<QueryResult, Error> {
191 let params = json_to_params(values);
192 let rows_affected = self.conn.execute(query, params).await?;
193
194 Ok(QueryResult {
195 rows_affected,
196 last_insert_id: self.conn.last_insert_rowid(),
197 })
198 }
199
200 pub async fn select(
202 &self,
203 query: &str,
204 values: Vec<JsonValue>,
205 ) -> Result<Vec<IndexMap<String, JsonValue>>, Error> {
206 let params = json_to_params(values);
207 let mut rows = self.conn.query(query, params).await?;
208
209 let mut results = Vec::new();
210
211 while let Some(row) = rows.next().await? {
212 let mut map = IndexMap::new();
213 let column_count = row.column_count();
214
215 for i in 0..column_count {
216 if let Some(column_name) = row.column_name(i) {
217 let value = decode::to_json(&row, i)?;
218 map.insert(column_name.to_string(), value);
219 }
220 }
221
222 results.push(map);
223 }
224
225 Ok(results)
226 }
227
228 pub async fn batch(&self, queries: Vec<String>) -> Result<(), Error> {
231 self.conn.execute("BEGIN", Params::None).await?;
232 for query in &queries {
233 if let Err(e) = self.conn.execute(query.as_str(), Params::None).await {
234 let _ = self.conn.execute("ROLLBACK", Params::None).await;
235 return Err(Error::Libsql(e));
236 }
237 }
238 if let Err(e) = self.conn.execute("COMMIT", Params::None).await {
239 let _ = self.conn.execute("ROLLBACK", Params::None).await;
240 return Err(Error::Libsql(e));
241 }
242 Ok(())
243 }
244
245 pub async fn close(&self) {
246 self.conn.reset().await;
247 }
248}
249
250fn json_to_params(values: Vec<JsonValue>) -> Params {
252 if values.is_empty() {
253 return Params::None;
254 }
255
256 let params: Vec<Value> = values.into_iter().map(json_to_libsql_value).collect();
257 Params::Positional(params)
258}
259
260fn json_to_libsql_value(v: JsonValue) -> Value {
261 match v {
262 JsonValue::Null => Value::Null,
263 JsonValue::Bool(b) => Value::Integer(if b { 1 } else { 0 }),
264 JsonValue::Number(n) => {
265 if let Some(i) = n.as_i64() {
266 Value::Integer(i)
267 } else if let Some(f) = n.as_f64() {
268 Value::Real(f)
269 } else {
270 Value::Null
271 }
272 }
273 JsonValue::String(s) => Value::Text(s),
274 JsonValue::Array(ref arr) => {
275 if arr.iter().all(|v| v.is_number()) {
276 let bytes: Vec<u8> = arr
277 .iter()
278 .filter_map(|v| v.as_u64().map(|n| n as u8))
279 .collect();
280 Value::Blob(bytes)
281 } else {
282 Value::Text(v.to_string())
283 }
284 }
285 JsonValue::Object(_) => Value::Text(v.to_string()),
286 }
287}
288
289pub struct DbInstances(pub Arc<Mutex<HashMap<String, Arc<DbConnection>>>>);
291
292impl Default for DbInstances {
293 fn default() -> Self {
294 Self(Arc::new(Mutex::new(HashMap::new())))
295 }
296}