Skip to main content

sochdb_storage/
dirty_tracking.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Batched Dirty Tracking with MPSC Queue
19//!
20//! This module replaces per-write mutex acquisition for dirty tracking
21//! with a batched MPSC approach that dramatically reduces lock contention.
22//!
23//! ## Problem: Lock Convoying
24//!
25//! Per-write mutex acquisition is the canonical scalability killer:
26//! - Serializes otherwise-parallel writers
27//! - Causes lock convoying under contention
28//! - N writers → N lock acquisitions per batch
29//!
30//! ## Solution: Batched MPSC + Thread-Local Buffering
31//!
32//! ```text
33//! ┌────────────────────────────────────────────────────────────────┐
34//! │                    Writer Threads                               │
35//! │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐       │
36//! │  │ Thread 1 │  │ Thread 2 │  │ Thread 3 │  │ Thread N │       │
37//! │  │  Buffer  │  │  Buffer  │  │  Buffer  │  │  Buffer  │       │
38//! │  └────┬─────┘  └────┬─────┘  └────┬─────┘  └────┬─────┘       │
39//! │       │             │             │             │              │
40//! │       └─────────────┼─────────────┼─────────────┘              │
41//! │                     │             │                            │
42//! │                     ▼             ▼                            │
43//! │              ┌──────────────────────────────┐                  │
44//! │              │      MPSC Channel            │                  │
45//! │              │  (crossbeam-channel)         │                  │
46//! │              └──────────────┬───────────────┘                  │
47//! │                             │                                  │
48//! │                             ▼                                  │
49//! │              ┌──────────────────────────────┐                  │
50//! │              │   Aggregator Thread          │                  │
51//! │              │   (drains every 10ms or     │                  │
52//! │              │    every 1000 entries)       │                  │
53//! │              └──────────────────────────────┘                  │
54//! └────────────────────────────────────────────────────────────────┘
55//!
56//! Lock acquisitions: O(W) → O(W/B) where W=writes, B=batch size
57//! ```
58//!
59//! ## Performance
60//!
61//! - Thread-local buffer: Zero contention during writes
62//! - MPSC send: ~20ns (vs ~200ns for mutex under contention)
63//! - Batch flush: Amortized over B writes
64//!
65//! ## Usage
66//!
67//! ```ignore
68//! let tracker = BatchedDirtyTracker::new();
69//!
70//! // In transaction:
71//! tracker.mark_dirty(key);  // Lock-free, buffers locally
72//!
73//! // At commit:
74//! tracker.flush_buffer();   // Sends batch to aggregator
75//! ```
76
77use std::cell::RefCell;
78use std::collections::HashSet;
79use std::sync::Arc;
80use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
81use std::thread;
82
83use crossbeam_channel::{self, Receiver, Sender};
84use parking_lot::Mutex;
85
86use crate::supervisor::{SupervisedWorker, Supervisor, WorkerHealth, WorkerStep};
87use crate::txn_arena::KeyFingerprint;
88
89/// Default batch size for thread-local buffer
90const DEFAULT_BATCH_SIZE: usize = 64;
91
92/// Maximum wait time before flushing (in milliseconds)
93const MAX_FLUSH_INTERVAL_MS: u64 = 10;
94
95// ============================================================================
96// DirtyEvent - Batched Dirty Key Notification
97// ============================================================================
98
99/// Event sent through MPSC channel
100#[derive(Debug)]
101pub enum DirtyEvent {
102    /// Batch of dirty key fingerprints from a transaction
103    Batch {
104        txn_id: u64,
105        keys: Vec<KeyFingerprint>,
106    },
107    /// Epoch advance request
108    AdvanceEpoch,
109    /// Shutdown signal
110    Shutdown,
111}
112
113// ============================================================================
114// ThreadLocalBuffer - Per-Thread Dirty Key Accumulator
115// ============================================================================
116
117/// Thread-local buffer for accumulating dirty keys
118struct ThreadLocalBuffer {
119    /// Current transaction ID
120    txn_id: u64,
121    /// Accumulated dirty key fingerprints
122    keys: Vec<KeyFingerprint>,
123    /// Sender to aggregator
124    sender: Sender<DirtyEvent>,
125}
126
127impl ThreadLocalBuffer {
128    fn new(sender: Sender<DirtyEvent>) -> Self {
129        Self {
130            txn_id: 0,
131            keys: Vec::with_capacity(DEFAULT_BATCH_SIZE),
132            sender,
133        }
134    }
135
136    /// Mark a key as dirty (no lock, no send)
137    #[inline]
138    fn mark_dirty(&mut self, txn_id: u64, key_fingerprint: KeyFingerprint) {
139        if self.txn_id != txn_id {
140            // New transaction - flush old buffer if any
141            self.flush();
142            self.txn_id = txn_id;
143        }
144        self.keys.push(key_fingerprint);
145    }
146
147    /// Flush accumulated keys to aggregator
148    fn flush(&mut self) {
149        if !self.keys.is_empty() {
150            let keys = std::mem::take(&mut self.keys);
151            // Best-effort send - don't block if channel is full
152            let _ = self.sender.try_send(DirtyEvent::Batch {
153                txn_id: self.txn_id,
154                keys,
155            });
156            self.keys = Vec::with_capacity(DEFAULT_BATCH_SIZE);
157        }
158    }
159}
160
161// ============================================================================
162// BatchedDirtyTracker - Lock-Free Dirty Tracking
163// ============================================================================
164
165/// Batched dirty tracker with MPSC queue
166///
167/// Provides lock-free dirty tracking for multi-threaded writes.
168/// Each thread accumulates dirty keys locally, then sends them
169/// in batches through an MPSC channel to an aggregator.
170pub struct BatchedDirtyTracker {
171    /// MPSC sender (cloned for each thread)
172    sender: Sender<DirtyEvent>,
173    /// MPSC receiver (for aggregator thread)
174    receiver: Receiver<DirtyEvent>,
175    /// Supervised aggregator worker handle (panic-contained, auto-restarting)
176    aggregator_handle: Mutex<Option<SupervisedWorker>>,
177    /// Running flag (shared with the supervised worker for cooperative shutdown)
178    running: Arc<AtomicBool>,
179    /// Current epoch
180    current_epoch: AtomicU64,
181    /// Aggregated dirty keys per epoch
182    epochs: [Mutex<HashSet<KeyFingerprint>>; 4],
183    /// Statistics
184    stats: DirtyTrackingStats,
185}
186
187/// Dirty tracking statistics
188pub struct DirtyTrackingStats {
189    /// Total events received
190    pub events_received: AtomicU64,
191    /// Total keys tracked
192    pub keys_tracked: AtomicU64,
193    /// Total batches received
194    pub batches_received: AtomicU64,
195    /// Current epoch
196    pub current_epoch: AtomicU64,
197}
198
199impl Default for DirtyTrackingStats {
200    fn default() -> Self {
201        Self {
202            events_received: AtomicU64::new(0),
203            keys_tracked: AtomicU64::new(0),
204            batches_received: AtomicU64::new(0),
205            current_epoch: AtomicU64::new(0),
206        }
207    }
208}
209
210const EPOCH_RING_SIZE: usize = 4;
211
212impl BatchedDirtyTracker {
213    /// Create a new batched dirty tracker
214    pub fn new() -> Arc<Self> {
215        let (sender, receiver) = crossbeam_channel::bounded(1024);
216
217        let tracker = Arc::new(Self {
218            sender,
219            receiver,
220            aggregator_handle: Mutex::new(None),
221            running: Arc::new(AtomicBool::new(false)),
222            current_epoch: AtomicU64::new(0),
223            epochs: [
224                Mutex::new(HashSet::new()),
225                Mutex::new(HashSet::new()),
226                Mutex::new(HashSet::new()),
227                Mutex::new(HashSet::new()),
228            ],
229            stats: DirtyTrackingStats::default(),
230        });
231
232        tracker
233    }
234
235    /// Start the aggregator thread.
236    ///
237    /// The aggregator runs under a [`Supervisor`]: if processing a single event
238    /// panics (e.g. a poisoned lock), the panic is contained and counted and the
239    /// loop restarts rather than silently dying and stalling dirty tracking.
240    pub fn start(self: &Arc<Self>) {
241        if self.running.swap(true, Ordering::SeqCst) {
242            return; // Already running
243        }
244
245        let tracker = Arc::clone(self);
246        let running = Arc::clone(&self.running);
247        let worker =
248            Supervisor::new("dirty-aggregator").spawn(running, move || tracker.aggregator_step());
249
250        *self.aggregator_handle.lock() = Some(worker);
251    }
252
253    /// Health/liveness counters for the supervised aggregator, if running.
254    pub fn aggregator_health(&self) -> Option<WorkerHealth> {
255        self.aggregator_handle.lock().as_ref().map(|w| w.health())
256    }
257
258    /// Stop the aggregator thread
259    pub fn stop(&self) {
260        if !self.running.swap(false, Ordering::SeqCst) {
261            return; // Already stopped
262        }
263
264        // Send shutdown signal to unblock a parked recv immediately.
265        let _ = self.sender.send(DirtyEvent::Shutdown);
266
267        // Wait for the supervised aggregator to finish.
268        if let Some(worker) = self.aggregator_handle.lock().take() {
269            worker.join();
270        }
271
272        // Drain any remaining events so observers see a consistent final state.
273        while let Ok(event) = self.receiver.try_recv() {
274            if matches!(event, DirtyEvent::Shutdown) {
275                break;
276            }
277            self.process_event(event);
278        }
279    }
280
281    /// Get a sender for a thread to use
282    pub fn get_sender(&self) -> Sender<DirtyEvent> {
283        self.sender.clone()
284    }
285
286    /// Mark a key as dirty using a thread-local buffer
287    ///
288    /// This is the zero-contention hot path used by writers.
289    #[inline]
290    pub fn mark_dirty(&self, txn_id: u64, key_fingerprint: KeyFingerprint) {
291        thread_local! {
292            static BUFFER: RefCell<Option<ThreadLocalBuffer>> = const { RefCell::new(None) };
293        }
294
295        BUFFER.with(|cell| {
296            let mut buffer = cell.borrow_mut();
297            if buffer.is_none() {
298                *buffer = Some(ThreadLocalBuffer::new(self.sender.clone()));
299            }
300            buffer.as_mut().unwrap().mark_dirty(txn_id, key_fingerprint);
301        });
302    }
303
304    /// Flush the current thread's buffer
305    pub fn flush_thread_buffer(&self) {
306        thread_local! {
307            static BUFFER: RefCell<Option<ThreadLocalBuffer>> = const { RefCell::new(None) };
308        }
309
310        BUFFER.with(|cell| {
311            if let Some(buffer) = cell.borrow_mut().as_mut() {
312                buffer.flush();
313            }
314        });
315    }
316
317    /// Send a batch of dirty keys directly (for transaction commit)
318    #[inline]
319    pub fn send_batch(&self, txn_id: u64, keys: Vec<KeyFingerprint>) {
320        if keys.is_empty() {
321            return;
322        }
323        let _ = self.sender.try_send(DirtyEvent::Batch { txn_id, keys });
324    }
325
326    /// Advance to next epoch, returning the old epoch's dirty keys
327    pub fn advance_epoch(&self) -> (u64, Vec<KeyFingerprint>) {
328        // Send epoch advance event to ensure all pending events are processed
329        let _ = self.sender.try_send(DirtyEvent::AdvanceEpoch);
330
331        let old_epoch = self.current_epoch.fetch_add(1, Ordering::SeqCst);
332        let old_idx = (old_epoch as usize) % EPOCH_RING_SIZE;
333
334        // Drain the old epoch
335        let mut guard = self.epochs[old_idx].lock();
336        let keys: Vec<_> = guard.drain().collect();
337
338        self.stats
339            .current_epoch
340            .store(old_epoch + 1, Ordering::Relaxed);
341
342        (old_epoch, keys)
343    }
344
345    /// Get current epoch
346    pub fn current_epoch(&self) -> u64 {
347        self.current_epoch.load(Ordering::Relaxed)
348    }
349
350    /// Get statistics
351    pub fn stats(&self) -> &DirtyTrackingStats {
352        &self.stats
353    }
354
355    /// One supervised iteration of the aggregator loop.
356    ///
357    /// Blocks up to `MAX_FLUSH_INTERVAL_MS` for the next event, processes it, and
358    /// reports whether the worker should continue. A panic inside
359    /// [`process_event`](Self::process_event) is contained by the [`Supervisor`].
360    fn aggregator_step(&self) -> WorkerStep {
361        use crossbeam_channel::RecvTimeoutError;
362
363        let timeout = std::time::Duration::from_millis(MAX_FLUSH_INTERVAL_MS);
364        match self.receiver.recv_timeout(timeout) {
365            Ok(DirtyEvent::Shutdown) => WorkerStep::Stop,
366            Ok(event) => {
367                self.process_event(event);
368                WorkerStep::Continue
369            }
370            Err(RecvTimeoutError::Timeout) => WorkerStep::Continue,
371            Err(RecvTimeoutError::Disconnected) => WorkerStep::Stop,
372        }
373    }
374
375    /// Process a single event
376    fn process_event(&self, event: DirtyEvent) {
377        match event {
378            DirtyEvent::Batch { txn_id: _, keys } => {
379                let epoch = self.current_epoch.load(Ordering::Relaxed);
380                let idx = (epoch as usize) % EPOCH_RING_SIZE;
381
382                let mut guard = self.epochs[idx].lock();
383                let key_count = keys.len();
384                guard.extend(keys);
385
386                self.stats.events_received.fetch_add(1, Ordering::Relaxed);
387                self.stats
388                    .keys_tracked
389                    .fetch_add(key_count as u64, Ordering::Relaxed);
390                self.stats.batches_received.fetch_add(1, Ordering::Relaxed);
391            }
392            DirtyEvent::AdvanceEpoch => {
393                // Epoch advance is handled by the caller
394            }
395            DirtyEvent::Shutdown => {
396                // Will exit the loop
397            }
398        }
399    }
400}
401
402impl Default for BatchedDirtyTracker {
403    fn default() -> Self {
404        let (sender, receiver) = crossbeam_channel::bounded(1024);
405        Self {
406            sender,
407            receiver,
408            aggregator_handle: Mutex::new(None),
409            running: Arc::new(AtomicBool::new(false)),
410            current_epoch: AtomicU64::new(0),
411            epochs: [
412                Mutex::new(HashSet::new()),
413                Mutex::new(HashSet::new()),
414                Mutex::new(HashSet::new()),
415                Mutex::new(HashSet::new()),
416            ],
417            stats: DirtyTrackingStats::default(),
418        }
419    }
420}
421
422impl Drop for BatchedDirtyTracker {
423    fn drop(&mut self) {
424        self.stop();
425    }
426}
427
428// ============================================================================
429// TxnDirtyBuffer - Transaction-Local Dirty Key Buffer
430// ============================================================================
431
432/// Transaction-local dirty key buffer for commit-time batching
433///
434/// Instead of tracking dirty keys globally during the transaction,
435/// this buffer accumulates them locally and flushes once at commit.
436/// This is simpler than the thread-local approach and works well
437/// for single-threaded transaction execution.
438pub struct TxnDirtyBuffer {
439    /// Transaction ID
440    txn_id: u64,
441    /// Accumulated dirty key fingerprints
442    keys: Vec<KeyFingerprint>,
443}
444
445impl TxnDirtyBuffer {
446    /// Create a new transaction dirty buffer
447    #[inline]
448    pub fn new(txn_id: u64) -> Self {
449        Self {
450            txn_id,
451            keys: Vec::with_capacity(64),
452        }
453    }
454
455    /// Create with expected capacity
456    #[inline]
457    pub fn with_capacity(txn_id: u64, capacity: usize) -> Self {
458        Self {
459            txn_id,
460            keys: Vec::with_capacity(capacity),
461        }
462    }
463
464    /// Record a dirty key (no lock, just local append)
465    #[inline]
466    pub fn record(&mut self, key_fingerprint: KeyFingerprint) {
467        self.keys.push(key_fingerprint);
468    }
469
470    /// Record multiple dirty keys
471    #[inline]
472    pub fn record_many(&mut self, key_fingerprints: impl IntoIterator<Item = KeyFingerprint>) {
473        self.keys.extend(key_fingerprints);
474    }
475
476    /// Get the transaction ID
477    #[inline]
478    pub fn txn_id(&self) -> u64 {
479        self.txn_id
480    }
481
482    /// Get the number of dirty keys
483    #[inline]
484    pub fn len(&self) -> usize {
485        self.keys.len()
486    }
487
488    /// Check if empty
489    #[inline]
490    pub fn is_empty(&self) -> bool {
491        self.keys.is_empty()
492    }
493
494    /// Drain the buffer and return all keys
495    #[inline]
496    pub fn drain(&mut self) -> Vec<KeyFingerprint> {
497        std::mem::take(&mut self.keys)
498    }
499
500    /// Flush to a BatchedDirtyTracker
501    #[inline]
502    pub fn flush_to(&mut self, tracker: &BatchedDirtyTracker) {
503        if !self.keys.is_empty() {
504            tracker.send_batch(self.txn_id, std::mem::take(&mut self.keys));
505            self.keys = Vec::with_capacity(64);
506        }
507    }
508
509    /// Clear the buffer without flushing
510    #[inline]
511    pub fn clear(&mut self) {
512        self.keys.clear();
513    }
514}
515
516// ============================================================================
517// Tests
518// ============================================================================
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523    use std::time::Duration;
524
525    #[test]
526    fn test_txn_dirty_buffer() {
527        let mut buffer = TxnDirtyBuffer::new(1);
528
529        buffer.record(KeyFingerprint::from_bytes(b"key1"));
530        buffer.record(KeyFingerprint::from_bytes(b"key2"));
531        buffer.record(KeyFingerprint::from_bytes(b"key3"));
532
533        assert_eq!(buffer.len(), 3);
534
535        let keys = buffer.drain();
536        assert_eq!(keys.len(), 3);
537        assert!(buffer.is_empty());
538    }
539
540    #[test]
541    fn test_batched_tracker_basic() {
542        let tracker = BatchedDirtyTracker::new();
543        tracker.start();
544
545        // Send some events directly
546        tracker.send_batch(
547            1,
548            vec![
549                KeyFingerprint::from_bytes(b"key1"),
550                KeyFingerprint::from_bytes(b"key2"),
551            ],
552        );
553
554        // Give aggregator time to process
555        thread::sleep(Duration::from_millis(50));
556
557        // Advance epoch to collect
558        let (_epoch, keys) = tracker.advance_epoch();
559
560        // Keys should have been processed
561        assert!(tracker.stats().batches_received.load(Ordering::Relaxed) >= 1);
562
563        tracker.stop();
564    }
565
566    #[test]
567    fn test_epoch_rotation() {
568        let tracker = BatchedDirtyTracker::new();
569
570        // Directly insert into epochs without starting the aggregator thread
571        {
572            let mut guard = tracker.epochs[0].lock();
573            guard.insert(KeyFingerprint::from_bytes(b"key1"));
574            guard.insert(KeyFingerprint::from_bytes(b"key2"));
575        }
576
577        let (epoch, keys) = tracker.advance_epoch();
578        assert_eq!(epoch, 0);
579        assert_eq!(keys.len(), 2);
580
581        // New epoch should be empty
582        let (epoch2, keys2) = tracker.advance_epoch();
583        assert_eq!(epoch2, 1);
584        assert!(keys2.is_empty());
585    }
586}