sochdb_storage/
dirty_tracking.rs

1// Copyright 2025 Sushanth (https://github.com/sushanthpy)
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Batched Dirty Tracking with MPSC Queue
16//!
17//! This module replaces per-write mutex acquisition for dirty tracking
18//! with a batched MPSC approach that dramatically reduces lock contention.
19//!
20//! ## Problem: Lock Convoying
21//!
22//! Per-write mutex acquisition is the canonical scalability killer:
23//! - Serializes otherwise-parallel writers
24//! - Causes lock convoying under contention
25//! - N writers → N lock acquisitions per batch
26//!
27//! ## Solution: Batched MPSC + Thread-Local Buffering
28//!
29//! ```text
30//! ┌────────────────────────────────────────────────────────────────┐
31//! │                    Writer Threads                               │
32//! │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐       │
33//! │  │ Thread 1 │  │ Thread 2 │  │ Thread 3 │  │ Thread N │       │
34//! │  │  Buffer  │  │  Buffer  │  │  Buffer  │  │  Buffer  │       │
35//! │  └────┬─────┘  └────┬─────┘  └────┬─────┘  └────┬─────┘       │
36//! │       │             │             │             │              │
37//! │       └─────────────┼─────────────┼─────────────┘              │
38//! │                     │             │                            │
39//! │                     ▼             ▼                            │
40//! │              ┌──────────────────────────────┐                  │
41//! │              │      MPSC Channel            │                  │
42//! │              │  (crossbeam-channel)         │                  │
43//! │              └──────────────┬───────────────┘                  │
44//! │                             │                                  │
45//! │                             ▼                                  │
46//! │              ┌──────────────────────────────┐                  │
47//! │              │   Aggregator Thread          │                  │
48//! │              │   (drains every 10ms or     │                  │
49//! │              │    every 1000 entries)       │                  │
50//! │              └──────────────────────────────┘                  │
51//! └────────────────────────────────────────────────────────────────┘
52//!
53//! Lock acquisitions: O(W) → O(W/B) where W=writes, B=batch size
54//! ```
55//!
56//! ## Performance
57//!
58//! - Thread-local buffer: Zero contention during writes
59//! - MPSC send: ~20ns (vs ~200ns for mutex under contention)
60//! - Batch flush: Amortized over B writes
61//!
62//! ## Usage
63//!
64//! ```ignore
65//! let tracker = BatchedDirtyTracker::new();
66//!
67//! // In transaction:
68//! tracker.mark_dirty(key);  // Lock-free, buffers locally
69//!
70//! // At commit:
71//! tracker.flush_buffer();   // Sends batch to aggregator
72//! ```
73
74use std::cell::RefCell;
75use std::collections::HashSet;
76use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
77use std::sync::Arc;
78use std::thread::{self, JoinHandle};
79
80use crossbeam_channel::{self, Receiver, Sender};
81use parking_lot::Mutex;
82
83use crate::txn_arena::KeyFingerprint;
84
85/// Default batch size for thread-local buffer
86const DEFAULT_BATCH_SIZE: usize = 64;
87
88/// Maximum wait time before flushing (in milliseconds)
89const MAX_FLUSH_INTERVAL_MS: u64 = 10;
90
91// ============================================================================
92// DirtyEvent - Batched Dirty Key Notification
93// ============================================================================
94
95/// Event sent through MPSC channel
96#[derive(Debug)]
97pub enum DirtyEvent {
98    /// Batch of dirty key fingerprints from a transaction
99    Batch {
100        txn_id: u64,
101        keys: Vec<KeyFingerprint>,
102    },
103    /// Epoch advance request
104    AdvanceEpoch,
105    /// Shutdown signal
106    Shutdown,
107}
108
109// ============================================================================
110// ThreadLocalBuffer - Per-Thread Dirty Key Accumulator
111// ============================================================================
112
113/// Thread-local buffer for accumulating dirty keys
114struct ThreadLocalBuffer {
115    /// Current transaction ID
116    txn_id: u64,
117    /// Accumulated dirty key fingerprints
118    keys: Vec<KeyFingerprint>,
119    /// Sender to aggregator
120    sender: Sender<DirtyEvent>,
121}
122
123impl ThreadLocalBuffer {
124    fn new(sender: Sender<DirtyEvent>) -> Self {
125        Self {
126            txn_id: 0,
127            keys: Vec::with_capacity(DEFAULT_BATCH_SIZE),
128            sender,
129        }
130    }
131
132    /// Mark a key as dirty (no lock, no send)
133    #[inline]
134    fn mark_dirty(&mut self, txn_id: u64, key_fingerprint: KeyFingerprint) {
135        if self.txn_id != txn_id {
136            // New transaction - flush old buffer if any
137            self.flush();
138            self.txn_id = txn_id;
139        }
140        self.keys.push(key_fingerprint);
141    }
142
143    /// Flush accumulated keys to aggregator
144    fn flush(&mut self) {
145        if !self.keys.is_empty() {
146            let keys = std::mem::take(&mut self.keys);
147            // Best-effort send - don't block if channel is full
148            let _ = self.sender.try_send(DirtyEvent::Batch {
149                txn_id: self.txn_id,
150                keys,
151            });
152            self.keys = Vec::with_capacity(DEFAULT_BATCH_SIZE);
153        }
154    }
155}
156
157// ============================================================================
158// BatchedDirtyTracker - Lock-Free Dirty Tracking
159// ============================================================================
160
161/// Batched dirty tracker with MPSC queue
162///
163/// Provides lock-free dirty tracking for multi-threaded writes.
164/// Each thread accumulates dirty keys locally, then sends them
165/// in batches through an MPSC channel to an aggregator.
166pub struct BatchedDirtyTracker {
167    /// MPSC sender (cloned for each thread)
168    sender: Sender<DirtyEvent>,
169    /// MPSC receiver (for aggregator thread)
170    receiver: Receiver<DirtyEvent>,
171    /// Aggregator thread handle
172    aggregator_handle: Mutex<Option<JoinHandle<()>>>,
173    /// Running flag
174    running: AtomicBool,
175    /// Current epoch
176    current_epoch: AtomicU64,
177    /// Aggregated dirty keys per epoch
178    epochs: [Mutex<HashSet<KeyFingerprint>>; 4],
179    /// Statistics
180    stats: DirtyTrackingStats,
181}
182
183/// Dirty tracking statistics
184pub struct DirtyTrackingStats {
185    /// Total events received
186    pub events_received: AtomicU64,
187    /// Total keys tracked
188    pub keys_tracked: AtomicU64,
189    /// Total batches received
190    pub batches_received: AtomicU64,
191    /// Current epoch
192    pub current_epoch: AtomicU64,
193}
194
195impl Default for DirtyTrackingStats {
196    fn default() -> Self {
197        Self {
198            events_received: AtomicU64::new(0),
199            keys_tracked: AtomicU64::new(0),
200            batches_received: AtomicU64::new(0),
201            current_epoch: AtomicU64::new(0),
202        }
203    }
204}
205
206const EPOCH_RING_SIZE: usize = 4;
207
208impl BatchedDirtyTracker {
209    /// Create a new batched dirty tracker
210    pub fn new() -> Arc<Self> {
211        let (sender, receiver) = crossbeam_channel::bounded(1024);
212        
213        let tracker = Arc::new(Self {
214            sender,
215            receiver,
216            aggregator_handle: Mutex::new(None),
217            running: AtomicBool::new(false),
218            current_epoch: AtomicU64::new(0),
219            epochs: [
220                Mutex::new(HashSet::new()),
221                Mutex::new(HashSet::new()),
222                Mutex::new(HashSet::new()),
223                Mutex::new(HashSet::new()),
224            ],
225            stats: DirtyTrackingStats::default(),
226        });
227        
228        tracker
229    }
230
231    /// Start the aggregator thread
232    pub fn start(self: &Arc<Self>) {
233        if self.running.swap(true, Ordering::SeqCst) {
234            return; // Already running
235        }
236
237        let tracker = Arc::clone(self);
238        let handle = thread::spawn(move || {
239            tracker.aggregator_loop();
240        });
241
242        *self.aggregator_handle.lock() = Some(handle);
243    }
244
245    /// Stop the aggregator thread
246    pub fn stop(&self) {
247        if !self.running.swap(false, Ordering::SeqCst) {
248            return; // Already stopped
249        }
250
251        // Send shutdown signal
252        let _ = self.sender.send(DirtyEvent::Shutdown);
253
254        // Wait for aggregator to finish
255        if let Some(handle) = self.aggregator_handle.lock().take() {
256            let _ = handle.join();
257        }
258    }
259
260    /// Get a sender for a thread to use
261    pub fn get_sender(&self) -> Sender<DirtyEvent> {
262        self.sender.clone()
263    }
264
265    /// Mark a key as dirty using a thread-local buffer
266    ///
267    /// This is the zero-contention hot path used by writers.
268    #[inline]
269    pub fn mark_dirty(&self, txn_id: u64, key_fingerprint: KeyFingerprint) {
270        thread_local! {
271            static BUFFER: RefCell<Option<ThreadLocalBuffer>> = const { RefCell::new(None) };
272        }
273
274        BUFFER.with(|cell| {
275            let mut buffer = cell.borrow_mut();
276            if buffer.is_none() {
277                *buffer = Some(ThreadLocalBuffer::new(self.sender.clone()));
278            }
279            buffer.as_mut().unwrap().mark_dirty(txn_id, key_fingerprint);
280        });
281    }
282
283    /// Flush the current thread's buffer
284    pub fn flush_thread_buffer(&self) {
285        thread_local! {
286            static BUFFER: RefCell<Option<ThreadLocalBuffer>> = const { RefCell::new(None) };
287        }
288
289        BUFFER.with(|cell| {
290            if let Some(buffer) = cell.borrow_mut().as_mut() {
291                buffer.flush();
292            }
293        });
294    }
295
296    /// Send a batch of dirty keys directly (for transaction commit)
297    #[inline]
298    pub fn send_batch(&self, txn_id: u64, keys: Vec<KeyFingerprint>) {
299        if keys.is_empty() {
300            return;
301        }
302        let _ = self.sender.try_send(DirtyEvent::Batch { txn_id, keys });
303    }
304
305    /// Advance to next epoch, returning the old epoch's dirty keys
306    pub fn advance_epoch(&self) -> (u64, Vec<KeyFingerprint>) {
307        // Send epoch advance event to ensure all pending events are processed
308        let _ = self.sender.try_send(DirtyEvent::AdvanceEpoch);
309        
310        let old_epoch = self.current_epoch.fetch_add(1, Ordering::SeqCst);
311        let old_idx = (old_epoch as usize) % EPOCH_RING_SIZE;
312        
313        // Drain the old epoch
314        let mut guard = self.epochs[old_idx].lock();
315        let keys: Vec<_> = guard.drain().collect();
316        
317        self.stats.current_epoch.store(old_epoch + 1, Ordering::Relaxed);
318        
319        (old_epoch, keys)
320    }
321
322    /// Get current epoch
323    pub fn current_epoch(&self) -> u64 {
324        self.current_epoch.load(Ordering::Relaxed)
325    }
326
327    /// Get statistics
328    pub fn stats(&self) -> &DirtyTrackingStats {
329        &self.stats
330    }
331
332    /// Aggregator loop - runs in background thread
333    fn aggregator_loop(&self) {
334        use crossbeam_channel::RecvTimeoutError;
335        
336        let timeout = std::time::Duration::from_millis(MAX_FLUSH_INTERVAL_MS);
337        
338        while self.running.load(Ordering::Relaxed) {
339            match self.receiver.recv_timeout(timeout) {
340                Ok(event) => {
341                    self.process_event(event);
342                }
343                Err(RecvTimeoutError::Timeout) => {
344                    // No events for a while - that's fine
345                }
346                Err(RecvTimeoutError::Disconnected) => {
347                    break;
348                }
349            }
350        }
351        
352        // Drain remaining events on shutdown
353        while let Ok(event) = self.receiver.try_recv() {
354            if matches!(event, DirtyEvent::Shutdown) {
355                break;
356            }
357            self.process_event(event);
358        }
359    }
360
361    /// Process a single event
362    fn process_event(&self, event: DirtyEvent) {
363        match event {
364            DirtyEvent::Batch { txn_id: _, keys } => {
365                let epoch = self.current_epoch.load(Ordering::Relaxed);
366                let idx = (epoch as usize) % EPOCH_RING_SIZE;
367                
368                let mut guard = self.epochs[idx].lock();
369                let key_count = keys.len();
370                guard.extend(keys);
371                
372                self.stats.events_received.fetch_add(1, Ordering::Relaxed);
373                self.stats.keys_tracked.fetch_add(key_count as u64, Ordering::Relaxed);
374                self.stats.batches_received.fetch_add(1, Ordering::Relaxed);
375            }
376            DirtyEvent::AdvanceEpoch => {
377                // Epoch advance is handled by the caller
378            }
379            DirtyEvent::Shutdown => {
380                // Will exit the loop
381            }
382        }
383    }
384}
385
386impl Default for BatchedDirtyTracker {
387    fn default() -> Self {
388        let (sender, receiver) = crossbeam_channel::bounded(1024);
389        Self {
390            sender,
391            receiver,
392            aggregator_handle: Mutex::new(None),
393            running: AtomicBool::new(false),
394            current_epoch: AtomicU64::new(0),
395            epochs: [
396                Mutex::new(HashSet::new()),
397                Mutex::new(HashSet::new()),
398                Mutex::new(HashSet::new()),
399                Mutex::new(HashSet::new()),
400            ],
401            stats: DirtyTrackingStats::default(),
402        }
403    }
404}
405
406impl Drop for BatchedDirtyTracker {
407    fn drop(&mut self) {
408        self.stop();
409    }
410}
411
412// ============================================================================
413// TxnDirtyBuffer - Transaction-Local Dirty Key Buffer
414// ============================================================================
415
416/// Transaction-local dirty key buffer for commit-time batching
417///
418/// Instead of tracking dirty keys globally during the transaction,
419/// this buffer accumulates them locally and flushes once at commit.
420/// This is simpler than the thread-local approach and works well
421/// for single-threaded transaction execution.
422pub struct TxnDirtyBuffer {
423    /// Transaction ID
424    txn_id: u64,
425    /// Accumulated dirty key fingerprints
426    keys: Vec<KeyFingerprint>,
427}
428
429impl TxnDirtyBuffer {
430    /// Create a new transaction dirty buffer
431    #[inline]
432    pub fn new(txn_id: u64) -> Self {
433        Self {
434            txn_id,
435            keys: Vec::with_capacity(64),
436        }
437    }
438
439    /// Create with expected capacity
440    #[inline]
441    pub fn with_capacity(txn_id: u64, capacity: usize) -> Self {
442        Self {
443            txn_id,
444            keys: Vec::with_capacity(capacity),
445        }
446    }
447
448    /// Record a dirty key (no lock, just local append)
449    #[inline]
450    pub fn record(&mut self, key_fingerprint: KeyFingerprint) {
451        self.keys.push(key_fingerprint);
452    }
453
454    /// Record multiple dirty keys
455    #[inline]
456    pub fn record_many(&mut self, key_fingerprints: impl IntoIterator<Item = KeyFingerprint>) {
457        self.keys.extend(key_fingerprints);
458    }
459
460    /// Get the transaction ID
461    #[inline]
462    pub fn txn_id(&self) -> u64 {
463        self.txn_id
464    }
465
466    /// Get the number of dirty keys
467    #[inline]
468    pub fn len(&self) -> usize {
469        self.keys.len()
470    }
471
472    /// Check if empty
473    #[inline]
474    pub fn is_empty(&self) -> bool {
475        self.keys.is_empty()
476    }
477
478    /// Drain the buffer and return all keys
479    #[inline]
480    pub fn drain(&mut self) -> Vec<KeyFingerprint> {
481        std::mem::take(&mut self.keys)
482    }
483
484    /// Flush to a BatchedDirtyTracker
485    #[inline]
486    pub fn flush_to(&mut self, tracker: &BatchedDirtyTracker) {
487        if !self.keys.is_empty() {
488            tracker.send_batch(self.txn_id, std::mem::take(&mut self.keys));
489            self.keys = Vec::with_capacity(64);
490        }
491    }
492
493    /// Clear the buffer without flushing
494    #[inline]
495    pub fn clear(&mut self) {
496        self.keys.clear();
497    }
498}
499
500// ============================================================================
501// Tests
502// ============================================================================
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507    use std::time::Duration;
508
509    #[test]
510    fn test_txn_dirty_buffer() {
511        let mut buffer = TxnDirtyBuffer::new(1);
512        
513        buffer.record(KeyFingerprint::from_bytes(b"key1"));
514        buffer.record(KeyFingerprint::from_bytes(b"key2"));
515        buffer.record(KeyFingerprint::from_bytes(b"key3"));
516        
517        assert_eq!(buffer.len(), 3);
518        
519        let keys = buffer.drain();
520        assert_eq!(keys.len(), 3);
521        assert!(buffer.is_empty());
522    }
523
524    #[test]
525    fn test_batched_tracker_basic() {
526        let tracker = BatchedDirtyTracker::new();
527        tracker.start();
528        
529        // Send some events directly
530        tracker.send_batch(1, vec![
531            KeyFingerprint::from_bytes(b"key1"),
532            KeyFingerprint::from_bytes(b"key2"),
533        ]);
534        
535        // Give aggregator time to process
536        thread::sleep(Duration::from_millis(50));
537        
538        // Advance epoch to collect
539        let (_epoch, keys) = tracker.advance_epoch();
540        
541        // Keys should have been processed
542        assert!(tracker.stats().batches_received.load(Ordering::Relaxed) >= 1);
543        
544        tracker.stop();
545    }
546
547    #[test]
548    fn test_epoch_rotation() {
549        let tracker = BatchedDirtyTracker::new();
550        
551        // Directly insert into epochs without starting the aggregator thread
552        {
553            let mut guard = tracker.epochs[0].lock();
554            guard.insert(KeyFingerprint::from_bytes(b"key1"));
555            guard.insert(KeyFingerprint::from_bytes(b"key2"));
556        }
557        
558        let (epoch, keys) = tracker.advance_epoch();
559        assert_eq!(epoch, 0);
560        assert_eq!(keys.len(), 2);
561        
562        // New epoch should be empty
563        let (epoch2, keys2) = tracker.advance_epoch();
564        assert_eq!(epoch2, 1);
565        assert!(keys2.is_empty());
566    }
567}