ruvector_dag/qudag/crypto/
ml_kem.rs1use zeroize::Zeroize;
16
17pub const ML_KEM_768_PUBLIC_KEY_SIZE: usize = 1184;
20pub const ML_KEM_768_SECRET_KEY_SIZE: usize = 2400;
21pub const ML_KEM_768_CIPHERTEXT_SIZE: usize = 1088;
22pub const SHARED_SECRET_SIZE: usize = 32;
23
24#[derive(Clone)]
25pub struct MlKem768PublicKey(pub [u8; ML_KEM_768_PUBLIC_KEY_SIZE]);
26
27#[derive(Clone, Zeroize)]
28#[zeroize(drop)]
29pub struct MlKem768SecretKey(pub [u8; ML_KEM_768_SECRET_KEY_SIZE]);
30
31#[derive(Clone)]
32pub struct EncapsulatedKey {
33 pub ciphertext: [u8; ML_KEM_768_CIPHERTEXT_SIZE],
34 pub shared_secret: [u8; SHARED_SECRET_SIZE],
35}
36
37pub struct MlKem768;
38
39#[cfg(feature = "production-crypto")]
44mod production {
45 use super::*;
46 use pqcrypto_kyber::kyber768;
47 use pqcrypto_traits::kem::{Ciphertext, PublicKey, SecretKey, SharedSecret};
48
49 impl MlKem768 {
50 pub fn generate_keypair() -> Result<(MlKem768PublicKey, MlKem768SecretKey), KemError> {
52 let (pk, sk) = kyber768::keypair();
53
54 let pk_bytes = pk.as_bytes();
55 let sk_bytes = sk.as_bytes();
56
57 let mut pk_arr = [0u8; ML_KEM_768_PUBLIC_KEY_SIZE];
59 let mut sk_arr = [0u8; ML_KEM_768_SECRET_KEY_SIZE];
60
61 if pk_bytes.len() != ML_KEM_768_PUBLIC_KEY_SIZE {
62 return Err(KemError::InvalidPublicKey);
63 }
64 if sk_bytes.len() != ML_KEM_768_SECRET_KEY_SIZE {
65 return Err(KemError::DecapsulationFailed);
66 }
67
68 pk_arr.copy_from_slice(pk_bytes);
69 sk_arr.copy_from_slice(sk_bytes);
70
71 Ok((MlKem768PublicKey(pk_arr), MlKem768SecretKey(sk_arr)))
72 }
73
74 pub fn encapsulate(pk: &MlKem768PublicKey) -> Result<EncapsulatedKey, KemError> {
76 let public_key =
77 kyber768::PublicKey::from_bytes(&pk.0).map_err(|_| KemError::InvalidPublicKey)?;
78
79 let (ss, ct) = kyber768::encapsulate(&public_key);
80
81 let ss_bytes = ss.as_bytes();
82 let ct_bytes = ct.as_bytes();
83
84 let mut shared_secret = [0u8; SHARED_SECRET_SIZE];
85 let mut ciphertext = [0u8; ML_KEM_768_CIPHERTEXT_SIZE];
86
87 if ss_bytes.len() != SHARED_SECRET_SIZE {
88 return Err(KemError::DecapsulationFailed);
89 }
90 if ct_bytes.len() != ML_KEM_768_CIPHERTEXT_SIZE {
91 return Err(KemError::InvalidCiphertext);
92 }
93
94 shared_secret.copy_from_slice(ss_bytes);
95 ciphertext.copy_from_slice(ct_bytes);
96
97 Ok(EncapsulatedKey {
98 ciphertext,
99 shared_secret,
100 })
101 }
102
103 pub fn decapsulate(
105 sk: &MlKem768SecretKey,
106 ciphertext: &[u8; ML_KEM_768_CIPHERTEXT_SIZE],
107 ) -> Result<[u8; SHARED_SECRET_SIZE], KemError> {
108 let secret_key = kyber768::SecretKey::from_bytes(&sk.0)
109 .map_err(|_| KemError::DecapsulationFailed)?;
110
111 let ct = kyber768::Ciphertext::from_bytes(ciphertext)
112 .map_err(|_| KemError::InvalidCiphertext)?;
113
114 let ss = kyber768::decapsulate(&ct, &secret_key);
115 let ss_bytes = ss.as_bytes();
116
117 let mut shared_secret = [0u8; SHARED_SECRET_SIZE];
118 if ss_bytes.len() != SHARED_SECRET_SIZE {
119 return Err(KemError::DecapsulationFailed);
120 }
121
122 shared_secret.copy_from_slice(ss_bytes);
123 Ok(shared_secret)
124 }
125 }
126}
127
128#[cfg(not(feature = "production-crypto"))]
133mod placeholder {
134 use super::*;
135 use sha2::{Digest, Sha256};
136
137 impl MlKem768 {
138 pub fn generate_keypair() -> Result<(MlKem768PublicKey, MlKem768SecretKey), KemError> {
143 let mut pk = [0u8; ML_KEM_768_PUBLIC_KEY_SIZE];
144 let mut sk = [0u8; ML_KEM_768_SECRET_KEY_SIZE];
145
146 getrandom::getrandom(&mut pk).map_err(|_| KemError::RngFailed)?;
147 getrandom::getrandom(&mut sk).map_err(|_| KemError::RngFailed)?;
148
149 Ok((MlKem768PublicKey(pk), MlKem768SecretKey(sk)))
150 }
151
152 pub fn encapsulate(pk: &MlKem768PublicKey) -> Result<EncapsulatedKey, KemError> {
157 let mut ephemeral = [0u8; 32];
158 getrandom::getrandom(&mut ephemeral).map_err(|_| KemError::RngFailed)?;
159
160 let mut ciphertext = [0u8; ML_KEM_768_CIPHERTEXT_SIZE];
161
162 let pk_hash = Self::sha256(&pk.0[..64]);
163 for i in 0..32 {
164 ciphertext[i] = ephemeral[i] ^ pk_hash[i];
165 }
166
167 let padding = Self::sha256(&ephemeral);
168 for i in 32..ML_KEM_768_CIPHERTEXT_SIZE {
169 ciphertext[i] = padding[i % 32];
170 }
171
172 let shared_secret = Self::hkdf_sha256(&ephemeral, &pk.0[..32], b"ml-kem-768-shared");
173
174 Ok(EncapsulatedKey {
175 ciphertext,
176 shared_secret,
177 })
178 }
179
180 pub fn decapsulate(
185 sk: &MlKem768SecretKey,
186 ciphertext: &[u8; ML_KEM_768_CIPHERTEXT_SIZE],
187 ) -> Result<[u8; SHARED_SECRET_SIZE], KemError> {
188 let sk_hash = Self::sha256(&sk.0[..64]);
189 let mut ephemeral = [0u8; 32];
190 for i in 0..32 {
191 ephemeral[i] = ciphertext[i] ^ sk_hash[i];
192 }
193
194 let expected_padding = Self::sha256(&ephemeral);
195 for i in 32..64.min(ML_KEM_768_CIPHERTEXT_SIZE) {
196 if ciphertext[i] != expected_padding[i % 32] {
197 return Err(KemError::InvalidCiphertext);
198 }
199 }
200
201 let shared_secret = Self::hkdf_sha256(&ephemeral, &sk.0[..32], b"ml-kem-768-shared");
202 Ok(shared_secret)
203 }
204
205 fn hkdf_sha256(ikm: &[u8], salt: &[u8], info: &[u8]) -> [u8; SHARED_SECRET_SIZE] {
206 let prk = Self::hmac_sha256(salt, ikm);
207 let mut okm_input = Vec::with_capacity(info.len() + 1);
208 okm_input.extend_from_slice(info);
209 okm_input.push(1);
210 Self::hmac_sha256(&prk, &okm_input)
211 }
212
213 fn hmac_sha256(key: &[u8], message: &[u8]) -> [u8; 32] {
214 const BLOCK_SIZE: usize = 64;
215
216 let mut key_block = [0u8; BLOCK_SIZE];
217 if key.len() > BLOCK_SIZE {
218 let hash = Self::sha256(key);
219 key_block[..32].copy_from_slice(&hash);
220 } else {
221 key_block[..key.len()].copy_from_slice(key);
222 }
223
224 let mut ipad = [0x36u8; BLOCK_SIZE];
225 let mut opad = [0x5cu8; BLOCK_SIZE];
226 for i in 0..BLOCK_SIZE {
227 ipad[i] ^= key_block[i];
228 opad[i] ^= key_block[i];
229 }
230
231 let mut inner = Vec::with_capacity(BLOCK_SIZE + message.len());
232 inner.extend_from_slice(&ipad);
233 inner.extend_from_slice(message);
234 let inner_hash = Self::sha256(&inner);
235
236 let mut outer = Vec::with_capacity(BLOCK_SIZE + 32);
237 outer.extend_from_slice(&opad);
238 outer.extend_from_slice(&inner_hash);
239 Self::sha256(&outer)
240 }
241
242 fn sha256(data: &[u8]) -> [u8; 32] {
243 let mut hasher = Sha256::new();
244 hasher.update(data);
245 let result = hasher.finalize();
246 let mut output = [0u8; 32];
247 output.copy_from_slice(&result);
248 output
249 }
250 }
251}
252
253#[derive(Debug, thiserror::Error)]
254pub enum KemError {
255 #[error("Random number generation failed")]
256 RngFailed,
257 #[error("Invalid public key")]
258 InvalidPublicKey,
259 #[error("Invalid ciphertext")]
260 InvalidCiphertext,
261 #[error("Decapsulation failed")]
262 DecapsulationFailed,
263}
264
265pub fn is_production() -> bool {
267 cfg!(feature = "production-crypto")
268}