Skip to main content

rustolio_utils/crypto/
encapsulation.rs

1//
2// SPDX-License-Identifier: MPL-2.0
3//
4// Copyright (c) 2026 Tobias Binnewies. All rights reserved.
5//
6// This Source Code Form is subject to the terms of the Mozilla Public
7// License, v. 2.0. If a copy of the MPL was not distributed with this
8// file, You can obtain one at http://mozilla.org/MPL/2.0/.
9//
10
11use ml_kem::{
12    kem::{self, Decapsulate as _},
13    EncapsulateDeterministic, EncodedSizeUser, KemCore as _, MlKem1024, MlKem1024Params,
14};
15
16use super::rand;
17
18pub type Result<T> = std::result::Result<T, Error>;
19
20#[derive(Debug, Clone, Copy, PartialEq)]
21pub enum Error {
22    InvalidDecapsulationKey,
23    InvalidEncapsulationKey,
24    InvalidEncapsulated,
25}
26
27impl std::fmt::Display for Error {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        write!(f, "{self:?}")
30    }
31}
32
33impl std::error::Error for Error {}
34
35#[derive(Debug, Clone)]
36#[repr(transparent)]
37pub struct DecapsulationKey(kem::DecapsulationKey<MlKem1024Params>);
38
39#[derive(Debug, Clone)]
40#[repr(transparent)]
41pub struct EncapsulationKey(kem::EncapsulationKey<MlKem1024Params>);
42
43impl DecapsulationKey {
44    pub fn generate() -> rand::Result<Self> {
45        let d = rand::array()?;
46        let z = rand::array()?;
47        let (dk, _) = MlKem1024::generate_deterministic(&d.into(), &z.into());
48        Ok(DecapsulationKey(dk))
49    }
50
51    pub fn encapsulation_key(&self) -> &EncapsulationKey {
52        unsafe {
53            // SAFETY: #[repr(transparent)] guarantees identical memory layout
54            std::mem::transmute(self.0.encapsulation_key())
55        }
56    }
57
58    pub fn to_bytes(&self) -> [u8; 3168] {
59        self.0.as_bytes().into()
60    }
61
62    pub fn from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self> {
63        let Ok(enc) = bytes.as_ref().try_into() else {
64            return Err(Error::InvalidDecapsulationKey);
65        };
66        Ok(Self(kem::DecapsulationKey::from_bytes(enc)))
67    }
68
69    pub fn decapsulate(&self, ct: &Encapsulated) -> SharedSecret {
70        SharedSecret(
71            self.0
72                .decapsulate(&ct.0.into())
73                .unwrap() // Infallible
74                .into(),
75        )
76    }
77}
78
79impl EncapsulationKey {
80    pub fn to_bytes(&self) -> [u8; 1568] {
81        self.0.as_bytes().into()
82    }
83
84    pub fn from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self> {
85        let Ok(enc) = bytes.as_ref().try_into() else {
86            return Err(Error::InvalidEncapsulationKey);
87        };
88        Ok(Self(kem::EncapsulationKey::from_bytes(enc)))
89    }
90
91    pub fn encapsulate(&self) -> rand::Result<(Encapsulated, SharedSecret)> {
92        let seed = rand::array()?;
93        let (ct, ss) = self.0.encapsulate_deterministic(&seed.into()).unwrap(); // Infallible
94        Ok((Encapsulated(ct.into()), SharedSecret(ss.into())))
95    }
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct SharedSecret([u8; 32]);
100
101#[derive(Debug, Clone, PartialEq, Eq)]
102#[repr(transparent)]
103pub struct Encapsulated([u8; 1568]);
104
105impl Encapsulated {
106    pub fn from_bytes(bytes: &[u8]) -> Result<&Self> {
107        if bytes.len() != 1568 {
108            return Err(Error::InvalidEncapsulated);
109        }
110        Ok(unsafe {
111            // SAFETY: #[repr(transparent)] & length checked
112            &*bytes.as_ptr().cast()
113        })
114    }
115
116    pub fn to_bytes(&self) -> [u8; 1568] {
117        self.0
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    #[test]
126    fn test_encapsulation() {
127        let dk = DecapsulationKey::generate().unwrap();
128        let ek = dk.encapsulation_key().clone();
129        let (ct, ss) = ek.encapsulate().unwrap();
130        let ss_ = dk.decapsulate(&ct);
131        assert_eq!(ss, ss_);
132    }
133
134    #[test]
135    fn test_encapsulation_fail() {
136        let dk = DecapsulationKey::generate().unwrap();
137        let ek = dk.encapsulation_key().clone();
138        let (ct, ss) = ek.encapsulate().unwrap();
139
140        let b = [0; 1568];
141        let ct_ = Encapsulated::from_bytes(&b).unwrap();
142        let ss_ = dk.decapsulate(ct_);
143
144        assert_ne!(&ct, ct_);
145        assert_ne!(ss, ss_);
146    }
147}