sad_rsa/pss/
verifying_key.rs

1use super::{verify_digest, Signature};
2use crate::RsaPublicKey;
3use core::marker::PhantomData;
4use digest::{Digest, FixedOutputReset, Update};
5use signature::{hazmat::PrehashVerifier, DigestVerifier, Verifier};
6
7#[cfg(feature = "encoding")]
8use {
9    crate::encoding::ID_RSASSA_PSS,
10    const_oid::AssociatedOid,
11    pkcs8::{Document, EncodePublicKey},
12    spki::{der::AnyRef, AlgorithmIdentifierRef, AssociatedAlgorithmIdentifier},
13};
14#[cfg(feature = "serde")]
15use {
16    serdect::serde::{de, ser, Deserialize, Serialize},
17    spki::DecodePublicKey,
18};
19
20/// Verifying key for checking the validity of RSASSA-PSS signatures as
21/// described in [RFC8017 § 8.1].
22///
23/// [RFC8017 § 8.1]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.1
24#[derive(Debug)]
25pub struct VerifyingKey<D>
26where
27    D: Digest,
28{
29    pub(super) inner: RsaPublicKey,
30    pub(super) salt_len: Option<usize>,
31    pub(super) phantom: PhantomData<D>,
32}
33
34impl<D> VerifyingKey<D>
35where
36    D: Digest,
37{
38    /// Create a new RSASSA-PSS verifying key.
39    /// Digest output size is used as a salt length.
40    pub fn new(key: RsaPublicKey) -> Self {
41        Self::new_with_salt_len(key, <D as Digest>::output_size())
42    }
43
44    /// Create a new RSASSA-PSS verifying key.
45    pub fn new_with_salt_len(key: RsaPublicKey, salt_len: usize) -> Self {
46        Self {
47            inner: key,
48            salt_len: Some(salt_len),
49            phantom: Default::default(),
50        }
51    }
52
53    /// Create a new RSASSA-PSS verifying key.
54    /// Attempts to automatically detect the salt length.
55    pub fn new_with_auto_salt_len(key: RsaPublicKey) -> Self {
56        Self {
57            inner: key,
58            salt_len: None,
59            phantom: Default::default(),
60        }
61    }
62
63    /// Return specified salt length for this key
64    pub fn salt_len(&self) -> Option<usize> {
65        self.salt_len
66    }
67}
68
69//
70// `*Verifier` trait impls
71//
72
73impl<D> DigestVerifier<D, Signature> for VerifyingKey<D>
74where
75    D: Digest + FixedOutputReset + Update,
76{
77    fn verify_digest<F: Fn(&mut D) -> signature::Result<()>>(
78        &self,
79        f: F,
80        signature: &Signature,
81    ) -> signature::Result<()> {
82        let mut digest = D::new();
83        f(&mut digest)?;
84        verify_digest::<D>(
85            &self.inner,
86            &digest.finalize(),
87            &signature.inner,
88            self.salt_len,
89        )
90        .map_err(|e| e.into())
91    }
92}
93
94impl<D> PrehashVerifier<Signature> for VerifyingKey<D>
95where
96    D: Digest + FixedOutputReset,
97{
98    fn verify_prehash(&self, prehash: &[u8], signature: &Signature) -> signature::Result<()> {
99        verify_digest::<D>(&self.inner, prehash, &signature.inner, self.salt_len)
100            .map_err(|e| e.into())
101    }
102}
103
104impl<D> Verifier<Signature> for VerifyingKey<D>
105where
106    D: Digest + FixedOutputReset,
107{
108    fn verify(&self, msg: &[u8], signature: &Signature) -> signature::Result<()> {
109        verify_digest::<D>(
110            &self.inner,
111            &D::digest(msg),
112            &signature.inner,
113            self.salt_len,
114        )
115        .map_err(|e| e.into())
116    }
117}
118
119//
120// Other trait impls
121//
122
123impl<D> AsRef<RsaPublicKey> for VerifyingKey<D>
124where
125    D: Digest,
126{
127    fn as_ref(&self) -> &RsaPublicKey {
128        &self.inner
129    }
130}
131
132#[cfg(feature = "encoding")]
133impl<D> AssociatedAlgorithmIdentifier for VerifyingKey<D>
134where
135    D: Digest,
136{
137    type Params = AnyRef<'static>;
138
139    const ALGORITHM_IDENTIFIER: AlgorithmIdentifierRef<'static> = pkcs1::ALGORITHM_ID;
140}
141
142// Implemented manually so we don't have to bind D with Clone
143impl<D> Clone for VerifyingKey<D>
144where
145    D: Digest,
146{
147    fn clone(&self) -> Self {
148        Self {
149            inner: self.inner.clone(),
150            salt_len: self.salt_len,
151            phantom: Default::default(),
152        }
153    }
154}
155
156#[cfg(feature = "encoding")]
157impl<D> EncodePublicKey for VerifyingKey<D>
158where
159    D: Digest,
160{
161    fn to_public_key_der(&self) -> spki::Result<Document> {
162        self.inner.to_public_key_der()
163    }
164}
165
166impl<D> From<RsaPublicKey> for VerifyingKey<D>
167where
168    D: Digest,
169{
170    fn from(key: RsaPublicKey) -> Self {
171        Self::new(key)
172    }
173}
174
175impl<D> From<VerifyingKey<D>> for RsaPublicKey
176where
177    D: Digest,
178{
179    fn from(key: VerifyingKey<D>) -> Self {
180        key.inner
181    }
182}
183
184#[cfg(feature = "encoding")]
185impl<D> TryFrom<pkcs8::SubjectPublicKeyInfoRef<'_>> for VerifyingKey<D>
186where
187    D: Digest + AssociatedOid,
188{
189    type Error = spki::Error;
190
191    fn try_from(spki: pkcs8::SubjectPublicKeyInfoRef<'_>) -> spki::Result<Self> {
192        match spki.algorithm.oid {
193            ID_RSASSA_PSS | pkcs1::ALGORITHM_OID => (),
194            _ => {
195                return Err(spki::Error::OidUnknown {
196                    oid: spki.algorithm.oid,
197                });
198            }
199        }
200
201        RsaPublicKey::try_from(spki).map(Self::new)
202    }
203}
204
205impl<D> PartialEq for VerifyingKey<D>
206where
207    D: Digest,
208{
209    fn eq(&self, other: &Self) -> bool {
210        self.inner == other.inner && self.salt_len == other.salt_len
211    }
212}
213
214#[cfg(feature = "serde")]
215impl<D> Serialize for VerifyingKey<D>
216where
217    D: Digest,
218{
219    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
220    where
221        S: serde::Serializer,
222    {
223        let der = self.to_public_key_der().map_err(ser::Error::custom)?;
224        serdect::slice::serialize_hex_lower_or_bin(&der, serializer)
225    }
226}
227
228#[cfg(feature = "serde")]
229impl<'de, D> Deserialize<'de> for VerifyingKey<D>
230where
231    D: Digest + AssociatedOid,
232{
233    fn deserialize<De>(deserializer: De) -> Result<Self, De::Error>
234    where
235        De: serde::Deserializer<'de>,
236    {
237        let der_bytes = serdect::slice::deserialize_hex_or_bin_vec(deserializer)?;
238        Self::from_public_key_der(&der_bytes).map_err(de::Error::custom)
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    #[test]
245    #[cfg(all(feature = "hazmat", feature = "serde"))]
246    fn test_serde() {
247        use super::*;
248        use crate::RsaPrivateKey;
249        use rand::rngs::ChaCha8Rng;
250        use rand_core::SeedableRng;
251        use serde_test::{assert_tokens, Configure, Token};
252        use sha2::Sha256;
253
254        let mut rng = ChaCha8Rng::from_seed([42; 32]);
255        let priv_key = RsaPrivateKey::new_unchecked(&mut rng, 64).expect("failed to generate key");
256        let pub_key = priv_key.to_public_key();
257        let verifying_key = VerifyingKey::<Sha256>::new(pub_key);
258
259        let tokens = [Token::Str(
260            "3024300d06092a864886f70d01010105000313003010020900ab240c3361d02e370203010001",
261        )];
262
263        assert_tokens(&verifying_key.readable(), &tokens);
264    }
265}