use std::{borrow::Cow, collections::HashMap};
use data_encoding::HEXLOWER_PERMISSIVE;
use serde::{Deserialize, Deserializer};
use sodiumoxide::crypto::{
auth::hmacsha256,
box_::{self, Nonce, PublicKey, SecretKey},
};
use crate::errors::{ApiError, CryptoError};
fn deserialize_hex_string<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
let bytes: &[u8] = Deserialize::deserialize(deserializer)?;
HEXLOWER_PERMISSIVE
.decode(bytes)
.map_err(serde::de::Error::custom)
}
#[serde(rename_all = "camelCase")]
#[derive(Debug, serde::Deserialize)]
pub struct IncomingMessage {
pub from: String,
pub to: String,
pub message_id: String,
pub date: usize,
#[serde(deserialize_with = "deserialize_hex_string")]
pub nonce: Vec<u8>,
#[serde(rename = "box")]
#[serde(deserialize_with = "deserialize_hex_string")]
pub box_data: Vec<u8>,
pub nickname: Option<String>,
}
impl IncomingMessage {
pub fn from_urlencoded_bytes(
bytes: impl AsRef<[u8]>,
api_secret: &str,
) -> Result<Self, ApiError> {
let bytes = bytes.as_ref();
let values: HashMap<Cow<str>, Cow<str>> = form_urlencoded::parse(bytes).collect();
let mac_hex = values
.get("mac")
.ok_or_else(|| ApiError::ParseError("Missing request body field: mac".to_string()))?;
let mut mac = [0u8; 32];
let bytes_decoded = HEXLOWER_PERMISSIVE
.decode_mut(mac_hex.as_bytes(), &mut mac)
.map_err(|_| ApiError::ParseError("Invalid hex bytes for MAC".to_string()))?;
if bytes_decoded != 32 {
return Err(ApiError::ParseError(format!(
"Invalid MAC: Length must be 32 bytes, but is {} bytes",
bytes_decoded
)));
}
let mut hmac_state = hmacsha256::State::init(api_secret.as_bytes());
for field in &["from", "to", "messageId", "date", "nonce", "box"] {
hmac_state.update(
values
.get(*field)
.ok_or_else(|| {
ApiError::ParseError(format!("Missing request body field: {}", field))
})?
.as_bytes(),
);
}
let given_tag = hmacsha256::Tag(mac);
let calculated_tag = hmac_state.finalize();
if given_tag != calculated_tag {
return Err(ApiError::InvalidMac);
}
serde_urlencoded::from_bytes(bytes)
.map_err(|e| ApiError::ParseError(format!("Could not parse message: {}", e)))
}
pub fn decrypt_box(
&self,
public_key: &PublicKey,
private_key: &SecretKey,
) -> Result<Vec<u8>, CryptoError> {
let nonce: Nonce = Nonce::from_slice(&self.nonce).ok_or(CryptoError::BadNonce)?;
let mut decrypted = box_::open(&self.box_data, &nonce, &public_key, &private_key)
.map_err(|_| CryptoError::DecryptionFailed)?;
let padding_amount = decrypted.last().cloned().ok_or(CryptoError::BadPadding)? as usize;
if padding_amount >= decrypted.len() {
return Err(CryptoError::BadPadding);
}
decrypted.resize(decrypted.len() - padding_amount, 0);
Ok(decrypted)
}
}
#[cfg(test)]
mod tests {
use super::*;
mod incoming_message_deserialize {
use super::*;
const TEST_PAYLOAD: &[u8] = b"from=ECHOECHO&to=*TESTTST&messageId=0102030405060708&date=1616950936&nonce=ffffffffffffffffffffffffffffffffffffffffffffffff&box=012345abcdef&mac=622b362e8353658ee649a5548acecc9ce9b88384d6b7e08e212446d68455b14e";
const TEST_MAC_SECRET: &str = "nevergonnagiveyouup";
#[test]
fn success() {
let msg =
IncomingMessage::from_urlencoded_bytes(TEST_PAYLOAD, TEST_MAC_SECRET).unwrap();
assert_eq!(msg.from, "ECHOECHO");
assert_eq!(msg.to, "*TESTTST");
assert_eq!(msg.nonce, vec![0xff; 24]);
assert_eq!(msg.box_data, vec![0x01, 0x23, 0x45, 0xab, 0xcd, 0xef]);
assert_eq!(msg.nickname, None);
}
#[test]
fn invalid_mac() {
match IncomingMessage::from_urlencoded_bytes(TEST_PAYLOAD, "nevergonnaletyoudown") {
Err(ApiError::InvalidMac) => { }
other => panic!("Unexpected result: {:?}", other),
}
}
}
mod decrypt_box {
use super::*;
#[test]
fn decrypt() {
let (a_pk, a_sk) = box_::gen_keypair();
let (b_pk, b_sk) = box_::gen_keypair();
let nonce = box_::gen_nonce();
let msg = IncomingMessage {
from: "AAAAAAAA".into(),
to: "*BBBBBBB".into(),
message_id: "00112233".into(),
date: 0,
nonce: nonce.0.to_vec(),
box_data: box_::seal(
&[ 1, 2, 42, 3, 3, 3],
&nonce,
&b_pk,
&a_sk,
),
nickname: None,
};
let err = msg.decrypt_box(&b_pk, &b_sk).unwrap_err();
assert_eq!(err, CryptoError::DecryptionFailed);
let decrypted = msg.decrypt_box(&a_pk, &b_sk).unwrap();
assert_eq!(decrypted, vec![1, 2, 42]);
}
#[test]
fn decrypt_bad_nonce() {
let (pk, sk) = box_::gen_keypair();
let msg = IncomingMessage {
from: "AAAAAAAA".into(),
to: "*BBBBBBB".into(),
message_id: "00112233".into(),
date: 0,
nonce: vec![1, 2, 3, 4],
box_data: vec![0],
nickname: None,
};
let err = msg.decrypt_box(&pk, &sk).unwrap_err();
assert_eq!(err, CryptoError::BadNonce);
}
#[test]
fn decrypt_bad_padding() {
let (a_pk, a_sk) = box_::gen_keypair();
let (b_pk, b_sk) = box_::gen_keypair();
let nonce = box_::gen_nonce();
let msg = IncomingMessage {
from: "AAAAAAAA".into(),
to: "*BBBBBBB".into(),
message_id: "00112233".into(),
date: 0,
nonce: nonce.0.to_vec(),
box_data: box_::seal(
&[ 1, 2, 42 ],
&nonce,
&b_pk,
&a_sk,
),
nickname: None,
};
let err = msg.decrypt_box(&a_pk, &b_sk).unwrap_err();
assert_eq!(err, CryptoError::BadPadding);
}
}
}