rust_auth_utils/
rsa.rs

1// based on https://github.com/better-auth/utils/blob/main/src/rsa.ts
2
3use crate::types::ExportKeyFormat;
4use rsa::{
5    pkcs8::{DecodePublicKey, EncodePrivateKey, EncodePublicKey},
6    Oaep, Pss, RsaPrivateKey, RsaPublicKey,
7};
8use sha2::{Digest, Sha256, Sha384, Sha512};
9
10const DEFAULT_RSA_KEY_SIZE: usize = 2048;
11const DEFAULT_SALT_LENGTH: usize = 32;
12
13#[derive(Clone)]
14pub struct RsaKeyPair {
15    pub private_key: RsaPrivateKey,
16    pub public_key: RsaPublicKey,
17    pub hash_algorithm: HashAlgorithm,
18}
19
20#[derive(Clone, Copy)]
21pub enum HashAlgorithm {
22    SHA256,
23    SHA384,
24    SHA512,
25}
26
27impl HashAlgorithm {
28    fn digest(&self, data: &[u8]) -> Vec<u8> {
29        match self {
30            HashAlgorithm::SHA256 => {
31                let mut hasher = Sha256::new();
32                hasher.update(data);
33                hasher.finalize().to_vec()
34            }
35            HashAlgorithm::SHA384 => {
36                let mut hasher = Sha384::new();
37                hasher.update(data);
38                hasher.finalize().to_vec()
39            }
40            HashAlgorithm::SHA512 => {
41                let mut hasher = Sha512::new();
42                hasher.update(data);
43                hasher.finalize().to_vec()
44            }
45        }
46    }
47
48    fn oaep_padding(&self) -> Oaep {
49        match self {
50            HashAlgorithm::SHA256 => Oaep::new::<Sha256>(),
51            HashAlgorithm::SHA384 => Oaep::new::<Sha384>(),
52            HashAlgorithm::SHA512 => Oaep::new::<Sha512>(),
53        }
54    }
55}
56
57pub struct RSA;
58
59impl RSA {
60    pub async fn generate_key_pair(
61        modulus_length: Option<usize>,
62        hash: Option<HashAlgorithm>,
63    ) -> Result<RsaKeyPair, Box<dyn std::error::Error>> {
64        let mut rng = rand::thread_rng();
65        let bits = modulus_length.unwrap_or(DEFAULT_RSA_KEY_SIZE);
66        let private_key = RsaPrivateKey::new(&mut rng, bits)?;
67        let public_key = RsaPublicKey::from(&private_key);
68        let hash_algorithm = hash.unwrap_or(HashAlgorithm::SHA256);
69        Ok(RsaKeyPair {
70            private_key,
71            public_key,
72            hash_algorithm,
73        })
74    }
75
76    pub async fn export_key_public(
77        key: &RsaKeyPair,
78        format: ExportKeyFormat,
79    ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
80        match format {
81            ExportKeyFormat::SPKI => Ok(key.public_key.to_public_key_der()?.as_bytes().to_vec()),
82            ExportKeyFormat::JWK => {
83                // TODO: Implement JWK format
84                Err("JWK format not yet implemented".into())
85            }
86            _ => Err("Unsupported export format for public key".into()),
87        }
88    }
89
90    pub async fn export_key_private(
91        key: &RsaKeyPair,
92        format: ExportKeyFormat,
93    ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
94        match format {
95            ExportKeyFormat::PKCS8 => Ok(key.private_key.to_pkcs8_der()?.as_bytes().to_vec()),
96            ExportKeyFormat::JWK => {
97                // TODO: Implement JWK format
98                Err("JWK format not yet implemented".into())
99            }
100            _ => Err("Unsupported export format for private key".into()),
101        }
102    }
103
104    pub async fn import_key(
105        key_data: &[u8],
106        format: ExportKeyFormat,
107        _for_encryption: bool,
108        hash: Option<HashAlgorithm>,
109    ) -> Result<RsaKeyPair, Box<dyn std::error::Error>> {
110        let hash_algorithm = hash.unwrap_or(HashAlgorithm::SHA256);
111        match format {
112            ExportKeyFormat::SPKI => {
113                let public_key = RsaPublicKey::from_public_key_der(key_data)?;
114                // Note: private key will be None for imported public keys
115                let private_key =
116                    RsaPrivateKey::new(&mut rand::thread_rng(), DEFAULT_RSA_KEY_SIZE)?;
117                Ok(RsaKeyPair {
118                    private_key,
119                    public_key,
120                    hash_algorithm,
121                })
122            }
123            ExportKeyFormat::JWK => {
124                // TODO: Implement JWK format
125                Err("JWK format not yet implemented".into())
126            }
127            _ => Err("Unsupported import format".into()),
128        }
129    }
130
131    pub async fn encrypt(
132        key: &RsaKeyPair,
133        data: impl AsRef<[u8]>,
134    ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
135        let mut rng = rand::thread_rng();
136        let padding = key.hash_algorithm.oaep_padding();
137        Ok(key.public_key.encrypt(&mut rng, padding, data.as_ref())?)
138    }
139
140    pub async fn decrypt(
141        key: &RsaKeyPair,
142        encrypted_data: impl AsRef<[u8]>,
143    ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
144        let padding = key.hash_algorithm.oaep_padding();
145        Ok(key.private_key.decrypt(padding, encrypted_data.as_ref())?)
146    }
147
148    pub async fn sign(
149        key: &RsaKeyPair,
150        data: impl AsRef<[u8]>,
151        salt_length: Option<usize>,
152        hash: Option<HashAlgorithm>,
153    ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
154        let hash_alg = hash.unwrap_or(key.hash_algorithm);
155        let hashed = hash_alg.digest(data.as_ref());
156
157        let mut rng = rand::thread_rng();
158        let salt_len = salt_length.unwrap_or(DEFAULT_SALT_LENGTH);
159        let padding = match hash_alg {
160            HashAlgorithm::SHA256 => Pss::new_with_salt::<Sha256>(salt_len),
161            HashAlgorithm::SHA384 => Pss::new_with_salt::<Sha384>(salt_len),
162            HashAlgorithm::SHA512 => Pss::new_with_salt::<Sha512>(salt_len),
163        };
164        Ok(key.private_key.sign_with_rng(&mut rng, padding, &hashed)?)
165    }
166
167    pub async fn verify(
168        key: &RsaKeyPair,
169        signature: impl AsRef<[u8]>,
170        data: impl AsRef<[u8]>,
171        salt_length: Option<usize>,
172        hash: Option<HashAlgorithm>,
173    ) -> Result<bool, Box<dyn std::error::Error>> {
174        let hash_alg = hash.unwrap_or(key.hash_algorithm);
175        let hashed = hash_alg.digest(data.as_ref());
176
177        let salt_len = salt_length.unwrap_or(DEFAULT_SALT_LENGTH);
178        let padding = match hash_alg {
179            HashAlgorithm::SHA256 => Pss::new_with_salt::<Sha256>(salt_len),
180            HashAlgorithm::SHA384 => Pss::new_with_salt::<Sha384>(salt_len),
181            HashAlgorithm::SHA512 => Pss::new_with_salt::<Sha512>(salt_len),
182        };
183        Ok(key
184            .public_key
185            .verify(padding, &hashed, signature.as_ref())
186            .is_ok())
187    }
188}