Skip to main content

sochdb_storage/
dirty_tracking.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Batched Dirty Tracking with MPSC Queue
19//!
20//! This module replaces per-write mutex acquisition for dirty tracking
21//! with a batched MPSC approach that dramatically reduces lock contention.
22//!
23//! ## Problem: Lock Convoying
24//!
25//! Per-write mutex acquisition is the canonical scalability killer:
26//! - Serializes otherwise-parallel writers
27//! - Causes lock convoying under contention
28//! - N writers → N lock acquisitions per batch
29//!
30//! ## Solution: Batched MPSC + Thread-Local Buffering
31//!
32//! ```text
33//! ┌────────────────────────────────────────────────────────────────┐
34//! │                    Writer Threads                               │
35//! │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐       │
36//! │  │ Thread 1 │  │ Thread 2 │  │ Thread 3 │  │ Thread N │       │
37//! │  │  Buffer  │  │  Buffer  │  │  Buffer  │  │  Buffer  │       │
38//! │  └────┬─────┘  └────┬─────┘  └────┬─────┘  └────┬─────┘       │
39//! │       │             │             │             │              │
40//! │       └─────────────┼─────────────┼─────────────┘              │
41//! │                     │             │                            │
42//! │                     ▼             ▼                            │
43//! │              ┌──────────────────────────────┐                  │
44//! │              │      MPSC Channel            │                  │
45//! │              │  (crossbeam-channel)         │                  │
46//! │              └──────────────┬───────────────┘                  │
47//! │                             │                                  │
48//! │                             ▼                                  │
49//! │              ┌──────────────────────────────┐                  │
50//! │              │   Aggregator Thread          │                  │
51//! │              │   (drains every 10ms or     │                  │
52//! │              │    every 1000 entries)       │                  │
53//! │              └──────────────────────────────┘                  │
54//! └────────────────────────────────────────────────────────────────┘
55//!
56//! Lock acquisitions: O(W) → O(W/B) where W=writes, B=batch size
57//! ```
58//!
59//! ## Performance
60//!
61//! - Thread-local buffer: Zero contention during writes
62//! - MPSC send: ~20ns (vs ~200ns for mutex under contention)
63//! - Batch flush: Amortized over B writes
64//!
65//! ## Usage
66//!
67//! ```ignore
68//! let tracker = BatchedDirtyTracker::new();
69//!
70//! // In transaction:
71//! tracker.mark_dirty(key);  // Lock-free, buffers locally
72//!
73//! // At commit:
74//! tracker.flush_buffer();   // Sends batch to aggregator
75//! ```
76
77use std::cell::RefCell;
78use std::collections::HashSet;
79use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
80use std::sync::Arc;
81use std::thread::{self, JoinHandle};
82
83use crossbeam_channel::{self, Receiver, Sender};
84use parking_lot::Mutex;
85
86use crate::txn_arena::KeyFingerprint;
87
88/// Default batch size for thread-local buffer
89const DEFAULT_BATCH_SIZE: usize = 64;
90
91/// Maximum wait time before flushing (in milliseconds)
92const MAX_FLUSH_INTERVAL_MS: u64 = 10;
93
94// ============================================================================
95// DirtyEvent - Batched Dirty Key Notification
96// ============================================================================
97
98/// Event sent through MPSC channel
99#[derive(Debug)]
100pub enum DirtyEvent {
101    /// Batch of dirty key fingerprints from a transaction
102    Batch {
103        txn_id: u64,
104        keys: Vec<KeyFingerprint>,
105    },
106    /// Epoch advance request
107    AdvanceEpoch,
108    /// Shutdown signal
109    Shutdown,
110}
111
112// ============================================================================
113// ThreadLocalBuffer - Per-Thread Dirty Key Accumulator
114// ============================================================================
115
116/// Thread-local buffer for accumulating dirty keys
117struct ThreadLocalBuffer {
118    /// Current transaction ID
119    txn_id: u64,
120    /// Accumulated dirty key fingerprints
121    keys: Vec<KeyFingerprint>,
122    /// Sender to aggregator
123    sender: Sender<DirtyEvent>,
124}
125
126impl ThreadLocalBuffer {
127    fn new(sender: Sender<DirtyEvent>) -> Self {
128        Self {
129            txn_id: 0,
130            keys: Vec::with_capacity(DEFAULT_BATCH_SIZE),
131            sender,
132        }
133    }
134
135    /// Mark a key as dirty (no lock, no send)
136    #[inline]
137    fn mark_dirty(&mut self, txn_id: u64, key_fingerprint: KeyFingerprint) {
138        if self.txn_id != txn_id {
139            // New transaction - flush old buffer if any
140            self.flush();
141            self.txn_id = txn_id;
142        }
143        self.keys.push(key_fingerprint);
144    }
145
146    /// Flush accumulated keys to aggregator
147    fn flush(&mut self) {
148        if !self.keys.is_empty() {
149            let keys = std::mem::take(&mut self.keys);
150            // Best-effort send - don't block if channel is full
151            let _ = self.sender.try_send(DirtyEvent::Batch {
152                txn_id: self.txn_id,
153                keys,
154            });
155            self.keys = Vec::with_capacity(DEFAULT_BATCH_SIZE);
156        }
157    }
158}
159
160// ============================================================================
161// BatchedDirtyTracker - Lock-Free Dirty Tracking
162// ============================================================================
163
164/// Batched dirty tracker with MPSC queue
165///
166/// Provides lock-free dirty tracking for multi-threaded writes.
167/// Each thread accumulates dirty keys locally, then sends them
168/// in batches through an MPSC channel to an aggregator.
169pub struct BatchedDirtyTracker {
170    /// MPSC sender (cloned for each thread)
171    sender: Sender<DirtyEvent>,
172    /// MPSC receiver (for aggregator thread)
173    receiver: Receiver<DirtyEvent>,
174    /// Aggregator thread handle
175    aggregator_handle: Mutex<Option<JoinHandle<()>>>,
176    /// Running flag
177    running: AtomicBool,
178    /// Current epoch
179    current_epoch: AtomicU64,
180    /// Aggregated dirty keys per epoch
181    epochs: [Mutex<HashSet<KeyFingerprint>>; 4],
182    /// Statistics
183    stats: DirtyTrackingStats,
184}
185
186/// Dirty tracking statistics
187pub struct DirtyTrackingStats {
188    /// Total events received
189    pub events_received: AtomicU64,
190    /// Total keys tracked
191    pub keys_tracked: AtomicU64,
192    /// Total batches received
193    pub batches_received: AtomicU64,
194    /// Current epoch
195    pub current_epoch: AtomicU64,
196}
197
198impl Default for DirtyTrackingStats {
199    fn default() -> Self {
200        Self {
201            events_received: AtomicU64::new(0),
202            keys_tracked: AtomicU64::new(0),
203            batches_received: AtomicU64::new(0),
204            current_epoch: AtomicU64::new(0),
205        }
206    }
207}
208
209const EPOCH_RING_SIZE: usize = 4;
210
211impl BatchedDirtyTracker {
212    /// Create a new batched dirty tracker
213    pub fn new() -> Arc<Self> {
214        let (sender, receiver) = crossbeam_channel::bounded(1024);
215        
216        let tracker = Arc::new(Self {
217            sender,
218            receiver,
219            aggregator_handle: Mutex::new(None),
220            running: AtomicBool::new(false),
221            current_epoch: AtomicU64::new(0),
222            epochs: [
223                Mutex::new(HashSet::new()),
224                Mutex::new(HashSet::new()),
225                Mutex::new(HashSet::new()),
226                Mutex::new(HashSet::new()),
227            ],
228            stats: DirtyTrackingStats::default(),
229        });
230        
231        tracker
232    }
233
234    /// Start the aggregator thread
235    pub fn start(self: &Arc<Self>) {
236        if self.running.swap(true, Ordering::SeqCst) {
237            return; // Already running
238        }
239
240        let tracker = Arc::clone(self);
241        let handle = thread::spawn(move || {
242            tracker.aggregator_loop();
243        });
244
245        *self.aggregator_handle.lock() = Some(handle);
246    }
247
248    /// Stop the aggregator thread
249    pub fn stop(&self) {
250        if !self.running.swap(false, Ordering::SeqCst) {
251            return; // Already stopped
252        }
253
254        // Send shutdown signal
255        let _ = self.sender.send(DirtyEvent::Shutdown);
256
257        // Wait for aggregator to finish
258        if let Some(handle) = self.aggregator_handle.lock().take() {
259            let _ = handle.join();
260        }
261    }
262
263    /// Get a sender for a thread to use
264    pub fn get_sender(&self) -> Sender<DirtyEvent> {
265        self.sender.clone()
266    }
267
268    /// Mark a key as dirty using a thread-local buffer
269    ///
270    /// This is the zero-contention hot path used by writers.
271    #[inline]
272    pub fn mark_dirty(&self, txn_id: u64, key_fingerprint: KeyFingerprint) {
273        thread_local! {
274            static BUFFER: RefCell<Option<ThreadLocalBuffer>> = const { RefCell::new(None) };
275        }
276
277        BUFFER.with(|cell| {
278            let mut buffer = cell.borrow_mut();
279            if buffer.is_none() {
280                *buffer = Some(ThreadLocalBuffer::new(self.sender.clone()));
281            }
282            buffer.as_mut().unwrap().mark_dirty(txn_id, key_fingerprint);
283        });
284    }
285
286    /// Flush the current thread's buffer
287    pub fn flush_thread_buffer(&self) {
288        thread_local! {
289            static BUFFER: RefCell<Option<ThreadLocalBuffer>> = const { RefCell::new(None) };
290        }
291
292        BUFFER.with(|cell| {
293            if let Some(buffer) = cell.borrow_mut().as_mut() {
294                buffer.flush();
295            }
296        });
297    }
298
299    /// Send a batch of dirty keys directly (for transaction commit)
300    #[inline]
301    pub fn send_batch(&self, txn_id: u64, keys: Vec<KeyFingerprint>) {
302        if keys.is_empty() {
303            return;
304        }
305        let _ = self.sender.try_send(DirtyEvent::Batch { txn_id, keys });
306    }
307
308    /// Advance to next epoch, returning the old epoch's dirty keys
309    pub fn advance_epoch(&self) -> (u64, Vec<KeyFingerprint>) {
310        // Send epoch advance event to ensure all pending events are processed
311        let _ = self.sender.try_send(DirtyEvent::AdvanceEpoch);
312        
313        let old_epoch = self.current_epoch.fetch_add(1, Ordering::SeqCst);
314        let old_idx = (old_epoch as usize) % EPOCH_RING_SIZE;
315        
316        // Drain the old epoch
317        let mut guard = self.epochs[old_idx].lock();
318        let keys: Vec<_> = guard.drain().collect();
319        
320        self.stats.current_epoch.store(old_epoch + 1, Ordering::Relaxed);
321        
322        (old_epoch, keys)
323    }
324
325    /// Get current epoch
326    pub fn current_epoch(&self) -> u64 {
327        self.current_epoch.load(Ordering::Relaxed)
328    }
329
330    /// Get statistics
331    pub fn stats(&self) -> &DirtyTrackingStats {
332        &self.stats
333    }
334
335    /// Aggregator loop - runs in background thread
336    fn aggregator_loop(&self) {
337        use crossbeam_channel::RecvTimeoutError;
338        
339        let timeout = std::time::Duration::from_millis(MAX_FLUSH_INTERVAL_MS);
340        
341        while self.running.load(Ordering::Relaxed) {
342            match self.receiver.recv_timeout(timeout) {
343                Ok(event) => {
344                    self.process_event(event);
345                }
346                Err(RecvTimeoutError::Timeout) => {
347                    // No events for a while - that's fine
348                }
349                Err(RecvTimeoutError::Disconnected) => {
350                    break;
351                }
352            }
353        }
354        
355        // Drain remaining events on shutdown
356        while let Ok(event) = self.receiver.try_recv() {
357            if matches!(event, DirtyEvent::Shutdown) {
358                break;
359            }
360            self.process_event(event);
361        }
362    }
363
364    /// Process a single event
365    fn process_event(&self, event: DirtyEvent) {
366        match event {
367            DirtyEvent::Batch { txn_id: _, keys } => {
368                let epoch = self.current_epoch.load(Ordering::Relaxed);
369                let idx = (epoch as usize) % EPOCH_RING_SIZE;
370                
371                let mut guard = self.epochs[idx].lock();
372                let key_count = keys.len();
373                guard.extend(keys);
374                
375                self.stats.events_received.fetch_add(1, Ordering::Relaxed);
376                self.stats.keys_tracked.fetch_add(key_count as u64, Ordering::Relaxed);
377                self.stats.batches_received.fetch_add(1, Ordering::Relaxed);
378            }
379            DirtyEvent::AdvanceEpoch => {
380                // Epoch advance is handled by the caller
381            }
382            DirtyEvent::Shutdown => {
383                // Will exit the loop
384            }
385        }
386    }
387}
388
389impl Default for BatchedDirtyTracker {
390    fn default() -> Self {
391        let (sender, receiver) = crossbeam_channel::bounded(1024);
392        Self {
393            sender,
394            receiver,
395            aggregator_handle: Mutex::new(None),
396            running: AtomicBool::new(false),
397            current_epoch: AtomicU64::new(0),
398            epochs: [
399                Mutex::new(HashSet::new()),
400                Mutex::new(HashSet::new()),
401                Mutex::new(HashSet::new()),
402                Mutex::new(HashSet::new()),
403            ],
404            stats: DirtyTrackingStats::default(),
405        }
406    }
407}
408
409impl Drop for BatchedDirtyTracker {
410    fn drop(&mut self) {
411        self.stop();
412    }
413}
414
415// ============================================================================
416// TxnDirtyBuffer - Transaction-Local Dirty Key Buffer
417// ============================================================================
418
419/// Transaction-local dirty key buffer for commit-time batching
420///
421/// Instead of tracking dirty keys globally during the transaction,
422/// this buffer accumulates them locally and flushes once at commit.
423/// This is simpler than the thread-local approach and works well
424/// for single-threaded transaction execution.
425pub struct TxnDirtyBuffer {
426    /// Transaction ID
427    txn_id: u64,
428    /// Accumulated dirty key fingerprints
429    keys: Vec<KeyFingerprint>,
430}
431
432impl TxnDirtyBuffer {
433    /// Create a new transaction dirty buffer
434    #[inline]
435    pub fn new(txn_id: u64) -> Self {
436        Self {
437            txn_id,
438            keys: Vec::with_capacity(64),
439        }
440    }
441
442    /// Create with expected capacity
443    #[inline]
444    pub fn with_capacity(txn_id: u64, capacity: usize) -> Self {
445        Self {
446            txn_id,
447            keys: Vec::with_capacity(capacity),
448        }
449    }
450
451    /// Record a dirty key (no lock, just local append)
452    #[inline]
453    pub fn record(&mut self, key_fingerprint: KeyFingerprint) {
454        self.keys.push(key_fingerprint);
455    }
456
457    /// Record multiple dirty keys
458    #[inline]
459    pub fn record_many(&mut self, key_fingerprints: impl IntoIterator<Item = KeyFingerprint>) {
460        self.keys.extend(key_fingerprints);
461    }
462
463    /// Get the transaction ID
464    #[inline]
465    pub fn txn_id(&self) -> u64 {
466        self.txn_id
467    }
468
469    /// Get the number of dirty keys
470    #[inline]
471    pub fn len(&self) -> usize {
472        self.keys.len()
473    }
474
475    /// Check if empty
476    #[inline]
477    pub fn is_empty(&self) -> bool {
478        self.keys.is_empty()
479    }
480
481    /// Drain the buffer and return all keys
482    #[inline]
483    pub fn drain(&mut self) -> Vec<KeyFingerprint> {
484        std::mem::take(&mut self.keys)
485    }
486
487    /// Flush to a BatchedDirtyTracker
488    #[inline]
489    pub fn flush_to(&mut self, tracker: &BatchedDirtyTracker) {
490        if !self.keys.is_empty() {
491            tracker.send_batch(self.txn_id, std::mem::take(&mut self.keys));
492            self.keys = Vec::with_capacity(64);
493        }
494    }
495
496    /// Clear the buffer without flushing
497    #[inline]
498    pub fn clear(&mut self) {
499        self.keys.clear();
500    }
501}
502
503// ============================================================================
504// Tests
505// ============================================================================
506
507#[cfg(test)]
508mod tests {
509    use super::*;
510    use std::time::Duration;
511
512    #[test]
513    fn test_txn_dirty_buffer() {
514        let mut buffer = TxnDirtyBuffer::new(1);
515        
516        buffer.record(KeyFingerprint::from_bytes(b"key1"));
517        buffer.record(KeyFingerprint::from_bytes(b"key2"));
518        buffer.record(KeyFingerprint::from_bytes(b"key3"));
519        
520        assert_eq!(buffer.len(), 3);
521        
522        let keys = buffer.drain();
523        assert_eq!(keys.len(), 3);
524        assert!(buffer.is_empty());
525    }
526
527    #[test]
528    fn test_batched_tracker_basic() {
529        let tracker = BatchedDirtyTracker::new();
530        tracker.start();
531        
532        // Send some events directly
533        tracker.send_batch(1, vec![
534            KeyFingerprint::from_bytes(b"key1"),
535            KeyFingerprint::from_bytes(b"key2"),
536        ]);
537        
538        // Give aggregator time to process
539        thread::sleep(Duration::from_millis(50));
540        
541        // Advance epoch to collect
542        let (_epoch, keys) = tracker.advance_epoch();
543        
544        // Keys should have been processed
545        assert!(tracker.stats().batches_received.load(Ordering::Relaxed) >= 1);
546        
547        tracker.stop();
548    }
549
550    #[test]
551    fn test_epoch_rotation() {
552        let tracker = BatchedDirtyTracker::new();
553        
554        // Directly insert into epochs without starting the aggregator thread
555        {
556            let mut guard = tracker.epochs[0].lock();
557            guard.insert(KeyFingerprint::from_bytes(b"key1"));
558            guard.insert(KeyFingerprint::from_bytes(b"key2"));
559        }
560        
561        let (epoch, keys) = tracker.advance_epoch();
562        assert_eq!(epoch, 0);
563        assert_eq!(keys.len(), 2);
564        
565        // New epoch should be empty
566        let (epoch2, keys2) = tracker.advance_epoch();
567        assert_eq!(epoch2, 1);
568        assert!(keys2.is_empty());
569    }
570}