use crate::{
commonio::{self, Buffer},
crypter::{Decrypter, Encrypter, DEFAULT_BLOCK_SIZE, MIN_BLOCK_SIZE},
key::{PublicKey, SecretKey},
};
use std::{
cmp,
io::{self, Write},
};
#[cfg(feature = "io-async")]
pub use crate::async_::write::*;
#[derive(Debug)]
pub struct SaltlickDecrypter<W: Write> {
buffer: Buffer,
decrypter: Decrypter,
inner: Option<W>,
}
impl<W: Write> SaltlickDecrypter<W> {
pub fn new(public_key: PublicKey, secret_key: SecretKey, writer: W) -> SaltlickDecrypter<W> {
Self::with_capacity(DEFAULT_BLOCK_SIZE, public_key, secret_key, writer)
}
pub fn new_deferred<F>(writer: W, lookup_fn: F) -> SaltlickDecrypter<W>
where
F: FnOnce(&PublicKey) -> Option<SecretKey> + 'static,
{
Self::deferred_with_capacity(DEFAULT_BLOCK_SIZE, writer, lookup_fn)
}
pub fn with_capacity(
capacity: usize,
public_key: PublicKey,
secret_key: SecretKey,
writer: W,
) -> SaltlickDecrypter<W> {
let capacity = cmp::max(capacity, MIN_BLOCK_SIZE);
SaltlickDecrypter {
buffer: Buffer::new(capacity),
decrypter: Decrypter::new(public_key, secret_key),
inner: Some(writer),
}
}
pub fn deferred_with_capacity<F>(
capacity: usize,
writer: W,
lookup_fn: F,
) -> SaltlickDecrypter<W>
where
F: FnOnce(&PublicKey) -> Option<SecretKey> + 'static,
{
let capacity = cmp::max(capacity, MIN_BLOCK_SIZE);
SaltlickDecrypter {
buffer: Buffer::new(capacity),
decrypter: Decrypter::new_deferred(lookup_fn),
inner: Some(writer),
}
}
pub fn finalize(mut self) -> Result<W, io::Error> {
let writer = self.inner.as_mut().expect("inner writer missing");
commonio::write_finalized(writer, &mut self.decrypter, &mut self.buffer)?;
let inner = self.inner.take().expect("inner writer missing");
Ok(inner)
}
}
impl<W: Write> Write for SaltlickDecrypter<W> {
fn write(&mut self, input: &[u8]) -> io::Result<usize> {
let writer = self.inner.as_mut().expect("inner writer missing");
commonio::write(writer, &mut self.decrypter, &mut self.buffer, input)
}
fn flush(&mut self) -> io::Result<()> {
let writer = self.inner.as_mut().expect("inner writer missing");
self.buffer.flush(writer)?;
writer.flush()
}
}
impl<W: Write> Drop for SaltlickDecrypter<W> {
fn drop(&mut self) {
if self.inner.is_some() && !self.buffer.panicked() {
let writer = self.inner.as_mut().unwrap();
let _ = self.buffer.flush(writer);
let _ = commonio::write_finalized(writer, &mut self.decrypter, &mut self.buffer);
}
}
}
#[derive(Debug)]
pub struct SaltlickEncrypter<W: Write> {
buffer: Buffer,
encrypter: Encrypter,
inner: Option<W>,
}
impl<W: Write> SaltlickEncrypter<W> {
pub fn new(public_key: PublicKey, writer: W) -> SaltlickEncrypter<W> {
SaltlickEncrypter::with_capacity(DEFAULT_BLOCK_SIZE, public_key, writer)
}
pub fn with_capacity(
capacity: usize,
public_key: PublicKey,
writer: W,
) -> SaltlickEncrypter<W> {
let capacity = cmp::max(capacity, MIN_BLOCK_SIZE);
SaltlickEncrypter {
buffer: Buffer::new(capacity),
encrypter: Encrypter::new(public_key),
inner: Some(writer),
}
}
pub fn set_block_size(&mut self, block_size: usize) {
self.encrypter.set_block_size(block_size);
}
pub fn finalize(mut self) -> Result<W, io::Error> {
let writer = self.inner.as_mut().expect("inner writer missing");
commonio::write_finalized(writer, &mut self.encrypter, &mut self.buffer)?;
let inner = self.inner.take().expect("inner writer missing");
Ok(inner)
}
}
impl<W: Write> Write for SaltlickEncrypter<W> {
fn write(&mut self, input: &[u8]) -> io::Result<usize> {
let writer = self.inner.as_mut().expect("inner writer missing");
commonio::write(writer, &mut self.encrypter, &mut self.buffer, input)
}
fn flush(&mut self) -> io::Result<()> {
let writer = self.inner.as_mut().expect("inner writer missing");
self.buffer.flush(writer)?;
writer.flush()
}
}
impl<W: Write> Drop for SaltlickEncrypter<W> {
fn drop(&mut self) {
if self.inner.is_some() && !self.buffer.panicked() {
let writer = self.inner.as_mut().unwrap();
let _ = self.buffer.flush(writer);
let _ = commonio::write_finalized(writer, &mut self.encrypter, &mut self.buffer);
}
}
}
#[cfg(test)]
mod tests {
use super::{SaltlickDecrypter, SaltlickEncrypter};
use crate::{key::gen_keypair, testutils::random_bytes};
use std::{cmp, io::Write, iter};
#[test]
fn single_write_test() {
for size in &[
1,
10 * 1024,
32 * 1024,
100 * 1024,
200 * 1024,
10 * 1024 * 1024,
] {
let random_data = random_bytes(0, *size);
let (public_key, secret_key) = gen_keypair();
let decrypter = SaltlickDecrypter::new_deferred(Vec::new(), |_| Some(secret_key));
let mut encrypter = SaltlickEncrypter::new(public_key, decrypter);
encrypter.write_all(&random_data[..]).unwrap();
let decrypter = encrypter.finalize().unwrap();
let output = decrypter.finalize().unwrap();
assert_eq!(&random_data[..], &output[..]);
}
}
#[test]
fn multiple_write_test() {
for size in &[
1,
10 * 1024,
32 * 1024,
100 * 1024,
200 * 1024,
10 * 1024 * 1024,
] {
let random_data = random_bytes(0, *size);
let (public_key, secret_key) = gen_keypair();
let decrypter = SaltlickDecrypter::new_deferred(Vec::new(), |_| Some(secret_key));
let mut encrypter = SaltlickEncrypter::new(public_key, decrypter);
encrypter.set_block_size(16 * 1024);
let mut written = 0;
for take in iter::successors(Some(1usize), |n| Some(n + 7)) {
let end = cmp::min(written + take, *size);
encrypter.write_all(&random_data[written..end]).unwrap();
encrypter.flush().unwrap();
written += take;
if written >= *size {
break;
}
}
let decrypter = encrypter.finalize().unwrap();
let output = decrypter.finalize().unwrap();
assert_eq!(&random_data[..], &output[..]);
}
}
#[test]
fn drop_flush_test() {
for size in &[
1,
10 * 1024,
32 * 1024,
100 * 1024,
200 * 1024,
10 * 1024 * 1024,
] {
let random_data = random_bytes(0, *size);
let (public_key, secret_key) = gen_keypair();
let mut output = Vec::new();
{
let decrypter = SaltlickDecrypter::new(public_key.clone(), secret_key, &mut output);
let mut encrypter = SaltlickEncrypter::new(public_key.clone(), decrypter);
encrypter.write_all(&random_data[..]).unwrap();
}
assert_eq!(&random_data[..], &output[..]);
}
}
#[test]
fn corrupt_value_test() {
let random_data = random_bytes(0, 100 * 1024);
let (public_key, secret_key) = gen_keypair();
let mut encrypter = SaltlickEncrypter::new(public_key.clone(), Vec::new());
encrypter.write_all(&random_data[..]).unwrap();
let mut ciphertext = encrypter.finalize().unwrap();
let index = ciphertext.len() - 5;
ciphertext[index] = ciphertext[index].wrapping_add(1);
let mut decrypter = SaltlickDecrypter::new(public_key, secret_key, Vec::new());
assert!(decrypter.write_all(&ciphertext[..]).is_err());
}
#[test]
fn incomplete_stream_test() {
let random_data = random_bytes(0, 100 * 1024);
let (public_key, secret_key) = gen_keypair();
let mut encrypter = SaltlickEncrypter::new(public_key.clone(), Vec::new());
encrypter.write_all(&random_data[..]).unwrap();
let mut ciphertext = encrypter.finalize().unwrap();
ciphertext.resize(ciphertext.len() - 5, 0);
let mut decrypter = SaltlickDecrypter::new(public_key, secret_key, Vec::new());
decrypter.write_all(&ciphertext[..]).unwrap();
assert!(decrypter.finalize().is_err());
}
}