Skip to main content

sochdb_vector/simd/
visibility.rs

1//! MVCC Visibility Check Kernel
2//!
3//! This module provides SIMD-accelerated visibility checking for MVCC
4//! (Multi-Version Concurrency Control) operations.
5//!
6//! # Algorithm
7//!
8//! A row is visible if:
9//! ```text
10//! visible[i] = (commit_ts[i] != 0) && (commit_ts[i] < snapshot_ts)
11//! ```
12//!
13//! With transaction ID awareness:
14//! ```text
15//! visible[i] = ((commit_ts[i] != 0) && (commit_ts[i] < snapshot_ts)) || (txn_id[i] == current_txn)
16//! ```
17//!
18//! # Boolean Logic
19//!
20//! ```text
21//! visible = (commit ≠ 0) ∧ (commit < snapshot)
22//!         = ¬(commit = 0) ∧ (commit < snapshot)
23//! ```
24//!
25//! # SIMD Strategy
26//!
27//! - **AVX2**: Process 4 u64 timestamps per 256-bit register
28//! - **NEON**: Process 2 u64 timestamps per 128-bit register
29
30use super::dispatch::cpu_features;
31
32/// Check visibility for a batch of rows based on commit timestamps.
33///
34/// # Arguments
35/// * `commit_timestamps` - Array of commit timestamps (0 = uncommitted)
36/// * `snapshot_ts` - The snapshot timestamp for visibility check
37/// * `visible_mask` - Output: 1 if visible, 0 if not visible
38///
39/// # Panics
40/// Panics if `visible_mask.len() < commit_timestamps.len()`
41#[inline]
42pub fn visibility_check(commit_timestamps: &[u64], snapshot_ts: u64, visible_mask: &mut [u8]) {
43    let n_rows = commit_timestamps.len();
44    assert!(
45        visible_mask.len() >= n_rows,
46        "visible_mask buffer too small"
47    );
48
49    if n_rows == 0 {
50        return;
51    }
52
53    let features = cpu_features();
54
55    #[cfg(target_arch = "x86_64")]
56    {
57        if features.has_avx2 {
58            unsafe { visibility_check_avx2(commit_timestamps, snapshot_ts, visible_mask) };
59            return;
60        }
61    }
62
63    #[cfg(target_arch = "aarch64")]
64    {
65        if features.has_neon {
66            unsafe { visibility_check_neon(commit_timestamps, snapshot_ts, visible_mask) };
67            return;
68        }
69    }
70
71    visibility_check_scalar(commit_timestamps, snapshot_ts, visible_mask);
72}
73
74/// Check visibility with transaction ID awareness (for self-visibility).
75///
76/// A row is visible if:
77/// - `(commit_ts != 0 && commit_ts < snapshot_ts)`, OR
78/// - `txn_id == current_txn_id` (self-visibility)
79///
80/// # Arguments
81/// * `commit_timestamps` - Array of commit timestamps (0 = uncommitted)
82/// * `txn_ids` - Array of transaction IDs that wrote each row
83/// * `snapshot_ts` - The snapshot timestamp for visibility check
84/// * `current_txn_id` - The current transaction's ID
85/// * `visible_mask` - Output: 1 if visible, 0 if not visible
86#[inline]
87pub fn visibility_check_with_txn(
88    commit_timestamps: &[u64],
89    txn_ids: &[u64],
90    snapshot_ts: u64,
91    current_txn_id: u64,
92    visible_mask: &mut [u8],
93) {
94    let n_rows = commit_timestamps.len();
95    assert_eq!(txn_ids.len(), n_rows, "txn_ids length mismatch");
96    assert!(
97        visible_mask.len() >= n_rows,
98        "visible_mask buffer too small"
99    );
100
101    if n_rows == 0 {
102        return;
103    }
104
105    let features = cpu_features();
106
107    #[cfg(target_arch = "x86_64")]
108    {
109        if features.has_avx2 {
110            unsafe {
111                visibility_check_with_txn_avx2(
112                    commit_timestamps,
113                    txn_ids,
114                    snapshot_ts,
115                    current_txn_id,
116                    visible_mask,
117                )
118            };
119            return;
120        }
121    }
122
123    #[cfg(target_arch = "aarch64")]
124    {
125        if features.has_neon {
126            unsafe {
127                visibility_check_with_txn_neon(
128                    commit_timestamps,
129                    txn_ids,
130                    snapshot_ts,
131                    current_txn_id,
132                    visible_mask,
133                )
134            };
135            return;
136        }
137    }
138
139    visibility_check_with_txn_scalar(
140        commit_timestamps,
141        txn_ids,
142        snapshot_ts,
143        current_txn_id,
144        visible_mask,
145    );
146}
147
148// ============================================================================
149// x86_64 AVX2 Implementation
150// ============================================================================
151
152#[cfg(target_arch = "x86_64")]
153#[target_feature(enable = "avx2")]
154unsafe fn visibility_check_avx2(
155    commit_timestamps: &[u64],
156    snapshot_ts: u64,
157    visible_mask: &mut [u8],
158) {
159    use std::arch::x86_64::*;
160
161    unsafe {
162        let n_rows = commit_timestamps.len();
163        let snapshot_vec = _mm256_set1_epi64x(snapshot_ts as i64);
164        let zero_vec = _mm256_setzero_si256();
165
166        let mut i = 0;
167
168        // Process 4 rows at a time (256 bits / 64 bits = 4)
169        while i + 4 <= n_rows {
170            // Load 4 commit timestamps
171            let commits = _mm256_loadu_si256(commit_timestamps.as_ptr().add(i) as *const __m256i);
172
173            // Check: commit_ts != 0
174            let eq_zero = _mm256_cmpeq_epi64(commits, zero_vec);
175            // Invert: not_zero = ~eq_zero
176            let not_zero = _mm256_xor_si256(eq_zero, _mm256_set1_epi64x(-1));
177
178            // Check: commit_ts < snapshot_ts (using snapshot > commit)
179            let less_than = _mm256_cmpgt_epi64(snapshot_vec, commits);
180
181            // Combine: not_zero AND less_than
182            let visible = _mm256_and_si256(not_zero, less_than);
183
184            // Extract to mask bytes: take bit 63 of each 64-bit lane
185            let mask_bits = _mm256_movemask_pd(_mm256_castsi256_pd(visible));
186
187            visible_mask[i] = if mask_bits & 1 != 0 { 1 } else { 0 };
188            visible_mask[i + 1] = if mask_bits & 2 != 0 { 1 } else { 0 };
189            visible_mask[i + 2] = if mask_bits & 4 != 0 { 1 } else { 0 };
190            visible_mask[i + 3] = if mask_bits & 8 != 0 { 1 } else { 0 };
191
192            i += 4;
193        }
194
195        // Scalar tail
196        while i < n_rows {
197            let commit = commit_timestamps[i];
198            visible_mask[i] = if commit != 0 && commit < snapshot_ts {
199                1
200            } else {
201                0
202            };
203            i += 1;
204        }
205    }
206}
207
208#[cfg(target_arch = "x86_64")]
209#[target_feature(enable = "avx2")]
210unsafe fn visibility_check_with_txn_avx2(
211    commit_timestamps: &[u64],
212    txn_ids: &[u64],
213    snapshot_ts: u64,
214    current_txn_id: u64,
215    visible_mask: &mut [u8],
216) {
217    use std::arch::x86_64::*;
218
219    unsafe {
220        let n_rows = commit_timestamps.len();
221        let snapshot_vec = _mm256_set1_epi64x(snapshot_ts as i64);
222        let zero_vec = _mm256_setzero_si256();
223        let current_txn_vec = _mm256_set1_epi64x(current_txn_id as i64);
224
225        let mut i = 0;
226
227        while i + 4 <= n_rows {
228            // Load 4 commit timestamps and txn IDs
229            let commits = _mm256_loadu_si256(commit_timestamps.as_ptr().add(i) as *const __m256i);
230            let txns = _mm256_loadu_si256(txn_ids.as_ptr().add(i) as *const __m256i);
231
232            // Check: txn_id == current_txn_id (own writes always visible)
233            let own_write = _mm256_cmpeq_epi64(txns, current_txn_vec);
234
235            // Check: commit_ts != 0
236            let eq_zero = _mm256_cmpeq_epi64(commits, zero_vec);
237            let not_zero = _mm256_xor_si256(eq_zero, _mm256_set1_epi64x(-1));
238
239            // Check: commit_ts < snapshot_ts
240            let less_than = _mm256_cmpgt_epi64(snapshot_vec, commits);
241
242            // Combine: (not_zero AND less_than) OR own_write
243            let committed_visible = _mm256_and_si256(not_zero, less_than);
244            let visible = _mm256_or_si256(committed_visible, own_write);
245
246            // Extract to mask bytes
247            let mask_bits = _mm256_movemask_pd(_mm256_castsi256_pd(visible));
248
249            visible_mask[i] = if mask_bits & 1 != 0 { 1 } else { 0 };
250            visible_mask[i + 1] = if mask_bits & 2 != 0 { 1 } else { 0 };
251            visible_mask[i + 2] = if mask_bits & 4 != 0 { 1 } else { 0 };
252            visible_mask[i + 3] = if mask_bits & 8 != 0 { 1 } else { 0 };
253
254            i += 4;
255        }
256
257        // Scalar tail
258        while i < n_rows {
259            let commit = commit_timestamps[i];
260            let txn = txn_ids[i];
261            let visible = (commit != 0 && commit < snapshot_ts) || txn == current_txn_id;
262            visible_mask[i] = if visible { 1 } else { 0 };
263            i += 1;
264        }
265    }
266}
267
268// ============================================================================
269// aarch64 NEON Implementation
270// ============================================================================
271
272#[cfg(target_arch = "aarch64")]
273#[target_feature(enable = "neon")]
274unsafe fn visibility_check_neon(
275    commit_timestamps: &[u64],
276    snapshot_ts: u64,
277    visible_mask: &mut [u8],
278) {
279    use std::arch::aarch64::*;
280
281    unsafe {
282        let n_rows = commit_timestamps.len();
283        let snapshot_vec = vdupq_n_u64(snapshot_ts);
284        let zero_vec = vdupq_n_u64(0);
285
286        let mut i = 0;
287
288        // Process 2 rows at a time (128 bits / 64 bits = 2)
289        while i + 2 <= n_rows {
290            // Load 2 commit timestamps
291            let commits = vld1q_u64(commit_timestamps.as_ptr().add(i));
292
293            // Check: commit_ts != 0
294            let eq_zero = vceqq_u64(commits, zero_vec);
295            // not_zero via bitwise NOT on the bytes
296            let not_zero = vmvnq_u8(vreinterpretq_u8_u64(eq_zero));
297
298            // Check: commit_ts < snapshot_ts
299            // NEON doesn't have vcltq_u64, use subtraction trick
300            // If commit < snapshot, then (commit - snapshot) will have high bit set (underflow)
301            let diff = vsubq_u64(commits, snapshot_vec);
302            let less_than = vshrq_n_u64(diff, 63); // Get sign bit (1 if underflowed)
303
304            // Combine: not_zero AND (less_than == 1)
305            let visible = vandq_u64(
306                vreinterpretq_u64_u8(not_zero),
307                vceqq_u64(less_than, vdupq_n_u64(1)),
308            );
309
310            // Extract to mask bytes
311            visible_mask[i] = if vgetq_lane_u64(visible, 0) != 0 {
312                1
313            } else {
314                0
315            };
316            visible_mask[i + 1] = if vgetq_lane_u64(visible, 1) != 0 {
317                1
318            } else {
319                0
320            };
321
322            i += 2;
323        }
324
325        // Scalar tail
326        while i < n_rows {
327            let commit = commit_timestamps[i];
328            visible_mask[i] = if commit != 0 && commit < snapshot_ts {
329                1
330            } else {
331                0
332            };
333            i += 1;
334        }
335    }
336}
337
338#[cfg(target_arch = "aarch64")]
339#[target_feature(enable = "neon")]
340unsafe fn visibility_check_with_txn_neon(
341    commit_timestamps: &[u64],
342    txn_ids: &[u64],
343    snapshot_ts: u64,
344    current_txn_id: u64,
345    visible_mask: &mut [u8],
346) {
347    use std::arch::aarch64::*;
348
349    unsafe {
350        let n_rows = commit_timestamps.len();
351        let snapshot_vec = vdupq_n_u64(snapshot_ts);
352        let zero_vec = vdupq_n_u64(0);
353        let current_txn_vec = vdupq_n_u64(current_txn_id);
354
355        let mut i = 0;
356
357        while i + 2 <= n_rows {
358            let commits = vld1q_u64(commit_timestamps.as_ptr().add(i));
359            let txns = vld1q_u64(txn_ids.as_ptr().add(i));
360
361            // Check: txn_id == current_txn_id
362            let own_write = vceqq_u64(txns, current_txn_vec);
363
364            // Check: commit_ts != 0
365            let eq_zero = vceqq_u64(commits, zero_vec);
366            let not_zero = vmvnq_u8(vreinterpretq_u8_u64(eq_zero));
367
368            // Check: commit_ts < snapshot_ts
369            let diff = vsubq_u64(commits, snapshot_vec);
370            let less_than = vshrq_n_u64(diff, 63);
371
372            // Combine
373            let committed_visible = vandq_u64(
374                vreinterpretq_u64_u8(not_zero),
375                vceqq_u64(less_than, vdupq_n_u64(1)),
376            );
377            let visible = vorrq_u64(committed_visible, own_write);
378
379            visible_mask[i] = if vgetq_lane_u64(visible, 0) != 0 {
380                1
381            } else {
382                0
383            };
384            visible_mask[i + 1] = if vgetq_lane_u64(visible, 1) != 0 {
385                1
386            } else {
387                0
388            };
389
390            i += 2;
391        }
392
393        // Scalar tail
394        while i < n_rows {
395            let commit = commit_timestamps[i];
396            let txn = txn_ids[i];
397            let visible = (commit != 0 && commit < snapshot_ts) || txn == current_txn_id;
398            visible_mask[i] = if visible { 1 } else { 0 };
399            i += 1;
400        }
401    }
402}
403
404// ============================================================================
405// Scalar Fallback
406// ============================================================================
407
408/// Scalar fallback for visibility check
409#[inline]
410fn visibility_check_scalar(commit_timestamps: &[u64], snapshot_ts: u64, visible_mask: &mut [u8]) {
411    for (i, &commit) in commit_timestamps.iter().enumerate() {
412        visible_mask[i] = if commit != 0 && commit < snapshot_ts {
413            1
414        } else {
415            0
416        };
417    }
418}
419
420/// Scalar fallback for visibility check with txn
421#[inline]
422fn visibility_check_with_txn_scalar(
423    commit_timestamps: &[u64],
424    txn_ids: &[u64],
425    snapshot_ts: u64,
426    current_txn_id: u64,
427    visible_mask: &mut [u8],
428) {
429    for i in 0..commit_timestamps.len() {
430        let commit = commit_timestamps[i];
431        let txn = txn_ids[i];
432        let visible = (commit != 0 && commit < snapshot_ts) || txn == current_txn_id;
433        visible_mask[i] = if visible { 1 } else { 0 };
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440
441    #[test]
442    fn test_visibility_basic() {
443        let commits = vec![0, 100, 200, 300, 400];
444        let snapshot = 250;
445        let mut mask = vec![0u8; 5];
446
447        visibility_check(&commits, snapshot, &mut mask);
448
449        // Expected:
450        // 0: commit=0 (uncommitted) -> not visible
451        // 100: 100 < 250 -> visible
452        // 200: 200 < 250 -> visible
453        // 300: 300 >= 250 -> not visible
454        // 400: 400 >= 250 -> not visible
455        assert_eq!(mask, vec![0, 1, 1, 0, 0]);
456    }
457
458    #[test]
459    fn test_visibility_with_txn() {
460        let commits = vec![0, 100, 200, 300, 0];
461        let txn_ids = vec![10, 20, 30, 40, 50];
462        let snapshot = 250;
463        let current_txn = 50;
464        let mut mask = vec![0u8; 5];
465
466        visibility_check_with_txn(&commits, &txn_ids, snapshot, current_txn, &mut mask);
467
468        // Expected:
469        // 0: commit=0, txn=10 != 50 -> not visible
470        // 100: commit < snapshot -> visible
471        // 200: commit < snapshot -> visible
472        // 300: commit >= snapshot, txn=40 != 50 -> not visible
473        // 0: commit=0, txn=50 == 50 -> visible (self-visibility)
474        assert_eq!(mask, vec![0, 1, 1, 0, 1]);
475    }
476
477    #[test]
478    fn test_visibility_alignment() {
479        // Test with sizes that don't align to SIMD width
480        for n_rows in [1, 2, 3, 4, 5, 7, 9, 15, 17] {
481            let commits: Vec<u64> = (0..n_rows).map(|i| (i * 100) as u64).collect();
482            let snapshot = 500;
483            let mut mask = vec![0u8; n_rows];
484
485            visibility_check(&commits, snapshot, &mut mask);
486
487            // Verify against scalar
488            let mut expected = vec![0u8; n_rows];
489            visibility_check_scalar(&commits, snapshot, &mut expected);
490
491            assert_eq!(mask, expected, "Mismatch for n_rows={}", n_rows);
492        }
493    }
494
495    #[test]
496    fn test_visibility_edge_cases() {
497        // All zeros
498        let commits = vec![0u64; 10];
499        let mut mask = vec![1u8; 10];
500        visibility_check(&commits, 100, &mut mask);
501        assert!(mask.iter().all(|&m| m == 0));
502
503        // All equal to snapshot
504        let commits = vec![100u64; 10];
505        let mut mask = vec![1u8; 10];
506        visibility_check(&commits, 100, &mut mask);
507        assert!(mask.iter().all(|&m| m == 0));
508
509        // All less than snapshot
510        let commits = vec![99u64; 10];
511        let mut mask = vec![0u8; 10];
512        visibility_check(&commits, 100, &mut mask);
513        assert!(mask.iter().all(|&m| m == 1));
514    }
515}