rustolio_utils/crypto/
encapsulation.rs1use ml_kem::{
12 kem::{self, Decapsulate as _},
13 EncapsulateDeterministic, EncodedSizeUser, KemCore as _, MlKem1024, MlKem1024Params,
14};
15
16use super::rand;
17
18pub type Result<T> = std::result::Result<T, Error>;
19
20#[derive(Debug, Clone, Copy, PartialEq)]
21pub enum Error {
22 InvalidDecapsulationKey,
23 InvalidEncapsulationKey,
24 InvalidEncapsulated,
25}
26
27impl std::fmt::Display for Error {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 write!(f, "{self:?}")
30 }
31}
32
33impl std::error::Error for Error {}
34
35#[derive(Debug, Clone)]
36#[repr(transparent)]
37pub struct DecapsulationKey(kem::DecapsulationKey<MlKem1024Params>);
38
39#[derive(Debug, Clone)]
40#[repr(transparent)]
41pub struct EncapsulationKey(kem::EncapsulationKey<MlKem1024Params>);
42
43impl DecapsulationKey {
44 pub fn generate() -> rand::Result<Self> {
45 let d = rand::array()?;
46 let z = rand::array()?;
47 let (dk, _) = MlKem1024::generate_deterministic(&d.into(), &z.into());
48 Ok(DecapsulationKey(dk))
49 }
50
51 pub fn encapsulation_key(&self) -> &EncapsulationKey {
52 unsafe {
53 std::mem::transmute(self.0.encapsulation_key())
55 }
56 }
57
58 pub fn to_bytes(&self) -> [u8; 3168] {
59 self.0.as_bytes().into()
60 }
61
62 pub fn from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self> {
63 let Ok(enc) = bytes.as_ref().try_into() else {
64 return Err(Error::InvalidDecapsulationKey);
65 };
66 Ok(Self(kem::DecapsulationKey::from_bytes(enc)))
67 }
68
69 pub fn decapsulate(&self, ct: &Encapsulated) -> SharedSecret {
70 SharedSecret(
71 self.0
72 .decapsulate(&ct.0.into())
73 .unwrap() .into(),
75 )
76 }
77}
78
79impl EncapsulationKey {
80 pub fn to_bytes(&self) -> [u8; 1568] {
81 self.0.as_bytes().into()
82 }
83
84 pub fn from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self> {
85 let Ok(enc) = bytes.as_ref().try_into() else {
86 return Err(Error::InvalidEncapsulationKey);
87 };
88 Ok(Self(kem::EncapsulationKey::from_bytes(enc)))
89 }
90
91 pub fn encapsulate(&self) -> rand::Result<(Encapsulated, SharedSecret)> {
92 let seed = rand::array()?;
93 let (ct, ss) = self.0.encapsulate_deterministic(&seed.into()).unwrap(); Ok((Encapsulated(ct.into()), SharedSecret(ss.into())))
95 }
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct SharedSecret([u8; 32]);
100
101#[derive(Debug, Clone, PartialEq, Eq)]
102#[repr(transparent)]
103pub struct Encapsulated([u8; 1568]);
104
105impl Encapsulated {
106 pub fn from_bytes(bytes: &[u8]) -> Result<&Self> {
107 if bytes.len() != 1568 {
108 return Err(Error::InvalidEncapsulated);
109 }
110 Ok(unsafe {
111 &*bytes.as_ptr().cast()
113 })
114 }
115
116 pub fn to_bytes(&self) -> [u8; 1568] {
117 self.0
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124
125 #[test]
126 fn test_encapsulation() {
127 let dk = DecapsulationKey::generate().unwrap();
128 let ek = dk.encapsulation_key().clone();
129 let (ct, ss) = ek.encapsulate().unwrap();
130 let ss_ = dk.decapsulate(&ct);
131 assert_eq!(ss, ss_);
132 }
133
134 #[test]
135 fn test_encapsulation_fail() {
136 let dk = DecapsulationKey::generate().unwrap();
137 let ek = dk.encapsulation_key().clone();
138 let (ct, ss) = ek.encapsulate().unwrap();
139
140 let b = [0; 1568];
141 let ct_ = Encapsulated::from_bytes(&b).unwrap();
142 let ss_ = dk.decapsulate(ct_);
143
144 assert_ne!(&ct, ct_);
145 assert_ne!(ss, ss_);
146 }
147}