tauri_plugin_stronghold/
lib.rs

1// Copyright 2019-2023 Tauri Programme within The Commons Conservancy
2// SPDX-License-Identifier: Apache-2.0
3// SPDX-License-Identifier: MIT
4
5//! Store secrets and keys using the [IOTA Stronghold](https://github.com/iotaledger/stronghold.rs) encrypted database and secure runtime.
6
7#![doc(
8    html_logo_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png",
9    html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png"
10)]
11
12use std::{
13    collections::HashMap,
14    fmt,
15    path::PathBuf,
16    sync::{Arc, Mutex},
17    time::Duration,
18};
19
20use crypto::keys::bip39;
21use iota_stronghold::{
22    procedures::{
23        BIP39Generate, BIP39Recover, Curve, Ed25519Sign, KeyType as StrongholdKeyType,
24        MnemonicLanguage, PublicKey, Slip10Derive, Slip10DeriveInput, Slip10Generate,
25        StrongholdProcedure,
26    },
27    Client, Location,
28};
29use serde::{de::Visitor, Deserialize, Deserializer};
30use stronghold::{Error, Result, Stronghold};
31use tauri::{
32    plugin::{Builder as PluginBuilder, TauriPlugin},
33    Manager, Runtime, State,
34};
35use zeroize::{Zeroize, Zeroizing};
36
37#[cfg(feature = "kdf")]
38pub mod kdf;
39
40pub mod stronghold;
41
42type PasswordHashFn = dyn Fn(&str) -> Vec<u8> + Send + Sync;
43
44#[derive(Default)]
45struct StrongholdCollection(Arc<Mutex<HashMap<PathBuf, Stronghold>>>);
46
47struct PasswordHashFunction(Box<PasswordHashFn>);
48
49#[derive(Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
50#[serde(untagged)]
51enum BytesDto {
52    Text(String),
53    Raw(Vec<u8>),
54}
55
56impl AsRef<[u8]> for BytesDto {
57    fn as_ref(&self) -> &[u8] {
58        match self {
59            Self::Text(t) => t.as_ref(),
60            Self::Raw(b) => b.as_ref(),
61        }
62    }
63}
64
65impl From<BytesDto> for Vec<u8> {
66    fn from(v: BytesDto) -> Self {
67        match v {
68            BytesDto::Text(t) => t.as_bytes().to_vec(),
69            BytesDto::Raw(b) => b,
70        }
71    }
72}
73
74#[derive(Deserialize)]
75#[serde(tag = "type", content = "payload")]
76enum LocationDto {
77    Generic { vault: BytesDto, record: BytesDto },
78    Counter { vault: BytesDto, counter: usize },
79}
80
81impl From<LocationDto> for Location {
82    fn from(dto: LocationDto) -> Location {
83        match dto {
84            LocationDto::Generic { vault, record } => Location::generic(vault, record),
85            LocationDto::Counter { vault, counter } => Location::counter(vault, counter),
86        }
87    }
88}
89
90#[derive(Deserialize)]
91#[serde(tag = "type", content = "payload")]
92#[allow(clippy::upper_case_acronyms)]
93enum Slip10DeriveInputDto {
94    Seed(LocationDto),
95    Key(LocationDto),
96}
97
98impl From<Slip10DeriveInputDto> for Slip10DeriveInput {
99    fn from(dto: Slip10DeriveInputDto) -> Slip10DeriveInput {
100        match dto {
101            Slip10DeriveInputDto::Seed(location) => Slip10DeriveInput::Seed(location.into()),
102            Slip10DeriveInputDto::Key(location) => Slip10DeriveInput::Key(location.into()),
103        }
104    }
105}
106
107pub enum KeyType {
108    Ed25519,
109    X25519,
110}
111
112impl From<KeyType> for StrongholdKeyType {
113    fn from(ty: KeyType) -> StrongholdKeyType {
114        match ty {
115            KeyType::Ed25519 => StrongholdKeyType::Ed25519,
116            KeyType::X25519 => StrongholdKeyType::X25519,
117        }
118    }
119}
120
121impl<'de> Deserialize<'de> for KeyType {
122    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
123    where
124        D: Deserializer<'de>,
125    {
126        struct KeyTypeVisitor;
127
128        impl Visitor<'_> for KeyTypeVisitor {
129            type Value = KeyType;
130
131            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
132                formatter.write_str("ed25519 or x25519")
133            }
134
135            fn visit_str<E>(self, value: &str) -> std::result::Result<Self::Value, E>
136            where
137                E: serde::de::Error,
138            {
139                match value.to_lowercase().as_str() {
140                    "ed25519" => Ok(KeyType::Ed25519),
141                    "x25519" => Ok(KeyType::X25519),
142                    _ => Err(serde::de::Error::custom("unknown key type")),
143                }
144            }
145        }
146
147        deserializer.deserialize_str(KeyTypeVisitor)
148    }
149}
150
151#[derive(Deserialize)]
152#[serde(tag = "type", content = "payload")]
153#[allow(clippy::upper_case_acronyms)]
154enum ProcedureDto {
155    SLIP10Generate {
156        output: LocationDto,
157        #[serde(rename = "sizeBytes")]
158        size_bytes: Option<usize>,
159    },
160    SLIP10Derive {
161        chain: Vec<u32>,
162        input: Slip10DeriveInputDto,
163        output: LocationDto,
164    },
165    BIP39Recover {
166        mnemonic: String,
167        passphrase: Option<String>,
168        output: LocationDto,
169    },
170    BIP39Generate {
171        passphrase: Option<String>,
172        output: LocationDto,
173    },
174    PublicKey {
175        #[serde(rename = "type")]
176        ty: KeyType,
177        #[serde(rename = "privateKey")]
178        private_key: LocationDto,
179    },
180    Ed25519Sign {
181        #[serde(rename = "privateKey")]
182        private_key: LocationDto,
183        msg: String,
184    },
185}
186
187impl From<ProcedureDto> for StrongholdProcedure {
188    fn from(dto: ProcedureDto) -> StrongholdProcedure {
189        match dto {
190            ProcedureDto::SLIP10Generate { output, size_bytes } => {
191                StrongholdProcedure::Slip10Generate(Slip10Generate {
192                    output: output.into(),
193                    size_bytes,
194                })
195            }
196            ProcedureDto::SLIP10Derive {
197                chain,
198                input,
199                output,
200            } => StrongholdProcedure::Slip10Derive(Slip10Derive {
201                curve: Curve::Ed25519,
202                chain,
203                input: input.into(),
204                output: output.into(),
205            }),
206            ProcedureDto::BIP39Recover {
207                mnemonic,
208                passphrase,
209                output,
210            } => StrongholdProcedure::BIP39Recover(BIP39Recover {
211                mnemonic: bip39::Mnemonic::from(mnemonic),
212                passphrase: bip39::Passphrase::from(passphrase.unwrap_or_default()),
213                output: output.into(),
214            }),
215            ProcedureDto::BIP39Generate { passphrase, output } => {
216                StrongholdProcedure::BIP39Generate(BIP39Generate {
217                    passphrase: bip39::Passphrase::from(passphrase.unwrap_or_default()),
218                    output: output.into(),
219                    language: MnemonicLanguage::English,
220                })
221            }
222            ProcedureDto::PublicKey { ty, private_key } => {
223                StrongholdProcedure::PublicKey(PublicKey {
224                    ty: ty.into(),
225                    private_key: private_key.into(),
226                })
227            }
228            ProcedureDto::Ed25519Sign { private_key, msg } => {
229                StrongholdProcedure::Ed25519Sign(Ed25519Sign {
230                    private_key: private_key.into(),
231                    msg: msg.as_bytes().to_vec(),
232                })
233            }
234        }
235    }
236}
237
238#[tauri::command]
239async fn initialize(
240    collection: State<'_, StrongholdCollection>,
241    hash_function: State<'_, PasswordHashFunction>,
242    snapshot_path: PathBuf,
243    mut password: String,
244) -> Result<()> {
245    let hash = (hash_function.0)(&password);
246    password.zeroize();
247    let stronghold = Stronghold::new(snapshot_path.clone(), hash)?;
248
249    collection
250        .0
251        .lock()
252        .unwrap()
253        .insert(snapshot_path, stronghold);
254
255    Ok(())
256}
257
258#[tauri::command]
259async fn destroy(
260    collection: State<'_, StrongholdCollection>,
261    snapshot_path: PathBuf,
262) -> Result<()> {
263    let mut collection = collection.0.lock().unwrap();
264    if let Some(stronghold) = collection.remove(&snapshot_path) {
265        if let Err(e) = stronghold.save() {
266            collection.insert(snapshot_path, stronghold);
267            return Err(e);
268        }
269    }
270    Ok(())
271}
272
273#[tauri::command]
274async fn save(collection: State<'_, StrongholdCollection>, snapshot_path: PathBuf) -> Result<()> {
275    let collection = collection.0.lock().unwrap();
276    if let Some(stronghold) = collection.get(&snapshot_path) {
277        stronghold.save()?;
278    }
279    Ok(())
280}
281
282#[tauri::command]
283async fn create_client(
284    collection: State<'_, StrongholdCollection>,
285    snapshot_path: PathBuf,
286    client: BytesDto,
287) -> Result<()> {
288    let stronghold = get_stronghold(collection, snapshot_path)?;
289    stronghold.create_client(client)?;
290    Ok(())
291}
292
293#[tauri::command]
294async fn load_client(
295    collection: State<'_, StrongholdCollection>,
296    snapshot_path: PathBuf,
297    client: BytesDto,
298) -> Result<()> {
299    let stronghold = get_stronghold(collection, snapshot_path)?;
300    stronghold.load_client(client)?;
301    Ok(())
302}
303
304#[tauri::command]
305async fn get_store_record(
306    collection: State<'_, StrongholdCollection>,
307    snapshot_path: PathBuf,
308    client: BytesDto,
309    key: String,
310) -> Result<Option<Vec<u8>>> {
311    let client = get_client(collection, snapshot_path, client)?;
312    client.store().get(key.as_ref()).map_err(Into::into)
313}
314
315#[tauri::command]
316async fn save_store_record(
317    collection: State<'_, StrongholdCollection>,
318    snapshot_path: PathBuf,
319    client: BytesDto,
320    key: String,
321    value: Vec<u8>,
322    lifetime: Option<Duration>,
323) -> Result<Option<Vec<u8>>> {
324    let client = get_client(collection, snapshot_path, client)?;
325    client
326        .store()
327        .insert(key.as_bytes().to_vec(), value, lifetime)
328        .map_err(Into::into)
329}
330
331#[tauri::command]
332async fn remove_store_record(
333    collection: State<'_, StrongholdCollection>,
334    snapshot_path: PathBuf,
335    client: BytesDto,
336    key: String,
337) -> Result<Option<Vec<u8>>> {
338    let client = get_client(collection, snapshot_path, client)?;
339    client.store().delete(key.as_ref()).map_err(Into::into)
340}
341
342#[tauri::command]
343async fn save_secret(
344    collection: State<'_, StrongholdCollection>,
345    snapshot_path: PathBuf,
346    client: BytesDto,
347    vault: BytesDto,
348    record_path: BytesDto,
349    secret: Vec<u8>,
350) -> Result<()> {
351    let client = get_client(collection, snapshot_path, client)?;
352    client
353        .vault(&vault)
354        .write_secret(
355            Location::generic(vault, record_path),
356            Zeroizing::new(secret),
357        )
358        .map_err(Into::into)
359}
360
361#[tauri::command]
362async fn remove_secret(
363    collection: State<'_, StrongholdCollection>,
364    snapshot_path: PathBuf,
365    client: BytesDto,
366    vault: BytesDto,
367    record_path: BytesDto,
368) -> Result<()> {
369    let client = get_client(collection, snapshot_path, client)?;
370    client
371        .vault(vault)
372        .delete_secret(record_path)
373        .map(|_| ())
374        .map_err(Into::into)
375}
376
377#[tauri::command]
378async fn execute_procedure(
379    collection: State<'_, StrongholdCollection>,
380    snapshot_path: PathBuf,
381    client: BytesDto,
382    procedure: ProcedureDto,
383) -> Result<Vec<u8>> {
384    let client = get_client(collection, snapshot_path, client)?;
385    client
386        .execute_procedure(StrongholdProcedure::from(procedure))
387        .map(Into::into)
388        .map_err(Into::into)
389}
390
391fn get_stronghold(
392    collection: State<'_, StrongholdCollection>,
393    snapshot_path: PathBuf,
394) -> Result<iota_stronghold::Stronghold> {
395    let collection = collection.0.lock().unwrap();
396    if let Some(stronghold) = collection.get(&snapshot_path) {
397        Ok(stronghold.inner().clone())
398    } else {
399        Err(Error::StrongholdNotInitialized)
400    }
401}
402
403fn get_client(
404    collection: State<'_, StrongholdCollection>,
405    snapshot_path: PathBuf,
406    client: BytesDto,
407) -> Result<Client> {
408    let collection = collection.0.lock().unwrap();
409    if let Some(stronghold) = collection.get(&snapshot_path) {
410        stronghold.get_client(client).map_err(Into::into)
411    } else {
412        Err(Error::StrongholdNotInitialized)
413    }
414}
415
416enum PasswordHashFunctionKind {
417    #[cfg(feature = "kdf")]
418    Argon2(PathBuf),
419    Custom(Box<PasswordHashFn>),
420}
421
422pub struct Builder {
423    password_hash_function: PasswordHashFunctionKind,
424}
425
426impl Builder {
427    pub fn new<F: Fn(&str) -> Vec<u8> + Send + Sync + 'static>(password_hash_function: F) -> Self {
428        Self {
429            password_hash_function: PasswordHashFunctionKind::Custom(Box::new(
430                password_hash_function,
431            )),
432        }
433    }
434
435    /// Initializes [`Self`] with argon2 as password hash function.
436    ///
437    /// # Examples
438    ///
439    /// ```rust
440    /// use tauri::Manager;
441    /// tauri::Builder::default()
442    ///     .setup(|app| {
443    ///         let salt_path = app
444    ///             .path()
445    ///             .app_local_data_dir()
446    ///             .expect("could not resolve app local data path")
447    ///             .join("salt.txt");
448    ///         app.handle().plugin(tauri_plugin_stronghold::Builder::with_argon2(&salt_path).build())?;
449    ///         Ok(())
450    ///     });
451    /// ```
452    #[cfg(feature = "kdf")]
453    pub fn with_argon2(salt_path: &std::path::Path) -> Self {
454        Self {
455            password_hash_function: PasswordHashFunctionKind::Argon2(salt_path.to_owned()),
456        }
457    }
458
459    pub fn build<R: Runtime>(self) -> TauriPlugin<R> {
460        let password_hash_function = self.password_hash_function;
461
462        let plugin_builder = PluginBuilder::new("stronghold").setup(move |app, _api| {
463            app.manage(StrongholdCollection::default());
464            app.manage(PasswordHashFunction(match password_hash_function {
465                #[cfg(feature = "kdf")]
466                PasswordHashFunctionKind::Argon2(path) => {
467                    Box::new(move |p| kdf::KeyDerivation::argon2(p, &path))
468                }
469                PasswordHashFunctionKind::Custom(f) => f,
470            }));
471            Ok(())
472        });
473
474        Builder::invoke_stronghold_handlers_and_build(plugin_builder)
475    }
476
477    fn invoke_stronghold_handlers_and_build<R: Runtime>(
478        builder: PluginBuilder<R>,
479    ) -> TauriPlugin<R> {
480        builder
481            .invoke_handler(tauri::generate_handler![
482                initialize,
483                destroy,
484                save,
485                create_client,
486                load_client,
487                get_store_record,
488                save_store_record,
489                remove_store_record,
490                save_secret,
491                remove_secret,
492                execute_procedure,
493            ])
494            .build()
495    }
496}