rust_rsa_tool/
lib.rs

1use aes::cipher::{block_padding::Pkcs7, BlockDecryptMut, BlockEncryptMut, KeyIvInit, KeyInit};
2use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
3use rand::rngs::OsRng;
4use rand::RngCore;
5use rsa::pkcs8::{DecodePrivateKey, DecodePublicKey, EncodePrivateKey, EncodePublicKey};
6use rsa::{Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey};
7use std::fs::File;
8use std::io::{Read, Write};
9use std::path::Path;
10use thiserror::Error;
11
12type Aes128CbcEnc = cbc::Encryptor<aes::Aes128>;
13type Aes128CbcDec = cbc::Decryptor<aes::Aes128>;
14type Aes128EcbEnc = ecb::Encryptor<aes::Aes128>;
15type Aes128EcbDec = ecb::Decryptor<aes::Aes128>;
16
17/// RSA最大加密明文大小 (for 2048-bit key with PKCS#1 v1.5 padding)
18const MAX_ENCRYPT_BLOCK: usize = 117;
19
20/// Custom error types
21#[derive(Error, Debug)]
22pub enum RsaUtilsError {
23    #[error("RSA encryption error: {0}")]
24    RsaError(#[from] rsa::Error),
25
26    #[error("IO error: {0}")]
27    IoError(#[from] std::io::Error),
28
29    #[error("Base64 decode error: {0}")]
30    Base64Error(#[from] base64::DecodeError),
31
32    #[error("PKCS8 error: {0}")]
33    Pkcs8Error(String),
34
35    #[error("Decryption error: {0}")]
36    DecryptionError(String),
37
38    #[error("Cipher error: {0}")]
39    CipherError(String),
40}
41
42/// RSA key pair container
43pub struct KeyPair {
44    pub public_key: RsaPublicKey,
45    pub private_key: RsaPrivateKey,
46}
47
48/// Initialize and generate RSA key pair (2048-bit)
49pub fn init_key() -> Result<KeyPair, RsaUtilsError> {
50    let mut rng = OsRng;
51    let bits = 2048;
52    let private_key = RsaPrivateKey::new(&mut rng, bits)?;
53    let public_key = RsaPublicKey::from(&private_key);
54
55    Ok(KeyPair {
56        public_key,
57        private_key,
58    })
59}
60
61/// Get public key from base64 encoded string (X.509/SPKI format)
62pub fn get_public_key(key_str: &str) -> Result<RsaPublicKey, RsaUtilsError> {
63    let key_bytes = BASE64.decode(key_str)?;
64
65    // Try to parse as SPKI (X.509) format
66    let public_key = RsaPublicKey::from_public_key_der(&key_bytes)
67        .map_err(|e| RsaUtilsError::Pkcs8Error(format!("Failed to parse public key: {}", e)))?;
68
69    Ok(public_key)
70}
71
72/// Get private key from base64 encoded string (PKCS#8 format)
73pub fn get_private_key(key_str: &str) -> Result<RsaPrivateKey, RsaUtilsError> {
74    let key_bytes = BASE64.decode(key_str)?;
75
76    // Try to parse as PKCS#8 format
77    let private_key = RsaPrivateKey::from_pkcs8_der(&key_bytes)
78        .map_err(|e| RsaUtilsError::Pkcs8Error(format!("Failed to parse private key: {}", e)))?;
79
80    Ok(private_key)
81}
82
83/// Encode public key to base64 string (X.509/SPKI format)
84pub fn encode_public_key(public_key: &RsaPublicKey) -> Result<String, RsaUtilsError> {
85    let der = public_key
86        .to_public_key_der()
87        .map_err(|e| RsaUtilsError::Pkcs8Error(format!("Failed to encode public key: {}", e)))?;
88    Ok(BASE64.encode(der.as_bytes()))
89}
90
91/// Encode private key to base64 string (PKCS#8 format)
92pub fn encode_private_key(private_key: &RsaPrivateKey) -> Result<String, RsaUtilsError> {
93    let der = private_key
94        .to_pkcs8_der()
95        .map_err(|e| RsaUtilsError::Pkcs8Error(format!("Failed to encode private key: {}", e)))?;
96    Ok(BASE64.encode(der.as_bytes()))
97}
98
99/// Encrypt data with RSA public key (supports data larger than key size via chunking)
100pub fn encrypt(plain_text: &[u8], public_key_str: &str) -> Result<Vec<u8>, RsaUtilsError> {
101    let public_key = get_public_key(public_key_str)?;
102    let mut rng = OsRng;
103
104    let mut result = Vec::new();
105    let mut offset = 0;
106    let input_len = plain_text.len();
107
108    while offset < input_len {
109        let chunk_size = std::cmp::min(MAX_ENCRYPT_BLOCK, input_len - offset);
110        let chunk = &plain_text[offset..offset + chunk_size];
111
112        let encrypted_chunk = public_key.encrypt(&mut rng, Pkcs1v15Encrypt, chunk)?;
113        result.extend_from_slice(&encrypted_chunk);
114
115        offset += chunk_size;
116    }
117
118    Ok(result)
119}
120
121/// Encrypt file using hybrid encryption (AES for file content, RSA for AES key)
122/// This matches the Java implementation's approach
123pub fn encrypt_file<P: AsRef<Path>>(
124    input_path: P,
125    output_path: P,
126    public_key_str: &str,
127) -> Result<(), RsaUtilsError> {
128    // Generate random AES key (128-bit)
129    let mut aes_key = [0u8; 16];
130    let mut iv = [0u8; 16];
131    OsRng.fill_bytes(&mut aes_key);
132    OsRng.fill_bytes(&mut iv);
133
134    // Get RSA public key
135    let public_key = get_public_key(public_key_str)?;
136    let mut rng = OsRng;
137
138    // Wrap (encrypt) the AES key with RSA
139    let mut key_to_wrap = Vec::new();
140    key_to_wrap.extend_from_slice(&aes_key);
141    key_to_wrap.extend_from_slice(&iv);
142
143    let wrapped_key = public_key.encrypt(&mut rng, Pkcs1v15Encrypt, &key_to_wrap)?;
144
145    // Open input and output files
146    let mut input_file = File::open(input_path)?;
147    let mut output_file = File::create(output_path)?;
148
149    // Write wrapped key length and wrapped key
150    output_file.write_all(&(wrapped_key.len() as u32).to_be_bytes())?;
151    output_file.write_all(&wrapped_key)?;
152
153    // Encrypt file content with AES
154    let cipher = Aes128CbcEnc::new(&aes_key.into(), &iv.into());
155    encrypt_stream(&mut input_file, &mut output_file, cipher)?;
156
157    Ok(())
158}
159
160/// Encrypt file using Java-compatible format (AES/ECB mode)
161/// This generates files that can be decrypted by Java's default Cipher.getInstance("AES")
162pub fn encrypt_file_java_ecb<P: AsRef<Path>>(
163    input_path: P,
164    output_path: P,
165    public_key_str: &str,
166) -> Result<(), RsaUtilsError> {
167    // Generate random AES key (128-bit)
168    // Note: ECB mode doesn't use IV
169    let mut aes_key = [0u8; 16];
170    OsRng.fill_bytes(&mut aes_key);
171
172    // Get RSA public key
173    let public_key = get_public_key(public_key_str)?;
174    let mut rng = OsRng;
175
176    // Wrap (encrypt) only the AES key with RSA (no IV for ECB mode)
177    let wrapped_key = public_key.encrypt(&mut rng, Pkcs1v15Encrypt, &aes_key)?;
178
179    // Open input and output files
180    let mut input_file = File::open(input_path)?;
181    let mut output_file = File::create(output_path)?;
182
183    // Write wrapped key length and wrapped key
184    output_file.write_all(&(wrapped_key.len() as u32).to_be_bytes())?;
185    output_file.write_all(&wrapped_key)?;
186
187    // Encrypt file content with AES-ECB
188    let cipher = Aes128EcbEnc::new(&aes_key.into());
189    encrypt_stream_ecb(&mut input_file, &mut output_file, cipher)?;
190
191    Ok(())
192}
193
194/// Decrypt file using hybrid decryption (RSA for AES key, AES for file content)
195pub fn decrypt_file<P: AsRef<Path>>(
196    input_path: P,
197    output_path: P,
198    private_key_str: &str,
199) -> Result<(), RsaUtilsError> {
200    // Get RSA private key
201    let private_key = get_private_key(private_key_str)?;
202
203    // Open input and output files
204    let mut input_file = File::open(input_path)?;
205    let mut output_file = File::create(output_path)?;
206
207    // Read wrapped key length
208    let mut length_bytes = [0u8; 4];
209    input_file.read_exact(&mut length_bytes)?;
210    let wrapped_key_len = u32::from_be_bytes(length_bytes) as usize;
211
212    // Read wrapped key
213    let mut wrapped_key = vec![0u8; wrapped_key_len];
214    input_file.read_exact(&mut wrapped_key)?;
215
216    // Unwrap (decrypt) the AES key with RSA
217    let unwrapped = private_key
218        .decrypt(Pkcs1v15Encrypt, &wrapped_key)
219        .map_err(|e| RsaUtilsError::DecryptionError(format!("Failed to unwrap key: {}", e)))?;
220
221    if unwrapped.len() != 32 {
222        return Err(RsaUtilsError::DecryptionError(
223            "Invalid unwrapped key size".to_string(),
224        ));
225    }
226
227    let mut aes_key = [0u8; 16];
228    let mut iv = [0u8; 16];
229    aes_key.copy_from_slice(&unwrapped[0..16]);
230    iv.copy_from_slice(&unwrapped[16..32]);
231
232    // Decrypt file content with AES
233    let cipher = Aes128CbcDec::new(&aes_key.into(), &iv.into());
234    decrypt_stream(&mut input_file, &mut output_file, cipher)?;
235
236    Ok(())
237}
238
239/// Decrypt file encrypted by Java (using AES/ECB mode)
240/// Java's default Cipher.getInstance("AES") uses ECB mode without IV
241pub fn decrypt_file_java_ecb<P: AsRef<Path>>(
242    input_path: P,
243    output_path: P,
244    private_key_str: &str,
245) -> Result<(), RsaUtilsError> {
246    // Get RSA private key
247    let private_key = get_private_key(private_key_str)?;
248
249    // Open input and output files
250    let mut input_file = File::open(input_path)?;
251    let mut output_file = File::create(output_path)?;
252
253    // Read wrapped key length
254    let mut length_bytes = [0u8; 4];
255    input_file.read_exact(&mut length_bytes)?;
256    let wrapped_key_len = u32::from_be_bytes(length_bytes) as usize;
257
258    // Read wrapped key
259    let mut wrapped_key = vec![0u8; wrapped_key_len];
260    input_file.read_exact(&mut wrapped_key)?;
261
262    // Unwrap (decrypt) the AES key with RSA
263    let aes_key_bytes = private_key
264        .decrypt(Pkcs1v15Encrypt, &wrapped_key)
265        .map_err(|e| RsaUtilsError::DecryptionError(format!("Failed to unwrap key: {}", e)))?;
266
267    if aes_key_bytes.len() != 16 {
268        return Err(RsaUtilsError::DecryptionError(
269            format!("Invalid unwrapped key size: expected 16 bytes, got {}", aes_key_bytes.len()),
270        ));
271    }
272
273    let mut aes_key = [0u8; 16];
274    aes_key.copy_from_slice(&aes_key_bytes[0..16]);
275
276    // Decrypt file content with AES-ECB
277    let cipher = Aes128EcbDec::new(&aes_key.into());
278    decrypt_stream_ecb(&mut input_file, &mut output_file, cipher)?;
279
280    Ok(())
281}
282
283/// Encrypt data stream with AES cipher (CBC mode)
284fn encrypt_stream<R: Read, W: Write>(
285    input: &mut R,
286    output: &mut W,
287    cipher: Aes128CbcEnc,
288) -> Result<(), RsaUtilsError> {
289    let mut buffer = Vec::new();
290    input.read_to_end(&mut buffer)?;
291
292    // Pad and encrypt
293    let ciphertext = cipher.encrypt_padded_vec_mut::<Pkcs7>(&buffer);
294
295    output.write_all(&ciphertext)?;
296    Ok(())
297}
298
299/// Encrypt data stream with AES cipher (ECB mode - for Java compatibility)
300fn encrypt_stream_ecb<R: Read, W: Write>(
301    input: &mut R,
302    output: &mut W,
303    cipher: Aes128EcbEnc,
304) -> Result<(), RsaUtilsError> {
305    let mut buffer = Vec::new();
306    input.read_to_end(&mut buffer)?;
307
308    // Pad and encrypt
309    let ciphertext = cipher.encrypt_padded_vec_mut::<Pkcs7>(&buffer);
310
311    output.write_all(&ciphertext)?;
312    Ok(())
313}
314
315/// Decrypt data stream with AES cipher (CBC mode)
316fn decrypt_stream<R: Read, W: Write>(
317    input: &mut R,
318    output: &mut W,
319    cipher: Aes128CbcDec,
320) -> Result<(), RsaUtilsError> {
321    let mut buffer = Vec::new();
322    input.read_to_end(&mut buffer)?;
323
324    // Decrypt and unpad
325    let plaintext = cipher
326        .decrypt_padded_vec_mut::<Pkcs7>(&buffer)
327        .map_err(|e| RsaUtilsError::DecryptionError(format!("Decryption failed: {}", e)))?;
328
329    output.write_all(&plaintext)?;
330    Ok(())
331}
332
333/// Decrypt data stream with AES cipher (ECB mode - for Java compatibility)
334fn decrypt_stream_ecb<R: Read, W: Write>(
335    input: &mut R,
336    output: &mut W,
337    cipher: Aes128EcbDec,
338) -> Result<(), RsaUtilsError> {
339    let mut buffer = Vec::new();
340    input.read_to_end(&mut buffer)?;
341
342    // Decrypt and unpad
343    let plaintext = cipher
344        .decrypt_padded_vec_mut::<Pkcs7>(&buffer)
345        .map_err(|e| RsaUtilsError::DecryptionError(format!("Decryption failed: {}", e)))?;
346
347    output.write_all(&plaintext)?;
348    Ok(())
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use rsa::traits::PublicKeyParts;
355    use std::io::Write;
356    use tempfile::NamedTempFile;
357
358    #[test]
359    fn test_key_generation() {
360        let key_pair = init_key().unwrap();
361        assert_eq!(key_pair.private_key.size(), 256); // 2048 bits = 256 bytes
362    }
363
364    #[test]
365    fn test_key_encoding_decoding() {
366        let key_pair = init_key().unwrap();
367
368        // Test public key
369        let pub_key_str = encode_public_key(&key_pair.public_key).unwrap();
370        let decoded_pub = get_public_key(&pub_key_str).unwrap();
371        assert_eq!(key_pair.public_key.n(), decoded_pub.n());
372
373        // Test private key
374        let priv_key_str = encode_private_key(&key_pair.private_key).unwrap();
375        let decoded_priv = get_private_key(&priv_key_str).unwrap();
376        assert_eq!(key_pair.private_key.n(), decoded_priv.n());
377    }
378
379    #[test]
380    fn test_small_data_encryption() {
381        let key_pair = init_key().unwrap();
382        let pub_key_str = encode_public_key(&key_pair.public_key).unwrap();
383
384        let plain_text = b"Hello, RSA!";
385        let encrypted = encrypt(plain_text, &pub_key_str).unwrap();
386
387        // Decrypt to verify
388        let decrypted = key_pair
389            .private_key
390            .decrypt(Pkcs1v15Encrypt, &encrypted)
391            .unwrap();
392        assert_eq!(plain_text, &decrypted[..]);
393    }
394
395    #[test]
396    fn test_file_encryption_decryption() {
397        let key_pair = init_key().unwrap();
398        let pub_key_str = encode_public_key(&key_pair.public_key).unwrap();
399        let priv_key_str = encode_private_key(&key_pair.private_key).unwrap();
400
401        // Create test file
402        let mut input_file = NamedTempFile::new().unwrap();
403        let test_data = b"This is a test file for RSA encryption!\nIt has multiple lines.\nAnd some more content to make it interesting.";
404        input_file.write_all(test_data).unwrap();
405        input_file.flush().unwrap();
406
407        // Create temp files for encrypted and decrypted output
408        let encrypted_file = NamedTempFile::new().unwrap();
409        let decrypted_file = NamedTempFile::new().unwrap();
410
411        // Encrypt
412        encrypt_file(
413            input_file.path(),
414            encrypted_file.path(),
415            &pub_key_str,
416        )
417            .unwrap();
418
419        // Decrypt
420        decrypt_file(
421            encrypted_file.path(),
422            decrypted_file.path(),
423            &priv_key_str,
424        )
425            .unwrap();
426
427        // Verify
428        let mut decrypted_content = Vec::new();
429        File::open(decrypted_file.path())
430            .unwrap()
431            .read_to_end(&mut decrypted_content)
432            .unwrap();
433
434        assert_eq!(test_data, &decrypted_content[..]);
435    }
436}