sochdb_storage/dirty_tracking.rs
1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Batched Dirty Tracking with MPSC Queue
19//!
20//! This module replaces per-write mutex acquisition for dirty tracking
21//! with a batched MPSC approach that dramatically reduces lock contention.
22//!
23//! ## Problem: Lock Convoying
24//!
25//! Per-write mutex acquisition is the canonical scalability killer:
26//! - Serializes otherwise-parallel writers
27//! - Causes lock convoying under contention
28//! - N writers → N lock acquisitions per batch
29//!
30//! ## Solution: Batched MPSC + Thread-Local Buffering
31//!
32//! ```text
33//! ┌────────────────────────────────────────────────────────────────┐
34//! │ Writer Threads │
35//! │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
36//! │ │ Thread 1 │ │ Thread 2 │ │ Thread 3 │ │ Thread N │ │
37//! │ │ Buffer │ │ Buffer │ │ Buffer │ │ Buffer │ │
38//! │ └────┬─────┘ └────┬─────┘ └────┬─────┘ └────┬─────┘ │
39//! │ │ │ │ │ │
40//! │ └─────────────┼─────────────┼─────────────┘ │
41//! │ │ │ │
42//! │ ▼ ▼ │
43//! │ ┌──────────────────────────────┐ │
44//! │ │ MPSC Channel │ │
45//! │ │ (crossbeam-channel) │ │
46//! │ └──────────────┬───────────────┘ │
47//! │ │ │
48//! │ ▼ │
49//! │ ┌──────────────────────────────┐ │
50//! │ │ Aggregator Thread │ │
51//! │ │ (drains every 10ms or │ │
52//! │ │ every 1000 entries) │ │
53//! │ └──────────────────────────────┘ │
54//! └────────────────────────────────────────────────────────────────┘
55//!
56//! Lock acquisitions: O(W) → O(W/B) where W=writes, B=batch size
57//! ```
58//!
59//! ## Performance
60//!
61//! - Thread-local buffer: Zero contention during writes
62//! - MPSC send: ~20ns (vs ~200ns for mutex under contention)
63//! - Batch flush: Amortized over B writes
64//!
65//! ## Usage
66//!
67//! ```ignore
68//! let tracker = BatchedDirtyTracker::new();
69//!
70//! // In transaction:
71//! tracker.mark_dirty(key); // Lock-free, buffers locally
72//!
73//! // At commit:
74//! tracker.flush_buffer(); // Sends batch to aggregator
75//! ```
76
77use std::cell::RefCell;
78use std::collections::HashSet;
79use std::sync::Arc;
80use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
81use std::thread;
82
83use crossbeam_channel::{self, Receiver, Sender};
84use parking_lot::Mutex;
85
86use crate::supervisor::{SupervisedWorker, Supervisor, WorkerHealth, WorkerStep};
87use crate::txn_arena::KeyFingerprint;
88
89/// Default batch size for thread-local buffer
90const DEFAULT_BATCH_SIZE: usize = 64;
91
92/// Maximum wait time before flushing (in milliseconds)
93const MAX_FLUSH_INTERVAL_MS: u64 = 10;
94
95// ============================================================================
96// DirtyEvent - Batched Dirty Key Notification
97// ============================================================================
98
99/// Event sent through MPSC channel
100#[derive(Debug)]
101pub enum DirtyEvent {
102 /// Batch of dirty key fingerprints from a transaction
103 Batch {
104 txn_id: u64,
105 keys: Vec<KeyFingerprint>,
106 },
107 /// Epoch advance request
108 AdvanceEpoch,
109 /// Shutdown signal
110 Shutdown,
111}
112
113// ============================================================================
114// ThreadLocalBuffer - Per-Thread Dirty Key Accumulator
115// ============================================================================
116
117/// Thread-local buffer for accumulating dirty keys
118struct ThreadLocalBuffer {
119 /// Current transaction ID
120 txn_id: u64,
121 /// Accumulated dirty key fingerprints
122 keys: Vec<KeyFingerprint>,
123 /// Sender to aggregator
124 sender: Sender<DirtyEvent>,
125}
126
127impl ThreadLocalBuffer {
128 fn new(sender: Sender<DirtyEvent>) -> Self {
129 Self {
130 txn_id: 0,
131 keys: Vec::with_capacity(DEFAULT_BATCH_SIZE),
132 sender,
133 }
134 }
135
136 /// Mark a key as dirty (no lock, no send)
137 #[inline]
138 fn mark_dirty(&mut self, txn_id: u64, key_fingerprint: KeyFingerprint) {
139 if self.txn_id != txn_id {
140 // New transaction - flush old buffer if any
141 self.flush();
142 self.txn_id = txn_id;
143 }
144 self.keys.push(key_fingerprint);
145 }
146
147 /// Flush accumulated keys to aggregator
148 fn flush(&mut self) {
149 if !self.keys.is_empty() {
150 let keys = std::mem::take(&mut self.keys);
151 // Best-effort send - don't block if channel is full
152 let _ = self.sender.try_send(DirtyEvent::Batch {
153 txn_id: self.txn_id,
154 keys,
155 });
156 self.keys = Vec::with_capacity(DEFAULT_BATCH_SIZE);
157 }
158 }
159}
160
161// ============================================================================
162// BatchedDirtyTracker - Lock-Free Dirty Tracking
163// ============================================================================
164
165/// Batched dirty tracker with MPSC queue
166///
167/// Provides lock-free dirty tracking for multi-threaded writes.
168/// Each thread accumulates dirty keys locally, then sends them
169/// in batches through an MPSC channel to an aggregator.
170pub struct BatchedDirtyTracker {
171 /// MPSC sender (cloned for each thread)
172 sender: Sender<DirtyEvent>,
173 /// MPSC receiver (for aggregator thread)
174 receiver: Receiver<DirtyEvent>,
175 /// Supervised aggregator worker handle (panic-contained, auto-restarting)
176 aggregator_handle: Mutex<Option<SupervisedWorker>>,
177 /// Running flag (shared with the supervised worker for cooperative shutdown)
178 running: Arc<AtomicBool>,
179 /// Current epoch
180 current_epoch: AtomicU64,
181 /// Aggregated dirty keys per epoch
182 epochs: [Mutex<HashSet<KeyFingerprint>>; 4],
183 /// Statistics
184 stats: DirtyTrackingStats,
185}
186
187/// Dirty tracking statistics
188pub struct DirtyTrackingStats {
189 /// Total events received
190 pub events_received: AtomicU64,
191 /// Total keys tracked
192 pub keys_tracked: AtomicU64,
193 /// Total batches received
194 pub batches_received: AtomicU64,
195 /// Current epoch
196 pub current_epoch: AtomicU64,
197}
198
199impl Default for DirtyTrackingStats {
200 fn default() -> Self {
201 Self {
202 events_received: AtomicU64::new(0),
203 keys_tracked: AtomicU64::new(0),
204 batches_received: AtomicU64::new(0),
205 current_epoch: AtomicU64::new(0),
206 }
207 }
208}
209
210const EPOCH_RING_SIZE: usize = 4;
211
212impl BatchedDirtyTracker {
213 /// Create a new batched dirty tracker
214 pub fn new() -> Arc<Self> {
215 let (sender, receiver) = crossbeam_channel::bounded(1024);
216
217 let tracker = Arc::new(Self {
218 sender,
219 receiver,
220 aggregator_handle: Mutex::new(None),
221 running: Arc::new(AtomicBool::new(false)),
222 current_epoch: AtomicU64::new(0),
223 epochs: [
224 Mutex::new(HashSet::new()),
225 Mutex::new(HashSet::new()),
226 Mutex::new(HashSet::new()),
227 Mutex::new(HashSet::new()),
228 ],
229 stats: DirtyTrackingStats::default(),
230 });
231
232 tracker
233 }
234
235 /// Start the aggregator thread.
236 ///
237 /// The aggregator runs under a [`Supervisor`]: if processing a single event
238 /// panics (e.g. a poisoned lock), the panic is contained and counted and the
239 /// loop restarts rather than silently dying and stalling dirty tracking.
240 pub fn start(self: &Arc<Self>) {
241 if self.running.swap(true, Ordering::SeqCst) {
242 return; // Already running
243 }
244
245 let tracker = Arc::clone(self);
246 let running = Arc::clone(&self.running);
247 let worker =
248 Supervisor::new("dirty-aggregator").spawn(running, move || tracker.aggregator_step());
249
250 *self.aggregator_handle.lock() = Some(worker);
251 }
252
253 /// Health/liveness counters for the supervised aggregator, if running.
254 pub fn aggregator_health(&self) -> Option<WorkerHealth> {
255 self.aggregator_handle.lock().as_ref().map(|w| w.health())
256 }
257
258 /// Stop the aggregator thread
259 pub fn stop(&self) {
260 if !self.running.swap(false, Ordering::SeqCst) {
261 return; // Already stopped
262 }
263
264 // Send shutdown signal to unblock a parked recv immediately.
265 let _ = self.sender.send(DirtyEvent::Shutdown);
266
267 // Wait for the supervised aggregator to finish.
268 if let Some(worker) = self.aggregator_handle.lock().take() {
269 worker.join();
270 }
271
272 // Drain any remaining events so observers see a consistent final state.
273 while let Ok(event) = self.receiver.try_recv() {
274 if matches!(event, DirtyEvent::Shutdown) {
275 break;
276 }
277 self.process_event(event);
278 }
279 }
280
281 /// Get a sender for a thread to use
282 pub fn get_sender(&self) -> Sender<DirtyEvent> {
283 self.sender.clone()
284 }
285
286 /// Mark a key as dirty using a thread-local buffer
287 ///
288 /// This is the zero-contention hot path used by writers.
289 #[inline]
290 pub fn mark_dirty(&self, txn_id: u64, key_fingerprint: KeyFingerprint) {
291 thread_local! {
292 static BUFFER: RefCell<Option<ThreadLocalBuffer>> = const { RefCell::new(None) };
293 }
294
295 BUFFER.with(|cell| {
296 let mut buffer = cell.borrow_mut();
297 if buffer.is_none() {
298 *buffer = Some(ThreadLocalBuffer::new(self.sender.clone()));
299 }
300 buffer.as_mut().unwrap().mark_dirty(txn_id, key_fingerprint);
301 });
302 }
303
304 /// Flush the current thread's buffer
305 pub fn flush_thread_buffer(&self) {
306 thread_local! {
307 static BUFFER: RefCell<Option<ThreadLocalBuffer>> = const { RefCell::new(None) };
308 }
309
310 BUFFER.with(|cell| {
311 if let Some(buffer) = cell.borrow_mut().as_mut() {
312 buffer.flush();
313 }
314 });
315 }
316
317 /// Send a batch of dirty keys directly (for transaction commit)
318 #[inline]
319 pub fn send_batch(&self, txn_id: u64, keys: Vec<KeyFingerprint>) {
320 if keys.is_empty() {
321 return;
322 }
323 let _ = self.sender.try_send(DirtyEvent::Batch { txn_id, keys });
324 }
325
326 /// Advance to next epoch, returning the old epoch's dirty keys
327 pub fn advance_epoch(&self) -> (u64, Vec<KeyFingerprint>) {
328 // Send epoch advance event to ensure all pending events are processed
329 let _ = self.sender.try_send(DirtyEvent::AdvanceEpoch);
330
331 let old_epoch = self.current_epoch.fetch_add(1, Ordering::SeqCst);
332 let old_idx = (old_epoch as usize) % EPOCH_RING_SIZE;
333
334 // Drain the old epoch
335 let mut guard = self.epochs[old_idx].lock();
336 let keys: Vec<_> = guard.drain().collect();
337
338 self.stats
339 .current_epoch
340 .store(old_epoch + 1, Ordering::Relaxed);
341
342 (old_epoch, keys)
343 }
344
345 /// Get current epoch
346 pub fn current_epoch(&self) -> u64 {
347 self.current_epoch.load(Ordering::Relaxed)
348 }
349
350 /// Get statistics
351 pub fn stats(&self) -> &DirtyTrackingStats {
352 &self.stats
353 }
354
355 /// One supervised iteration of the aggregator loop.
356 ///
357 /// Blocks up to `MAX_FLUSH_INTERVAL_MS` for the next event, processes it, and
358 /// reports whether the worker should continue. A panic inside
359 /// [`process_event`](Self::process_event) is contained by the [`Supervisor`].
360 fn aggregator_step(&self) -> WorkerStep {
361 use crossbeam_channel::RecvTimeoutError;
362
363 let timeout = std::time::Duration::from_millis(MAX_FLUSH_INTERVAL_MS);
364 match self.receiver.recv_timeout(timeout) {
365 Ok(DirtyEvent::Shutdown) => WorkerStep::Stop,
366 Ok(event) => {
367 self.process_event(event);
368 WorkerStep::Continue
369 }
370 Err(RecvTimeoutError::Timeout) => WorkerStep::Continue,
371 Err(RecvTimeoutError::Disconnected) => WorkerStep::Stop,
372 }
373 }
374
375 /// Process a single event
376 fn process_event(&self, event: DirtyEvent) {
377 match event {
378 DirtyEvent::Batch { txn_id: _, keys } => {
379 let epoch = self.current_epoch.load(Ordering::Relaxed);
380 let idx = (epoch as usize) % EPOCH_RING_SIZE;
381
382 let mut guard = self.epochs[idx].lock();
383 let key_count = keys.len();
384 guard.extend(keys);
385
386 self.stats.events_received.fetch_add(1, Ordering::Relaxed);
387 self.stats
388 .keys_tracked
389 .fetch_add(key_count as u64, Ordering::Relaxed);
390 self.stats.batches_received.fetch_add(1, Ordering::Relaxed);
391 }
392 DirtyEvent::AdvanceEpoch => {
393 // Epoch advance is handled by the caller
394 }
395 DirtyEvent::Shutdown => {
396 // Will exit the loop
397 }
398 }
399 }
400}
401
402impl Default for BatchedDirtyTracker {
403 fn default() -> Self {
404 let (sender, receiver) = crossbeam_channel::bounded(1024);
405 Self {
406 sender,
407 receiver,
408 aggregator_handle: Mutex::new(None),
409 running: Arc::new(AtomicBool::new(false)),
410 current_epoch: AtomicU64::new(0),
411 epochs: [
412 Mutex::new(HashSet::new()),
413 Mutex::new(HashSet::new()),
414 Mutex::new(HashSet::new()),
415 Mutex::new(HashSet::new()),
416 ],
417 stats: DirtyTrackingStats::default(),
418 }
419 }
420}
421
422impl Drop for BatchedDirtyTracker {
423 fn drop(&mut self) {
424 self.stop();
425 }
426}
427
428// ============================================================================
429// TxnDirtyBuffer - Transaction-Local Dirty Key Buffer
430// ============================================================================
431
432/// Transaction-local dirty key buffer for commit-time batching
433///
434/// Instead of tracking dirty keys globally during the transaction,
435/// this buffer accumulates them locally and flushes once at commit.
436/// This is simpler than the thread-local approach and works well
437/// for single-threaded transaction execution.
438pub struct TxnDirtyBuffer {
439 /// Transaction ID
440 txn_id: u64,
441 /// Accumulated dirty key fingerprints
442 keys: Vec<KeyFingerprint>,
443}
444
445impl TxnDirtyBuffer {
446 /// Create a new transaction dirty buffer
447 #[inline]
448 pub fn new(txn_id: u64) -> Self {
449 Self {
450 txn_id,
451 keys: Vec::with_capacity(64),
452 }
453 }
454
455 /// Create with expected capacity
456 #[inline]
457 pub fn with_capacity(txn_id: u64, capacity: usize) -> Self {
458 Self {
459 txn_id,
460 keys: Vec::with_capacity(capacity),
461 }
462 }
463
464 /// Record a dirty key (no lock, just local append)
465 #[inline]
466 pub fn record(&mut self, key_fingerprint: KeyFingerprint) {
467 self.keys.push(key_fingerprint);
468 }
469
470 /// Record multiple dirty keys
471 #[inline]
472 pub fn record_many(&mut self, key_fingerprints: impl IntoIterator<Item = KeyFingerprint>) {
473 self.keys.extend(key_fingerprints);
474 }
475
476 /// Get the transaction ID
477 #[inline]
478 pub fn txn_id(&self) -> u64 {
479 self.txn_id
480 }
481
482 /// Get the number of dirty keys
483 #[inline]
484 pub fn len(&self) -> usize {
485 self.keys.len()
486 }
487
488 /// Check if empty
489 #[inline]
490 pub fn is_empty(&self) -> bool {
491 self.keys.is_empty()
492 }
493
494 /// Drain the buffer and return all keys
495 #[inline]
496 pub fn drain(&mut self) -> Vec<KeyFingerprint> {
497 std::mem::take(&mut self.keys)
498 }
499
500 /// Flush to a BatchedDirtyTracker
501 #[inline]
502 pub fn flush_to(&mut self, tracker: &BatchedDirtyTracker) {
503 if !self.keys.is_empty() {
504 tracker.send_batch(self.txn_id, std::mem::take(&mut self.keys));
505 self.keys = Vec::with_capacity(64);
506 }
507 }
508
509 /// Clear the buffer without flushing
510 #[inline]
511 pub fn clear(&mut self) {
512 self.keys.clear();
513 }
514}
515
516// ============================================================================
517// Tests
518// ============================================================================
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523 use std::time::Duration;
524
525 #[test]
526 fn test_txn_dirty_buffer() {
527 let mut buffer = TxnDirtyBuffer::new(1);
528
529 buffer.record(KeyFingerprint::from_bytes(b"key1"));
530 buffer.record(KeyFingerprint::from_bytes(b"key2"));
531 buffer.record(KeyFingerprint::from_bytes(b"key3"));
532
533 assert_eq!(buffer.len(), 3);
534
535 let keys = buffer.drain();
536 assert_eq!(keys.len(), 3);
537 assert!(buffer.is_empty());
538 }
539
540 #[test]
541 fn test_batched_tracker_basic() {
542 let tracker = BatchedDirtyTracker::new();
543 tracker.start();
544
545 // Send some events directly
546 tracker.send_batch(
547 1,
548 vec![
549 KeyFingerprint::from_bytes(b"key1"),
550 KeyFingerprint::from_bytes(b"key2"),
551 ],
552 );
553
554 // Give aggregator time to process
555 thread::sleep(Duration::from_millis(50));
556
557 // Advance epoch to collect
558 let (_epoch, keys) = tracker.advance_epoch();
559
560 // Keys should have been processed
561 assert!(tracker.stats().batches_received.load(Ordering::Relaxed) >= 1);
562
563 tracker.stop();
564 }
565
566 #[test]
567 fn test_epoch_rotation() {
568 let tracker = BatchedDirtyTracker::new();
569
570 // Directly insert into epochs without starting the aggregator thread
571 {
572 let mut guard = tracker.epochs[0].lock();
573 guard.insert(KeyFingerprint::from_bytes(b"key1"));
574 guard.insert(KeyFingerprint::from_bytes(b"key2"));
575 }
576
577 let (epoch, keys) = tracker.advance_epoch();
578 assert_eq!(epoch, 0);
579 assert_eq!(keys.len(), 2);
580
581 // New epoch should be empty
582 let (epoch2, keys2) = tracker.advance_epoch();
583 assert_eq!(epoch2, 1);
584 assert!(keys2.is_empty());
585 }
586}