sochdb_storage/dirty_tracking.rs
1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Batched Dirty Tracking with MPSC Queue
19//!
20//! This module replaces per-write mutex acquisition for dirty tracking
21//! with a batched MPSC approach that dramatically reduces lock contention.
22//!
23//! ## Problem: Lock Convoying
24//!
25//! Per-write mutex acquisition is the canonical scalability killer:
26//! - Serializes otherwise-parallel writers
27//! - Causes lock convoying under contention
28//! - N writers → N lock acquisitions per batch
29//!
30//! ## Solution: Batched MPSC + Thread-Local Buffering
31//!
32//! ```text
33//! ┌────────────────────────────────────────────────────────────────┐
34//! │ Writer Threads │
35//! │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
36//! │ │ Thread 1 │ │ Thread 2 │ │ Thread 3 │ │ Thread N │ │
37//! │ │ Buffer │ │ Buffer │ │ Buffer │ │ Buffer │ │
38//! │ └────┬─────┘ └────┬─────┘ └────┬─────┘ └────┬─────┘ │
39//! │ │ │ │ │ │
40//! │ └─────────────┼─────────────┼─────────────┘ │
41//! │ │ │ │
42//! │ ▼ ▼ │
43//! │ ┌──────────────────────────────┐ │
44//! │ │ MPSC Channel │ │
45//! │ │ (crossbeam-channel) │ │
46//! │ └──────────────┬───────────────┘ │
47//! │ │ │
48//! │ ▼ │
49//! │ ┌──────────────────────────────┐ │
50//! │ │ Aggregator Thread │ │
51//! │ │ (drains every 10ms or │ │
52//! │ │ every 1000 entries) │ │
53//! │ └──────────────────────────────┘ │
54//! └────────────────────────────────────────────────────────────────┘
55//!
56//! Lock acquisitions: O(W) → O(W/B) where W=writes, B=batch size
57//! ```
58//!
59//! ## Performance
60//!
61//! - Thread-local buffer: Zero contention during writes
62//! - MPSC send: ~20ns (vs ~200ns for mutex under contention)
63//! - Batch flush: Amortized over B writes
64//!
65//! ## Usage
66//!
67//! ```ignore
68//! let tracker = BatchedDirtyTracker::new();
69//!
70//! // In transaction:
71//! tracker.mark_dirty(key); // Lock-free, buffers locally
72//!
73//! // At commit:
74//! tracker.flush_buffer(); // Sends batch to aggregator
75//! ```
76
77use std::cell::RefCell;
78use std::collections::HashSet;
79use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
80use std::sync::Arc;
81use std::thread::{self, JoinHandle};
82
83use crossbeam_channel::{self, Receiver, Sender};
84use parking_lot::Mutex;
85
86use crate::txn_arena::KeyFingerprint;
87
88/// Default batch size for thread-local buffer
89const DEFAULT_BATCH_SIZE: usize = 64;
90
91/// Maximum wait time before flushing (in milliseconds)
92const MAX_FLUSH_INTERVAL_MS: u64 = 10;
93
94// ============================================================================
95// DirtyEvent - Batched Dirty Key Notification
96// ============================================================================
97
98/// Event sent through MPSC channel
99#[derive(Debug)]
100pub enum DirtyEvent {
101 /// Batch of dirty key fingerprints from a transaction
102 Batch {
103 txn_id: u64,
104 keys: Vec<KeyFingerprint>,
105 },
106 /// Epoch advance request
107 AdvanceEpoch,
108 /// Shutdown signal
109 Shutdown,
110}
111
112// ============================================================================
113// ThreadLocalBuffer - Per-Thread Dirty Key Accumulator
114// ============================================================================
115
116/// Thread-local buffer for accumulating dirty keys
117struct ThreadLocalBuffer {
118 /// Current transaction ID
119 txn_id: u64,
120 /// Accumulated dirty key fingerprints
121 keys: Vec<KeyFingerprint>,
122 /// Sender to aggregator
123 sender: Sender<DirtyEvent>,
124}
125
126impl ThreadLocalBuffer {
127 fn new(sender: Sender<DirtyEvent>) -> Self {
128 Self {
129 txn_id: 0,
130 keys: Vec::with_capacity(DEFAULT_BATCH_SIZE),
131 sender,
132 }
133 }
134
135 /// Mark a key as dirty (no lock, no send)
136 #[inline]
137 fn mark_dirty(&mut self, txn_id: u64, key_fingerprint: KeyFingerprint) {
138 if self.txn_id != txn_id {
139 // New transaction - flush old buffer if any
140 self.flush();
141 self.txn_id = txn_id;
142 }
143 self.keys.push(key_fingerprint);
144 }
145
146 /// Flush accumulated keys to aggregator
147 fn flush(&mut self) {
148 if !self.keys.is_empty() {
149 let keys = std::mem::take(&mut self.keys);
150 // Best-effort send - don't block if channel is full
151 let _ = self.sender.try_send(DirtyEvent::Batch {
152 txn_id: self.txn_id,
153 keys,
154 });
155 self.keys = Vec::with_capacity(DEFAULT_BATCH_SIZE);
156 }
157 }
158}
159
160// ============================================================================
161// BatchedDirtyTracker - Lock-Free Dirty Tracking
162// ============================================================================
163
164/// Batched dirty tracker with MPSC queue
165///
166/// Provides lock-free dirty tracking for multi-threaded writes.
167/// Each thread accumulates dirty keys locally, then sends them
168/// in batches through an MPSC channel to an aggregator.
169pub struct BatchedDirtyTracker {
170 /// MPSC sender (cloned for each thread)
171 sender: Sender<DirtyEvent>,
172 /// MPSC receiver (for aggregator thread)
173 receiver: Receiver<DirtyEvent>,
174 /// Aggregator thread handle
175 aggregator_handle: Mutex<Option<JoinHandle<()>>>,
176 /// Running flag
177 running: AtomicBool,
178 /// Current epoch
179 current_epoch: AtomicU64,
180 /// Aggregated dirty keys per epoch
181 epochs: [Mutex<HashSet<KeyFingerprint>>; 4],
182 /// Statistics
183 stats: DirtyTrackingStats,
184}
185
186/// Dirty tracking statistics
187pub struct DirtyTrackingStats {
188 /// Total events received
189 pub events_received: AtomicU64,
190 /// Total keys tracked
191 pub keys_tracked: AtomicU64,
192 /// Total batches received
193 pub batches_received: AtomicU64,
194 /// Current epoch
195 pub current_epoch: AtomicU64,
196}
197
198impl Default for DirtyTrackingStats {
199 fn default() -> Self {
200 Self {
201 events_received: AtomicU64::new(0),
202 keys_tracked: AtomicU64::new(0),
203 batches_received: AtomicU64::new(0),
204 current_epoch: AtomicU64::new(0),
205 }
206 }
207}
208
209const EPOCH_RING_SIZE: usize = 4;
210
211impl BatchedDirtyTracker {
212 /// Create a new batched dirty tracker
213 pub fn new() -> Arc<Self> {
214 let (sender, receiver) = crossbeam_channel::bounded(1024);
215
216 let tracker = Arc::new(Self {
217 sender,
218 receiver,
219 aggregator_handle: Mutex::new(None),
220 running: AtomicBool::new(false),
221 current_epoch: AtomicU64::new(0),
222 epochs: [
223 Mutex::new(HashSet::new()),
224 Mutex::new(HashSet::new()),
225 Mutex::new(HashSet::new()),
226 Mutex::new(HashSet::new()),
227 ],
228 stats: DirtyTrackingStats::default(),
229 });
230
231 tracker
232 }
233
234 /// Start the aggregator thread
235 pub fn start(self: &Arc<Self>) {
236 if self.running.swap(true, Ordering::SeqCst) {
237 return; // Already running
238 }
239
240 let tracker = Arc::clone(self);
241 let handle = thread::spawn(move || {
242 tracker.aggregator_loop();
243 });
244
245 *self.aggregator_handle.lock() = Some(handle);
246 }
247
248 /// Stop the aggregator thread
249 pub fn stop(&self) {
250 if !self.running.swap(false, Ordering::SeqCst) {
251 return; // Already stopped
252 }
253
254 // Send shutdown signal
255 let _ = self.sender.send(DirtyEvent::Shutdown);
256
257 // Wait for aggregator to finish
258 if let Some(handle) = self.aggregator_handle.lock().take() {
259 let _ = handle.join();
260 }
261 }
262
263 /// Get a sender for a thread to use
264 pub fn get_sender(&self) -> Sender<DirtyEvent> {
265 self.sender.clone()
266 }
267
268 /// Mark a key as dirty using a thread-local buffer
269 ///
270 /// This is the zero-contention hot path used by writers.
271 #[inline]
272 pub fn mark_dirty(&self, txn_id: u64, key_fingerprint: KeyFingerprint) {
273 thread_local! {
274 static BUFFER: RefCell<Option<ThreadLocalBuffer>> = const { RefCell::new(None) };
275 }
276
277 BUFFER.with(|cell| {
278 let mut buffer = cell.borrow_mut();
279 if buffer.is_none() {
280 *buffer = Some(ThreadLocalBuffer::new(self.sender.clone()));
281 }
282 buffer.as_mut().unwrap().mark_dirty(txn_id, key_fingerprint);
283 });
284 }
285
286 /// Flush the current thread's buffer
287 pub fn flush_thread_buffer(&self) {
288 thread_local! {
289 static BUFFER: RefCell<Option<ThreadLocalBuffer>> = const { RefCell::new(None) };
290 }
291
292 BUFFER.with(|cell| {
293 if let Some(buffer) = cell.borrow_mut().as_mut() {
294 buffer.flush();
295 }
296 });
297 }
298
299 /// Send a batch of dirty keys directly (for transaction commit)
300 #[inline]
301 pub fn send_batch(&self, txn_id: u64, keys: Vec<KeyFingerprint>) {
302 if keys.is_empty() {
303 return;
304 }
305 let _ = self.sender.try_send(DirtyEvent::Batch { txn_id, keys });
306 }
307
308 /// Advance to next epoch, returning the old epoch's dirty keys
309 pub fn advance_epoch(&self) -> (u64, Vec<KeyFingerprint>) {
310 // Send epoch advance event to ensure all pending events are processed
311 let _ = self.sender.try_send(DirtyEvent::AdvanceEpoch);
312
313 let old_epoch = self.current_epoch.fetch_add(1, Ordering::SeqCst);
314 let old_idx = (old_epoch as usize) % EPOCH_RING_SIZE;
315
316 // Drain the old epoch
317 let mut guard = self.epochs[old_idx].lock();
318 let keys: Vec<_> = guard.drain().collect();
319
320 self.stats.current_epoch.store(old_epoch + 1, Ordering::Relaxed);
321
322 (old_epoch, keys)
323 }
324
325 /// Get current epoch
326 pub fn current_epoch(&self) -> u64 {
327 self.current_epoch.load(Ordering::Relaxed)
328 }
329
330 /// Get statistics
331 pub fn stats(&self) -> &DirtyTrackingStats {
332 &self.stats
333 }
334
335 /// Aggregator loop - runs in background thread
336 fn aggregator_loop(&self) {
337 use crossbeam_channel::RecvTimeoutError;
338
339 let timeout = std::time::Duration::from_millis(MAX_FLUSH_INTERVAL_MS);
340
341 while self.running.load(Ordering::Relaxed) {
342 match self.receiver.recv_timeout(timeout) {
343 Ok(event) => {
344 self.process_event(event);
345 }
346 Err(RecvTimeoutError::Timeout) => {
347 // No events for a while - that's fine
348 }
349 Err(RecvTimeoutError::Disconnected) => {
350 break;
351 }
352 }
353 }
354
355 // Drain remaining events on shutdown
356 while let Ok(event) = self.receiver.try_recv() {
357 if matches!(event, DirtyEvent::Shutdown) {
358 break;
359 }
360 self.process_event(event);
361 }
362 }
363
364 /// Process a single event
365 fn process_event(&self, event: DirtyEvent) {
366 match event {
367 DirtyEvent::Batch { txn_id: _, keys } => {
368 let epoch = self.current_epoch.load(Ordering::Relaxed);
369 let idx = (epoch as usize) % EPOCH_RING_SIZE;
370
371 let mut guard = self.epochs[idx].lock();
372 let key_count = keys.len();
373 guard.extend(keys);
374
375 self.stats.events_received.fetch_add(1, Ordering::Relaxed);
376 self.stats.keys_tracked.fetch_add(key_count as u64, Ordering::Relaxed);
377 self.stats.batches_received.fetch_add(1, Ordering::Relaxed);
378 }
379 DirtyEvent::AdvanceEpoch => {
380 // Epoch advance is handled by the caller
381 }
382 DirtyEvent::Shutdown => {
383 // Will exit the loop
384 }
385 }
386 }
387}
388
389impl Default for BatchedDirtyTracker {
390 fn default() -> Self {
391 let (sender, receiver) = crossbeam_channel::bounded(1024);
392 Self {
393 sender,
394 receiver,
395 aggregator_handle: Mutex::new(None),
396 running: AtomicBool::new(false),
397 current_epoch: AtomicU64::new(0),
398 epochs: [
399 Mutex::new(HashSet::new()),
400 Mutex::new(HashSet::new()),
401 Mutex::new(HashSet::new()),
402 Mutex::new(HashSet::new()),
403 ],
404 stats: DirtyTrackingStats::default(),
405 }
406 }
407}
408
409impl Drop for BatchedDirtyTracker {
410 fn drop(&mut self) {
411 self.stop();
412 }
413}
414
415// ============================================================================
416// TxnDirtyBuffer - Transaction-Local Dirty Key Buffer
417// ============================================================================
418
419/// Transaction-local dirty key buffer for commit-time batching
420///
421/// Instead of tracking dirty keys globally during the transaction,
422/// this buffer accumulates them locally and flushes once at commit.
423/// This is simpler than the thread-local approach and works well
424/// for single-threaded transaction execution.
425pub struct TxnDirtyBuffer {
426 /// Transaction ID
427 txn_id: u64,
428 /// Accumulated dirty key fingerprints
429 keys: Vec<KeyFingerprint>,
430}
431
432impl TxnDirtyBuffer {
433 /// Create a new transaction dirty buffer
434 #[inline]
435 pub fn new(txn_id: u64) -> Self {
436 Self {
437 txn_id,
438 keys: Vec::with_capacity(64),
439 }
440 }
441
442 /// Create with expected capacity
443 #[inline]
444 pub fn with_capacity(txn_id: u64, capacity: usize) -> Self {
445 Self {
446 txn_id,
447 keys: Vec::with_capacity(capacity),
448 }
449 }
450
451 /// Record a dirty key (no lock, just local append)
452 #[inline]
453 pub fn record(&mut self, key_fingerprint: KeyFingerprint) {
454 self.keys.push(key_fingerprint);
455 }
456
457 /// Record multiple dirty keys
458 #[inline]
459 pub fn record_many(&mut self, key_fingerprints: impl IntoIterator<Item = KeyFingerprint>) {
460 self.keys.extend(key_fingerprints);
461 }
462
463 /// Get the transaction ID
464 #[inline]
465 pub fn txn_id(&self) -> u64 {
466 self.txn_id
467 }
468
469 /// Get the number of dirty keys
470 #[inline]
471 pub fn len(&self) -> usize {
472 self.keys.len()
473 }
474
475 /// Check if empty
476 #[inline]
477 pub fn is_empty(&self) -> bool {
478 self.keys.is_empty()
479 }
480
481 /// Drain the buffer and return all keys
482 #[inline]
483 pub fn drain(&mut self) -> Vec<KeyFingerprint> {
484 std::mem::take(&mut self.keys)
485 }
486
487 /// Flush to a BatchedDirtyTracker
488 #[inline]
489 pub fn flush_to(&mut self, tracker: &BatchedDirtyTracker) {
490 if !self.keys.is_empty() {
491 tracker.send_batch(self.txn_id, std::mem::take(&mut self.keys));
492 self.keys = Vec::with_capacity(64);
493 }
494 }
495
496 /// Clear the buffer without flushing
497 #[inline]
498 pub fn clear(&mut self) {
499 self.keys.clear();
500 }
501}
502
503// ============================================================================
504// Tests
505// ============================================================================
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510 use std::time::Duration;
511
512 #[test]
513 fn test_txn_dirty_buffer() {
514 let mut buffer = TxnDirtyBuffer::new(1);
515
516 buffer.record(KeyFingerprint::from_bytes(b"key1"));
517 buffer.record(KeyFingerprint::from_bytes(b"key2"));
518 buffer.record(KeyFingerprint::from_bytes(b"key3"));
519
520 assert_eq!(buffer.len(), 3);
521
522 let keys = buffer.drain();
523 assert_eq!(keys.len(), 3);
524 assert!(buffer.is_empty());
525 }
526
527 #[test]
528 fn test_batched_tracker_basic() {
529 let tracker = BatchedDirtyTracker::new();
530 tracker.start();
531
532 // Send some events directly
533 tracker.send_batch(1, vec![
534 KeyFingerprint::from_bytes(b"key1"),
535 KeyFingerprint::from_bytes(b"key2"),
536 ]);
537
538 // Give aggregator time to process
539 thread::sleep(Duration::from_millis(50));
540
541 // Advance epoch to collect
542 let (_epoch, keys) = tracker.advance_epoch();
543
544 // Keys should have been processed
545 assert!(tracker.stats().batches_received.load(Ordering::Relaxed) >= 1);
546
547 tracker.stop();
548 }
549
550 #[test]
551 fn test_epoch_rotation() {
552 let tracker = BatchedDirtyTracker::new();
553
554 // Directly insert into epochs without starting the aggregator thread
555 {
556 let mut guard = tracker.epochs[0].lock();
557 guard.insert(KeyFingerprint::from_bytes(b"key1"));
558 guard.insert(KeyFingerprint::from_bytes(b"key2"));
559 }
560
561 let (epoch, keys) = tracker.advance_epoch();
562 assert_eq!(epoch, 0);
563 assert_eq!(keys.len(), 2);
564
565 // New epoch should be empty
566 let (epoch2, keys2) = tracker.advance_epoch();
567 assert_eq!(epoch2, 1);
568 assert!(keys2.is_empty());
569 }
570}