sunset/
ssh_chapoly.rs

1#![cfg_attr(fuzzing, allow(dead_code))]
2
3#[allow(unused_imports)]
4use {
5    crate::error::*,
6    log::{debug, error, info, log, trace, warn},
7};
8
9use chacha20::cipher::{
10    KeyIvInit, StreamCipher, StreamCipherSeek, StreamCipherSeekCore,
11};
12use chacha20::ChaCha20;
13use digest::KeyInit;
14use poly1305::universal_hash::generic_array::GenericArray;
15use poly1305::universal_hash::UniversalHash;
16use poly1305::Poly1305;
17use subtle::ConstantTimeEq;
18use zeroize::{Zeroize, ZeroizeOnDrop};
19
20use pretty_hex::PrettyHex;
21
22use crate::*;
23use encrypt::SSH_LENGTH_SIZE;
24
25#[derive(Clone, ZeroizeOnDrop)]
26/// `chacha20-poly1305@openssh.com` authenticated cipher
27pub struct SSHChaPoly {
28    /// Length key
29    k1: [u8; 32],
30    /// Packet key
31    k2: [u8; 32],
32}
33
34impl SSHChaPoly {
35    pub const TAG_LEN: usize = 16;
36    pub const KEY_LEN: usize = 64;
37
38    /// `key` must be 64 bytes
39    pub fn new_from_slice(key: &[u8]) -> Result<Self> {
40        if key.len() != Self::KEY_LEN {
41            return Err(Error::BadKey);
42        }
43        let mut k1 = [0u8; 32];
44        let mut k2 = [0u8; 32];
45        k1.copy_from_slice(&key[32..64]);
46        k2.copy_from_slice(&key[..32]);
47        Ok(Self { k1, k2 })
48    }
49
50    fn cha20(key: &[u8; 32], seq: u32) -> ChaCha20 {
51        let mut nonce = [0u8; 12];
52        nonce[8..].copy_from_slice(&seq.to_be_bytes());
53        ChaCha20::new(key.into(), (&nonce).into())
54    }
55
56    /// Decrypts the packet length.
57    ///
58    /// `buf` must be at least 4 bytes, extra data is ignored.
59    pub fn packet_length(&self, seq: u32, buf: &[u8]) -> Result<u32> {
60        if buf.len() < SSH_LENGTH_SIZE {
61            return Err(Error::BadDecrypt);
62        }
63        let mut b: [u8; SSH_LENGTH_SIZE] =
64            buf[..SSH_LENGTH_SIZE].try_into().unwrap();
65        let mut c = Self::cha20(&self.k1, seq);
66        c.apply_keystream(&mut b);
67        trace!("packet_length {:?}", b.hex_dump());
68        Ok(u32::from_be_bytes(b))
69    }
70
71    /// Decrypts in-place and validates the MAC.
72    ///
73    /// Length has already been decrypted by `packet_length()`.
74    pub fn decrypt(&self, seq: u32, msg: &mut [u8], mac: &[u8]) -> Result<()> {
75        if msg.len() < SSH_LENGTH_SIZE {
76            return Err(Error::BadDecrypt);
77        }
78        if mac.len() != Self::TAG_LEN {
79            return Err(Error::BadDecrypt);
80        }
81
82        let mut c = Self::cha20(&self.k2, seq);
83        let mut poly_key = [0u8; 32];
84        c.apply_keystream(&mut poly_key);
85
86        // check tag
87        let msg_tag = poly1305::Tag::from_slice(mac);
88        let poly = Poly1305::new((&poly_key).into());
89        // compute_unpadded() adds the necessary trailing 1 byte when padding output
90        let tag = poly.compute_unpadded(msg);
91        let good: bool = tag.ct_eq(msg_tag).into();
92        if !good {
93            return Err(Error::BadDecrypt);
94        }
95
96        // decrypt payload
97        let (_, payload) = msg.split_at_mut(SSH_LENGTH_SIZE);
98        // set block counter to 1
99        c.seek(64u32);
100        c.apply_keystream(payload);
101        Ok(())
102    }
103
104    /// Encrypt in-place, including length, payload, MAC.
105    pub fn encrypt(&self, seq: u32, msg: &mut [u8], mac: &mut [u8]) -> Result<()> {
106        if msg.len() < SSH_LENGTH_SIZE {
107            return Err(Error::BadDecrypt);
108        }
109        if mac.len() != Self::TAG_LEN {
110            return Err(Error::BadDecrypt);
111        }
112
113        // encrypt length
114        let l = (msg.len() - SSH_LENGTH_SIZE) as u32;
115        let msg_len = &mut msg[..SSH_LENGTH_SIZE];
116        msg_len.copy_from_slice(&(l.to_be_bytes()));
117        let mut c = Self::cha20(&self.k1, seq);
118        c.apply_keystream(msg_len);
119
120        let mut c = Self::cha20(&self.k2, seq);
121
122        // encrypt payload
123        let (_, payload) = msg.split_at_mut(SSH_LENGTH_SIZE);
124        // set block counter to 1
125        c.seek(64u32);
126        c.apply_keystream(payload);
127
128        // compute tag
129        let mut poly_key = [0u8; 32];
130        // set block counter to 0
131        c.seek(0u32);
132        c.apply_keystream(&mut poly_key);
133        let poly = Poly1305::new((&poly_key).into());
134        let tag = poly.compute_unpadded(msg);
135        mac.copy_from_slice(tag.as_slice());
136
137        Ok(())
138    }
139}