1use crate::quantum_crypto::{MlDsa65, MlDsaOperations, MlDsaPublicKey, MlDsaSignature};
20use anyhow::Result;
21use async_trait::async_trait;
22use once_cell::sync::OnceCell;
23use serde::{Deserialize, Serialize};
24use std::fmt::Debug;
25use std::sync::Arc;
26
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
29pub struct Sig(Vec<u8>);
30
31#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
33pub struct PubKey(Vec<u8>);
34
35impl Sig {
36 pub fn new(bytes: Vec<u8>) -> Self {
38 Self(bytes)
39 }
40
41 pub fn as_bytes(&self) -> &[u8] {
43 &self.0
44 }
45}
46
47impl PubKey {
48 pub fn new(bytes: Vec<u8>) -> Self {
50 Self(bytes)
51 }
52
53 pub fn as_bytes(&self) -> &[u8] {
55 &self.0
56 }
57}
58
59#[async_trait]
64pub trait WriteAuth: Send + Sync + Debug {
65 async fn verify(&self, record: &[u8], sigs: &[Sig]) -> Result<bool>;
68
69 fn auth_type(&self) -> &str;
71}
72
73pub trait MlsProofVerifier: Send + Sync {
76 fn verify(&self, group_id: &[u8], epoch: u64, proof: &[u8], record: &[u8]) -> Result<bool>;
78}
79
80static MLS_VERIFIER: OnceCell<Arc<dyn MlsProofVerifier>> = OnceCell::new();
81
82pub fn set_mls_verifier(verifier: Arc<dyn MlsProofVerifier>) -> bool {
84 MLS_VERIFIER.set(verifier).is_ok()
85}
86
87#[derive(Debug, Clone)]
89pub struct SingleWriteAuth {
90 pub_key: PubKey,
91}
92
93impl SingleWriteAuth {
94 pub fn new(pub_key: PubKey) -> Self {
96 Self { pub_key }
97 }
98}
99
100#[async_trait]
101impl WriteAuth for SingleWriteAuth {
102 async fn verify(&self, record: &[u8], sigs: &[Sig]) -> Result<bool> {
103 let Some(first_sig) = sigs.first() else {
104 return Ok(false);
105 };
106
107 let pk = MlDsaPublicKey::from_bytes(self.pub_key.as_bytes())
108 .map_err(|e| anyhow::anyhow!("invalid ML-DSA public key: {e}"))?;
109 const SIG_LEN: usize = 3309;
110 let sig_bytes = first_sig.as_bytes();
111 if sig_bytes.len() != SIG_LEN {
112 return Ok(false);
113 }
114 let mut arr = [0u8; SIG_LEN];
115 arr.copy_from_slice(sig_bytes);
116 let sig = MlDsaSignature(Box::new(arr));
117 let ml = MlDsa65::new();
118 let ok = ml
119 .verify(&pk, record, &sig)
120 .map_err(|e| anyhow::anyhow!("ML-DSA verify failed: {e}"))?;
121 Ok(ok)
122 }
123
124 fn auth_type(&self) -> &str {
125 "single"
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct DelegatedWriteAuth {
132 authorized_keys: Vec<PubKey>,
133}
134
135impl DelegatedWriteAuth {
136 pub fn new(authorized_keys: Vec<PubKey>) -> Self {
138 Self { authorized_keys }
139 }
140
141 pub fn add_key(&mut self, key: PubKey) {
143 if !self.authorized_keys.contains(&key) {
144 self.authorized_keys.push(key);
145 }
146 }
147}
148
149#[async_trait]
150impl WriteAuth for DelegatedWriteAuth {
151 async fn verify(&self, record: &[u8], sigs: &[Sig]) -> Result<bool> {
152 let Some(first_sig) = sigs.first() else {
153 return Ok(false);
154 };
155 if self.authorized_keys.is_empty() {
156 return Ok(false);
157 }
158 const SIG_LEN: usize = 3309;
159 let sig_bytes = first_sig.as_bytes();
160 if sig_bytes.len() != SIG_LEN {
161 return Ok(false);
162 }
163 let mut arr = [0u8; SIG_LEN];
164 arr.copy_from_slice(sig_bytes);
165 let sig = MlDsaSignature(Box::new(arr));
166 let ml = MlDsa65::new();
167 for ak in &self.authorized_keys {
168 if let Ok(pk) = MlDsaPublicKey::from_bytes(ak.as_bytes())
169 && let Ok(valid) = ml.verify(&pk, record, &sig)
170 && valid
171 {
172 return Ok(true);
173 }
174 }
175 Ok(false)
176 }
177
178 fn auth_type(&self) -> &str {
179 "delegated"
180 }
181}
182
183#[derive(Debug, Clone)]
186pub struct MlsWriteAuth {
187 group_id: Vec<u8>,
188 epoch: u64,
189}
190
191impl MlsWriteAuth {
192 pub fn new(group_id: Vec<u8>, epoch: u64) -> Self {
194 Self { group_id, epoch }
195 }
196}
197
198#[async_trait]
199impl WriteAuth for MlsWriteAuth {
200 async fn verify(&self, record: &[u8], sigs: &[Sig]) -> Result<bool> {
201 let verifier = match MLS_VERIFIER.get() {
203 Some(v) => v.clone(),
204 None => return Ok(false),
205 };
206
207 let proof = match sigs.first() {
209 Some(s) => s.as_bytes(),
210 None => return Ok(false),
211 };
212
213 verifier.verify(&self.group_id, self.epoch, proof, record)
214 }
215
216 fn auth_type(&self) -> &str {
217 "mls"
218 }
219}
220
221#[derive(Debug, Clone)]
225pub struct ThresholdWriteAuth {
226 threshold: usize,
227 total: usize,
228 pub_keys: Vec<PubKey>,
229}
230
231impl ThresholdWriteAuth {
232 pub fn new(threshold: usize, total: usize, pub_keys: Vec<PubKey>) -> Result<Self> {
234 if threshold > total {
235 anyhow::bail!("Threshold cannot exceed total");
236 }
237 if threshold == 0 {
238 anyhow::bail!("Threshold must be at least 1");
239 }
240 if pub_keys.len() != total {
241 anyhow::bail!("Public keys count must equal total");
242 }
243
244 Ok(Self {
245 threshold,
246 total,
247 pub_keys,
248 })
249 }
250
251 pub fn from_pub_keys(threshold: usize, total: usize, pub_keys: Vec<PubKey>) -> Result<Self> {
253 Self::new(threshold, total, pub_keys)
254 }
255}
256
257#[async_trait]
258impl WriteAuth for ThresholdWriteAuth {
259 async fn verify(&self, _record: &[u8], sigs: &[Sig]) -> Result<bool> {
260 if sigs.len() < self.threshold {
262 return Ok(false);
263 }
264
265 if sigs.len() > self.total {
267 return Ok(false);
268 }
269
270 Ok(sigs.len() >= self.threshold && self.pub_keys.len() == self.total)
274 }
275
276 fn auth_type(&self) -> &str {
277 "threshold"
278 }
279}
280
281#[derive(Debug)]
283pub struct CompositeWriteAuth {
284 auths: Vec<Box<dyn WriteAuth>>,
285 require_all: bool,
286}
287
288impl CompositeWriteAuth {
289 pub fn all(auths: Vec<Box<dyn WriteAuth>>) -> Self {
291 Self {
292 auths,
293 require_all: true,
294 }
295 }
296
297 pub fn any(auths: Vec<Box<dyn WriteAuth>>) -> Self {
299 Self {
300 auths,
301 require_all: false,
302 }
303 }
304}
305
306#[async_trait]
307impl WriteAuth for CompositeWriteAuth {
308 async fn verify(&self, record: &[u8], sigs: &[Sig]) -> Result<bool> {
309 if self.require_all {
310 for auth in &self.auths {
312 if !auth.verify(record, sigs).await? {
313 return Ok(false);
314 }
315 }
316 Ok(true)
317 } else {
318 for auth in &self.auths {
320 if auth.verify(record, sigs).await? {
321 return Ok(true);
322 }
323 }
324 Ok(false)
325 }
326 }
327
328 fn auth_type(&self) -> &str {
329 if self.require_all {
330 "composite_all"
331 } else {
332 "composite_any"
333 }
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[tokio::test]
342 async fn test_single_write_auth() {
343 let pub_key = PubKey::new(vec![0u8; 1952]);
345 let auth = SingleWriteAuth::new(pub_key);
346
347 let record = b"test record";
348 let sig = Sig::new(vec![0u8; 3309]);
350
351 let result = auth.verify(record, &[sig]).await;
354 assert!(result.is_err() || !result.unwrap());
355
356 assert_eq!(auth.auth_type(), "single");
357 }
358
359 #[tokio::test]
360 async fn test_delegated_write_auth() {
361 let key1 = PubKey::new(vec![0u8; 1952]);
363 let key2 = PubKey::new(vec![1u8; 1952]);
364 let mut auth = DelegatedWriteAuth::new(vec![key1.clone()]);
365 auth.add_key(key2);
366
367 let record = b"test record";
368 let sig = Sig::new(vec![0u8; 3309]);
370
371 let result = auth.verify(record, &[sig]).await;
374 assert!(result.is_err() || !result.unwrap());
375
376 assert_eq!(auth.auth_type(), "delegated");
377 }
378
379 #[tokio::test]
380 async fn test_threshold_auth() {
381 let keys = vec![
384 PubKey::new(vec![1; 32]),
385 PubKey::new(vec![2; 32]),
386 PubKey::new(vec![3; 32]),
387 ];
388
389 let auth = ThresholdWriteAuth::from_pub_keys(2, 3, keys).unwrap();
390
391 let sigs = vec![Sig::new(vec![1; 64]), Sig::new(vec![2; 64])];
393
394 let record = b"test";
395 let result = auth.verify(record, &sigs).await.unwrap();
397 assert!(result); assert_eq!(auth.threshold, 2);
400 assert_eq!(auth.total, 3);
401
402 let insufficient_sigs = vec![Sig::new(vec![1; 64])];
404 let result2 = auth.verify(record, &insufficient_sigs).await.unwrap();
405 assert!(!result2); }
407
408 #[tokio::test]
409 async fn test_composite_auth_all() {
410 let auth1 = Box::new(SingleWriteAuth::new(PubKey::new(vec![0u8; 1952])));
412 let auth2 = Box::new(SingleWriteAuth::new(PubKey::new(vec![1u8; 1952])));
413
414 let composite = CompositeWriteAuth::all(vec![auth1, auth2]);
415
416 let sig = Sig::new(vec![0u8; 3309]);
418 let result = composite.verify(b"test", &[sig]).await;
421 assert!(result.is_err() || !result.unwrap());
422
423 assert_eq!(composite.auth_type(), "composite_all");
424 }
425}