Skip to main content

saorsa_core/
monotonic_counter.rs

1// Copyright 2024 Saorsa Labs Limited
2//
3// This software is dual-licensed under:
4// - GNU Affero General Public License v3.0 or later (AGPL-3.0-or-later)
5// - Commercial License
6//
7// For AGPL-3.0 license, see LICENSE-AGPL-3.0
8// For commercial licensing, contact: david@saorsalabs.com
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under these licenses is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
14//! # Monotonic Counter System for Replay Attack Prevention
15//!
16//! This module provides a secure monotonic counter system that prevents replay attacks
17//! by ensuring sequence numbers always increase and cannot be reused.
18//!
19//! ## Security Features
20//! - Atomic operations to prevent race conditions
21//! - Persistent storage with crash recovery
22//! - Sequence validation with gap detection
23//! - Memory-efficient tracking for multiple peers
24//!
25//! ## Performance Features
26//! - Batch updates for multiple counters
27//! - Efficient in-memory caching
28//! - Background persistence to avoid blocking
29//! - Configurable sync intervals
30
31#![allow(missing_docs)]
32
33use crate::error::StorageError;
34use crate::peer_record::UserId;
35use crate::{P2PError, Result};
36use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38use std::path::PathBuf;
39use std::sync::{Arc, RwLock};
40use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
41use tokio::sync::Mutex;
42use tokio::time::interval;
43
44/// Maximum number of sequence numbers to remember per peer
45const MAX_SEQUENCE_HISTORY: usize = 1000;
46
47/// Default sync interval for persisting counters to disk
48const DEFAULT_SYNC_INTERVAL: Duration = Duration::from_secs(30);
49
50/// Maximum age for sequence numbers before they're considered stale
51const MAX_SEQUENCE_AGE: Duration = Duration::from_secs(3600); // 1 hour
52
53/// Monotonic counter system for preventing replay attacks
54pub struct MonotonicCounterSystem {
55    /// In-memory counter cache
56    counters: Arc<RwLock<HashMap<UserId, PeerCounter>>>,
57    /// Persistent storage path
58    storage_path: PathBuf,
59    /// Sync interval for persistence
60    sync_interval: Duration,
61    /// Background sync task handle
62    sync_task: Option<tokio::task::JoinHandle<()>>,
63    /// System statistics
64    stats: Arc<Mutex<CounterStats>>,
65}
66
67/// Per-peer counter state
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct PeerCounter {
70    /// Current sequence number
71    pub current_sequence: u64,
72    /// Last valid sequence number received
73    pub last_valid_sequence: u64,
74    /// History of recent sequence numbers to detect replays
75    pub sequence_history: Vec<SequenceEntry>,
76    /// Timestamp of last update
77    pub last_updated: u64,
78    /// Number of replay attempts detected
79    pub replay_attempts: u64,
80    /// Sequence number gaps detected
81    pub sequence_gaps: u64,
82}
83
84/// Sequence number entry with metadata
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct SequenceEntry {
87    /// Sequence number
88    pub sequence: u64,
89    /// Timestamp when received
90    pub timestamp: u64,
91    /// Hash of the message for deduplication
92    pub message_hash: [u8; 32],
93}
94
95/// Statistics for monitoring counter system performance
96#[derive(Debug, Clone, Default)]
97pub struct CounterStats {
98    /// Total sequence numbers processed
99    pub total_processed: u64,
100    /// Total replay attempts detected
101    pub total_replays: u64,
102    /// Total sequence gaps detected
103    pub total_gaps: u64,
104    /// Number of peers tracked
105    pub peers_tracked: usize,
106    /// Number of persistence operations
107    pub persistence_ops: u64,
108    /// Average validation time in microseconds
109    pub avg_validation_time_us: u64,
110    /// Number of cache hits
111    pub cache_hits: u64,
112    /// Number of cache misses
113    pub cache_misses: u64,
114}
115
116/// Result of sequence validation
117#[derive(Debug, Clone, PartialEq)]
118pub enum SequenceValidationResult {
119    /// Sequence is valid and accepted
120    Valid,
121    /// Sequence is a replay (already seen)
122    Replay,
123    /// Sequence is too old
124    TooOld,
125    /// Sequence has a gap (missing intermediate sequences)
126    Gap { expected: u64, received: u64 },
127    /// Sequence is from the future (clock skew)
128    FromFuture,
129}
130
131/// Batch update request for multiple counters
132pub struct BatchUpdateRequest {
133    /// User ID
134    pub user_id: UserId,
135    /// Sequence number
136    pub sequence: u64,
137    /// Message hash for deduplication
138    pub message_hash: [u8; 32],
139    /// Timestamp
140    pub timestamp: u64,
141}
142
143/// Result of batch update
144pub struct BatchUpdateResult {
145    /// User ID
146    pub user_id: UserId,
147    /// Validation result
148    pub result: SequenceValidationResult,
149    /// Whether the update was applied
150    pub applied: bool,
151}
152
153impl MonotonicCounterSystem {
154    /// Create a new monotonic counter system
155    pub async fn new(storage_path: PathBuf) -> Result<Self> {
156        Self::new_with_sync_interval(storage_path, DEFAULT_SYNC_INTERVAL).await
157    }
158
159    /// Create a new monotonic counter system with custom sync interval
160    pub async fn new_with_sync_interval(
161        storage_path: PathBuf,
162        sync_interval: Duration,
163    ) -> Result<Self> {
164        // Ensure storage directory exists
165        if let Some(parent) = storage_path.parent() {
166            tokio::fs::create_dir_all(parent).await.map_err(|e| {
167                P2PError::Storage(StorageError::Database(
168                    format!("Failed to create storage directory: {e}").into(),
169                ))
170            })?;
171        }
172
173        // Load existing counters from storage
174        let counters = Self::load_counters(&storage_path).await?;
175
176        let system = Self {
177            counters: Arc::new(RwLock::new(counters)),
178            storage_path,
179            sync_interval,
180            sync_task: None,
181            stats: Arc::new(Mutex::new(CounterStats::default())),
182        };
183
184        Ok(system)
185    }
186
187    /// Start the background sync task
188    pub async fn start_sync_task(&mut self) -> Result<()> {
189        if self.sync_task.is_some() {
190            return Ok(()); // Already started
191        }
192
193        let counters = self.counters.clone();
194        let storage_path = self.storage_path.clone();
195        let sync_interval = self.sync_interval;
196        let stats = self.stats.clone();
197
198        let task = tokio::spawn(async move {
199            let mut interval = interval(sync_interval);
200            loop {
201                interval.tick().await;
202
203                if let Err(e) = Self::sync_counters(&counters, &storage_path, &stats).await {
204                    tracing::warn!("Failed to sync counters to storage: {}", e);
205                }
206            }
207        });
208
209        self.sync_task = Some(task);
210        Ok(())
211    }
212
213    /// Stop the background sync task
214    pub async fn stop_sync_task(&mut self) {
215        if let Some(task) = self.sync_task.take() {
216            task.abort();
217        }
218    }
219
220    /// Validate and update sequence number for a peer
221    pub async fn validate_sequence(
222        &self,
223        user_id: &UserId,
224        sequence: u64,
225        message_hash: [u8; 32],
226    ) -> Result<SequenceValidationResult> {
227        let start_time = Instant::now();
228        let timestamp = current_timestamp();
229
230        // Get or create peer counter
231        let validation_result = {
232            let mut counters = self.counters.write().map_err(|_| {
233                P2PError::Storage(StorageError::LockPoisoned(
234                    "write lock failed".to_string().into(),
235                ))
236            })?;
237            let peer_counter = counters
238                .entry(user_id.clone())
239                .or_insert_with(PeerCounter::new);
240
241            // Validate the sequence number
242            let result =
243                self.validate_sequence_internal(peer_counter, sequence, message_hash, timestamp);
244
245            // Update statistics
246            if let SequenceValidationResult::Valid = result {
247                // Apply the update
248                peer_counter.apply_sequence_update(sequence, message_hash, timestamp);
249            }
250
251            result
252        };
253
254        // Update performance statistics
255        self.update_validation_stats(start_time, &validation_result)
256            .await;
257
258        Ok(validation_result)
259    }
260
261    /// Validate sequence number without updating state
262    fn validate_sequence_internal(
263        &self,
264        peer_counter: &PeerCounter,
265        sequence: u64,
266        message_hash: [u8; 32],
267        timestamp: u64,
268    ) -> SequenceValidationResult {
269        // Check if timestamp is too far in the future (clock skew protection)
270        let current_time = current_timestamp();
271        if timestamp > current_time + 60 {
272            return SequenceValidationResult::FromFuture;
273        }
274
275        // Check if sequence is too old
276        if timestamp < current_time.saturating_sub(MAX_SEQUENCE_AGE.as_secs()) {
277            return SequenceValidationResult::TooOld;
278        }
279
280        // Check for replay attack (already seen this sequence)
281        if peer_counter.has_seen_sequence(sequence, message_hash) {
282            return SequenceValidationResult::Replay;
283        }
284
285        // Check for sequence gaps
286        if sequence > peer_counter.last_valid_sequence + 1 {
287            return SequenceValidationResult::Gap {
288                expected: peer_counter.last_valid_sequence + 1,
289                received: sequence,
290            };
291        }
292
293        // Check if sequence is older than last valid (out of order)
294        if sequence <= peer_counter.last_valid_sequence {
295            return SequenceValidationResult::Replay;
296        }
297
298        SequenceValidationResult::Valid
299    }
300
301    /// Process batch updates for multiple peers
302    pub async fn batch_update(
303        &self,
304        requests: Vec<BatchUpdateRequest>,
305    ) -> Result<Vec<BatchUpdateResult>> {
306        let mut results = Vec::with_capacity(requests.len());
307
308        // Process all requests atomically
309        {
310            let mut counters = self.counters.write().map_err(|_| {
311                P2PError::Storage(StorageError::LockPoisoned(
312                    "write lock failed".to_string().into(),
313                ))
314            })?;
315
316            for request in requests {
317                let peer_counter = counters
318                    .entry(request.user_id.clone())
319                    .or_insert_with(PeerCounter::new);
320
321                let validation_result = self.validate_sequence_internal(
322                    peer_counter,
323                    request.sequence,
324                    request.message_hash,
325                    request.timestamp,
326                );
327
328                let applied = matches!(validation_result, SequenceValidationResult::Valid);
329
330                if applied {
331                    peer_counter.apply_sequence_update(
332                        request.sequence,
333                        request.message_hash,
334                        request.timestamp,
335                    );
336                }
337
338                results.push(BatchUpdateResult {
339                    user_id: request.user_id,
340                    result: validation_result,
341                    applied,
342                });
343            }
344        }
345
346        // Update batch statistics
347        {
348            let mut stats = self.stats.lock().await;
349            stats.total_processed += results.len() as u64;
350            stats.total_replays += results
351                .iter()
352                .filter(|r| matches!(r.result, SequenceValidationResult::Replay))
353                .count() as u64;
354            stats.total_gaps += results
355                .iter()
356                .filter(|r| matches!(r.result, SequenceValidationResult::Gap { .. }))
357                .count() as u64;
358        }
359
360        Ok(results)
361    }
362
363    /// Get current statistics
364    pub async fn get_stats(&self) -> CounterStats {
365        let stats = self.stats.lock().await;
366        let mut current_stats = stats.clone();
367
368        // Update live statistics
369        let counters = self.counters.read().unwrap_or_else(|e| e.into_inner());
370        current_stats.peers_tracked = counters.len();
371
372        current_stats
373    }
374
375    /// Get counter state for a specific peer
376    pub async fn get_peer_counter(&self, user_id: &UserId) -> Option<PeerCounter> {
377        let counters = self.counters.read().ok()?;
378        counters.get(user_id).cloned()
379    }
380
381    /// Reset counter for a peer (use with caution)
382    pub async fn reset_peer_counter(&self, user_id: &UserId) -> Result<()> {
383        let mut counters = self.counters.write().map_err(|_| {
384            P2PError::Storage(StorageError::LockPoisoned(
385                "write lock failed".to_string().into(),
386            ))
387        })?;
388        counters.remove(user_id);
389        Ok(())
390    }
391
392    /// Cleanup old sequence entries
393    pub async fn cleanup_old_sequences(&self) -> Result<()> {
394        let current_time = current_timestamp();
395        let cutoff_time = current_time.saturating_sub(MAX_SEQUENCE_AGE.as_secs());
396
397        let mut counters = self.counters.write().map_err(|_| {
398            P2PError::Storage(StorageError::LockPoisoned(
399                "write lock failed".to_string().into(),
400            ))
401        })?;
402        for (_, peer_counter) in counters.iter_mut() {
403            peer_counter.cleanup_old_sequences(cutoff_time);
404        }
405
406        Ok(())
407    }
408
409    /// Load counters from persistent storage
410    async fn load_counters(storage_path: &PathBuf) -> Result<HashMap<UserId, PeerCounter>> {
411        if !storage_path.exists() {
412            return Ok(HashMap::new());
413        }
414
415        let data = tokio::fs::read(storage_path).await.map_err(|e| {
416            P2PError::Storage(StorageError::Database(
417                format!("Failed to read counters file: {e}").into(),
418            ))
419        })?;
420
421        let counters: HashMap<UserId, PeerCounter> = postcard::from_bytes(&data).map_err(|e| {
422            P2PError::Storage(StorageError::Database(
423                format!("Failed to deserialize counters: {e}").into(),
424            ))
425        })?;
426
427        Ok(counters)
428    }
429
430    /// Sync counters to persistent storage
431    async fn sync_counters(
432        counters: &Arc<RwLock<HashMap<UserId, PeerCounter>>>,
433        storage_path: &PathBuf,
434        stats: &Arc<Mutex<CounterStats>>,
435    ) -> Result<()> {
436        let counters_snapshot = {
437            let counters = counters.read().map_err(|_| {
438                P2PError::Storage(StorageError::LockPoisoned(
439                    "read lock failed".to_string().into(),
440                ))
441            })?;
442            counters.clone()
443        };
444
445        let data = postcard::to_stdvec(&counters_snapshot).map_err(|e| {
446            P2PError::Storage(StorageError::Database(
447                format!("Failed to serialize counters: {e}").into(),
448            ))
449        })?;
450
451        tokio::fs::write(storage_path, data).await.map_err(|e| {
452            P2PError::Storage(StorageError::Database(
453                format!("Failed to write counters file: {e}").into(),
454            ))
455        })?;
456
457        // Update statistics
458        {
459            let mut stats = stats.lock().await;
460            stats.persistence_ops += 1;
461        }
462
463        Ok(())
464    }
465
466    /// Update validation statistics
467    async fn update_validation_stats(
468        &self,
469        start_time: Instant,
470        result: &SequenceValidationResult,
471    ) {
472        let elapsed = start_time.elapsed().as_micros() as u64;
473        let mut stats = self.stats.lock().await;
474
475        // Update running average
476        let total_ops = stats.total_processed + 1;
477        stats.avg_validation_time_us =
478            (stats.avg_validation_time_us * stats.total_processed + elapsed) / total_ops;
479
480        stats.total_processed = total_ops;
481
482        match result {
483            SequenceValidationResult::Replay => stats.total_replays += 1,
484            SequenceValidationResult::Gap { .. } => stats.total_gaps += 1,
485            _ => {}
486        }
487    }
488}
489
490impl PeerCounter {
491    /// Create a new peer counter
492    pub fn new() -> Self {
493        Self {
494            current_sequence: 0,
495            last_valid_sequence: 0,
496            sequence_history: Vec::new(),
497            last_updated: current_timestamp(),
498            replay_attempts: 0,
499            sequence_gaps: 0,
500        }
501    }
502
503    /// Check if we've seen this sequence number before
504    pub fn has_seen_sequence(&self, sequence: u64, message_hash: [u8; 32]) -> bool {
505        self.sequence_history
506            .iter()
507            .any(|entry| entry.sequence == sequence && entry.message_hash == message_hash)
508    }
509
510    /// Apply a sequence update
511    pub fn apply_sequence_update(&mut self, sequence: u64, message_hash: [u8; 32], timestamp: u64) {
512        // Update current sequence
513        self.current_sequence = sequence;
514        self.last_valid_sequence = sequence;
515        self.last_updated = timestamp;
516
517        // Add to history
518        self.sequence_history.push(SequenceEntry {
519            sequence,
520            timestamp,
521            message_hash,
522        });
523
524        // Maintain history size limit
525        if self.sequence_history.len() > MAX_SEQUENCE_HISTORY {
526            self.sequence_history.remove(0);
527        }
528    }
529
530    /// Cleanup old sequences from history
531    pub fn cleanup_old_sequences(&mut self, cutoff_time: u64) {
532        self.sequence_history
533            .retain(|entry| entry.timestamp >= cutoff_time);
534    }
535
536    /// Get the next expected sequence number
537    pub fn next_expected_sequence(&self) -> u64 {
538        self.last_valid_sequence + 1
539    }
540}
541
542impl Default for PeerCounter {
543    fn default() -> Self {
544        Self::new()
545    }
546}
547
548/// Get current Unix timestamp
549fn current_timestamp() -> u64 {
550    SystemTime::now()
551        .duration_since(UNIX_EPOCH)
552        .map(|d| d.as_secs())
553        .unwrap_or(0)
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559    use tempfile::tempdir;
560    use tokio::test;
561
562    #[test]
563    async fn test_sequence_validation() {
564        let temp_dir = tempdir().unwrap();
565        let storage_path = temp_dir.path().join("counters.bin");
566        let system = MonotonicCounterSystem::new(storage_path).await.unwrap();
567
568        let user_id = UserId::from_bytes([1; 32]);
569        let message_hash = *blake3::hash(b"test message").as_bytes();
570
571        // First sequence should be valid
572        let result = system
573            .validate_sequence(&user_id, 1, message_hash)
574            .await
575            .unwrap();
576        assert_eq!(result, SequenceValidationResult::Valid);
577
578        // Replay should be detected
579        let result = system
580            .validate_sequence(&user_id, 1, message_hash)
581            .await
582            .unwrap();
583        assert_eq!(result, SequenceValidationResult::Replay);
584
585        // Next sequence should be valid
586        let result = system
587            .validate_sequence(&user_id, 2, message_hash)
588            .await
589            .unwrap();
590        assert_eq!(result, SequenceValidationResult::Valid);
591
592        // Gap should be detected
593        let result = system
594            .validate_sequence(&user_id, 5, message_hash)
595            .await
596            .unwrap();
597        assert_eq!(
598            result,
599            SequenceValidationResult::Gap {
600                expected: 3,
601                received: 5
602            }
603        );
604    }
605
606    #[test]
607    async fn test_batch_updates() {
608        let temp_dir = tempdir().unwrap();
609        let storage_path = temp_dir.path().join("counters.bin");
610        let system = MonotonicCounterSystem::new(storage_path).await.unwrap();
611
612        let user_id1 = UserId::from_bytes([1; 32]);
613        let user_id2 = UserId::from_bytes([2; 32]);
614        let message_hash = *blake3::hash(b"test message").as_bytes();
615
616        let requests = vec![
617            BatchUpdateRequest {
618                user_id: user_id1.clone(),
619                sequence: 1,
620                message_hash,
621                timestamp: current_timestamp(),
622            },
623            BatchUpdateRequest {
624                user_id: user_id2.clone(),
625                sequence: 1,
626                message_hash,
627                timestamp: current_timestamp(),
628            },
629        ];
630
631        let results = system.batch_update(requests).await.unwrap();
632        assert_eq!(results.len(), 2);
633        assert!(results.iter().all(|r| r.applied));
634        assert!(
635            results
636                .iter()
637                .all(|r| matches!(r.result, SequenceValidationResult::Valid))
638        );
639    }
640
641    #[test]
642    async fn test_persistence() {
643        let temp_dir = tempdir().unwrap();
644        let storage_path = temp_dir.path().join("counters.bin");
645
646        // Create system and add some counters
647        {
648            let system = MonotonicCounterSystem::new(storage_path.clone())
649                .await
650                .unwrap();
651            let user_id = UserId::from_bytes([1; 32]);
652            let message_hash = *blake3::hash(b"test message").as_bytes();
653
654            system
655                .validate_sequence(&user_id, 1, message_hash)
656                .await
657                .unwrap();
658            system
659                .validate_sequence(&user_id, 2, message_hash)
660                .await
661                .unwrap();
662
663            // Force sync
664            MonotonicCounterSystem::sync_counters(&system.counters, &storage_path, &system.stats)
665                .await
666                .unwrap();
667        }
668
669        // Create new system and verify counters are loaded
670        {
671            let system = MonotonicCounterSystem::new(storage_path).await.unwrap();
672            let user_id = UserId::from_bytes([1; 32]);
673            let counter = system.get_peer_counter(&user_id).await.unwrap();
674
675            assert_eq!(counter.last_valid_sequence, 2);
676            assert_eq!(counter.sequence_history.len(), 2);
677        }
678    }
679
680    #[test]
681    async fn test_old_sequence_cleanup() {
682        let temp_dir = tempdir().unwrap();
683        let storage_path = temp_dir.path().join("counters.bin");
684        let system = MonotonicCounterSystem::new(storage_path).await.unwrap();
685
686        let user_id = UserId::from_bytes([1; 32]);
687        let message_hash = *blake3::hash(b"test message").as_bytes();
688
689        // Add some sequences
690        for i in 1..=10 {
691            system
692                .validate_sequence(&user_id, i, message_hash)
693                .await
694                .unwrap();
695        }
696
697        // Verify we have sequences
698        let counter = system.get_peer_counter(&user_id).await.unwrap();
699        assert_eq!(counter.sequence_history.len(), 10);
700
701        // Cleanup shouldn't remove anything (sequences are recent)
702        system.cleanup_old_sequences().await.unwrap();
703        let counter = system.get_peer_counter(&user_id).await.unwrap();
704        assert_eq!(counter.sequence_history.len(), 10);
705    }
706
707    #[test]
708    async fn test_statistics() {
709        let temp_dir = tempdir().unwrap();
710        let storage_path = temp_dir.path().join("counters.bin");
711        let system = MonotonicCounterSystem::new(storage_path).await.unwrap();
712
713        let user_id = UserId::from_bytes([1; 32]);
714        let message_hash = *blake3::hash(b"test message").as_bytes();
715
716        // Process some sequences
717        system
718            .validate_sequence(&user_id, 1, message_hash)
719            .await
720            .unwrap();
721        system
722            .validate_sequence(&user_id, 1, message_hash)
723            .await
724            .unwrap(); // Replay
725        system
726            .validate_sequence(&user_id, 5, message_hash)
727            .await
728            .unwrap(); // Gap
729
730        let stats = system.get_stats().await;
731        assert_eq!(stats.total_processed, 3);
732        assert_eq!(stats.total_replays, 1);
733        assert_eq!(stats.total_gaps, 1);
734        assert_eq!(stats.peers_tracked, 1);
735    }
736}