1#![allow(missing_docs)]
32
33use crate::error::StorageError;
34use crate::peer_record::UserId;
35use crate::{P2PError, Result};
36use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38use std::path::PathBuf;
39use std::sync::{Arc, RwLock};
40use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
41use tokio::sync::Mutex;
42use tokio::time::interval;
43
44const MAX_SEQUENCE_HISTORY: usize = 1000;
46
47const DEFAULT_SYNC_INTERVAL: Duration = Duration::from_secs(30);
49
50const MAX_SEQUENCE_AGE: Duration = Duration::from_secs(3600); pub struct MonotonicCounterSystem {
55 counters: Arc<RwLock<HashMap<UserId, PeerCounter>>>,
57 storage_path: PathBuf,
59 sync_interval: Duration,
61 sync_task: Option<tokio::task::JoinHandle<()>>,
63 stats: Arc<Mutex<CounterStats>>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct PeerCounter {
70 pub current_sequence: u64,
72 pub last_valid_sequence: u64,
74 pub sequence_history: Vec<SequenceEntry>,
76 pub last_updated: u64,
78 pub replay_attempts: u64,
80 pub sequence_gaps: u64,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct SequenceEntry {
87 pub sequence: u64,
89 pub timestamp: u64,
91 pub message_hash: [u8; 32],
93}
94
95#[derive(Debug, Clone, Default)]
97pub struct CounterStats {
98 pub total_processed: u64,
100 pub total_replays: u64,
102 pub total_gaps: u64,
104 pub peers_tracked: usize,
106 pub persistence_ops: u64,
108 pub avg_validation_time_us: u64,
110 pub cache_hits: u64,
112 pub cache_misses: u64,
114}
115
116#[derive(Debug, Clone, PartialEq)]
118pub enum SequenceValidationResult {
119 Valid,
121 Replay,
123 TooOld,
125 Gap { expected: u64, received: u64 },
127 FromFuture,
129}
130
131pub struct BatchUpdateRequest {
133 pub user_id: UserId,
135 pub sequence: u64,
137 pub message_hash: [u8; 32],
139 pub timestamp: u64,
141}
142
143pub struct BatchUpdateResult {
145 pub user_id: UserId,
147 pub result: SequenceValidationResult,
149 pub applied: bool,
151}
152
153impl MonotonicCounterSystem {
154 pub async fn new(storage_path: PathBuf) -> Result<Self> {
156 Self::new_with_sync_interval(storage_path, DEFAULT_SYNC_INTERVAL).await
157 }
158
159 pub async fn new_with_sync_interval(
161 storage_path: PathBuf,
162 sync_interval: Duration,
163 ) -> Result<Self> {
164 if let Some(parent) = storage_path.parent() {
166 tokio::fs::create_dir_all(parent).await.map_err(|e| {
167 P2PError::Storage(StorageError::Database(
168 format!("Failed to create storage directory: {e}").into(),
169 ))
170 })?;
171 }
172
173 let counters = Self::load_counters(&storage_path).await?;
175
176 let system = Self {
177 counters: Arc::new(RwLock::new(counters)),
178 storage_path,
179 sync_interval,
180 sync_task: None,
181 stats: Arc::new(Mutex::new(CounterStats::default())),
182 };
183
184 Ok(system)
185 }
186
187 pub async fn start_sync_task(&mut self) -> Result<()> {
189 if self.sync_task.is_some() {
190 return Ok(()); }
192
193 let counters = self.counters.clone();
194 let storage_path = self.storage_path.clone();
195 let sync_interval = self.sync_interval;
196 let stats = self.stats.clone();
197
198 let task = tokio::spawn(async move {
199 let mut interval = interval(sync_interval);
200 loop {
201 interval.tick().await;
202
203 if let Err(e) = Self::sync_counters(&counters, &storage_path, &stats).await {
204 tracing::warn!("Failed to sync counters to storage: {}", e);
205 }
206 }
207 });
208
209 self.sync_task = Some(task);
210 Ok(())
211 }
212
213 pub async fn stop_sync_task(&mut self) {
215 if let Some(task) = self.sync_task.take() {
216 task.abort();
217 }
218 }
219
220 pub async fn validate_sequence(
222 &self,
223 user_id: &UserId,
224 sequence: u64,
225 message_hash: [u8; 32],
226 ) -> Result<SequenceValidationResult> {
227 let start_time = Instant::now();
228 let timestamp = current_timestamp();
229
230 let validation_result = {
232 let mut counters = self.counters.write().map_err(|_| {
233 P2PError::Storage(StorageError::LockPoisoned(
234 "write lock failed".to_string().into(),
235 ))
236 })?;
237 let peer_counter = counters
238 .entry(user_id.clone())
239 .or_insert_with(PeerCounter::new);
240
241 let result =
243 self.validate_sequence_internal(peer_counter, sequence, message_hash, timestamp);
244
245 if let SequenceValidationResult::Valid = result {
247 peer_counter.apply_sequence_update(sequence, message_hash, timestamp);
249 }
250
251 result
252 };
253
254 self.update_validation_stats(start_time, &validation_result)
256 .await;
257
258 Ok(validation_result)
259 }
260
261 fn validate_sequence_internal(
263 &self,
264 peer_counter: &PeerCounter,
265 sequence: u64,
266 message_hash: [u8; 32],
267 timestamp: u64,
268 ) -> SequenceValidationResult {
269 let current_time = current_timestamp();
271 if timestamp > current_time + 60 {
272 return SequenceValidationResult::FromFuture;
273 }
274
275 if timestamp < current_time.saturating_sub(MAX_SEQUENCE_AGE.as_secs()) {
277 return SequenceValidationResult::TooOld;
278 }
279
280 if peer_counter.has_seen_sequence(sequence, message_hash) {
282 return SequenceValidationResult::Replay;
283 }
284
285 if sequence > peer_counter.last_valid_sequence + 1 {
287 return SequenceValidationResult::Gap {
288 expected: peer_counter.last_valid_sequence + 1,
289 received: sequence,
290 };
291 }
292
293 if sequence <= peer_counter.last_valid_sequence {
295 return SequenceValidationResult::Replay;
296 }
297
298 SequenceValidationResult::Valid
299 }
300
301 pub async fn batch_update(
303 &self,
304 requests: Vec<BatchUpdateRequest>,
305 ) -> Result<Vec<BatchUpdateResult>> {
306 let mut results = Vec::with_capacity(requests.len());
307
308 {
310 let mut counters = self.counters.write().map_err(|_| {
311 P2PError::Storage(StorageError::LockPoisoned(
312 "write lock failed".to_string().into(),
313 ))
314 })?;
315
316 for request in requests {
317 let peer_counter = counters
318 .entry(request.user_id.clone())
319 .or_insert_with(PeerCounter::new);
320
321 let validation_result = self.validate_sequence_internal(
322 peer_counter,
323 request.sequence,
324 request.message_hash,
325 request.timestamp,
326 );
327
328 let applied = matches!(validation_result, SequenceValidationResult::Valid);
329
330 if applied {
331 peer_counter.apply_sequence_update(
332 request.sequence,
333 request.message_hash,
334 request.timestamp,
335 );
336 }
337
338 results.push(BatchUpdateResult {
339 user_id: request.user_id,
340 result: validation_result,
341 applied,
342 });
343 }
344 }
345
346 {
348 let mut stats = self.stats.lock().await;
349 stats.total_processed += results.len() as u64;
350 stats.total_replays += results
351 .iter()
352 .filter(|r| matches!(r.result, SequenceValidationResult::Replay))
353 .count() as u64;
354 stats.total_gaps += results
355 .iter()
356 .filter(|r| matches!(r.result, SequenceValidationResult::Gap { .. }))
357 .count() as u64;
358 }
359
360 Ok(results)
361 }
362
363 pub async fn get_stats(&self) -> CounterStats {
365 let stats = self.stats.lock().await;
366 let mut current_stats = stats.clone();
367
368 let counters = self.counters.read().unwrap_or_else(|e| e.into_inner());
370 current_stats.peers_tracked = counters.len();
371
372 current_stats
373 }
374
375 pub async fn get_peer_counter(&self, user_id: &UserId) -> Option<PeerCounter> {
377 let counters = self.counters.read().ok()?;
378 counters.get(user_id).cloned()
379 }
380
381 pub async fn reset_peer_counter(&self, user_id: &UserId) -> Result<()> {
383 let mut counters = self.counters.write().map_err(|_| {
384 P2PError::Storage(StorageError::LockPoisoned(
385 "write lock failed".to_string().into(),
386 ))
387 })?;
388 counters.remove(user_id);
389 Ok(())
390 }
391
392 pub async fn cleanup_old_sequences(&self) -> Result<()> {
394 let current_time = current_timestamp();
395 let cutoff_time = current_time.saturating_sub(MAX_SEQUENCE_AGE.as_secs());
396
397 let mut counters = self.counters.write().map_err(|_| {
398 P2PError::Storage(StorageError::LockPoisoned(
399 "write lock failed".to_string().into(),
400 ))
401 })?;
402 for (_, peer_counter) in counters.iter_mut() {
403 peer_counter.cleanup_old_sequences(cutoff_time);
404 }
405
406 Ok(())
407 }
408
409 async fn load_counters(storage_path: &PathBuf) -> Result<HashMap<UserId, PeerCounter>> {
411 if !storage_path.exists() {
412 return Ok(HashMap::new());
413 }
414
415 let data = tokio::fs::read(storage_path).await.map_err(|e| {
416 P2PError::Storage(StorageError::Database(
417 format!("Failed to read counters file: {e}").into(),
418 ))
419 })?;
420
421 let counters: HashMap<UserId, PeerCounter> = postcard::from_bytes(&data).map_err(|e| {
422 P2PError::Storage(StorageError::Database(
423 format!("Failed to deserialize counters: {e}").into(),
424 ))
425 })?;
426
427 Ok(counters)
428 }
429
430 async fn sync_counters(
432 counters: &Arc<RwLock<HashMap<UserId, PeerCounter>>>,
433 storage_path: &PathBuf,
434 stats: &Arc<Mutex<CounterStats>>,
435 ) -> Result<()> {
436 let counters_snapshot = {
437 let counters = counters.read().map_err(|_| {
438 P2PError::Storage(StorageError::LockPoisoned(
439 "read lock failed".to_string().into(),
440 ))
441 })?;
442 counters.clone()
443 };
444
445 let data = postcard::to_stdvec(&counters_snapshot).map_err(|e| {
446 P2PError::Storage(StorageError::Database(
447 format!("Failed to serialize counters: {e}").into(),
448 ))
449 })?;
450
451 tokio::fs::write(storage_path, data).await.map_err(|e| {
452 P2PError::Storage(StorageError::Database(
453 format!("Failed to write counters file: {e}").into(),
454 ))
455 })?;
456
457 {
459 let mut stats = stats.lock().await;
460 stats.persistence_ops += 1;
461 }
462
463 Ok(())
464 }
465
466 async fn update_validation_stats(
468 &self,
469 start_time: Instant,
470 result: &SequenceValidationResult,
471 ) {
472 let elapsed = start_time.elapsed().as_micros() as u64;
473 let mut stats = self.stats.lock().await;
474
475 let total_ops = stats.total_processed + 1;
477 stats.avg_validation_time_us =
478 (stats.avg_validation_time_us * stats.total_processed + elapsed) / total_ops;
479
480 stats.total_processed = total_ops;
481
482 match result {
483 SequenceValidationResult::Replay => stats.total_replays += 1,
484 SequenceValidationResult::Gap { .. } => stats.total_gaps += 1,
485 _ => {}
486 }
487 }
488}
489
490impl PeerCounter {
491 pub fn new() -> Self {
493 Self {
494 current_sequence: 0,
495 last_valid_sequence: 0,
496 sequence_history: Vec::new(),
497 last_updated: current_timestamp(),
498 replay_attempts: 0,
499 sequence_gaps: 0,
500 }
501 }
502
503 pub fn has_seen_sequence(&self, sequence: u64, message_hash: [u8; 32]) -> bool {
505 self.sequence_history
506 .iter()
507 .any(|entry| entry.sequence == sequence && entry.message_hash == message_hash)
508 }
509
510 pub fn apply_sequence_update(&mut self, sequence: u64, message_hash: [u8; 32], timestamp: u64) {
512 self.current_sequence = sequence;
514 self.last_valid_sequence = sequence;
515 self.last_updated = timestamp;
516
517 self.sequence_history.push(SequenceEntry {
519 sequence,
520 timestamp,
521 message_hash,
522 });
523
524 if self.sequence_history.len() > MAX_SEQUENCE_HISTORY {
526 self.sequence_history.remove(0);
527 }
528 }
529
530 pub fn cleanup_old_sequences(&mut self, cutoff_time: u64) {
532 self.sequence_history
533 .retain(|entry| entry.timestamp >= cutoff_time);
534 }
535
536 pub fn next_expected_sequence(&self) -> u64 {
538 self.last_valid_sequence + 1
539 }
540}
541
542impl Default for PeerCounter {
543 fn default() -> Self {
544 Self::new()
545 }
546}
547
548fn current_timestamp() -> u64 {
550 SystemTime::now()
551 .duration_since(UNIX_EPOCH)
552 .map(|d| d.as_secs())
553 .unwrap_or(0)
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559 use tempfile::tempdir;
560 use tokio::test;
561
562 #[test]
563 async fn test_sequence_validation() {
564 let temp_dir = tempdir().unwrap();
565 let storage_path = temp_dir.path().join("counters.bin");
566 let system = MonotonicCounterSystem::new(storage_path).await.unwrap();
567
568 let user_id = UserId::from_bytes([1; 32]);
569 let message_hash = *blake3::hash(b"test message").as_bytes();
570
571 let result = system
573 .validate_sequence(&user_id, 1, message_hash)
574 .await
575 .unwrap();
576 assert_eq!(result, SequenceValidationResult::Valid);
577
578 let result = system
580 .validate_sequence(&user_id, 1, message_hash)
581 .await
582 .unwrap();
583 assert_eq!(result, SequenceValidationResult::Replay);
584
585 let result = system
587 .validate_sequence(&user_id, 2, message_hash)
588 .await
589 .unwrap();
590 assert_eq!(result, SequenceValidationResult::Valid);
591
592 let result = system
594 .validate_sequence(&user_id, 5, message_hash)
595 .await
596 .unwrap();
597 assert_eq!(
598 result,
599 SequenceValidationResult::Gap {
600 expected: 3,
601 received: 5
602 }
603 );
604 }
605
606 #[test]
607 async fn test_batch_updates() {
608 let temp_dir = tempdir().unwrap();
609 let storage_path = temp_dir.path().join("counters.bin");
610 let system = MonotonicCounterSystem::new(storage_path).await.unwrap();
611
612 let user_id1 = UserId::from_bytes([1; 32]);
613 let user_id2 = UserId::from_bytes([2; 32]);
614 let message_hash = *blake3::hash(b"test message").as_bytes();
615
616 let requests = vec![
617 BatchUpdateRequest {
618 user_id: user_id1.clone(),
619 sequence: 1,
620 message_hash,
621 timestamp: current_timestamp(),
622 },
623 BatchUpdateRequest {
624 user_id: user_id2.clone(),
625 sequence: 1,
626 message_hash,
627 timestamp: current_timestamp(),
628 },
629 ];
630
631 let results = system.batch_update(requests).await.unwrap();
632 assert_eq!(results.len(), 2);
633 assert!(results.iter().all(|r| r.applied));
634 assert!(
635 results
636 .iter()
637 .all(|r| matches!(r.result, SequenceValidationResult::Valid))
638 );
639 }
640
641 #[test]
642 async fn test_persistence() {
643 let temp_dir = tempdir().unwrap();
644 let storage_path = temp_dir.path().join("counters.bin");
645
646 {
648 let system = MonotonicCounterSystem::new(storage_path.clone())
649 .await
650 .unwrap();
651 let user_id = UserId::from_bytes([1; 32]);
652 let message_hash = *blake3::hash(b"test message").as_bytes();
653
654 system
655 .validate_sequence(&user_id, 1, message_hash)
656 .await
657 .unwrap();
658 system
659 .validate_sequence(&user_id, 2, message_hash)
660 .await
661 .unwrap();
662
663 MonotonicCounterSystem::sync_counters(&system.counters, &storage_path, &system.stats)
665 .await
666 .unwrap();
667 }
668
669 {
671 let system = MonotonicCounterSystem::new(storage_path).await.unwrap();
672 let user_id = UserId::from_bytes([1; 32]);
673 let counter = system.get_peer_counter(&user_id).await.unwrap();
674
675 assert_eq!(counter.last_valid_sequence, 2);
676 assert_eq!(counter.sequence_history.len(), 2);
677 }
678 }
679
680 #[test]
681 async fn test_old_sequence_cleanup() {
682 let temp_dir = tempdir().unwrap();
683 let storage_path = temp_dir.path().join("counters.bin");
684 let system = MonotonicCounterSystem::new(storage_path).await.unwrap();
685
686 let user_id = UserId::from_bytes([1; 32]);
687 let message_hash = *blake3::hash(b"test message").as_bytes();
688
689 for i in 1..=10 {
691 system
692 .validate_sequence(&user_id, i, message_hash)
693 .await
694 .unwrap();
695 }
696
697 let counter = system.get_peer_counter(&user_id).await.unwrap();
699 assert_eq!(counter.sequence_history.len(), 10);
700
701 system.cleanup_old_sequences().await.unwrap();
703 let counter = system.get_peer_counter(&user_id).await.unwrap();
704 assert_eq!(counter.sequence_history.len(), 10);
705 }
706
707 #[test]
708 async fn test_statistics() {
709 let temp_dir = tempdir().unwrap();
710 let storage_path = temp_dir.path().join("counters.bin");
711 let system = MonotonicCounterSystem::new(storage_path).await.unwrap();
712
713 let user_id = UserId::from_bytes([1; 32]);
714 let message_hash = *blake3::hash(b"test message").as_bytes();
715
716 system
718 .validate_sequence(&user_id, 1, message_hash)
719 .await
720 .unwrap();
721 system
722 .validate_sequence(&user_id, 1, message_hash)
723 .await
724 .unwrap(); system
726 .validate_sequence(&user_id, 5, message_hash)
727 .await
728 .unwrap(); let stats = system.get_stats().await;
731 assert_eq!(stats.total_processed, 3);
732 assert_eq!(stats.total_replays, 1);
733 assert_eq!(stats.total_gaps, 1);
734 assert_eq!(stats.peers_tracked, 1);
735 }
736}