1use alloc::vec::Vec;
7use core::fmt;
8
9use fips204::traits::{SerDes, Verifier};
10use primitives::PqScheme;
11use tide_fn_dsa_vrfy::{FalconProfile, VerifyingKey as TideVerifyingKey, VerifyingKeyStandard};
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub struct PqPublicKey {
16 scheme: PqScheme,
17 data: Vec<u8>,
18}
19
20impl PqPublicKey {
21 pub fn from_scheme_and_bytes(scheme: PqScheme, data: &[u8]) -> Result<Self, PqError> {
24 let expected = scheme.pubkey_len();
25 if data.len() != expected {
26 return Err(PqError::InvalidKeyLength { expected, got: data.len() });
27 }
28 Ok(Self { scheme, data: data.to_vec() })
29 }
30
31 pub fn from_prefixed_slice(data: &[u8]) -> Result<Self, PqError> {
33 let (&prefix, raw) = data.split_first().ok_or(PqError::EmptyData)?;
34 let scheme = PqScheme::from_prefix(prefix).ok_or(PqError::UnknownScheme(prefix))?;
35 Self::from_scheme_and_bytes(scheme, raw)
36 }
37
38 pub fn scheme(&self) -> PqScheme {
40 self.scheme
41 }
42
43 pub fn as_bytes(&self) -> &[u8] {
45 &self.data
46 }
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
51pub struct PqSignature {
52 data: Vec<u8>,
53}
54
55impl PqSignature {
56 pub fn from_slice(data: &[u8]) -> Self {
58 Self { data: data.to_vec() }
59 }
60
61 pub fn as_bytes(&self) -> &[u8] {
63 &self.data
64 }
65
66 pub fn verify_msg32(&self, msg: &[u8; 32], pk: &PqPublicKey) -> Result<(), PqError> {
68 self.verify_message(msg, pk)
69 }
70
71 pub fn verify_msg64(&self, msg: &[u8; 64], pk: &PqPublicKey) -> Result<(), PqError> {
73 self.verify_message(msg, pk)
74 }
75
76 pub fn verify_msg32_allow_legacy(
79 &self,
80 msg: &[u8; 32],
81 pk: &PqPublicKey,
82 ) -> Result<(), PqError> {
83 self.verify_message_allow_legacy(msg, pk)
84 }
85
86 pub fn verify_msg32_legacy(&self, msg: &[u8; 32], pk: &PqPublicKey) -> Result<(), PqError> {
89 self.verify_message_legacy(msg, pk)
90 }
91
92 pub fn verify_msg64_legacy(&self, msg: &[u8; 64], pk: &PqPublicKey) -> Result<(), PqError> {
95 self.verify_message_legacy(msg, pk)
96 }
97
98 pub fn verify_msg64_allow_legacy(
101 &self,
102 msg: &[u8; 64],
103 pk: &PqPublicKey,
104 ) -> Result<(), PqError> {
105 self.verify_message_allow_legacy(msg, pk)
106 }
107
108 fn verify_message(&self, msg: &[u8], pk: &PqPublicKey) -> Result<(), PqError> {
109 match pk.scheme() {
110 PqScheme::Falcon512 | PqScheme::Falcon1024 => {
111 let vk =
112 VerifyingKeyStandard::decode(pk.as_bytes()).ok_or(PqError::BackendFailure)?;
113 if vk.verify_falcon(FalconProfile::PqClean, self.as_bytes(), msg) {
114 Ok(())
115 } else {
116 Err(PqError::VerificationFailed)
117 }
118 }
119 PqScheme::MlDsa44 => {
120 let pk_arr: [u8; 1312] =
121 pk.as_bytes().try_into().map_err(|_| PqError::BackendFailure)?;
122 let pk_obj = fips204::ml_dsa_44::PublicKey::try_from_bytes(pk_arr)
123 .map_err(|_| PqError::BackendFailure)?;
124 let sig_arr: [u8; 2420] =
125 self.as_bytes().try_into().map_err(|_| PqError::VerificationFailed)?;
126 if pk_obj.verify(msg, &sig_arr, &[]) {
127 Ok(())
128 } else {
129 Err(PqError::VerificationFailed)
130 }
131 }
132 PqScheme::MlDsa65 => {
133 let pk_arr: [u8; 1952] =
134 pk.as_bytes().try_into().map_err(|_| PqError::BackendFailure)?;
135 let pk_obj = fips204::ml_dsa_65::PublicKey::try_from_bytes(pk_arr)
136 .map_err(|_| PqError::BackendFailure)?;
137 let sig_arr: [u8; 3309] =
138 self.as_bytes().try_into().map_err(|_| PqError::VerificationFailed)?;
139 if pk_obj.verify(msg, &sig_arr, &[]) {
140 Ok(())
141 } else {
142 Err(PqError::VerificationFailed)
143 }
144 }
145 PqScheme::MlDsa87 => {
146 let pk_arr: [u8; 2592] =
147 pk.as_bytes().try_into().map_err(|_| PqError::BackendFailure)?;
148 let pk_obj = fips204::ml_dsa_87::PublicKey::try_from_bytes(pk_arr)
149 .map_err(|_| PqError::BackendFailure)?;
150 let sig_arr: [u8; 4627] =
151 self.as_bytes().try_into().map_err(|_| PqError::VerificationFailed)?;
152 if pk_obj.verify(msg, &sig_arr, &[]) {
153 Ok(())
154 } else {
155 Err(PqError::VerificationFailed)
156 }
157 }
158 }
159 }
160
161 fn verify_message_legacy(&self, msg: &[u8], pk: &PqPublicKey) -> Result<(), PqError> {
162 if pk.scheme() != PqScheme::Falcon512 {
163 return self.verify_message(msg, pk);
164 }
165 let vk = VerifyingKeyStandard::decode(pk.as_bytes()).ok_or(PqError::BackendFailure)?;
166 if vk.verify_falcon(FalconProfile::TidecoinLegacyFalcon512, self.as_bytes(), msg) {
167 Ok(())
168 } else {
169 Err(PqError::VerificationFailed)
170 }
171 }
172
173 fn verify_message_allow_legacy(&self, msg: &[u8], pk: &PqPublicKey) -> Result<(), PqError> {
174 match self.verify_message(msg, pk) {
175 Ok(()) => Ok(()),
176 Err(PqError::VerificationFailed) if pk.scheme() == PqScheme::Falcon512 => {
177 self.verify_message_legacy(msg, pk)
178 }
179 Err(err) => Err(err),
180 }
181 }
182}
183
184#[derive(Debug, Clone, PartialEq, Eq)]
186pub enum PqError {
187 UnknownScheme(u8),
189 InvalidKeyLength {
191 expected: usize,
193 got: usize,
195 },
196 VerificationFailed,
198 BackendFailure,
200 EmptyData,
202}
203
204impl fmt::Display for PqError {
205 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206 match *self {
207 Self::UnknownScheme(b) => write!(f, "unknown PQ scheme prefix: 0x{b:02x}"),
208 Self::InvalidKeyLength { expected, got } => {
209 write!(f, "invalid key length: expected {expected} bytes, got {got}")
210 }
211 Self::VerificationFailed => write!(f, "signature verification failed"),
212 Self::BackendFailure => write!(f, "backend PQ operation failed"),
213 Self::EmptyData => write!(f, "empty data"),
214 }
215 }
216}
217
218#[cfg(feature = "std")]
219impl std::error::Error for PqError {}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn scheme_prefix_round_trip() {
227 for scheme in [
228 PqScheme::Falcon512,
229 PqScheme::Falcon1024,
230 PqScheme::MlDsa44,
231 PqScheme::MlDsa65,
232 PqScheme::MlDsa87,
233 ] {
234 assert_eq!(PqScheme::from_prefix(scheme.prefix()), Some(scheme));
235 }
236 }
237
238 #[test]
239 fn prefixed_pubkey_rejects_wrong_length() {
240 let err =
241 PqPublicKey::from_prefixed_slice(&[PqScheme::Falcon512.prefix(), 1, 2]).unwrap_err();
242 assert_eq!(err, PqError::InvalidKeyLength { expected: 897, got: 2 });
243 }
244}