zer_compute/batch_sizer.rs
1//! Auto-tunes the GPU batch size from available VRAM using the exact buffer layout.
2
3use crate::soa::STRING_STRIDE;
4
5/// Fraction of VRAM to target by default. Leaves headroom for the DeBERTa judge
6/// (Phase 7), which can occupy ~1–2 GB depending on the model variant.
7const DEFAULT_VRAM_UTILIZATION: f32 = 0.75;
8
9/// Minimum number of pairs required to justify the GPU kernel launch overhead.
10/// Below this threshold `DeviceComparator` silently falls back to the CPU path.
11pub const GPU_BATCH_MIN: usize = 1_000;
12
13/// Computes the maximum batch size that fits safely within VRAM.
14///
15/// Uses the exact per-pair device memory layout of the GPU compare kernel, no
16/// field-length estimation required because every string is padded to
17/// [`STRING_STRIDE`] bytes on the device regardless of actual content.
18///
19/// # Per-pair device memory layout
20///
21/// | Buffer | Bytes per pair |
22/// |---|---|
23/// | `d_data_a` + `d_data_b` (u8, STRING_STRIDE each) | `2 times n_fields times 64` |
24/// | `d_lens_a` + `d_lens_b` (u16) | `2 times n_fields times 2` |
25/// | `d_ids_a` + `d_ids_b` (u64) | `2 times 8` |
26/// | `d_weights` + `d_probs` (f32) | `2 times 4` |
27/// | `d_levels` (u32) | `n_fields times 4` |
28///
29/// Total: `n_fields times 136 + 24` bytes per pair (exact, no estimation).
30///
31/// # Example
32///
33/// ```
34/// use zer_compute::batch_sizer::BatchSizer;
35///
36/// let sizer = BatchSizer::new();
37/// // 3 GB available VRAM (e.g. after OS + model overhead), 10 fields
38/// let available = 3u64 * 1024 * 1024 * 1024;
39/// let max = sizer.max_batch_size(available, 10);
40/// assert!(max > 1_000_000, "should easily fit millions of pairs");
41/// ```
42#[derive(Debug, Clone)]
43pub struct BatchSizer {
44 /// Fraction of available VRAM to commit to the comparison batch. Default: 0.75.
45 pub vram_utilization: f32,
46}
47
48impl Default for BatchSizer {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl BatchSizer {
55 pub fn new() -> Self {
56 Self { vram_utilization: DEFAULT_VRAM_UTILIZATION }
57 }
58
59 /// Override the utilization fraction (0.0 < fraction ≤ 1.0).
60 pub fn with_utilization(mut self, fraction: f32) -> Self {
61 assert!(fraction > 0.0 && fraction <= 1.0, "utilization must be in (0, 1]");
62 self.vram_utilization = fraction;
63 self
64 }
65
66 /// Compute the maximum number of pairs that fit in `available_vram_bytes` VRAM
67 /// for a schema with `num_fields` fields.
68 ///
69 /// The formula matches the GPU compare kernel buffer layout exactly, no avg_field_len
70 /// estimate is needed because device buffers always use `STRING_STRIDE` bytes per string.
71 ///
72 /// Returns at least 1 so callers never divide by zero.
73 pub fn max_batch_size(
74 &self,
75 available_vram_bytes: u64,
76 num_fields: usize,
77 ) -> usize {
78 let bytes_per_pair: usize =
79 2 * num_fields * STRING_STRIDE // d_data_a + d_data_b (u8)
80 + 2 * num_fields * 2 // d_lens_a + d_lens_b (u16)
81 + 2 * 8 // d_ids_a + d_ids_b (u64)
82 + 2 * 4 // d_weights + d_probs (f32)
83 + num_fields * 4; // d_levels (u32)
84
85 let usable = (available_vram_bytes as f64 * self.vram_utilization as f64) as u64;
86 (usable / bytes_per_pair as u64).max(1) as usize
87 }
88
89 /// Minimum batch size to justify a GPU kernel launch. Batches smaller than
90 /// this are routed to the CPU path transparently.
91 pub const fn min_batch_for_gpu() -> usize {
92 GPU_BATCH_MIN
93 }
94}
95
96// ── Unit tests ────────────────────────────────────────────────────────────────
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 #[test]
103 fn max_batch_grows_with_vram() {
104 let sizer = BatchSizer::new();
105 let small = sizer.max_batch_size(1 * 1024 * 1024 * 1024, 10);
106 let large = sizer.max_batch_size(8 * 1024 * 1024 * 1024, 10);
107 assert!(large > small);
108 }
109
110 #[test]
111 fn max_batch_never_zero() {
112 let sizer = BatchSizer::new();
113 // Even with absurdly many fields it must return at least 1
114 let r = sizer.max_batch_size(1, 1000);
115 assert_eq!(r, 1);
116 }
117
118 #[test]
119 fn three_gb_vram_fits_millions() {
120 let sizer = BatchSizer::new();
121 // 3 GB is a realistic headroom figure after OS + model overhead
122 // 10 fields times 136 + 24 = 1,384 bytes/pair → >2M pairs in 3 GB at 75%
123 let available = 3u64 * 1024 * 1024 * 1024;
124 let max = sizer.max_batch_size(available, 10);
125 assert!(max > 1_000_000, "expected >1M pairs, got {max}");
126 }
127
128 #[test]
129 fn min_batch_constant_is_positive() {
130 assert!(BatchSizer::min_batch_for_gpu() > 0);
131 }
132
133 #[test]
134 fn utilization_scales_result() {
135 let full = BatchSizer::new().with_utilization(1.0).max_batch_size(1_000_000, 5);
136 let half = BatchSizer::new().with_utilization(0.5).max_batch_size(1_000_000, 5);
137 assert!(full > half);
138 }
139
140 #[test]
141 fn formula_matches_compare_pool_layout() {
142 // Verify the formula matches the GPU compare kernel buffer layout.
143 // For n=1 field, 1 pair the device allocates:
144 // d_data_a/b: 2 times 1 times 64 = 128 bytes
145 // d_lens_a/b: 2 times 1 times 2 = 4 bytes
146 // d_ids_a/b: 2 times 8 = 16 bytes
147 // d_weights/probs: 2 times 4 = 8 bytes
148 // d_levels: 1 times 4 = 4 bytes
149 // total = 160 bytes/pair
150 let bytes_per_pair_1field = 2 * 1 * STRING_STRIDE + 2 * 1 * 2 + 16 + 8 + 1 * 4;
151 assert_eq!(bytes_per_pair_1field, 160);
152
153 let sizer = BatchSizer::new().with_utilization(1.0);
154 let max = sizer.max_batch_size(160, 1);
155 assert_eq!(max, 1, "exactly one pair should fit in 160 bytes for 1 field");
156 }
157}