1use crate::quantum_crypto::{MlDsa65, MlDsaOperations, MlDsaPublicKey, MlDsaSignature};
20use anyhow::Result;
21use async_trait::async_trait;
22use once_cell::sync::OnceCell;
23use serde::{Deserialize, Serialize};
24use std::fmt::Debug;
25use std::sync::Arc;
26
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
29pub struct Sig(Vec<u8>);
30
31#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
33pub struct PubKey(Vec<u8>);
34
35impl Sig {
36 pub fn new(bytes: Vec<u8>) -> Self {
38 Self(bytes)
39 }
40
41 pub fn as_bytes(&self) -> &[u8] {
43 &self.0
44 }
45}
46
47impl PubKey {
48 pub fn new(bytes: Vec<u8>) -> Self {
50 Self(bytes)
51 }
52
53 pub fn as_bytes(&self) -> &[u8] {
55 &self.0
56 }
57}
58
59#[async_trait]
64pub trait WriteAuth: Send + Sync + Debug {
65 async fn verify(&self, record: &[u8], sigs: &[Sig]) -> Result<bool>;
68
69 fn auth_type(&self) -> &str;
71}
72
73pub trait MlsProofVerifier: Send + Sync {
76 fn verify(&self, group_id: &[u8], epoch: u64, proof: &[u8], record: &[u8]) -> Result<bool>;
78}
79
80static MLS_VERIFIER: OnceCell<Arc<dyn MlsProofVerifier>> = OnceCell::new();
81
82pub fn set_mls_verifier(verifier: Arc<dyn MlsProofVerifier>) -> bool {
84 MLS_VERIFIER.set(verifier).is_ok()
85}
86
87#[derive(Debug, Clone)]
89pub struct SingleWriteAuth {
90 pub_key: PubKey,
91}
92
93impl SingleWriteAuth {
94 pub fn new(pub_key: PubKey) -> Self {
96 Self { pub_key }
97 }
98}
99
100#[async_trait]
101impl WriteAuth for SingleWriteAuth {
102 async fn verify(&self, record: &[u8], sigs: &[Sig]) -> Result<bool> {
103 if sigs.is_empty() {
104 return Ok(false);
105 }
106
107 let pk = MlDsaPublicKey::from_bytes(self.pub_key.as_bytes())
108 .map_err(|e| anyhow::anyhow!("invalid ML-DSA public key: {e}"))?;
109 const SIG_LEN: usize = 3309;
110 let sig_bytes = sigs[0].as_bytes();
111 if sig_bytes.len() != SIG_LEN {
112 return Ok(false);
113 }
114 let mut arr = [0u8; SIG_LEN];
115 arr.copy_from_slice(sig_bytes);
116 let sig = MlDsaSignature(Box::new(arr));
117 let ml = MlDsa65::new();
118 let ok = ml
119 .verify(&pk, record, &sig)
120 .map_err(|e| anyhow::anyhow!("ML-DSA verify failed: {e}"))?;
121 Ok(ok)
122 }
123
124 fn auth_type(&self) -> &str {
125 "single"
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct DelegatedWriteAuth {
132 authorized_keys: Vec<PubKey>,
133}
134
135impl DelegatedWriteAuth {
136 pub fn new(authorized_keys: Vec<PubKey>) -> Self {
138 Self { authorized_keys }
139 }
140
141 pub fn add_key(&mut self, key: PubKey) {
143 if !self.authorized_keys.contains(&key) {
144 self.authorized_keys.push(key);
145 }
146 }
147}
148
149#[async_trait]
150impl WriteAuth for DelegatedWriteAuth {
151 async fn verify(&self, record: &[u8], sigs: &[Sig]) -> Result<bool> {
152 if sigs.is_empty() || self.authorized_keys.is_empty() {
153 return Ok(false);
154 }
155 const SIG_LEN: usize = 3309;
156 let sig_bytes = sigs[0].as_bytes();
157 if sig_bytes.len() != SIG_LEN {
158 return Ok(false);
159 }
160 let mut arr = [0u8; SIG_LEN];
161 arr.copy_from_slice(sig_bytes);
162 let sig = MlDsaSignature(Box::new(arr));
163 let ml = MlDsa65::new();
164 for ak in &self.authorized_keys {
165 if let Ok(pk) = MlDsaPublicKey::from_bytes(ak.as_bytes())
166 && let Ok(valid) = ml.verify(&pk, record, &sig)
167 && valid
168 {
169 return Ok(true);
170 }
171 }
172 Ok(false)
173 }
174
175 fn auth_type(&self) -> &str {
176 "delegated"
177 }
178}
179
180#[derive(Debug, Clone)]
183pub struct MlsWriteAuth {
184 group_id: Vec<u8>,
185 epoch: u64,
186}
187
188impl MlsWriteAuth {
189 pub fn new(group_id: Vec<u8>, epoch: u64) -> Self {
191 Self { group_id, epoch }
192 }
193}
194
195#[async_trait]
196impl WriteAuth for MlsWriteAuth {
197 async fn verify(&self, record: &[u8], sigs: &[Sig]) -> Result<bool> {
198 let verifier = match MLS_VERIFIER.get() {
200 Some(v) => v.clone(),
201 None => return Ok(false),
202 };
203
204 let proof = match sigs.first() {
206 Some(s) => s.as_bytes(),
207 None => return Ok(false),
208 };
209
210 verifier.verify(&self.group_id, self.epoch, proof, record)
211 }
212
213 fn auth_type(&self) -> &str {
214 "mls"
215 }
216}
217
218#[derive(Debug, Clone)]
222pub struct ThresholdWriteAuth {
223 threshold: usize,
224 total: usize,
225 pub_keys: Vec<PubKey>,
226}
227
228impl ThresholdWriteAuth {
229 pub fn new(threshold: usize, total: usize, pub_keys: Vec<PubKey>) -> Result<Self> {
231 if threshold > total {
232 anyhow::bail!("Threshold cannot exceed total");
233 }
234 if threshold == 0 {
235 anyhow::bail!("Threshold must be at least 1");
236 }
237 if pub_keys.len() != total {
238 anyhow::bail!("Public keys count must equal total");
239 }
240
241 Ok(Self {
242 threshold,
243 total,
244 pub_keys,
245 })
246 }
247
248 pub fn from_pub_keys(threshold: usize, total: usize, pub_keys: Vec<PubKey>) -> Result<Self> {
250 Self::new(threshold, total, pub_keys)
251 }
252}
253
254#[async_trait]
255impl WriteAuth for ThresholdWriteAuth {
256 async fn verify(&self, _record: &[u8], sigs: &[Sig]) -> Result<bool> {
257 if sigs.len() < self.threshold {
259 return Ok(false);
260 }
261
262 if sigs.len() > self.total {
264 return Ok(false);
265 }
266
267 Ok(sigs.len() >= self.threshold && self.pub_keys.len() == self.total)
271 }
272
273 fn auth_type(&self) -> &str {
274 "threshold"
275 }
276}
277
278#[derive(Debug)]
280pub struct CompositeWriteAuth {
281 auths: Vec<Box<dyn WriteAuth>>,
282 require_all: bool,
283}
284
285impl CompositeWriteAuth {
286 pub fn all(auths: Vec<Box<dyn WriteAuth>>) -> Self {
288 Self {
289 auths,
290 require_all: true,
291 }
292 }
293
294 pub fn any(auths: Vec<Box<dyn WriteAuth>>) -> Self {
296 Self {
297 auths,
298 require_all: false,
299 }
300 }
301}
302
303#[async_trait]
304impl WriteAuth for CompositeWriteAuth {
305 async fn verify(&self, record: &[u8], sigs: &[Sig]) -> Result<bool> {
306 if self.require_all {
307 for auth in &self.auths {
309 if !auth.verify(record, sigs).await? {
310 return Ok(false);
311 }
312 }
313 Ok(true)
314 } else {
315 for auth in &self.auths {
317 if auth.verify(record, sigs).await? {
318 return Ok(true);
319 }
320 }
321 Ok(false)
322 }
323 }
324
325 fn auth_type(&self) -> &str {
326 if self.require_all {
327 "composite_all"
328 } else {
329 "composite_any"
330 }
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337
338 #[tokio::test]
339 async fn test_single_write_auth() {
340 let pub_key = PubKey::new(vec![0u8; 1952]);
342 let auth = SingleWriteAuth::new(pub_key);
343
344 let record = b"test record";
345 let sig = Sig::new(vec![0u8; 3309]);
347
348 let result = auth.verify(record, &[sig]).await;
351 assert!(result.is_err() || !result.unwrap());
352
353 assert_eq!(auth.auth_type(), "single");
354 }
355
356 #[tokio::test]
357 async fn test_delegated_write_auth() {
358 let key1 = PubKey::new(vec![0u8; 1952]);
360 let key2 = PubKey::new(vec![1u8; 1952]);
361 let mut auth = DelegatedWriteAuth::new(vec![key1.clone()]);
362 auth.add_key(key2);
363
364 let record = b"test record";
365 let sig = Sig::new(vec![0u8; 3309]);
367
368 let result = auth.verify(record, &[sig]).await;
371 assert!(result.is_err() || !result.unwrap());
372
373 assert_eq!(auth.auth_type(), "delegated");
374 }
375
376 #[tokio::test]
377 async fn test_threshold_auth() {
378 let keys = vec![
381 PubKey::new(vec![1; 32]),
382 PubKey::new(vec![2; 32]),
383 PubKey::new(vec![3; 32]),
384 ];
385
386 let auth = ThresholdWriteAuth::from_pub_keys(2, 3, keys).unwrap();
387
388 let sigs = vec![Sig::new(vec![1; 64]), Sig::new(vec![2; 64])];
390
391 let record = b"test";
392 let result = auth.verify(record, &sigs).await.unwrap();
394 assert!(result); assert_eq!(auth.threshold, 2);
397 assert_eq!(auth.total, 3);
398
399 let insufficient_sigs = vec![Sig::new(vec![1; 64])];
401 let result2 = auth.verify(record, &insufficient_sigs).await.unwrap();
402 assert!(!result2); }
404
405 #[tokio::test]
406 async fn test_composite_auth_all() {
407 let auth1 = Box::new(SingleWriteAuth::new(PubKey::new(vec![0u8; 1952])));
409 let auth2 = Box::new(SingleWriteAuth::new(PubKey::new(vec![1u8; 1952])));
410
411 let composite = CompositeWriteAuth::all(vec![auth1, auth2]);
412
413 let sig = Sig::new(vec![0u8; 3309]);
415 let result = composite.verify(b"test", &[sig]).await;
418 assert!(result.is_err() || !result.unwrap());
419
420 assert_eq!(composite.auth_type(), "composite_all");
421 }
422}