Skip to main content

ssec_cli/
dec.rs

1use ssec_core::decrypt::{Decrypt, SsecHeaderError};
2use futures_util::{Stream, StreamExt};
3use tokio::io::AsyncWriteExt;
4use zeroize::Zeroizing;
5use indicatif::{ProgressBar, ProgressStyle};
6use std::path::PathBuf;
7use crate::cli::{DecArgs, FetchArgs};
8use crate::file::new_async_tempfile;
9use crate::password::prompt_password;
10use crate::io::IoBundle;
11use crate::{DEFINITE_BAR_STYLE, INDEFINITE_BAR_STYLE};
12
13const SPINNER_STYLE: &str = "{spinner} deriving decryption key";
14
15macro_rules! bail {
16	($p:ident, $m:literal) => {
17		$p.suspend(|| {
18			eprintln!($m);
19		});
20		return Err(());
21	}
22}
23
24async fn dec_stream_to<E, S>(
25	stream: S,
26	password: Zeroizing<Vec<u8>>,
27	out_path: PathBuf,
28	show_progress: bool,
29	enc_len: Option<u64>
30) -> Result<(), ()>
31where
32	E: std::error::Error,
33	S: Stream<Item = Result<bytes::Bytes, E>> + Unpin + Send + 'static
34{
35	let progress = match show_progress {
36		true => ProgressBar::new_spinner(),
37		false => ProgressBar::hidden()
38	};
39	let stream = stream.map({
40		let progress = progress.clone();
41		move |b| {
42			if let Ok(b) = &b {
43				progress.inc(b.len() as u64);
44			}
45			b
46		}
47	});
48
49	let (dec, f_out) = tokio::join!(
50		async {
51			let dec = Decrypt::new(stream).await?;
52			Ok::<_, SsecHeaderError<E>>(tokio::task::spawn_blocking({
53				let progress = progress.clone();
54				move || {
55					progress.set_style(ProgressStyle::with_template(SPINNER_STYLE).unwrap());
56					progress.enable_steady_tick(std::time::Duration::from_millis(100));
57
58					dec.try_password(&password)
59				}
60			}).await.unwrap())
61		},
62		new_async_tempfile()
63	);
64
65	let mut dec = match dec {
66		Ok(Ok(dec)) => dec,
67		Ok(Err(_)) => {
68			bail!(progress, "password incorrect");
69		},
70		Err(SsecHeaderError::NotSsec) => {
71			bail!(progress, "input is not a SSEC file");
72		},
73		Err(SsecHeaderError::UnsupportedVersion(0)) => {
74			bail!(progress, "input is from an old version of SSEC, consider downgrading to `ssec-cli` version 0.3");
75		},
76		Err(SsecHeaderError::UnsupportedVersion(v)) => {
77			bail!(progress, "input is from a future version of SSEC (version {v:?}), consider updating `ssec-cli` to the latest version");
78		},
79		Err(SsecHeaderError::UnsupportedCompression(c)) => {
80			bail!(progress, "input has unimplemented compression (type {c:?}), consider updating `ssec-cli` to the latest version");
81		},
82		Err(SsecHeaderError::Stream(e)) => {
83			bail!(progress, "input stream failed: {e}");
84		}
85	};
86	let mut f_out = f_out.unwrap();
87
88	progress.disable_steady_tick();
89	match enc_len {
90		Some(enc_len) => {
91			progress.set_length(enc_len);
92			progress.set_style(ProgressStyle::with_template(DEFINITE_BAR_STYLE).unwrap());
93		},
94		None => progress.set_style(ProgressStyle::with_template(INDEFINITE_BAR_STYLE).unwrap())
95	}
96	progress.reset();
97
98	while let Some(bytes) = dec.next().await {
99		let b = match bytes {
100			Ok(b) => b,
101			Err(e) => {
102				bail!(progress, "{e}");
103			},
104		};
105
106		f_out.as_mut().write_all(&b).await.unwrap();
107	}
108
109	f_out.as_mut().shutdown().await.unwrap();
110
111	f_out.persist(out_path).await.unwrap();
112
113	Ok(())
114}
115
116pub async fn dec_file<B: IoBundle>(args: DecArgs, io: B) -> Result<(), ()> {
117	let password = prompt_password(io).await.map_err(|e| {
118		eprintln!("failed to read password interactively: {e}");
119	})?;
120
121	let f_in = tokio::fs::File::open(&args.in_file).await.map_err(|e| {
122		eprintln!("failed to open file {:?}: {e}", args.in_file);
123	})?;
124
125	let f_in_metadata = f_in.metadata().await.map_err(|e| {
126		eprintln!("failed to get metadata of input file: {e}");
127	})?;
128
129	let s = tokio_util::io::ReaderStream::new(f_in);
130
131	dec_stream_to(
132		s,
133		password,
134		args.out_file,
135		B::is_interactive() && !args.silent,
136		Some(f_in_metadata.len())
137	).await
138}
139
140pub async fn dec_fetch<B: IoBundle>(args: FetchArgs, io: B) -> Result<(), ()> {
141	let password = prompt_password(io).await.map_err(|e| {
142		eprintln!("failed to read password interactively: {e}");
143	})?;
144
145	let client = reqwest::Client::new();
146
147	let resp = client.get(args.url.clone()).send().await.map_err(|e| {
148		eprintln!("failed to fetch remote file {:?}: {e}", args.url);
149	})?;
150	let enc_len = resp.content_length();
151	let s = resp.bytes_stream();
152
153	dec_stream_to(
154		s,
155		password,
156		args.out_file,
157		B::is_interactive() && !args.silent,
158		enc_len
159	).await
160}