Skip to main content

zer_compute/
batch_sizer.rs

1//! Auto-tunes the GPU batch size from available VRAM using the exact buffer layout.
2
3use crate::soa::STRING_STRIDE;
4
5/// Fraction of VRAM to target by default. Leaves headroom for the DeBERTa judge
6/// (Phase 7), which can occupy ~1–2 GB depending on the model variant.
7const DEFAULT_VRAM_UTILIZATION: f32 = 0.75;
8
9/// Minimum number of pairs required to justify the GPU kernel launch overhead.
10/// Below this threshold `DeviceComparator` silently falls back to the CPU path.
11pub const GPU_BATCH_MIN: usize = 1_000;
12
13/// Computes the maximum batch size that fits safely within VRAM.
14///
15/// Uses the exact per-pair device memory layout of the GPU compare kernel, no
16/// field-length estimation required because every string is padded to
17/// [`STRING_STRIDE`] bytes on the device regardless of actual content.
18///
19/// # Per-pair device memory layout
20///
21/// | Buffer | Bytes per pair |
22/// |---|---|
23/// | `d_data_a` + `d_data_b` (u8, STRING_STRIDE each) | `2  times  n_fields  times  64` |
24/// | `d_lens_a` + `d_lens_b` (u16) | `2  times  n_fields  times  2` |
25/// | `d_ids_a`  + `d_ids_b`  (u64) | `2  times  8` |
26/// | `d_weights` + `d_probs` (f32) | `2  times  4` |
27/// | `d_levels`              (u32) | `n_fields  times  4` |
28///
29/// Total: `n_fields  times  136 + 24` bytes per pair (exact, no estimation).
30///
31/// # Example
32///
33/// ```
34/// use zer_compute::batch_sizer::BatchSizer;
35///
36/// let sizer = BatchSizer::new();
37/// // 3 GB available VRAM (e.g. after OS + model overhead), 10 fields
38/// let available = 3u64 * 1024 * 1024 * 1024;
39/// let max = sizer.max_batch_size(available, 10);
40/// assert!(max > 1_000_000, "should easily fit millions of pairs");
41/// ```
42#[derive(Debug, Clone)]
43pub struct BatchSizer {
44    /// Fraction of available VRAM to commit to the comparison batch. Default: 0.75.
45    pub vram_utilization: f32,
46}
47
48impl Default for BatchSizer {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl BatchSizer {
55    pub fn new() -> Self {
56        Self {
57            vram_utilization: DEFAULT_VRAM_UTILIZATION,
58        }
59    }
60
61    /// Override the utilization fraction (0.0 < fraction ≤ 1.0).
62    pub fn with_utilization(mut self, fraction: f32) -> Self {
63        assert!(
64            fraction > 0.0 && fraction <= 1.0,
65            "utilization must be in (0, 1]"
66        );
67        self.vram_utilization = fraction;
68        self
69    }
70
71    /// Compute the maximum number of pairs that fit in `available_vram_bytes` VRAM
72    /// for a schema with `num_fields` fields.
73    ///
74    /// The formula matches the GPU compare kernel buffer layout exactly, no avg_field_len
75    /// estimate is needed because device buffers always use `STRING_STRIDE` bytes per string.
76    ///
77    /// Returns at least 1 so callers never divide by zero.
78    pub fn max_batch_size(&self, available_vram_bytes: u64, num_fields: usize) -> usize {
79        let bytes_per_pair: usize = 2 * num_fields * STRING_STRIDE   // d_data_a + d_data_b (u8)
80            + 2 * num_fields * 2              // d_lens_a + d_lens_b (u16)
81            + 2 * 8                           // d_ids_a  + d_ids_b  (u64)
82            + 2 * 4                           // d_weights + d_probs (f32)
83            + num_fields * 4; // d_levels            (u32)
84
85        let usable = (available_vram_bytes as f64 * self.vram_utilization as f64) as u64;
86        (usable / bytes_per_pair as u64).max(1) as usize
87    }
88
89    /// Minimum batch size to justify a GPU kernel launch. Batches smaller than
90    /// this are routed to the CPU path transparently.
91    pub const fn min_batch_for_gpu() -> usize {
92        GPU_BATCH_MIN
93    }
94}
95
96// ── Unit tests ────────────────────────────────────────────────────────────────
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[test]
103    fn max_batch_grows_with_vram() {
104        let sizer = BatchSizer::new();
105        let small = sizer.max_batch_size(1 * 1024 * 1024 * 1024, 10);
106        let large = sizer.max_batch_size(8 * 1024 * 1024 * 1024, 10);
107        assert!(large > small);
108    }
109
110    #[test]
111    fn max_batch_never_zero() {
112        let sizer = BatchSizer::new();
113        // Even with absurdly many fields it must return at least 1
114        let r = sizer.max_batch_size(1, 1000);
115        assert_eq!(r, 1);
116    }
117
118    #[test]
119    fn three_gb_vram_fits_millions() {
120        let sizer = BatchSizer::new();
121        // 3 GB is a realistic headroom figure after OS + model overhead
122        // 10 fields  times  136 + 24 = 1,384 bytes/pair → >2M pairs in 3 GB at 75%
123        let available = 3u64 * 1024 * 1024 * 1024;
124        let max = sizer.max_batch_size(available, 10);
125        assert!(max > 1_000_000, "expected >1M pairs, got {max}");
126    }
127
128    #[test]
129    fn min_batch_constant_is_positive() {
130        assert!(BatchSizer::min_batch_for_gpu() > 0);
131    }
132
133    #[test]
134    fn utilization_scales_result() {
135        let full = BatchSizer::new()
136            .with_utilization(1.0)
137            .max_batch_size(1_000_000, 5);
138        let half = BatchSizer::new()
139            .with_utilization(0.5)
140            .max_batch_size(1_000_000, 5);
141        assert!(full > half);
142    }
143
144    #[test]
145    fn formula_matches_compare_pool_layout() {
146        // Verify the formula matches the GPU compare kernel buffer layout.
147        // For n=1 field, 1 pair the device allocates:
148        //   d_data_a/b: 2  times  1  times  64 = 128 bytes
149        //   d_lens_a/b: 2  times  1  times  2  =   4 bytes
150        //   d_ids_a/b:  2  times  8       =  16 bytes
151        //   d_weights/probs: 2  times  4  =   8 bytes
152        //   d_levels:   1  times  4       =   4 bytes
153        //   total = 160 bytes/pair
154        let bytes_per_pair_1field = 2 * 1 * STRING_STRIDE + 2 * 1 * 2 + 16 + 8 + 1 * 4;
155        assert_eq!(bytes_per_pair_1field, 160);
156
157        let sizer = BatchSizer::new().with_utilization(1.0);
158        let max = sizer.max_batch_size(160, 1);
159        assert_eq!(
160            max, 1,
161            "exactly one pair should fit in 160 bytes for 1 field"
162        );
163    }
164}