spideroak_crypto/
hpke.rs

1//! Hybrid Public Key Encryption per [RFC 9180].
2//!
3//! ## Notation
4//!
5//! - `sk`: a private key; shorthand for "*S*ecret *K*ey"
6//! - `pk`: a public key; shorthand for "*P*ublic *K*ey"
7//! - `skR`, `pkR`: a receiver's secret or public key
8//! - `skS`, `pkS`: a sender's secret or public key
9//! - `skE`, `pkE`: an ephemeral secret or public key
10//! - `encap`, `decap`: see [Encapsulate](#Encapsulate).
11//!
12//! [RFC 9180]: https://www.rfc-editor.org/rfc/rfc9180.html
13
14#![forbid(unsafe_code)]
15// We use the same variable names used in the HPKE RFC.
16#![allow(non_snake_case)]
17
18use core::{fmt, iter, marker::PhantomData, num::NonZeroU16, result::Result};
19
20use buggy::{bug, Bug, BugExt};
21use generic_array::ArrayLength;
22use subtle::{Choice, ConstantTimeEq};
23use typenum::Unsigned as _;
24
25use crate::{
26    aead::{Aead, IndCca2, KeyData, Nonce, OpenError, SealError},
27    csprng::Csprng,
28    import::{ExportError, Import as _, ImportError},
29    kdf::{Expand, Kdf, KdfError, Prk},
30    kem::{Kem, KemError},
31    keys::RawSecretBytes as _,
32    AlgId,
33};
34
35/// Converts `v` to a big-endian byte array.
36macro_rules! i2osp {
37    ($v:expr) => {
38        $v.to_be_bytes()
39    };
40    ($v:expr, $n:ty) => {{
41        let src = $v.to_be_bytes();
42        let mut dst = generic_array::GenericArray::<u8, $n>::default();
43        // Copy `src` into `dst`, padding with zeros on the
44        // left.
45        //
46        // NB: the compiler knows how to optimize this. Don't
47        // rewrite it without verifying the assembly.
48        let idx = dst.len().abs_diff(src.len());
49        if dst.len() >= src.len() {
50            dst[idx..].copy_from_slice(&src);
51        } else {
52            dst.copy_from_slice(&src[idx..]);
53        }
54        dst
55    }};
56}
57
58/// An HPKE operation mode.
59#[derive(Debug)]
60pub enum Mode<'a, T> {
61    /// The most basic operation mode.
62    Base,
63    /// Extends the base mode by allowing the recipient to
64    /// authenticate that the sender possessed a particular
65    /// pre-shared key.
66    Psk(Psk<'a>),
67    /// Extends the base mode by allowing the recipient to
68    /// authenticate that the sender possessed a particular
69    /// private key.
70    Auth(T),
71    /// A combination of [`Mode::Auth`] and [`Mode::Psk`].
72    AuthPsk(T, Psk<'a>),
73}
74
75impl<T> fmt::Display for Mode<'_, T> {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        match self {
78            Self::Base => write!(f, "mode_base"),
79            Self::Psk(_) => write!(f, "mode_psk"),
80            Self::Auth(_) => write!(f, "mode_auth"),
81            Self::AuthPsk(_, _) => write!(f, "mode_auth_psk"),
82        }
83    }
84}
85
86impl<'a, T> Mode<'a, T> {
87    // The default `psk` and `psk_id` are empty strings. See
88    // section 5.1.
89    const DEFAULT_PSK: Psk<'static> = Psk {
90        psk: &[],
91        psk_id: &[],
92    };
93
94    /// Converts from `Mode<'_, T>` to `Mode<'_, &T>`.
95    pub const fn as_ref(&self) -> Mode<'_, &T> {
96        match *self {
97            Self::Base => Mode::Base,
98            Self::Psk(psk) => Mode::Psk(psk),
99            Self::Auth(ref k) => Mode::Auth(k),
100            Self::AuthPsk(ref k, psk) => Mode::AuthPsk(k, psk),
101        }
102    }
103
104    fn psk(&self) -> &Psk<'a> {
105        match self {
106            Mode::Psk(psk) => psk,
107            Mode::AuthPsk(_, psk) => psk,
108            _ => &Self::DEFAULT_PSK,
109        }
110    }
111
112    const fn id(&self) -> u8 {
113        match self {
114            Self::Base => 0x00,
115            Self::Psk(_) => 0x01,
116            Self::Auth(_) => 0x02,
117            Self::AuthPsk(_, _) => 0x03,
118        }
119    }
120}
121
122/// The PSK or its ID are empty.
123#[derive(Copy, Clone, Debug, Eq, PartialEq)]
124pub struct InvalidPsk;
125
126impl fmt::Display for InvalidPsk {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        f.write_str("invalid pre-shared key: PSK or PSK ID are empty")
129    }
130}
131
132impl core::error::Error for InvalidPsk {}
133
134/// A pre-shared key and its ID.
135#[derive(Copy, Clone)]
136pub struct Psk<'a> {
137    /// The pre-shared key.
138    psk: &'a [u8],
139    // The pre-shared key's ID.
140    psk_id: &'a [u8],
141}
142
143impl<'a> Psk<'a> {
144    /// Creates a [`Psk`] from a pre-shared key and its ID.
145    pub fn new(psk: &'a [u8], psk_id: &'a [u8]) -> Result<Self, InvalidPsk> {
146        // See Section 5.1, `VerifyPSKInputs`.
147        if psk.is_empty() || psk_id.is_empty() {
148            Err(InvalidPsk)
149        } else {
150            Ok(Self { psk, psk_id })
151        }
152    }
153}
154
155impl ConstantTimeEq for Psk<'_> {
156    fn ct_eq(&self, other: &Self) -> Choice {
157        self.psk.ct_eq(other.psk) & self.psk_id.ct_eq(other.psk_id)
158    }
159}
160
161impl fmt::Debug for Psk<'_> {
162    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163        f.debug_struct("Psk")
164            .field("psk_id", &self.psk_id)
165            .finish_non_exhaustive()
166    }
167}
168
169/// KEM algorithm identifiers per [IANA].
170///
171/// [IANA]: https://www.iana.org/assignments/hpke/hpke.xhtml
172#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, AlgId)]
173pub enum KemId {
174    /// DHKEM(P-256, HKDF-SHA256).
175    #[alg_id(0x0010)]
176    DhKemP256HkdfSha256,
177    /// DHKEM(P-384, HKDF-SHA384).
178    #[alg_id(0x0011)]
179    DhKemP384HkdfSha384,
180    /// DHKEM(P-521, HKDF-SHA512).
181    #[alg_id(0x0012)]
182    DhKemP521HkdfSha512,
183    /// DHKEM(CP-256, HKDF-SHA256)
184    #[alg_id(0x0013)]
185    DhKemCp256HkdfSha256,
186    /// DHKEM(CP-384, HKDF-SHA384)
187    #[alg_id(0x0014)]
188    DhKemCp384HkdfSha384,
189    /// DHKEM(CP-521, HKDF-SHA512)
190    #[alg_id(0x0015)]
191    DhKemCp521HkdfSha512,
192    /// DHKEM(secp256k1, HKDF-SHA256)
193    #[alg_id(0x0016)]
194    DhKemSecp256k1HkdfSha256,
195    /// DHKEM(X25519, HKDF-SHA256).
196    #[alg_id(0x0020)]
197    DhKemX25519HkdfSha256,
198    /// DHKEM(X448, HKDF-SHA512).
199    #[alg_id(0x0021)]
200    DhKemX448HkdfSha512,
201    /// X25519Kyber768Draft00
202    #[alg_id(0x0030)]
203    X25519Kyber768Draft00,
204    /// ML-KEM-512.
205    #[alg_id(0x040)]
206    MlKem512,
207    /// ML-KEM-768.
208    #[alg_id(0x041)]
209    MlKem768,
210    /// ML-KEM-1024.
211    #[alg_id(0x042)]
212    MlKem1024,
213    /// X-Wing.
214    #[alg_id(0x647a)]
215    XWing,
216    /// Some other KEM.
217    ///
218    /// Non-zero since 0x0000 is marked as 'reserved'.
219    #[alg_id(Other)]
220    Other(NonZeroU16),
221}
222
223impl fmt::Display for KemId {
224    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225        match self {
226            Self::DhKemP256HkdfSha256 => write!(f, "DHKEM(P-256, HKDF-SHA256)"),
227            Self::DhKemP384HkdfSha384 => write!(f, "DHKEM(P-384, HKDF-SHA384)"),
228            Self::DhKemP521HkdfSha512 => write!(f, "DHKEM(P-521, HKDF-SHA512)"),
229            Self::DhKemCp256HkdfSha256 => write!(f, "DHKEM(CP-256, HKDF-SHA256)"),
230            Self::DhKemCp384HkdfSha384 => write!(f, "DHKEM(CP-384, HKDF-SHA384)"),
231            Self::DhKemCp521HkdfSha512 => write!(f, "DHKEM(CP-521, HKDF-SHA512)"),
232            Self::DhKemSecp256k1HkdfSha256 => write!(f, "DHKEM(secp256k1, HKDF-SHA256)"),
233            Self::DhKemX25519HkdfSha256 => write!(f, "DHKEM(X25519, HKDF-SHA256)"),
234            Self::DhKemX448HkdfSha512 => write!(f, "DHKEM(X448, HKDF-SHA512)"),
235            Self::X25519Kyber768Draft00 => write!(f, "X25519Kyber768Draft00"),
236            Self::MlKem512 => write!(f, "ML-KEM-512"),
237            Self::MlKem768 => write!(f, "ML-KEM-768"),
238            Self::MlKem1024 => write!(f, "ML-KEM-1024"),
239            Self::XWing => write!(f, "X-Wing"),
240            Self::Other(id) => write!(f, "Kem({:#02x})", id),
241        }
242    }
243}
244
245/// KDF algorithm identifiers per [IANA].
246///
247/// [IANA]: https://www.iana.org/assignments/hpke/hpke.xhtml
248#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, AlgId)]
249pub enum KdfId {
250    /// HKDF-SHA256.
251    #[alg_id(0x0001)]
252    HkdfSha256,
253    /// HKDF-SHA384.
254    #[alg_id(0x0002)]
255    HkdfSha384,
256    /// HKDF-SHA512.
257    #[alg_id(0x0003)]
258    HkdfSha512,
259    /// Some other KDF.
260    ///
261    /// Non-zero since 0x0000 is marked as 'reserved'.
262    #[alg_id(Other)]
263    Other(NonZeroU16),
264}
265
266impl fmt::Display for KdfId {
267    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
268        match self {
269            Self::HkdfSha256 => write!(f, "HkdfSha256"),
270            Self::HkdfSha384 => write!(f, "HkdfSha384"),
271            Self::HkdfSha512 => write!(f, "HkdfSha512"),
272            Self::Other(id) => write!(f, "Kdf({:#02x})", id),
273        }
274    }
275}
276
277/// AEAD algorithm identifiers per [IANA].
278///
279/// [IANA]: https://www.iana.org/assignments/hpke/hpke.xhtml
280#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, AlgId)]
281pub enum AeadId {
282    /// AES-128-GCM.
283    #[alg_id(0x0001)]
284    Aes128Gcm,
285    /// AES-256-GCM.
286    #[alg_id(0x0002)]
287    Aes256Gcm,
288    /// ChaCha20Poly1305.
289    #[alg_id(0x0003)]
290    ChaCha20Poly1305,
291    /// Some other AEAD.
292    ///
293    /// Non-zero since 0x0000 is marked as 'reserved'.
294    #[alg_id(Other)]
295    Other(NonZeroU16),
296    /// Export-only AEAD.
297    #[alg_id(0xffff)]
298    ExportOnly,
299}
300
301impl fmt::Display for AeadId {
302    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303        match self {
304            Self::Aes128Gcm => write!(f, "Aes128Gcm"),
305            Self::Aes256Gcm => write!(f, "Aes256Gcm"),
306            Self::ChaCha20Poly1305 => write!(f, "ChaCha20Poly1305"),
307            Self::Other(id) => write!(f, "Aead({:#02x})", id),
308            Self::ExportOnly => write!(f, "ExportOnly"),
309        }
310    }
311}
312
313/// A [`Kem`] that can be used by HPKE.
314pub trait HpkeKem: Kem {
315    /// Identifies the KEM algorithm per [IANA].
316    ///
317    /// [IANA]: https://www.iana.org/assignments/hpke/hpke.xhtml
318    const ID: KemId;
319}
320
321/// A [`Kdf`] that can be used by HPKE.
322pub trait HpkeKdf: Kdf {
323    /// Identifies the KDF algorithm per [IANA].
324    ///
325    /// [IANA]: https://www.iana.org/assignments/hpke/hpke.xhtml
326    const ID: KdfId;
327}
328
329/// An [`Aead`] that can be used by HPKE.
330pub trait HpkeAead: Aead + IndCca2 {
331    /// Identifies the AEAD algorithm per [IANA].
332    ///
333    /// [IANA]: https://www.iana.org/assignments/hpke/hpke.xhtml
334    const ID: AeadId;
335}
336
337/// An error from an [`Hpke`].
338#[derive(Debug, Eq, PartialEq)]
339pub enum HpkeError {
340    /// An AEAD seal operation failed.
341    Seal(SealError),
342    /// An AEAD open operation failed.
343    Open(OpenError),
344    /// A KDF operation failed.
345    Kdf(KdfError),
346    /// A KEM operation failed.
347    Kem(KemError),
348    /// A key could not be imported.
349    Import(ImportError),
350    /// A key could not be exported.
351    Export(ExportError),
352    /// The encryption context has been used to send the maximum
353    /// number of messages.
354    MessageLimitReached,
355    /// An internal bug was discovered.
356    Bug(Bug),
357}
358
359impl fmt::Display for HpkeError {
360    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361        match self {
362            Self::Seal(err) => write!(f, "{}", err),
363            Self::Open(err) => write!(f, "{}", err),
364            Self::Kdf(err) => write!(f, "{}", err),
365            Self::Kem(err) => write!(f, "{}", err),
366            Self::Import(err) => write!(f, "{}", err),
367            Self::Export(err) => write!(f, "{}", err),
368            Self::MessageLimitReached => write!(f, "message limit reached"),
369            Self::Bug(err) => write!(f, "{err}"),
370        }
371    }
372}
373
374impl core::error::Error for HpkeError {
375    fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
376        match self {
377            Self::Seal(err) => Some(err),
378            Self::Open(err) => Some(err),
379            Self::Kdf(err) => Some(err),
380            Self::Kem(err) => Some(err),
381            Self::Import(err) => Some(err),
382            Self::Export(err) => Some(err),
383            Self::MessageLimitReached => None,
384            Self::Bug(err) => Some(err),
385        }
386    }
387}
388
389impl From<SealError> for HpkeError {
390    fn from(err: SealError) -> Self {
391        Self::Seal(err)
392    }
393}
394
395impl From<OpenError> for HpkeError {
396    fn from(err: OpenError) -> Self {
397        Self::Open(err)
398    }
399}
400
401impl From<KdfError> for HpkeError {
402    fn from(err: KdfError) -> Self {
403        Self::Kdf(err)
404    }
405}
406
407impl From<KemError> for HpkeError {
408    fn from(err: KemError) -> Self {
409        Self::Kem(err)
410    }
411}
412
413impl From<ImportError> for HpkeError {
414    fn from(err: ImportError) -> Self {
415        Self::Import(err)
416    }
417}
418
419impl From<ExportError> for HpkeError {
420    fn from(err: ExportError) -> Self {
421        Self::Export(err)
422    }
423}
424
425impl From<Bug> for HpkeError {
426    fn from(err: Bug) -> Self {
427        Self::Bug(err)
428    }
429}
430
431impl From<MessageLimitReached> for HpkeError {
432    fn from(_err: MessageLimitReached) -> Self {
433        Self::MessageLimitReached
434    }
435}
436
437/// Hybrid Public Key Encryption (HPKE) per [RFC 9180].
438///
439/// [RFC 9180]: <https://www.rfc-editor.org/rfc/rfc9180.html>
440#[derive(Debug)]
441pub struct Hpke<K, F, A> {
442    _kem: PhantomData<fn() -> K>,
443    _kdf: PhantomData<fn() -> F>,
444    _aead: PhantomData<fn() -> A>,
445}
446
447impl<K, F, A> Hpke<K, F, A>
448where
449    K: HpkeKem,
450    F: HpkeKdf,
451    A: HpkeAead,
452{
453    /// Creates a randomized encryption context for encrypting
454    /// messages for the receiver, `pkR`.
455    ///
456    /// It returns the encryption context and an encapsulated
457    /// symmetric key which can be used by the receiver to
458    /// decrypt messages.
459    ///
460    /// The `info` parameter provides contextual binding.
461    #[allow(clippy::type_complexity)]
462    pub fn setup_send<'a, R: Csprng>(
463        rng: &mut R,
464        mode: Mode<'_, &K::DecapKey>,
465        pkR: &K::EncapKey,
466        info: impl IntoIterator<Item = &'a [u8]>,
467    ) -> Result<(K::Encap, SendCtx<K, F, A>), HpkeError> {
468        let (shared_secret, enc) = match mode {
469            Mode::Auth(skS) | Mode::AuthPsk(skS, _) => K::auth_encap::<R>(rng, pkR, skS)?,
470            Mode::Base | Mode::Psk(_) => K::encap::<R>(rng, pkR)?,
471        };
472        let ctx = Self::key_schedule(mode, &shared_secret, info)?;
473        Ok((enc, ctx.into_send_ctx()))
474    }
475
476    /// Deterministically creates an encryption context for
477    /// encrypting messages for the receiver, `pkR`.
478    ///
479    /// It returns the encryption context and an encapsulated
480    /// symmetric key which can be used by the receiver to
481    /// decrypt messages.
482    ///
483    /// The `info` parameter provides contextual binding.
484    ///
485    /// # Warning
486    ///
487    /// The security of this function relies on choosing the
488    /// correct value for `skE`. It is a catastrophic error if
489    /// you do not ensure all of the following properties:
490    ///
491    /// - it must be cryptographically secure
492    /// - it must never be reused
493    #[allow(clippy::type_complexity)]
494    pub fn setup_send_deterministically<'a>(
495        mode: Mode<'_, &K::DecapKey>,
496        pkR: &K::EncapKey,
497        info: impl IntoIterator<Item = &'a [u8]>,
498        skE: K::DecapKey,
499    ) -> Result<(K::Encap, SendCtx<K, F, A>), HpkeError> {
500        let (shared_secret, enc) = match mode {
501            Mode::Auth(skS) | Mode::AuthPsk(skS, _) => {
502                K::auth_encap_deterministically(pkR, skS, skE)?
503            }
504            Mode::Base | Mode::Psk(_) => K::encap_deterministically(pkR, skE)?,
505        };
506        let ctx = Self::key_schedule(mode, &shared_secret, info)?;
507        Ok((enc, ctx.into_send_ctx()))
508    }
509
510    /// Creates an encryption context that can decrypt messages
511    /// from a particular sender (the creator of `enc`).
512    ///
513    /// The `mode` and `info` parameters must be the same
514    /// parameters used by the sender.
515    pub fn setup_recv<'a>(
516        mode: Mode<'_, &K::EncapKey>,
517        enc: &K::Encap,
518        skR: &K::DecapKey,
519        info: impl IntoIterator<Item = &'a [u8]>,
520    ) -> Result<RecvCtx<K, F, A>, HpkeError> {
521        let shared_secret = match mode {
522            Mode::Auth(pkS) | Mode::AuthPsk(pkS, _) => K::auth_decap(enc, skR, pkS)?,
523            Mode::Base | Mode::Psk(_) => K::decap(enc, skR)?,
524        };
525        let ctx = Self::key_schedule(mode, &shared_secret, info)?;
526        Ok(ctx.into_recv_ctx())
527    }
528
529    /// The "HPKE" `suite_id`.
530    ///
531    /// ```text
532    /// suite_id = concat(
533    ///     "HPKE",
534    ///     I2OSP(kem_id, 2),
535    ///     I2OSP(kdf_id, 2),
536    ///     I2OSP(aead_id, 2),
537    /// )
538    /// ```
539    #[rustfmt::skip]
540    const HPKE_SUITE_ID: &[u8] = &[
541        b'H',
542        b'P',
543        b'K',
544        b'E',
545        i2osp!(K::ID)[0], i2osp!(K::ID)[1],
546        i2osp!(F::ID)[0], i2osp!(F::ID)[1],
547        i2osp!(A::ID)[0], i2osp!(A::ID)[1],
548    ];
549
550    /// The "HPKE-v1" domain separator for `LabeledExtract` and
551    /// `LabeledExpand`.
552    const DOMAIN: &[u8] = b"HPKE-v1";
553
554    fn key_schedule<'a, T>(
555        mode: Mode<'_, T>,
556        shared_secret: &K::Secret,
557        info: impl IntoIterator<Item = &'a [u8]>,
558    ) -> Result<Schedule<K, F, A>, HpkeError> {
559        let Psk { psk, psk_id } = mode.psk();
560
561        //  psk_id_hash = LabeledExtract("", "psk_id_hash", psk_id)
562        let psk_id_hash = Self::labeled_extract(b"", b"psk_id_hash", iter::once(psk_id).copied());
563
564        //  info_hash = LabeledExtract("", "info_hash", info)
565        let info_hash = Self::labeled_extract(b"", b"info_hash", info);
566
567        //  key_schedule_context = concat(mode, psk_id_hash, info_hash)
568        let ks_ctx = [&[mode.id()], psk_id_hash.as_bytes(), info_hash.as_bytes()];
569
570        //  secret = LabeledExtract(shared_secret, "secret", psk)
571        let secret = Self::labeled_extract(
572            shared_secret.raw_secret_bytes(),
573            b"secret",
574            iter::once(psk).copied(),
575        );
576
577        // key = LabeledExpand(secret, "key", key_schedule_context, Nk)
578        let key = Self::labeled_expand(&secret, b"key", ks_ctx)?;
579
580        // base_nonce = LabeledExpand(secret, "base_nonce",
581        //                      key_schedule_context, Nn)
582        let base_nonce = Self::labeled_expand(&secret, b"base_nonce", ks_ctx)?;
583
584        // exporter_secret = LabeledExpand(secret, "exp",
585        //                           key_schedule_context, Nh)
586        let exporter_secret = Self::labeled_expand(&secret, b"exp", ks_ctx)?;
587
588        Ok(Schedule {
589            key,
590            base_nonce,
591            exporter_secret,
592            _kem: PhantomData,
593        })
594    }
595
596    /// Performs `LabeledExtract`.
597    fn labeled_extract<'a>(
598        salt: &[u8],
599        label: &'static [u8],
600        ikm: impl IntoIterator<Item = &'a [u8]>,
601    ) -> Prk<F::PrkSize> {
602        // def LabeledExtract(salt, label, ikm):
603        //     labeled_ikm = concat("HPKE-v1", suite_id, label, ikm)
604        //     return Extract(salt, labeled_ikm)
605        let labeled_ikm = [Self::DOMAIN, Self::HPKE_SUITE_ID, label]
606            .into_iter()
607            .chain(ikm);
608        F::extract_multi(labeled_ikm, salt)
609    }
610
611    /// Performs `LabeledExpand`.
612    fn labeled_expand<'a, T: Expand>(
613        prk: &Prk<F::PrkSize>,
614        label: &'static [u8],
615        info: impl IntoIterator<Item = &'a [u8], IntoIter: Clone>,
616    ) -> Result<T, KdfError> {
617        // def LabeledExpand(prk, label, info, L):
618        //     labeled_info = concat(I2OSP(L, 2), "HPKE-v1", suite_id,
619        //                 label, info)
620        //     return Expand(prk, labeled_info, L)
621        let size = T::Size::U16.to_be_bytes();
622        let labeled_info = iter::once(size.as_slice())
623            .chain(iter::once(Self::DOMAIN))
624            .chain(iter::once(Self::HPKE_SUITE_ID))
625            .chain(iter::once(label))
626            .chain(
627                // `.map(|v| v)` shortens the lifetime from `'a`
628                // to `'1`.
629                #[allow(clippy::map_identity)]
630                info.into_iter().map(|v| v),
631            );
632        T::expand_multi::<F, _>(prk, labeled_info)
633    }
634
635    /// Performs `LabeledExpand`.
636    fn labeled_expand_into<'a>(
637        out: &mut [u8],
638        prk: &Prk<F::PrkSize>,
639        label: &'static [u8],
640        info: impl IntoIterator<Item = &'a [u8], IntoIter: Clone>,
641    ) -> Result<(), KdfError> {
642        // def LabeledExpand(prk, label, info, L):
643        //     labeled_info = concat(I2OSP(L, 2), "HPKE-v1", suite_id,
644        //                 label, info)
645        //     return Expand(prk, labeled_info, L)
646        let size = u16::try_from(out.len())
647            .map_err(|_| KdfError::OutputTooLong)?
648            .to_be_bytes();
649        let labeled_info = iter::once(size.as_slice())
650            .chain(iter::once(Self::DOMAIN))
651            .chain(iter::once(Self::HPKE_SUITE_ID))
652            .chain(iter::once(label))
653            .chain(
654                // `.map(|v| v)` shortens the lifetime from `'a`
655                // to `'1`.
656                #[allow(clippy::map_identity)]
657                info.into_iter().map(|v| v),
658            );
659        F::expand_multi(out, prk, labeled_info)
660    }
661}
662
663#[derive(Debug)]
664struct Schedule<K, F, A>
665where
666    K: HpkeKem,
667    F: HpkeKdf,
668    A: HpkeAead,
669{
670    key: KeyData<A>,
671    base_nonce: Nonce<A::NonceSize>,
672    exporter_secret: Prk<F::PrkSize>,
673    _kem: PhantomData<fn() -> K>,
674}
675
676impl<K, F, A> Schedule<K, F, A>
677where
678    K: HpkeKem,
679    F: HpkeKdf,
680    A: HpkeAead,
681{
682    fn into_send_ctx(self) -> SendCtx<K, F, A> {
683        SendCtx {
684            seal: Either::Right((self.key, self.base_nonce)),
685            export: ExportCtx::new(self.exporter_secret),
686        }
687    }
688
689    fn into_recv_ctx(self) -> RecvCtx<K, F, A> {
690        RecvCtx {
691            open: Either::Right((self.key, self.base_nonce)),
692            export: ExportCtx::new(self.exporter_secret),
693        }
694    }
695}
696
697/// Either `L` or `R`.
698#[derive(Debug)]
699enum Either<L, R> {
700    Left(L),
701    Right(R),
702}
703
704impl<L, R> Either<L, R> {
705    fn get_or_insert_left<F, E>(&mut self, f: F) -> Result<&mut L, E>
706    where
707        F: FnOnce(&R) -> Result<L, E>,
708        E: From<Bug>,
709    {
710        match self {
711            Self::Left(left) => Ok(left),
712            Self::Right(right) => {
713                *self = Self::Left(f(right)?);
714                match self {
715                    Self::Left(left) => Ok(left),
716                    Self::Right(_) => bug!("we just assigned `Self::Left`"),
717                }
718            }
719        }
720    }
721}
722
723type RawKey<A> = (KeyData<A>, Nonce<<A as Aead>::NonceSize>);
724
725/// An encryption context that encrypts messages for a particular
726/// recipient.
727pub struct SendCtx<K, F, A>
728where
729    K: HpkeKem,
730    F: HpkeKdf,
731    A: HpkeAead,
732{
733    seal: Either<SealCtx<A>, RawKey<A>>,
734    export: ExportCtx<K, F, A>,
735}
736
737impl<K, F, A> SendCtx<K, F, A>
738where
739    K: HpkeKem,
740    F: HpkeKdf,
741    A: HpkeAead,
742{
743    /// The size in bytes of the overhead added to the plaintext.
744    pub const OVERHEAD: usize = SealCtx::<A>::OVERHEAD;
745
746    // Exposed for `aranya-crypto`, do not use.
747    #[doc(hidden)]
748    pub fn into_raw_parts(self) -> Option<(KeyData<A>, Nonce<A::NonceSize>)> {
749        match self.seal {
750            Either::Left(_) => None,
751            Either::Right((key, base_nonce)) => Some((key, base_nonce)),
752        }
753    }
754
755    fn seal_ctx(&mut self) -> Result<&mut SealCtx<A>, ImportError> {
756        self.seal
757            .get_or_insert_left(|(key, nonce)| SealCtx::new(key, nonce, Seq::ZERO))
758    }
759
760    /// Encrypts and authenticates `plaintext`, returning the
761    /// sequence number.
762    ///
763    /// The resulting ciphertext is written to `dst`, which must
764    /// be at least `plaintext.len()` + [`OVERHEAD`][Self::OVERHEAD]
765    /// bytes long.
766    pub fn seal(
767        &mut self,
768        dst: &mut [u8],
769        plaintext: &[u8],
770        additional_data: &[u8],
771    ) -> Result<Seq, HpkeError> {
772        self.seal_ctx()?.seal(dst, plaintext, additional_data)
773    }
774
775    /// Encrypts and authenticates `data` in-place, returning the
776    /// sequence number.
777    pub fn seal_in_place(
778        &mut self,
779        data: impl AsMut<[u8]>,
780        tag: &mut [u8],
781        additional_data: &[u8],
782    ) -> Result<Seq, HpkeError> {
783        self.seal_ctx()?.seal_in_place(data, tag, additional_data)
784    }
785
786    /// Exports a secret from the encryption context.
787    pub fn export<T>(&self, context: &[u8]) -> Result<T, KdfError>
788    where
789        T: Expand,
790    {
791        self.export.export(context)
792    }
793
794    /// Exports a secret from the encryption context, writing it
795    /// to `out`.
796    pub fn export_into(&self, out: &mut [u8], context: &[u8]) -> Result<(), KdfError> {
797        self.export.export_into(out, context)
798    }
799}
800
801impl<K, F, A> fmt::Debug for SendCtx<K, F, A>
802where
803    K: HpkeKem,
804    F: HpkeKdf,
805    A: HpkeAead + fmt::Debug,
806{
807    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
808        f.debug_struct("SendCtx")
809            .field("seal", &self.seal)
810            .field("export", &self.export)
811            .finish()
812    }
813}
814
815/// An encryption context that can only encrypt messages for
816/// a particular recipient.
817///
818/// Unlike [`SendCtx`], it cannot export secrets.
819#[doc(hidden)]
820pub struct SealCtx<A: HpkeAead> {
821    aead: A,
822    base_nonce: Nonce<A::NonceSize>,
823    /// Incremented after each call to `seal`.
824    seq: Seq,
825}
826
827impl<A: HpkeAead> SealCtx<A> {
828    /// The size in bytes of the overhead added to the plaintext.
829    pub const OVERHEAD: usize = A::OVERHEAD;
830
831    // Exported for `aranya-crypto`. Do not use.
832    #[doc(hidden)]
833    pub fn new(
834        key: &KeyData<A>,
835        base_nonce: &Nonce<A::NonceSize>,
836        seq: Seq,
837    ) -> Result<Self, ImportError> {
838        let key = A::Key::import(key.as_bytes())?;
839        Ok(Self {
840            aead: A::new(&key),
841            base_nonce: base_nonce.clone(),
842            seq,
843        })
844    }
845
846    fn compute_nonce(&self) -> Result<Nonce<A::NonceSize>, MessageLimitReached> {
847        self.seq.compute_nonce::<A::NonceSize>(&self.base_nonce)
848    }
849
850    fn increment_seq(&mut self) -> Result<Seq, Bug> {
851        self.seq.increment::<A::NonceSize>()
852    }
853
854    /// Encrypts and authenticates `plaintext`, returning the
855    /// sequence number.
856    ///
857    /// The resulting ciphertext is written to `dst`, which must
858    /// be at least `plaintext.len()` + [`OVERHEAD`][Self::OVERHEAD]
859    /// bytes long.
860    pub fn seal(
861        &mut self,
862        dst: &mut [u8],
863        plaintext: &[u8],
864        additional_data: &[u8],
865    ) -> Result<Seq, HpkeError> {
866        let nonce = self.compute_nonce()?;
867        self.aead.seal(dst, &nonce, plaintext, additional_data)?;
868        let prev = self.increment_seq()?;
869        Ok(prev)
870    }
871
872    /// Encrypts and authenticates `data` in place, returning the
873    /// sequence number.
874    pub fn seal_in_place(
875        &mut self,
876        mut data: impl AsMut<[u8]>,
877        tag: &mut [u8],
878        additional_data: &[u8],
879    ) -> Result<Seq, HpkeError> {
880        let nonce = self.compute_nonce()?;
881        self.aead
882            .seal_in_place(&nonce, data.as_mut(), tag, additional_data)?;
883        let prev = self.increment_seq()?;
884        Ok(prev)
885    }
886
887    /// Returns the current sequence number.
888    pub fn seq(&self) -> Seq {
889        self.seq
890    }
891}
892
893impl<A: HpkeAead + fmt::Debug> fmt::Debug for SealCtx<A> {
894    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
895        f.debug_struct("SealCtx")
896            .field("aead", &self.aead)
897            .field("base_nonce", &self.base_nonce)
898            .field("seq", &self.seq)
899            .finish()
900    }
901}
902
903/// An encryption context that decrypts messages from
904/// a particular sender.
905pub struct RecvCtx<K, F, A>
906where
907    K: HpkeKem,
908    F: HpkeKdf,
909    A: HpkeAead,
910{
911    open: Either<OpenCtx<A>, RawKey<A>>,
912    export: ExportCtx<K, F, A>,
913}
914
915impl<K, F, A> RecvCtx<K, F, A>
916where
917    K: HpkeKem,
918    F: HpkeKdf,
919    A: HpkeAead,
920{
921    /// The size in bytes of the overhead added to the plaintext.
922    pub const OVERHEAD: usize = OpenCtx::<A>::OVERHEAD;
923
924    // Exposed for `aranya-crypto`, do not use.
925    #[doc(hidden)]
926    pub fn into_raw_parts(self) -> Option<(KeyData<A>, Nonce<A::NonceSize>)> {
927        match self.open {
928            Either::Left(_) => None,
929            Either::Right((key, base_nonce)) => Some((key, base_nonce)),
930        }
931    }
932
933    fn open_ctx(&mut self) -> Result<&mut OpenCtx<A>, ImportError> {
934        self.open
935            .get_or_insert_left(|(key, nonce)| OpenCtx::new(key, nonce, Seq::ZERO))
936    }
937
938    /// Decrypts and authenticates `ciphertext` using the
939    /// internal sequence number.
940    ///
941    /// The resulting plaintext is written to `dst`, which must
942    /// must be at least `ciphertext.len()` - [`Self::OVERHEAD`]
943    /// bytes long.
944    pub fn open(
945        &mut self,
946        dst: &mut [u8],
947        ciphertext: &[u8],
948        additional_data: &[u8],
949    ) -> Result<(), HpkeError> {
950        self.open_ctx()?.open(dst, ciphertext, additional_data)
951    }
952
953    /// Decrypts and authenticates `ciphertext` at a particular
954    /// sequence number.
955    ///
956    /// The resulting plaintext is written to `dst`, which must
957    /// must be at least `ciphertext.len()` - [`Self::OVERHEAD`]
958    /// bytes long.
959    pub fn open_at(
960        &mut self,
961        dst: &mut [u8],
962        ciphertext: &[u8],
963        additional_data: &[u8],
964        seq: Seq,
965    ) -> Result<(), HpkeError> {
966        self.open_ctx()?
967            .open_at(dst, ciphertext, additional_data, seq)
968    }
969
970    /// Decrypts and authenticates `ciphertext`.
971    pub fn open_in_place(
972        &mut self,
973        data: impl AsMut<[u8]>,
974        tag: &[u8],
975        additional_data: &[u8],
976    ) -> Result<(), HpkeError> {
977        self.open_ctx()?.open_in_place(data, tag, additional_data)
978    }
979
980    /// Decrypts and authenticates `ciphertext` at a particular
981    /// sequence number.
982    pub fn open_in_place_at(
983        &mut self,
984        data: impl AsMut<[u8]>,
985        tag: &[u8],
986        additional_data: &[u8],
987        seq: Seq,
988    ) -> Result<(), HpkeError> {
989        self.open_ctx()?
990            .open_in_place_at(data, tag, additional_data, seq)
991    }
992
993    /// Exports a secret from the encryption context.
994    pub fn export<T>(&self, context: &[u8]) -> Result<T, KdfError>
995    where
996        T: Expand,
997    {
998        self.export.export(context)
999    }
1000
1001    /// Exports a secret from the encryption context, writing it
1002    /// to `out`.
1003    pub fn export_into(&self, out: &mut [u8], context: &[u8]) -> Result<(), KdfError> {
1004        self.export.export_into(out, context)
1005    }
1006}
1007
1008impl<K, F, A> fmt::Debug for RecvCtx<K, F, A>
1009where
1010    K: HpkeKem,
1011    F: HpkeKdf,
1012    A: HpkeAead + fmt::Debug,
1013{
1014    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1015        f.debug_struct("RecvCtx")
1016            .field("open", &self.open)
1017            .field("export", &self.export)
1018            .finish()
1019    }
1020}
1021
1022/// An encryption context that can only decrypt messages from
1023/// a particular sender.
1024///
1025/// Unlike [`RecvCtx`], it cannot export secrets.
1026#[doc(hidden)]
1027pub struct OpenCtx<A: HpkeAead> {
1028    aead: A,
1029    base_nonce: Nonce<A::NonceSize>,
1030    /// Incremented after each call to `open`.
1031    seq: Seq,
1032}
1033
1034impl<A: HpkeAead> OpenCtx<A> {
1035    /// The size in bytes of the overhead added to the plaintext.
1036    pub const OVERHEAD: usize = A::OVERHEAD;
1037
1038    // Exported for `aranya-crypto`. Do not use.
1039    #[doc(hidden)]
1040    pub fn new(
1041        key: &KeyData<A>,
1042        base_nonce: &Nonce<A::NonceSize>,
1043        seq: Seq,
1044    ) -> Result<Self, ImportError> {
1045        let key = A::Key::import(key.as_bytes())?;
1046        Ok(Self {
1047            aead: A::new(&key),
1048            base_nonce: base_nonce.clone(),
1049            seq,
1050        })
1051    }
1052
1053    fn increment_seq(&mut self) -> Result<Seq, Bug> {
1054        self.seq.increment::<A::NonceSize>()
1055    }
1056
1057    /// Decrypts and authenticates `ciphertext`.
1058    ///
1059    /// The resulting plaintext is written to `dst`, which must
1060    /// must be at least `ciphertext.len()` - [`Self::OVERHEAD`]
1061    /// bytes long.
1062    pub fn open(
1063        &mut self,
1064        dst: &mut [u8],
1065        ciphertext: &[u8],
1066        additional_data: &[u8],
1067    ) -> Result<(), HpkeError> {
1068        self.open_at(dst, ciphertext, additional_data, self.seq)?;
1069        self.increment_seq()?;
1070        Ok(())
1071    }
1072
1073    /// Decrypts and authenticates `ciphertext` at a particular
1074    /// sequence number.
1075    ///
1076    /// The resulting plaintext is written to `dst`, which must
1077    /// must be at least `ciphertext.len()` - [`Self::OVERHEAD`]
1078    /// bytes long.
1079    pub fn open_at(
1080        &self,
1081        dst: &mut [u8],
1082        ciphertext: &[u8],
1083        additional_data: &[u8],
1084        seq: Seq,
1085    ) -> Result<(), HpkeError> {
1086        let nonce = seq.compute_nonce::<A::NonceSize>(&self.base_nonce)?;
1087        self.aead.open(dst, &nonce, ciphertext, additional_data)?;
1088        Ok(())
1089    }
1090
1091    /// Decrypts and authenticates `ciphertext`.
1092    pub fn open_in_place(
1093        &mut self,
1094        mut data: impl AsMut<[u8]>,
1095        tag: &[u8],
1096        additional_data: &[u8],
1097    ) -> Result<(), HpkeError> {
1098        self.open_in_place_at(data.as_mut(), tag, additional_data, self.seq)?;
1099        self.increment_seq()?;
1100        Ok(())
1101    }
1102
1103    /// Decrypts and authenticates `ciphertext` at a particular
1104    /// sequence number.
1105    pub fn open_in_place_at(
1106        &self,
1107        mut data: impl AsMut<[u8]>,
1108        tag: &[u8],
1109        additional_data: &[u8],
1110        seq: Seq,
1111    ) -> Result<(), HpkeError> {
1112        let nonce = seq.compute_nonce::<A::NonceSize>(&self.base_nonce)?;
1113        self.aead
1114            .open_in_place(&nonce, data.as_mut(), tag, additional_data)?;
1115        Ok(())
1116    }
1117}
1118
1119impl<A: HpkeAead + fmt::Debug> fmt::Debug for OpenCtx<A> {
1120    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1121        f.debug_struct("OpenCtx")
1122            .field("aead", &self.aead)
1123            .field("base_nonce", &self.base_nonce)
1124            .field("seq", &self.seq)
1125            .finish()
1126    }
1127}
1128
1129/// HPKE's message limit has been reached.
1130#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1131pub struct MessageLimitReached;
1132
1133impl fmt::Display for MessageLimitReached {
1134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1135        f.write_str("message limit reached")
1136    }
1137}
1138
1139impl core::error::Error for MessageLimitReached {}
1140
1141/// Sequence numbers ensure nonce uniqueness.
1142#[derive(Copy, Clone, Debug, Default, Hash, Eq, PartialEq, Ord, PartialOrd)]
1143pub struct Seq {
1144    /// The sequence number.
1145    ///
1146    /// It's encoded as a big-endian integer (I2OSP) and XORed
1147    /// with the `base_nonce`.
1148    ///
1149    /// This should be the size of the nonce, but it's
1150    /// vanishingly unlikely that we'll ever overflow. Since
1151    /// encryption contexts ([`SealCtx`], etc.) can only be used
1152    /// serially, we can only overflow if the user actually
1153    /// performs 2^64-1 operations. At an impossible one
1154    /// nanosecond per encryption, this will take upward of 500
1155    /// years.
1156    seq: u64,
1157}
1158
1159impl Seq {
1160    /// The zero value of a `Seq`.
1161    pub const ZERO: Self = Self::new(0);
1162
1163    /// Creates a sequence number.
1164    #[inline]
1165    pub const fn new(seq: u64) -> Self {
1166        Self { seq }
1167    }
1168
1169    /// Converts itself to a `u64`.
1170    #[inline]
1171    pub const fn to_u64(self) -> u64 {
1172        self.seq
1173    }
1174
1175    /// Returns the maximum allowed sequence number.
1176    ///
1177    /// Exported for `aranya-crypto`. Do not use.
1178    #[doc(hidden)]
1179    pub const fn max<N: ArrayLength>() -> u64 {
1180        // 1<<(8*N) - 1
1181        let shift = 8usize.saturating_mul(N::USIZE);
1182        match 1u64.checked_shl(shift as u32) {
1183            Some(v) => v.saturating_sub(1),
1184            None => u64::MAX,
1185        }
1186    }
1187
1188    /// Increments the sequence by one and returns the *previous*
1189    /// sequence number.
1190    fn increment<N: ArrayLength>(&mut self) -> Result<Self, Bug> {
1191        // if self.seq >= (1 << (8*Nn)) - 1:
1192        //     raise MessageLimitReachedError
1193        if self.seq >= Self::max::<N>() {
1194            // We only call `Seq::increment` after computing the
1195            // nonce, which requires `seq < Self::max`.
1196            bug!("`Seq::increment` called after limit reached");
1197        }
1198        // self.seq += 1
1199        let prev = self.seq;
1200        self.seq = prev
1201            .checked_add(1)
1202            .assume("`Seq` overflow should be impossible")?;
1203        Ok(Self { seq: prev })
1204    }
1205
1206    /// Computes the per-message nonce.
1207    fn compute_nonce<N: ArrayLength>(
1208        self,
1209        base_nonce: &Nonce<N>,
1210    ) -> Result<Nonce<N>, MessageLimitReached> {
1211        if self.seq >= Self::max::<N>() {
1212            Err(MessageLimitReached)
1213        } else {
1214            //  seq_bytes = I2OSP(seq, Nn)
1215            let seq_bytes = i2osp!(self.seq, N);
1216            // xor(self.base_nonce, seq_bytes)
1217            Ok(base_nonce ^ &Nonce::from_bytes(seq_bytes))
1218        }
1219    }
1220}
1221
1222impl fmt::Display for Seq {
1223    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1224        write!(f, "{}", self.seq)
1225    }
1226}
1227
1228struct ExportCtx<K, F, A>
1229where
1230    K: HpkeKem,
1231    F: HpkeKdf,
1232    A: HpkeAead,
1233{
1234    exporter_secret: Prk<F::PrkSize>,
1235    _etc: PhantomData<fn() -> (K, A)>,
1236}
1237
1238impl<K, F, A> ExportCtx<K, F, A>
1239where
1240    K: HpkeKem,
1241    F: HpkeKdf,
1242    A: HpkeAead,
1243{
1244    fn new(exporter_secret: Prk<F::PrkSize>) -> Self {
1245        Self {
1246            exporter_secret,
1247            _etc: PhantomData,
1248        }
1249    }
1250
1251    /// Exports a secret from the context.
1252    fn export<T>(&self, context: &[u8]) -> Result<T, KdfError>
1253    where
1254        T: Expand,
1255    {
1256        // def Context.Export(exporter_context, L):
1257        //   return LabeledExpand(self.exporter_secret, "sec",
1258        //                        exporter_context, L)
1259        Hpke::<K, F, A>::labeled_expand(&self.exporter_secret, b"sec", [context])
1260    }
1261
1262    /// Exports a secret from the context, writing it to `out`.
1263    fn export_into(&self, out: &mut [u8], context: &[u8]) -> Result<(), KdfError> {
1264        // def Context.Export(exporter_context, L):
1265        //   return LabeledExpand(self.exporter_secret, "sec",
1266        //                        exporter_context, L)
1267        Hpke::<K, F, A>::labeled_expand_into(out, &self.exporter_secret, b"sec", [context])
1268    }
1269}
1270
1271impl<K, F, A> fmt::Debug for ExportCtx<K, F, A>
1272where
1273    K: HpkeKem,
1274    F: HpkeKdf,
1275    A: HpkeAead,
1276{
1277    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1278        f.debug_struct("ExportCtx").finish_non_exhaustive()
1279    }
1280}
1281
1282#[cfg(test)]
1283mod tests {
1284    #![allow(clippy::panic)]
1285
1286    use std::{collections::HashSet, ops::RangeInclusive};
1287
1288    use typenum::{U1, U2};
1289
1290    use super::*;
1291
1292    /// Tests that [`Seq::compute_nonce`] generates correct
1293    /// nonces.
1294    #[test]
1295    fn test_seq_compute_nonce() {
1296        let base = Nonce::<U1>::try_from_slice(&[0xfe]).expect("should be able to create nonce");
1297        let cases = [
1298            (0, Ok(&[0xfe])),
1299            (1, Ok(&[0xff])),
1300            (2, Ok(&[0xfc])),
1301            (4, Ok(&[0xfa])),
1302            (254, Ok(&[0x00])),
1303            (255, Err(MessageLimitReached)),
1304            (256, Err(MessageLimitReached)),
1305            (257, Err(MessageLimitReached)),
1306            (u64::MAX, Err(MessageLimitReached)),
1307        ];
1308        for (input, output) in cases {
1309            let got = Seq::new(input).compute_nonce::<U1>(&base);
1310            let want = output.map(|s| Nonce::try_from_slice(s).expect("unable to create nonce"));
1311            assert_eq!(got, want, "seq = {input}");
1312        }
1313    }
1314
1315    /// Tests that all nonces are unique.
1316    #[test]
1317    fn test_seq_unique_nonce() {
1318        let base =
1319            Nonce::<U2>::try_from_slice(&[0xfe, 0xfe]).expect("should be able to create nonce");
1320        let mut seen = HashSet::new();
1321        for v in 0..u16::MAX {
1322            let got = Seq::new(u64::from(v))
1323                .compute_nonce::<U2>(&base)
1324                .expect("unable to create nonce");
1325            assert!(seen.insert(got), "duplicate nonce: {got:?}");
1326        }
1327    }
1328
1329    #[test]
1330    fn test_invalid_psk() {
1331        let err = Psk::new(&[], &[]).expect_err("should get `InvalidPsk`");
1332        assert_eq!(err, InvalidPsk);
1333    }
1334
1335    #[test]
1336    fn test_psk_ct_eq() {
1337        let cases = [
1338            (true, ("abc", "123"), ("abc", "123")),
1339            (false, ("a", "b"), ("a", "x")),
1340            (false, ("a", "b"), ("x", "b")),
1341            (false, ("a", "b"), ("c", "d")),
1342        ];
1343        for (pass, lhs, rhs) in cases {
1344            let lhs = Psk::new(lhs.0.as_bytes(), lhs.1.as_bytes()).expect("should not fail");
1345            let rhs = Psk::new(rhs.0.as_bytes(), rhs.1.as_bytes()).expect("should not fail");
1346            assert_eq!(pass, bool::from(lhs.ct_eq(&rhs)));
1347        }
1348    }
1349
1350    /// Tests that [`AeadId`] is assigned correctly.
1351    #[test]
1352    fn test_aead_id() {
1353        let unassigned = 0x0004..=0xFFFE;
1354        for id in unassigned {
1355            let want = AeadId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
1356            let encoded = want.to_be_bytes();
1357            let got = AeadId::try_from_be_bytes(encoded).unwrap_or_else(|err| {
1358                panic!("should be able to decode unassigned `AeadId` {id}: {err}")
1359            });
1360            assert_eq!(got, want);
1361        }
1362    }
1363
1364    /// Tests that [`KdfId`] is assigned correctly.
1365    #[test]
1366    fn test_kdf_id() {
1367        let unassigned = 0x0004..=0xFFFF;
1368        for id in unassigned {
1369            let want = KdfId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
1370            let encoded = want.to_be_bytes();
1371            let got = KdfId::try_from_be_bytes(encoded).unwrap_or_else(|err| {
1372                panic!("should be able to decode unassigned `KdfId` {id}: {err}")
1373            });
1374            assert_eq!(got, want);
1375        }
1376    }
1377
1378    /// Tests that [`KemId`] is assigned correctly.
1379    #[test]
1380    fn test_kem_id() {
1381        let unassigned: [RangeInclusive<u16>; 6] = [
1382            0x0001..=0x000F,
1383            0x0017..=0x001F,
1384            0x0022..=0x002F,
1385            0x0031..=0x0039,
1386            0x0043..=0x6479,
1387            0x647b..=0xFFFF,
1388        ];
1389        for id in unassigned.into_iter().flatten() {
1390            let want = KemId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
1391            let encoded = want.to_be_bytes();
1392            let got = KemId::try_from_be_bytes(encoded).unwrap_or_else(|err| {
1393                panic!("should be able to decode unassigned `KemId` {id}: {err}")
1394            });
1395            assert_eq!(got, want);
1396        }
1397    }
1398}