1#![doc(
8 html_logo_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png",
9 html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png"
10)]
11
12use std::{
13 collections::HashMap,
14 fmt,
15 path::PathBuf,
16 sync::{Arc, Mutex},
17 time::Duration,
18};
19
20use crypto::keys::bip39;
21use iota_stronghold::{
22 procedures::{
23 BIP39Generate, BIP39Recover, Curve, Ed25519Sign, KeyType as StrongholdKeyType,
24 MnemonicLanguage, PublicKey, Slip10Derive, Slip10DeriveInput, Slip10Generate,
25 StrongholdProcedure,
26 },
27 Client, Location,
28};
29use serde::{de::Visitor, Deserialize, Deserializer};
30use stronghold::{Error, Result, Stronghold};
31use tauri::{
32 plugin::{Builder as PluginBuilder, TauriPlugin},
33 Manager, Runtime, State,
34};
35use zeroize::{Zeroize, Zeroizing};
36
37#[cfg(feature = "kdf")]
38pub mod kdf;
39
40pub mod stronghold;
41
42type PasswordHashFn = dyn Fn(&str) -> Vec<u8> + Send + Sync;
43
44#[derive(Default)]
45struct StrongholdCollection(Arc<Mutex<HashMap<PathBuf, Stronghold>>>);
46
47struct PasswordHashFunction(Box<PasswordHashFn>);
48
49#[derive(Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
50#[serde(untagged)]
51enum BytesDto {
52 Text(String),
53 Raw(Vec<u8>),
54}
55
56impl AsRef<[u8]> for BytesDto {
57 fn as_ref(&self) -> &[u8] {
58 match self {
59 Self::Text(t) => t.as_ref(),
60 Self::Raw(b) => b.as_ref(),
61 }
62 }
63}
64
65impl From<BytesDto> for Vec<u8> {
66 fn from(v: BytesDto) -> Self {
67 match v {
68 BytesDto::Text(t) => t.as_bytes().to_vec(),
69 BytesDto::Raw(b) => b,
70 }
71 }
72}
73
74#[derive(Deserialize)]
75#[serde(tag = "type", content = "payload")]
76enum LocationDto {
77 Generic { vault: BytesDto, record: BytesDto },
78 Counter { vault: BytesDto, counter: usize },
79}
80
81impl From<LocationDto> for Location {
82 fn from(dto: LocationDto) -> Location {
83 match dto {
84 LocationDto::Generic { vault, record } => Location::generic(vault, record),
85 LocationDto::Counter { vault, counter } => Location::counter(vault, counter),
86 }
87 }
88}
89
90#[derive(Deserialize)]
91#[serde(tag = "type", content = "payload")]
92#[allow(clippy::upper_case_acronyms)]
93enum Slip10DeriveInputDto {
94 Seed(LocationDto),
95 Key(LocationDto),
96}
97
98impl From<Slip10DeriveInputDto> for Slip10DeriveInput {
99 fn from(dto: Slip10DeriveInputDto) -> Slip10DeriveInput {
100 match dto {
101 Slip10DeriveInputDto::Seed(location) => Slip10DeriveInput::Seed(location.into()),
102 Slip10DeriveInputDto::Key(location) => Slip10DeriveInput::Key(location.into()),
103 }
104 }
105}
106
107pub enum KeyType {
108 Ed25519,
109 X25519,
110}
111
112impl From<KeyType> for StrongholdKeyType {
113 fn from(ty: KeyType) -> StrongholdKeyType {
114 match ty {
115 KeyType::Ed25519 => StrongholdKeyType::Ed25519,
116 KeyType::X25519 => StrongholdKeyType::X25519,
117 }
118 }
119}
120
121impl<'de> Deserialize<'de> for KeyType {
122 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
123 where
124 D: Deserializer<'de>,
125 {
126 struct KeyTypeVisitor;
127
128 impl Visitor<'_> for KeyTypeVisitor {
129 type Value = KeyType;
130
131 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
132 formatter.write_str("ed25519 or x25519")
133 }
134
135 fn visit_str<E>(self, value: &str) -> std::result::Result<Self::Value, E>
136 where
137 E: serde::de::Error,
138 {
139 match value.to_lowercase().as_str() {
140 "ed25519" => Ok(KeyType::Ed25519),
141 "x25519" => Ok(KeyType::X25519),
142 _ => Err(serde::de::Error::custom("unknown key type")),
143 }
144 }
145 }
146
147 deserializer.deserialize_str(KeyTypeVisitor)
148 }
149}
150
151#[derive(Deserialize)]
152#[serde(tag = "type", content = "payload")]
153#[allow(clippy::upper_case_acronyms)]
154enum ProcedureDto {
155 SLIP10Generate {
156 output: LocationDto,
157 #[serde(rename = "sizeBytes")]
158 size_bytes: Option<usize>,
159 },
160 SLIP10Derive {
161 chain: Vec<u32>,
162 input: Slip10DeriveInputDto,
163 output: LocationDto,
164 },
165 BIP39Recover {
166 mnemonic: String,
167 passphrase: Option<String>,
168 output: LocationDto,
169 },
170 BIP39Generate {
171 passphrase: Option<String>,
172 output: LocationDto,
173 },
174 PublicKey {
175 #[serde(rename = "type")]
176 ty: KeyType,
177 #[serde(rename = "privateKey")]
178 private_key: LocationDto,
179 },
180 Ed25519Sign {
181 #[serde(rename = "privateKey")]
182 private_key: LocationDto,
183 msg: String,
184 },
185}
186
187impl From<ProcedureDto> for StrongholdProcedure {
188 fn from(dto: ProcedureDto) -> StrongholdProcedure {
189 match dto {
190 ProcedureDto::SLIP10Generate { output, size_bytes } => {
191 StrongholdProcedure::Slip10Generate(Slip10Generate {
192 output: output.into(),
193 size_bytes,
194 })
195 }
196 ProcedureDto::SLIP10Derive {
197 chain,
198 input,
199 output,
200 } => StrongholdProcedure::Slip10Derive(Slip10Derive {
201 curve: Curve::Ed25519,
202 chain,
203 input: input.into(),
204 output: output.into(),
205 }),
206 ProcedureDto::BIP39Recover {
207 mnemonic,
208 passphrase,
209 output,
210 } => StrongholdProcedure::BIP39Recover(BIP39Recover {
211 mnemonic: bip39::Mnemonic::from(mnemonic),
212 passphrase: bip39::Passphrase::from(passphrase.unwrap_or_default()),
213 output: output.into(),
214 }),
215 ProcedureDto::BIP39Generate { passphrase, output } => {
216 StrongholdProcedure::BIP39Generate(BIP39Generate {
217 passphrase: bip39::Passphrase::from(passphrase.unwrap_or_default()),
218 output: output.into(),
219 language: MnemonicLanguage::English,
220 })
221 }
222 ProcedureDto::PublicKey { ty, private_key } => {
223 StrongholdProcedure::PublicKey(PublicKey {
224 ty: ty.into(),
225 private_key: private_key.into(),
226 })
227 }
228 ProcedureDto::Ed25519Sign { private_key, msg } => {
229 StrongholdProcedure::Ed25519Sign(Ed25519Sign {
230 private_key: private_key.into(),
231 msg: msg.as_bytes().to_vec(),
232 })
233 }
234 }
235 }
236}
237
238#[tauri::command]
239async fn initialize(
240 collection: State<'_, StrongholdCollection>,
241 hash_function: State<'_, PasswordHashFunction>,
242 snapshot_path: PathBuf,
243 mut password: String,
244) -> Result<()> {
245 let hash = (hash_function.0)(&password);
246 password.zeroize();
247 let stronghold = Stronghold::new(snapshot_path.clone(), hash)?;
248
249 collection
250 .0
251 .lock()
252 .unwrap()
253 .insert(snapshot_path, stronghold);
254
255 Ok(())
256}
257
258#[tauri::command]
259async fn destroy(
260 collection: State<'_, StrongholdCollection>,
261 snapshot_path: PathBuf,
262) -> Result<()> {
263 let mut collection = collection.0.lock().unwrap();
264 if let Some(stronghold) = collection.remove(&snapshot_path) {
265 if let Err(e) = stronghold.save() {
266 collection.insert(snapshot_path, stronghold);
267 return Err(e);
268 }
269 }
270 Ok(())
271}
272
273#[tauri::command]
274async fn save(collection: State<'_, StrongholdCollection>, snapshot_path: PathBuf) -> Result<()> {
275 let collection = collection.0.lock().unwrap();
276 if let Some(stronghold) = collection.get(&snapshot_path) {
277 stronghold.save()?;
278 }
279 Ok(())
280}
281
282#[tauri::command]
283async fn create_client(
284 collection: State<'_, StrongholdCollection>,
285 snapshot_path: PathBuf,
286 client: BytesDto,
287) -> Result<()> {
288 let stronghold = get_stronghold(collection, snapshot_path)?;
289 stronghold.create_client(client)?;
290 Ok(())
291}
292
293#[tauri::command]
294async fn load_client(
295 collection: State<'_, StrongholdCollection>,
296 snapshot_path: PathBuf,
297 client: BytesDto,
298) -> Result<()> {
299 let stronghold = get_stronghold(collection, snapshot_path)?;
300 stronghold.load_client(client)?;
301 Ok(())
302}
303
304#[tauri::command]
305async fn get_store_record(
306 collection: State<'_, StrongholdCollection>,
307 snapshot_path: PathBuf,
308 client: BytesDto,
309 key: String,
310) -> Result<Option<Vec<u8>>> {
311 let client = get_client(collection, snapshot_path, client)?;
312 client.store().get(key.as_ref()).map_err(Into::into)
313}
314
315#[tauri::command]
316async fn save_store_record(
317 collection: State<'_, StrongholdCollection>,
318 snapshot_path: PathBuf,
319 client: BytesDto,
320 key: String,
321 value: Vec<u8>,
322 lifetime: Option<Duration>,
323) -> Result<Option<Vec<u8>>> {
324 let client = get_client(collection, snapshot_path, client)?;
325 client
326 .store()
327 .insert(key.as_bytes().to_vec(), value, lifetime)
328 .map_err(Into::into)
329}
330
331#[tauri::command]
332async fn remove_store_record(
333 collection: State<'_, StrongholdCollection>,
334 snapshot_path: PathBuf,
335 client: BytesDto,
336 key: String,
337) -> Result<Option<Vec<u8>>> {
338 let client = get_client(collection, snapshot_path, client)?;
339 client.store().delete(key.as_ref()).map_err(Into::into)
340}
341
342#[tauri::command]
343async fn save_secret(
344 collection: State<'_, StrongholdCollection>,
345 snapshot_path: PathBuf,
346 client: BytesDto,
347 vault: BytesDto,
348 record_path: BytesDto,
349 secret: Vec<u8>,
350) -> Result<()> {
351 let client = get_client(collection, snapshot_path, client)?;
352 client
353 .vault(&vault)
354 .write_secret(
355 Location::generic(vault, record_path),
356 Zeroizing::new(secret),
357 )
358 .map_err(Into::into)
359}
360
361#[tauri::command]
362async fn remove_secret(
363 collection: State<'_, StrongholdCollection>,
364 snapshot_path: PathBuf,
365 client: BytesDto,
366 vault: BytesDto,
367 record_path: BytesDto,
368) -> Result<()> {
369 let client = get_client(collection, snapshot_path, client)?;
370 client
371 .vault(vault)
372 .delete_secret(record_path)
373 .map(|_| ())
374 .map_err(Into::into)
375}
376
377#[tauri::command]
378async fn execute_procedure(
379 collection: State<'_, StrongholdCollection>,
380 snapshot_path: PathBuf,
381 client: BytesDto,
382 procedure: ProcedureDto,
383) -> Result<Vec<u8>> {
384 let client = get_client(collection, snapshot_path, client)?;
385 client
386 .execute_procedure(StrongholdProcedure::from(procedure))
387 .map(Into::into)
388 .map_err(Into::into)
389}
390
391fn get_stronghold(
392 collection: State<'_, StrongholdCollection>,
393 snapshot_path: PathBuf,
394) -> Result<iota_stronghold::Stronghold> {
395 let collection = collection.0.lock().unwrap();
396 if let Some(stronghold) = collection.get(&snapshot_path) {
397 Ok(stronghold.inner().clone())
398 } else {
399 Err(Error::StrongholdNotInitialized)
400 }
401}
402
403fn get_client(
404 collection: State<'_, StrongholdCollection>,
405 snapshot_path: PathBuf,
406 client: BytesDto,
407) -> Result<Client> {
408 let collection = collection.0.lock().unwrap();
409 if let Some(stronghold) = collection.get(&snapshot_path) {
410 stronghold.get_client(client).map_err(Into::into)
411 } else {
412 Err(Error::StrongholdNotInitialized)
413 }
414}
415
416enum PasswordHashFunctionKind {
417 #[cfg(feature = "kdf")]
418 Argon2(PathBuf),
419 Custom(Box<PasswordHashFn>),
420}
421
422pub struct Builder {
423 password_hash_function: PasswordHashFunctionKind,
424}
425
426impl Builder {
427 pub fn new<F: Fn(&str) -> Vec<u8> + Send + Sync + 'static>(password_hash_function: F) -> Self {
428 Self {
429 password_hash_function: PasswordHashFunctionKind::Custom(Box::new(
430 password_hash_function,
431 )),
432 }
433 }
434
435 #[cfg(feature = "kdf")]
453 pub fn with_argon2(salt_path: &std::path::Path) -> Self {
454 Self {
455 password_hash_function: PasswordHashFunctionKind::Argon2(salt_path.to_owned()),
456 }
457 }
458
459 pub fn build<R: Runtime>(self) -> TauriPlugin<R> {
460 let password_hash_function = self.password_hash_function;
461
462 let plugin_builder = PluginBuilder::new("stronghold").setup(move |app, _api| {
463 app.manage(StrongholdCollection::default());
464 app.manage(PasswordHashFunction(match password_hash_function {
465 #[cfg(feature = "kdf")]
466 PasswordHashFunctionKind::Argon2(path) => {
467 Box::new(move |p| kdf::KeyDerivation::argon2(p, &path))
468 }
469 PasswordHashFunctionKind::Custom(f) => f,
470 }));
471 Ok(())
472 });
473
474 Builder::invoke_stronghold_handlers_and_build(plugin_builder)
475 }
476
477 fn invoke_stronghold_handlers_and_build<R: Runtime>(
478 builder: PluginBuilder<R>,
479 ) -> TauriPlugin<R> {
480 builder
481 .invoke_handler(tauri::generate_handler![
482 initialize,
483 destroy,
484 save,
485 create_client,
486 load_client,
487 get_store_record,
488 save_store_record,
489 remove_store_record,
490 save_secret,
491 remove_secret,
492 execute_procedure,
493 ])
494 .build()
495 }
496}