ssh_vault/vault/ssh/
rsa.rs

1use crate::vault::{
2    Vault, crypto::Crypto, crypto::aes256::Aes256Crypto, fingerprint::md5_fingerprint,
3};
4use anyhow::{Context, Result};
5use base64ct::{Base64, Encoding};
6use rand::rngs::OsRng;
7use rsa::{BigUint, Oaep, RsaPrivateKey, RsaPublicKey};
8use secrecy::{ExposeSecret, SecretSlice};
9use sha2::Sha256;
10use ssh_key::{PrivateKey, PublicKey, private::KeypairData, public::KeyData};
11use zeroize::Zeroize;
12
13#[derive(Debug)]
14pub struct RsaVault {
15    public_key: RsaPublicKey,
16    private_key: Option<RsaPrivateKey>,
17}
18
19impl Vault for RsaVault {
20    fn new(public: Option<PublicKey>, private: Option<PrivateKey>) -> Result<Self> {
21        match (public, private) {
22            (Some(public), None) => match public.key_data() {
23                KeyData::Rsa(key_data) => {
24                    let public_key =
25                        RsaPublicKey::try_from(key_data).context("Could not load key")?;
26
27                    Ok(Self {
28                        public_key,
29                        private_key: None,
30                    })
31                }
32                _ => Err(anyhow::anyhow!("Invalid key type for RsaVault")),
33            },
34
35            (None, Some(private)) => match private.key_data() {
36                KeypairData::Rsa(rsa_keypair) => {
37                    if private.is_encrypted() {
38                        return Err(anyhow::anyhow!("Private key is encrypted"));
39                    }
40
41                    // Extract components from ssh-key's RSA representation
42                    // Use as_bytes() or a similar method to get the &[u8] from Mpint
43                    //
44                    // <https://docs.rs/ssh-key/latest/ssh_key/private/struct.RsaPrivateKey.html>
45                    //
46                    // pub struct RsaPrivateKey {
47                    //     pub d: Mpint,
48                    //     pub iqmp: Mpint,
49                    //     pub p: Mpint,
50                    //     pub q: Mpint,
51                    // }
52                    let modulus = BigUint::from_bytes_be(rsa_keypair.public.n.as_ref());
53                    let public_exponent = BigUint::from_bytes_be(rsa_keypair.public.e.as_ref());
54                    let private_exponent = BigUint::from_bytes_be(rsa_keypair.private.d.as_ref());
55                    let prime_p = BigUint::from_bytes_be(rsa_keypair.private.p.as_ref());
56                    let prime_q = BigUint::from_bytes_be(rsa_keypair.private.q.as_ref());
57
58                    // Create the RSA private key
59                    //
60                    // Constructs an RSA key pair from individual components:
61                    //
62                    // n: RSA modulus
63                    // e: public exponent (i.e. encrypting exponent)
64                    // d: private exponent (i.e. decrypting exponent)
65                    // primes: prime factors of n: typically two primes p and q. More than two
66                    // primes can be provided for multiprime RSA, however this is generally not
67                    // recommended. If no primes are provided, a prime factor recovery algorithm
68                    // will be employed to attempt to recover the factors (as described in NIST SP
69                    // 800-56B Revision 2 Appendix C.2). This algorithm only works if there are
70                    // just two prime factors p and q (as opposed to multiprime), and e is between
71                    // 2^16 and 2^256.
72                    let private_key = RsaPrivateKey::from_components(
73                        modulus,
74                        public_exponent,
75                        private_exponent,
76                        vec![prime_p, prime_q],
77                    )?;
78
79                    // let private_key = RsaPrivateKey::try_from(key_data)?;
80
81                    let public_key = private_key.to_public_key();
82
83                    Ok(Self {
84                        public_key,
85                        private_key: Some(private_key),
86                    })
87                }
88                _ => Err(anyhow::anyhow!("Invalid key type for RsaVault")),
89            },
90
91            (Some(_), Some(_)) => Err(anyhow::anyhow!(
92                "Only one of public and private key is required"
93            )),
94
95            _ => Err(anyhow::anyhow!("Missing public and private key")),
96        }
97    }
98
99    fn create(&self, password: SecretSlice<u8>, data: &mut [u8]) -> Result<String> {
100        let crypto = Aes256Crypto::new(password.clone());
101
102        let fingerprint = md5_fingerprint(&self.public_key)?;
103
104        let encrypted_data = crypto.encrypt(data, fingerprint.as_bytes())?;
105
106        // zeroize data
107        data.zeroize();
108
109        let encrypted_password =
110            self.public_key
111                .encrypt(&mut OsRng, Oaep::new::<Sha256>(), password.expose_secret())?;
112
113        // create vault payload
114        let payload = format!(
115            "{};{}",
116            Base64::encode_string(&encrypted_password),
117            Base64::encode_string(&encrypted_data)
118        )
119        .chars()
120        .collect::<Vec<_>>()
121        .chunks(64)
122        .map(|chunk| chunk.iter().collect::<String>())
123        .collect::<Vec<_>>()
124        .join("\n");
125
126        Ok(format!("SSH-VAULT;AES256;{fingerprint}\n{payload}"))
127    }
128
129    fn view(&self, password: &[u8], data: &[u8], fingerprint: &str) -> Result<String> {
130        let get_fingerprint = md5_fingerprint(&self.public_key)?;
131
132        if get_fingerprint != fingerprint {
133            return Err(anyhow::anyhow!("Fingerprint mismatch, use correct key"));
134        }
135
136        match &self.private_key {
137            Some(private_key) => {
138                let password: SecretSlice<u8> =
139                    SecretSlice::new(private_key.decrypt(Oaep::new::<Sha256>(), password)?.into());
140
141                let crypto = Aes256Crypto::new(password);
142
143                let out = crypto.decrypt(data, fingerprint.as_bytes())?;
144                Ok(String::from_utf8(out)?)
145            }
146            None => Err(anyhow::anyhow!("Private key is required to view vault")),
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::vault::Vault;
155    use anyhow::Result;
156    use ssh_key::{PrivateKey, PublicKey};
157    use std::path::Path;
158
159    #[test]
160    fn test_rsa_vault_using_both_keys() -> Result<()> {
161        let public_key_file = Path::new("test_data/id_rsa.pub");
162        let private_key_file = Path::new("test_data/id_rsa");
163        let public_key = PublicKey::read_openssh_file(public_key_file)?;
164        let private_key = PrivateKey::read_openssh_file(private_key_file)?;
165        let vault = RsaVault::new(Some(public_key), Some(private_key));
166        assert!(vault.is_err());
167
168        let Err(err) = vault else {
169            unreachable!("expected error when both keys provided")
170        };
171
172        // Convert the error to a string and check the message
173        assert_eq!(
174            err.to_string(),
175            "Only one of public and private key is required"
176        );
177
178        Ok(())
179    }
180
181    #[test]
182    fn test_rsa_vault_using_public_key() -> Result<()> {
183        let public_key_file = Path::new("test_data/id_rsa.pub");
184        let public_key = PublicKey::read_openssh_file(public_key_file)?;
185        let vault = RsaVault::new(Some(public_key), None);
186        assert!(vault.is_ok());
187        Ok(())
188    }
189
190    #[test]
191    fn test_rsa_vault_using_private_key() -> Result<()> {
192        let private_key_file = Path::new("test_data/id_rsa");
193        let private_key = PrivateKey::read_openssh_file(private_key_file)?;
194        let vault = RsaVault::new(None, Some(private_key));
195        assert!(vault.is_ok());
196        Ok(())
197    }
198}