Skip to main content

zer_compute/
batch_sizer.rs

1//! Auto-tunes the GPU batch size from available VRAM using the exact buffer layout.
2
3use crate::soa::STRING_STRIDE;
4
5/// Fraction of VRAM to target by default. Leaves headroom for the DeBERTa judge
6/// (Phase 7), which can occupy ~1–2 GB depending on the model variant.
7const DEFAULT_VRAM_UTILIZATION: f32 = 0.75;
8
9/// Minimum number of pairs required to justify the GPU kernel launch overhead.
10/// Below this threshold `DeviceComparator` silently falls back to the CPU path.
11pub const GPU_BATCH_MIN: usize = 1_000;
12
13/// Computes the maximum batch size that fits safely within VRAM.
14///
15/// Uses the exact per-pair device memory layout of the GPU compare kernel, no
16/// field-length estimation required because every string is padded to
17/// [`STRING_STRIDE`] bytes on the device regardless of actual content.
18///
19/// # Per-pair device memory layout
20///
21/// | Buffer | Bytes per pair |
22/// |---|---|
23/// | `d_data_a` + `d_data_b` (u8, STRING_STRIDE each) | `2 × n_fields × 64` |
24/// | `d_lens_a` + `d_lens_b` (u16) | `2 × n_fields × 2` |
25/// | `d_ids_a`  + `d_ids_b`  (u64) | `2 × 8` |
26/// | `d_weights` + `d_probs` (f32) | `2 × 4` |
27/// | `d_levels`              (u32) | `n_fields × 4` |
28///
29/// Total: `n_fields × 136 + 24` bytes per pair (exact, no estimation).
30///
31/// # Example
32///
33/// ```
34/// use zer_compute::batch_sizer::BatchSizer;
35///
36/// let sizer = BatchSizer::new();
37/// // 3 GB available VRAM (e.g. after OS + model overhead), 10 fields
38/// let available = 3u64 * 1024 * 1024 * 1024;
39/// let max = sizer.max_batch_size(available, 10);
40/// assert!(max > 1_000_000, "should easily fit millions of pairs");
41/// ```
42#[derive(Debug, Clone)]
43pub struct BatchSizer {
44    /// Fraction of available VRAM to commit to the comparison batch. Default: 0.75.
45    pub vram_utilization: f32,
46}
47
48impl Default for BatchSizer {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl BatchSizer {
55    pub fn new() -> Self {
56        Self { vram_utilization: DEFAULT_VRAM_UTILIZATION }
57    }
58
59    /// Override the utilization fraction (0.0 < fraction ≤ 1.0).
60    pub fn with_utilization(mut self, fraction: f32) -> Self {
61        assert!(fraction > 0.0 && fraction <= 1.0, "utilization must be in (0, 1]");
62        self.vram_utilization = fraction;
63        self
64    }
65
66    /// Compute the maximum number of pairs that fit in `available_vram_bytes` VRAM
67    /// for a schema with `num_fields` fields.
68    ///
69    /// The formula matches the GPU compare kernel buffer layout exactly, no avg_field_len
70    /// estimate is needed because device buffers always use `STRING_STRIDE` bytes per string.
71    ///
72    /// Returns at least 1 so callers never divide by zero.
73    pub fn max_batch_size(
74        &self,
75        available_vram_bytes: u64,
76        num_fields: usize,
77    ) -> usize {
78        let bytes_per_pair: usize =
79              2 * num_fields * STRING_STRIDE   // d_data_a + d_data_b (u8)
80            + 2 * num_fields * 2              // d_lens_a + d_lens_b (u16)
81            + 2 * 8                           // d_ids_a  + d_ids_b  (u64)
82            + 2 * 4                           // d_weights + d_probs (f32)
83            + num_fields * 4;                 // d_levels            (u32)
84
85        let usable = (available_vram_bytes as f64 * self.vram_utilization as f64) as u64;
86        (usable / bytes_per_pair as u64).max(1) as usize
87    }
88
89    /// Minimum batch size to justify a GPU kernel launch. Batches smaller than
90    /// this are routed to the CPU path transparently.
91    pub const fn min_batch_for_gpu() -> usize {
92        GPU_BATCH_MIN
93    }
94}
95
96// ── Unit tests ────────────────────────────────────────────────────────────────
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[test]
103    fn max_batch_grows_with_vram() {
104        let sizer = BatchSizer::new();
105        let small = sizer.max_batch_size(1 * 1024 * 1024 * 1024, 10);
106        let large = sizer.max_batch_size(8 * 1024 * 1024 * 1024, 10);
107        assert!(large > small);
108    }
109
110    #[test]
111    fn max_batch_never_zero() {
112        let sizer = BatchSizer::new();
113        // Even with absurdly many fields it must return at least 1
114        let r = sizer.max_batch_size(1, 1000);
115        assert_eq!(r, 1);
116    }
117
118    #[test]
119    fn three_gb_vram_fits_millions() {
120        let sizer     = BatchSizer::new();
121        // 3 GB is a realistic headroom figure after OS + model overhead
122        // 10 fields × 136 + 24 = 1,384 bytes/pair → >2M pairs in 3 GB at 75%
123        let available = 3u64 * 1024 * 1024 * 1024;
124        let max       = sizer.max_batch_size(available, 10);
125        assert!(max > 1_000_000, "expected >1M pairs, got {max}");
126    }
127
128    #[test]
129    fn min_batch_constant_is_positive() {
130        assert!(BatchSizer::min_batch_for_gpu() > 0);
131    }
132
133    #[test]
134    fn utilization_scales_result() {
135        let full = BatchSizer::new().with_utilization(1.0).max_batch_size(1_000_000, 5);
136        let half = BatchSizer::new().with_utilization(0.5).max_batch_size(1_000_000, 5);
137        assert!(full > half);
138    }
139
140    #[test]
141    fn formula_matches_compare_pool_layout() {
142        // Verify the formula matches the GPU compare kernel buffer layout.
143        // For n=1 field, 1 pair the device allocates:
144        //   d_data_a/b: 2 × 1 × 64 = 128 bytes
145        //   d_lens_a/b: 2 × 1 × 2  =   4 bytes
146        //   d_ids_a/b:  2 × 8       =  16 bytes
147        //   d_weights/probs: 2 × 4  =   8 bytes
148        //   d_levels:   1 × 4       =   4 bytes
149        //   total = 160 bytes/pair
150        let bytes_per_pair_1field = 2 * 1 * STRING_STRIDE + 2 * 1 * 2 + 16 + 8 + 1 * 4;
151        assert_eq!(bytes_per_pair_1field, 160);
152
153        let sizer = BatchSizer::new().with_utilization(1.0);
154        let max   = sizer.max_batch_size(160, 1);
155        assert_eq!(max, 1, "exactly one pair should fit in 160 bytes for 1 field");
156    }
157}