1#![cfg_attr(fuzzing, allow(dead_code))]
2
3#[allow(unused_imports)]
4use {
5 crate::error::*,
6 log::{debug, error, info, log, trace, warn},
7};
8
9use chacha20::cipher::{
10 KeyIvInit, StreamCipher, StreamCipherSeek, StreamCipherSeekCore,
11};
12use chacha20::ChaCha20;
13use digest::KeyInit;
14use poly1305::universal_hash::generic_array::GenericArray;
15use poly1305::universal_hash::UniversalHash;
16use poly1305::Poly1305;
17use subtle::ConstantTimeEq;
18use zeroize::{Zeroize, ZeroizeOnDrop};
19
20use pretty_hex::PrettyHex;
21
22use crate::*;
23use encrypt::SSH_LENGTH_SIZE;
24
25#[derive(Clone, ZeroizeOnDrop)]
26pub struct SSHChaPoly {
28 k1: [u8; 32],
30 k2: [u8; 32],
32}
33
34impl SSHChaPoly {
35 pub const TAG_LEN: usize = 16;
36 pub const KEY_LEN: usize = 64;
37
38 pub fn new_from_slice(key: &[u8]) -> Result<Self> {
40 if key.len() != Self::KEY_LEN {
41 return Err(Error::BadKey);
42 }
43 let mut k1 = [0u8; 32];
44 let mut k2 = [0u8; 32];
45 k1.copy_from_slice(&key[32..64]);
46 k2.copy_from_slice(&key[..32]);
47 Ok(Self { k1, k2 })
48 }
49
50 fn cha20(key: &[u8; 32], seq: u32) -> ChaCha20 {
51 let mut nonce = [0u8; 12];
52 nonce[8..].copy_from_slice(&seq.to_be_bytes());
53 ChaCha20::new(key.into(), (&nonce).into())
54 }
55
56 pub fn packet_length(&self, seq: u32, buf: &[u8]) -> Result<u32> {
60 if buf.len() < SSH_LENGTH_SIZE {
61 return Err(Error::BadDecrypt);
62 }
63 let mut b: [u8; SSH_LENGTH_SIZE] =
64 buf[..SSH_LENGTH_SIZE].try_into().unwrap();
65 let mut c = Self::cha20(&self.k1, seq);
66 c.apply_keystream(&mut b);
67 trace!("packet_length {:?}", b.hex_dump());
68 Ok(u32::from_be_bytes(b))
69 }
70
71 pub fn decrypt(&self, seq: u32, msg: &mut [u8], mac: &[u8]) -> Result<()> {
75 if msg.len() < SSH_LENGTH_SIZE {
76 return Err(Error::BadDecrypt);
77 }
78 if mac.len() != Self::TAG_LEN {
79 return Err(Error::BadDecrypt);
80 }
81
82 let mut c = Self::cha20(&self.k2, seq);
83 let mut poly_key = [0u8; 32];
84 c.apply_keystream(&mut poly_key);
85
86 let msg_tag = poly1305::Tag::from_slice(mac);
88 let poly = Poly1305::new((&poly_key).into());
89 let tag = poly.compute_unpadded(msg);
91 let good: bool = tag.ct_eq(msg_tag).into();
92 if !good {
93 return Err(Error::BadDecrypt);
94 }
95
96 let (_, payload) = msg.split_at_mut(SSH_LENGTH_SIZE);
98 c.seek(64u32);
100 c.apply_keystream(payload);
101 Ok(())
102 }
103
104 pub fn encrypt(&self, seq: u32, msg: &mut [u8], mac: &mut [u8]) -> Result<()> {
106 if msg.len() < SSH_LENGTH_SIZE {
107 return Err(Error::BadDecrypt);
108 }
109 if mac.len() != Self::TAG_LEN {
110 return Err(Error::BadDecrypt);
111 }
112
113 let l = (msg.len() - SSH_LENGTH_SIZE) as u32;
115 let msg_len = &mut msg[..SSH_LENGTH_SIZE];
116 msg_len.copy_from_slice(&(l.to_be_bytes()));
117 let mut c = Self::cha20(&self.k1, seq);
118 c.apply_keystream(msg_len);
119
120 let mut c = Self::cha20(&self.k2, seq);
121
122 let (_, payload) = msg.split_at_mut(SSH_LENGTH_SIZE);
124 c.seek(64u32);
126 c.apply_keystream(payload);
127
128 let mut poly_key = [0u8; 32];
130 c.seek(0u32);
132 c.apply_keystream(&mut poly_key);
133 let poly = Poly1305::new((&poly_key).into());
134 let tag = poly.compute_unpadded(msg);
135 mac.copy_from_slice(tag.as_slice());
136
137 Ok(())
138 }
139}