1use crate::error::SecurityError;
33use crate::quantum_crypto::ant_quic_integration::{MlDsaPublicKey, MlDsaSecretKey, MlDsaSignature};
34use crate::{NetworkAddress, P2PError, Result};
35use blake3::Hash;
36use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38use std::fmt;
39use std::time::{Duration, SystemTime, UNIX_EPOCH};
40use uuid::Uuid;
41
42pub const MAX_DHT_RECORD_SIZE: usize = 64 * 1024; pub const MAX_ENDPOINTS_PER_PEER: usize = 16;
47
48pub const MAX_TTL_SECONDS: u32 = 24 * 60 * 60;
50
51pub const DEFAULT_TTL_SECONDS: u32 = 5 * 60;
53
54#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
57pub struct UserId {
58 pub hash: [u8; 32],
60}
61
62impl UserId {
63 pub fn from_public_key(public_key: &MlDsaPublicKey) -> Self {
65 let hash = blake3::hash(public_key.as_bytes());
66 Self { hash: hash.into() }
67 }
68
69 pub fn from_bytes(bytes: [u8; 32]) -> Self {
71 Self { hash: bytes }
72 }
73
74 pub fn as_bytes(&self) -> &[u8; 32] {
76 &self.hash
77 }
78}
79
80impl fmt::Display for UserId {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 write!(f, "{}", hex::encode(&self.hash[..8]))
83 }
84}
85
86#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
88pub struct EndpointId {
89 pub uuid: Uuid,
91}
92
93impl EndpointId {
94 pub fn new() -> Self {
96 Self {
97 uuid: Uuid::new_v4(),
98 }
99 }
100
101 pub fn from_uuid(uuid: Uuid) -> Self {
103 Self { uuid }
104 }
105}
106
107impl Default for EndpointId {
108 fn default() -> Self {
109 Self::new()
110 }
111}
112
113impl fmt::Display for EndpointId {
114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115 write!(f, "{}", self.uuid)
116 }
117}
118
119pub type NodeId = String;
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
124pub enum NatType {
125 NoNat,
127 FullCone,
129 RestrictedCone,
131 PortRestricted,
133 Symmetric,
135 Unknown,
137}
138
139impl NatType {
140 pub fn supports_hole_punching(&self) -> bool {
142 matches!(
143 self,
144 NatType::NoNat | NatType::FullCone | NatType::RestrictedCone | NatType::PortRestricted
145 )
146 }
147
148 pub fn hole_punching_difficulty(&self) -> u8 {
150 match self {
151 NatType::NoNat => 100,
152 NatType::FullCone => 90,
153 NatType::RestrictedCone => 70,
154 NatType::PortRestricted => 50,
155 NatType::Symmetric => 10,
156 NatType::Unknown => 0,
157 }
158 }
159}
160
161impl fmt::Display for NatType {
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 match self {
164 NatType::NoNat => write!(f, "No NAT"),
165 NatType::FullCone => write!(f, "Full Cone"),
166 NatType::RestrictedCone => write!(f, "Restricted Cone"),
167 NatType::PortRestricted => write!(f, "Port Restricted"),
168 NatType::Symmetric => write!(f, "Symmetric"),
169 NatType::Unknown => write!(f, "Unknown"),
170 }
171 }
172}
173
174#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
176pub struct PeerEndpoint {
177 pub endpoint_id: EndpointId,
179
180 pub external_address: NetworkAddress,
182
183 pub nat_type: NatType,
185
186 pub coordinator_nodes: Vec<NodeId>,
188
189 pub device_info: Option<String>,
191
192 pub last_updated: u64,
194}
195
196impl PeerEndpoint {
197 pub fn new(
199 endpoint_id: EndpointId,
200 external_address: NetworkAddress,
201 nat_type: NatType,
202 coordinator_nodes: Vec<NodeId>,
203 device_info: Option<String>,
204 ) -> Self {
205 Self {
206 endpoint_id,
207 external_address,
208 nat_type,
209 coordinator_nodes,
210 device_info,
211 last_updated: current_timestamp(),
212 }
213 }
214
215 pub fn is_stale(&self, max_age: Duration) -> bool {
217 let age = current_timestamp().saturating_sub(self.last_updated);
218 age > max_age.as_secs()
219 }
220
221 pub fn refresh(&mut self) {
223 self.last_updated = current_timestamp();
224 }
225}
226
227#[derive(Clone)]
229pub struct PeerDHTRecord {
230 pub version: u8,
232
233 pub user_id: UserId,
235
236 pub public_key: MlDsaPublicKey,
238
239 pub sequence_number: u64,
241
242 pub name: Option<String>,
244
245 pub endpoints: Vec<PeerEndpoint>,
247
248 pub ttl: u32,
250
251 pub timestamp: u64,
253
254 pub signature: MlDsaSignature,
256}
257
258impl PeerDHTRecord {
259 pub const CURRENT_VERSION: u8 = 1;
261
262 pub fn new(
264 user_id: UserId,
265 public_key: MlDsaPublicKey,
266 sequence_number: u64,
267 name: Option<String>,
268 endpoints: Vec<PeerEndpoint>,
269 ttl: u32,
270 ) -> Result<Self> {
271 Self::validate_inputs(&name, &endpoints, ttl)?;
273
274 Ok(Self {
275 version: Self::CURRENT_VERSION,
276 user_id,
277 public_key,
278 sequence_number,
279 name,
280 endpoints,
281 timestamp: current_timestamp(),
282 ttl,
283 signature: {
284 let sig_bytes = [0u8; 3309];
286 MlDsaSignature(Box::new(sig_bytes))
287 },
288 })
289 }
290
291 fn validate_inputs(name: &Option<String>, endpoints: &[PeerEndpoint], ttl: u32) -> Result<()> {
293 if let Some(name) = name {
295 if name.len() > 255 {
296 return Err(P2PError::Config(crate::error::ConfigError::InvalidValue {
297 field: "name".to_string().into(),
298 reason: format!("Name too long (max 255), got {} chars", name.len()).into(),
299 }));
300 }
301 if name.is_empty() {
302 return Err(P2PError::Config(crate::error::ConfigError::InvalidValue {
303 field: "name".to_string().into(),
304 reason: "Name cannot be empty".to_string().into(),
305 }));
306 }
307 }
308
309 if endpoints.is_empty() {
311 return Err(P2PError::Config(crate::error::ConfigError::InvalidValue {
312 field: "endpoints".to_string().into(),
313 reason: "At least one endpoint required".to_string().into(),
314 }));
315 }
316 if endpoints.len() > MAX_ENDPOINTS_PER_PEER {
317 return Err(P2PError::Config(crate::error::ConfigError::InvalidValue {
318 field: "endpoints".to_string().into(),
319 reason: format!(
320 "Too many endpoints ({}, max {})",
321 endpoints.len(),
322 MAX_ENDPOINTS_PER_PEER
323 )
324 .into(),
325 }));
326 }
327
328 if ttl == 0 {
330 return Err(P2PError::Config(crate::error::ConfigError::InvalidValue {
331 field: "ttl".to_string().into(),
332 reason: "TTL cannot be zero".to_string().into(),
333 }));
334 }
335 if ttl > MAX_TTL_SECONDS {
336 return Err(P2PError::Config(crate::error::ConfigError::InvalidValue {
337 field: "ttl".to_string().into(),
338 reason: format!("TTL too large ({}, max {})", ttl, MAX_TTL_SECONDS).into(),
339 }));
340 }
341
342 Ok(())
343 }
344
345 pub fn create_signable_message(&self) -> Result<Vec<u8>> {
347 let mut message = Vec::new();
348
349 message.push(self.version);
351
352 message.extend_from_slice(&self.user_id.hash);
354
355 message.extend_from_slice(self.public_key.as_bytes());
357
358 message.extend_from_slice(&self.sequence_number.to_be_bytes());
360
361 if let Some(ref name) = self.name {
363 let name_bytes = name.as_bytes();
364 message.extend_from_slice(&(name_bytes.len() as u32).to_be_bytes());
365 message.extend_from_slice(name_bytes);
366 } else {
367 message.extend_from_slice(&0u32.to_be_bytes());
368 }
369
370 let endpoints_data = bincode::serialize(&self.endpoints).map_err(|e| {
372 P2PError::Storage(crate::error::StorageError::Database(
373 format!("Failed to serialize endpoints: {}", e).into(),
374 ))
375 })?;
376 message.extend_from_slice(&(endpoints_data.len() as u32).to_be_bytes());
377 message.extend_from_slice(&endpoints_data);
378
379 message.extend_from_slice(&self.timestamp.to_be_bytes());
381
382 message.extend_from_slice(&self.ttl.to_be_bytes());
384
385 Ok(message)
386 }
387
388 pub fn sign(&mut self, signing_key: &MlDsaSecretKey) -> Result<()> {
390 let message = self.create_signable_message()?;
391 self.signature =
392 crate::quantum_crypto::ml_dsa_sign(signing_key, &message).map_err(|e| {
393 P2PError::Security(SecurityError::SignatureVerificationFailed(
394 format!("ML-DSA signing failed: {:?}", e).into(),
395 ))
396 })?;
397 Ok(())
398 }
399
400 pub fn verify_signature(&self) -> Result<()> {
402 let message = self.create_signable_message()?;
403 let ok = crate::quantum_crypto::ml_dsa_verify(&self.public_key, &message, &self.signature)
404 .map_err(|e| {
405 P2PError::Security(SecurityError::SignatureVerificationFailed(
406 format!("ML-DSA verify error: {:?}", e).into(),
407 ))
408 })?;
409 if ok {
410 Ok(())
411 } else {
412 Err(P2PError::Security(
413 SecurityError::SignatureVerificationFailed(
414 "Failed to verify signature".to_string().into(),
415 ),
416 ))
417 }
418 }
419
420 pub fn is_expired(&self) -> bool {
422 let age = current_timestamp().saturating_sub(self.timestamp);
423 age > self.ttl as u64
424 }
425
426 pub fn remaining_ttl(&self) -> u32 {
428 let age = current_timestamp().saturating_sub(self.timestamp);
429 if age >= self.ttl as u64 {
430 0
431 } else {
432 self.ttl - age as u32
433 }
434 }
435
436 pub fn content_hash(&self) -> Hash {
438 let mut hasher = blake3::Hasher::new();
439 hasher.update(&self.user_id.hash);
440 hasher.update(&self.sequence_number.to_be_bytes());
441 hasher.update(&self.timestamp.to_be_bytes());
442 hasher.finalize()
443 }
444}
445
446fn current_timestamp() -> u64 {
448 SystemTime::now()
449 .duration_since(UNIX_EPOCH)
450 .map(|d| d.as_secs())
451 .unwrap_or(0)
452}
453
454pub struct SignatureCache {
456 cache: HashMap<Hash, bool>,
457 max_size: usize,
458}
459
460impl SignatureCache {
461 pub fn new(max_size: usize) -> Self {
463 Self {
464 cache: HashMap::new(),
465 max_size,
466 }
467 }
468
469 pub fn verify_cached(&mut self, record: &PeerDHTRecord) -> Result<()> {
471 let hash = record.content_hash();
472
473 if let Some(&result) = self.cache.get(&hash) {
475 return if result {
476 Ok(())
477 } else {
478 Err(P2PError::Security(
479 SecurityError::SignatureVerificationFailed(
480 "Invalid signature in cache".to_string().into(),
481 ),
482 ))
483 };
484 }
485
486 let result = record.verify_signature();
488 let success = result.is_ok();
489
490 if self.cache.len() >= self.max_size {
492 if let Some(key) = self.cache.keys().next().cloned() {
494 self.cache.remove(&key);
495 }
496 }
497 self.cache.insert(hash, success);
498
499 result
500 }
501
502 pub fn clear(&mut self) {
504 self.cache.clear();
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511 fn create_test_keypair() -> (MlDsaSecretKey, MlDsaPublicKey) {
514 let (public_key, secret_key) = crate::quantum_crypto::generate_ml_dsa_keypair().unwrap();
515 (secret_key, public_key)
516 }
517
518 fn create_test_endpoint() -> PeerEndpoint {
519 PeerEndpoint::new(
520 EndpointId::new(),
521 "192.168.1.1:8080".parse::<NetworkAddress>().unwrap(),
522 NatType::FullCone,
523 vec!["coordinator1".to_string()],
524 Some("test-device".to_string()),
525 )
526 }
527
528 #[test]
529 fn test_user_id_generation() {
530 let (_, public_key) = create_test_keypair();
531 let user_id = UserId::from_public_key(&public_key);
532
533 let user_id2 = UserId::from_public_key(&public_key);
535 assert_eq!(user_id, user_id2);
536 }
537
538 #[test]
539 fn test_nat_type_hole_punching() {
540 assert!(NatType::NoNat.supports_hole_punching());
541 assert!(NatType::FullCone.supports_hole_punching());
542 assert!(NatType::RestrictedCone.supports_hole_punching());
543 assert!(NatType::PortRestricted.supports_hole_punching());
544 assert!(!NatType::Symmetric.supports_hole_punching());
545 assert!(!NatType::Unknown.supports_hole_punching());
546 }
547
548 #[test]
549 fn test_peer_endpoint_creation() {
550 let endpoint = create_test_endpoint();
551 assert!(!endpoint.is_stale(Duration::from_secs(60)));
552
553 let mut old_endpoint = endpoint.clone();
554 old_endpoint.last_updated = current_timestamp() - 120; assert!(old_endpoint.is_stale(Duration::from_secs(60)));
556 }
557
558 #[test]
559 fn test_dht_record_creation_and_signing() {
560 let (secret_key, public_key) = create_test_keypair();
561 let user_id = UserId::from_public_key(&public_key);
562 let endpoint = create_test_endpoint();
563
564 let mut record = PeerDHTRecord::new(
565 user_id,
566 public_key,
567 1,
568 Some("test-user".to_string()),
569 vec![endpoint],
570 DEFAULT_TTL_SECONDS,
571 )
572 .unwrap();
573
574 record.sign(&secret_key).unwrap();
576
577 assert!(record.verify_signature().is_ok());
579
580 assert!(!record.is_expired());
582 assert!(record.remaining_ttl() > 0);
583 }
584
585 #[test]
586 fn test_signature_cache() {
587 let (secret_key, public_key) = create_test_keypair();
588 let user_id = UserId::from_public_key(&public_key);
589 let endpoint = create_test_endpoint();
590
591 let mut record = PeerDHTRecord::new(
592 user_id,
593 public_key,
594 1,
595 Some("test-user".to_string()),
596 vec![endpoint],
597 DEFAULT_TTL_SECONDS,
598 )
599 .unwrap();
600
601 record.sign(&secret_key).unwrap();
602
603 let mut cache = SignatureCache::new(100);
604
605 assert!(cache.verify_cached(&record).is_ok());
607
608 assert!(cache.verify_cached(&record).is_ok());
610 }
611
612 #[test]
613 fn test_validation_limits() {
614 let (_, public_key) = create_test_keypair();
615 let user_id = UserId::from_public_key(&public_key);
616
617 let long_name = "a".repeat(256);
619 let result = PeerDHTRecord::new(
620 user_id.clone(),
621 public_key.clone(),
622 1,
623 Some(long_name),
624 vec![create_test_endpoint()],
625 DEFAULT_TTL_SECONDS,
626 );
627 assert!(result.is_err());
628
629 let many_endpoints = vec![create_test_endpoint(); MAX_ENDPOINTS_PER_PEER + 1];
631 let result = PeerDHTRecord::new(
632 user_id.clone(),
633 public_key.clone(),
634 1,
635 Some("test".to_string()),
636 many_endpoints,
637 DEFAULT_TTL_SECONDS,
638 );
639 assert!(result.is_err());
640
641 let result = PeerDHTRecord::new(
643 user_id,
644 public_key,
645 1,
646 Some("test".to_string()),
647 vec![create_test_endpoint()],
648 MAX_TTL_SECONDS + 1,
649 );
650 assert!(result.is_err());
651 }
652}